|
7 | 7 |
|
8 | 8 | import jax |
9 | 9 | import jax.numpy as jnp |
| 10 | +from drone_models.core import load_params |
10 | 11 | from drone_models.first_principles import dynamics as first_principles_dynamics |
11 | | -from drone_models.first_principles.params import FirstPrinciplesParams |
12 | 12 | from drone_models.so_rpy import dynamics as so_rpy_dynamics |
13 | | -from drone_models.so_rpy.params import SoRpyParams |
14 | 13 | from drone_models.so_rpy_rotor import dynamics as so_rpy_rotor_dynamics |
15 | | -from drone_models.so_rpy_rotor.params import SoRpyRotorParams |
16 | 14 | from drone_models.so_rpy_rotor_drag import dynamics as so_rpy_rotor_drag_dynamics |
17 | | -from drone_models.so_rpy_rotor_drag.params import SoRpyRotorDragParams |
18 | 15 | from flax.struct import dataclass |
19 | 16 | from jax import Array |
20 | 17 |
|
@@ -60,18 +57,18 @@ def create( |
60 | 57 | n_worlds: int, n_drones: int, drone_model: str, device: Device |
61 | 58 | ) -> FirstPrinciplesData: |
62 | 59 | """Create a default set of parameters for the simulation.""" |
63 | | - p = FirstPrinciplesParams.load(drone_model) |
64 | | - J = jax.device_put(jnp.tile(p.J[None, None, :, :], (n_worlds, n_drones, 1, 1)), device) |
| 60 | + p = load_params("first_principles", drone_model) |
| 61 | + J = jax.device_put(jnp.tile(p["J"][None, None, :, :], (n_worlds, n_drones, 1, 1)), device) |
65 | 62 | return FirstPrinciplesData( |
66 | | - mass=jnp.full((n_worlds, n_drones, 1), p.mass, device=device), |
67 | | - gravity_vec=jnp.asarray(p.gravity_vec, device=device), |
| 63 | + mass=jnp.full((n_worlds, n_drones, 1), p["mass"], device=device), |
| 64 | + gravity_vec=jnp.asarray(p["gravity_vec"], device=device), |
68 | 65 | J=J, |
69 | 66 | J_inv=jnp.linalg.inv(J), |
70 | | - KF=jnp.asarray(p.KF, device=device), |
71 | | - KM=jnp.asarray(p.KM, device=device), |
72 | | - L=jnp.asarray(p.L, device=device), |
73 | | - mixing_matrix=jnp.asarray(p.mixing_matrix, device=device), |
74 | | - thrust_tau=jnp.asarray(p.thrust_tau, device=device), |
| 67 | + KF=jnp.asarray(p["KF"], device=device), |
| 68 | + KM=jnp.asarray(p["KM"], device=device), |
| 69 | + L=jnp.asarray(p["L"], device=device), |
| 70 | + mixing_matrix=jnp.asarray(p["mixing_matrix"], device=device), |
| 71 | + thrust_tau=jnp.asarray(p["thrust_tau"], device=device), |
75 | 72 | ) |
76 | 73 |
|
77 | 74 |
|
@@ -127,18 +124,18 @@ class SoRpyData: |
127 | 124 | @staticmethod |
128 | 125 | def create(n_worlds: int, n_drones: int, drone_model: str, device: Device) -> SoRpyData: |
129 | 126 | """Create a default set of parameters for the simulation.""" |
130 | | - p = SoRpyParams.load(drone_model) |
131 | | - J = jax.device_put(jnp.tile(p.J[None, None, :, :], (n_worlds, n_drones, 1, 1)), device) |
| 127 | + p = load_params("so_rpy", drone_model) |
| 128 | + J = jax.device_put(jnp.tile(p["J"][None, None, :, :], (n_worlds, n_drones, 1, 1)), device) |
132 | 129 | return SoRpyData( |
133 | | - mass=jnp.full((n_worlds, n_drones, 1), p.mass, device=device), |
134 | | - gravity_vec=jnp.asarray(p.gravity_vec, device=device), |
| 130 | + mass=jnp.full((n_worlds, n_drones, 1), p["mass"], device=device), |
| 131 | + gravity_vec=jnp.asarray(p["gravity_vec"], device=device), |
135 | 132 | J=J, |
136 | 133 | J_inv=jnp.linalg.inv(J), |
137 | | - acc_coef=jnp.asarray(p.acc_coef, device=device), |
138 | | - cmd_f_coef=jnp.asarray(p.cmd_f_coef, device=device), |
139 | | - rpy_coef=jnp.asarray(p.rpy_coef, device=device), |
140 | | - rpy_rates_coef=jnp.asarray(p.rpy_rates_coef, device=device), |
141 | | - cmd_rpy_coef=jnp.asarray(p.cmd_rpy_coef, device=device), |
| 134 | + acc_coef=jnp.asarray(p["acc_coef"], device=device), |
| 135 | + cmd_f_coef=jnp.asarray(p["cmd_f_coef"], device=device), |
| 136 | + rpy_coef=jnp.asarray(p["rpy_coef"], device=device), |
| 137 | + rpy_rates_coef=jnp.asarray(p["rpy_rates_coef"], device=device), |
| 138 | + cmd_rpy_coef=jnp.asarray(p["cmd_rpy_coef"], device=device), |
142 | 139 | ) |
143 | 140 |
|
144 | 141 |
|
@@ -199,21 +196,21 @@ class SoRpyRotorData: |
199 | 196 | @staticmethod |
200 | 197 | def create(n_worlds: int, n_drones: int, drone_model: str, device: Device) -> SoRpyRotorData: |
201 | 198 | """Create a default set of parameters for the simulation.""" |
202 | | - p = SoRpyRotorParams.load(drone_model) |
203 | | - J = jax.device_put(jnp.tile(p.J[None, None, :, :], (n_worlds, n_drones, 1, 1)), device) |
| 199 | + p = load_params("so_rpy_rotor", drone_model) |
| 200 | + J = jax.device_put(jnp.tile(p["J"][None, None, :, :], (n_worlds, n_drones, 1, 1)), device) |
204 | 201 | return SoRpyRotorData( |
205 | | - mass=jnp.full((n_worlds, n_drones, 1), p.mass, device=device), |
206 | | - gravity_vec=jnp.asarray(p.gravity_vec, device=device), |
| 202 | + mass=jnp.full((n_worlds, n_drones, 1), p["mass"], device=device), |
| 203 | + gravity_vec=jnp.asarray(p["gravity_vec"], device=device), |
207 | 204 | J=J, |
208 | 205 | J_inv=jnp.linalg.inv(J), |
209 | | - KF=jnp.asarray(p.KF, device=device), |
210 | | - KM=jnp.asarray(p.KM, device=device), |
211 | | - rotor_coef=jnp.asarray(p.rotor_coef, device=device), |
212 | | - acc_coef=jnp.asarray(p.acc_coef, device=device), |
213 | | - cmd_f_coef=jnp.asarray(p.cmd_f_coef, device=device), |
214 | | - rpy_coef=jnp.asarray(p.rpy_coef, device=device), |
215 | | - rpy_rates_coef=jnp.asarray(p.rpy_rates_coef, device=device), |
216 | | - cmd_rpy_coef=jnp.asarray(p.cmd_rpy_coef, device=device), |
| 206 | + KF=jnp.asarray(p["KF"], device=device), |
| 207 | + KM=jnp.asarray(p["KM"], device=device), |
| 208 | + rotor_coef=jnp.asarray(p["rotor_coef"], device=device), |
| 209 | + acc_coef=jnp.asarray(p["acc_coef"], device=device), |
| 210 | + cmd_f_coef=jnp.asarray(p["cmd_f_coef"], device=device), |
| 211 | + rpy_coef=jnp.asarray(p["rpy_coef"], device=device), |
| 212 | + rpy_rates_coef=jnp.asarray(p["rpy_rates_coef"], device=device), |
| 213 | + cmd_rpy_coef=jnp.asarray(p["cmd_rpy_coef"], device=device), |
217 | 214 | ) |
218 | 215 |
|
219 | 216 |
|
@@ -280,27 +277,28 @@ class SoRpyRotorDragData: |
280 | 277 | drag_square_coef: Array # (N, M, 1) |
281 | 278 | """Square drag coefficient of the drone.""" |
282 | 279 |
|
| 280 | + @staticmethod |
283 | 281 | def create( |
284 | 282 | n_worlds: int, n_drones: int, drone_model: str, device: Device |
285 | 283 | ) -> SoRpyRotorDragData: |
286 | 284 | """Create a default set of parameters for the simulation.""" |
287 | | - p = SoRpyRotorDragParams.load(drone_model) |
288 | | - J = jax.device_put(jnp.tile(p.J[None, None, :, :], (n_worlds, n_drones, 1, 1)), device) |
| 285 | + p = load_params("so_rpy_rotor_drag", drone_model) |
| 286 | + J = jax.device_put(jnp.tile(p["J"][None, None, :, :], (n_worlds, n_drones, 1, 1)), device) |
289 | 287 | return SoRpyRotorDragData( |
290 | | - mass=jnp.full((n_worlds, n_drones, 1), p.mass, device=device), |
291 | | - gravity_vec=jnp.asarray(p.gravity_vec, device=device), |
| 288 | + mass=jnp.full((n_worlds, n_drones, 1), p["mass"], device=device), |
| 289 | + gravity_vec=jnp.asarray(p["gravity_vec"], device=device), |
292 | 290 | J=J, |
293 | 291 | J_inv=jnp.linalg.inv(J), |
294 | | - KF=jnp.asarray(p.KF, device=device), |
295 | | - KM=jnp.asarray(p.KM, device=device), |
296 | | - thrust_time_coef=jnp.asarray(p.thrust_time_coef, device=device), |
297 | | - acc_coef=jnp.asarray(p.acc_coef, device=device), |
298 | | - cmd_f_coef=jnp.asarray(p.cmd_f_coef, device=device), |
299 | | - rpy_coef=jnp.asarray(p.rpy_coef, device=device), |
300 | | - rpy_rates_coef=jnp.asarray(p.rpy_rates_coef, device=device), |
301 | | - cmd_rpy_coef=jnp.asarray(p.cmd_rpy_coef, device=device), |
302 | | - drag_linear_coef=jnp.asarray(p.drag_linear_coef, device=device), |
303 | | - drag_square_coef=jnp.asarray(p.drag_square_coef, device=device), |
| 292 | + KF=jnp.asarray(p["KF"], device=device), |
| 293 | + KM=jnp.asarray(p["KM"], device=device), |
| 294 | + thrust_time_coef=jnp.asarray(p["thrust_time_coef"], device=device), |
| 295 | + acc_coef=jnp.asarray(p["acc_coef"], device=device), |
| 296 | + cmd_f_coef=jnp.asarray(p["cmd_f_coef"], device=device), |
| 297 | + rpy_coef=jnp.asarray(p["rpy_coef"], device=device), |
| 298 | + rpy_rates_coef=jnp.asarray(p["rpy_rates_coef"], device=device), |
| 299 | + cmd_rpy_coef=jnp.asarray(p["cmd_rpy_coef"], device=device), |
| 300 | + drag_linear_coef=jnp.asarray(p["drag_linear_coef"], device=device), |
| 301 | + drag_square_coef=jnp.asarray(p["drag_square_coef"], device=device), |
304 | 302 | ) |
305 | 303 |
|
306 | 304 |
|
|
0 commit comments