@@ -184,25 +184,23 @@ def evaluate(
184184 context_manager : ContextManager ,
185185 session_wrapper : SessionWrapper ,
186186 num_steps : int ,
187- observations : torch .Tensor | None = None ,
188187 verbose : bool = True ,
189188 reset_from_onnx_counter_steps : int = 50 ,
190189 atol : float = 1.0e-5 ,
191190 rtol : float = 1.0e-5 ,
192191 pause_on_failure : bool = True ,
193192) -> tuple [bool , torch .Tensor ]:
194- """Evaluate an ONNX exported model against the original IsaacLab environment and torch policy .
193+ """Evaluate an ONNX exported model against an `ExportableEnvironment` stepped through a `SessionWrapper` .
195194
196195 This function runs the simulation for a specified number of steps and compares the
197- outputs of the ONNX model with the environment's state and the original torch model's
198- outputs at each step. This is useful for verifying the correctness of the ONNX export.
196+ outputs of the ONNX model with the environment's state and actor's actions at each step.
197+ This is useful for verifying the correctness of the ONNX export.
199198
200199 Args:
201200 env: The environment to run the evaluation in.
202201 context_manager: The context manager handling inputs and outputs.
203202 session_wrapper: An ONNX session wrapper.
204203 num_steps: The number of steps to run the evaluation for.
205- observations: The initial observations. If None, the environment is reset. Defaults to None.
206204 verbose: Whether to print verbose output during evaluation. Defaults to True.
207205 reset_from_onnx_counter_steps: Set after how many steps we should set memory inputs from ONNX instead of using
208206 the environment's state.
@@ -220,7 +218,15 @@ def evaluate(
220218 the final observations tensor.
221219 """
222220
223- obs = observations .clone () if observations is not None else env .observations_reset ()
221+ # Reset both the environment and the actor.
222+ obs = env .observations_reset ()
223+
224+ actor = session_wrapper .get_actor ()
225+ if actor is None :
226+ raise ValueError (
227+ "Session wrapper has no actor. Cannot evaluate ONNX model without access to original actor for comparison."
228+ )
229+ actor .reset (torch .tensor ([True ], device = obs .device ))
224230
225231 # Print ONNX graph structure if verbose
226232 if verbose :
@@ -259,9 +265,10 @@ def reset():
259265 )
260266
261267 # Compute actions for the initial observations.
262- env_actions : torch .Tensor = session_wrapper . get_torch_model () (obs )
268+ env_actions : torch .Tensor = actor (obs )
263269
264270 reset_memory_from_env = False
271+ env .context_manager ().read_inputs ()
265272
266273 while step_ctr < num_steps :
267274 reset_memory_from_env = (
@@ -270,15 +277,16 @@ def reset():
270277 next_obs , is_reset_step = env .step (env_actions )
271278 # Use the environment's observations for the next step.
272279 obs [:] = next_obs
273- # Compute actions from the new observations.
274- env_actions = session_wrapper .get_torch_model ()(obs )
275280
276281 # Check if the environment was reset.
277282 if is_reset_step :
278283 # Re-read the ONNX inputs from the environment after a reset to avoid mismatch between
279284 # ONNX inputs and environment state after reset.
280285 env .context_manager ().read_inputs ()
281286
287+ # Reset the actor state.
288+ actor .reset (torch .tensor ([is_reset_step ], device = env_actions .device ))
289+
282290 # We need to reset the memory inputs from the environment after a reset.
283291 reset_memory_from_env = True
284292
@@ -328,6 +336,9 @@ def reset():
328336 for component in context_manager .get_output_components ()
329337 }
330338
339+ # Compute actions from the new observations.
340+ env_actions = actor (obs )
341+
331342 # Compare outputs from environment and ONNX model.
332343 step_export_ok , msg = _compare_step_outputs (
333344 env_obs = obs ,
0 commit comments