Skip to content

Commit f94da5a

Browse files
committed
Add functional API for controllers
1 parent b390ac6 commit f94da5a

6 files changed

Lines changed: 246 additions & 116 deletions

File tree

crazyflow/sim/data.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ class ControlData(typing.Protocol):
102102

103103
@dataclass
104104
class SimControls:
105+
mode: Control = field(pytree_node=False)
106+
"""Control mode of the simulation."""
105107
state: ControlData | None
106108
"""State control data."""
107109
attitude: ControlData | None
@@ -136,7 +138,11 @@ def create(
136138
n_worlds, n_drones, force_torque_freq, drone_model, device
137139
)
138140
return SimControls(
139-
state=state, attitude=attitude, force_torque=force_torque, rotor_vel=rotor_vel
141+
mode=control,
142+
state=state,
143+
attitude=attitude,
144+
force_torque=force_torque,
145+
rotor_vel=rotor_vel,
140146
)
141147
case Control.attitude:
142148
attitude = attitude = MellingerAttitudeData.create(
@@ -146,14 +152,22 @@ def create(
146152
n_worlds, n_drones, force_torque_freq, drone_model, device
147153
)
148154
return SimControls(
149-
state=None, attitude=attitude, force_torque=force_torque, rotor_vel=rotor_vel
155+
mode=control,
156+
state=None,
157+
attitude=attitude,
158+
force_torque=force_torque,
159+
rotor_vel=rotor_vel,
150160
)
151161
case Control.force_torque:
152162
force_torque = MellingerForceTorqueData.create(
153163
n_worlds, n_drones, force_torque_freq, drone_model, device
154164
)
155165
return SimControls(
156-
state=None, attitude=None, force_torque=force_torque, rotor_vel=rotor_vel
166+
mode=control,
167+
state=None,
168+
attitude=None,
169+
force_torque=force_torque,
170+
rotor_vel=rotor_vel,
157171
)
158172
case _:
159173
raise ValueError(f"Control mode {control} not implemented")

crazyflow/sim/functional.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
from crazyflow.control import Control
6+
from crazyflow.control.control import controllable as _controllable
7+
from crazyflow.utils import to_device
8+
9+
if TYPE_CHECKING:
10+
from jax import Array
11+
12+
from crazyflow.sim.data import SimData
13+
14+
15+
def state_control(data: SimData, controls: Array) -> SimData:
16+
"""State control function."""
17+
assert data.controls.mode == Control.state, f"control type {data.controls.mode} not enabled"
18+
assert controls.shape == (data.core.n_worlds, data.core.n_drones, 13), "controls shape mismatch"
19+
controls = to_device(controls, data.core.steps.device)
20+
data = data.replace(
21+
controls=data.controls.replace(state=data.controls.state.replace(staged_cmd=controls))
22+
)
23+
return data
24+
25+
26+
def attitude_control(data: SimData, controls: Array) -> SimData:
27+
"""Attitude control function.
28+
29+
We need to stage the attitude controls because the sys_id physics mode operates directly on
30+
the attitude controls. If we were to directly update the controls, this would effectively
31+
bypass the control frequency and run the attitude controller at the physics update rate. By
32+
staging the controls, we ensure that the physics module sees the old controls until the
33+
controller updates at its correct frequency.
34+
"""
35+
assert data.controls.mode == Control.attitude, f"control type {data.controls.mode} not enabled"
36+
assert controls.shape == (data.core.n_worlds, data.core.n_drones, 4), "controls shape mismatch"
37+
controls = to_device(controls, data.core.steps.device)
38+
data = data.replace(
39+
controls=data.controls.replace(attitude=data.controls.attitude.replace(staged_cmd=controls))
40+
)
41+
return data
42+
43+
44+
def force_torque_control(data: SimData, controls: Array) -> SimData:
45+
"""Force-torque control function."""
46+
assert data.controls.mode == Control.force_torque, (
47+
f"control type {data.controls.mode} not enabled"
48+
)
49+
assert controls.shape == (data.core.n_worlds, data.core.n_drones, 4), "controls shape mismatch"
50+
controls = to_device(controls, data.core.steps.device)
51+
data = data.replace(
52+
controls=data.controls.replace(
53+
force_torque=data.controls.force_torque.replace(staged_cmd=controls)
54+
)
55+
)
56+
return data
57+
58+
59+
def controllable(data: SimData) -> Array:
60+
"""Check which worlds can currently update their controllers."""
61+
controls = data.controls
62+
match data.controls.mode:
63+
case Control.state:
64+
control_steps, control_freq = controls.state.steps, controls.state.freq
65+
case Control.attitude:
66+
control_steps, control_freq = controls.attitude.steps, controls.attitude.freq
67+
case Control.force_torque:
68+
control_steps = controls.force_torque.steps
69+
control_freq = controls.force_torque.freq
70+
case _:
71+
raise NotImplementedError(f"Control mode {data.controls.mode} not implemented")
72+
return _controllable(data.core.steps, data.core.freq, control_steps, control_freq)

crazyflow/sim/sim.py

Lines changed: 8 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from gymnasium.envs.mujoco.mujoco_rendering import MujocoRenderer
1919
from jax import Array, Device
2020

21+
import crazyflow.sim.functional as F
2122
from crazyflow.control.control import Control, controllable
2223
from crazyflow.exception import ConfigError, NotInitializedError
2324
from crazyflow.sim.data import SimControls, SimCore, SimData, SimParams, SimState, SimStateDeriv
@@ -29,7 +30,7 @@
2930
so_rpy_rotor_drag_physics,
3031
so_rpy_rotor_physics,
3132
)
32-
from crazyflow.utils import grid_2d, leaf_replace, pytree_replace, to_device
33+
from crazyflow.utils import grid_2d, leaf_replace, pytree_replace
3334

3435
if TYPE_CHECKING:
3536
from mujoco.mjx import Data, Model
@@ -134,45 +135,15 @@ def step(self, n_steps: int = 1):
134135

135136
def state_control(self, controls: Array):
136137
"""Set the desired state for all drones in all worlds."""
137-
assert controls.shape == (self.n_worlds, self.n_drones, 13), "controls shape mismatch"
138-
assert self.control == Control.state, "State control is not enabled by the sim config"
139-
controls = to_device(controls, self.device)
140-
self.data = self.data.replace(
141-
controls=self.data.controls.replace(
142-
state=self.data.controls.state.replace(staged_cmd=controls)
143-
)
144-
)
138+
self.data = F.state_control(self.data, controls)
145139

146140
def attitude_control(self, controls: Array):
147-
"""Set the desired attitude for all drones in all worlds.
141+
"""Set the desired attitude for all drones in all worlds."""
142+
self.data = F.attitude_control(self.data, controls)
148143

149-
We need to stage the attitude controls because the sys_id physics mode operates directly on
150-
the attitude controls. If we were to directly update the controls, this would effectively
151-
bypass the control frequency and run the attitude controller at the physics update rate. By
152-
staging the controls, we ensure that the physics module sees the old controls until the
153-
controller updates at its correct frequency.
154-
"""
155-
assert controls.shape == (self.n_worlds, self.n_drones, 4), "controls shape mismatch"
156-
assert self.control == Control.attitude, "Attitude control is not enabled by the sim config"
157-
controls = to_device(controls, self.device)
158-
self.data = self.data.replace(
159-
controls=self.data.controls.replace(
160-
attitude=self.data.controls.attitude.replace(staged_cmd=controls)
161-
)
162-
)
163-
164-
def force_torque_control(self, cmd: Array):
144+
def force_torque_control(self, controls: Array):
165145
"""Set the desired force and torque for all drones in all worlds."""
166-
assert cmd.shape == (self.n_worlds, self.n_drones, 4), "Command shape mismatch"
167-
assert self.control == Control.force_torque, (
168-
"Force-torque control is not enabled by the sim config"
169-
)
170-
controls = to_device(cmd, self.device)
171-
self.data = self.data.replace(
172-
controls=self.data.controls.replace(
173-
force_torque=self.data.controls.force_torque.replace(staged_cmd=controls)
174-
)
175-
)
146+
self.data = F.force_torque_control(self.data, controls)
176147

177148
@requires_mujoco_sync
178149
def render(
@@ -408,18 +379,7 @@ def controllable(self) -> Array:
408379
as soon as the controller frequency allows for an update. Successive control updates that
409380
happen before the staged buffers are applied overwrite the desired values.
410381
"""
411-
controls = self.data.controls
412-
match self.control:
413-
case Control.state:
414-
control_steps, control_freq = controls.state.steps, controls.state.freq
415-
case Control.attitude:
416-
control_steps, control_freq = controls.attitude.steps, controls.attitude.freq
417-
case Control.force_torque:
418-
control_steps = controls.force_torque.steps
419-
control_freq = controls.force_torque.freq
420-
case _:
421-
raise NotImplementedError(f"Control mode {self.control} not implemented")
422-
return controllable(self.data.core.steps, self.data.core.freq, control_steps, control_freq)
382+
return F.controllable(self.data)
423383

424384
@requires_mujoco_sync
425385
def contacts(self, body: str | None = None) -> Array:

crazyflow/sim/visualize.py

Lines changed: 0 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -6,50 +6,6 @@
66
from crazyflow.sim import Sim
77

88

9-
def draw_capsule(
10-
sim: Sim,
11-
p1: NDArray,
12-
p2: NDArray,
13-
radius: float = 0.05,
14-
rgba: NDArray | None = None,
15-
is_cylinder: bool = False,
16-
):
17-
"""Draw a capsule (pill) or cylinder between two points.
18-
19-
Args:
20-
sim: The simulation.
21-
p1: Start point [3,]
22-
p2: End point [3,]
23-
radius: The thickness of the geom.
24-
rgba: The color of the object.
25-
is_cylinder: If True, draws a flat-ended cylinder.
26-
If False, draws a pill-shaped capsule.
27-
"""
28-
if sim.viewer is None:
29-
return
30-
31-
# 1. Calculate Midpoint (Center of the geom)
32-
pos = (p1 + p2) / 2.0
33-
34-
# 2. Calculate Half-length (MuJoCo uses half-extents)
35-
dist = np.linalg.norm(p2 - p1)
36-
half_length = dist / 2.0
37-
38-
# 3. Define Size: [radius, radius, half_length]
39-
# Note: For capsules, size[2] is the half-length of the *cylindrical* part.
40-
# MuJoCo adds the hemispherical caps on top of this.
41-
size = np.array([radius, half_length, 0])
42-
43-
# 4. Get Rotation (Align Z-axis to the vector p2-p1)
44-
# Using your existing helper (wrapped in a list for the reshape)
45-
mat = _rotation_matrix_from_points(p1[None, :], p2[None, :]).as_matrix().flatten()
46-
47-
geom_type = mujoco.mjtGeom.mjGEOM_CYLINDER if is_cylinder else mujoco.mjtGeom.mjGEOM_CAPSULE
48-
rgba = rgba if rgba is not None else np.array([0, 1.0, 0, 1])
49-
50-
sim.viewer.viewer.add_marker(type=geom_type, pos=pos, size=size, mat=mat, rgba=rgba)
51-
52-
539
def draw_line(
5410
sim: Sim,
5511
points: NDArray,

0 commit comments

Comments
 (0)