This script provides a fast path to compute normalization statistics for Kai0 configs by
directly reading local parquet files instead of going through the full data loader. It produces
norm_stats that are compatible with the original openpi pipeline (same RunningStats
implementation and batching scheme).
- You have already downloaded the dataset locally (e.g. under
./data, seedocs/dataset.md). - You have a training config in
src/openpi/training/config.py(e.g.pi05_flatten_fold_normal) and you want to computenorm_statsbefore runningscripts/train.py. - You prefer a simpler / faster pipeline compared to the original
compute_norm_stats.pywhile keeping numerically compatible statistics.
The script lives at:
scripts/compute_norm_states_fast.py
Main entry:
main(config_name: str, base_dir: str | None = None, max_frames: int | None = None)
CLI is handled via tyro, so you call it from the repo root as:
uv run python scripts/compute_norm_states_fast.py --config-name <config_name> [--base-dir <path>] [--max-frames N]-
--config-name(str, required)- Name of the TrainConfig defined in
src/openpi/training/config.py, e.g.:pi05_flatten_fold_normalpi05_tee_shirt_sort_normalpi05_hang_cloth_normal
- Internally resolved via
_config.get_config(config_name).
- Name of the TrainConfig defined in
-
--base-dir(str, optional)- Base directory containing the parquet data for this config.
- If omitted, the script will read it from
config.data:data_config = config.data.create(config.assets_dirs, config.model)base_dirdefaults todata_config.repo_id
- This means you can either:
- Set
repo_idin the config to your local dataset path (e.g.<path_to_repo_root>/data/FlattenFold/base), or - Keep
repo_idas-is and pass--base-direxplicitly to point to your local copy.
- Set
-
--max-frames(int, optional)- If set, stops after processing at most
max_framesframes across all parquet files. - Useful for quick sanity checks or debugging smaller subsets.
- If set, stops after processing at most
-
Load config
- Uses
_config.get_config(config_name)to get theTrainConfig. - Calls
config.data.create(config.assets_dirs, config.model)to build a data config. - Reads
action_dimfromconfig.model.action_dim.
- Uses
-
Resolve input data directory
- If
base_diris not provided:- Uses
data_config.repo_idas the base directory. - Prints a message like:
Auto-detected base directory from config: <base_dir>
- Uses
- Verifies that the directory exists.
- If
-
Scan parquet files
- Recursively walks
base_dirand collects all files ending with.parquet. - Sorts them lexicographically for deterministic ordering (matches dataset order).
- Recursively walks
-
Read and process data
- For each parquet file:
- Loads it with
pandas.read_parquet. - Expects columns:
observation.stateaction
- For each row:
- Extracts
stateandactionarrays. - Applies:
process_state(state, action_dim)process_actions(actions, action_dim)
- These helpers:
- Pad to
action_dim(if dimension is smaller). - Clip abnormal values outside ([-π, π]) to 0 (for robustness, consistent with
FakeInputslogic).
- Pad to
- Extracts
- Accumulates processed arrays into:
collected_data["state"]collected_data["actions"]
- Maintains a running
total_framescounter and respectsmax_framesif provided.
- Loads it with
- For each parquet file:
-
Concatenate and pad
- Concatenates all collected batches per key:
all_data["state"],all_data["actions"]
- Ensures the last dimension matches
action_dim(pads with zeros if needed).
- Concatenates all collected batches per key:
-
Compute statistics with
RunningStats- Initializes one
normalize.RunningStats()per key (state,actions). - Feeds data in batches of 32 to match the original implementation’s floating-point accumulation behavior.
- For each key, computes:
mean,std,q01,q99, etc.
- Initializes one
-
Save
norm_stats- Collects results into a dict
norm_stats. - Saves them with
openpi.shared.normalize.saveto:output_path = config.assets_dirs / data_config.repo_id
- Prints the output path and a success message:
✅ Normalization stats saved to <output_path>
- Collects results into a dict
Note: The save logic mirrors the original openpi
compute_norm_stats.pybehavior so that training code can loadnorm_statstransparently.
-
Download dataset
- Follow
docs/dataset.mdto download the Kai0 dataset under./dataat the repo root.
- Follow
-
Set config paths
- Edit
src/openpi/training/config.pyfor the normal π₀.5 configs (see READMEPreparation):repo_id→ absolute path to the dataset subset, e.g.<path_to_repo_root>/data/FlattenFold/baseweight_loader→ path to the π₀.5 base checkpoint (e.g. Kai0 best model per task).
- Edit
-
Compute normalization stats
- From the repo root:
uv run python scripts/compute_norm_states_fast.py --config-name pi05_flatten_fold_normal- Train
- Then run JAX training with:
XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 \
uv run scripts/train.py pi05_flatten_fold_normal --exp_name=<your_experiment_name>The training code will pick up the normalization statistics saved by this script and use them for input normalization, in the same way as the original openpi pipeline.