|
18 | 18 | from gymnasium.envs.mujoco.mujoco_rendering import MujocoRenderer |
19 | 19 | from jax import Array, Device |
20 | 20 |
|
| 21 | +import crazyflow.sim.functional as F |
21 | 22 | from crazyflow.control.control import Control, controllable |
22 | 23 | from crazyflow.exception import ConfigError, NotInitializedError |
23 | 24 | from crazyflow.sim.data import SimControls, SimCore, SimData, SimParams, SimState, SimStateDeriv |
|
29 | 30 | so_rpy_rotor_drag_physics, |
30 | 31 | so_rpy_rotor_physics, |
31 | 32 | ) |
32 | | -from crazyflow.utils import grid_2d, leaf_replace, pytree_replace, to_device |
| 33 | +from crazyflow.utils import grid_2d, leaf_replace, pytree_replace |
33 | 34 |
|
34 | 35 | if TYPE_CHECKING: |
35 | 36 | from mujoco.mjx import Data, Model |
@@ -134,45 +135,15 @@ def step(self, n_steps: int = 1): |
134 | 135 |
|
135 | 136 | def state_control(self, controls: Array): |
136 | 137 | """Set the desired state for all drones in all worlds.""" |
137 | | - assert controls.shape == (self.n_worlds, self.n_drones, 13), "controls shape mismatch" |
138 | | - assert self.control == Control.state, "State control is not enabled by the sim config" |
139 | | - controls = to_device(controls, self.device) |
140 | | - self.data = self.data.replace( |
141 | | - controls=self.data.controls.replace( |
142 | | - state=self.data.controls.state.replace(staged_cmd=controls) |
143 | | - ) |
144 | | - ) |
| 138 | + self.data = F.state_control(self.data, controls) |
145 | 139 |
|
146 | 140 | def attitude_control(self, controls: Array): |
147 | | - """Set the desired attitude for all drones in all worlds. |
| 141 | + """Set the desired attitude for all drones in all worlds.""" |
| 142 | + self.data = F.attitude_control(self.data, controls) |
148 | 143 |
|
149 | | - We need to stage the attitude controls because the sys_id physics mode operates directly on |
150 | | - the attitude controls. If we were to directly update the controls, this would effectively |
151 | | - bypass the control frequency and run the attitude controller at the physics update rate. By |
152 | | - staging the controls, we ensure that the physics module sees the old controls until the |
153 | | - controller updates at its correct frequency. |
154 | | - """ |
155 | | - assert controls.shape == (self.n_worlds, self.n_drones, 4), "controls shape mismatch" |
156 | | - assert self.control == Control.attitude, "Attitude control is not enabled by the sim config" |
157 | | - controls = to_device(controls, self.device) |
158 | | - self.data = self.data.replace( |
159 | | - controls=self.data.controls.replace( |
160 | | - attitude=self.data.controls.attitude.replace(staged_cmd=controls) |
161 | | - ) |
162 | | - ) |
163 | | - |
164 | | - def force_torque_control(self, cmd: Array): |
| 144 | + def force_torque_control(self, controls: Array): |
165 | 145 | """Set the desired force and torque for all drones in all worlds.""" |
166 | | - assert cmd.shape == (self.n_worlds, self.n_drones, 4), "Command shape mismatch" |
167 | | - assert self.control == Control.force_torque, ( |
168 | | - "Force-torque control is not enabled by the sim config" |
169 | | - ) |
170 | | - controls = to_device(cmd, self.device) |
171 | | - self.data = self.data.replace( |
172 | | - controls=self.data.controls.replace( |
173 | | - force_torque=self.data.controls.force_torque.replace(staged_cmd=controls) |
174 | | - ) |
175 | | - ) |
| 146 | + self.data = F.force_torque_control(self.data, controls) |
176 | 147 |
|
177 | 148 | @requires_mujoco_sync |
178 | 149 | def render( |
@@ -408,18 +379,7 @@ def controllable(self) -> Array: |
408 | 379 | as soon as the controller frequency allows for an update. Successive control updates that |
409 | 380 | happen before the staged buffers are applied overwrite the desired values. |
410 | 381 | """ |
411 | | - controls = self.data.controls |
412 | | - match self.control: |
413 | | - case Control.state: |
414 | | - control_steps, control_freq = controls.state.steps, controls.state.freq |
415 | | - case Control.attitude: |
416 | | - control_steps, control_freq = controls.attitude.steps, controls.attitude.freq |
417 | | - case Control.force_torque: |
418 | | - control_steps = controls.force_torque.steps |
419 | | - control_freq = controls.force_torque.freq |
420 | | - case _: |
421 | | - raise NotImplementedError(f"Control mode {self.control} not implemented") |
422 | | - return controllable(self.data.core.steps, self.data.core.freq, control_steps, control_freq) |
| 382 | + return F.controllable(self.data) |
423 | 383 |
|
424 | 384 | @requires_mujoco_sync |
425 | 385 | def contacts(self, body: str | None = None) -> Array: |
|
0 commit comments