Skip to content

Commit 4946714

Browse files
committed
Add functional API for controllers
1 parent cd3bd4c commit 4946714

6 files changed

Lines changed: 1053 additions & 649 deletions

File tree

crazyflow/sim/data.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ class ControlData(typing.Protocol):
102102

103103
@dataclass
104104
class SimControls:
105+
mode: Control = field(pytree_node=False)
106+
"""Control mode of the simulation."""
105107
state: ControlData | None
106108
"""State control data."""
107109
attitude: ControlData | None
@@ -136,7 +138,11 @@ def create(
136138
n_worlds, n_drones, force_torque_freq, drone_model, device
137139
)
138140
return SimControls(
139-
state=state, attitude=attitude, force_torque=force_torque, rotor_vel=rotor_vel
141+
mode=control,
142+
state=state,
143+
attitude=attitude,
144+
force_torque=force_torque,
145+
rotor_vel=rotor_vel,
140146
)
141147
case Control.attitude:
142148
attitude = attitude = MellingerAttitudeData.create(
@@ -146,14 +152,22 @@ def create(
146152
n_worlds, n_drones, force_torque_freq, drone_model, device
147153
)
148154
return SimControls(
149-
state=None, attitude=attitude, force_torque=force_torque, rotor_vel=rotor_vel
155+
mode=control,
156+
state=None,
157+
attitude=attitude,
158+
force_torque=force_torque,
159+
rotor_vel=rotor_vel,
150160
)
151161
case Control.force_torque:
152162
force_torque = MellingerForceTorqueData.create(
153163
n_worlds, n_drones, force_torque_freq, drone_model, device
154164
)
155165
return SimControls(
156-
state=None, attitude=None, force_torque=force_torque, rotor_vel=rotor_vel
166+
mode=control,
167+
state=None,
168+
attitude=None,
169+
force_torque=force_torque,
170+
rotor_vel=rotor_vel,
157171
)
158172
case _:
159173
raise ValueError(f"Control mode {control} not implemented")

crazyflow/sim/functional.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
from crazyflow.control import Control
6+
from crazyflow.control.control import controllable as _controllable
7+
from crazyflow.utils import to_device
8+
9+
if TYPE_CHECKING:
10+
from jax import Array
11+
12+
from crazyflow.sim.data import SimData
13+
14+
15+
def state_control(data: SimData, controls: Array) -> SimData:
16+
"""State control function."""
17+
assert data.controls.mode == Control.state, f"control type {data.controls.mode} not enabled"
18+
assert controls.shape == (data.core.n_worlds, data.core.n_drones, 13), "controls shape mismatch"
19+
controls = to_device(controls, data.core.steps.device)
20+
data = data.replace(
21+
controls=data.controls.replace(state=data.controls.state.replace(staged_cmd=controls))
22+
)
23+
return data
24+
25+
26+
def attitude_control(data: SimData, controls: Array) -> SimData:
27+
"""Attitude control function.
28+
29+
We need to stage the attitude controls because the sys_id physics mode operates directly on
30+
the attitude controls. If we were to directly update the controls, this would effectively
31+
bypass the control frequency and run the attitude controller at the physics update rate. By
32+
staging the controls, we ensure that the physics module sees the old controls until the
33+
controller updates at its correct frequency.
34+
"""
35+
assert data.controls.mode == Control.attitude, f"control type {data.controls.mode} not enabled"
36+
assert controls.shape == (data.core.n_worlds, data.core.n_drones, 4), "controls shape mismatch"
37+
controls = to_device(controls, data.core.steps.device)
38+
data = data.replace(
39+
controls=data.controls.replace(attitude=data.controls.attitude.replace(staged_cmd=controls))
40+
)
41+
return data
42+
43+
44+
def force_torque_control(data: SimData, controls: Array) -> SimData:
45+
"""Force-torque control function."""
46+
assert data.controls.mode == Control.force_torque, (
47+
f"control type {data.controls.mode} not enabled"
48+
)
49+
assert controls.shape == (data.core.n_worlds, data.core.n_drones, 4), "controls shape mismatch"
50+
controls = to_device(controls, data.core.steps.device)
51+
data = data.replace(
52+
controls=data.controls.replace(
53+
force_torque=data.controls.force_torque.replace(staged_cmd=controls)
54+
)
55+
)
56+
return data
57+
58+
59+
def controllable(data: SimData) -> Array:
60+
"""Check which worlds can currently update their controllers."""
61+
controls = data.controls
62+
match data.controls.mode:
63+
case Control.state:
64+
control_steps, control_freq = controls.state.steps, controls.state.freq
65+
case Control.attitude:
66+
control_steps, control_freq = controls.attitude.steps, controls.attitude.freq
67+
case Control.force_torque:
68+
control_steps = controls.force_torque.steps
69+
control_freq = controls.force_torque.freq
70+
case _:
71+
raise NotImplementedError(f"Control mode {data.controls.mode} not implemented")
72+
return _controllable(data.core.steps, data.core.freq, control_steps, control_freq)

crazyflow/sim/sim.py

Lines changed: 8 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from gymnasium.envs.mujoco.mujoco_rendering import MujocoRenderer
1919
from jax import Array, Device
2020

21+
import crazyflow.sim.functional as F
2122
from crazyflow.control.control import Control, controllable
2223
from crazyflow.exception import ConfigError, NotInitializedError
2324
from crazyflow.sim.data import SimControls, SimCore, SimData, SimParams, SimState, SimStateDeriv
@@ -29,7 +30,7 @@
2930
so_rpy_rotor_drag_physics,
3031
so_rpy_rotor_physics,
3132
)
32-
from crazyflow.utils import grid_2d, leaf_replace, pytree_replace, to_device
33+
from crazyflow.utils import grid_2d, leaf_replace, pytree_replace
3334

3435
if TYPE_CHECKING:
3536
from mujoco.mjx import Data, Model
@@ -134,45 +135,15 @@ def step(self, n_steps: int = 1):
134135

135136
def state_control(self, controls: Array):
136137
"""Set the desired state for all drones in all worlds."""
137-
assert controls.shape == (self.n_worlds, self.n_drones, 13), "controls shape mismatch"
138-
assert self.control == Control.state, "State control is not enabled by the sim config"
139-
controls = to_device(controls, self.device)
140-
self.data = self.data.replace(
141-
controls=self.data.controls.replace(
142-
state=self.data.controls.state.replace(staged_cmd=controls)
143-
)
144-
)
138+
self.data = F.state_control(self.data, controls)
145139

146140
def attitude_control(self, controls: Array):
147-
"""Set the desired attitude for all drones in all worlds.
141+
"""Set the desired attitude for all drones in all worlds."""
142+
self.data = F.attitude_control(self.data, controls)
148143

149-
We need to stage the attitude controls because the sys_id physics mode operates directly on
150-
the attitude controls. If we were to directly update the controls, this would effectively
151-
bypass the control frequency and run the attitude controller at the physics update rate. By
152-
staging the controls, we ensure that the physics module sees the old controls until the
153-
controller updates at its correct frequency.
154-
"""
155-
assert controls.shape == (self.n_worlds, self.n_drones, 4), "controls shape mismatch"
156-
assert self.control == Control.attitude, "Attitude control is not enabled by the sim config"
157-
controls = to_device(controls, self.device)
158-
self.data = self.data.replace(
159-
controls=self.data.controls.replace(
160-
attitude=self.data.controls.attitude.replace(staged_cmd=controls)
161-
)
162-
)
163-
164-
def force_torque_control(self, cmd: Array):
144+
def force_torque_control(self, controls: Array):
165145
"""Set the desired force and torque for all drones in all worlds."""
166-
assert cmd.shape == (self.n_worlds, self.n_drones, 4), "Command shape mismatch"
167-
assert self.control == Control.force_torque, (
168-
"Force-torque control is not enabled by the sim config"
169-
)
170-
controls = to_device(cmd, self.device)
171-
self.data = self.data.replace(
172-
controls=self.data.controls.replace(
173-
force_torque=self.data.controls.force_torque.replace(staged_cmd=controls)
174-
)
175-
)
146+
self.data = F.force_torque_control(self.data, controls)
176147

177148
@requires_mujoco_sync
178149
def render(
@@ -408,18 +379,7 @@ def controllable(self) -> Array:
408379
as soon as the controller frequency allows for an update. Successive control updates that
409380
happen before the staged buffers are applied overwrite the desired values.
410381
"""
411-
controls = self.data.controls
412-
match self.control:
413-
case Control.state:
414-
control_steps, control_freq = controls.state.steps, controls.state.freq
415-
case Control.attitude:
416-
control_steps, control_freq = controls.attitude.steps, controls.attitude.freq
417-
case Control.force_torque:
418-
control_steps = controls.force_torque.steps
419-
control_freq = controls.force_torque.freq
420-
case _:
421-
raise NotImplementedError(f"Control mode {self.control} not implemented")
422-
return controllable(self.data.core.steps, self.data.core.freq, control_steps, control_freq)
382+
return F.controllable(self.data)
423383

424384
@requires_mujoco_sync
425385
def contacts(self, body: str | None = None) -> Array:

0 commit comments

Comments
 (0)