Skip to content

Commit b716fd9

Browse files
committed
[Test] Add tests and benchmarks for collector throughput optimizations
Cover all 7 performance features: _skip_maybe_reset, _StepMDP out= reuse, _trust_step_output, update_traj_ids, combined optimization flags, torch.compile fullgraph, and fast-path benchmarks. Made-with: Cursor ghstack-source-id: ad18afe Pull-Request: #3567
1 parent 383415a commit b716fd9

5 files changed

Lines changed: 285 additions & 0 deletions

File tree

benchmarks/test_collectors_benchmark.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,32 @@ def test_async_pixels(benchmark):
247247
benchmark(execute_collector, c)
248248

249249

250+
def single_collector_fast_setup():
251+
device = "cuda:0" if torch.cuda.device_count() else "cpu"
252+
env = TransformedEnv(DMControlEnv("cheetah", "run", device=device), StepCounter(50))
253+
env._trust_step_output = True
254+
env.base_env._trust_step_output = True
255+
env._skip_maybe_reset = True
256+
c = SyncDataCollector(
257+
env,
258+
RandomPolicy(env.action_spec),
259+
total_frames=-1,
260+
frames_per_batch=100,
261+
device=device,
262+
update_traj_ids=False,
263+
)
264+
c = iter(c)
265+
for i, _ in enumerate(c):
266+
if i == 10:
267+
break
268+
return ((c,), {})
269+
270+
271+
def test_single_fast(benchmark):
272+
(c,), _ = single_collector_fast_setup()
273+
benchmark(execute_collector, c)
274+
275+
250276
class TestRBGCollector:
251277
@pytest.mark.parametrize(
252278
"n_col,n_wokrers_per_col",

benchmarks/test_envs_benchmark.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,47 @@ def test_step_mdp_speed(
128128
)
129129

130130

131+
def _step_and_maybe_reset_loop(env, td):
132+
for _ in range(100):
133+
_, td = env.step_and_maybe_reset(td)
134+
td = env.rand_action(td)
135+
136+
137+
def make_env_fast_path():
138+
device = "cuda:0" if torch.cuda.device_count() else "cpu"
139+
env = TransformedEnv(DMControlEnv("cheetah", "run", device=device), StepCounter(50))
140+
env._trust_step_output = True
141+
env.base_env._trust_step_output = True
142+
env._skip_maybe_reset = True
143+
td = env.reset()
144+
td = env.rand_action(td)
145+
for _ in range(3):
146+
_, td = env.step_and_maybe_reset(td)
147+
td = env.rand_action(td)
148+
return ((env, td), {})
149+
150+
151+
def make_env_normal_path():
152+
device = "cuda:0" if torch.cuda.device_count() else "cpu"
153+
env = TransformedEnv(DMControlEnv("cheetah", "run", device=device), StepCounter(50))
154+
td = env.reset()
155+
td = env.rand_action(td)
156+
for _ in range(3):
157+
_, td = env.step_and_maybe_reset(td)
158+
td = env.rand_action(td)
159+
return ((env, td), {})
160+
161+
162+
def test_step_and_maybe_reset_fast_path(benchmark):
163+
(env, td), _ = make_env_fast_path()
164+
benchmark(_step_and_maybe_reset_loop, env, td)
165+
166+
167+
def test_step_and_maybe_reset_normal(benchmark):
168+
(env, td), _ = make_env_normal_path()
169+
benchmark(_step_and_maybe_reset_loop, env, td)
170+
171+
131172
if __name__ == "__main__":
132173
args, unknown = argparse.ArgumentParser().parse_known_args()
133174
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

test/compile/test_compile_collectors.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,42 @@ def test_cudagraph_policy(self, collector_cls, cudagraph_policy):
117117
collector.shutdown()
118118
del collector
119119

120+
def test_compile_step_and_maybe_reset_fullgraph(self):
121+
torch._dynamo.reset_code_caches()
122+
123+
env = ContinuousActionVecMockEnv()
124+
env._trust_step_output = True
125+
env._skip_maybe_reset = True
126+
127+
td = env.reset()
128+
td = env.rand_action(td)
129+
130+
for _ in range(3):
131+
_, td = env.step_and_maybe_reset(td)
132+
td = env.rand_action(td)
133+
134+
torch._dynamo.reset()
135+
explanation = torch._dynamo.explain(env.step_and_maybe_reset)(td)
136+
assert explanation.graph_count == 1
137+
assert explanation.graph_break_count == 0
138+
139+
out_eager, next_eager = env.step_and_maybe_reset(td.clone())
140+
141+
compiled_fn = torch.compile(env.step_and_maybe_reset, fullgraph=True)
142+
out_compiled, next_compiled = compiled_fn(td.clone())
143+
144+
for key in out_eager.keys(True, True):
145+
v_e = out_eager.get(key)
146+
v_c = out_compiled.get(key)
147+
if isinstance(v_e, torch.Tensor):
148+
torch.testing.assert_close(v_e, v_c)
149+
150+
for key in next_eager.keys(True, True):
151+
v_e = next_eager.get(key)
152+
v_c = next_compiled.get(key)
153+
if isinstance(v_e, torch.Tensor):
154+
torch.testing.assert_close(v_e, v_c)
155+
120156

121157
if __name__ == "__main__":
122158
pytest.main([__file__, "-v"])

test/test_collectors.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4088,6 +4088,69 @@ def test_unique_traj_sync(self, cat_results):
40884088
del c
40894089

40904090

4091+
class TestUpdateTrajIds:
4092+
def test_update_traj_ids_default_is_true(self):
4093+
env = ContinuousActionVecMockEnv()
4094+
policy = TensorDictModule(
4095+
nn.Linear(
4096+
env.observation_spec["observation"].shape[-1], env.action_spec.shape[-1]
4097+
),
4098+
in_keys=["observation"],
4099+
out_keys=["action"],
4100+
)
4101+
collector = Collector(env, policy, frames_per_batch=10, total_frames=10)
4102+
try:
4103+
assert collector.update_traj_ids is True
4104+
finally:
4105+
collector.shutdown()
4106+
4107+
def test_update_traj_ids_false_skips_tracking(self):
4108+
env = ContinuousActionVecMockEnv()
4109+
policy = TensorDictModule(
4110+
nn.Linear(
4111+
env.observation_spec["observation"].shape[-1], env.action_spec.shape[-1]
4112+
),
4113+
in_keys=["observation"],
4114+
out_keys=["action"],
4115+
)
4116+
collector = Collector(
4117+
env,
4118+
policy,
4119+
frames_per_batch=10,
4120+
total_frames=20,
4121+
update_traj_ids=False,
4122+
)
4123+
try:
4124+
for data in collector:
4125+
traj_ids = data.get(("collector", "traj_ids"))
4126+
assert (traj_ids == traj_ids[..., 0:1]).all()
4127+
finally:
4128+
collector.shutdown()
4129+
4130+
def test_update_traj_ids_true_updates(self):
4131+
env = ContinuousActionVecMockEnv()
4132+
policy = TensorDictModule(
4133+
nn.Linear(
4134+
env.observation_spec["observation"].shape[-1], env.action_spec.shape[-1]
4135+
),
4136+
in_keys=["observation"],
4137+
out_keys=["action"],
4138+
)
4139+
collector = Collector(
4140+
env,
4141+
policy,
4142+
frames_per_batch=50,
4143+
total_frames=100,
4144+
update_traj_ids=True,
4145+
)
4146+
try:
4147+
for data in collector:
4148+
traj_ids = data.get(("collector", "traj_ids"))
4149+
assert traj_ids is not None
4150+
finally:
4151+
collector.shutdown()
4152+
4153+
40914154
class TestDynamicEnvs:
40924155
def test_dynamic_sync_collector(self):
40934156
env = EnvWithDynamicSpec()
@@ -5296,6 +5359,36 @@ def env_fn():
52965359
assert expected_trace.exists(), f"Trace file not found at {expected_trace}"
52975360

52985361

5362+
class TestCollectorOptimizationFlags:
5363+
def test_collector_all_optimizations(self):
5364+
env = TransformedEnv(ContinuousActionVecMockEnv(), StepCounter())
5365+
env._trust_step_output = True
5366+
env.base_env._trust_step_output = True
5367+
env._skip_maybe_reset = True
5368+
policy = TensorDictModule(
5369+
nn.Linear(
5370+
env.observation_spec["observation"].shape[-1], env.action_spec.shape[-1]
5371+
),
5372+
in_keys=["observation"],
5373+
out_keys=["action"],
5374+
)
5375+
collector = Collector(
5376+
env,
5377+
policy,
5378+
frames_per_batch=20,
5379+
total_frames=40,
5380+
update_traj_ids=False,
5381+
)
5382+
try:
5383+
for data in collector:
5384+
assert data.shape[-1] == 20
5385+
assert "observation" in data.keys()
5386+
assert ("next", "observation") in data.keys(True)
5387+
assert ("next", "reward") in data.keys(True)
5388+
finally:
5389+
collector.shutdown()
5390+
5391+
52995392
if __name__ == "__main__":
53005393
args, unknown = argparse.ArgumentParser().parse_known_args()
53015394
pytest.main(

test/test_envs.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2232,6 +2232,28 @@ def test_step_class(
22322232
out_cls = step_func(tensordict)
22332233
assert (out_func == out_cls).all()
22342234

2235+
@pytest.mark.parametrize(
2236+
"envcls",
2237+
[
2238+
ContinuousActionVecMockEnv,
2239+
CountingBatchedEnv,
2240+
CountingEnv,
2241+
],
2242+
)
2243+
def test_step_class_out_reuse(self, envcls):
2244+
torch.manual_seed(0)
2245+
env = envcls()
2246+
tensordict = env.rand_step(env.reset())
2247+
2248+
step_func = _StepMDP(env, exclude_action=False)
2249+
result_no_out = step_func(tensordict.clone())
2250+
out_buf = result_no_out.clone()
2251+
out_buf_id = id(out_buf)
2252+
2253+
result_with_out = step_func(tensordict.clone(), out=out_buf)
2254+
assert id(result_with_out) == out_buf_id
2255+
assert (result_no_out == result_with_out).all()
2256+
22352257
@pytest.mark.parametrize("nested_obs", [True, False])
22362258
@pytest.mark.parametrize("nested_action", [True, False])
22372259
@pytest.mark.parametrize("nested_done", [True, False])
@@ -3780,6 +3802,32 @@ def policy(td):
37803802
assert not lazy["lidar"][~done.squeeze()].isnan().any()
37813803
assert (lazy_root["lidar"][1:][done[:-1].squeeze()] == 0).all()
37823804

3805+
def test_skip_maybe_reset_default(self):
3806+
env = AutoResettingCountingEnv(4, auto_reset=True)
3807+
assert not env._skip_maybe_reset
3808+
3809+
def test_skip_maybe_reset_step_and_maybe_reset(self):
3810+
env_normal = AutoResettingCountingEnv(100, auto_reset=True)
3811+
td_normal = env_normal.reset()
3812+
td_normal.set("action", torch.ones((*td_normal.shape, 1), dtype=torch.int64))
3813+
3814+
env_skip = AutoResettingCountingEnv(100, auto_reset=True)
3815+
env_skip._skip_maybe_reset = True
3816+
td_skip = env_skip.reset()
3817+
td_skip.set("action", torch.ones((*td_skip.shape, 1), dtype=torch.int64))
3818+
3819+
out_normal, next_normal = env_normal.step_and_maybe_reset(td_normal)
3820+
out_skip, next_skip = env_skip.step_and_maybe_reset(td_skip)
3821+
3822+
torch.testing.assert_close(
3823+
out_normal["next", "observation"],
3824+
out_skip["next", "observation"],
3825+
)
3826+
torch.testing.assert_close(
3827+
next_normal["observation"],
3828+
next_skip["observation"],
3829+
)
3830+
37833831

37843832
class TestEnvWithDynamicSpec:
37853833
def test_dynamic_rollout(self):
@@ -5026,6 +5074,47 @@ def test_parallel_env_no_buffers_mps_rollout(self):
50265074
env.close(raise_if_closed=False)
50275075

50285076

5077+
class TestTrustStepOutput:
5078+
def test_trust_step_output_default(self):
5079+
env = ContinuousActionVecMockEnv()
5080+
assert not env._trust_step_output
5081+
5082+
def test_trust_step_output_fast_path(self):
5083+
env = TransformedEnv(ContinuousActionVecMockEnv(), StepCounter())
5084+
td = env.reset()
5085+
td = env.rand_action(td)
5086+
5087+
out_normal = env.step(td.clone())
5088+
5089+
env._trust_step_output = True
5090+
env.base_env._trust_step_output = True
5091+
out_fast = env.step(td.clone())
5092+
5093+
torch.testing.assert_close(
5094+
out_normal["next", "observation"],
5095+
out_fast["next", "observation"],
5096+
)
5097+
torch.testing.assert_close(
5098+
out_normal["next", "reward"],
5099+
out_fast["next", "reward"],
5100+
)
5101+
5102+
def test_trust_step_fast_path_step_and_maybe_reset(self):
5103+
env = TransformedEnv(ContinuousActionVecMockEnv(), StepCounter())
5104+
env._trust_step_output = True
5105+
env.base_env._trust_step_output = True
5106+
env._skip_maybe_reset = True
5107+
5108+
td = env.reset()
5109+
td = env.rand_action(td)
5110+
5111+
out, next_out = env.step_and_maybe_reset(td)
5112+
5113+
assert "next" in out.keys()
5114+
assert "observation" in next_out.keys()
5115+
assert "step_count" in next_out.keys()
5116+
5117+
50295118
if __name__ == "__main__":
50305119
args, unknown = argparse.ArgumentParser().parse_known_args()
50315120
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

0 commit comments

Comments
 (0)