Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions test/objectives/test_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,94 @@ def test_shifted_gae_accepts_noncanonical_strides(self):
assert torch.isfinite(out["advantage"]).all()
assert torch.isfinite(out["value_target"]).all()

@pytest.mark.parametrize(
"estimator_cls,kwargs",
[
(TD0Estimator, {"gamma": 0.9}),
(TD1Estimator, {"gamma": 0.9}),
(TDLambdaEstimator, {"gamma": 0.9, "lmbda": 0.95}),
(GAE, {"gamma": 0.9, "lmbda": 0.95}),
],
)
def test_final_obs_bootstrap_shifted(self, estimator_cls, kwargs):
"""``("final", obs)`` carries the true bootstrap obs at the window edge.

Without it, shifted-GAE under ``compact_obs=True`` falls back to
``V(s_{T-1})`` via :meth:`_fill_missing_next_inputs`, corrupting the
last step's advantage when the window boundary is not a real done.
The fix overrides those positions with the values carried under
``("final", obs)`` and matches the non-compact reference exactly.

Also verifies the consumer drops ``("final", ...)`` from the returned
tensordict so it survives a contiguous-storage replay buffer.
"""
from tensordict import UnbatchedTensor

torch.manual_seed(0)
value_net = TensorDictModule(
nn.Linear(3, 1, bias=False),
in_keys=["obs"],
out_keys=["state_value"],
)
B, T, F = 2, 5, 3
obs = torch.randn(B, T, F)
# No real done inside the window: the last step is a "soft" boundary.
done = torch.zeros(B, T, 1, dtype=torch.bool)
reward = torch.ones(B, T, 1)
# The "true" next obs after step T-1 (one per env, no time dim).
final_obs = torch.randn(B, F)

# Reference: full ('next', obs) at every step.
next_obs_full = torch.empty(B, T, F)
next_obs_full[:, :-1] = obs[:, 1:]
next_obs_full[:, -1] = final_obs
td_ref = TensorDict(
{
"obs": obs,
"next": {
"obs": next_obs_full,
"reward": reward,
"done": done.clone(),
"terminated": done.clone(),
},
},
[B, T],
)

# Compact + final_obs: no ('next', obs) but a ('final', obs) UnbatchedTensor.
td_compact = TensorDict(
{
"obs": obs,
"next": {
"reward": reward,
"done": done.clone(),
"terminated": done.clone(),
},
"final": TensorDict(
{"obs": UnbatchedTensor(final_obs)},
batch_size=(B, T),
),
},
[B, T],
)

est = estimator_cls(**kwargs, value_network=value_net, shifted=True)
out_ref = est(td_ref.clone())
out_compact = est(td_compact.clone())

# Must match the non-compact reference exactly at the boundary
# (and everywhere else).
torch.testing.assert_close(out_compact["advantage"], out_ref["advantage"])
torch.testing.assert_close(out_compact["value_target"], out_ref["value_target"])

# Drop must have happened — the rollout is now safe to extend into a
# contiguous-storage RB.
out_inplace = td_compact.clone()
est(out_inplace)
assert (
"final" not in out_inplace.keys()
), "('final', ...) should have been consumed and dropped"

@pytest.mark.skipif(not _has_gym, reason="requires gym")
def test_gae_multi_done(self):

Expand Down
111 changes: 111 additions & 0 deletions test/test_collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1425,6 +1425,117 @@ def make_env():
root_shifted = ref_data.get(k)[..., 1:, :]
torch.testing.assert_close(ref_next[mask], root_shifted[mask])

def test_final_obs_requires_compact_obs(self):
"""``final_obs=True`` without ``compact_obs=True`` must raise."""

def make_env():
return TransformedEnv(ContinuousActionVecMockEnv(), InitTracker())

with pytest.raises(ValueError, match="requires compact_obs=True"):
Collector(
create_env_fn=make_env,
policy=RandomPolicy(make_env().action_spec),
frames_per_batch=10,
total_frames=10,
compact_obs=False,
final_obs=True,
)

@pytest.mark.parametrize("use_buffers", [True, False])
def test_final_obs_matches_compact_off(self, use_buffers):
"""``final_obs=True`` carries the same boundary obs as a non-compact run.

Runs two rollouts with identical seeds: one with
``compact_obs=False`` (full ``('next', obs)`` retained at every step)
and one with ``compact_obs=True, final_obs=True`` (boundary obs
stored under ``('final', obs)`` as
:class:`~tensordict.UnbatchedTensor`). The two must agree on the
boundary obs (non-done envs only — at done envs the bootstrap is
masked downstream so the value there is unconstrained).
"""
from tensordict import UnbatchedTensor

def make_env():
return TransformedEnv(ContinuousActionVecMockEnv(), InitTracker())

dummy_env = make_env()
obs_keys = list(dummy_env._observation_keys_step_mdp)
dummy_env.close()

def run(compact, final):
torch.manual_seed(0)
return Collector(
create_env_fn=make_env,
policy=RandomPolicy(make_env().action_spec),
frames_per_batch=20,
total_frames=20,
use_buffers=use_buffers,
compact_obs=compact,
final_obs=final,
)

ref = run(False, False)
ref_data = next(iter(ref))
ref.shutdown()
del ref

comp = run(True, True)
comp_data = next(iter(comp))
comp.shutdown()
del comp

# Batch shape preserved (no time dim leakage from the UnbatchedTensor).
assert comp_data.batch_size == ref_data.batch_size

# ('final', k) is present and is an UnbatchedTensor.
for k in obs_keys:
full_final = ("final", *k) if isinstance(k, tuple) else ("final", k)
assert full_final in comp_data.keys(True, True), f"missing {full_final}"
val = comp_data.get(full_final)
assert isinstance(val, UnbatchedTensor), type(val)

# Compare against the reference's ('next', k) at the last step,
# masked to non-done envs.
full_next = ("next", *k) if isinstance(k, tuple) else ("next", k)
ref_last_next = ref_data.get(full_next)[..., -1, :]
done_last = ref_data.get(("next", "done"))[..., -1, :].squeeze(-1)
mask = ~done_last
torch.testing.assert_close(val[mask], ref_last_next[mask])

def test_final_obs_unbatched_survives_indexing(self):
"""The ``("final", obs)`` UnbatchedTensor must not collapse on reshape.

Closes (a): if the leaf were a regular tensor, indexing or reshaping
the rollout along the time axis would either drop the time dim or
propagate a shape mismatch into a contiguous-storage replay buffer.
"""
from tensordict import UnbatchedTensor

def make_env():
return TransformedEnv(ContinuousActionVecMockEnv(), InitTracker())

c = Collector(
create_env_fn=make_env,
policy=RandomPolicy(make_env().action_spec),
frames_per_batch=20,
total_frames=20,
compact_obs=True,
final_obs=True,
)
data = next(iter(c))
c.shutdown()

original = data.get(("final", "observation"))
# Slicing along time must preserve the same underlying tensor.
sliced = data[..., :5].get(("final", "observation"))
assert isinstance(sliced, UnbatchedTensor)
assert torch.equal(original, sliced)
# exclude("final") must yield a td whose batch shape reshapes cleanly.
without = data.exclude("final")
assert "final" not in without.keys()
flat = without.reshape(-1)
assert flat.batch_size.numel() == 20

@pytest.mark.parametrize("env_class", [CountingEnv, CountingBatchedEnv])
def test_initial_obs_consistency(self, env_class, seed=1):
# non regression test on #938
Expand Down
16 changes: 16 additions & 0 deletions torchrl/collectors/_multi_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,13 @@ class MultiCollector(BaseCollector, metaclass=_MultiCollectorMeta):
:class:`~torchrl.envs.transforms.rb_transforms.NextStateReconstructor`
at sampling time.
Defaults to ``False``.
final_obs (bool, optional): if ``True`` (requires ``compact_obs=True``),
each worker additionally stores the true next-observation reached
after the last step of its rollout under ``("final", k)`` as an
:class:`tensordict.UnbatchedTensor`. Closes the shifted-GAE
bootstrap-correctness gap at window boundaries. See
:class:`~torchrl.collectors.SyncDataCollector` for details.
Defaults to ``False``.
worker_idx (int, optional): the index of the worker.

Examples:
Expand Down Expand Up @@ -416,6 +423,7 @@ def __init__(
pre_collect_hook: Callable[[], None] | None = None,
post_collect_hook: Callable[[TensorDictBase], None] | None = None,
compact_obs: bool = False,
final_obs: bool = False,
):
self.closed = True
self.worker_idx = worker_idx
Expand Down Expand Up @@ -527,6 +535,13 @@ def __init__(
self.reset_at_each_iter = reset_at_each_iter
self.postproc = postproc
self.compact_obs = bool(compact_obs)
self.final_obs = bool(final_obs)
if self.final_obs and not self.compact_obs:
raise ValueError(
"final_obs=True requires compact_obs=True; otherwise the true "
"next observation is already stored at every step under "
"('next', ...)."
)
self.max_frames_per_traj = (
int(max_frames_per_traj) if max_frames_per_traj is not None else 0
)
Expand Down Expand Up @@ -1323,6 +1338,7 @@ def _run_processes(self) -> None:
"pre_collect_hook": self._worker_pre_collect_hook,
"post_collect_hook": self._worker_post_collect_hook,
"compact_obs": self.compact_obs,
"final_obs": self.final_obs,
}
proc = _ProcessNoWarnCtx(
target=_main_async_collector,
Expand Down
2 changes: 2 additions & 0 deletions torchrl/collectors/_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def _main_async_collector(
pre_collect_hook: Callable[[], None] | None = None,
post_collect_hook: Callable[[TensorDictBase], None] | None = None,
compact_obs: bool = False,
final_obs: bool = False,
) -> None:
# Process-level initialisation hook (e.g. Isaac Lab ``AppLauncher``).
# Runs before any CUDA/torchrl work in the child process.
Expand Down Expand Up @@ -142,6 +143,7 @@ def _main_async_collector(
pre_collect_hook=pre_collect_hook,
post_collect_hook=post_collect_hook,
compact_obs=compact_obs,
final_obs=final_obs,
)
# Set up weight receivers for worker process using the standard register_scheme_receiver API.
# This properly initializes the schemes on the receiver side and stores them in _receiver_schemes.
Expand Down
Loading
Loading