Skip to content

[BugFix] Fix GAE compact path bias on recurrent value nets at internal truncations#3771

Merged
vmoens merged 5 commits into
gh/vmoens/285/basefrom
gh/vmoens/285/head
May 19, 2026
Merged

[BugFix] Fix GAE compact path bias on recurrent value nets at internal truncations#3771
vmoens merged 5 commits into
gh/vmoens/285/basefrom
gh/vmoens/285/head

Conversation

@vmoens
Copy link
Copy Markdown
Collaborator

@vmoens vmoens commented May 17, 2026

Stack from ghstack (oldest at bottom):

_call_value_net_compact previously built data_in = [root[0:T], boundary]
along the time dim and read value_[t] = V(data_in[t+1]). For every internal
done[t]=True (t < T-1), that read was V(root_obs[t+1]) -- the
post-reset first observation of the next episode -- rather than the
env-returned ("next", obs)[t] (the true pre-reset truncation observation).
GAE bootstraps with (1 - terminated), so on envs that truncate without
terminating (Isaac-Ant, Pendulum-on-timeout, ...) the wrong
next_state_value was not masked out and flowed straight into the value
target. With recurrent value nets the wrong observation also corrupted the
LSTM hidden state going forward, cascading into downstream slots. The wandb
runs g71pk34w / x05igvw7 (shifted='compact') trailed sln6yf2a /
6c8ihgh7 (shifted=False) by ~20% end-of-traj reward at iter 1000 for
exactly this reason.

The fix replaces the T+1 interleave with a fused batched call: the root
and ("next", ...) streams are concatenated along a non-time batch dim
into a constant-shape [2*B, T, *F] tensor and the value net is invoked
once. Reads of value and value_ are simple batch-half slices. The
("final", k) collector contract still overrides the next side at slot
T-1. For recurrent value nets ("next", "is_init") |= root_is_init is
applied so the LSTM resets at every trajectory boundary, exactly matching
the shifted=False reference (verified byte-exact on the regression
fixture). _fill_missing_next_inputs handles compact_obs=True rollouts
that don't populate ("next", k). A time_idx == 0 guard keeps 1D
rollouts correct by unsqueezing a batch dim before the cat.

Shape stays constant across calls in a training run (no .item() syncs,
no Python branching on tensor values), so torch.compile and vmap
remain happy.

The regression test test_gae_recurrent_shifted_compact_matches_unshifted_isaac_shape
asserts compact matches shifted=False to within rel < 0.05 on an
Isaac-shaped multi-trajectory rollout (B=4, T=16, truncations every 4
steps, compact_obs=False semantics). Pre-fix LSTM/GRU rel-err: 0.52 /
0.52. Post-fix: < 1e-7 (FP noise) for both. All other 1972 tests in
test/objectives/test_values.py continue to pass.

[ghstack-poisoned]
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented May 17, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/3771

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 1 New Failure, 1 Cancelled Job, 2 Unrelated Failures

As of commit 45c9ef4 with merge base 5d11fa3 (image):

NEW FAILURE - The following job has failed:

CANCELLED JOB - The following job was cancelled. Please retry:

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

[ghstack-poisoned]
vmoens added 3 commits May 18, 2026 19:43
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
vmoens added a commit that referenced this pull request May 19, 2026
…l truncations

`_call_value_net_compact` previously built `data_in = [root[0:T], boundary]`
along the time dim and read `value_[t] = V(data_in[t+1])`. For every internal
`done[t]=True` (`t < T-1`), that read was `V(root_obs[t+1])` -- the
post-reset first observation of the next episode -- rather than the
env-returned `("next", obs)[t]` (the true pre-reset truncation observation).
GAE bootstraps with `(1 - terminated)`, so on envs that truncate without
terminating (Isaac-Ant, Pendulum-on-timeout, ...) the wrong
`next_state_value` was not masked out and flowed straight into the value
target. With recurrent value nets the wrong observation also corrupted the
LSTM hidden state going forward, cascading into downstream slots. The wandb
runs `g71pk34w` / `x05igvw7` (`shifted='compact'`) trailed `sln6yf2a` /
`6c8ihgh7` (`shifted=False`) by ~20% end-of-traj reward at iter 1000 for
exactly this reason.

The fix replaces the `T+1` interleave with a fused batched call: the root
and `("next", ...)` streams are concatenated along a non-time batch dim
into a constant-shape `[2*B, T, *F]` tensor and the value net is invoked
once. Reads of `value` and `value_` are simple batch-half slices. The
`("final", k)` collector contract still overrides the next side at slot
`T-1`. For recurrent value nets `("next", "is_init") |= root_is_init` is
applied so the LSTM resets at every trajectory boundary, exactly matching
the `shifted=False` reference (verified byte-exact on the regression
fixture). `_fill_missing_next_inputs` handles `compact_obs=True` rollouts
that don't populate `("next", k)`. A `time_idx == 0` guard keeps 1D
rollouts correct by unsqueezing a batch dim before the cat.

Shape stays constant across calls in a training run (no `.item()` syncs,
no Python branching on tensor values), so `torch.compile` and `vmap`
remain happy.

The regression test `test_gae_recurrent_shifted_compact_matches_unshifted_isaac_shape`
asserts compact matches `shifted=False` to within `rel < 0.05` on an
Isaac-shaped multi-trajectory rollout (`B=4`, `T=16`, truncations every 4
steps, `compact_obs=False` semantics). Pre-fix LSTM/GRU rel-err: 0.52 /
0.52. Post-fix: < 1e-7 (FP noise) for both. All other 1972 tests in
`test/objectives/test_values.py` continue to pass.

ghstack-source-id: f2d1a3e
Pull-Request: #3771
@vmoens vmoens merged commit 45c9ef4 into gh/vmoens/285/base May 19, 2026
107 of 113 checks passed
@vmoens vmoens deleted the gh/vmoens/285/head branch May 19, 2026 07:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

BugFix CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Objectives

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant