Skip to content

Commit ae730dc

Browse files
committed
[UPDATE]: update tda
1 parent bf793ff commit ae730dc

1,169 files changed

Lines changed: 349011 additions & 10 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

src/openpi/models/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ class ModelType(enum.Enum):
3333
PI0 = "pi0"
3434
PI0_FAST = "pi0_fast"
3535
PI05 = "pi05"
36+
PI0_RTC = "pi0_rtc"
37+
PI05_RTC = "pi05_rtc"
3638

3739

3840
# The model always expects these images

src/openpi/models/pi0_config.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
if TYPE_CHECKING:
1515
from openpi.models.pi0 import Pi0
16+
from openpi.models.pi0_rtc import Pi0RTC
1617

1718

1819
@dataclasses.dataclass(frozen=True)
@@ -107,6 +108,33 @@ def get_freeze_filter(self) -> nnx.filterlib.Filter:
107108
return nnx.Nothing
108109
return nnx.All(*filters)
109110

111+
112+
@dataclasses.dataclass(frozen=True)
113+
class Pi0RTCConfig(Pi0Config):
114+
"""Config for Pi0RTC (real-time control) model. Uses same architecture as Pi0/Pi05 but sample_actions supports
115+
prev_action_chunk, inference_delay, execute_horizon for RTC guidance. Use this config when serving
116+
for RTC inference (e.g. agilex_inference_openpi_rtc.py). Set pi05=True for Pi05-based RTC (model_type PI05_RTC)."""
117+
118+
@property
119+
@override
120+
def model_type(self) -> _model.ModelType:
121+
return _model.ModelType.PI05_RTC if self.pi05 else _model.ModelType.PI0_RTC
122+
123+
@override
124+
def create(self, rng: at.KeyArrayLike) -> "Pi0RTC":
125+
from openpi.models.pi0_rtc import Pi0RTC
126+
127+
return Pi0RTC(self, rngs=nnx.Rngs(rng))
128+
129+
@override
130+
def load_pytorch(self, train_config, weight_path: str):
131+
"""RTC model is JAX-only; use a JAX checkpoint with serve_policy and Pi0RTCConfig."""
132+
raise NotImplementedError(
133+
"Pi0RTC is only supported with JAX checkpoints. Use a checkpoint saved from OpenPi JAX training "
134+
"(params directory, not model.safetensors) and serve with --policy.config=pi05_rtc_flatten_fold_inference (or your RTC config name)."
135+
)
136+
137+
110138
@dataclasses.dataclass(frozen=True)
111139
class AdvantageEstimatorConfig(Pi0Config):
112140
# * Custom

src/openpi/models/pi0_rtc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ def rtc_step(carry):
322322
x_t_for_denoise = x_t
323323
if mask_prefix_delay and provided_dim > 0:
324324
mask_time = (jnp.arange(self.action_horizon) < d).astype(bool)[None, :, None]
325-
# 仅覆盖提供的维度,其余保持 x_t 原值
325+
# Overwrite only the provided dims in the delay prefix; leave the rest as x_t.
326326
overwrite = jnp.where(mask_time, prev_chunk[..., :provided_dim], x_t_for_denoise[..., :provided_dim])
327327
x_t_for_denoise = x_t_for_denoise.at[..., :provided_dim].set(overwrite)
328328

src/openpi/policies/agilex_policy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ class AgilexInputs(transforms.DataTransformFn):
5252
mask_state: bool = False
5353

5454
def __call__(self, data: dict) -> dict:
55-
# We only mask padding for pi0 model, not pi0-FAST
56-
mask_padding = self.model_type == _model.ModelType.PI0
55+
# We only mask padding for pi0/pi0_rtc model, not pi05/pi05_rtc or pi0-FAST
56+
mask_padding = self.model_type in (_model.ModelType.PI0, _model.ModelType.PI0_RTC)
5757

5858
in_images = data["images"]
5959

src/openpi/policies/arx_policy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ class ARXInputs(transforms.DataTransformFn):
5252
mask_state: bool = False
5353

5454
def __call__(self, data: dict) -> dict:
55-
# We only mask padding for pi0 model, not pi0-FAST
56-
mask_padding = self.model_type == _model.ModelType.PI0
55+
# We only mask padding for pi0/pi0_rtc model, not pi05/pi05_rtc or pi0-FAST
56+
mask_padding = self.model_type in (_model.ModelType.PI0, _model.ModelType.PI0_RTC)
5757

5858
in_images = data["images"]
5959

src/openpi/policies/droid_policy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def __call__(self, data: dict) -> dict:
4545
wrist_image = _parse_image(data["observation/wrist_image_left"])
4646

4747
match self.model_type:
48-
case _model.ModelType.PI0 | _model.ModelType.PI05:
48+
case _model.ModelType.PI0 | _model.ModelType.PI05 | _model.ModelType.PI0_RTC | _model.ModelType.PI05_RTC:
4949
names = ("base_0_rgb", "left_wrist_0_rgb", "right_wrist_0_rgb")
5050
images = (base_image, wrist_image, np.zeros_like(base_image))
5151
image_masks = (np.True_, np.True_, np.False_)

src/openpi/training/config.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ class ModelTransformFactory(GroupFactory):
115115

116116
def __call__(self, model_config: _model.BaseModelConfig) -> _transforms.Group:
117117
match model_config.model_type:
118-
case _model.ModelType.PI0:
118+
case _model.ModelType.PI0 | _model.ModelType.PI0_RTC:
119119
return _transforms.Group(
120120
inputs=[
121121
_transforms.InjectDefaultPrompt(self.default_prompt),
@@ -126,7 +126,7 @@ def __call__(self, model_config: _model.BaseModelConfig) -> _transforms.Group:
126126
_transforms.PadStatesAndActions(model_config.action_dim),
127127
],
128128
)
129-
case _model.ModelType.PI05:
129+
case _model.ModelType.PI05 | _model.ModelType.PI05_RTC:
130130
assert isinstance(model_config, pi0_config.Pi0Config)
131131
return _transforms.Group(
132132
inputs=[
@@ -187,7 +187,7 @@ def create_base_config(self, assets_dirs: pathlib.Path, model_config: _model.Bas
187187
repo_id=repo_id,
188188
asset_id=asset_id,
189189
norm_stats=self._load_norm_stats(epath.Path(self.assets.assets_dir or assets_dirs), asset_id),
190-
use_quantile_norm=model_config.model_type != ModelType.PI0,
190+
use_quantile_norm=model_config.model_type not in (ModelType.PI0, ModelType.PI0_RTC),
191191
)
192192

193193
def _load_norm_stats(self, assets_dir: epath.Path, asset_id: str | None) -> dict[str, _transforms.NormStats] | None:
@@ -1371,6 +1371,23 @@ def __post_init__(self) -> None:
13711371
num_workers=8,
13721372
batch_size=256,
13731373
),
1374+
1375+
#**************************FlattenFold RTC Inference*******************************
1376+
# Use this config when serving the policy for agilex_inference_openpi_rtc.py (JAX checkpoints only).
1377+
TrainConfig(
1378+
name="pi05_rtc_flatten_fold_inference",
1379+
model=pi0_config.Pi0RTCConfig(pi05=True),
1380+
data=LerobotAgilexDataConfig(
1381+
repo_id="<path_to_repo_root>/data/FlattenFold/base",
1382+
default_prompt="Flatten and fold the cloth.",
1383+
use_delta_joint_actions=False,
1384+
),
1385+
weight_loader=weight_loaders.CheckpointWeightLoader("<path_to/pi05_base/checkpoint>"),
1386+
num_train_steps=100_000,
1387+
keep_period=5000,
1388+
num_workers=8,
1389+
batch_size=256,
1390+
),
13741391
# RoboArena & PolaRiS configs.
13751392
*roboarena_config.get_roboarena_configs(),
13761393
*polaris_config.get_polaris_configs(),

stage_advantage/README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ This module implements a pipeline for training an **Advantage Estimator** and us
2222

2323
**End-to-end order for AWBC:** (1) Stage 0 on data with `progress` → optional for Stage 1. (2) Stage 1 → train estimator. (3) Stage 2 → run eval on your dataset so it gets `data_PI06_100000/` or `data_KAI0_100000/` with advantage columns. (4) Run Stage 0 again with `--advantage-source absolute_advantage` on that dataset (e.g. via `gt_labeling.sh` with `DATA_PATH` = the repo you ran eval on, and source subdirs `data_PI06_100000` / `data_KAI0_100000`). (5) Point AWBC config `repo_id` at the resulting advantage-labeled directory and run Stage 3 training.
2424

25+
**Pre-annotated data:** The downloaded dataset includes **`data/Task_A/advantage`**, a fully annotated advantage dataset that can be used **directly for AWBC training** (Stage 3) without running Stage 0–2. Set the AWBC config `repo_id` to that path and run training.
26+
2527
---
2628

2729
## Stage 0: GT Data Labeling
@@ -287,6 +289,8 @@ So during AWBC training the model is conditioned on prompts that explicitly incl
287289

288290
At **inference** time you must use the **same prompt format** as in training. To run the policy in the high-advantage regime, pass the **positive**-advantage prompt, e.g. `"<task>, Advantage: positive"` (with the same `<task>` wording as in your `tasks.jsonl`). Using a different format or omitting the advantage part can hurt performance, since the model was trained to condition on this exact style of prompt.
289291

292+
**Where to set the prompt when deploying:** The language prompt is set in the **inference code** (e.g. the `lang_embeddings` variable in the Agilex inference scripts). See the [train_deploy_alignment/inference README](../train_deploy_alignment/inference/README.md) and [Agilex README — Prompt and AWBC](../train_deploy_alignment/inference/agilex/README.md#prompt-and-awbc-important) for how to configure it so it matches your training and, for AWBC, uses the positive-advantage format above.
293+
290294
### How it works (data flow)
291295

292296
1. **Data**: The advantage dataset must contain `task_index` in each parquet and `meta/tasks.jsonl` mapping `task_index` → prompt string. This is produced by running Stage 2 (eval) to get advantage columns, then Stage 0 (`gt_label.py --advantage-source absolute_advantage`) to discretize into `task_index` and write `tasks.jsonl`.

stage_advantage/awbc/README.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,11 @@ Each uses `base_config=DataConfig(prompt_from_task=True)` so that the dataset’
1717
## Prerequisites
1818

1919
1. **Advantage dataset**
20-
The data must have `task_index` in each parquet and `meta/tasks.jsonl` (prompt strings per `task_index`). To build it:
20+
The data must have `task_index` in each parquet and `meta/tasks.jsonl` (prompt strings per `task_index`).
21+
22+
**Pre-annotated data:** The downloaded dataset includes **`data/Task_A/advantage`**, a fully annotated advantage dataset that can be used **directly for AWBC training** (no need to run Stage 0–2 first). Set the AWBC config `repo_id` to that path and run the training commands below.
23+
24+
To build your own advantage dataset instead:
2125
- Run **Stage 2** (eval) on your dataset → get `data_PI06_100000/` or `data_KAI0_100000/` with advantage columns.
2226
- Run **Stage 0** on that output: `gt_label.py --advantage-source absolute_advantage` (or `gt_labeling.sh` with `DATA_PATH` = the eval repo). The resulting directory (with `data/`, `meta/tasks.jsonl`, `videos/`) is your advantage dataset.
2327
- Place or link it at e.g. `./data/FlattenFold/advantage` and set `repo_id` in config to that path.

train_deploy_alignment/README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Train–Deploy Alignment
2+
3+
This directory contains three modules used to align training data and deployment/inference:
4+
5+
| Module | Description |
6+
|--------|-------------|
7+
| **dagger** | DAgger-style data collection (policy-in-the-loop, intervention, save). See [dagger/README.md](dagger/README.md) for ARX and Agilex. |
8+
| **inference** | Deployment and inference code, including ARX, Agilex. |
9+
| **data_augment** | Data augmentation and format conversion (time scaling, space mirroring, HDF5 → LeRobot). See [data_augment/README.md](data_augment/README.md). |
10+
11+
See each module’s README for setup and usage.

0 commit comments

Comments
 (0)