@@ -18,7 +18,6 @@ def _print_progress_bar(
1818 num_steps : int ,
1919 failed_steps : int ,
2020 step_export_ok : bool ,
21- is_reset_step : bool ,
2221 inference_times : list [float ],
2322) -> None :
2423 """Print progress bar with step information.
@@ -28,7 +27,6 @@ def _print_progress_bar(
2827 num_steps: Total number of steps.
2928 failed_steps: Number of failed steps so far.
3029 step_export_ok: Whether current step passed validation.
31- is_reset_step: Whether environment was reset this step.
3230 inference_times: List of inference times.
3331 """
3432 status_emoji = "🔴" if not step_export_ok else "🟢"
@@ -41,8 +39,6 @@ def _print_progress_bar(
4139 mean_time = np .mean (inference_times ) * 1.0e3
4240 std_time = np .std (inference_times ) * 1.0e3
4341 extra_info .append (f"⏱️ μ={ mean_time :.3f} ms σ={ std_time :.3f} ms" )
44- if is_reset_step :
45- extra_info .append ("RESET" )
4642 extra_str = " | " .join (extra_info )
4743
4844 print (
@@ -189,7 +185,8 @@ def evaluate(
189185 env : ExportableEnvironment ,
190186 context_manager : ContextManager ,
191187 session_wrapper : SessionWrapper ,
192- num_steps : int ,
188+ num_episodes : int ,
189+ max_episode_steps : int | None = None ,
193190 verbose : bool = True ,
194191 reset_from_onnx_counter_steps : int = 50 ,
195192 atol : float = 1.0e-5 ,
@@ -198,15 +195,16 @@ def evaluate(
198195) -> tuple [bool , torch .Tensor ]:
199196 """Evaluate an ONNX exported model against an `ExportableEnvironment` stepped through a `SessionWrapper`.
200197
201- This function runs the simulation for a specified number of steps and compares the
202- outputs of the ONNX model with the environment's state and actor's actions at each step.
203- This is useful for verifying the correctness of the ONNX export.
198+ This function runs the simulation for a specified number of episodes, each with a maximum number
199+ of steps, and compares the outputs of the ONNX model with the environment's state and actor's
200+ actions at each step. This is useful for verifying the correctness of the ONNX export.
204201
205202 Args:
206203 env: The environment to run the evaluation in.
207204 context_manager: The context manager handling inputs and outputs.
208205 session_wrapper: An ONNX session wrapper.
209- num_steps: The number of steps to run the evaluation for.
206+ num_episodes: The number of episodes to run the evaluation for.
207+ max_episode_steps: The maximum number of steps per episode.
210208 verbose: Whether to print verbose output during evaluation. Defaults to True.
211209 reset_from_onnx_counter_steps: Set after how many steps we should set memory inputs from ONNX instead of using
212210 the environment's state.
@@ -218,14 +216,76 @@ def evaluate(
218216 Note: this value is chosen arbitrarily.
219217 atol: Absolute tolerance used to compare tensors.
220218 rtol: Relative tolerance used to compare tensors.
219+ pause_on_failure: Whether to pause on each failed step and wait for user input before
220+ continuing. Defaults to True.
221221
222222 Returns:
223223 A tuple containing a boolean indicating if the evaluation was successful and
224224 the final observations tensor.
225225 """
226+ if verbose :
227+ print ("Starting evaluation..." )
228+
229+ for i_episode in range (num_episodes ):
230+ if verbose :
231+ print (f"\n Starting episode { i_episode + 1 } /{ num_episodes } ..." )
232+ export_ok , final_obs = evaluate_episode (
233+ env = env ,
234+ context_manager = context_manager ,
235+ session_wrapper = session_wrapper ,
236+ max_num_steps = max_episode_steps ,
237+ verbose = verbose ,
238+ reset_from_onnx_counter_steps = reset_from_onnx_counter_steps ,
239+ atol = atol ,
240+ rtol = rtol ,
241+ pause_on_failure = pause_on_failure ,
242+ )
243+ return export_ok , final_obs
244+
245+
246+ def evaluate_episode (
247+ env : ExportableEnvironment ,
248+ context_manager : ContextManager ,
249+ session_wrapper : SessionWrapper ,
250+ max_num_steps : int | None = None ,
251+ verbose : bool = True ,
252+ reset_from_onnx_counter_steps : int = 50 ,
253+ atol : float = 1.0e-5 ,
254+ rtol : float = 1.0e-5 ,
255+ pause_on_failure : bool = True ,
256+ ):
257+ """Run evaluation for a single episode, comparing the ONNX model outputs against the environment.
258+
259+ Steps the environment and the ONNX session in lockstep, comparing observations, actions, and
260+ outputs at each step. Useful for fine-grained inspection of a single episode, e.g. in a
261+ debugging loop, without the outer episode iteration of :func:`evaluate`.
262+
263+ Args:
264+ env: The environment to run the evaluation in.
265+ context_manager: The context manager handling inputs and outputs.
266+ session_wrapper: An ONNX session wrapper.
267+ max_num_steps: The maximum number of steps to run. If ``None``, runs until the environment
268+ signals a reset.
269+ verbose: Whether to print verbose output during evaluation. Defaults to True.
270+ reset_from_onnx_counter_steps: Set after how many steps we should set memory inputs from
271+ ONNX instead of using the environment's state.
272+
273+ Note: we do this to avoid numerical error accumulation that would occur if we only
274+ ever use the ONNX inference outputs as memory fed back as ONNX inference inputs,
275+ while all other inputs are set directly from the environment's state.
276+
277+ Note: this value is chosen arbitrarily.
278+ atol: Absolute tolerance used to compare tensors.
279+ rtol: Relative tolerance used to compare tensors.
280+ pause_on_failure: Whether to pause on each failed step and wait for user input before
281+ continuing. Defaults to True.
226282
227- # Reset both the environment and the actor.
283+ Returns:
284+ A tuple containing a boolean indicating if the episode evaluation was successful and
285+ the final observations tensor.
286+ """
228287 obs = env .observations_reset ()
288+ session_wrapper .reset ()
229289
230290 actor = session_wrapper .get_actor ()
231291 if actor is None :
@@ -273,15 +333,8 @@ def update():
273333 """
274334 context_manager .read_inputs ()
275335
276- def reset ():
277- """Callback passed to the environment's evaluation hooks to reset the
278- context manager's inputs from the environment's state at each reset.
279- """
280- context_manager .read_inputs ()
281-
282336 env .register_evaluation_hooks (
283337 update = update ,
284- reset = reset ,
285338 evaluate_substep = evaluate_substep ,
286339 )
287340
@@ -291,7 +344,7 @@ def reset():
291344 reset_memory_from_env = False
292345 env .context_manager ().read_inputs ()
293346
294- while step_ctr < num_steps :
347+ while step_ctr < max_num_steps if max_num_steps is not None else True :
295348 reset_memory_from_env = (
296349 reset_memory_from_env or (step_ctr % reset_from_onnx_counter_steps ) == 0
297350 )
@@ -301,18 +354,9 @@ def reset():
301354
302355 # Check if the environment was reset.
303356 if is_reset_step :
304- # Re-read the ONNX inputs from the environment after a reset to avoid mismatch between
305- # ONNX inputs and environment state after reset.
306- env .context_manager ().read_inputs ()
307-
308- # Reset the actor state.
309- actor .reset (torch .tensor ([is_reset_step ], device = env_actions .device ))
310-
311- # We need to reset the memory inputs from the environment after a reset.
312- reset_memory_from_env = True
313-
314- # Reset the session wrapper results to avoid using stale outputs.
315- session_wrapper ._results = None
357+ if verbose :
358+ print (f"\n 🔄 Environment reset at step { step_ctr + 1 } ." )
359+ break
316360
317361 # Get onnx outputs if the session has been run.
318362 ort_outputs = (
@@ -383,16 +427,13 @@ def reset():
383427
384428 # Display progress bar.
385429 if verbose :
386- if step_ctr == 0 :
387- print ("\n \n Starting evaluation..." )
388430 if not step_export_ok :
389431 print (msg )
390432 _print_progress_bar (
391433 step_ctr = step_ctr ,
392- num_steps = num_steps ,
434+ num_steps = max_num_steps ,
393435 failed_steps = failed_steps ,
394436 step_export_ok = step_export_ok ,
395- is_reset_step = is_reset_step ,
396437 inference_times = inference_times ,
397438 )
398439
0 commit comments