Skip to content

fix(train): ensure consistent color augmentation RNG across base and wrist cameras#914

Open
wadeKeith wants to merge 2 commits into
Physical-Intelligence:mainfrom
wadeKeith:fix/augmentation-rng-consistency
Open

fix(train): ensure consistent color augmentation RNG across base and wrist cameras#914
wadeKeith wants to merge 2 commits into
Physical-Intelligence:mainfrom
wadeKeith:fix/augmentation-rng-consistency

Conversation

@wadeKeith
Copy link
Copy Markdown
Contributor

Problem

Fixes #859.

augmax.Chain splits the input RNG into sub-keys by transform count. Base cameras use 4 transforms (RandomCrop, Resize, Rotate, ColorJitter) while wrist cameras use only 1 (ColorJitter). This means ColorJitter receives different sub-RNGs for base vs. wrist cameras, even when given the same input seed.

The result: base and wrist cameras see inconsistent color augmentation within the same training sample, which may confuse the VLM about object colors across camera views.

Fix

Split augmentation into two deterministic stages:

  1. Spatial transforms (base cameras only): RandomCrop + Resize + Rotate
  2. Color transforms (all cameras, shared RNG): ColorJitter

Both stages derive their RNG from the original key using jax.random.fold_in with different constants, ensuring:

  • ColorJitter always uses the same RNG regardless of camera type
  • Spatial transforms use a separate, independent RNG stream
  • Results are fully deterministic and reproducible

Changes

  • src/openpi/models/model.py: Refactored image augmentation to use split RNG stages

@wadeKeith wadeKeith requested a review from kvablack as a code owner March 29, 2026 09:32
…wrist cameras

Fixes Physical-Intelligence#859. augmax.Chain splits RNG by transform count, so base cameras
(4 transforms) and wrist cameras (1 transform) received different sub-RNGs
for ColorJitter despite the same input seed. This caused inconsistent color
semantics between camera views.

Split augmentation into spatial (camera-specific) and color (shared) stages
with deterministic RNG derivation via jax.random.fold_in, ensuring all cameras
see the same ColorJitter parameters within each sample.
@wadeKeith wadeKeith force-pushed the fix/augmentation-rng-consistency branch from 0dcc94f to e9bc7ee Compare May 9, 2026 11:21
@wadeKeith
Copy link
Copy Markdown
Contributor Author

@kvablack gentle ping — all CI is green, ready for review. Thanks!

@wadeKeith
Copy link
Copy Markdown
Contributor Author

Friendly ping for review - all CI checks are passing. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Possible bug in data augmentation pipeline

2 participants