From aa070e333f0d75aba9795268a55b92763e0064fd Mon Sep 17 00:00:00 2001 From: Tradewindycc Date: Tue, 10 Feb 2026 16:31:01 +0800 Subject: [PATCH 1/4] [minor]: rm sa config for future release --- README.md | 3 +- src/openpi/training/config.py | 208 +++++++++++++++++----------------- 2 files changed, 105 insertions(+), 106 deletions(-) diff --git a/README.md b/README.md index d14cac5..e1f1dd8 100755 --- a/README.md +++ b/README.md @@ -75,7 +75,7 @@ This repository is built on top of [openpi](https://github.com/Physical-Intellig For Model Arithmetic (mixing checkpoints), GPU memory requirements depend on the model size and number of checkpoints being mixed. A single A100 (80GB) is sufficient for most use cases. -The repo has been tested with Ubuntu 22.04. +Non-edge components (e.g., Policy Training, Model Arithmetic) have been tested on Ubuntu 22.04. ### Hardware @@ -112,7 +112,6 @@ uv pip install safetensors Download the Kai0 dataset so it is available under `./data` for training and evaluation. From the repository root, run: ```bash -pip install huggingface_hub # if not already installed python scripts/download_dataset.py ``` diff --git a/src/openpi/training/config.py b/src/openpi/training/config.py index 7090e8b..e60dc88 100755 --- a/src/openpi/training/config.py +++ b/src/openpi/training/config.py @@ -1224,110 +1224,110 @@ def __post_init__(self) -> None: #************************Advantage Estimator*************************** - TrainConfig( - name="ADVANTAGE_TORCH_KAI0_FLATTEN_FOLD", - advantage_estimator=True, - model=pi0_config.AdvantageEstimatorConfig( - pi05=True, - loss_value_weight=1., - loss_action_weight=0., - discrete_state_input=False, - ), - data=LerobotAgilexDataConfig( - repo_id = "Path/to/your/advantage/dataset", - assets=AssetsConfig( - assets_dir="Path/to/your/advantage/dataset/assets", - asset_id="Your_advantage_dataset_name", - ), - default_prompt="Flatten and fold the cloth.", - # * why removing "prompt" here will lead to an error in transforms.py - repack_transforms=_transforms.Group( - inputs=[ - _transforms.RepackTransform( - { - "images": { - "top_head": "observation.images.top_head", - "hand_left": "observation.images.hand_left", - "hand_right": "observation.images.hand_right", - "his_-100_top_head": "his_-100_observation.images.top_head", - "his_-100_hand_left": "his_-100_observation.images.hand_left", - "his_-100_hand_right": "his_-100_observation.images.hand_right", - }, - "state": "observation.state", - "actions": "action", - # "prompt": "prompt", # ! Not adding this for default prompt. - "episode_length": "episode_length", - "frame_index": "frame_index", - "episode_index": "episode_index", - "progress_gt": "progress_gt", - "stage_progress_gt": "stage_progress_gt", - "progress": "progress", - # "is_suboptimal": "is_suboptimal", - } - ) - ] - ) - ), - pytorch_weight_path="Path/to/your/pi05_base/checkpoint", - num_train_steps=100_000, - keep_period=10000, - save_interval=10000, - num_workers=8, - batch_size=16, # * 1 gpus - # batch_size=128, # * 8 gpus - skip_norm_stats=True, # * No norm stats used. - ), - TrainConfig( - name="ADVANTAGE_TORCH_PI06_FLATTEN_FOLD", - advantage_estimator=True, - model=pi0_config.AdvantageEstimatorConfig( - pi05=True, - loss_value_weight=1., - loss_action_weight=0., # No action loss in advantage estimator training - discrete_state_input=False, # Not using states into prompt like pi05 - ), - data=LerobotAgilexDataConfig( - # repo_id = "/cpfs01/shared/filtered_cut_data/short_sleeve/flatten_fold/v9-3/1022_20_590_v9-3_2000_lerobot", - repo_id = "Path/to/your/advantage/dataset", - assets=AssetsConfig( - assets_dir="Path/to/your/advantage/dataset/assets", - asset_id="Your_advantage_dataset_name", - ), - default_prompt="Flatten and fold the cloth.", - # * why removing "prompt" here will lead to an error in transforms.py - repack_transforms=_transforms.Group( - inputs=[ - _transforms.RepackTransform( - { - "images": { - "top_head": "observation.images.top_head", - "hand_left": "observation.images.hand_left", - "hand_right": "observation.images.hand_right", - }, - "state": "observation.state", - "actions": "action", - # "prompt": "prompt", # No need if default prompt is used. - "episode_length": "episode_length", - "frame_index": "frame_index", - "episode_index": "episode_index", - "progress_gt": "progress_gt", - "stage_progress_gt": "stage_progress_gt", - "progress": "progress", - # "is_suboptimal": "is_suboptimal", - } - ) - ] - ) - ), - pytorch_weight_path="Path/to/your/pi06_base/checkpoint", - num_train_steps=100_000, - keep_period=10000, - save_interval=10000, - num_workers=55, - # batch_size=16, # * 1 gpus - batch_size=18*8, # * 8 gpus - skip_norm_stats=True, # * No norm stats used. - ), + # TrainConfig( + # name="ADVANTAGE_TORCH_KAI0_FLATTEN_FOLD", + # advantage_estimator=True, + # model=pi0_config.AdvantageEstimatorConfig( + # pi05=True, + # loss_value_weight=1., + # loss_action_weight=0., + # discrete_state_input=False, + # ), + # data=LerobotAgilexDataConfig( + # repo_id = "Path/to/your/advantage/dataset", + # assets=AssetsConfig( + # assets_dir="Path/to/your/advantage/dataset/assets", + # asset_id="Your_advantage_dataset_name", + # ), + # default_prompt="Flatten and fold the cloth.", + # # * why removing "prompt" here will lead to an error in transforms.py + # repack_transforms=_transforms.Group( + # inputs=[ + # _transforms.RepackTransform( + # { + # "images": { + # "top_head": "observation.images.top_head", + # "hand_left": "observation.images.hand_left", + # "hand_right": "observation.images.hand_right", + # "his_-100_top_head": "his_-100_observation.images.top_head", + # "his_-100_hand_left": "his_-100_observation.images.hand_left", + # "his_-100_hand_right": "his_-100_observation.images.hand_right", + # }, + # "state": "observation.state", + # "actions": "action", + # # "prompt": "prompt", # ! Not adding this for default prompt. + # "episode_length": "episode_length", + # "frame_index": "frame_index", + # "episode_index": "episode_index", + # "progress_gt": "progress_gt", + # "stage_progress_gt": "stage_progress_gt", + # "progress": "progress", + # # "is_suboptimal": "is_suboptimal", + # } + # ) + # ] + # ) + # ), + # pytorch_weight_path="Path/to/your/pi05_base/checkpoint", + # num_train_steps=100_000, + # keep_period=10000, + # save_interval=10000, + # num_workers=8, + # batch_size=16, # * 1 gpus + # # batch_size=128, # * 8 gpus + # skip_norm_stats=True, # * No norm stats used. + # ), + # TrainConfig( + # name="ADVANTAGE_TORCH_PI06_FLATTEN_FOLD", + # advantage_estimator=True, + # model=pi0_config.AdvantageEstimatorConfig( + # pi05=True, + # loss_value_weight=1., + # loss_action_weight=0., # No action loss in advantage estimator training + # discrete_state_input=False, # Not using states into prompt like pi05 + # ), + # data=LerobotAgilexDataConfig( + # # repo_id = "/cpfs01/shared/filtered_cut_data/short_sleeve/flatten_fold/v9-3/1022_20_590_v9-3_2000_lerobot", + # repo_id = "Path/to/your/advantage/dataset", + # assets=AssetsConfig( + # assets_dir="Path/to/your/advantage/dataset/assets", + # asset_id="Your_advantage_dataset_name", + # ), + # default_prompt="Flatten and fold the cloth.", + # # * why removing "prompt" here will lead to an error in transforms.py + # repack_transforms=_transforms.Group( + # inputs=[ + # _transforms.RepackTransform( + # { + # "images": { + # "top_head": "observation.images.top_head", + # "hand_left": "observation.images.hand_left", + # "hand_right": "observation.images.hand_right", + # }, + # "state": "observation.state", + # "actions": "action", + # # "prompt": "prompt", # No need if default prompt is used. + # "episode_length": "episode_length", + # "frame_index": "frame_index", + # "episode_index": "episode_index", + # "progress_gt": "progress_gt", + # "stage_progress_gt": "stage_progress_gt", + # "progress": "progress", + # # "is_suboptimal": "is_suboptimal", + # } + # ) + # ] + # ) + # ), + # pytorch_weight_path="Path/to/your/pi06_base/checkpoint", + # num_train_steps=100_000, + # keep_period=10000, + # save_interval=10000, + # num_workers=55, + # # batch_size=16, # * 1 gpus + # batch_size=18*8, # * 8 gpus + # skip_norm_stats=True, # * No norm stats used. + # ), #************************advantage estimator*************************** # RoboArena & PolaRiS configs. *roboarena_config.get_roboarena_configs(), From d27eb973259ec6d2f30db299830cafadbffd8e02 Mon Sep 17 00:00:00 2001 From: Chonghao Sima Date: Tue, 10 Feb 2026 17:23:56 +0800 Subject: [PATCH 2/4] Update README.md --- README.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index e1f1dd8..fb485b2 100755 --- a/README.md +++ b/README.md @@ -1,5 +1,4 @@ -# χ₀ - +# χ₀: Resource-Aware Robust Manipulation viaTaming Distributional Inconsistencies
@@ -15,7 +14,8 @@
-χ₀ (**kai0**) is a resource-efficient framework for achieving production-level robustness in robotic manipulation by taming distributional inconsistencies. This repository is built on top of [openpi](https://github.com/Physical-Intelligence/openpi), the open-source models and packages for robotics published by the [Physical Intelligence team](https://www.physicalintelligence.company/). +χ₀ (**kai0**) is a resource-efficient framework for achieving production-level robustness in robotic manipulation by taming distributional inconsistencies. + χ₀ addresses the systematic distributional shift among the human demonstration distribution ($P_\text{train}$), the inductive bias learned by the policy ($Q_\text{model}$), and the test-time execution distribution ($P_\text{test}$) through three technical modules: @@ -23,7 +23,7 @@ - **[Stage Advantage](#stage-advantage-coming-soon)**: A stage-aware advantage estimator that provides stable, dense progress signals for policy training. **[Coming Soon]** - **[Train-Deploy Alignment](#train-deploy-alignment-coming-soon)**: Bridges the distribution gap via spatio-temporal augmentation, heuristic DAgger corrections, and temporal chunk-wise smoothing. **[Coming Soon]** -χ₀ enables two sets of dual-arm robots to collaboratively orchestrate long-horizon garment manipulation — flattening, folding, and hanging — surpassing the state-of-the-art $\pi_{0.5}$ baseline by approximately 250% in success rate, with only 20 hours of data and 8 A100 GPUs. +χ₀ enables two sets of dual-arm robots to collaboratively orchestrate long-horizon garment manipulation — flattening, folding, and hanging — surpassing the state-of-the-art $\pi_{0.5}$ baseline by approximately 250% in success rate,with `only 20 hours of data and 8 A100 GPUs`. @@ -31,7 +31,7 @@ https://github.com/user-attachments/assets/e662f096-d273-4458-abd4-e12b9685a9bc ## Table of Contents -- [Updates](#updates) +- [Update](#update) - [Acknowledgement](#acknowledgement) - [Requirements](#requirements) - [Compute](#compute) @@ -52,7 +52,7 @@ https://github.com/user-attachments/assets/e662f096-d273-4458-abd4-e12b9685a9bc - [Troubleshooting](#troubleshooting) - [Links and Community](#links-and-community) -## Updates +## Update - [Feb 10 2026] Initial release of the **Model Arithmetic** module with support for both JAX and PyTorch checkpoints (not tested thoroughly). - [Feb 10 2025] χ₀ paper released. From 12d45547cde0c9cab75503f8dcaa1b8852fe9b69 Mon Sep 17 00:00:00 2001 From: Tradewindycc Date: Tue, 10 Feb 2026 17:25:56 +0800 Subject: [PATCH 3/4] [update]: update docs and ignores --- .gitignore | 2 + README.md | 18 +++- DATASET.md => docs/dataset.md | 0 docs/norm_stats_fast.md | 186 +++++++++++++++++++++++--------- scripts/download_checkpoints.py | 108 +++++++++++++++++++ scripts/download_dataset.py | 41 +++++-- scripts/train.py | 10 +- 7 files changed, 303 insertions(+), 62 deletions(-) rename DATASET.md => docs/dataset.md (100%) mode change 100755 => 100644 docs/norm_stats_fast.md create mode 100644 scripts/download_checkpoints.py diff --git a/.gitignore b/.gitignore index b5835bb..5cffc9b 100644 --- a/.gitignore +++ b/.gitignore @@ -9,8 +9,10 @@ __pycache__/ # Distribution / packaging .Python build/ +checkpoints/ develop-eggs/ dist/ +data/ downloads/ eggs/ .eggs/ diff --git a/README.md b/README.md index fb485b2..5e5afbc 100755 --- a/README.md +++ b/README.md @@ -115,11 +115,25 @@ Download the Kai0 dataset so it is available under `./data` for training and eva python scripts/download_dataset.py ``` -This fetches the full dataset from [Hugging Face](https://huggingface.co/datasets/OpenDriveLab-org/Kai0) into `./data` (FlattenFold, HangCloth, TeeShirtSort). To download only specific tasks or use a custom path, see [DATASET.md](DATASET.md#step-1-download-the-dataset). +This fetches the full dataset from [Hugging Face](https://huggingface.co/datasets/OpenDriveLab-org/Kai0) into `./data` (FlattenFold, HangCloth, TeeShirtSort). To download only specific tasks or use a custom path, see the [dataset docs](docs/dataset.md#step-1-download-the-dataset). ### 2. Download checkpoints (optional, for testing) -We provide **one best model per task** (FlattenFold, HangCloth, TeeShirtSort) in the [Kai0 repo on Hugging Face](https://huggingface.co/OpenDriveLab-org/Kai0/tree/main). Download the task folder(s) you need and set `weight_loader` in config to the path of the downloaded checkpoint directory (see step 3 below). You can also use openpi’s pretrained π₀.5 checkpoint instead. +We provide **one best model per task** (FlattenFold, HangCloth, TeeShirtSort) in the [Kai0 repo on Hugging Face](https://huggingface.co/OpenDriveLab-org/Kai0/tree/main). + +From the repository root, you can download all best-model checkpoints to `./checkpoints` with: + +```bash +python scripts/download_checkpoints.py +``` + +To download only specific tasks or use a custom path, run: + +```bash +python scripts/download_checkpoints.py --tasks FlattenFold HangCloth --local-dir ./my_checkpoints +``` + +After download, set `weight_loader` in the training config to the path of the corresponding checkpoint directory (see step 3 below). You can also use openpi’s pretrained π₀.5 checkpoint instead. ### 3. Fine-tune with normal π₀.5 diff --git a/DATASET.md b/docs/dataset.md similarity index 100% rename from DATASET.md rename to docs/dataset.md diff --git a/docs/norm_stats_fast.md b/docs/norm_stats_fast.md old mode 100755 new mode 100644 index bc8f72c..3cb011b --- a/docs/norm_stats_fast.md +++ b/docs/norm_stats_fast.md @@ -1,69 +1,155 @@ -# Normalization statistics +## Fast normalization stats computation (`compute_norm_states_fast.py`) -Following common practice, our models normalize the proprioceptive state inputs and action targets during policy training and inference. The statistics used for normalization are computed over the training data and stored alongside the model checkpoint. +This script provides a **fast path** to compute normalization statistics for Kai0 configs by +directly reading local parquet files instead of going through the full data loader. It produces +`norm_stats` that are **compatible with the original openpi pipeline** (same `RunningStats` +implementation and batching scheme). -## Reloading normalization statistics +--- -When you fine-tune one of our models on a new dataset, you need to decide whether to (A) reuse existing normalization statistics or (B) compute new statistics over your new training data. Which option is better for you depends on the similarity of your robot and task to the robot and task distribution in the pre-training dataset. Below, we list all the available pre-training normalization statistics for each model. +### When to use this script -**If your target robot matches one of these pre-training statistics, consider reloading the same normalization statistics.** By reloading the normalization statistics, the actions in your dataset will be more "familiar" to the model, which can lead to better performance. You can reload the normalization statistics by adding an `AssetsConfig` to your training config that points to the corresponding checkpoint directory and normalization statistics ID, like below for the `Trossen` (aka ALOHA) robot statistics of the `pi0_base` checkpoint: +- You have already **downloaded the dataset locally** (e.g. under `./data`, see + [`docs/dataset.md`](./dataset.md#step-1-download-the-dataset)). +- You have a **training config** in `src/openpi/training/config.py` (e.g. + `pi05_flatten_fold_normal`) and you want to compute `norm_stats` before running + `scripts/train.py`. +- You prefer a **simpler / faster** pipeline compared to the original `compute_norm_stats.py` + while keeping numerically compatible statistics. -```python -TrainConfig( - ... - data=LeRobotAlohaDataConfig( - ... - assets=AssetsConfig( - assets_dir="gs://openpi-assets/checkpoints/pi0_base/assets", - asset_id="trossen", - ), - ), -) -``` - -For an example of a full training config that reloads normalization statistics, see the `pi0_aloha_pen_uncap` config in the [training config file](https://github.com/physical-intelligence/openpi/blob/main/src/openpi/training/config.py). +--- -**Note:** To successfully reload normalization statistics, it's important that your robot + dataset are following the action space definitions used in pre-training. We provide a detailed description of our action space definitions below. +### Script entry point -**Note #2:** Whether reloading normalization statistics is beneficial depends on the similarity of your robot and task to the robot and task distribution in the pre-training dataset. We recommend to always try both, reloading and training with a fresh set of statistics computed on your new dataset (see [main README](../README.md) for instructions on how to compute new statistics), and pick the one that works better for your task. +The script lives at: +- `scripts/compute_norm_states_fast.py` -## Provided Pre-training Normalization Statistics +Main entry: -Below is a list of all the pre-training normalization statistics we provide. We provide them for both, the `pi0_base` and `pi0_fast_base` models. For `pi0_base`, set the `assets_dir` to `gs://openpi-assets/checkpoints/pi0_base/assets` and for `pi0_fast_base`, set the `assets_dir` to `gs://openpi-assets/checkpoints/pi0_fast_base/assets`. -| Robot | Description | Asset ID | -|-------|-------------|----------| -| ALOHA | 6-DoF dual arm robot with parallel grippers | trossen | -| Mobile ALOHA | Mobile version of ALOHA mounted on a Slate base | trossen_mobile | -| Franka Emika (DROID) | 7-DoF arm with parallel gripper based on the DROID setup | droid | -| Franka Emika (non-DROID) | Franka FR3 arm with Robotiq 2F-85 gripper | franka | -| UR5e | 6-DoF UR5e arm with Robotiq 2F-85 gripper | ur5e | -| UR5e bi-manual | Bi-manual UR5e setup with Robotiq 2F-85 grippers | ur5e_dual | -| ARX | Bi-manual ARX-5 robot arm setup with parallel gripper | arx | -| ARX mobile | Mobile version of bi-manual ARX-5 robot arm setup mounted on a Slate base | arx_mobile | -| Fibocom mobile | Fibocom mobile robot with 2x ARX-5 arms | fibocom_mobile | +- `main(config_name: str, base_dir: str | None = None, max_frames: int | None = None)` +CLI is handled via [`tyro`](https://github.com/brentyi/tyro), so you call it from the repo root as: -## Pi0 Model Action Space Definitions - -Out of the box, both the `pi0_base` and `pi0_fast_base` use the following action space definitions (left and right are defined looking from behind the robot towards the workspace): +```bash +uv run python scripts/compute_norm_states_fast.py --config-name [--base-dir ] [--max-frames N] ``` - "dim_0:dim_5": "left arm joint angles", - "dim_6": "left arm gripper position", - "dim_7:dim_12": "right arm joint angles (for bi-manual only)", - "dim_13": "right arm gripper position (for bi-manual only)", - # For mobile robots: - "dim_14:dim_15": "x-y base velocity (for mobile robots only)", +--- + +### Arguments + +- **`--config-name`** (`str`, required) + - Name of the TrainConfig defined in `src/openpi/training/config.py`, e.g.: + - `pi05_flatten_fold_normal` + - `pi05_tee_shirt_sort_normal` + - `pi05_hang_cloth_normal` + - Internally resolved via `_config.get_config(config_name)`. + +- **`--base-dir`** (`str`, optional) + - Base directory containing the parquet data for this config. + - If omitted, the script will read it from `config.data`: + - `data_config = config.data.create(config.assets_dirs, config.model)` + - `base_dir` defaults to `data_config.repo_id` + - This means you can either: + - Set `repo_id` in the config to your local dataset path (e.g. + `/data/FlattenFold/base`), or + - Keep `repo_id` as-is and pass `--base-dir` explicitly to point to your local copy. + +- **`--max-frames`** (`int`, optional) + - If set, stops after processing at most `max_frames` frames across all parquet files. + - Useful for **quick sanity checks** or debugging smaller subsets. + +--- + +### What the script does + +1. **Load config** + - Uses `_config.get_config(config_name)` to get the `TrainConfig`. + - Calls `config.data.create(config.assets_dirs, config.model)` to build a data config. + - Reads `action_dim` from `config.model.action_dim`. + +2. **Resolve input data directory** + - If `base_dir` is not provided: + - Uses `data_config.repo_id` as the base directory. + - Prints a message like: + - `Auto-detected base directory from config: ` + - Verifies that the directory exists. + +3. **Scan parquet files** + - Recursively walks `base_dir` and collects all files ending with `.parquet`. + - Sorts them lexicographically for **deterministic ordering** (matches dataset order). + +4. **Read and process data** + - For each parquet file: + - Loads it with `pandas.read_parquet`. + - Expects columns: + - `observation.state` + - `action` + - For each row: + - Extracts `state` and `action` arrays. + - Applies: + - `process_state(state, action_dim)` + - `process_actions(actions, action_dim)` + - These helpers: + - **Pad** to `action_dim` (if dimension is smaller). + - **Clip abnormal values** outside \([-π, π]\) to 0 (for robustness, consistent with `FakeInputs` logic). + - Accumulates processed arrays into: + - `collected_data["state"]` + - `collected_data["actions"]` + - Maintains a running `total_frames` counter and respects `max_frames` if provided. + +5. **Concatenate and pad** + - Concatenates all collected batches per key: + - `all_data["state"]`, `all_data["actions"]` + - Ensures the last dimension matches `action_dim` (pads with zeros if needed). + +6. **Compute statistics with `RunningStats`** + - Initializes one `normalize.RunningStats()` per key (`state`, `actions`). + - Feeds data in **batches of 32** to match the original implementation’s floating-point + accumulation behavior. + - For each key, computes: + - `mean`, `std`, `q01`, `q99`, etc. + +7. **Save `norm_stats`** + - Collects results into a dict `norm_stats`. + - Saves them with `openpi.shared.normalize.save` to: + - `output_path = config.assets_dirs / data_config.repo_id` + - Prints the output path and a success message: + - `✅ Normalization stats saved to ` + +> **Note:** The save logic mirrors the original openpi `compute_norm_stats.py` behavior so that +> training code can load `norm_stats` transparently. + +--- + +### Typical workflow with Kai0 configs + +1. **Download dataset** + - Follow [`docs/dataset.md`](./dataset.md#step-1-download-the-dataset) to download the Kai0 + dataset under `./data` at the repo root. + +2. **Set config paths** + - Edit `src/openpi/training/config.py` for the normal π₀.5 configs (see README `Preparation`): + - `repo_id` → absolute path to the dataset subset, e.g. + `/data/FlattenFold/base` + - `weight_loader` → path to the π₀.5 base checkpoint (e.g. Kai0 best model per task). + +3. **Compute normalization stats** + - From the repo root: + +```bash +uv run python scripts/compute_norm_states_fast.py --config-name pi05_flatten_fold_normal ``` -The proprioceptive state uses the same definitions as the action space, except for the base x-y position (the last two dimensions) for mobile robots, which we don't include in the proprioceptive state. +4. **Train** + - Then run JAX training with: -For 7-DoF robots (e.g. Franka), we use the first 7 dimensions of the action space for the joint actions, and the 8th dimension for the gripper action. +```bash +XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 \ +uv run scripts/train.py pi05_flatten_fold_normal --exp_name= +``` -General info for Pi robots: -- Joint angles are expressed in radians, with position zero corresponding to the zero position reported by each robot's interface library, except for ALOHA, where the standard ALOHA code uses a slightly different convention (see the [ALOHA example code](../examples/aloha_real/README.md) for details). -- Gripper positions are in [0.0, 1.0], with 0.0 corresponding to fully open and 1.0 corresponding to fully closed. -- Control frequencies are either 20 Hz for UR5e and Franka, and 50 Hz for ARX and Trossen (ALOHA) arms. +The training code will pick up the normalization statistics saved by this script and use them +for input normalization, in the same way as the original openpi pipeline. -For DROID, we use the original DROID action configuration, with joint velocity actions in the first 7 dimensions and gripper actions in the 8th dimension + a control frequency of 15 Hz. diff --git a/scripts/download_checkpoints.py b/scripts/download_checkpoints.py new file mode 100644 index 0000000..ec9e893 --- /dev/null +++ b/scripts/download_checkpoints.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python3 +""" +Download Kai0 best-model checkpoints from Hugging Face to the repo's ./checkpoints directory. + +Run from the repository root: + python scripts/download_checkpoints.py + +Optional: download only specific tasks or set a custom output path: + python scripts/download_checkpoints.py --tasks FlattenFold HangCloth --local-dir ./my_ckpts +""" +from __future__ import annotations + +import argparse +import sys +from multiprocessing import Process +from pathlib import Path + + +def get_repo_root() -> Path: + """Return the repository root (directory containing .git).""" + path = Path(__file__).resolve().parent.parent + if (path / ".git").exists(): + return path + # Fallback: assume cwd is repo root + return Path.cwd() + + +def main() -> int: + parser = argparse.ArgumentParser( + description="Download Kai0 best-model checkpoints from Hugging Face to ./checkpoints (or --local-dir)." + ) + parser.add_argument( + "--local-dir", + type=str, + default=None, + help="Directory to save checkpoints (default: /checkpoints)", + ) + parser.add_argument( + "--tasks", + nargs="+", + choices=["FlattenFold", "HangCloth", "TeeShirtSort"], + default=None, + help="Download only these task folders from the repo (default: all)", + ) + parser.add_argument( + "--repo-id", + type=str, + default="OpenDriveLab-org/Kai0", + help="Hugging Face repo id that hosts best-model checkpoints (default: OpenDriveLab-org/Kai0)", + ) + args = parser.parse_args() + + try: + from huggingface_hub import snapshot_download # type: ignore + except ImportError: + print("Install huggingface_hub first: pip install huggingface_hub", file=sys.stderr) + return 1 + + repo_root = get_repo_root() + local_dir = Path(args.local_dir) if args.local_dir else repo_root / "checkpoints" + local_dir = local_dir.resolve() + + allow_patterns = None + if args.tasks: + # Each task corresponds to a top-level folder in the repo. + allow_patterns = [f"{t}/*" for t in args.tasks] + allow_patterns.append("README.md") + + print(f"Downloading checkpoints to {local_dir}") + print(f"Repo: {args.repo_id}" + (f", tasks: {args.tasks}" if args.tasks else " (all tasks)")) + + # Run snapshot_download in a separate process so Ctrl+C in the main process + # can reliably terminate the download, even if the library swallows signals. + def _worker(): + snapshot_download( + repo_id=args.repo_id, + repo_type="model", + local_dir=str(local_dir), + local_dir_use_symlinks=False, + allow_patterns=allow_patterns, + ) + + proc = Process(target=_worker) + proc.start() + + try: + proc.join() + except KeyboardInterrupt: + print( + "\nCheckpoint download interrupted by user (Ctrl+C). Terminating download process...", + file=sys.stderr, + ) + proc.terminate() + proc.join() + print("Partial checkpoint data may remain in:", local_dir, file=sys.stderr) + return 130 + + if proc.exitcode != 0: + print(f"\nCheckpoint download process exited with code {proc.exitcode}", file=sys.stderr) + return proc.exitcode or 1 + + print(f"\nDone. Checkpoints are at: {local_dir}") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) + diff --git a/scripts/download_dataset.py b/scripts/download_dataset.py index 3fe2beb..866765e 100644 --- a/scripts/download_dataset.py +++ b/scripts/download_dataset.py @@ -13,6 +13,7 @@ import argparse import os import sys +from multiprocessing import Process from pathlib import Path @@ -51,7 +52,7 @@ def main() -> int: args = parser.parse_args() try: - from huggingface_hub import snapshot_download + from huggingface_hub import snapshot_download # type: ignore except ImportError: print("Install huggingface_hub first: pip install huggingface_hub", file=sys.stderr) return 1 @@ -69,15 +70,37 @@ def main() -> int: print(f"Downloading dataset to {local_dir}") print(f"Repo: {args.repo_id}" + (f", tasks: {args.tasks}" if args.tasks else " (all tasks)")) - snapshot_download( - repo_id=args.repo_id, - repo_type="dataset", - local_dir=str(local_dir), - local_dir_use_symlinks=False, - allow_patterns=allow_patterns, - ) + # Run snapshot_download in a separate process so Ctrl+C in the main process + # can reliably terminate the download, even if the library swallows signals. + def _worker(): + snapshot_download( + repo_id=args.repo_id, + repo_type="dataset", + local_dir=str(local_dir), + local_dir_use_symlinks=False, + allow_patterns=allow_patterns, + ) + + proc = Process(target=_worker) + proc.start() + + try: + proc.join() + except KeyboardInterrupt: + print( + "\nDownload interrupted by user (Ctrl+C). Terminating download process...", + file=sys.stderr, + ) + proc.terminate() + proc.join() + print("Partial data may remain in:", local_dir, file=sys.stderr) + return 130 + + if proc.exitcode != 0: + print(f"\nDownload process exited with code {proc.exitcode}", file=sys.stderr) + return proc.exitcode or 1 - print(f"Done. Dataset is at: {local_dir}") + print(f"\nDone. Dataset is at: {local_dir}") return 0 diff --git a/scripts/train.py b/scripts/train.py index 5d28941..5a37560 100755 --- a/scripts/train.py +++ b/scripts/train.py @@ -15,6 +15,9 @@ import optax import tqdm_loggable.auto as tqdm import wandb +import shutil +from pathlib import Path + import openpi.models.model as _model import openpi.shared.array_typing as at @@ -215,6 +218,11 @@ def main(config: _config.TrainConfig): overwrite=config.overwrite, resume=config.resume, ) + + dst_dir = config.checkpoint_dir + src_file = Path(config.data.repo_id) / 'norm_stats.json' + shutil.copy(src_file, dst_dir) + init_wandb(config, resuming=resuming, enabled=config.wandb_enabled) data_loader = _data_loader.create_data_loader( @@ -270,7 +278,7 @@ def main(config: _config.TrainConfig): batch = next(data_iter) if (step % config.save_interval == 0 and step > start_step) or step == config.num_train_steps - 1: - _checkpoints.save_state(checkpoint_manager, train_state, data_loader, step) + _checkpoints.save_state(checkpoint_manager, train_state, data_loader, step, config.save_train_state) logging.info("Waiting for checkpoint manager to finish") checkpoint_manager.wait_until_finished() From 1e15fb7233a1655bee98812bf73f2d97a5fdadf3 Mon Sep 17 00:00:00 2001 From: Tradewindycc Date: Tue, 10 Feb 2026 17:30:29 +0800 Subject: [PATCH 4/4] minor touch --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 5e5afbc..6722394 100755 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# χ₀: Resource-Aware Robust Manipulation viaTaming Distributional Inconsistencies +# χ₀: Resource-Aware Robust Manipulation via Taming Distributional Inconsistencies