|
58 | 58 | from isaaclab.envs import ManagerBasedEnv |
59 | 59 |
|
60 | 60 |
|
| 61 | +VARIABLE_IMPEDANCE_MODES = frozenset({"variable", "variable_kp"}) |
| 62 | + |
| 63 | + |
61 | 64 | # ══════════════════════════════════════════════════════════════════ |
62 | 65 | # ExportPatcher |
63 | 66 | # ══════════════════════════════════════════════════════════════════ |
@@ -139,24 +142,33 @@ def _disable_training_managers(unwrapped): |
139 | 142 | """ |
140 | 143 | num_envs = unwrapped.num_envs |
141 | 144 | device = unwrapped.device |
| 145 | + _zero_reward = torch.zeros(num_envs, device=device) |
| 146 | + _no_termination = torch.zeros(num_envs, dtype=torch.bool, device=device) |
| 147 | + |
| 148 | + def _noop_curriculum(env_ids=None): |
| 149 | + return None |
| 150 | + |
| 151 | + def _zero_reward_compute(dt): |
| 152 | + return _zero_reward |
| 153 | + |
| 154 | + def _no_termination_compute(): |
| 155 | + return _no_termination |
| 156 | + |
| 157 | + def _noop(*args, **kwargs): |
| 158 | + return None |
142 | 159 |
|
143 | 160 | if hasattr(unwrapped, "curriculum_manager"): |
144 | | - unwrapped.curriculum_manager.compute = lambda env_ids=None: None |
| 161 | + unwrapped.curriculum_manager.compute = _noop_curriculum |
145 | 162 |
|
146 | 163 | if hasattr(unwrapped, "reward_manager"): |
147 | | - _zero_reward = torch.zeros(num_envs, device=device) |
148 | | - unwrapped.reward_manager.compute = lambda dt: _zero_reward |
| 164 | + unwrapped.reward_manager.compute = _zero_reward_compute |
149 | 165 |
|
150 | 166 | if hasattr(unwrapped, "termination_manager"): |
151 | | - _no_termination = torch.zeros(num_envs, dtype=torch.bool, device=device) |
152 | | - unwrapped.termination_manager.compute = lambda: _no_termination |
| 167 | + unwrapped.termination_manager.compute = _no_termination_compute |
153 | 168 |
|
154 | 169 | if hasattr(unwrapped, "recorder_manager"): |
155 | 170 | rm = unwrapped.recorder_manager |
156 | 171 |
|
157 | | - def _noop(*args, **kwargs): |
158 | | - return None |
159 | | - |
160 | 172 | rm.record_pre_step = _noop |
161 | 173 | rm.record_post_step = _noop |
162 | 174 | rm.record_pre_reset = _noop |
@@ -415,7 +427,7 @@ def _collect_action_outputs(self, action_manager) -> list[TensorSemantics]: |
415 | 427 | tensors: list[TensorSemantics] = [] |
416 | 428 | for term_name, term in action_manager._terms.items(): |
417 | 429 | osc = getattr(term, "_osc", None) |
418 | | - if osc and hasattr(osc, "cfg") and osc.cfg.impedance_mode in ["variable", "variable_kp"]: |
| 430 | + if osc and hasattr(osc, "cfg") and osc.cfg.impedance_mode in VARIABLE_IMPEDANCE_MODES: |
419 | 431 | asset = getattr(term, "_asset", None) |
420 | 432 | real_asset = getattr(asset, "_real_asset", asset) |
421 | 433 | joint_ids = getattr(term, "_joint_ids", None) |
@@ -490,7 +502,7 @@ def _collect_action_static_outputs( |
490 | 502 | if skip_terms and term_name in skip_terms: |
491 | 503 | continue |
492 | 504 | osc = getattr(term, "_osc", None) |
493 | | - if osc and hasattr(osc, "cfg") and osc.cfg.impedance_mode in ["variable", "variable_kp"]: |
| 505 | + if osc and hasattr(osc, "cfg") and osc.cfg.impedance_mode in VARIABLE_IMPEDANCE_MODES: |
494 | 506 | continue |
495 | 507 | asset = getattr(term, "_asset", None) |
496 | 508 | real_asset = getattr(asset, "_real_asset", asset) |
|
0 commit comments