@@ -2232,6 +2232,28 @@ def test_step_class(
22322232 out_cls = step_func (tensordict )
22332233 assert (out_func == out_cls ).all ()
22342234
2235+ @pytest .mark .parametrize (
2236+ "envcls" ,
2237+ [
2238+ ContinuousActionVecMockEnv ,
2239+ CountingBatchedEnv ,
2240+ CountingEnv ,
2241+ ],
2242+ )
2243+ def test_step_class_out_reuse (self , envcls ):
2244+ torch .manual_seed (0 )
2245+ env = envcls ()
2246+ tensordict = env .rand_step (env .reset ())
2247+
2248+ step_func = _StepMDP (env , exclude_action = False )
2249+ result_no_out = step_func (tensordict .clone ())
2250+ out_buf = result_no_out .clone ()
2251+ out_buf_id = id (out_buf )
2252+
2253+ result_with_out = step_func (tensordict .clone (), out = out_buf )
2254+ assert id (result_with_out ) == out_buf_id
2255+ assert (result_no_out == result_with_out ).all ()
2256+
22352257 @pytest .mark .parametrize ("nested_obs" , [True , False ])
22362258 @pytest .mark .parametrize ("nested_action" , [True , False ])
22372259 @pytest .mark .parametrize ("nested_done" , [True , False ])
@@ -3780,6 +3802,32 @@ def policy(td):
37803802 assert not lazy ["lidar" ][~ done .squeeze ()].isnan ().any ()
37813803 assert (lazy_root ["lidar" ][1 :][done [:- 1 ].squeeze ()] == 0 ).all ()
37823804
3805+ def test_skip_maybe_reset_default (self ):
3806+ env = AutoResettingCountingEnv (4 , auto_reset = True )
3807+ assert not env ._skip_maybe_reset
3808+
3809+ def test_skip_maybe_reset_step_and_maybe_reset (self ):
3810+ env_normal = AutoResettingCountingEnv (100 , auto_reset = True )
3811+ td_normal = env_normal .reset ()
3812+ td_normal .set ("action" , torch .ones ((* td_normal .shape , 1 ), dtype = torch .int64 ))
3813+
3814+ env_skip = AutoResettingCountingEnv (100 , auto_reset = True )
3815+ env_skip ._skip_maybe_reset = True
3816+ td_skip = env_skip .reset ()
3817+ td_skip .set ("action" , torch .ones ((* td_skip .shape , 1 ), dtype = torch .int64 ))
3818+
3819+ out_normal , next_normal = env_normal .step_and_maybe_reset (td_normal )
3820+ out_skip , next_skip = env_skip .step_and_maybe_reset (td_skip )
3821+
3822+ torch .testing .assert_close (
3823+ out_normal ["next" , "observation" ],
3824+ out_skip ["next" , "observation" ],
3825+ )
3826+ torch .testing .assert_close (
3827+ next_normal ["observation" ],
3828+ next_skip ["observation" ],
3829+ )
3830+
37833831
37843832class TestEnvWithDynamicSpec :
37853833 def test_dynamic_rollout (self ):
@@ -5026,6 +5074,47 @@ def test_parallel_env_no_buffers_mps_rollout(self):
50265074 env .close (raise_if_closed = False )
50275075
50285076
5077+ class TestTrustStepOutput :
5078+ def test_trust_step_output_default (self ):
5079+ env = ContinuousActionVecMockEnv ()
5080+ assert not env ._trust_step_output
5081+
5082+ def test_trust_step_output_fast_path (self ):
5083+ env = TransformedEnv (ContinuousActionVecMockEnv (), StepCounter ())
5084+ td = env .reset ()
5085+ td = env .rand_action (td )
5086+
5087+ out_normal = env .step (td .clone ())
5088+
5089+ env ._trust_step_output = True
5090+ env .base_env ._trust_step_output = True
5091+ out_fast = env .step (td .clone ())
5092+
5093+ torch .testing .assert_close (
5094+ out_normal ["next" , "observation" ],
5095+ out_fast ["next" , "observation" ],
5096+ )
5097+ torch .testing .assert_close (
5098+ out_normal ["next" , "reward" ],
5099+ out_fast ["next" , "reward" ],
5100+ )
5101+
5102+ def test_trust_step_fast_path_step_and_maybe_reset (self ):
5103+ env = TransformedEnv (ContinuousActionVecMockEnv (), StepCounter ())
5104+ env ._trust_step_output = True
5105+ env .base_env ._trust_step_output = True
5106+ env ._skip_maybe_reset = True
5107+
5108+ td = env .reset ()
5109+ td = env .rand_action (td )
5110+
5111+ out , next_out = env .step_and_maybe_reset (td )
5112+
5113+ assert "next" in out .keys ()
5114+ assert "observation" in next_out .keys ()
5115+ assert "step_count" in next_out .keys ()
5116+
5117+
50295118if __name__ == "__main__" :
50305119 args , unknown = argparse .ArgumentParser ().parse_known_args ()
50315120 pytest .main ([__file__ , "--capture" , "no" , "--exitfirst" ] + unknown )
0 commit comments