Skip to content

Add optional pi0.5 subtask prediction stage#942

Open
taivu1998 wants to merge 1 commit into
Physical-Intelligence:mainfrom
taivu1998:tdv/issue-664-subtask-prediction
Open

Add optional pi0.5 subtask prediction stage#942
taivu1998 wants to merge 1 commit into
Physical-Intelligence:mainfrom
taivu1998:tdv/issue-664-subtask-prediction

Conversation

@taivu1998
Copy link
Copy Markdown

Summary

Adds opt-in pi0.5 subtask prediction support for research workflows that want to reproduce the two-stage inference pattern described in issue #664:

  1. predict intermediate subtask text from the prompt/images/state;
  2. condition flow-matching action decoding on that generated subtask.

The existing action-only pi0/pi0.5 behavior remains the default.

Motivation

Issue #664 points out that the pi0.5 paper describes an intermediate subtask prediction stage, while openpi currently tokenizes the prompt and directly predicts actions. This PR adds that missing path behind explicit config flags instead of changing default inference or training behavior for released checkpoints.

Changes

  • Adds Pi0Config flags for pi0.5 subtask workflows:
    • train_subtask_prediction
    • sample_subtask_prediction
    • subtask_loss_weight
    • max_subtask_len
    • subtask_temperature
    • subtask_eos_token
  • Extends Observation with optional tokenized action-suffix fields used by staged inference.
  • Adds PaliGemma tokenizer helpers for:
    • supervised subtask training sequences;
    • staged inference prefixes ending in Subtask:;
    • action cue suffixes;
    • output detokenization.
  • Adds pi0.5 JAX model support for:
    • batch-shaped token autoregressive masks;
    • text de-embedding from Gemma hidden states;
    • optional subtask cross-entropy loss;
    • autoregressive subtask token generation;
    • sample_actions_with_subtask() returning actions plus generated subtask tokens.
  • Wires policy inference to use staged sampling only when sample_subtask_prediction=True.
  • Keeps PyTorch behavior explicit by raising NotImplementedError if subtask mode is requested there.
  • Adds tests for tokenizer masks, transforms, and dummy pi0.5 staged model behavior.
  • Updates README wording to describe action-only defaults plus the new opt-in subtask path.

Compatibility

  • Default Pi0Config(pi05=True) behavior is unchanged.
  • Existing sample_actions() remains action-only and keeps the same return type.
  • Staged inference is exposed separately through sample_actions_with_subtask() and policy wiring when enabled.
  • PyTorch checkpoints fail early for subtask mode rather than silently ignoring the feature.

Validation

  • uvx ruff check src/openpi/models/model.py src/openpi/models/pi0.py src/openpi/models/pi0_config.py src/openpi/models/gemma.py src/openpi/models/tokenizer.py src/openpi/models/tokenizer_test.py src/openpi/models/model_test.py src/openpi/models_pytorch/pi0_pytorch.py src/openpi/policies/policy.py src/openpi/training/config.py src/openpi/transforms.py src/openpi/transforms_test.py
  • .venv/bin/python -m py_compile src/openpi/models/model.py src/openpi/models/pi0.py src/openpi/models/pi0_config.py src/openpi/models/gemma.py src/openpi/models/tokenizer.py src/openpi/models/tokenizer_test.py src/openpi/models/model_test.py src/openpi/models_pytorch/pi0_pytorch.py src/openpi/policies/policy.py src/openpi/training/config.py src/openpi/transforms.py src/openpi/transforms_test.py
  • .venv/bin/python -m pytest src/openpi/models/tokenizer_test.py::test_subtask_tokenize_training src/openpi/models/tokenizer_test.py::test_subtask_tokenize_inference src/openpi/transforms_test.py::test_tokenize_pi05_subtask_training src/openpi/transforms_test.py::test_tokenize_pi05_subtask_inference src/openpi/transforms_test.py::test_tokenize_pi05_subtask_training_requires_subtask src/openpi/models/model_test.py::test_pi05_subtask_model -q
  • git diff --check

Note: the full uv run pytest ... path is not runnable on this macOS arm64 machine because the repo pins jax[cuda12]==0.5.3, whose CUDA plugin wheel is Linux-only. The targeted runtime tests above used a lightweight CPU-JAX environment.

@taivu1998 taivu1998 marked this pull request as ready for review May 11, 2026 03:36
@jimmyt857 jimmyt857 removed their request for review May 11, 2026 04:08
Copy link
Copy Markdown
Contributor

@wadeKeith wadeKeith left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice feature! The optional pi0.5 subtask prediction stage is well scoped with clean separation from the main action prediction path. Tokenizer extensions and test coverage look solid. LGTM! Reviewed by Hermes Agent.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants