Skip to content

Commit 51515b8

Browse files
dbellicoso-bdaiexploy-bot
authored andcommitted
Add support for actor states (#55)
# Pull Request ### What change is being made Added support for actor states. Added example that uses new features. ### Why this change is being made Add missing feature. ### Tested Usage covered in example script. Should be unit tested. GitOrigin-RevId: 6009fff9166d1924002161b98c6e924309b1485e
1 parent 18eac1e commit 51515b8

12 files changed

Lines changed: 353 additions & 39 deletions

File tree

.vscode/launch.json

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,20 @@
1313
"python": "${workspaceFolder}/.pixi/envs/isaaclab/bin/python",
1414
"justMyCode": false
1515
},
16+
{
17+
"name": "[Core] Tests.",
18+
"type": "debugpy",
19+
"request": "launch",
20+
"module": "pytest",
21+
"args": [
22+
"exploy/exporter/core/tests/test_export_environment.py"
23+
],
24+
"console": "integratedTerminal",
25+
"python": "${workspaceFolder}/.pixi/envs/core/bin/python",
26+
"justMyCode": false,
27+
"env": {
28+
"PYTHONPATH": "${workspaceFolder}"
29+
}
30+
},
1631
]
1732
}

examples/exporter_scripts/export_isaaclab.py renamed to examples/exporter_scripts/isaaclab/export_isaaclab.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,19 @@ def make_simulation_app() -> tuple[SimulationApp, argparse.Namespace]:
5555
from isaaclab.sim import SimulationContext
5656
from isaaclab_rl.rsl_rl import RslRlVecEnvWrapper
5757
from isaaclab_tasks.utils import parse_env_cfg
58+
from rsl_rl.algorithms.ppo import PPO
5859
from rsl_rl.runners import OnPolicyRunner
5960

6061
from exploy.exporter.core.evaluator import evaluate
6162
from exploy.exporter.core.exporter import export_environment_as_onnx
6263
from exploy.exporter.core.session_wrapper import SessionWrapper
63-
from exploy.exporter.frameworks.isaaclab import inputs, memory, outputs
64+
from exploy.exporter.frameworks.isaaclab import (
65+
environments, # noqa: F401
66+
inputs,
67+
memory,
68+
outputs,
69+
)
70+
from exploy.exporter.frameworks.isaaclab.actor import make_exportable_actor
6471
from exploy.exporter.frameworks.isaaclab.env import IsaacLabExportableEnvironment
6572

6673

@@ -82,17 +89,17 @@ def export_isaaclab(
8289
env = RslRlVecEnvWrapper(gym.make(task_name, cfg=env_cfg, render_mode=None))
8390
runner = OnPolicyRunner(env, agent_cfg.to_dict(), log_dir=test_dir, device=agent_cfg.device)
8491

85-
# Get the policy and its normalizer.
86-
policy = runner.alg.policy.actor.to(env.device)
87-
normalizer = runner.alg.policy.actor_obs_normalizer.to(env.device)
88-
actor = torch.nn.Sequential(normalizer, policy).eval()
89-
9092
# Export to ONNX.
9193
onnx_export_dir = test_dir
9294
onnx_export_file = "test_export.onnx"
9395

9496
exportable_env = IsaacLabExportableEnvironment(env.unwrapped)
9597

98+
# Get the policy and its normalizer.
99+
alg: PPO = runner.alg
100+
assert isinstance(alg, PPO), f"Expected PPO algorithm, got: {type(alg).__name__}"
101+
actor = make_exportable_actor(exportable_env, alg.policy, device=task_device)
102+
96103
articulations = env.unwrapped.scene.articulations
97104
context_manager = exportable_env.context_manager()
98105

@@ -145,7 +152,7 @@ def export_isaaclab(
145152
session_wrapper = SessionWrapper(
146153
onnx_folder=onnx_export_dir,
147154
onnx_file_name=onnx_export_file,
148-
policy=actor,
155+
actor=actor,
149156
optimize=True,
150157
)
151158

exploy.code-workspace

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
"${workspaceFolder}/.pixi/envs/isaaclab/lib/python3.11/site-packages/isaaclab/source/isaaclab",
1111
"${workspaceFolder}/.pixi/envs/isaaclab/lib/python3.11/site-packages/isaaclab/source/isaaclab_rl",
1212
"${workspaceFolder}/.pixi/envs/isaaclab/lib/python3.11/site-packages/isaaclab/source/isaaclab_tasks",
13-
"${workspaceFolder}/.pixi/envs/isaaclab/lib/python3.11/site-packages/isaacsim/exts/isaacsim.core.utils"
13+
"${workspaceFolder}/.pixi/envs/isaaclab/lib/python3.11/site-packages/isaacsim/exts/isaacsim.core.utils",
14+
"${workspaceFolder}/.pixi/envs/isaaclab/lib/python3.11/site-packages",
1415
]
1516
}
1617
}

exploy/exporter/core/actor.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# Copyright (c) 2026 Robotics and AI Institute LLC dba RAI Institute. All rights reserved.
2+
3+
import abc
4+
from collections.abc import Callable
5+
6+
import torch
7+
8+
from exploy.exporter.core.components import Connection, Memory
9+
from exploy.exporter.core.context_manager import ContextManager
10+
11+
12+
class ExportableActor(torch.nn.Module, abc.ABC):
13+
"""Abstract interface for an actor that can be exported to ONNX."""
14+
15+
def __init__(self):
16+
super().__init__()
17+
18+
@abc.abstractmethod
19+
def forward(self, obs: torch.Tensor) -> torch.Tensor:
20+
"""Given a batch of observations, compute the corresponding actions.
21+
22+
Args:
23+
obs: A tensor of shape (batch_size, obs_dim) containing the observations."""
24+
raise NotImplementedError("forward() method must be implemented by subclasses.")
25+
26+
def reset(self, dones: torch.Tensor):
27+
"""Reset the actor's internal state (e.g., RNN hidden states) based on the done flags.
28+
29+
Args:
30+
dones: A tensor of shape (batch_size,) containing boolean flags indicating which
31+
environments have been reset.
32+
"""
33+
pass
34+
35+
def get_state(self) -> tuple[torch.Tensor, ...] | None:
36+
"""Get the actor's internal state as a tuple of tensors, or None if there is no state."""
37+
return None
38+
39+
40+
def make_exportable_actor(actor: torch.nn.Module) -> ExportableActor:
41+
"""Convert a torch.nn.Module actor to an ExportableActor.
42+
43+
Args:
44+
actor: The actor to convert.
45+
"""
46+
47+
class Actor(ExportableActor):
48+
def __init__(self, actor: torch.nn.Module):
49+
super().__init__()
50+
self._actor = actor
51+
52+
def forward(self, obs: torch.Tensor) -> torch.Tensor:
53+
return self._actor(obs)
54+
55+
return Actor(actor)
56+
57+
58+
def add_actor_memory(
59+
context_manager: ContextManager,
60+
get_hidden_states_func: Callable[[], tuple[torch.Tensor, ...]],
61+
):
62+
"""Add inputs for actor hidden states.
63+
64+
Args:
65+
context_manager: The context manager to add the inputs to.
66+
get_hidden_states_func: A function that returns a tuple of hidden state tensors, used to get the hidden states to add as inputs.
67+
"""
68+
actor_state = get_hidden_states_func()
69+
if actor_state is None:
70+
return
71+
72+
assert isinstance(actor_state, tuple), (
73+
f"Expected actor hidden states to be a tuple of tensors, got: {type(actor_state).__name__}"
74+
)
75+
76+
for i_hs in range(len(actor_state)):
77+
78+
def get_hidden_state(
79+
_i_hs: int = i_hs,
80+
_get_cb: Callable = get_hidden_states_func,
81+
) -> torch.Tensor:
82+
return _get_cb()[_i_hs]
83+
84+
def set_hidden_state(
85+
value: torch.Tensor,
86+
_i_hs: int = i_hs,
87+
_get_cb: Callable = get_hidden_states_func,
88+
):
89+
_get_cb()[_i_hs][:] = value
90+
91+
component_name = f"actor_hidden_state_{i_hs}"
92+
memory_comp = Memory(
93+
name=component_name,
94+
get_from_env_cb=get_hidden_state,
95+
)
96+
context_manager.add_component(memory_comp)
97+
context_manager.add_component(
98+
Connection(
99+
name=f"connection_{component_name}",
100+
getter=memory_comp.get_from_env_cb,
101+
setter=set_hidden_state,
102+
)
103+
)

exploy/exporter/core/evaluator.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -184,25 +184,23 @@ def evaluate(
184184
context_manager: ContextManager,
185185
session_wrapper: SessionWrapper,
186186
num_steps: int,
187-
observations: torch.Tensor | None = None,
188187
verbose: bool = True,
189188
reset_from_onnx_counter_steps: int = 50,
190189
atol: float = 1.0e-5,
191190
rtol: float = 1.0e-5,
192191
pause_on_failure: bool = True,
193192
) -> tuple[bool, torch.Tensor]:
194-
"""Evaluate an ONNX exported model against the original IsaacLab environment and torch policy.
193+
"""Evaluate an ONNX exported model against an `ExportableEnvironment` stepped through a `SessionWrapper`.
195194
196195
This function runs the simulation for a specified number of steps and compares the
197-
outputs of the ONNX model with the environment's state and the original torch model's
198-
outputs at each step. This is useful for verifying the correctness of the ONNX export.
196+
outputs of the ONNX model with the environment's state and actor's actions at each step.
197+
This is useful for verifying the correctness of the ONNX export.
199198
200199
Args:
201200
env: The environment to run the evaluation in.
202201
context_manager: The context manager handling inputs and outputs.
203202
session_wrapper: An ONNX session wrapper.
204203
num_steps: The number of steps to run the evaluation for.
205-
observations: The initial observations. If None, the environment is reset. Defaults to None.
206204
verbose: Whether to print verbose output during evaluation. Defaults to True.
207205
reset_from_onnx_counter_steps: Set after how many steps we should set memory inputs from ONNX instead of using
208206
the environment's state.
@@ -220,7 +218,15 @@ def evaluate(
220218
the final observations tensor.
221219
"""
222220

223-
obs = observations.clone() if observations is not None else env.observations_reset()
221+
# Reset both the environment and the actor.
222+
obs = env.observations_reset()
223+
224+
actor = session_wrapper.get_actor()
225+
if actor is None:
226+
raise ValueError(
227+
"Session wrapper has no actor. Cannot evaluate ONNX model without access to original actor for comparison."
228+
)
229+
actor.reset(torch.tensor([True], device=obs.device))
224230

225231
# Print ONNX graph structure if verbose
226232
if verbose:
@@ -259,9 +265,10 @@ def reset():
259265
)
260266

261267
# Compute actions for the initial observations.
262-
env_actions: torch.Tensor = session_wrapper.get_torch_model()(obs)
268+
env_actions: torch.Tensor = actor(obs)
263269

264270
reset_memory_from_env = False
271+
env.context_manager().read_inputs()
265272

266273
while step_ctr < num_steps:
267274
reset_memory_from_env = (
@@ -270,15 +277,16 @@ def reset():
270277
next_obs, is_reset_step = env.step(env_actions)
271278
# Use the environment's observations for the next step.
272279
obs[:] = next_obs
273-
# Compute actions from the new observations.
274-
env_actions = session_wrapper.get_torch_model()(obs)
275280

276281
# Check if the environment was reset.
277282
if is_reset_step:
278283
# Re-read the ONNX inputs from the environment after a reset to avoid mismatch between
279284
# ONNX inputs and environment state after reset.
280285
env.context_manager().read_inputs()
281286

287+
# Reset the actor state.
288+
actor.reset(torch.tensor([is_reset_step], device=env_actions.device))
289+
282290
# We need to reset the memory inputs from the environment after a reset.
283291
reset_memory_from_env = True
284292

@@ -328,6 +336,9 @@ def reset():
328336
for component in context_manager.get_output_components()
329337
}
330338

339+
# Compute actions from the new observations.
340+
env_actions = actor(obs)
341+
331342
# Compare outputs from environment and ONNX model.
332343
step_export_ok, msg = _compare_step_outputs(
333344
env_obs=obs,

exploy/exporter/core/exporter.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# Copyright (c) 2026 Robotics and AI Institute LLC dba RAI Institute. All rights reserved.
22

3-
import copy
43
import datetime
54
import json
65
import os
@@ -80,7 +79,7 @@ def __init__(
8079
self._env: ExportableEnvironment = env
8180

8281
self.verbose = verbose
83-
self.actor = copy.deepcopy(actor)
82+
self.actor = actor
8483

8584
self.export_mode = ExportMode.Default
8685

exploy/exporter/core/session_wrapper.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
import numpy as np
66
import onnx
77
import onnxruntime as ort
8-
import torch
98
from onnx import helper
109

10+
from exploy.exporter.core.actor import ExportableActor
1111
from exploy.exporter.core.utils.paths import prepare_onnx_paths
1212

1313

@@ -18,15 +18,15 @@ def __init__(
1818
self,
1919
onnx_folder: pathlib.Path,
2020
onnx_file_name: str,
21-
policy: torch.nn.Module | None = None,
21+
actor: ExportableActor | None = None,
2222
optimize: bool = True,
2323
):
2424
"""Construct a `SessionWrapper` to use it for policy inference.
2525
2626
Args:
2727
onnx_folder: The folder containing an ONNX file to load.
2828
onnx_file_name: The name of the ONNX file contained in `ONNX_folder`.
29-
policy: A `torch.nn.Module` representing the actor.
29+
actor: An `ExportableActor` representing the actor.
3030
optimize: If true, optimize the ONNX graph, save it to file, and use it for inference.
3131
"""
3232
# Prepare file paths
@@ -59,7 +59,7 @@ def __init__(
5959
self.session = session
6060
self.input_names = [inp.name for inp in session.get_inputs()]
6161
self.output_names = [val.name for val in session.get_outputs()]
62-
self._policy = policy
62+
self._actor = actor
6363
self.metadata = session.get_modelmeta()
6464

6565
self._results = None
@@ -81,13 +81,13 @@ def __call__(self, **kwargs):
8181
self._results = self.session.run(self.output_names, in_kwargs)
8282
return self._results
8383

84-
def get_torch_model(self) -> torch.nn.Module:
85-
"""Get the original torch policy model.
84+
def get_actor(self) -> ExportableActor | None:
85+
"""Get the original `ExportableActor` object used by this session wrapper.
8686
8787
Returns:
88-
The torch.nn.Module representing the policy, or None if not provided.
88+
The `ExportableActor` representing the actor, or None if not provided.
8989
"""
90-
return self._policy
90+
return self._actor
9191

9292
def get_output_value(self, output_name: str):
9393
"""Get a specific output value from the last inference run.

0 commit comments

Comments
 (0)