Skip to content

Commit b1b82d4

Browse files
dbellicoso-bdaiexploy-bot
authored andcommitted
Fix output evaluation (#96)
### What change is being made Fix output evaluation. ### Why this change is being made N/A ### Tested N/A GitOrigin-RevId: 0d650ef25596ce22ac407cddc01098a6939ac67b
1 parent 3e08a9f commit b1b82d4

1 file changed

Lines changed: 36 additions & 8 deletions

File tree

python/exploy/exporter/core/evaluator.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,8 @@ def _compare_step_outputs(
109109
)
110110
msg += "\n • Review compute_observation() implementation for data flow correctness"
111111
return False, msg
112+
else:
113+
msg += "\n✅ Observations match between environment and ONNX model."
112114

113115
step_export_ok = step_export_ok and obs_ok
114116

@@ -133,6 +135,8 @@ def _compare_step_outputs(
133135
msg += "\n • Verify actor network matches between env and ONNX"
134136
msg += "\n • Ensure action normalizer parameters are correctly exported"
135137
return False, msg
138+
else:
139+
msg += "\n✅ Actions match between environment and ONNX model."
136140

137141
step_export_ok = step_export_ok and actions_ok
138142

@@ -173,6 +177,8 @@ def _compare_step_outputs(
173177
"\n • Review process_action() and apply_action() implementations for consistency"
174178
)
175179
return False, msg
180+
else:
181+
msg += "\n✅ Outputs match between environment and ONNX model."
176182

177183
step_export_ok = step_export_ok and outputs_ok
178184

@@ -233,8 +239,21 @@ def evaluate(
233239
failed_steps = 0
234240
inference_times = []
235241

236-
# Evaluate a single substep at sim dt.
242+
# Hold a dictionary of environment outputs to compare against ONNX outputs at each step.
243+
# This is populated at each decimation step of each environment update, and later
244+
# compared against the ONNX outputs after running inference.
245+
env_outputs = {}
246+
237247
def evaluate_substep(step_ctr: int):
248+
"""Evaluate a single substep at sim dt.
249+
250+
Args:
251+
step_ctr: The current decimation step counter.
252+
"""
253+
# Get the environment's outputs.
254+
for component in context_manager.get_output_components():
255+
env_outputs[component.output_name] = component.get_from_env_cb().clone().cpu()
256+
238257
# Skip first step, as we evaluate the policy in the main evaluation loop before calling env.step().
239258
# Skip if we have not run the session yet.
240259
if step_ctr == 0 or session_wrapper._results is None:
@@ -249,9 +268,15 @@ def evaluate_substep(step_ctr: int):
249268
session_wrapper(**onnx_inputs)
250269

251270
def update():
271+
"""Callback passed to the environment's evaluation hooks to update the
272+
context manager's inputs from the environment's state at each step.
273+
"""
252274
context_manager.read_inputs()
253275

254276
def reset():
277+
"""Callback passed to the environment's evaluation hooks to reset the
278+
context manager's inputs from the environment's state at each reset.
279+
"""
255280
context_manager.read_inputs()
256281

257282
env.register_evaluation_hooks(
@@ -321,16 +346,17 @@ def reset():
321346
t_inference_s = time.perf_counter() - t_start
322347
inference_times.append(t_inference_s)
323348

324-
# Get observations and actions. Needs to be called before env.step() to get them
325-
# from the full model.
349+
# Get observations and actions.
326350
ort_observations = torch.from_numpy(session_wrapper.get_output_value("obs")).clone()
327351
ort_actions = torch.from_numpy(session_wrapper.get_output_value("actions")).clone()
328352

329-
# Get the environment's outputs.
330-
env_outputs = {
331-
component.output_name: component.get_from_env_cb().clone().cpu()
332-
for component in context_manager.get_output_components()
333-
}
353+
if not env_outputs:
354+
# The env_outputs dict is empty. This happens if the exportable environment
355+
# does not register its evaluation hooks. Populate the environment outputs here.
356+
env_outputs = {
357+
component.output_name: component.get_from_env_cb().clone().cpu()
358+
for component in context_manager.get_output_components()
359+
}
334360

335361
# Compute actions from the new observations.
336362
env_actions = actor(obs)
@@ -349,6 +375,8 @@ def reset():
349375
rtol=rtol,
350376
)
351377

378+
env_outputs = {}
379+
352380
export_ok = export_ok and step_export_ok
353381
if not step_export_ok:
354382
failed_steps += 1

0 commit comments

Comments
 (0)