diff --git a/.gitignore b/.gitignore
index b5835bb..5cffc9b 100644
--- a/.gitignore
+++ b/.gitignore
@@ -9,8 +9,10 @@ __pycache__/
# Distribution / packaging
.Python
build/
+checkpoints/
develop-eggs/
dist/
+data/
downloads/
eggs/
.eggs/
diff --git a/README.md b/README.md
index 166561d..f1a1a58 100755
--- a/README.md
+++ b/README.md
@@ -1,5 +1,4 @@
-# χ₀: Resource-Aware Robust Manipulation viaTaming Distributional Inconsistencies
-
+# χ₀: Resource-Aware Robust Manipulation via Taming Distributional Inconsistencies
@@ -16,6 +15,7 @@
χ₀ (**kai0**) is a resource-efficient framework for achieving production-level robustness in robotic manipulation by taming distributional inconsistencies.
+
χ₀ addresses the systematic distributional shift among the human demonstration distribution ($P_\text{train}$), the inductive bias learned by the policy ($Q_\text{model}$), and the test-time execution distribution ($P_\text{test}$) through three technical modules:
@@ -31,7 +31,7 @@ https://github.com/user-attachments/assets/e662f096-d273-4458-abd4-e12b9685a9bc
## Table of Contents
-- [Updates](#updates)
+- [Update](#update)
- [Acknowledgement](#acknowledgement)
- [Requirements](#requirements)
- [Compute](#compute)
@@ -75,7 +75,7 @@ This repository is built on top of [openpi](https://github.com/Physical-Intellig
For Model Arithmetic (mixing checkpoints), GPU memory requirements depend on the model size and number of checkpoints being mixed. A single A100 (80GB) is sufficient for most use cases.
-The repo has been tested with Ubuntu 22.04.
+Non-edge components (e.g., Policy Training, Model Arithmetic) have been tested on Ubuntu 22.04.
### Hardware
@@ -112,15 +112,28 @@ uv pip install safetensors
Download the Kai0 dataset so it is available under `./data` for training and evaluation. From the repository root, run:
```bash
-pip install huggingface_hub # if not already installed
python scripts/download_dataset.py
```
-This fetches the full dataset from [Hugging Face](https://huggingface.co/datasets/OpenDriveLab-org/Kai0) into `./data` (FlattenFold, HangCloth, TeeShirtSort). To download only specific tasks or use a custom path, see [DATASET.md](DATASET.md#step-1-download-the-dataset).
+This fetches the full dataset from [Hugging Face](https://huggingface.co/datasets/OpenDriveLab-org/Kai0) into `./data` (FlattenFold, HangCloth, TeeShirtSort). To download only specific tasks or use a custom path, see the [dataset docs](docs/dataset.md#step-1-download-the-dataset).
### 2. Download checkpoints (optional, for testing)
-We provide **one best model per task** (FlattenFold, HangCloth, TeeShirtSort) in the [Kai0 repo on Hugging Face](https://huggingface.co/OpenDriveLab-org/Kai0/tree/main). Download the task folder(s) you need and set `weight_loader` in config to the path of the downloaded checkpoint directory (see step 3 below). You can also use openpi’s pretrained π₀.5 checkpoint instead.
+We provide **one best model per task** (FlattenFold, HangCloth, TeeShirtSort) in the [Kai0 repo on Hugging Face](https://huggingface.co/OpenDriveLab-org/Kai0/tree/main).
+
+From the repository root, you can download all best-model checkpoints to `./checkpoints` with:
+
+```bash
+python scripts/download_checkpoints.py
+```
+
+To download only specific tasks or use a custom path, run:
+
+```bash
+python scripts/download_checkpoints.py --tasks FlattenFold HangCloth --local-dir ./my_checkpoints
+```
+
+After download, set `weight_loader` in the training config to the path of the corresponding checkpoint directory (see step 3 below). You can also use openpi’s pretrained π₀.5 checkpoint instead.
### 3. Fine-tune with normal π₀.5
diff --git a/DATASET.md b/docs/dataset.md
similarity index 100%
rename from DATASET.md
rename to docs/dataset.md
diff --git a/docs/norm_stats_fast.md b/docs/norm_stats_fast.md
old mode 100755
new mode 100644
index bc8f72c..3cb011b
--- a/docs/norm_stats_fast.md
+++ b/docs/norm_stats_fast.md
@@ -1,69 +1,155 @@
-# Normalization statistics
+## Fast normalization stats computation (`compute_norm_states_fast.py`)
-Following common practice, our models normalize the proprioceptive state inputs and action targets during policy training and inference. The statistics used for normalization are computed over the training data and stored alongside the model checkpoint.
+This script provides a **fast path** to compute normalization statistics for Kai0 configs by
+directly reading local parquet files instead of going through the full data loader. It produces
+`norm_stats` that are **compatible with the original openpi pipeline** (same `RunningStats`
+implementation and batching scheme).
-## Reloading normalization statistics
+---
-When you fine-tune one of our models on a new dataset, you need to decide whether to (A) reuse existing normalization statistics or (B) compute new statistics over your new training data. Which option is better for you depends on the similarity of your robot and task to the robot and task distribution in the pre-training dataset. Below, we list all the available pre-training normalization statistics for each model.
+### When to use this script
-**If your target robot matches one of these pre-training statistics, consider reloading the same normalization statistics.** By reloading the normalization statistics, the actions in your dataset will be more "familiar" to the model, which can lead to better performance. You can reload the normalization statistics by adding an `AssetsConfig` to your training config that points to the corresponding checkpoint directory and normalization statistics ID, like below for the `Trossen` (aka ALOHA) robot statistics of the `pi0_base` checkpoint:
+- You have already **downloaded the dataset locally** (e.g. under `./data`, see
+ [`docs/dataset.md`](./dataset.md#step-1-download-the-dataset)).
+- You have a **training config** in `src/openpi/training/config.py` (e.g.
+ `pi05_flatten_fold_normal`) and you want to compute `norm_stats` before running
+ `scripts/train.py`.
+- You prefer a **simpler / faster** pipeline compared to the original `compute_norm_stats.py`
+ while keeping numerically compatible statistics.
-```python
-TrainConfig(
- ...
- data=LeRobotAlohaDataConfig(
- ...
- assets=AssetsConfig(
- assets_dir="gs://openpi-assets/checkpoints/pi0_base/assets",
- asset_id="trossen",
- ),
- ),
-)
-```
-
-For an example of a full training config that reloads normalization statistics, see the `pi0_aloha_pen_uncap` config in the [training config file](https://github.com/physical-intelligence/openpi/blob/main/src/openpi/training/config.py).
+---
-**Note:** To successfully reload normalization statistics, it's important that your robot + dataset are following the action space definitions used in pre-training. We provide a detailed description of our action space definitions below.
+### Script entry point
-**Note #2:** Whether reloading normalization statistics is beneficial depends on the similarity of your robot and task to the robot and task distribution in the pre-training dataset. We recommend to always try both, reloading and training with a fresh set of statistics computed on your new dataset (see [main README](../README.md) for instructions on how to compute new statistics), and pick the one that works better for your task.
+The script lives at:
+- `scripts/compute_norm_states_fast.py`
-## Provided Pre-training Normalization Statistics
+Main entry:
-Below is a list of all the pre-training normalization statistics we provide. We provide them for both, the `pi0_base` and `pi0_fast_base` models. For `pi0_base`, set the `assets_dir` to `gs://openpi-assets/checkpoints/pi0_base/assets` and for `pi0_fast_base`, set the `assets_dir` to `gs://openpi-assets/checkpoints/pi0_fast_base/assets`.
-| Robot | Description | Asset ID |
-|-------|-------------|----------|
-| ALOHA | 6-DoF dual arm robot with parallel grippers | trossen |
-| Mobile ALOHA | Mobile version of ALOHA mounted on a Slate base | trossen_mobile |
-| Franka Emika (DROID) | 7-DoF arm with parallel gripper based on the DROID setup | droid |
-| Franka Emika (non-DROID) | Franka FR3 arm with Robotiq 2F-85 gripper | franka |
-| UR5e | 6-DoF UR5e arm with Robotiq 2F-85 gripper | ur5e |
-| UR5e bi-manual | Bi-manual UR5e setup with Robotiq 2F-85 grippers | ur5e_dual |
-| ARX | Bi-manual ARX-5 robot arm setup with parallel gripper | arx |
-| ARX mobile | Mobile version of bi-manual ARX-5 robot arm setup mounted on a Slate base | arx_mobile |
-| Fibocom mobile | Fibocom mobile robot with 2x ARX-5 arms | fibocom_mobile |
+- `main(config_name: str, base_dir: str | None = None, max_frames: int | None = None)`
+CLI is handled via [`tyro`](https://github.com/brentyi/tyro), so you call it from the repo root as:
-## Pi0 Model Action Space Definitions
-
-Out of the box, both the `pi0_base` and `pi0_fast_base` use the following action space definitions (left and right are defined looking from behind the robot towards the workspace):
+```bash
+uv run python scripts/compute_norm_states_fast.py --config-name [--base-dir ] [--max-frames N]
```
- "dim_0:dim_5": "left arm joint angles",
- "dim_6": "left arm gripper position",
- "dim_7:dim_12": "right arm joint angles (for bi-manual only)",
- "dim_13": "right arm gripper position (for bi-manual only)",
- # For mobile robots:
- "dim_14:dim_15": "x-y base velocity (for mobile robots only)",
+---
+
+### Arguments
+
+- **`--config-name`** (`str`, required)
+ - Name of the TrainConfig defined in `src/openpi/training/config.py`, e.g.:
+ - `pi05_flatten_fold_normal`
+ - `pi05_tee_shirt_sort_normal`
+ - `pi05_hang_cloth_normal`
+ - Internally resolved via `_config.get_config(config_name)`.
+
+- **`--base-dir`** (`str`, optional)
+ - Base directory containing the parquet data for this config.
+ - If omitted, the script will read it from `config.data`:
+ - `data_config = config.data.create(config.assets_dirs, config.model)`
+ - `base_dir` defaults to `data_config.repo_id`
+ - This means you can either:
+ - Set `repo_id` in the config to your local dataset path (e.g.
+ `/data/FlattenFold/base`), or
+ - Keep `repo_id` as-is and pass `--base-dir` explicitly to point to your local copy.
+
+- **`--max-frames`** (`int`, optional)
+ - If set, stops after processing at most `max_frames` frames across all parquet files.
+ - Useful for **quick sanity checks** or debugging smaller subsets.
+
+---
+
+### What the script does
+
+1. **Load config**
+ - Uses `_config.get_config(config_name)` to get the `TrainConfig`.
+ - Calls `config.data.create(config.assets_dirs, config.model)` to build a data config.
+ - Reads `action_dim` from `config.model.action_dim`.
+
+2. **Resolve input data directory**
+ - If `base_dir` is not provided:
+ - Uses `data_config.repo_id` as the base directory.
+ - Prints a message like:
+ - `Auto-detected base directory from config: `
+ - Verifies that the directory exists.
+
+3. **Scan parquet files**
+ - Recursively walks `base_dir` and collects all files ending with `.parquet`.
+ - Sorts them lexicographically for **deterministic ordering** (matches dataset order).
+
+4. **Read and process data**
+ - For each parquet file:
+ - Loads it with `pandas.read_parquet`.
+ - Expects columns:
+ - `observation.state`
+ - `action`
+ - For each row:
+ - Extracts `state` and `action` arrays.
+ - Applies:
+ - `process_state(state, action_dim)`
+ - `process_actions(actions, action_dim)`
+ - These helpers:
+ - **Pad** to `action_dim` (if dimension is smaller).
+ - **Clip abnormal values** outside \([-π, π]\) to 0 (for robustness, consistent with `FakeInputs` logic).
+ - Accumulates processed arrays into:
+ - `collected_data["state"]`
+ - `collected_data["actions"]`
+ - Maintains a running `total_frames` counter and respects `max_frames` if provided.
+
+5. **Concatenate and pad**
+ - Concatenates all collected batches per key:
+ - `all_data["state"]`, `all_data["actions"]`
+ - Ensures the last dimension matches `action_dim` (pads with zeros if needed).
+
+6. **Compute statistics with `RunningStats`**
+ - Initializes one `normalize.RunningStats()` per key (`state`, `actions`).
+ - Feeds data in **batches of 32** to match the original implementation’s floating-point
+ accumulation behavior.
+ - For each key, computes:
+ - `mean`, `std`, `q01`, `q99`, etc.
+
+7. **Save `norm_stats`**
+ - Collects results into a dict `norm_stats`.
+ - Saves them with `openpi.shared.normalize.save` to:
+ - `output_path = config.assets_dirs / data_config.repo_id`
+ - Prints the output path and a success message:
+ - `✅ Normalization stats saved to `
+
+> **Note:** The save logic mirrors the original openpi `compute_norm_stats.py` behavior so that
+> training code can load `norm_stats` transparently.
+
+---
+
+### Typical workflow with Kai0 configs
+
+1. **Download dataset**
+ - Follow [`docs/dataset.md`](./dataset.md#step-1-download-the-dataset) to download the Kai0
+ dataset under `./data` at the repo root.
+
+2. **Set config paths**
+ - Edit `src/openpi/training/config.py` for the normal π₀.5 configs (see README `Preparation`):
+ - `repo_id` → absolute path to the dataset subset, e.g.
+ `/data/FlattenFold/base`
+ - `weight_loader` → path to the π₀.5 base checkpoint (e.g. Kai0 best model per task).
+
+3. **Compute normalization stats**
+ - From the repo root:
+
+```bash
+uv run python scripts/compute_norm_states_fast.py --config-name pi05_flatten_fold_normal
```
-The proprioceptive state uses the same definitions as the action space, except for the base x-y position (the last two dimensions) for mobile robots, which we don't include in the proprioceptive state.
+4. **Train**
+ - Then run JAX training with:
-For 7-DoF robots (e.g. Franka), we use the first 7 dimensions of the action space for the joint actions, and the 8th dimension for the gripper action.
+```bash
+XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 \
+uv run scripts/train.py pi05_flatten_fold_normal --exp_name=
+```
-General info for Pi robots:
-- Joint angles are expressed in radians, with position zero corresponding to the zero position reported by each robot's interface library, except for ALOHA, where the standard ALOHA code uses a slightly different convention (see the [ALOHA example code](../examples/aloha_real/README.md) for details).
-- Gripper positions are in [0.0, 1.0], with 0.0 corresponding to fully open and 1.0 corresponding to fully closed.
-- Control frequencies are either 20 Hz for UR5e and Franka, and 50 Hz for ARX and Trossen (ALOHA) arms.
+The training code will pick up the normalization statistics saved by this script and use them
+for input normalization, in the same way as the original openpi pipeline.
-For DROID, we use the original DROID action configuration, with joint velocity actions in the first 7 dimensions and gripper actions in the 8th dimension + a control frequency of 15 Hz.
diff --git a/scripts/download_checkpoints.py b/scripts/download_checkpoints.py
new file mode 100644
index 0000000..ec9e893
--- /dev/null
+++ b/scripts/download_checkpoints.py
@@ -0,0 +1,108 @@
+#!/usr/bin/env python3
+"""
+Download Kai0 best-model checkpoints from Hugging Face to the repo's ./checkpoints directory.
+
+Run from the repository root:
+ python scripts/download_checkpoints.py
+
+Optional: download only specific tasks or set a custom output path:
+ python scripts/download_checkpoints.py --tasks FlattenFold HangCloth --local-dir ./my_ckpts
+"""
+from __future__ import annotations
+
+import argparse
+import sys
+from multiprocessing import Process
+from pathlib import Path
+
+
+def get_repo_root() -> Path:
+ """Return the repository root (directory containing .git)."""
+ path = Path(__file__).resolve().parent.parent
+ if (path / ".git").exists():
+ return path
+ # Fallback: assume cwd is repo root
+ return Path.cwd()
+
+
+def main() -> int:
+ parser = argparse.ArgumentParser(
+ description="Download Kai0 best-model checkpoints from Hugging Face to ./checkpoints (or --local-dir)."
+ )
+ parser.add_argument(
+ "--local-dir",
+ type=str,
+ default=None,
+ help="Directory to save checkpoints (default: /checkpoints)",
+ )
+ parser.add_argument(
+ "--tasks",
+ nargs="+",
+ choices=["FlattenFold", "HangCloth", "TeeShirtSort"],
+ default=None,
+ help="Download only these task folders from the repo (default: all)",
+ )
+ parser.add_argument(
+ "--repo-id",
+ type=str,
+ default="OpenDriveLab-org/Kai0",
+ help="Hugging Face repo id that hosts best-model checkpoints (default: OpenDriveLab-org/Kai0)",
+ )
+ args = parser.parse_args()
+
+ try:
+ from huggingface_hub import snapshot_download # type: ignore
+ except ImportError:
+ print("Install huggingface_hub first: pip install huggingface_hub", file=sys.stderr)
+ return 1
+
+ repo_root = get_repo_root()
+ local_dir = Path(args.local_dir) if args.local_dir else repo_root / "checkpoints"
+ local_dir = local_dir.resolve()
+
+ allow_patterns = None
+ if args.tasks:
+ # Each task corresponds to a top-level folder in the repo.
+ allow_patterns = [f"{t}/*" for t in args.tasks]
+ allow_patterns.append("README.md")
+
+ print(f"Downloading checkpoints to {local_dir}")
+ print(f"Repo: {args.repo_id}" + (f", tasks: {args.tasks}" if args.tasks else " (all tasks)"))
+
+ # Run snapshot_download in a separate process so Ctrl+C in the main process
+ # can reliably terminate the download, even if the library swallows signals.
+ def _worker():
+ snapshot_download(
+ repo_id=args.repo_id,
+ repo_type="model",
+ local_dir=str(local_dir),
+ local_dir_use_symlinks=False,
+ allow_patterns=allow_patterns,
+ )
+
+ proc = Process(target=_worker)
+ proc.start()
+
+ try:
+ proc.join()
+ except KeyboardInterrupt:
+ print(
+ "\nCheckpoint download interrupted by user (Ctrl+C). Terminating download process...",
+ file=sys.stderr,
+ )
+ proc.terminate()
+ proc.join()
+ print("Partial checkpoint data may remain in:", local_dir, file=sys.stderr)
+ return 130
+
+ if proc.exitcode != 0:
+ print(f"\nCheckpoint download process exited with code {proc.exitcode}", file=sys.stderr)
+ return proc.exitcode or 1
+
+ print(f"\nDone. Checkpoints are at: {local_dir}")
+ return 0
+
+
+if __name__ == "__main__":
+ sys.exit(main())
+
diff --git a/scripts/download_dataset.py b/scripts/download_dataset.py
index 3fe2beb..866765e 100644
--- a/scripts/download_dataset.py
+++ b/scripts/download_dataset.py
@@ -13,6 +13,7 @@
import argparse
import os
import sys
+from multiprocessing import Process
from pathlib import Path
@@ -51,7 +52,7 @@ def main() -> int:
args = parser.parse_args()
try:
- from huggingface_hub import snapshot_download
+ from huggingface_hub import snapshot_download # type: ignore
except ImportError:
print("Install huggingface_hub first: pip install huggingface_hub", file=sys.stderr)
return 1
@@ -69,15 +70,37 @@ def main() -> int:
print(f"Downloading dataset to {local_dir}")
print(f"Repo: {args.repo_id}" + (f", tasks: {args.tasks}" if args.tasks else " (all tasks)"))
- snapshot_download(
- repo_id=args.repo_id,
- repo_type="dataset",
- local_dir=str(local_dir),
- local_dir_use_symlinks=False,
- allow_patterns=allow_patterns,
- )
+ # Run snapshot_download in a separate process so Ctrl+C in the main process
+ # can reliably terminate the download, even if the library swallows signals.
+ def _worker():
+ snapshot_download(
+ repo_id=args.repo_id,
+ repo_type="dataset",
+ local_dir=str(local_dir),
+ local_dir_use_symlinks=False,
+ allow_patterns=allow_patterns,
+ )
+
+ proc = Process(target=_worker)
+ proc.start()
+
+ try:
+ proc.join()
+ except KeyboardInterrupt:
+ print(
+ "\nDownload interrupted by user (Ctrl+C). Terminating download process...",
+ file=sys.stderr,
+ )
+ proc.terminate()
+ proc.join()
+ print("Partial data may remain in:", local_dir, file=sys.stderr)
+ return 130
+
+ if proc.exitcode != 0:
+ print(f"\nDownload process exited with code {proc.exitcode}", file=sys.stderr)
+ return proc.exitcode or 1
- print(f"Done. Dataset is at: {local_dir}")
+ print(f"\nDone. Dataset is at: {local_dir}")
return 0
diff --git a/scripts/train.py b/scripts/train.py
index 5d28941..5a37560 100755
--- a/scripts/train.py
+++ b/scripts/train.py
@@ -15,6 +15,9 @@
import optax
import tqdm_loggable.auto as tqdm
import wandb
+import shutil
+from pathlib import Path
+
import openpi.models.model as _model
import openpi.shared.array_typing as at
@@ -215,6 +218,11 @@ def main(config: _config.TrainConfig):
overwrite=config.overwrite,
resume=config.resume,
)
+
+ dst_dir = config.checkpoint_dir
+ src_file = Path(config.data.repo_id) / 'norm_stats.json'
+ shutil.copy(src_file, dst_dir)
+
init_wandb(config, resuming=resuming, enabled=config.wandb_enabled)
data_loader = _data_loader.create_data_loader(
@@ -270,7 +278,7 @@ def main(config: _config.TrainConfig):
batch = next(data_iter)
if (step % config.save_interval == 0 and step > start_step) or step == config.num_train_steps - 1:
- _checkpoints.save_state(checkpoint_manager, train_state, data_loader, step)
+ _checkpoints.save_state(checkpoint_manager, train_state, data_loader, step, config.save_train_state)
logging.info("Waiting for checkpoint manager to finish")
checkpoint_manager.wait_until_finished()
diff --git a/src/openpi/training/config.py b/src/openpi/training/config.py
index 7090e8b..e60dc88 100755
--- a/src/openpi/training/config.py
+++ b/src/openpi/training/config.py
@@ -1224,110 +1224,110 @@ def __post_init__(self) -> None:
#************************Advantage Estimator***************************
- TrainConfig(
- name="ADVANTAGE_TORCH_KAI0_FLATTEN_FOLD",
- advantage_estimator=True,
- model=pi0_config.AdvantageEstimatorConfig(
- pi05=True,
- loss_value_weight=1.,
- loss_action_weight=0.,
- discrete_state_input=False,
- ),
- data=LerobotAgilexDataConfig(
- repo_id = "Path/to/your/advantage/dataset",
- assets=AssetsConfig(
- assets_dir="Path/to/your/advantage/dataset/assets",
- asset_id="Your_advantage_dataset_name",
- ),
- default_prompt="Flatten and fold the cloth.",
- # * why removing "prompt" here will lead to an error in transforms.py
- repack_transforms=_transforms.Group(
- inputs=[
- _transforms.RepackTransform(
- {
- "images": {
- "top_head": "observation.images.top_head",
- "hand_left": "observation.images.hand_left",
- "hand_right": "observation.images.hand_right",
- "his_-100_top_head": "his_-100_observation.images.top_head",
- "his_-100_hand_left": "his_-100_observation.images.hand_left",
- "his_-100_hand_right": "his_-100_observation.images.hand_right",
- },
- "state": "observation.state",
- "actions": "action",
- # "prompt": "prompt", # ! Not adding this for default prompt.
- "episode_length": "episode_length",
- "frame_index": "frame_index",
- "episode_index": "episode_index",
- "progress_gt": "progress_gt",
- "stage_progress_gt": "stage_progress_gt",
- "progress": "progress",
- # "is_suboptimal": "is_suboptimal",
- }
- )
- ]
- )
- ),
- pytorch_weight_path="Path/to/your/pi05_base/checkpoint",
- num_train_steps=100_000,
- keep_period=10000,
- save_interval=10000,
- num_workers=8,
- batch_size=16, # * 1 gpus
- # batch_size=128, # * 8 gpus
- skip_norm_stats=True, # * No norm stats used.
- ),
- TrainConfig(
- name="ADVANTAGE_TORCH_PI06_FLATTEN_FOLD",
- advantage_estimator=True,
- model=pi0_config.AdvantageEstimatorConfig(
- pi05=True,
- loss_value_weight=1.,
- loss_action_weight=0., # No action loss in advantage estimator training
- discrete_state_input=False, # Not using states into prompt like pi05
- ),
- data=LerobotAgilexDataConfig(
- # repo_id = "/cpfs01/shared/filtered_cut_data/short_sleeve/flatten_fold/v9-3/1022_20_590_v9-3_2000_lerobot",
- repo_id = "Path/to/your/advantage/dataset",
- assets=AssetsConfig(
- assets_dir="Path/to/your/advantage/dataset/assets",
- asset_id="Your_advantage_dataset_name",
- ),
- default_prompt="Flatten and fold the cloth.",
- # * why removing "prompt" here will lead to an error in transforms.py
- repack_transforms=_transforms.Group(
- inputs=[
- _transforms.RepackTransform(
- {
- "images": {
- "top_head": "observation.images.top_head",
- "hand_left": "observation.images.hand_left",
- "hand_right": "observation.images.hand_right",
- },
- "state": "observation.state",
- "actions": "action",
- # "prompt": "prompt", # No need if default prompt is used.
- "episode_length": "episode_length",
- "frame_index": "frame_index",
- "episode_index": "episode_index",
- "progress_gt": "progress_gt",
- "stage_progress_gt": "stage_progress_gt",
- "progress": "progress",
- # "is_suboptimal": "is_suboptimal",
- }
- )
- ]
- )
- ),
- pytorch_weight_path="Path/to/your/pi06_base/checkpoint",
- num_train_steps=100_000,
- keep_period=10000,
- save_interval=10000,
- num_workers=55,
- # batch_size=16, # * 1 gpus
- batch_size=18*8, # * 8 gpus
- skip_norm_stats=True, # * No norm stats used.
- ),
+ # TrainConfig(
+ # name="ADVANTAGE_TORCH_KAI0_FLATTEN_FOLD",
+ # advantage_estimator=True,
+ # model=pi0_config.AdvantageEstimatorConfig(
+ # pi05=True,
+ # loss_value_weight=1.,
+ # loss_action_weight=0.,
+ # discrete_state_input=False,
+ # ),
+ # data=LerobotAgilexDataConfig(
+ # repo_id = "Path/to/your/advantage/dataset",
+ # assets=AssetsConfig(
+ # assets_dir="Path/to/your/advantage/dataset/assets",
+ # asset_id="Your_advantage_dataset_name",
+ # ),
+ # default_prompt="Flatten and fold the cloth.",
+ # # * why removing "prompt" here will lead to an error in transforms.py
+ # repack_transforms=_transforms.Group(
+ # inputs=[
+ # _transforms.RepackTransform(
+ # {
+ # "images": {
+ # "top_head": "observation.images.top_head",
+ # "hand_left": "observation.images.hand_left",
+ # "hand_right": "observation.images.hand_right",
+ # "his_-100_top_head": "his_-100_observation.images.top_head",
+ # "his_-100_hand_left": "his_-100_observation.images.hand_left",
+ # "his_-100_hand_right": "his_-100_observation.images.hand_right",
+ # },
+ # "state": "observation.state",
+ # "actions": "action",
+ # # "prompt": "prompt", # ! Not adding this for default prompt.
+ # "episode_length": "episode_length",
+ # "frame_index": "frame_index",
+ # "episode_index": "episode_index",
+ # "progress_gt": "progress_gt",
+ # "stage_progress_gt": "stage_progress_gt",
+ # "progress": "progress",
+ # # "is_suboptimal": "is_suboptimal",
+ # }
+ # )
+ # ]
+ # )
+ # ),
+ # pytorch_weight_path="Path/to/your/pi05_base/checkpoint",
+ # num_train_steps=100_000,
+ # keep_period=10000,
+ # save_interval=10000,
+ # num_workers=8,
+ # batch_size=16, # * 1 gpus
+ # # batch_size=128, # * 8 gpus
+ # skip_norm_stats=True, # * No norm stats used.
+ # ),
+ # TrainConfig(
+ # name="ADVANTAGE_TORCH_PI06_FLATTEN_FOLD",
+ # advantage_estimator=True,
+ # model=pi0_config.AdvantageEstimatorConfig(
+ # pi05=True,
+ # loss_value_weight=1.,
+ # loss_action_weight=0., # No action loss in advantage estimator training
+ # discrete_state_input=False, # Not using states into prompt like pi05
+ # ),
+ # data=LerobotAgilexDataConfig(
+ # # repo_id = "/cpfs01/shared/filtered_cut_data/short_sleeve/flatten_fold/v9-3/1022_20_590_v9-3_2000_lerobot",
+ # repo_id = "Path/to/your/advantage/dataset",
+ # assets=AssetsConfig(
+ # assets_dir="Path/to/your/advantage/dataset/assets",
+ # asset_id="Your_advantage_dataset_name",
+ # ),
+ # default_prompt="Flatten and fold the cloth.",
+ # # * why removing "prompt" here will lead to an error in transforms.py
+ # repack_transforms=_transforms.Group(
+ # inputs=[
+ # _transforms.RepackTransform(
+ # {
+ # "images": {
+ # "top_head": "observation.images.top_head",
+ # "hand_left": "observation.images.hand_left",
+ # "hand_right": "observation.images.hand_right",
+ # },
+ # "state": "observation.state",
+ # "actions": "action",
+ # # "prompt": "prompt", # No need if default prompt is used.
+ # "episode_length": "episode_length",
+ # "frame_index": "frame_index",
+ # "episode_index": "episode_index",
+ # "progress_gt": "progress_gt",
+ # "stage_progress_gt": "stage_progress_gt",
+ # "progress": "progress",
+ # # "is_suboptimal": "is_suboptimal",
+ # }
+ # )
+ # ]
+ # )
+ # ),
+ # pytorch_weight_path="Path/to/your/pi06_base/checkpoint",
+ # num_train_steps=100_000,
+ # keep_period=10000,
+ # save_interval=10000,
+ # num_workers=55,
+ # # batch_size=16, # * 1 gpus
+ # batch_size=18*8, # * 8 gpus
+ # skip_norm_stats=True, # * No norm stats used.
+ # ),
#************************advantage estimator***************************
# RoboArena & PolaRiS configs.
*roboarena_config.get_roboarena_configs(),