Skip to content

Commit 63cacaa

Browse files
jbruedigam-bdaiexploy-bot
authored andcommitted
Add mjlab exporter (#47)
# Pull Request ### What change is being made Please provide a detailed description of WHAT change is being made such that the reader is able to understand the change holistically. ### Why this change is being made Please provide the rationale behind the change and additional context like links to documents, related work or tickets. ### Tested Please provide a description how this change was tested, e.g. unit tests, hardware tests or commands you run. GitOrigin-RevId: a4174dfb9929d45864f0090b23fdf9fe8e9c72bb
1 parent 51515b8 commit 63cacaa

40 files changed

Lines changed: 3628 additions & 1858 deletions

.github/workflows/test.yml

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ jobs:
9090
cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }}
9191
environments: isaaclab
9292

93-
- name: Run Isaac Lab Export Test
93+
- name: Run IsaacLab Export Test
9494
env:
9595
OMNI_KIT_ACCEPT_EULA: yes
9696
run: pixi run -e isaaclab export-isaaclab-ci
@@ -116,3 +116,41 @@ jobs:
116116
env:
117117
OMNI_KIT_ACCEPT_EULA: yes
118118
run: pixi run -e isaaclab test
119+
120+
test-mjlab-export:
121+
name: Test Export mjlab
122+
runs-on: github-gpu-runner
123+
timeout-minutes: 60
124+
steps:
125+
- name: Checkout
126+
uses: actions/checkout@v6
127+
128+
- name: Setup pixi
129+
uses: prefix-dev/setup-pixi@v0.9.4
130+
with:
131+
pixi-version: latest
132+
cache: true
133+
cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }}
134+
environments: mjlab
135+
136+
- name: Run mjlab Export Test
137+
run: pixi run -e mjlab export-mjlab-ci
138+
139+
test-frameworks-mjlab:
140+
name: Test Frameworks - mjlab
141+
runs-on: github-gpu-runner
142+
timeout-minutes: 60
143+
steps:
144+
- name: Checkout
145+
uses: actions/checkout@v6
146+
147+
- name: Setup pixi
148+
uses: prefix-dev/setup-pixi@v0.9.4
149+
with:
150+
pixi-version: latest
151+
cache: true
152+
cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }}
153+
environments: mjlab
154+
155+
- name: Run Frameworks mjlab Tests
156+
run: pixi run -e mjlab test

.ruff.toml

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,8 @@ select = [
1919
ignore = [
2020
"E501", # line-too-long (formatter handles this)
2121
"E402", # module-level-import-not-at-top-of-file (needed for some frameworks)
22-
"F401", # unused-import (may be for re-export)
23-
"F841", # unused-variable (sometimes intentional)
2422
]
2523

26-
[lint.per-file-ignores]
27-
"__init__.py" = ["F401"] # Allow unused imports in __init__.py
28-
"examples/exporter_scripts/export_isaaclab.py" = ["E402", "F401", "F403"] # Allow import flexibility for isaaclab
29-
3024
[format]
3125
quote-style = "double"
3226
indent-style = "space"

.vscode/launch.json

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,20 @@
88
"name": "[Examples] Export IsaacLab task.",
99
"type": "debugpy",
1010
"request": "launch",
11-
"program": "${workspaceFolder}/examples/exporter_scripts/export_isaaclab.py",
11+
"program": "${workspaceFolder}/examples/exporter_scripts/isaaclab/export.py",
1212
"console": "integratedTerminal",
1313
"python": "${workspaceFolder}/.pixi/envs/isaaclab/bin/python",
1414
"justMyCode": false
1515
},
16+
{
17+
"name": "[Examples] Export mjlab task.",
18+
"type": "debugpy",
19+
"request": "launch",
20+
"program": "${workspaceFolder}/examples/exporter_scripts/mjlab/export.py",
21+
"console": "integratedTerminal",
22+
"python": "${workspaceFolder}/.pixi/envs/mjlab/bin/python",
23+
"justMyCode": false,
24+
},
1625
{
1726
"name": "[Core] Tests.",
1827
"type": "debugpy",
@@ -27,6 +36,6 @@
2736
"env": {
2837
"PYTHONPATH": "${workspaceFolder}"
2938
}
30-
},
39+
}
3140
]
3241
}

control/matcher.cpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -295,9 +295,13 @@ std::vector<std::unique_ptr<Input>> DepthImageMatcher::createInputs() const {
295295
// --------------- Body matchers ------------------------------
296296
bool BodyPositionMatcher::matches(const Match& maybe_match) {
297297
std::smatch match;
298-
std::regex pattern = std::regex(
299-
fmt::format("obj\\.({})\\.bodies\\.({})\\.pos_b_rt_w_in_w", alphanumeric, alphanumeric));
298+
std::regex pattern =
299+
std::regex(fmt::format("obj\\.({})\\.({})\\.pos_b_rt_w_in_w", alphanumeric, alphanumeric));
300300
if (std::regex_match(maybe_match.name, match, pattern) && match.size() > 2) {
301+
// Exclude "base" - that's handled by BasePositionMatcher
302+
if (match[2].str() == "base") {
303+
return false;
304+
}
301305
found_matches_[match[2].str()] = maybe_match;
302306
return true;
303307
}
@@ -315,8 +319,12 @@ std::vector<std::unique_ptr<Input>> BodyPositionMatcher::createInputs() const {
315319
bool BodyOrientationMatcher::matches(const Match& maybe_match) {
316320
std::smatch match;
317321
std::regex pattern =
318-
std::regex(fmt::format("obj\\.({})\\.bodies\\.({})\\.w_Q_b", alphanumeric, alphanumeric));
322+
std::regex(fmt::format("obj\\.({})\\.({})\\.w_Q_b", alphanumeric, alphanumeric));
319323
if (std::regex_match(maybe_match.name, match, pattern) && match.size() > 2) {
324+
// Exclude "base" - that's handled by BaseOrientationMatcher
325+
if (match[2].str() == "base") {
326+
return false;
327+
}
320328
found_matches_[match[2].str()] = maybe_match;
321329
return true;
322330
}

control/test/components_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ TEST_F(OnnxComponentsTest, JointVelocityInput_InitAndRead) {
133133
}
134134

135135
TEST_F(OnnxComponentsTest, BodyOrientationInput_InitAndRead) {
136-
BodyOrientationInput body_input("obj.box1.bodies.box.w_Q_b", "test_body");
136+
BodyOrientationInput body_input("obj.box1.box.w_Q_b", "test_body");
137137

138138
// Test initialization
139139
EXPECT_CALL(state_mock_, initBodyOrientationW("test_body")).WillOnce(Return(true));

control/test/testdata/test_onnx_generator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@
4040
"sensor.ray_caster.trail.b",
4141
"sensor.depth_image.one",
4242
# body
43-
"obj.box1.bodies.box.pos_b_rt_w_in_w",
44-
"obj.box1.bodies.box.w_Q_b",
43+
"obj.box1.box.pos_b_rt_w_in_w",
44+
"obj.box1.box.w_Q_b",
4545
# memory
4646
"memory.output.joint_targets.jt1.pos.in",
4747
# step count

docs/api/frameworks/isaaclab.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
Isaac Lab Framework
1+
IsaacLab Framework
22
===================
33

4-
Isaac Lab integration for RL policy export.
4+
IsaacLab integration for RL policy export.
55

66
This module provides the ``IsaacLabExportableEnvironment`` class and utilities
7-
for exporting policies trained in NVIDIA Isaac Lab to ONNX format.
7+
for exporting policies trained in NVIDIA IsaacLab to ONNX format.
88

99
.. automodule:: exploy.exporter.frameworks.isaaclab
1010
:members:

examples/exporter_scripts/isaaclab/export_isaaclab.py renamed to examples/exporter_scripts/isaaclab/export.py

Lines changed: 31 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@
1717

1818
def make_simulation_app() -> tuple[SimulationApp, argparse.Namespace]:
1919
# Create argument parser for headless mode
20-
parser = argparse.ArgumentParser(description="Export Isaac Lab environment to ONNX")
20+
parser = argparse.ArgumentParser(description="Export IsaacLab environment to ONNX")
2121

2222
# Add custom arguments
2323
parser.add_argument(
2424
"--task",
2525
type=str,
26-
default="Isaac-Velocity-Rough-G1-Play-v0",
27-
help="Name of the Isaac Lab task to export (default: Isaac-Velocity-Rough-G1-Play-v0)",
26+
default="IsaacLab-Velocity-Rough-G1-Play-v0",
27+
help="Name of the IsaacLab task to export (default: IsaacLab-Velocity-Rough-G1-Play-v0)",
2828
)
2929
parser.add_argument(
3030
"--pause-on-failure",
@@ -61,20 +61,14 @@ def make_simulation_app() -> tuple[SimulationApp, argparse.Namespace]:
6161
from exploy.exporter.core.evaluator import evaluate
6262
from exploy.exporter.core.exporter import export_environment_as_onnx
6363
from exploy.exporter.core.session_wrapper import SessionWrapper
64-
from exploy.exporter.frameworks.isaaclab import (
65-
environments, # noqa: F401
66-
inputs,
67-
memory,
68-
outputs,
69-
)
70-
from exploy.exporter.frameworks.isaaclab.actor import make_exportable_actor
64+
from exploy.exporter.frameworks.isaaclab import environments # noqa: F401
7165
from exploy.exporter.frameworks.isaaclab.env import IsaacLabExportableEnvironment
66+
from exploy.exporter.frameworks.manager_based import inputs, memory, outputs
67+
from exploy.exporter.frameworks.manager_based.actor import make_exportable_actor
7268

7369

74-
def export_isaaclab(
75-
task_name: str = "Isaac-Velocity-Rough-G1-Play-v0", pause_on_failure: bool = False
76-
):
77-
"""Test Isaac Lab ONNX export and evaluation pipeline."""
70+
def export(task_name: str = "Isaac-Velocity-Rough-G1-Play-v0", pause_on_failure: bool = False):
71+
"""Test IsaacLab ONNX export and evaluation pipeline."""
7872
test_dir = pathlib.Path(__file__).parent / "exporter_tests"
7973

8074
task_device = "cpu"
@@ -93,50 +87,42 @@ def export_isaaclab(
9387
onnx_export_dir = test_dir
9488
onnx_export_file = "test_export.onnx"
9589

96-
exportable_env = IsaacLabExportableEnvironment(env.unwrapped)
90+
unwrapped_env = env.unwrapped
91+
exportable_env = IsaacLabExportableEnvironment(unwrapped_env)
9792

9893
# Get the policy and its normalizer.
9994
alg: PPO = runner.alg
10095
assert isinstance(alg, PPO), f"Expected PPO algorithm, got: {type(alg).__name__}"
10196
actor = make_exportable_actor(exportable_env, alg.policy, device=task_device)
10297

103-
articulations = env.unwrapped.scene.articulations
98+
articulations = unwrapped_env.unwrapped.scene.articulations
10499
context_manager = exportable_env.context_manager()
105100

106-
inputs.add_base_vel(
107-
articulations=articulations,
108-
context_manager=context_manager,
109-
)
110-
111-
inputs.add_body_pos_and_quat(
112-
articulations=articulations,
113-
context_manager=context_manager,
114-
)
101+
inputs.add_base_vel(articulations, context_manager)
115102

116-
inputs.add_commands(
117-
command_manager=env.unwrapped.command_manager,
118-
context_manager=context_manager,
119-
)
103+
inputs.add_body_pos_and_quat(articulations, context_manager)
120104

121-
inputs.add_joint_pos_and_vel(
122-
articulations=articulations,
123-
context_manager=context_manager,
105+
inputs.add_command(
106+
unwrapped_env,
107+
context_manager,
108+
command_name="base_velocity",
109+
command_type="se2_velocity",
124110
)
125111

126-
inputs.add_sensor_inputs(
127-
sensors=env.unwrapped.scene.sensors,
128-
context_manager=context_manager,
129-
)
112+
inputs.add_joint_pos_and_vel(articulations, context_manager)
130113

131-
memory.add_memory(
132-
env=env.unwrapped,
133-
context_manager=context_manager,
134-
)
114+
for sensor_name, sensor in unwrapped_env.scene.sensors.items():
115+
inputs.add_sensor_input(sensor_name, sensor, context_manager)
135116

136-
outputs.add_outputs(
137-
action_manager=env.unwrapped.action_manager,
138-
context_manager=context_manager,
139-
)
117+
memory.add_memory(unwrapped_env, context_manager, attr_name="action")
118+
for action_term_name in unwrapped_env.action_manager.active_terms:
119+
memory.add_memory(
120+
unwrapped_env,
121+
context_manager,
122+
attr_name="processed_actions",
123+
action_term_name=action_term_name,
124+
)
125+
outputs.add_output(unwrapped_env, context_manager, action_term_name=action_term_name)
140126

141127
export_environment_as_onnx(
142128
env=exportable_env,
@@ -181,7 +167,7 @@ def export_isaaclab(
181167
import sys
182168

183169
try:
184-
export_isaaclab(task_name=args.task, pause_on_failure=args.pause_on_failure)
170+
export(task_name=args.task, pause_on_failure=args.pause_on_failure)
185171
except Exception as e:
186172
print(f"❌ Test ERROR: {e}")
187173
import traceback

0 commit comments

Comments
 (0)