[BugFix] Fix GAE compact path bias on recurrent value nets at internal truncations#3771
Merged
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/3771
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 1 New Failure, 1 Cancelled Job, 2 Unrelated FailuresAs of commit 45c9ef4 with merge base 5d11fa3 ( NEW FAILURE - The following job has failed:
CANCELLED JOB - The following job was cancelled. Please retry:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This was referenced May 17, 2026
This was referenced May 18, 2026
vmoens
added a commit
that referenced
this pull request
May 19, 2026
…l truncations
`_call_value_net_compact` previously built `data_in = [root[0:T], boundary]`
along the time dim and read `value_[t] = V(data_in[t+1])`. For every internal
`done[t]=True` (`t < T-1`), that read was `V(root_obs[t+1])` -- the
post-reset first observation of the next episode -- rather than the
env-returned `("next", obs)[t]` (the true pre-reset truncation observation).
GAE bootstraps with `(1 - terminated)`, so on envs that truncate without
terminating (Isaac-Ant, Pendulum-on-timeout, ...) the wrong
`next_state_value` was not masked out and flowed straight into the value
target. With recurrent value nets the wrong observation also corrupted the
LSTM hidden state going forward, cascading into downstream slots. The wandb
runs `g71pk34w` / `x05igvw7` (`shifted='compact'`) trailed `sln6yf2a` /
`6c8ihgh7` (`shifted=False`) by ~20% end-of-traj reward at iter 1000 for
exactly this reason.
The fix replaces the `T+1` interleave with a fused batched call: the root
and `("next", ...)` streams are concatenated along a non-time batch dim
into a constant-shape `[2*B, T, *F]` tensor and the value net is invoked
once. Reads of `value` and `value_` are simple batch-half slices. The
`("final", k)` collector contract still overrides the next side at slot
`T-1`. For recurrent value nets `("next", "is_init") |= root_is_init` is
applied so the LSTM resets at every trajectory boundary, exactly matching
the `shifted=False` reference (verified byte-exact on the regression
fixture). `_fill_missing_next_inputs` handles `compact_obs=True` rollouts
that don't populate `("next", k)`. A `time_idx == 0` guard keeps 1D
rollouts correct by unsqueezing a batch dim before the cat.
Shape stays constant across calls in a training run (no `.item()` syncs,
no Python branching on tensor values), so `torch.compile` and `vmap`
remain happy.
The regression test `test_gae_recurrent_shifted_compact_matches_unshifted_isaac_shape`
asserts compact matches `shifted=False` to within `rel < 0.05` on an
Isaac-shaped multi-trajectory rollout (`B=4`, `T=16`, truncations every 4
steps, `compact_obs=False` semantics). Pre-fix LSTM/GRU rel-err: 0.52 /
0.52. Post-fix: < 1e-7 (FP noise) for both. All other 1972 tests in
`test/objectives/test_values.py` continue to pass.
ghstack-source-id: f2d1a3e
Pull-Request: #3771
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Stack from ghstack (oldest at bottom):
_call_value_net_compactpreviously builtdata_in = [root[0:T], boundary]along the time dim and read
value_[t] = V(data_in[t+1]). For every internaldone[t]=True(t < T-1), that read wasV(root_obs[t+1])-- thepost-reset first observation of the next episode -- rather than the
env-returned
("next", obs)[t](the true pre-reset truncation observation).GAE bootstraps with
(1 - terminated), so on envs that truncate withoutterminating (Isaac-Ant, Pendulum-on-timeout, ...) the wrong
next_state_valuewas not masked out and flowed straight into the valuetarget. With recurrent value nets the wrong observation also corrupted the
LSTM hidden state going forward, cascading into downstream slots. The wandb
runs
g71pk34w/x05igvw7(shifted='compact') trailedsln6yf2a/6c8ihgh7(shifted=False) by ~20% end-of-traj reward at iter 1000 forexactly this reason.
The fix replaces the
T+1interleave with a fused batched call: the rootand
("next", ...)streams are concatenated along a non-time batch diminto a constant-shape
[2*B, T, *F]tensor and the value net is invokedonce. Reads of
valueandvalue_are simple batch-half slices. The("final", k)collector contract still overrides the next side at slotT-1. For recurrent value nets("next", "is_init") |= root_is_initisapplied so the LSTM resets at every trajectory boundary, exactly matching
the
shifted=Falsereference (verified byte-exact on the regressionfixture).
_fill_missing_next_inputshandlescompact_obs=Truerolloutsthat don't populate
("next", k). Atime_idx == 0guard keeps 1Drollouts correct by unsqueezing a batch dim before the cat.
Shape stays constant across calls in a training run (no
.item()syncs,no Python branching on tensor values), so
torch.compileandvmapremain happy.
The regression test
test_gae_recurrent_shifted_compact_matches_unshifted_isaac_shapeasserts compact matches
shifted=Falseto withinrel < 0.05on anIsaac-shaped multi-trajectory rollout (
B=4,T=16, truncations every 4steps,
compact_obs=Falsesemantics). Pre-fix LSTM/GRU rel-err: 0.52 /0.52. Post-fix: < 1e-7 (FP noise) for both. All other 1972 tests in
test/objectives/test_values.pycontinue to pass.