Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ __pycache__/
# Distribution / packaging
.Python
build/
checkpoints/
develop-eggs/
dist/
data/
downloads/
eggs/
.eggs/
Expand Down
27 changes: 20 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# χ₀: Resource-Aware Robust Manipulation viaTaming Distributional Inconsistencies

# χ₀: Resource-Aware Robust Manipulation via Taming Distributional Inconsistencies

<div id="top" align="center">

Expand All @@ -16,6 +15,7 @@
</div>

χ₀ (**kai0**) is a resource-efficient framework for achieving production-level robustness in robotic manipulation by taming distributional inconsistencies.
<!-- This repository is built on top of [openpi](https://github.com/Physical-Intelligence/openpi), the open-source models and packages for robotics published by the [Physical Intelligence team](https://www.physicalintelligence.company/). -->

χ₀ addresses the systematic distributional shift among the human demonstration distribution ($P_\text{train}$), the inductive bias learned by the policy ($Q_\text{model}$), and the test-time execution distribution ($P_\text{test}$) through three technical modules:

Expand All @@ -31,7 +31,7 @@ https://github.com/user-attachments/assets/e662f096-d273-4458-abd4-e12b9685a9bc

## Table of Contents

- [Updates](#updates)
- [Update](#update)
- [Acknowledgement](#acknowledgement)
- [Requirements](#requirements)
- [Compute](#compute)
Expand Down Expand Up @@ -75,7 +75,7 @@ This repository is built on top of [openpi](https://github.com/Physical-Intellig

For Model Arithmetic (mixing checkpoints), GPU memory requirements depend on the model size and number of checkpoints being mixed. A single A100 (80GB) is sufficient for most use cases.

The repo has been tested with Ubuntu 22.04.
Non-edge components (e.g., Policy Training, Model Arithmetic) have been tested on Ubuntu 22.04.

### Hardware

Expand Down Expand Up @@ -112,15 +112,28 @@ uv pip install safetensors
Download the Kai0 dataset so it is available under `./data` for training and evaluation. From the repository root, run:

```bash
pip install huggingface_hub # if not already installed
python scripts/download_dataset.py
```

This fetches the full dataset from [Hugging Face](https://huggingface.co/datasets/OpenDriveLab-org/Kai0) into `./data` (FlattenFold, HangCloth, TeeShirtSort). To download only specific tasks or use a custom path, see [DATASET.md](DATASET.md#step-1-download-the-dataset).
This fetches the full dataset from [Hugging Face](https://huggingface.co/datasets/OpenDriveLab-org/Kai0) into `./data` (FlattenFold, HangCloth, TeeShirtSort). To download only specific tasks or use a custom path, see the [dataset docs](docs/dataset.md#step-1-download-the-dataset).

### 2. Download checkpoints (optional, for testing)

We provide **one best model per task** (FlattenFold, HangCloth, TeeShirtSort) in the [Kai0 repo on Hugging Face](https://huggingface.co/OpenDriveLab-org/Kai0/tree/main). Download the task folder(s) you need and set `weight_loader` in config to the path of the downloaded checkpoint directory (see step 3 below). You can also use openpi’s pretrained π₀.5 checkpoint instead.
We provide **one best model per task** (FlattenFold, HangCloth, TeeShirtSort) in the [Kai0 repo on Hugging Face](https://huggingface.co/OpenDriveLab-org/Kai0/tree/main).

From the repository root, you can download all best-model checkpoints to `./checkpoints` with:

```bash
python scripts/download_checkpoints.py
```

To download only specific tasks or use a custom path, run:

```bash
python scripts/download_checkpoints.py --tasks FlattenFold HangCloth --local-dir ./my_checkpoints
```

After download, set `weight_loader` in the training config to the path of the corresponding checkpoint directory (see step 3 below). You can also use openpi’s pretrained π₀.5 checkpoint instead.

### 3. Fine-tune with normal π₀.5

Expand Down
File renamed without changes.
186 changes: 136 additions & 50 deletions docs/norm_stats_fast.md
100755 → 100644
Original file line number Diff line number Diff line change
@@ -1,69 +1,155 @@
# Normalization statistics
## Fast normalization stats computation (`compute_norm_states_fast.py`)

Following common practice, our models normalize the proprioceptive state inputs and action targets during policy training and inference. The statistics used for normalization are computed over the training data and stored alongside the model checkpoint.
This script provides a **fast path** to compute normalization statistics for Kai0 configs by
directly reading local parquet files instead of going through the full data loader. It produces
`norm_stats` that are **compatible with the original openpi pipeline** (same `RunningStats`
implementation and batching scheme).

## Reloading normalization statistics
---

When you fine-tune one of our models on a new dataset, you need to decide whether to (A) reuse existing normalization statistics or (B) compute new statistics over your new training data. Which option is better for you depends on the similarity of your robot and task to the robot and task distribution in the pre-training dataset. Below, we list all the available pre-training normalization statistics for each model.
### When to use this script

**If your target robot matches one of these pre-training statistics, consider reloading the same normalization statistics.** By reloading the normalization statistics, the actions in your dataset will be more "familiar" to the model, which can lead to better performance. You can reload the normalization statistics by adding an `AssetsConfig` to your training config that points to the corresponding checkpoint directory and normalization statistics ID, like below for the `Trossen` (aka ALOHA) robot statistics of the `pi0_base` checkpoint:
- You have already **downloaded the dataset locally** (e.g. under `./data`, see
[`docs/dataset.md`](./dataset.md#step-1-download-the-dataset)).
- You have a **training config** in `src/openpi/training/config.py` (e.g.
`pi05_flatten_fold_normal`) and you want to compute `norm_stats` before running
`scripts/train.py`.
- You prefer a **simpler / faster** pipeline compared to the original `compute_norm_stats.py`
while keeping numerically compatible statistics.

```python
TrainConfig(
...
data=LeRobotAlohaDataConfig(
...
assets=AssetsConfig(
assets_dir="gs://openpi-assets/checkpoints/pi0_base/assets",
asset_id="trossen",
),
),
)
```

For an example of a full training config that reloads normalization statistics, see the `pi0_aloha_pen_uncap` config in the [training config file](https://github.com/physical-intelligence/openpi/blob/main/src/openpi/training/config.py).
---

**Note:** To successfully reload normalization statistics, it's important that your robot + dataset are following the action space definitions used in pre-training. We provide a detailed description of our action space definitions below.
### Script entry point

**Note #2:** Whether reloading normalization statistics is beneficial depends on the similarity of your robot and task to the robot and task distribution in the pre-training dataset. We recommend to always try both, reloading and training with a fresh set of statistics computed on your new dataset (see [main README](../README.md) for instructions on how to compute new statistics), and pick the one that works better for your task.
The script lives at:

- `scripts/compute_norm_states_fast.py`

## Provided Pre-training Normalization Statistics
Main entry:

Below is a list of all the pre-training normalization statistics we provide. We provide them for both, the `pi0_base` and `pi0_fast_base` models. For `pi0_base`, set the `assets_dir` to `gs://openpi-assets/checkpoints/pi0_base/assets` and for `pi0_fast_base`, set the `assets_dir` to `gs://openpi-assets/checkpoints/pi0_fast_base/assets`.
| Robot | Description | Asset ID |
|-------|-------------|----------|
| ALOHA | 6-DoF dual arm robot with parallel grippers | trossen |
| Mobile ALOHA | Mobile version of ALOHA mounted on a Slate base | trossen_mobile |
| Franka Emika (DROID) | 7-DoF arm with parallel gripper based on the DROID setup | droid |
| Franka Emika (non-DROID) | Franka FR3 arm with Robotiq 2F-85 gripper | franka |
| UR5e | 6-DoF UR5e arm with Robotiq 2F-85 gripper | ur5e |
| UR5e bi-manual | Bi-manual UR5e setup with Robotiq 2F-85 grippers | ur5e_dual |
| ARX | Bi-manual ARX-5 robot arm setup with parallel gripper | arx |
| ARX mobile | Mobile version of bi-manual ARX-5 robot arm setup mounted on a Slate base | arx_mobile |
| Fibocom mobile | Fibocom mobile robot with 2x ARX-5 arms | fibocom_mobile |
- `main(config_name: str, base_dir: str | None = None, max_frames: int | None = None)`

CLI is handled via [`tyro`](https://github.com/brentyi/tyro), so you call it from the repo root as:

## Pi0 Model Action Space Definitions

Out of the box, both the `pi0_base` and `pi0_fast_base` use the following action space definitions (left and right are defined looking from behind the robot towards the workspace):
```bash
uv run python scripts/compute_norm_states_fast.py --config-name <config_name> [--base-dir <path>] [--max-frames N]
```
"dim_0:dim_5": "left arm joint angles",
"dim_6": "left arm gripper position",
"dim_7:dim_12": "right arm joint angles (for bi-manual only)",
"dim_13": "right arm gripper position (for bi-manual only)",

# For mobile robots:
"dim_14:dim_15": "x-y base velocity (for mobile robots only)",
---

### Arguments

- **`--config-name`** (`str`, required)
- Name of the TrainConfig defined in `src/openpi/training/config.py`, e.g.:
- `pi05_flatten_fold_normal`
- `pi05_tee_shirt_sort_normal`
- `pi05_hang_cloth_normal`
- Internally resolved via `_config.get_config(config_name)`.

- **`--base-dir`** (`str`, optional)
- Base directory containing the parquet data for this config.
- If omitted, the script will read it from `config.data`:
- `data_config = config.data.create(config.assets_dirs, config.model)`
- `base_dir` defaults to `data_config.repo_id`
- This means you can either:
- Set `repo_id` in the config to your local dataset path (e.g.
`<path_to_repo_root>/data/FlattenFold/base`), or
- Keep `repo_id` as-is and pass `--base-dir` explicitly to point to your local copy.

- **`--max-frames`** (`int`, optional)
- If set, stops after processing at most `max_frames` frames across all parquet files.
- Useful for **quick sanity checks** or debugging smaller subsets.

---

### What the script does

1. **Load config**
- Uses `_config.get_config(config_name)` to get the `TrainConfig`.
- Calls `config.data.create(config.assets_dirs, config.model)` to build a data config.
- Reads `action_dim` from `config.model.action_dim`.

2. **Resolve input data directory**
- If `base_dir` is not provided:
- Uses `data_config.repo_id` as the base directory.
- Prints a message like:
- `Auto-detected base directory from config: <base_dir>`
- Verifies that the directory exists.

3. **Scan parquet files**
- Recursively walks `base_dir` and collects all files ending with `.parquet`.
- Sorts them lexicographically for **deterministic ordering** (matches dataset order).

4. **Read and process data**
- For each parquet file:
- Loads it with `pandas.read_parquet`.
- Expects columns:
- `observation.state`
- `action`
- For each row:
- Extracts `state` and `action` arrays.
- Applies:
- `process_state(state, action_dim)`
- `process_actions(actions, action_dim)`
- These helpers:
- **Pad** to `action_dim` (if dimension is smaller).
- **Clip abnormal values** outside \([-π, π]\) to 0 (for robustness, consistent with `FakeInputs` logic).
- Accumulates processed arrays into:
- `collected_data["state"]`
- `collected_data["actions"]`
- Maintains a running `total_frames` counter and respects `max_frames` if provided.

5. **Concatenate and pad**
- Concatenates all collected batches per key:
- `all_data["state"]`, `all_data["actions"]`
- Ensures the last dimension matches `action_dim` (pads with zeros if needed).

6. **Compute statistics with `RunningStats`**
- Initializes one `normalize.RunningStats()` per key (`state`, `actions`).
- Feeds data in **batches of 32** to match the original implementation’s floating-point
accumulation behavior.
- For each key, computes:
- `mean`, `std`, `q01`, `q99`, etc.

7. **Save `norm_stats`**
- Collects results into a dict `norm_stats`.
- Saves them with `openpi.shared.normalize.save` to:
- `output_path = config.assets_dirs / data_config.repo_id`
- Prints the output path and a success message:
- `✅ Normalization stats saved to <output_path>`

> **Note:** The save logic mirrors the original openpi `compute_norm_stats.py` behavior so that
> training code can load `norm_stats` transparently.

---

### Typical workflow with Kai0 configs

1. **Download dataset**
- Follow [`docs/dataset.md`](./dataset.md#step-1-download-the-dataset) to download the Kai0
dataset under `./data` at the repo root.

2. **Set config paths**
- Edit `src/openpi/training/config.py` for the normal π₀.5 configs (see README `Preparation`):
- `repo_id` → absolute path to the dataset subset, e.g.
`<path_to_repo_root>/data/FlattenFold/base`
- `weight_loader` → path to the π₀.5 base checkpoint (e.g. Kai0 best model per task).

3. **Compute normalization stats**
- From the repo root:

```bash
uv run python scripts/compute_norm_states_fast.py --config-name pi05_flatten_fold_normal
```

The proprioceptive state uses the same definitions as the action space, except for the base x-y position (the last two dimensions) for mobile robots, which we don't include in the proprioceptive state.
4. **Train**
- Then run JAX training with:

For 7-DoF robots (e.g. Franka), we use the first 7 dimensions of the action space for the joint actions, and the 8th dimension for the gripper action.
```bash
XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 \
uv run scripts/train.py pi05_flatten_fold_normal --exp_name=<your_experiment_name>
```

General info for Pi robots:
- Joint angles are expressed in radians, with position zero corresponding to the zero position reported by each robot's interface library, except for ALOHA, where the standard ALOHA code uses a slightly different convention (see the [ALOHA example code](../examples/aloha_real/README.md) for details).
- Gripper positions are in [0.0, 1.0], with 0.0 corresponding to fully open and 1.0 corresponding to fully closed.
- Control frequencies are either 20 Hz for UR5e and Franka, and 50 Hz for ARX and Trossen (ALOHA) arms.
The training code will pick up the normalization statistics saved by this script and use them
for input normalization, in the same way as the original openpi pipeline.

For DROID, we use the original DROID action configuration, with joint velocity actions in the first 7 dimensions and gripper actions in the 8th dimension + a control frequency of 15 Hz.
108 changes: 108 additions & 0 deletions scripts/download_checkpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
#!/usr/bin/env python3
"""
Download Kai0 best-model checkpoints from Hugging Face to the repo's ./checkpoints directory.

Run from the repository root:
python scripts/download_checkpoints.py

Optional: download only specific tasks or set a custom output path:
python scripts/download_checkpoints.py --tasks FlattenFold HangCloth --local-dir ./my_ckpts
"""
from __future__ import annotations

import argparse
import sys
from multiprocessing import Process
from pathlib import Path


def get_repo_root() -> Path:
"""Return the repository root (directory containing .git)."""
path = Path(__file__).resolve().parent.parent
if (path / ".git").exists():
return path
# Fallback: assume cwd is repo root
return Path.cwd()


def main() -> int:
parser = argparse.ArgumentParser(
description="Download Kai0 best-model checkpoints from Hugging Face to ./checkpoints (or --local-dir)."
)
parser.add_argument(
"--local-dir",
type=str,
default=None,
help="Directory to save checkpoints (default: <repo_root>/checkpoints)",
)
parser.add_argument(
"--tasks",
nargs="+",
choices=["FlattenFold", "HangCloth", "TeeShirtSort"],
default=None,
help="Download only these task folders from the repo (default: all)",
)
parser.add_argument(
"--repo-id",
type=str,
default="OpenDriveLab-org/Kai0",
help="Hugging Face repo id that hosts best-model checkpoints (default: OpenDriveLab-org/Kai0)",
)
args = parser.parse_args()

try:
from huggingface_hub import snapshot_download # type: ignore
except ImportError:
print("Install huggingface_hub first: pip install huggingface_hub", file=sys.stderr)
return 1

repo_root = get_repo_root()
local_dir = Path(args.local_dir) if args.local_dir else repo_root / "checkpoints"
local_dir = local_dir.resolve()

allow_patterns = None
if args.tasks:
# Each task corresponds to a top-level folder in the repo.
allow_patterns = [f"{t}/*" for t in args.tasks]
allow_patterns.append("README.md")

print(f"Downloading checkpoints to {local_dir}")
print(f"Repo: {args.repo_id}" + (f", tasks: {args.tasks}" if args.tasks else " (all tasks)"))

# Run snapshot_download in a separate process so Ctrl+C in the main process
# can reliably terminate the download, even if the library swallows signals.
def _worker():
snapshot_download(
repo_id=args.repo_id,
repo_type="model",
local_dir=str(local_dir),
local_dir_use_symlinks=False,
allow_patterns=allow_patterns,
)

proc = Process(target=_worker)
proc.start()

try:
proc.join()
except KeyboardInterrupt:
print(
"\nCheckpoint download interrupted by user (Ctrl+C). Terminating download process...",
file=sys.stderr,
)
proc.terminate()
proc.join()
print("Partial checkpoint data may remain in:", local_dir, file=sys.stderr)
return 130

if proc.exitcode != 0:
print(f"\nCheckpoint download process exited with code {proc.exitcode}", file=sys.stderr)
return proc.exitcode or 1

print(f"\nDone. Checkpoints are at: {local_dir}")
return 0


if __name__ == "__main__":
sys.exit(main())

Loading
Loading