Skip to content

Commit cb2622c

Browse files
committed
added some import gaurds, improved code readability
1 parent 0904877 commit cb2622c

4 files changed

Lines changed: 34 additions & 14 deletions

File tree

scripts/reinforcement_learning/rsl_rl/export.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,13 @@
1515
import time
1616
from collections.abc import Mapping
1717

18-
import leapp
1918
import torch
20-
from leapp import annotate
19+
20+
try:
21+
import leapp
22+
from leapp import annotate
23+
except ImportError as e:
24+
raise ImportError("LEAPP package is required for policy export. Install with: pip install leapp") from e
2125

2226
# Disable TorchScript before importing task/environment modules so any
2327
# @torch.jit.script helpers resolve to plain Python functions during export.

source/isaaclab/isaaclab/envs/direct_deployment_env.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,11 @@
2121

2222
import torch
2323
import yaml
24-
from leapp import InferenceManager
24+
25+
try:
26+
from leapp import InferenceManager
27+
except ImportError as e:
28+
raise ImportError("LEAPP package is required for policy deployment testing.Install with: pip install leapp") from e
2529

2630
from isaaclab.managers import CommandManager, EventManager
2731
from isaaclab.scene import InteractiveScene

source/isaaclab/isaaclab/utils/leapp/export_annotator.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,9 @@
5858
from isaaclab.envs import ManagerBasedEnv
5959

6060

61+
VARIABLE_IMPEDANCE_MODES = frozenset({"variable", "variable_kp"})
62+
63+
6164
# ══════════════════════════════════════════════════════════════════
6265
# ExportPatcher
6366
# ══════════════════════════════════════════════════════════════════
@@ -139,24 +142,33 @@ def _disable_training_managers(unwrapped):
139142
"""
140143
num_envs = unwrapped.num_envs
141144
device = unwrapped.device
145+
_zero_reward = torch.zeros(num_envs, device=device)
146+
_no_termination = torch.zeros(num_envs, dtype=torch.bool, device=device)
147+
148+
def _noop_curriculum(env_ids=None):
149+
return None
150+
151+
def _zero_reward_compute(dt):
152+
return _zero_reward
153+
154+
def _no_termination_compute():
155+
return _no_termination
156+
157+
def _noop(*args, **kwargs):
158+
return None
142159

143160
if hasattr(unwrapped, "curriculum_manager"):
144-
unwrapped.curriculum_manager.compute = lambda env_ids=None: None
161+
unwrapped.curriculum_manager.compute = _noop_curriculum
145162

146163
if hasattr(unwrapped, "reward_manager"):
147-
_zero_reward = torch.zeros(num_envs, device=device)
148-
unwrapped.reward_manager.compute = lambda dt: _zero_reward
164+
unwrapped.reward_manager.compute = _zero_reward_compute
149165

150166
if hasattr(unwrapped, "termination_manager"):
151-
_no_termination = torch.zeros(num_envs, dtype=torch.bool, device=device)
152-
unwrapped.termination_manager.compute = lambda: _no_termination
167+
unwrapped.termination_manager.compute = _no_termination_compute
153168

154169
if hasattr(unwrapped, "recorder_manager"):
155170
rm = unwrapped.recorder_manager
156171

157-
def _noop(*args, **kwargs):
158-
return None
159-
160172
rm.record_pre_step = _noop
161173
rm.record_post_step = _noop
162174
rm.record_pre_reset = _noop
@@ -415,7 +427,7 @@ def _collect_action_outputs(self, action_manager) -> list[TensorSemantics]:
415427
tensors: list[TensorSemantics] = []
416428
for term_name, term in action_manager._terms.items():
417429
osc = getattr(term, "_osc", None)
418-
if osc and hasattr(osc, "cfg") and osc.cfg.impedance_mode in ["variable", "variable_kp"]:
430+
if osc and hasattr(osc, "cfg") and osc.cfg.impedance_mode in VARIABLE_IMPEDANCE_MODES:
419431
asset = getattr(term, "_asset", None)
420432
real_asset = getattr(asset, "_real_asset", asset)
421433
joint_ids = getattr(term, "_joint_ids", None)
@@ -490,7 +502,7 @@ def _collect_action_static_outputs(
490502
if skip_terms and term_name in skip_terms:
491503
continue
492504
osc = getattr(term, "_osc", None)
493-
if osc and hasattr(osc, "cfg") and osc.cfg.impedance_mode in ["variable", "variable_kp"]:
505+
if osc and hasattr(osc, "cfg") and osc.cfg.impedance_mode in VARIABLE_IMPEDANCE_MODES:
494506
continue
495507
asset = getattr(term, "_asset", None)
496508
real_asset = getattr(asset, "_real_asset", asset)

source/isaaclab_rl/test/export/test_rsl_rl_direct_export_flow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def test_direct_env_export_flow():
170170
cwd=_REPO_ROOT,
171171
capture_output=True,
172172
text=True,
173-
timeout=600,
173+
timeout=6000,
174174
)
175175

176176
if result.returncode != 0:

0 commit comments

Comments
 (0)