Skip to content

Commit 05eb527

Browse files
dbellicoso-bdaiexploy-bot
authored andcommitted
Evaluator API (#100)
### What change is being made Split evaluation into episodes. ### Why this change is being made Cleanup how resets are handled when evaluating policies. ### Tested Covered by existing tests. GitOrigin-RevId: 93670bfb7a3f7dbc979f2851a1ef9c4a8dc3caa0
1 parent b1b82d4 commit 05eb527

9 files changed

Lines changed: 139 additions & 103 deletions

File tree

docs/tutorial/exporter/exporter_tutorial.md

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ import torch
3333

3434
from exploy.exporter.core.actor import ExportableActor, add_actor_memory
3535
from exploy.exporter.core.context_manager import Group, Input, Memory, Output
36-
from exploy.exporter.core.evaluator import evaluate
36+
from exploy.exporter.core.evaluator import evaluate, evaluate_episode
3737
from exploy.exporter.core.exportable_environment import ExportableEnvironment
3838
from exploy.exporter.core.exporter import export_environment_as_onnx
3939
from exploy.exporter.core.session_wrapper import SessionWrapper
@@ -212,7 +212,7 @@ class ExportableEnv(ExportableEnvironment):
212212
def metadata(self) -> dict:
213213
return {"env_name": "Env", "version": "1.0"}
214214

215-
def register_evaluation_hooks(self, update, reset, evaluate_substep):
215+
def register_evaluation_hooks(self, update, evaluate_substep):
216216
pass
217217

218218
def get_observation_names(self) -> list[str]:
@@ -497,7 +497,8 @@ with torch.inference_mode():
497497
env=exp_env,
498498
context_manager=exp_env.context_manager(),
499499
session_wrapper=session_wrapper,
500-
num_steps=20,
500+
num_episodes=1,
501+
max_episode_steps=20,
501502
verbose=True,
502503
pause_on_failure=False,
503504
)
@@ -511,6 +512,23 @@ with torch.inference_mode():
511512
If `export_ok` is `False`, the evaluator prints a detailed diagnostic showing which outputs
512513
diverged and at which step.
513514

515+
Under the hood, {py:func}`evaluate() <exploy.exporter.core.evaluator.evaluate>` calls
516+
{py:func}`evaluate_episode() <exploy.exporter.core.evaluator.evaluate_episode>` once per episode.
517+
You can call `evaluate_episode()` directly when you want finer control — for example, to run and
518+
inspect a single episode in a tight debugging loop without the outer episode iteration:
519+
520+
```python
521+
with torch.inference_mode():
522+
episode_ok, observations = evaluate_episode(
523+
env=exp_env,
524+
context_manager=exp_env.context_manager(),
525+
session_wrapper=session_wrapper,
526+
max_num_steps=20,
527+
verbose=True,
528+
pause_on_failure=False,
529+
)
530+
```
531+
514532
---
515533

516534
## Advanced: Using Torch Modules in Observations
@@ -649,7 +667,8 @@ def export_and_evaluate(
649667
exp_env: ExportableEnv,
650668
actor: ExportableActor,
651669
onnx_file_name: str,
652-
num_eval_steps: int,
670+
num_eval_episodes: int,
671+
max_eval_steps_per_episode: int,
653672
) -> bool:
654673
# Register inputs, outputs, and memory.
655674
exp_env.context_manager().add_components(
@@ -714,7 +733,8 @@ def export_and_evaluate(
714733
env=exp_env,
715734
context_manager=exp_env.context_manager(),
716735
session_wrapper=session_wrapper,
717-
num_steps=num_eval_steps,
736+
num_episodes=num_eval_episodes,
737+
max_episode_steps=max_eval_steps_per_episode,
718738
verbose=False,
719739
pause_on_failure=False,
720740
)
@@ -730,7 +750,7 @@ env = Environment(data_source=data_source)
730750
exp_env = ExportableEnv(env=env)
731751
actor = Actor(num_obs=env.num_obs, num_act=env.num_act).eval()
732752

733-
assert export_and_evaluate(exp_env, actor, "policy.onnx", num_eval_steps=20)
753+
assert export_and_evaluate(exp_env, actor, "policy.onnx", num_eval_episodes=1, max_eval_steps_per_episode=20)
734754
```
735755

736756
### Environment with a torch module
@@ -742,7 +762,7 @@ exp_env = ExportableEnv(env=env)
742762
actor = Actor(num_obs=env.num_obs, num_act=env.num_act).eval()
743763
exp_env.context_manager().add_module(env.module)
744764

745-
assert export_and_evaluate(exp_env, actor, "policy_with_module.onnx", num_eval_steps=20)
765+
assert export_and_evaluate(exp_env, actor, "policy_with_module.onnx", num_eval_episodes=1, max_eval_steps_per_episode=20)
746766
```
747767

748768
### Environment with a torch module and an RNN actor
@@ -760,5 +780,5 @@ add_actor_memory(
760780
get_hidden_states_func=actor.get_state,
761781
)
762782

763-
assert export_and_evaluate(exp_env, actor, "policy_with_rnn.onnx", num_eval_steps=20)
783+
assert export_and_evaluate(exp_env, actor, "policy_with_rnn.onnx", num_eval_episodes=1, max_eval_steps_per_episode=20)
764784
```

examples/exporter_scripts/isaaclab/export_isaaclab.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,13 +164,15 @@ def export_isaaclab(
164164
)
165165

166166
# Evaluate.
167-
evaluate_steps = 200
167+
evaluate_episodes = 2
168+
evaluate_steps = 100
168169
with torch.inference_mode():
169170
export_ok, _ = evaluate(
170171
env=exportable_env,
171172
context_manager=exportable_env.context_manager(),
172173
session_wrapper=session_wrapper,
173-
num_steps=evaluate_steps,
174+
num_episodes=evaluate_episodes,
175+
max_episode_steps=evaluate_steps,
174176
verbose=True,
175177
pause_on_failure=pause_on_failure,
176178
)

python/exploy/exporter/core/evaluator.py

Lines changed: 75 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ def _print_progress_bar(
1818
num_steps: int,
1919
failed_steps: int,
2020
step_export_ok: bool,
21-
is_reset_step: bool,
2221
inference_times: list[float],
2322
) -> None:
2423
"""Print progress bar with step information.
@@ -28,7 +27,6 @@ def _print_progress_bar(
2827
num_steps: Total number of steps.
2928
failed_steps: Number of failed steps so far.
3029
step_export_ok: Whether current step passed validation.
31-
is_reset_step: Whether environment was reset this step.
3230
inference_times: List of inference times.
3331
"""
3432
status_emoji = "🔴" if not step_export_ok else "🟢"
@@ -41,8 +39,6 @@ def _print_progress_bar(
4139
mean_time = np.mean(inference_times) * 1.0e3
4240
std_time = np.std(inference_times) * 1.0e3
4341
extra_info.append(f"⏱️ μ={mean_time:.3f}ms σ={std_time:.3f}ms")
44-
if is_reset_step:
45-
extra_info.append("RESET")
4642
extra_str = " | ".join(extra_info)
4743

4844
print(
@@ -189,7 +185,8 @@ def evaluate(
189185
env: ExportableEnvironment,
190186
context_manager: ContextManager,
191187
session_wrapper: SessionWrapper,
192-
num_steps: int,
188+
num_episodes: int,
189+
max_episode_steps: int | None = None,
193190
verbose: bool = True,
194191
reset_from_onnx_counter_steps: int = 50,
195192
atol: float = 1.0e-5,
@@ -198,15 +195,16 @@ def evaluate(
198195
) -> tuple[bool, torch.Tensor]:
199196
"""Evaluate an ONNX exported model against an `ExportableEnvironment` stepped through a `SessionWrapper`.
200197
201-
This function runs the simulation for a specified number of steps and compares the
202-
outputs of the ONNX model with the environment's state and actor's actions at each step.
203-
This is useful for verifying the correctness of the ONNX export.
198+
This function runs the simulation for a specified number of episodes, each with a maximum number
199+
of steps, and compares the outputs of the ONNX model with the environment's state and actor's
200+
actions at each step. This is useful for verifying the correctness of the ONNX export.
204201
205202
Args:
206203
env: The environment to run the evaluation in.
207204
context_manager: The context manager handling inputs and outputs.
208205
session_wrapper: An ONNX session wrapper.
209-
num_steps: The number of steps to run the evaluation for.
206+
num_episodes: The number of episodes to run the evaluation for.
207+
max_episode_steps: The maximum number of steps per episode.
210208
verbose: Whether to print verbose output during evaluation. Defaults to True.
211209
reset_from_onnx_counter_steps: Set after how many steps we should set memory inputs from ONNX instead of using
212210
the environment's state.
@@ -218,14 +216,76 @@ def evaluate(
218216
Note: this value is chosen arbitrarily.
219217
atol: Absolute tolerance used to compare tensors.
220218
rtol: Relative tolerance used to compare tensors.
219+
pause_on_failure: Whether to pause on each failed step and wait for user input before
220+
continuing. Defaults to True.
221221
222222
Returns:
223223
A tuple containing a boolean indicating if the evaluation was successful and
224224
the final observations tensor.
225225
"""
226+
if verbose:
227+
print("Starting evaluation...")
228+
229+
for i_episode in range(num_episodes):
230+
if verbose:
231+
print(f"\nStarting episode {i_episode + 1}/{num_episodes}...")
232+
export_ok, final_obs = evaluate_episode(
233+
env=env,
234+
context_manager=context_manager,
235+
session_wrapper=session_wrapper,
236+
max_num_steps=max_episode_steps,
237+
verbose=verbose,
238+
reset_from_onnx_counter_steps=reset_from_onnx_counter_steps,
239+
atol=atol,
240+
rtol=rtol,
241+
pause_on_failure=pause_on_failure,
242+
)
243+
return export_ok, final_obs
244+
245+
246+
def evaluate_episode(
247+
env: ExportableEnvironment,
248+
context_manager: ContextManager,
249+
session_wrapper: SessionWrapper,
250+
max_num_steps: int | None = None,
251+
verbose: bool = True,
252+
reset_from_onnx_counter_steps: int = 50,
253+
atol: float = 1.0e-5,
254+
rtol: float = 1.0e-5,
255+
pause_on_failure: bool = True,
256+
):
257+
"""Run evaluation for a single episode, comparing the ONNX model outputs against the environment.
258+
259+
Steps the environment and the ONNX session in lockstep, comparing observations, actions, and
260+
outputs at each step. Useful for fine-grained inspection of a single episode, e.g. in a
261+
debugging loop, without the outer episode iteration of :func:`evaluate`.
262+
263+
Args:
264+
env: The environment to run the evaluation in.
265+
context_manager: The context manager handling inputs and outputs.
266+
session_wrapper: An ONNX session wrapper.
267+
max_num_steps: The maximum number of steps to run. If ``None``, runs until the environment
268+
signals a reset.
269+
verbose: Whether to print verbose output during evaluation. Defaults to True.
270+
reset_from_onnx_counter_steps: Set after how many steps we should set memory inputs from
271+
ONNX instead of using the environment's state.
272+
273+
Note: we do this to avoid numerical error accumulation that would occur if we only
274+
ever use the ONNX inference outputs as memory fed back as ONNX inference inputs,
275+
while all other inputs are set directly from the environment's state.
276+
277+
Note: this value is chosen arbitrarily.
278+
atol: Absolute tolerance used to compare tensors.
279+
rtol: Relative tolerance used to compare tensors.
280+
pause_on_failure: Whether to pause on each failed step and wait for user input before
281+
continuing. Defaults to True.
226282
227-
# Reset both the environment and the actor.
283+
Returns:
284+
A tuple containing a boolean indicating if the episode evaluation was successful and
285+
the final observations tensor.
286+
"""
228287
obs = env.observations_reset()
288+
session_wrapper.reset()
229289

230290
actor = session_wrapper.get_actor()
231291
if actor is None:
@@ -273,15 +333,8 @@ def update():
273333
"""
274334
context_manager.read_inputs()
275335

276-
def reset():
277-
"""Callback passed to the environment's evaluation hooks to reset the
278-
context manager's inputs from the environment's state at each reset.
279-
"""
280-
context_manager.read_inputs()
281-
282336
env.register_evaluation_hooks(
283337
update=update,
284-
reset=reset,
285338
evaluate_substep=evaluate_substep,
286339
)
287340

@@ -291,7 +344,7 @@ def reset():
291344
reset_memory_from_env = False
292345
env.context_manager().read_inputs()
293346

294-
while step_ctr < num_steps:
347+
while step_ctr < max_num_steps if max_num_steps is not None else True:
295348
reset_memory_from_env = (
296349
reset_memory_from_env or (step_ctr % reset_from_onnx_counter_steps) == 0
297350
)
@@ -301,18 +354,9 @@ def reset():
301354

302355
# Check if the environment was reset.
303356
if is_reset_step:
304-
# Re-read the ONNX inputs from the environment after a reset to avoid mismatch between
305-
# ONNX inputs and environment state after reset.
306-
env.context_manager().read_inputs()
307-
308-
# Reset the actor state.
309-
actor.reset(torch.tensor([is_reset_step], device=env_actions.device))
310-
311-
# We need to reset the memory inputs from the environment after a reset.
312-
reset_memory_from_env = True
313-
314-
# Reset the session wrapper results to avoid using stale outputs.
315-
session_wrapper._results = None
357+
if verbose:
358+
print(f"\n🔄 Environment reset at step {step_ctr + 1}.")
359+
break
316360

317361
# Get onnx outputs if the session has been run.
318362
ort_outputs = (
@@ -383,16 +427,13 @@ def reset():
383427

384428
# Display progress bar.
385429
if verbose:
386-
if step_ctr == 0:
387-
print("\n\nStarting evaluation...")
388430
if not step_export_ok:
389431
print(msg)
390432
_print_progress_bar(
391433
step_ctr=step_ctr,
392-
num_steps=num_steps,
434+
num_steps=max_num_steps,
393435
failed_steps=failed_steps,
394436
step_export_ok=step_export_ok,
395-
is_reset_step=is_reset_step,
396437
inference_times=inference_times,
397438
)
398439

python/exploy/exporter/core/exportable_environment.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ def decimation(self) -> int:
5959
def register_evaluation_hooks(
6060
self,
6161
update: Callable[[], None],
62-
reset: Callable[[], None],
6362
evaluate_substep: Callable[[int], None],
6463
):
6564
"""Register evaluation hooks for this environment."""

python/exploy/exporter/core/session_wrapper.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import pathlib
44

5-
import numpy as np
65
import onnxruntime as ort
76

87
from exploy.exporter.core.actor import ExportableActor
@@ -105,4 +104,4 @@ def get_output_value(self, output_name: str):
105104

106105
def reset(self):
107106
"""Reset the internal results to zeros to avoid stale data at environment reset."""
108-
self._results = [np.zeros_like(output) for output in self._results]
107+
self._results = None

python/exploy/exporter/core/tests/test_export_environment.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def get_observation_names(self) -> list[str]:
213213
def observations_reset(self) -> torch.Tensor:
214214
return self.env.compute_obs()
215215

216-
def register_evaluation_hooks(self, update, reset, evaluate_substep):
216+
def register_evaluation_hooks(self, update, evaluate_substep):
217217
pass
218218

219219
def metadata(self) -> dict:
@@ -309,7 +309,8 @@ def export_and_evaluate_env(
309309
exp_env: ExportableEnv,
310310
actor: ExportableActor,
311311
onnx_file_name: str,
312-
num_eval_steps: int,
312+
num_eval_episodes: int,
313+
max_eval_steps_per_episode: int,
313314
) -> bool:
314315
"""Helper function to export an environment and evaluate it using the exported ONNX graph."""
315316
exp_env.context_manager().add_components(
@@ -372,13 +373,13 @@ def export_and_evaluate_env(
372373
)
373374

374375
# Evaluate.
375-
evaluate_steps = num_eval_steps
376376
with torch.inference_mode():
377377
export_ok, _ = evaluate(
378378
env=exp_env,
379379
context_manager=exp_env.context_manager(),
380380
session_wrapper=session_wrapper,
381-
num_steps=evaluate_steps,
381+
num_episodes=num_eval_episodes,
382+
max_episode_steps=max_eval_steps_per_episode,
382383
verbose=False,
383384
pause_on_failure=False,
384385
)
@@ -397,7 +398,8 @@ def test_env(self):
397398
exp_env=exp_env,
398399
actor=actor,
399400
onnx_file_name="test_export_env.onnx",
400-
num_eval_steps=20,
401+
num_eval_episodes=2,
402+
max_eval_steps_per_episode=20,
401403
)
402404
assert export_ok, "ONNX export validation failed"
403405

@@ -414,7 +416,8 @@ def test_env_with_module(self):
414416
exp_env=exp_env,
415417
actor=actor,
416418
onnx_file_name="test_export_env_with_module.onnx",
417-
num_eval_steps=20,
419+
num_eval_episodes=2,
420+
max_eval_steps_per_episode=20,
418421
)
419422
assert export_ok, "ONNX export validation failed"
420423

@@ -440,6 +443,7 @@ def test_env_with_module_and_rnn_actor(self):
440443
exp_env=exp_env,
441444
actor=actor,
442445
onnx_file_name="test_export_env_with_rnn_actor.onnx",
443-
num_eval_steps=20,
446+
num_eval_episodes=2,
447+
max_eval_steps_per_episode=20,
444448
)
445449
assert export_ok, "ONNX export validation failed"

0 commit comments

Comments
 (0)