Skip to content

Commit 5e97f42

Browse files
committed
Use more generic parameter loading
1 parent eecaa08 commit 5e97f42

1 file changed

Lines changed: 46 additions & 48 deletions

File tree

crazyflow/sim/physics.py

Lines changed: 46 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,11 @@
77

88
import jax
99
import jax.numpy as jnp
10+
from drone_models.core import load_params
1011
from drone_models.first_principles import dynamics as first_principles_dynamics
11-
from drone_models.first_principles.params import FirstPrinciplesParams
1212
from drone_models.so_rpy import dynamics as so_rpy_dynamics
13-
from drone_models.so_rpy.params import SoRpyParams
1413
from drone_models.so_rpy_rotor import dynamics as so_rpy_rotor_dynamics
15-
from drone_models.so_rpy_rotor.params import SoRpyRotorParams
1614
from drone_models.so_rpy_rotor_drag import dynamics as so_rpy_rotor_drag_dynamics
17-
from drone_models.so_rpy_rotor_drag.params import SoRpyRotorDragParams
1815
from flax.struct import dataclass
1916
from jax import Array
2017

@@ -60,18 +57,18 @@ def create(
6057
n_worlds: int, n_drones: int, drone_model: str, device: Device
6158
) -> FirstPrinciplesData:
6259
"""Create a default set of parameters for the simulation."""
63-
p = FirstPrinciplesParams.load(drone_model)
64-
J = jax.device_put(jnp.tile(p.J[None, None, :, :], (n_worlds, n_drones, 1, 1)), device)
60+
p = load_params("first_principles", drone_model)
61+
J = jax.device_put(jnp.tile(p["J"][None, None, :, :], (n_worlds, n_drones, 1, 1)), device)
6562
return FirstPrinciplesData(
66-
mass=jnp.full((n_worlds, n_drones, 1), p.mass, device=device),
67-
gravity_vec=jnp.asarray(p.gravity_vec, device=device),
63+
mass=jnp.full((n_worlds, n_drones, 1), p["mass"], device=device),
64+
gravity_vec=jnp.asarray(p["gravity_vec"], device=device),
6865
J=J,
6966
J_inv=jnp.linalg.inv(J),
70-
KF=jnp.asarray(p.KF, device=device),
71-
KM=jnp.asarray(p.KM, device=device),
72-
L=jnp.asarray(p.L, device=device),
73-
mixing_matrix=jnp.asarray(p.mixing_matrix, device=device),
74-
thrust_tau=jnp.asarray(p.thrust_tau, device=device),
67+
KF=jnp.asarray(p["KF"], device=device),
68+
KM=jnp.asarray(p["KM"], device=device),
69+
L=jnp.asarray(p["L"], device=device),
70+
mixing_matrix=jnp.asarray(p["mixing_matrix"], device=device),
71+
thrust_tau=jnp.asarray(p["thrust_tau"], device=device),
7572
)
7673

7774

@@ -127,18 +124,18 @@ class SoRpyData:
127124
@staticmethod
128125
def create(n_worlds: int, n_drones: int, drone_model: str, device: Device) -> SoRpyData:
129126
"""Create a default set of parameters for the simulation."""
130-
p = SoRpyParams.load(drone_model)
131-
J = jax.device_put(jnp.tile(p.J[None, None, :, :], (n_worlds, n_drones, 1, 1)), device)
127+
p = load_params("so_rpy", drone_model)
128+
J = jax.device_put(jnp.tile(p["J"][None, None, :, :], (n_worlds, n_drones, 1, 1)), device)
132129
return SoRpyData(
133-
mass=jnp.full((n_worlds, n_drones, 1), p.mass, device=device),
134-
gravity_vec=jnp.asarray(p.gravity_vec, device=device),
130+
mass=jnp.full((n_worlds, n_drones, 1), p["mass"], device=device),
131+
gravity_vec=jnp.asarray(p["gravity_vec"], device=device),
135132
J=J,
136133
J_inv=jnp.linalg.inv(J),
137-
acc_coef=jnp.asarray(p.acc_coef, device=device),
138-
cmd_f_coef=jnp.asarray(p.cmd_f_coef, device=device),
139-
rpy_coef=jnp.asarray(p.rpy_coef, device=device),
140-
rpy_rates_coef=jnp.asarray(p.rpy_rates_coef, device=device),
141-
cmd_rpy_coef=jnp.asarray(p.cmd_rpy_coef, device=device),
134+
acc_coef=jnp.asarray(p["acc_coef"], device=device),
135+
cmd_f_coef=jnp.asarray(p["cmd_f_coef"], device=device),
136+
rpy_coef=jnp.asarray(p["rpy_coef"], device=device),
137+
rpy_rates_coef=jnp.asarray(p["rpy_rates_coef"], device=device),
138+
cmd_rpy_coef=jnp.asarray(p["cmd_rpy_coef"], device=device),
142139
)
143140

144141

@@ -199,21 +196,21 @@ class SoRpyRotorData:
199196
@staticmethod
200197
def create(n_worlds: int, n_drones: int, drone_model: str, device: Device) -> SoRpyRotorData:
201198
"""Create a default set of parameters for the simulation."""
202-
p = SoRpyRotorParams.load(drone_model)
203-
J = jax.device_put(jnp.tile(p.J[None, None, :, :], (n_worlds, n_drones, 1, 1)), device)
199+
p = load_params("so_rpy_rotor", drone_model)
200+
J = jax.device_put(jnp.tile(p["J"][None, None, :, :], (n_worlds, n_drones, 1, 1)), device)
204201
return SoRpyRotorData(
205-
mass=jnp.full((n_worlds, n_drones, 1), p.mass, device=device),
206-
gravity_vec=jnp.asarray(p.gravity_vec, device=device),
202+
mass=jnp.full((n_worlds, n_drones, 1), p["mass"], device=device),
203+
gravity_vec=jnp.asarray(p["gravity_vec"], device=device),
207204
J=J,
208205
J_inv=jnp.linalg.inv(J),
209-
KF=jnp.asarray(p.KF, device=device),
210-
KM=jnp.asarray(p.KM, device=device),
211-
rotor_coef=jnp.asarray(p.rotor_coef, device=device),
212-
acc_coef=jnp.asarray(p.acc_coef, device=device),
213-
cmd_f_coef=jnp.asarray(p.cmd_f_coef, device=device),
214-
rpy_coef=jnp.asarray(p.rpy_coef, device=device),
215-
rpy_rates_coef=jnp.asarray(p.rpy_rates_coef, device=device),
216-
cmd_rpy_coef=jnp.asarray(p.cmd_rpy_coef, device=device),
206+
KF=jnp.asarray(p["KF"], device=device),
207+
KM=jnp.asarray(p["KM"], device=device),
208+
rotor_coef=jnp.asarray(p["rotor_coef"], device=device),
209+
acc_coef=jnp.asarray(p["acc_coef"], device=device),
210+
cmd_f_coef=jnp.asarray(p["cmd_f_coef"], device=device),
211+
rpy_coef=jnp.asarray(p["rpy_coef"], device=device),
212+
rpy_rates_coef=jnp.asarray(p["rpy_rates_coef"], device=device),
213+
cmd_rpy_coef=jnp.asarray(p["cmd_rpy_coef"], device=device),
217214
)
218215

219216

@@ -280,27 +277,28 @@ class SoRpyRotorDragData:
280277
drag_square_coef: Array # (N, M, 1)
281278
"""Square drag coefficient of the drone."""
282279

280+
@staticmethod
283281
def create(
284282
n_worlds: int, n_drones: int, drone_model: str, device: Device
285283
) -> SoRpyRotorDragData:
286284
"""Create a default set of parameters for the simulation."""
287-
p = SoRpyRotorDragParams.load(drone_model)
288-
J = jax.device_put(jnp.tile(p.J[None, None, :, :], (n_worlds, n_drones, 1, 1)), device)
285+
p = load_params("so_rpy_rotor_drag", drone_model)
286+
J = jax.device_put(jnp.tile(p["J"][None, None, :, :], (n_worlds, n_drones, 1, 1)), device)
289287
return SoRpyRotorDragData(
290-
mass=jnp.full((n_worlds, n_drones, 1), p.mass, device=device),
291-
gravity_vec=jnp.asarray(p.gravity_vec, device=device),
288+
mass=jnp.full((n_worlds, n_drones, 1), p["mass"], device=device),
289+
gravity_vec=jnp.asarray(p["gravity_vec"], device=device),
292290
J=J,
293291
J_inv=jnp.linalg.inv(J),
294-
KF=jnp.asarray(p.KF, device=device),
295-
KM=jnp.asarray(p.KM, device=device),
296-
thrust_time_coef=jnp.asarray(p.thrust_time_coef, device=device),
297-
acc_coef=jnp.asarray(p.acc_coef, device=device),
298-
cmd_f_coef=jnp.asarray(p.cmd_f_coef, device=device),
299-
rpy_coef=jnp.asarray(p.rpy_coef, device=device),
300-
rpy_rates_coef=jnp.asarray(p.rpy_rates_coef, device=device),
301-
cmd_rpy_coef=jnp.asarray(p.cmd_rpy_coef, device=device),
302-
drag_linear_coef=jnp.asarray(p.drag_linear_coef, device=device),
303-
drag_square_coef=jnp.asarray(p.drag_square_coef, device=device),
292+
KF=jnp.asarray(p["KF"], device=device),
293+
KM=jnp.asarray(p["KM"], device=device),
294+
thrust_time_coef=jnp.asarray(p["thrust_time_coef"], device=device),
295+
acc_coef=jnp.asarray(p["acc_coef"], device=device),
296+
cmd_f_coef=jnp.asarray(p["cmd_f_coef"], device=device),
297+
rpy_coef=jnp.asarray(p["rpy_coef"], device=device),
298+
rpy_rates_coef=jnp.asarray(p["rpy_rates_coef"], device=device),
299+
cmd_rpy_coef=jnp.asarray(p["cmd_rpy_coef"], device=device),
300+
drag_linear_coef=jnp.asarray(p["drag_linear_coef"], device=device),
301+
drag_square_coef=jnp.asarray(p["drag_square_coef"], device=device),
304302
)
305303

306304

0 commit comments

Comments
 (0)