88import jax .numpy as jnp
99import mujoco
1010import mujoco .mjx as mjx
11+ from drone_models .controller .mellinger import state2attitude
1112from einops import rearrange
1213from gymnasium .envs .mujoco .mujoco_rendering import MujocoRenderer
1314from jax import Array , Device
1415from jax .scipy .spatial .transform import Rotation as R
1516
1617from crazyflow .constants import J_INV , MASS , J
17- from crazyflow .control .control import Control , attitude2rpm , pwm2rpm , state2attitude , thrust2pwm
18+ from crazyflow .control .control import Control , attitude2rpm , controllable , pwm2rpm , thrust2pwm
19+ from crazyflow .control .control import state2attitude as state2attitude_legacy
1820from crazyflow .exception import ConfigError , NotInitializedError
1921from crazyflow .sim .integration import Integrator , euler , rk4 , symplectic_euler
2022from crazyflow .sim .physics import (
2426 rpms2collective_wrench ,
2527 surrogate_identified_collective_wrench ,
2628)
27- from crazyflow .sim .structs import SimControls , SimCore , SimData , SimParams , SimState , SimStateDeriv
29+ from crazyflow .sim .structs import (
30+ SimControls ,
31+ SimControlsNew ,
32+ SimCore ,
33+ SimData ,
34+ SimParams ,
35+ SimState ,
36+ SimStateDeriv ,
37+ )
2838from crazyflow .utils import grid_2d , leaf_replace , pytree_replace , to_device
2939
3040if TYPE_CHECKING :
3141 from mujoco .mjx import Data , Model
3242 from numpy .typing import NDArray
3343
44+ from crazyflow .control .mellinger import MellingerStateData
45+
3446Params = ParamSpec ("Params" ) # Represents arbitrary parameters
3547Return = TypeVar ("Return" ) # Represents the return type
3648
@@ -141,7 +153,11 @@ def state_control(self, controls: Array):
141153 assert controls .shape == (self .n_worlds , self .n_drones , 13 ), "controls shape mismatch"
142154 assert self .control == Control .state , "State control is not enabled by the sim config"
143155 controls = to_device (controls , self .device )
144- self .data = self .data .replace (controls = self .data .controls .replace (state = controls ))
156+ self .data = self .data .replace (
157+ new_controls = self .data .new_controls .replace (
158+ state = self .data .new_controls .state .replace (cmd = controls )
159+ )
160+ )
145161
146162 def thrust_control (self , cmd : Array ):
147163 """Set the desired thrust for all drones in all worlds."""
@@ -182,7 +198,7 @@ def seed(self, seed: int):
182198 Args:
183199 seed: The seed for the JAX rng.
184200 """
185- self .data = seed_sim (self .data , seed , self .device )
201+ self .data : SimData = seed_sim (self .data , seed , self .device )
186202
187203 def close (self ):
188204 if self .viewer is not None :
@@ -282,7 +298,7 @@ def build_mjx(self):
282298
283299 def init_data (
284300 self , state_freq : int , attitude_freq : int , thrust_freq : int , rng_key : Array
285- ) -> tuple [ SimData , SimData ] :
301+ ) -> SimData :
286302 """Initialize the simulation data."""
287303 drone_ids = [self .mj_model .body (f"drone:{ i } " ).id for i in range (self .n_drones )]
288304 N , D = self .n_worlds , self .n_drones
@@ -292,6 +308,9 @@ def init_data(
292308 controls = SimControls .create (N , D , state_freq , attitude_freq , thrust_freq , self .device ),
293309 params = SimParams .create (N , D , MASS , J , J_INV , self .device ),
294310 core = SimCore .create (self .freq , N , D , drone_ids , rng_key , self .device ),
311+ new_controls = SimControlsNew .create (
312+ N , D , self .control , state_freq , attitude_freq , thrust_freq , self .device
313+ ),
295314 )
296315 if D > 1 : # If multiple drones, arrange them in a grid
297316 grid = grid_2d (D )
@@ -417,23 +436,6 @@ def integrate(data: SimData) -> SimData:
417436 return integrate
418437
419438
420- @jax .jit
421- def controllable (step : Array , freq : int , control_steps : Array , control_freq : int ) -> Array :
422- """Check which worlds can currently update their controllers.
423-
424- Args:
425- step: The current step of the simulation.
426- freq: The frequency of the simulation.
427- control_steps: The steps at which the controllers were last updated.
428- control_freq: The frequency of the controllers.
429-
430- Returns:
431- A boolean mask of shape (n_worlds,) that is True at the worlds where the controllers can be
432- updated.
433- """
434- return ((step - control_steps ) >= (freq / control_freq )) | (control_steps == - 1 )
435-
436-
437439@jax .jit
438440def contacts (geom_start : int , geom_count : int , data : Data ) -> Array :
439441 """Filter contacts from MuJoCo data."""
@@ -461,18 +463,53 @@ def sync_sim2mjx(data: SimData, mjx_data: Data, mjx_model: Model) -> tuple[SimDa
461463
462464def step_state_controller (data : SimData ) -> SimData :
463465 """Compute the updated controls for the state controller."""
464- states , controls = data .states , data .controls
465- mask = controllable (data .core .steps , data .core .freq , controls .state_steps , controls .state_freq )
466- des_pos , des_vel = controls .state [..., :3 ], controls .state [..., 3 :6 ]
467- des_yaw = controls .state [..., [9 ]] # Keep (N, M, 1) shape for broadcasting
468- dt = 1 / data .controls .state_freq
469- attitude , pos_err_i = state2attitude (
470- states .pos , states .vel , states .quat , des_pos , des_vel , des_yaw , controls .pos_err_i , dt
466+ states , ctrl_state = data .states , data .new_controls .state
467+ assert ctrl_state is not None , "Using state controller without initialized state control data"
468+ ctrl_state : MellingerStateData
469+ mask = controllable (data .core .steps , data .core .freq , ctrl_state .steps , ctrl_state .freq )
470+ jax .debug .print ("Ctrl cmd: {cmd}" , cmd = ctrl_state .cmd )
471+ attitude , (pos_err_i ,) = state2attitude (
472+ states .pos ,
473+ states .quat ,
474+ states .vel ,
475+ states .ang_vel ,
476+ ctrl_state .cmd ,
477+ ctrl_freq = ctrl_state .freq ,
478+ ctrl_errors = (ctrl_state .pos_err_i ,),
479+ ** ctrl_state .params ._asdict (),
471480 )
472- controls = leaf_replace (
473- controls , mask , state_steps = data .core .steps , staged_attitude = attitude , pos_err_i = pos_err_i
481+ jax .debug .print ("Attitude: {attitude}" , attitude = attitude )
482+ ctrl_state = leaf_replace (ctrl_state , mask , steps = data .core .steps , pos_err_i = pos_err_i )
483+ data = data .replace (
484+ controls = data .controls .replace (staged_attitude = attitude ),
485+ new_controls = data .new_controls .replace (state = ctrl_state ),
474486 )
475- return data .replace (controls = controls )
487+ return data
488+
489+
490+ # def step_state_controller(data: SimData) -> SimData:
491+ # """Compute the updated controls for the state controller."""
492+ # states, ctrl_state = data.states, data.new_controls.state
493+ # assert ctrl_state is not None, "Using state controller without initialized state control data"
494+ # mask = controllable(data.core.steps, data.core.freq, ctrl_state.steps, ctrl_state.freq)
495+ # attitude, pos_err_i = state2attitude_legacy(
496+ # states.pos,
497+ # states.vel,
498+ # states.quat,
499+ # ctrl_state.cmd[..., :3],
500+ # ctrl_state.cmd[..., 3:6],
501+ # ctrl_state.cmd[..., [9]],
502+ # ctrl_state.pos_err_i,
503+ # 1 / ctrl_state.freq,
504+ # )
505+ # attitude = jnp.roll(attitude, -1, axis=-1)
506+ # jax.debug.print("Attitude: {attitude}", attitude=attitude)
507+ # ctrl_state = leaf_replace(ctrl_state, mask, steps=data.core.steps, pos_err_i=pos_err_i)
508+ # data = data.replace(
509+ # controls=data.controls.replace(staged_attitude=attitude),
510+ # new_controls=data.new_controls.replace(state=ctrl_state),
511+ # )
512+ # return data
476513
477514
478515def step_attitude_controller (data : SimData ) -> SimData :
@@ -533,14 +570,6 @@ def identified_wrench(data: SimData) -> SimData:
533570identified_derivative = analytical_derivative # We can use the same derivative function for both
534571
535572
536- def identity (data : SimData , * args : Any , ** kwargs : Any ) -> SimData :
537- """Identity function for the simulation pipeline.
538-
539- Used as default function for optional pipeline steps.
540- """
541- return data
542-
543-
544573def clip_floor_pos (data : SimData ) -> SimData :
545574 """Clip the position of the drone to the floor."""
546575 clip = data .states .pos [..., 2 ] < - 0.001
0 commit comments