Skip to content

Commit 5a28c10

Browse files
committed
[WIP] Add drone-models. Switch to scipy PR branch. Add new controllers
1 parent d07577d commit 5a28c10

9 files changed

Lines changed: 345 additions & 111 deletions

File tree

.gitmodules

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[submodule "submodules/drone-models"]
2+
path = submodules/drone-models
3+
url = https://github.com/utiasDSL/drone-models.git

crazyflow/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
import os
2+
3+
os.environ["SCIPY_ARRAY_API"] = "1"
4+
15
import crazyflow.envs # noqa: F401, ensure gymnasium envs are registered
26
from crazyflow.control import Control
37
from crazyflow.sim import Physics, Sim

crazyflow/control/control.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,23 @@ class Control(str, Enum):
5050
default = attitude
5151

5252

53+
@jax.jit
54+
def controllable(step: Array, freq: int, control_steps: Array, control_freq: int) -> Array:
55+
"""Check which worlds can currently update their controllers.
56+
57+
Args:
58+
step: The current step of the simulation.
59+
freq: The frequency of the simulation.
60+
control_steps: The steps at which the controllers were last updated.
61+
control_freq: The frequency of the controllers.
62+
63+
Returns:
64+
A boolean mask of shape (n_worlds,) that is True at the worlds where the controllers can be
65+
updated.
66+
"""
67+
return ((step - control_steps) >= (freq / control_freq)) | (control_steps == -1)
68+
69+
5370
KF: float = 3.16e-10
5471
KM: float = 7.94e-12
5572
P_F: Array = np.array([0.4, 0.4, 1.25])
@@ -111,7 +128,7 @@ def attitude2rpm(
111128
) -> tuple[Array, Array]:
112129
"""Convert the desired collective thrust and attitude into motor RPMs."""
113130
rot = R.from_quat(quat)
114-
target_rot = R.from_euler("xyz", controls[1:])
131+
target_rot = R.from_euler("xyz", controls[:3])
115132
drot = (target_rot.inv() * rot).as_matrix()
116133
# Extract the anti-symmetric part of the relative rotation matrix.
117134
rot_e = jnp.array([drot[2, 1] - drot[1, 2], drot[0, 2] - drot[2, 0], drot[1, 0] - drot[0, 1]])
@@ -123,7 +140,7 @@ def attitude2rpm(
123140
# PID target torques.
124141
target_torques = -P_T * rot_e + D_T * rpy_rates_e + I_T * rpy_err_i
125142
target_torques = jnp.clip(target_torques, -3200, 3200)
126-
thrust_per_motor = jnp.atleast_1d(controls[0]) / 4
143+
thrust_per_motor = jnp.atleast_1d(controls[3]) / 4
127144
pwm = jnp.clip(thrust2pwm(thrust_per_motor) + MIX_MATRIX @ target_torques, MIN_PWM, MAX_PWM)
128145
return pwm2rpm(pwm), rpy_err_i
129146

crazyflow/control/mellinger.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
from __future__ import annotations
2+
3+
import jax.numpy as jnp
4+
from drone_models.controller.mellinger import MellingerStateParams
5+
from flax.struct import dataclass, field
6+
from jax import Array, Device
7+
8+
9+
@dataclass
10+
class MellingerStateData:
11+
cmd: Array # (N, M, 13)
12+
"""Full state control command for the drone.
13+
14+
A command consists of [x, y, z, vx, vy, vz, ax, ay, az, yaw, roll_rate, pitch_rate, yaw_rate].
15+
We currently do not use the acceleration and angle rate components. This is subject to change.
16+
"""
17+
steps: Array # (N, 1)
18+
"""Last simulation steps that the state control command was applied."""
19+
freq: int = field(pytree_node=False)
20+
"""Frequency of the state control command."""
21+
pos_err_i: Array # (N, M, 3)
22+
"""Integral errors of the state control command."""
23+
# Parameters for the state controller
24+
params: MellingerStateParams
25+
26+
@staticmethod
27+
def create(
28+
n_worlds: int, n_drones: int, freq: int, drone_model: str, device: Device
29+
) -> MellingerStateData:
30+
"""Create a default set of state data for the simulation."""
31+
cmd = jnp.zeros((n_worlds, n_drones, 13), device=device)
32+
steps = jnp.zeros((n_worlds, 1), dtype=jnp.int32, device=device)
33+
pos_err_i = jnp.zeros((n_worlds, n_drones, 3), device=device)
34+
params = MellingerStateParams.load(drone_model)
35+
return MellingerStateData(
36+
cmd=cmd, steps=steps, freq=freq, pos_err_i=pos_err_i, params=params
37+
)
38+
39+
40+
# @dataclass
41+
# class MellingerAttitudeData:
42+
# cmd: Array # (N, M, 4)
43+
# """Full attitude control command for the drone.
44+
45+
# A command consists of [collective thrust, roll, pitch, yaw].
46+
# """
47+
# steps: Array # (N, 1)
48+
# """Last simulation steps that the attitude control command was applied."""
49+
# freq: int = field(pytree_node=False)
50+
# """Frequency of the attitude control command."""
51+
# pos_err_i: Array # (N, M, 3)
52+
# """Integral errors of the attitude control command."""
53+
# # Parameters for the attitude controller
54+
# params: MellingerAttitudeParams
55+
56+
# @staticmethod
57+
# def create(
58+
# n_worlds: int, n_drones: int, freq: int, drone_model: str, device: Device
59+
# ) -> MellingerAttitudeData:
60+
# """Create a default set of attitude data for the simulation."""
61+
# cmd = jnp.zeros((n_worlds, n_drones, 4), device=device)
62+
# steps = jnp.zeros((n_worlds, 1), dtype=jnp.int32, device=device)
63+
# pos_err_i = jnp.zeros((n_worlds, n_drones, 3), device=device)
64+
# params = MellingerAttitudeParams.load(drone_model)
65+
# return MellingerAttitudeData(
66+
# cmd=cmd, steps=steps, freq=freq, pos_err_i=pos_err_i, params=params
67+
# )

crazyflow/sim/sim.py

Lines changed: 69 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,15 @@
88
import jax.numpy as jnp
99
import mujoco
1010
import mujoco.mjx as mjx
11+
from drone_models.controller.mellinger import state2attitude
1112
from einops import rearrange
1213
from gymnasium.envs.mujoco.mujoco_rendering import MujocoRenderer
1314
from jax import Array, Device
1415
from jax.scipy.spatial.transform import Rotation as R
1516

1617
from crazyflow.constants import J_INV, MASS, J
17-
from crazyflow.control.control import Control, attitude2rpm, pwm2rpm, state2attitude, thrust2pwm
18+
from crazyflow.control.control import Control, attitude2rpm, controllable, pwm2rpm, thrust2pwm
19+
from crazyflow.control.control import state2attitude as state2attitude_legacy
1820
from crazyflow.exception import ConfigError, NotInitializedError
1921
from crazyflow.sim.integration import Integrator, euler, rk4, symplectic_euler
2022
from crazyflow.sim.physics import (
@@ -24,13 +26,23 @@
2426
rpms2collective_wrench,
2527
surrogate_identified_collective_wrench,
2628
)
27-
from crazyflow.sim.structs import SimControls, SimCore, SimData, SimParams, SimState, SimStateDeriv
29+
from crazyflow.sim.structs import (
30+
SimControls,
31+
SimControlsNew,
32+
SimCore,
33+
SimData,
34+
SimParams,
35+
SimState,
36+
SimStateDeriv,
37+
)
2838
from crazyflow.utils import grid_2d, leaf_replace, pytree_replace, to_device
2939

3040
if TYPE_CHECKING:
3141
from mujoco.mjx import Data, Model
3242
from numpy.typing import NDArray
3343

44+
from crazyflow.control.mellinger import MellingerStateData
45+
3446
Params = ParamSpec("Params") # Represents arbitrary parameters
3547
Return = TypeVar("Return") # Represents the return type
3648

@@ -141,7 +153,11 @@ def state_control(self, controls: Array):
141153
assert controls.shape == (self.n_worlds, self.n_drones, 13), "controls shape mismatch"
142154
assert self.control == Control.state, "State control is not enabled by the sim config"
143155
controls = to_device(controls, self.device)
144-
self.data = self.data.replace(controls=self.data.controls.replace(state=controls))
156+
self.data = self.data.replace(
157+
new_controls=self.data.new_controls.replace(
158+
state=self.data.new_controls.state.replace(cmd=controls)
159+
)
160+
)
145161

146162
def thrust_control(self, cmd: Array):
147163
"""Set the desired thrust for all drones in all worlds."""
@@ -182,7 +198,7 @@ def seed(self, seed: int):
182198
Args:
183199
seed: The seed for the JAX rng.
184200
"""
185-
self.data = seed_sim(self.data, seed, self.device)
201+
self.data: SimData = seed_sim(self.data, seed, self.device)
186202

187203
def close(self):
188204
if self.viewer is not None:
@@ -282,7 +298,7 @@ def build_mjx(self):
282298

283299
def init_data(
284300
self, state_freq: int, attitude_freq: int, thrust_freq: int, rng_key: Array
285-
) -> tuple[SimData, SimData]:
301+
) -> SimData:
286302
"""Initialize the simulation data."""
287303
drone_ids = [self.mj_model.body(f"drone:{i}").id for i in range(self.n_drones)]
288304
N, D = self.n_worlds, self.n_drones
@@ -292,6 +308,9 @@ def init_data(
292308
controls=SimControls.create(N, D, state_freq, attitude_freq, thrust_freq, self.device),
293309
params=SimParams.create(N, D, MASS, J, J_INV, self.device),
294310
core=SimCore.create(self.freq, N, D, drone_ids, rng_key, self.device),
311+
new_controls=SimControlsNew.create(
312+
N, D, self.control, state_freq, attitude_freq, thrust_freq, self.device
313+
),
295314
)
296315
if D > 1: # If multiple drones, arrange them in a grid
297316
grid = grid_2d(D)
@@ -417,23 +436,6 @@ def integrate(data: SimData) -> SimData:
417436
return integrate
418437

419438

420-
@jax.jit
421-
def controllable(step: Array, freq: int, control_steps: Array, control_freq: int) -> Array:
422-
"""Check which worlds can currently update their controllers.
423-
424-
Args:
425-
step: The current step of the simulation.
426-
freq: The frequency of the simulation.
427-
control_steps: The steps at which the controllers were last updated.
428-
control_freq: The frequency of the controllers.
429-
430-
Returns:
431-
A boolean mask of shape (n_worlds,) that is True at the worlds where the controllers can be
432-
updated.
433-
"""
434-
return ((step - control_steps) >= (freq / control_freq)) | (control_steps == -1)
435-
436-
437439
@jax.jit
438440
def contacts(geom_start: int, geom_count: int, data: Data) -> Array:
439441
"""Filter contacts from MuJoCo data."""
@@ -461,18 +463,53 @@ def sync_sim2mjx(data: SimData, mjx_data: Data, mjx_model: Model) -> tuple[SimDa
461463

462464
def step_state_controller(data: SimData) -> SimData:
463465
"""Compute the updated controls for the state controller."""
464-
states, controls = data.states, data.controls
465-
mask = controllable(data.core.steps, data.core.freq, controls.state_steps, controls.state_freq)
466-
des_pos, des_vel = controls.state[..., :3], controls.state[..., 3:6]
467-
des_yaw = controls.state[..., [9]] # Keep (N, M, 1) shape for broadcasting
468-
dt = 1 / data.controls.state_freq
469-
attitude, pos_err_i = state2attitude(
470-
states.pos, states.vel, states.quat, des_pos, des_vel, des_yaw, controls.pos_err_i, dt
466+
states, ctrl_state = data.states, data.new_controls.state
467+
assert ctrl_state is not None, "Using state controller without initialized state control data"
468+
ctrl_state: MellingerStateData
469+
mask = controllable(data.core.steps, data.core.freq, ctrl_state.steps, ctrl_state.freq)
470+
jax.debug.print("Ctrl cmd: {cmd}", cmd=ctrl_state.cmd)
471+
attitude, (pos_err_i,) = state2attitude(
472+
states.pos,
473+
states.quat,
474+
states.vel,
475+
states.ang_vel,
476+
ctrl_state.cmd,
477+
ctrl_freq=ctrl_state.freq,
478+
ctrl_errors=(ctrl_state.pos_err_i,),
479+
**ctrl_state.params._asdict(),
471480
)
472-
controls = leaf_replace(
473-
controls, mask, state_steps=data.core.steps, staged_attitude=attitude, pos_err_i=pos_err_i
481+
jax.debug.print("Attitude: {attitude}", attitude=attitude)
482+
ctrl_state = leaf_replace(ctrl_state, mask, steps=data.core.steps, pos_err_i=pos_err_i)
483+
data = data.replace(
484+
controls=data.controls.replace(staged_attitude=attitude),
485+
new_controls=data.new_controls.replace(state=ctrl_state),
474486
)
475-
return data.replace(controls=controls)
487+
return data
488+
489+
490+
# def step_state_controller(data: SimData) -> SimData:
491+
# """Compute the updated controls for the state controller."""
492+
# states, ctrl_state = data.states, data.new_controls.state
493+
# assert ctrl_state is not None, "Using state controller without initialized state control data"
494+
# mask = controllable(data.core.steps, data.core.freq, ctrl_state.steps, ctrl_state.freq)
495+
# attitude, pos_err_i = state2attitude_legacy(
496+
# states.pos,
497+
# states.vel,
498+
# states.quat,
499+
# ctrl_state.cmd[..., :3],
500+
# ctrl_state.cmd[..., 3:6],
501+
# ctrl_state.cmd[..., [9]],
502+
# ctrl_state.pos_err_i,
503+
# 1 / ctrl_state.freq,
504+
# )
505+
# attitude = jnp.roll(attitude, -1, axis=-1)
506+
# jax.debug.print("Attitude: {attitude}", attitude=attitude)
507+
# ctrl_state = leaf_replace(ctrl_state, mask, steps=data.core.steps, pos_err_i=pos_err_i)
508+
# data = data.replace(
509+
# controls=data.controls.replace(staged_attitude=attitude),
510+
# new_controls=data.new_controls.replace(state=ctrl_state),
511+
# )
512+
# return data
476513

477514

478515
def step_attitude_controller(data: SimData) -> SimData:
@@ -533,14 +570,6 @@ def identified_wrench(data: SimData) -> SimData:
533570
identified_derivative = analytical_derivative # We can use the same derivative function for both
534571

535572

536-
def identity(data: SimData, *args: Any, **kwargs: Any) -> SimData:
537-
"""Identity function for the simulation pipeline.
538-
539-
Used as default function for optional pipeline steps.
540-
"""
541-
return data
542-
543-
544573
def clip_floor_pos(data: SimData) -> SimData:
545574
"""Clip the position of the drone to the floor."""
546575
clip = data.states.pos[..., 2] < -0.001

crazyflow/sim/structs.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
from __future__ import annotations
22

3+
import typing
4+
35
import jax
46
import jax.numpy as jnp
57
from flax.struct import dataclass, field
68
from jax import Array, Device
79

10+
from crazyflow.control import Control
11+
from crazyflow.control.mellinger import MellingerStateData
12+
813

914
@dataclass
1015
class SimState:
@@ -146,6 +151,54 @@ def create(
146151
)
147152

148153

154+
class ControlData(typing.Protocol):
155+
cmd: Array # (N, M, X)
156+
"""Control command for the drone."""
157+
steps: Array # (N, 1)
158+
"""Last simulation steps that the state control command was applied."""
159+
freq: int
160+
"""Frequency of the state control command."""
161+
# Parameters for the controller
162+
params: typing.Any
163+
164+
165+
@dataclass
166+
class SimControlsNew:
167+
state: ControlData | None = None
168+
"""State control data."""
169+
attitude: ControlData | None = None
170+
"""Attitude control data."""
171+
thrust: ControlData | None = None
172+
"""Thrust control data."""
173+
174+
@staticmethod
175+
def create(
176+
n_worlds: int,
177+
n_drones: int,
178+
control: Control,
179+
state_freq: int | None,
180+
attitude_freq: int | None,
181+
thrust_freq: int | None,
182+
device: Device,
183+
) -> SimControlsNew:
184+
"""Create a default set of controls for the simulation."""
185+
match control:
186+
case Control.state:
187+
state = MellingerStateData.create(n_worlds, n_drones, state_freq, "", device)
188+
attitude = None # MellingerAttitudeData.create(n_worlds, n_drones, device)
189+
thrust = None # MellingerThrustData.create(n_worlds, n_drones, device)
190+
return SimControlsNew(state=state, attitude=attitude, thrust=thrust)
191+
case Control.attitude:
192+
attitude = None # MellingerAttitudeData.create(n_worlds, n_drones, device)
193+
thrust = None # MellingerThrustData.create(n_worlds, n_drones, device)
194+
return SimControlsNew(attitude=attitude, thrust=thrust)
195+
case Control.thrust:
196+
thrust = None # MellingerThrustData.create(n_worlds, n_drones, device)
197+
return SimControlsNew(thrust=thrust)
198+
case _:
199+
raise ValueError(f"Control mode {control} not implemented")
200+
201+
149202
@dataclass
150203
class SimParams:
151204
mass: Array # (N, M, 1)
@@ -217,6 +270,8 @@ class SimData:
217270
"""Derivative of the state of the simulation."""
218271
controls: SimControls
219272
"""Drone control values."""
273+
new_controls: SimControlsNew
274+
"""New style control data TODO improve this"""
220275
params: SimParams
221276
"""Drone parameters."""
222277
core: SimCore

0 commit comments

Comments
 (0)