@@ -109,6 +109,8 @@ def _compare_step_outputs(
109109 )
110110 msg += "\n • Review compute_observation() implementation for data flow correctness"
111111 return False , msg
112+ else :
113+ msg += "\n ✅ Observations match between environment and ONNX model."
112114
113115 step_export_ok = step_export_ok and obs_ok
114116
@@ -133,6 +135,8 @@ def _compare_step_outputs(
133135 msg += "\n • Verify actor network matches between env and ONNX"
134136 msg += "\n • Ensure action normalizer parameters are correctly exported"
135137 return False , msg
138+ else :
139+ msg += "\n ✅ Actions match between environment and ONNX model."
136140
137141 step_export_ok = step_export_ok and actions_ok
138142
@@ -173,6 +177,8 @@ def _compare_step_outputs(
173177 "\n • Review process_action() and apply_action() implementations for consistency"
174178 )
175179 return False , msg
180+ else :
181+ msg += "\n ✅ Outputs match between environment and ONNX model."
176182
177183 step_export_ok = step_export_ok and outputs_ok
178184
@@ -233,8 +239,21 @@ def evaluate(
233239 failed_steps = 0
234240 inference_times = []
235241
236- # Evaluate a single substep at sim dt.
242+ # Hold a dictionary of environment outputs to compare against ONNX outputs at each step.
243+ # This is populated at each decimation step of each environment update, and later
244+ # compared against the ONNX outputs after running inference.
245+ env_outputs = {}
246+
237247 def evaluate_substep (step_ctr : int ):
248+ """Evaluate a single substep at sim dt.
249+
250+ Args:
251+ step_ctr: The current decimation step counter.
252+ """
253+ # Get the environment's outputs.
254+ for component in context_manager .get_output_components ():
255+ env_outputs [component .output_name ] = component .get_from_env_cb ().clone ().cpu ()
256+
238257 # Skip first step, as we evaluate the policy in the main evaluation loop before calling env.step().
239258 # Skip if we have not run the session yet.
240259 if step_ctr == 0 or session_wrapper ._results is None :
@@ -249,9 +268,15 @@ def evaluate_substep(step_ctr: int):
249268 session_wrapper (** onnx_inputs )
250269
251270 def update ():
271+ """Callback passed to the environment's evaluation hooks to update the
272+ context manager's inputs from the environment's state at each step.
273+ """
252274 context_manager .read_inputs ()
253275
254276 def reset ():
277+ """Callback passed to the environment's evaluation hooks to reset the
278+ context manager's inputs from the environment's state at each reset.
279+ """
255280 context_manager .read_inputs ()
256281
257282 env .register_evaluation_hooks (
@@ -321,16 +346,17 @@ def reset():
321346 t_inference_s = time .perf_counter () - t_start
322347 inference_times .append (t_inference_s )
323348
324- # Get observations and actions. Needs to be called before env.step() to get them
325- # from the full model.
349+ # Get observations and actions.
326350 ort_observations = torch .from_numpy (session_wrapper .get_output_value ("obs" )).clone ()
327351 ort_actions = torch .from_numpy (session_wrapper .get_output_value ("actions" )).clone ()
328352
329- # Get the environment's outputs.
330- env_outputs = {
331- component .output_name : component .get_from_env_cb ().clone ().cpu ()
332- for component in context_manager .get_output_components ()
333- }
353+ if not env_outputs :
354+ # The env_outputs dict is empty. This happens if the exportable environment
355+ # does not register its evaluation hooks. Populate the environment outputs here.
356+ env_outputs = {
357+ component .output_name : component .get_from_env_cb ().clone ().cpu ()
358+ for component in context_manager .get_output_components ()
359+ }
334360
335361 # Compute actions from the new observations.
336362 env_actions = actor (obs )
@@ -349,6 +375,8 @@ def reset():
349375 rtol = rtol ,
350376 )
351377
378+ env_outputs = {}
379+
352380 export_ok = export_ok and step_export_ok
353381 if not step_export_ok :
354382 failed_steps += 1
0 commit comments