Skip to content

Commit 45c5cb4

Browse files
committed
Fix most tests
1 parent 866d376 commit 45c5cb4

15 files changed

Lines changed: 167 additions & 208 deletions

File tree

crazyflow/control/mellinger.py

Lines changed: 15 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
from __future__ import annotations
22

33
import jax.numpy as jnp
4-
from drone_models.controller.mellinger import (
5-
MellingerAttitudeParams,
6-
MellingerForceTorqueParams,
7-
MellingerStateParams,
8-
)
4+
from drone_models.controller.mellinger.params import AttitudeParams, ForceTorqueParams, StateParams
95
from flax.struct import dataclass, field
106
from jax import Array, Device
117

8+
from crazyflow.utils import named_tuple2device
9+
1210

1311
@dataclass
1412
class MellingerStateData:
@@ -27,7 +25,7 @@ class MellingerStateData:
2725
pos_err_i: Array # (N, M, 3)
2826
"""Integral errors of the state control command."""
2927
# Parameters for the state controller
30-
params: MellingerStateParams
28+
params: StateParams
3129

3230
@staticmethod
3331
def create(
@@ -37,7 +35,7 @@ def create(
3735
cmd = jnp.zeros((n_worlds, n_drones, 13), device=device)
3836
steps = -jnp.ones((n_worlds, 1), dtype=jnp.int32, device=device)
3937
pos_err_i = jnp.zeros((n_worlds, n_drones, 3), device=device)
40-
params = MellingerStateParams.load(drone_model)
38+
params = named_tuple2device(StateParams.load(drone_model), device)
4139
return MellingerStateData(
4240
cmd=cmd, staged_cmd=cmd, steps=steps, freq=freq, pos_err_i=pos_err_i, params=params
4341
)
@@ -61,7 +59,7 @@ class MellingerAttitudeData:
6159
last_ang_vel: Array # (N, M, 3)
6260
"""Last angular velocity of the drone."""
6361
# Parameters for the attitude controller
64-
params: MellingerAttitudeParams
62+
params: AttitudeParams
6563

6664
@staticmethod
6765
def create(
@@ -71,7 +69,7 @@ def create(
7169
cmd = jnp.zeros((n_worlds, n_drones, 4), device=device)
7270
steps = -jnp.ones((n_worlds, 1), dtype=jnp.int32, device=device)
7371
zeros_3d = jnp.zeros((n_worlds, n_drones, 3), device=device)
74-
params = MellingerAttitudeParams.load(drone_model)
72+
params = named_tuple2device(AttitudeParams.load(drone_model), device)
7573
return MellingerAttitudeData(
7674
cmd=cmd,
7775
staged_cmd=cmd,
@@ -85,40 +83,27 @@ def create(
8583

8684
@dataclass
8785
class MellingerForceTorqueData:
88-
cmd_force: Array # (N, M, 1)
89-
"""Force command for the drone.
90-
91-
A command consists of [fz].
92-
"""
93-
cmd_torque: Array # (N, M, 3)
94-
"""Torque command for the drone.
86+
cmd: Array # (N, M, 4)
87+
"""Force-torque command for the drone.
9588
96-
A command consists of [tx, ty, tz].
89+
A command consists of [fz, tx, ty, tz].
9790
"""
98-
staged_cmd_force: Array # (N, M, 1)
99-
staged_cmd_torque: Array # (N, M, 3)
91+
staged_cmd: Array # (N, M, 4)
10092
"""Staging buffer to store the most recent command until the next controller tick."""
10193
steps: Array # (N, 1)
10294
"""Last simulation steps that the force and torque control command was applied."""
10395
freq: int = field(pytree_node=False)
10496
"""Frequency of the force and torque control command."""
10597
# Parameters for the force and torque controller
106-
params: MellingerForceTorqueParams
98+
params: ForceTorqueParams
10799

108100
@staticmethod
109101
def create(
110102
n_worlds: int, n_drones: int, freq: int, drone_model: str, device: Device
111103
) -> MellingerForceTorqueData:
112-
zero_1d = jnp.zeros((n_worlds, n_drones, 1), device=device)
113-
zero_3d = jnp.zeros((n_worlds, n_drones, 3), device=device)
104+
zero_4d = jnp.zeros((n_worlds, n_drones, 4), device=device)
114105
steps = -jnp.ones((n_worlds, 1), dtype=jnp.int32, device=device)
115-
params = MellingerForceTorqueParams.load(drone_model)
106+
params = named_tuple2device(ForceTorqueParams.load(drone_model), device)
116107
return MellingerForceTorqueData(
117-
cmd_force=zero_1d,
118-
cmd_torque=zero_3d,
119-
staged_cmd_force=zero_1d,
120-
staged_cmd_torque=zero_3d,
121-
steps=steps,
122-
freq=freq,
123-
params=params,
108+
cmd=zero_4d, staged_cmd=zero_4d, steps=steps, freq=freq, params=params
124109
)

crazyflow/sim/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from crazyflow.sim.physics import Physics
22
from crazyflow.sim.sim import Sim
3-
from crazyflow.sim.symbolic import symbolic_attitude, symbolic_from_sim, symbolic_thrust
3+
from crazyflow.sim.symbolic import symbolic_from_sim
44

5-
__all__ = ["Sim", "Physics", "symbolic_attitude", "symbolic_from_sim", "symbolic_thrust"]
5+
__all__ = ["Sim", "Physics", "symbolic_from_sim"]

crazyflow/sim/sim.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def __init__(
7474
self,
7575
n_worlds: int = 1,
7676
n_drones: int = 1,
77+
drone_model: str = "cf2x_L250",
7778
physics: Physics = Physics.default,
7879
control: Control = Control.default,
7980
integrator: Integrator = Integrator.default,
@@ -93,6 +94,7 @@ def __init__(
9394
raise ConfigError("Double precision mode is required for high frequency simulations")
9495
self.physics = physics
9596
self.control = control
97+
self.drone_model = drone_model
9698
self.integrator = integrator
9799
self.device = jax.devices(device)[0]
98100
self.n_worlds = n_worlds
@@ -179,9 +181,7 @@ def force_torque_control(self, cmd: Array):
179181
controls = to_device(cmd, self.device)
180182
self.data = self.data.replace(
181183
controls=self.data.controls.replace(
182-
force_torque=self.data.controls.force_torque.replace(
183-
staged_cmd_force=controls[..., [0]], staged_cmd_torque=controls[..., 1:]
184-
)
184+
force_torque=self.data.controls.force_torque.replace(staged_cmd=controls)
185185
)
186186
)
187187

@@ -325,7 +325,14 @@ def init_data(
325325
states=SimState.create(N, D, self.device),
326326
states_deriv=SimStateDeriv.create(N, D, self.device),
327327
controls=SimControls.create(
328-
N, D, self.control, state_freq, attitude_freq, force_torque_freq, self.device
328+
N,
329+
D,
330+
self.control,
331+
self.drone_model,
332+
state_freq,
333+
attitude_freq,
334+
force_torque_freq,
335+
self.device,
329336
),
330337
params=SimParams.create(N, D, MASS, J, J_INV, self.device),
331338
constants=SimConstants.create(
@@ -501,7 +508,7 @@ def step_state_controller(data: SimData) -> SimData:
501508
**state_ctrl.params._asdict(),
502509
)
503510
state_ctrl = leaf_replace(state_ctrl, mask, steps=data.core.steps, pos_err_i=pos_err_i)
504-
attitude_ctrl = data.controls.attitude.replace(staged_cmd=rpyt)
511+
attitude_ctrl = leaf_replace(data.controls.attitude, mask, staged_cmd=rpyt)
505512
data = data.replace(controls=data.controls.replace(state=state_ctrl, attitude=attitude_ctrl))
506513
return data
507514

@@ -520,7 +527,6 @@ def step_attitude_controller(data: SimData) -> SimData:
520527
states.ang_vel,
521528
attitude_ctrl.cmd,
522529
ctrl_errors=(attitude_ctrl.r_int_error,),
523-
ctrl_info=(attitude_ctrl.last_ang_vel,),
524530
ctrl_freq=attitude_ctrl.freq,
525531
**attitude_ctrl.params._asdict(),
526532
)
@@ -531,16 +537,14 @@ def step_attitude_controller(data: SimData) -> SimData:
531537
last_ang_vel=states.ang_vel,
532538
steps=data.core.steps,
533539
)
534-
force_torque_ctrl = data.controls.force_torque.replace(
535-
staged_cmd_force=force, staged_cmd_torque=torque
536-
)
540+
ft_ctrl = data.controls.force_torque.replace(staged_cmd=jnp.concat([force, torque], axis=-1))
537541
# TODO: Remove. Set the force and torque directly into the physics step.
538542
r = R.from_quat(states.quat)
539543
torque = r.apply(torque)
540544
force = r.apply(jnp.zeros_like(torque).at[..., 2].set(force[..., 0]))
541-
data = data.replace(states=data.states.replace(force=force, torque=torque))
545+
states = leaf_replace(states, mask, force=force, torque=torque)
542546
return data.replace(
543-
controls=data.controls.replace(attitude=attitude_ctrl, force_torque=force_torque_ctrl)
547+
states=states, controls=data.controls.replace(attitude=attitude_ctrl, force_torque=ft_ctrl)
544548
)
545549

546550

@@ -550,19 +554,17 @@ def step_force_torque_controller(data: SimData) -> SimData:
550554
assert ft_ctrl is not None, "Using force torque controller without initialized data"
551555
ft_ctrl: MellingerForceTorqueData
552556
mask = controllable(data.core.steps, data.core.freq, ft_ctrl.steps, ft_ctrl.freq)
553-
ft_ctrl = leaf_replace(
554-
ft_ctrl, mask, cmd_force=ft_ctrl.staged_cmd_force, cmd_torque=ft_ctrl.staged_cmd_torque
555-
)
557+
ft_ctrl = leaf_replace(ft_ctrl, mask, cmd=ft_ctrl.staged_cmd)
556558
rotor_vel = force_torque2rotor_vel(
557-
ft_ctrl.cmd_force, ft_ctrl.cmd_torque, **ft_ctrl.params._asdict()
559+
ft_ctrl.cmd[..., [0]], ft_ctrl.cmd[..., 1:], **ft_ctrl.params._asdict()
558560
)
559561
ft_ctrl = leaf_replace(ft_ctrl, mask, steps=data.core.steps)
560562
data = data.replace(controls=data.controls.replace(rotor_vel=rotor_vel, force_torque=ft_ctrl))
561563
# TODO: Remove
562564
r = R.from_quat(data.states.quat)
563-
cmd_force = jnp.zeros_like(ft_ctrl.cmd_torque)
564-
cmd_force = cmd_force.at[..., 2].set(ft_ctrl.cmd_force[..., 0])
565-
force, torque = r.apply(cmd_force), r.apply(ft_ctrl.cmd_torque)
565+
cmd_force = jnp.zeros_like(ft_ctrl.cmd[..., 1:])
566+
cmd_force = cmd_force.at[..., 2].set(ft_ctrl.cmd[..., 0])
567+
force, torque = r.apply(cmd_force), r.apply(ft_ctrl.cmd[..., 1:])
566568
data = data.replace(states=data.states.replace(force=force, torque=torque))
567569
return data
568570

crazyflow/sim/structs.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def create(n_worlds: int, n_drones: int, device: Device) -> SimStateDeriv:
7171
return SimStateDeriv(dpos=dpos, drot=drot, dvel=dvel, dang_vel=dang_vel)
7272

7373

74+
@typing.runtime_checkable
7475
class ControlData(typing.Protocol):
7576
staged_cmd: Array # (N, M, X)
7677
"""Staged control command for the drone.
@@ -106,6 +107,7 @@ def create(
106107
n_worlds: int,
107108
n_drones: int,
108109
control: Control,
110+
drone_model: str,
109111
state_freq: int | None,
110112
attitude_freq: int | None,
111113
force_torque_freq: int | None,
@@ -115,29 +117,31 @@ def create(
115117
rotor_vel = jnp.zeros((n_worlds, n_drones, 4), device=device)
116118
match control:
117119
case Control.state:
118-
state = MellingerStateData.create(n_worlds, n_drones, state_freq, "", device)
120+
state = MellingerStateData.create(
121+
n_worlds, n_drones, state_freq, drone_model, device
122+
)
119123
attitude = MellingerAttitudeData.create(
120-
n_worlds, n_drones, attitude_freq, "", device
124+
n_worlds, n_drones, attitude_freq, drone_model, device
121125
)
122126
force_torque = MellingerForceTorqueData.create(
123-
n_worlds, n_drones, force_torque_freq, "", device
127+
n_worlds, n_drones, force_torque_freq, drone_model, device
124128
)
125129
return SimControls(
126130
state=state, attitude=attitude, force_torque=force_torque, rotor_vel=rotor_vel
127131
)
128132
case Control.attitude:
129133
attitude = attitude = MellingerAttitudeData.create(
130-
n_worlds, n_drones, attitude_freq, "", device
134+
n_worlds, n_drones, attitude_freq, drone_model, device
131135
)
132136
force_torque = MellingerForceTorqueData.create(
133-
n_worlds, n_drones, force_torque_freq, "", device
137+
n_worlds, n_drones, force_torque_freq, drone_model, device
134138
)
135139
return SimControls(
136140
state=None, attitude=attitude, force_torque=force_torque, rotor_vel=rotor_vel
137141
)
138142
case Control.force_torque:
139143
force_torque = MellingerForceTorqueData.create(
140-
n_worlds, n_drones, force_torque_freq, "", device
144+
n_worlds, n_drones, force_torque_freq, drone_model, device
141145
)
142146
return SimControls(
143147
state=None, attitude=None, force_torque=force_torque, rotor_vel=rotor_vel
@@ -183,9 +187,11 @@ def create(
183187
L: float, mixing_matrix: Array, KF: float, KM: float, device: Device
184188
) -> SimConstants:
185189
"""Create a default set of constants for the simulation."""
186-
return SimConstants(
187-
L=L, MIXING_MATRIX=jnp.array(mixing_matrix, device=device), KF=KF, KM=KM
188-
)
190+
L = jnp.array(L, device=device, dtype=jnp.float32)
191+
mixing_matrix = jnp.array(mixing_matrix, device=device, dtype=jnp.float32)
192+
KF = jnp.array(KF, device=device, dtype=jnp.float32)
193+
KM = jnp.array(KM, device=device, dtype=jnp.float32)
194+
return SimConstants(L=L, MIXING_MATRIX=mixing_matrix, KF=KF, KM=KM)
189195

190196

191197
@dataclass

crazyflow/sim/symbolic.py

Lines changed: 1 addition & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -224,73 +224,6 @@ def symbolic_attitude(dt: float, params: dict | None = None) -> SymbolicModel:
224224
return SymbolicModel(dynamics=dynamics, cost=cost, dt=dt)
225225

226226

227-
def symbolic_thrust(mass: float, J: NDArray, dt: float) -> SymbolicModel:
228-
"""Create symbolic (CasADi) models for dynamics, observation, and cost of a quadcopter.
229-
230-
This model is based on the analytical model of Luis, Carlos, and Jérôme Le Ny. "Design of a
231-
trajectory tracking controller for a nanoquadcopter." arXiv preprint arXiv:1608.05786 (2016).
232-
233-
Returns:
234-
The CasADi symbolic model of the environment.
235-
"""
236-
# Define states.
237-
z = MX.sym("z")
238-
z_dot = MX.sym("z_dot")
239-
240-
# Set up the dynamics model for a 3D quadrotor.
241-
nx, nu = 12, 4
242-
Ixx, Iyy, Izz = J.diagonal()
243-
J = cs.blockcat([[Ixx, 0.0, 0.0], [0.0, Iyy, 0.0], [0.0, 0.0, Izz]])
244-
Jinv = cs.blockcat([[1.0 / Ixx, 0.0, 0.0], [0.0, 1.0 / Iyy, 0.0], [0.0, 0.0, 1.0 / Izz]])
245-
gamma = KM / KF
246-
# System state variables
247-
x, y = MX.sym("x"), MX.sym("y")
248-
x_dot, y_dot = MX.sym("x_dot"), MX.sym("y_dot")
249-
phi, theta, psi = MX.sym("phi"), MX.sym("theta"), MX.sym("psi")
250-
p, q, r = MX.sym("p"), MX.sym("q"), MX.sym("r")
251-
# Rotation matrix transforming a vector in the body frame to the world frame. PyBullet Euler
252-
# angles use the SDFormat for rotation matrices.
253-
Rob = csRotXYZ(phi, theta, psi)
254-
# Define state variables.
255-
X = cs.vertcat(x, y, z, phi, theta, psi, x_dot, y_dot, z_dot, p, q, r)
256-
# Define inputs.
257-
f1, f2, f3, f4 = MX.sym("f1"), MX.sym("f2"), MX.sym("f3"), MX.sym("f4")
258-
U = cs.vertcat(f1, f2, f3, f4)
259-
260-
# Defining the dynamics function.
261-
# We are using the velocity of the base wrt to the world frame expressed in the world frame.
262-
# Note that the reference expresses this in the body frame.
263-
pos_ddot = Rob @ cs.vertcat(0, 0, f1 + f2 + f3 + f4) / mass - cs.vertcat(0, 0, GRAVITY)
264-
pos_dot = cs.vertcat(x_dot, y_dot, z_dot)
265-
# We use the spin directions (signs) from the mix matrix used in the simulation.
266-
sx, sy, sz = SIGN_MIX_MATRIX[..., 0], SIGN_MIX_MATRIX[..., 1], SIGN_MIX_MATRIX[..., 2]
267-
Mb = cs.vertcat(
268-
ARM_LEN / cs.sqrt(2.0) * (sx[0] * f1 + sx[1] * f2 + sx[2] * f3 + sx[3] * f4),
269-
ARM_LEN / cs.sqrt(2.0) * (sy[0] * f1 + sy[1] * f2 + sy[2] * f3 + sy[3] * f4),
270-
gamma * (sz[0] * f1 + sz[1] * f2 + sz[2] * f3 + sz[3] * f4),
271-
)
272-
rate_dot = Jinv @ (Mb - (cs.skew(cs.vertcat(p, q, r)) @ J @ cs.vertcat(p, q, r)))
273-
ang_dot = cs.blockcat(
274-
[
275-
[1, cs.sin(phi) * cs.tan(theta), cs.cos(phi) * cs.tan(theta)],
276-
[0, cs.cos(phi), -cs.sin(phi)],
277-
[0, cs.sin(phi) / cs.cos(theta), cs.cos(phi) / cs.cos(theta)],
278-
]
279-
) @ cs.vertcat(p, q, r)
280-
X_dot = cs.vertcat(pos_dot, ang_dot, pos_ddot, rate_dot)
281-
282-
Y = cs.vertcat(x, x_dot, y, y_dot, z, z_dot, phi, theta, psi, p, q, r)
283-
284-
# Define cost (quadratic form).
285-
Q, R = MX.sym("Q", nx, nx), MX.sym("R", nu, nu)
286-
Xr, Ur = MX.sym("Xr", nx, 1), MX.sym("Ur", nu, 1)
287-
cost_func = 0.5 * (X - Xr).T @ Q @ (X - Xr) + 0.5 * (U - Ur).T @ R @ (U - Ur)
288-
# Define dynamics and cost dictionaries.
289-
dynamics = {"dyn_eqn": X_dot, "obs_eqn": Y, "vars": {"X": X, "U": U}}
290-
cost = {"cost_func": cost_func, "vars": {"X": X, "U": U, "Xr": Xr, "Ur": Ur, "Q": Q, "R": R}}
291-
return SymbolicModel(dynamics=dynamics, cost=cost, dt=dt)
292-
293-
294227
def symbolic_from_sim(sim: Sim) -> SymbolicModel:
295228
"""Create a symbolic model from a Sim instance.
296229
@@ -308,8 +241,7 @@ def symbolic_from_sim(sim: Sim) -> SymbolicModel:
308241
case Control.attitude:
309242
return symbolic_attitude(1 / sim.control_freq)
310243
case Control.force_torque:
311-
mass, J = sim.default_data.params.mass[0, 0], sim.default_data.params.J[0, 0]
312-
return symbolic_force_torque(mass, J, 1 / sim.control_freq)
244+
raise NotImplementedError("Symbolic model for force torque control is not implemented")
313245
case _:
314246
raise ValueError(f"Unsupported control type for symbolic model: {sim.control}")
315247

crazyflow/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,11 @@ def to_device(data: Array, device: str) -> Array:
6868
return jnp.array(data, device=device)
6969

7070

71+
def named_tuple2device(data: T, device: str) -> T:
72+
"""Turn a named tuple into a jax array on the specified device."""
73+
return data._replace(**{f: jnp.asarray(getattr(data, f), device=device) for f in data._fields})
74+
75+
7176
def enable_cache(
7277
cache_path: Path = Path("/tmp/jax_cache"),
7378
min_entry_size_bytes: int = -1,

examples/force_torque.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def main():
1111
duration = 5.0
1212
fps = 60
1313

14-
cmd = np.ones((sim.n_worlds, sim.n_drones, 4)) # [fz, tx, ty, tz]
14+
cmd = np.zeros((sim.n_worlds, sim.n_drones, 4)) # [fz, tx, ty, tz]
1515
cmd[..., 0] = (MASS + 1e-4) * GRAVITY # Plus a small margin to accelerate slightly
1616
for i in range(int(duration * sim.control_freq)):
1717
sim.force_torque_control(cmd)

examples/gradient.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def step(cmd: NDArray, data: SimData) -> jax.Array:
1919
)
2020
data = sim_step(data, sim.freq // sim.control_freq)
2121
# Quadratic cost to reach 1m height
22-
return (data.states.pos[0, 0, 2] - 1.0) ** 2 - 1e-3 * jnp.sum(cmd**2)
22+
return (data.states.pos[0, 0, 2] - 1.0) ** 2
2323

2424
step_grad = jax.jit(jax.grad(step))
2525

0 commit comments

Comments
 (0)