Skip to content

Commit be04a39

Browse files
committed
[WIP, broken] Switch to mellinger from drone-models
1 parent 5a28c10 commit be04a39

14 files changed

Lines changed: 383 additions & 333 deletions

File tree

crazyflow/constants.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,16 @@
66

77
# Drone constants
88
ARM_LEN: float = 0.0325 * np.sqrt(2)
9-
MIX_MATRIX: NDArray = np.array([[-0.5, -0.5, -1], [-0.5, 0.5, 1], [0.5, 0.5, -1], [0.5, -0.5, 1]])
9+
# fmt: off
10+
MIX_MATRIX: NDArray = np.array([[-0.5, -0.5, -1],
11+
[-0.5, 0.5, 1],
12+
[ 0.5, 0.5, -1],
13+
[ 0.5, -0.5, 1]])
14+
# fmt: on
1015
SIGN_MIX_MATRIX: NDArray = np.sign(MIX_MATRIX)
1116
# Crazyflie 2.1 mass as measured in the lab with battery included
1217
MASS: float = 0.033
1318
J: NDArray = np.array([[2.3951e-5, 0, 0], [0, 2.3951e-5, 0], [0, 0, 3.2347e-5]])
1419
J_INV: NDArray = np.linalg.inv(J)
20+
KF: float = 8.701227710666256e-10
21+
KM: float = 7.94e-12

crazyflow/control/control.py

Lines changed: 3 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from jax import Array
2020
from jax.scipy.spatial.transform import Rotation as R
2121

22-
from crazyflow.constants import GRAVITY, MASS, MIX_MATRIX
22+
from crazyflow.constants import MIX_MATRIX
2323

2424

2525
class Control(str, Enum):
@@ -41,8 +41,8 @@ class Control(str, Enum):
4141
Note:
4242
Recommended frequency is >=100 Hz.
4343
"""
44-
thrust = "thrust"
45-
"""Thrust control takes [thrust1, thrust2, thrust3, thrust4] for each drone motor.
44+
force_torque = "force_torque"
45+
"""Force and torque control takes [fx, fy, fz, tx, ty, tz].
4646
4747
Note:
4848
Recommended frequency is >=500 Hz.
@@ -90,38 +90,6 @@ def controllable(step: Array, freq: int, control_steps: Array, control_freq: int
9090
THRUST_CURVE_C: float = 0.0209
9191

9292

93-
@partial(jnp.vectorize, signature="(3),(3),(4),(3),(3),(1),(3)->(4),(3)", excluded=[7])
94-
def state2attitude(
95-
pos: Array,
96-
vel: Array,
97-
quat: Array,
98-
des_pos: Array,
99-
des_vel: Array,
100-
des_yaw: Array,
101-
i_error: Array,
102-
dt: float,
103-
) -> tuple[Array, Array]:
104-
"""Compute the next desired collective thrust and roll/pitch/yaw of the drone."""
105-
pos_error, vel_error = des_pos - pos, des_vel - vel
106-
# Update integral error
107-
i_error = jnp.clip(i_error + pos_error * dt, -I_F_RANGE, I_F_RANGE)
108-
# Compute target thrust
109-
thrust = P_F * pos_error + I_F * i_error + D_F * vel_error
110-
thrust = thrust.at[2].add(MASS * GRAVITY)
111-
# Update z_axis to the current orientation of the drone
112-
z_axis = R.from_quat(quat).as_matrix()[:, 2]
113-
# Project the thrust onto the z-axis
114-
thrust_desired = jnp.clip(thrust @ z_axis, 0.3 * MASS * GRAVITY, 1.8 * MASS * GRAVITY)
115-
# Update the desired z-axis
116-
z_axis = thrust / jnp.linalg.norm(thrust)
117-
yaw_axis = jnp.concatenate([jnp.cos(des_yaw), jnp.sin(des_yaw), jnp.array([0.0])])
118-
y_axis = jnp.cross(z_axis, yaw_axis)
119-
y_axis = y_axis / jnp.linalg.norm(y_axis)
120-
x_axis = jnp.cross(y_axis, z_axis)
121-
euler_desired = R.from_matrix(jnp.vstack([x_axis, y_axis, z_axis]).T).as_euler("xyz")
122-
return jnp.concatenate([jnp.atleast_1d(thrust_desired), euler_desired]), i_error
123-
124-
12593
@partial(jnp.vectorize, signature="(4),(4),(3),(3)->(4),(3)", excluded=[4])
12694
def attitude2rpm(
12795
controls: Array, quat: Array, last_rpy: Array, rpy_err_i: Array, dt: float

crazyflow/control/mellinger.py

Lines changed: 86 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
from __future__ import annotations
22

33
import jax.numpy as jnp
4-
from drone_models.controller.mellinger import MellingerStateParams
4+
from drone_models.controller.mellinger import (
5+
MellingerAttitudeParams,
6+
MellingerForceTorqueParams,
7+
MellingerStateParams,
8+
)
59
from flax.struct import dataclass, field
610
from jax import Array, Device
711

@@ -14,6 +18,8 @@ class MellingerStateData:
1418
A command consists of [x, y, z, vx, vy, vz, ax, ay, az, yaw, roll_rate, pitch_rate, yaw_rate].
1519
We currently do not use the acceleration and angle rate components. This is subject to change.
1620
"""
21+
staged_cmd: Array # (N, M, 13)
22+
"""Staging buffer to store the most recent command until the next controller tick."""
1723
steps: Array # (N, 1)
1824
"""Last simulation steps that the state control command was applied."""
1925
freq: int = field(pytree_node=False)
@@ -29,39 +35,90 @@ def create(
2935
) -> MellingerStateData:
3036
"""Create a default set of state data for the simulation."""
3137
cmd = jnp.zeros((n_worlds, n_drones, 13), device=device)
32-
steps = jnp.zeros((n_worlds, 1), dtype=jnp.int32, device=device)
38+
steps = -jnp.ones((n_worlds, 1), dtype=jnp.int32, device=device)
3339
pos_err_i = jnp.zeros((n_worlds, n_drones, 3), device=device)
3440
params = MellingerStateParams.load(drone_model)
3541
return MellingerStateData(
36-
cmd=cmd, steps=steps, freq=freq, pos_err_i=pos_err_i, params=params
42+
cmd=cmd, staged_cmd=cmd, steps=steps, freq=freq, pos_err_i=pos_err_i, params=params
3743
)
3844

3945

40-
# @dataclass
41-
# class MellingerAttitudeData:
42-
# cmd: Array # (N, M, 4)
43-
# """Full attitude control command for the drone.
46+
@dataclass
47+
class MellingerAttitudeData:
48+
cmd: Array # (N, M, 4)
49+
"""Full attitude control command for the drone.
50+
51+
A command consists of [roll, pitch, yaw, collective thrust].
52+
"""
53+
staged_cmd: Array # (N, M, 4)
54+
"""Staging buffer to store the most recent command until the next controller tick."""
55+
steps: Array # (N, 1)
56+
"""Last simulation steps that the attitude control command was applied."""
57+
freq: int = field(pytree_node=False)
58+
"""Frequency of the attitude control command."""
59+
r_int_error: Array # (N, M, 3)
60+
"""Integral errors of the attitude control command."""
61+
last_ang_vel: Array # (N, M, 3)
62+
"""Last angular velocity of the drone."""
63+
# Parameters for the attitude controller
64+
params: MellingerAttitudeParams
65+
66+
@staticmethod
67+
def create(
68+
n_worlds: int, n_drones: int, freq: int, drone_model: str, device: Device
69+
) -> MellingerAttitudeData:
70+
"""Create a default set of attitude data for the simulation."""
71+
cmd = jnp.zeros((n_worlds, n_drones, 4), device=device)
72+
steps = -jnp.ones((n_worlds, 1), dtype=jnp.int32, device=device)
73+
zeros_3d = jnp.zeros((n_worlds, n_drones, 3), device=device)
74+
params = MellingerAttitudeParams.load(drone_model)
75+
return MellingerAttitudeData(
76+
cmd=cmd,
77+
staged_cmd=cmd,
78+
steps=steps,
79+
freq=freq,
80+
r_int_error=zeros_3d,
81+
last_ang_vel=zeros_3d,
82+
params=params,
83+
)
84+
4485

45-
# A command consists of [collective thrust, roll, pitch, yaw].
46-
# """
47-
# steps: Array # (N, 1)
48-
# """Last simulation steps that the attitude control command was applied."""
49-
# freq: int = field(pytree_node=False)
50-
# """Frequency of the attitude control command."""
51-
# pos_err_i: Array # (N, M, 3)
52-
# """Integral errors of the attitude control command."""
53-
# # Parameters for the attitude controller
54-
# params: MellingerAttitudeParams
86+
@dataclass
87+
class MellingerForceTorqueData:
88+
cmd_force: Array # (N, M, 1)
89+
"""Force command for the drone.
90+
91+
A command consists of [fz].
92+
"""
93+
cmd_torque: Array # (N, M, 3)
94+
"""Torque command for the drone.
5595
56-
# @staticmethod
57-
# def create(
58-
# n_worlds: int, n_drones: int, freq: int, drone_model: str, device: Device
59-
# ) -> MellingerAttitudeData:
60-
# """Create a default set of attitude data for the simulation."""
61-
# cmd = jnp.zeros((n_worlds, n_drones, 4), device=device)
62-
# steps = jnp.zeros((n_worlds, 1), dtype=jnp.int32, device=device)
63-
# pos_err_i = jnp.zeros((n_worlds, n_drones, 3), device=device)
64-
# params = MellingerAttitudeParams.load(drone_model)
65-
# return MellingerAttitudeData(
66-
# cmd=cmd, steps=steps, freq=freq, pos_err_i=pos_err_i, params=params
67-
# )
96+
A command consists of [tx, ty, tz].
97+
"""
98+
staged_cmd_force: Array # (N, M, 1)
99+
staged_cmd_torque: Array # (N, M, 3)
100+
"""Staging buffer to store the most recent command until the next controller tick."""
101+
steps: Array # (N, 1)
102+
"""Last simulation steps that the force and torque control command was applied."""
103+
freq: int = field(pytree_node=False)
104+
"""Frequency of the force and torque control command."""
105+
# Parameters for the force and torque controller
106+
params: MellingerForceTorqueParams
107+
108+
@staticmethod
109+
def create(
110+
n_worlds: int, n_drones: int, freq: int, drone_model: str, device: Device
111+
) -> MellingerForceTorqueData:
112+
zero_1d = jnp.zeros((n_worlds, n_drones, 1), device=device)
113+
zero_3d = jnp.zeros((n_worlds, n_drones, 3), device=device)
114+
steps = -jnp.ones((n_worlds, 1), dtype=jnp.int32, device=device)
115+
params = MellingerForceTorqueParams.load(drone_model)
116+
return MellingerForceTorqueData(
117+
cmd_force=zero_1d,
118+
cmd_torque=zero_3d,
119+
staged_cmd_force=zero_1d,
120+
staged_cmd_torque=zero_3d,
121+
steps=steps,
122+
freq=freq,
123+
params=params,
124+
)

crazyflow/envs/drone_env.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ def action_space(control_type: Control) -> spaces.Box:
3333
np.array([4 * MIN_THRUST, -np.pi / 2, -np.pi / 2, -np.pi / 2], dtype=np.float32),
3434
np.array([4 * MAX_THRUST, np.pi / 2, np.pi / 2, np.pi / 2], dtype=np.float32),
3535
)
36-
case Control.thrust:
37-
return spaces.Box(MIN_THRUST, MAX_THRUST, shape=(4,))
36+
case Control.force_torque:
37+
return spaces.Box(-1.0, 1.0, shape=(6,))
3838
case _:
3939
raise ValueError(f"Invalid control type {control_type}")
4040

@@ -128,8 +128,8 @@ def _apply_action(self, action: Array):
128128
raise NotImplementedError("State control currently not supported")
129129
case Control.attitude:
130130
self.sim.attitude_control(action)
131-
case Control.thrust:
132-
self.sim.thrust_control(action)
131+
case Control.force_torque:
132+
self.sim.force_torque_control(action)
133133
case _:
134134
raise ValueError(f"Invalid control type {self.sim.control}")
135135

crazyflow/randomize/randomize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,4 +50,4 @@ def _randomize_mass_params(data: SimData, mass: Array, mask: Array | None = None
5050
@jax.jit
5151
def _randomize_inertia_params(data: SimData, new_j: Array, mask: Array | None = None) -> SimData:
5252
new_j_inv = jnp.linalg.inv(new_j)
53-
return data.replace(params=leaf_replace(data.params, mask, J=new_j, J_INV=new_j_inv))
53+
return data.replace(params=leaf_replace(data.params, mask, J=new_j, J_inv=new_j_inv))

crazyflow/sim/physics.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def surrogate_identified_collective_wrench(
4848
J: The drone's inertia matrix.
4949
dt: The simulation time step.
5050
"""
51-
collective_thrust, attitude = controls[0], controls[1:]
51+
attitude, collective_thrust = controls[:3], controls[3]
5252
rot = R.from_quat(quat)
5353
thrust = rot.apply(jnp.array([0, 0, collective_thrust]))
5454
drift = rot.apply(jnp.array([0, 0, 1]))
@@ -84,9 +84,9 @@ def collective_force2acceleration(force: Array, mass: Array) -> Array:
8484

8585

8686
@partial(vectorize, signature="(3),(4),(3,3)->(3)")
87-
def collective_torque2ang_vel_deriv(torque: Array, quat: Array, J_INV: Array) -> Array:
87+
def collective_torque2ang_vel_deriv(torque: Array, quat: Array, J_inv: Array) -> Array:
8888
"""Convert torques to ang_vel_deriv."""
89-
return J_INV @ R.from_quat(quat).apply(torque, inverse=True)
89+
return J_inv @ R.from_quat(quat).apply(torque, inverse=True)
9090

9191

9292
@partial(vectorize, signature="(4),(4),(3),(3,3)->(3),(3)")

0 commit comments

Comments
 (0)