11from __future__ import annotations
22
3- from functools import partial
3+ from functools import partial , wraps
44from pathlib import Path
55from typing import TYPE_CHECKING , Any , Callable
66
1313from jax import Array , Device
1414from jax .scipy .spatial .transform import Rotation as R
1515
16- from crazyflow .constants import J_INV , MASS , SIGN_MIX_MATRIX , J
16+ from crazyflow .constants import J_INV , MASS , J
1717from crazyflow .control .control import Control , attitude2rpm , pwm2rpm , state2attitude , thrust2pwm
1818from crazyflow .exception import ConfigError , NotInitializedError
1919from crazyflow .sim .integration import Integrator , euler , rk4 , symplectic_euler
2222 collective_force2acceleration ,
2323 collective_torque2ang_vel_deriv ,
2424 rpms2collective_wrench ,
25- rpms2motor_forces ,
26- rpms2motor_torques ,
2725 surrogate_identified_collective_wrench ,
2826)
2927from crazyflow .sim .structs import SimControls , SimCore , SimData , SimParams , SimState , SimStateDeriv
3432 from numpy .typing import NDArray
3533
3634
35+ def requires_mujoco_sync (fn : Callable [[SimData ], SimData ]) -> Callable [[SimData ], SimData ]:
36+ """Decorator to ensure that the simulation data is synchronized with the MuJoCo mjx data."""
37+
38+ @wraps (fn )
39+ def wrapper (sim : Sim , * args : Any , ** kwargs : Any ) -> SimData :
40+ if not sim .data .core .mjx_synced :
41+ sim .data , sim .mjx_data = sync_sim2mjx (sim .data , sim .mjx_data , sim .mjx_model )
42+ return fn (sim , * args , ** kwargs )
43+
44+ return wrapper
45+
46+
3747class Sim :
3848 default_path = Path (__file__ ).parents [1 ] / "models/cf2/scene.xml"
3949 drone_path = Path (__file__ ).parents [1 ] / "models/cf2/cf2.xml"
@@ -71,10 +81,10 @@ def __init__(
7181 # Initialize MuJoCo world and data
7282 self ._xml_path = xml_path or self .default_path
7383 self .spec = self .build_mjx_spec ()
74- self .mj_model , self .mj_data , self .mjx_model , mjx_data = self .build_mjx_model (self .spec )
84+ self .mj_model , self .mj_data , self .mjx_model , self . mjx_data = self .build_mjx_model (self .spec )
7585 self .viewer : MujocoRenderer | None = None
7686
77- self .data = self .init_data (state_freq , attitude_freq , thrust_freq , rng_key , mjx_data )
87+ self .data = self .init_data (state_freq , attitude_freq , thrust_freq , rng_key )
7888 self .default_data : SimData
7989 self .build_default_data ()
8090
@@ -90,9 +100,6 @@ def __init__(
90100 # We never drop below -0.001 (drones can't pass through the floor). We use -0.001 to
91101 # enable checks for negative z sign
92102 self .step_pipeline += (clip_floor_pos ,)
93- # MuJoCo needs to sync after every physics step so that the next step control, wrench
94- # and disturbance functions see the correct state.
95- self .step_pipeline += (select_sync_fn (self .physics ),)
96103
97104 self .build_reset_fn ()
98105 self .build_step_fn ()
@@ -140,6 +147,7 @@ def thrust_control(self, cmd: Array):
140147 controls = to_device (cmd , self .device )
141148 self .data = self .data .replace (controls = self .data .controls .replace (thrust = controls ))
142149
150+ @requires_mujoco_sync
143151 def render (
144152 self ,
145153 mode : str | None = "human" ,
@@ -160,9 +168,9 @@ def render(
160168 height = height ,
161169 width = width ,
162170 )
163- self .mj_data .qpos [:] = self .data . mjx_data .qpos [world , :]
164- self .mj_data .mocap_pos [:] = self .data . mjx_data .mocap_pos [world , :]
165- self .mj_data .mocap_quat [:] = self .data . mjx_data .mocap_quat [world , :]
171+ self .mj_data .qpos [:] = self .mjx_data .qpos [world , :]
172+ self .mj_data .mocap_pos [:] = self .mjx_data .mocap_pos [world , :]
173+ self .mj_data .mocap_quat [:] = self .mjx_data .mocap_quat [world , :]
166174 mujoco .mj_forward (self .mj_model , self .mj_data )
167175 return self .viewer .render (mode )
168176
@@ -232,20 +240,8 @@ def single_step(data: SimData, _: None) -> tuple[SimData, None]:
232240 # always use the same n_steps value for successive calls.
233241 @partial (jax .jit , static_argnames = "n_steps" )
234242 def step (data : SimData , n_steps : int = 1 ) -> SimData :
235- # Performance optimization: When step is called, jax checks if it can reuse a previously
236- # compiled version of the function. This check flattens the sim.data PyTree and compares
237- # the metadata of each leaf with the cached metadata. The more leaves contained in
238- # sim.data, the more time is spent on the cache lookup even if the function has already
239- # been compiled. Since mjx_model contains many PyTree nodes and it is not used by
240- # physics modes other than mujoco with domain randomization, we set it to None and
241- # capture the current sim.mjx_model in the step function's closure. Changes to the
242- # params are synced to mjx_model at the start of the step function.
243- if optimize_mjx_model := (data .mjx_model is None ):
244- data = data .replace (mjx_model = self .mjx_model )
245- data = self .sync_sim2mjx (data )
246243 data , _ = jax .lax .scan (single_step , data , length = n_steps , unroll = 1 )
247- if optimize_mjx_model :
248- data = data .replace (mjx_model = None )
244+ data = data .replace (core = data .core .replace (mjx_synced = False )) # Flag mjx data as stale
249245 return data
250246
251247 self ._step = step
@@ -259,7 +255,7 @@ def reset(data: SimData, default_data: SimData, mask: Array | None = None) -> Si
259255 data = pytree_replace (data , default_data , mask ) # Does not overwrite rng_key
260256 for fn in pipeline :
261257 data = fn (data , mask )
262- data = self . sync_sim2mjx ( data , self . mjx_model )
258+ data = data . replace ( core = data . core . replace ( mjx_synced = False )) # Flag mjx data as stale
263259 return data
264260
265261 self ._reset = reset
@@ -270,7 +266,6 @@ def build_data(self):
270266 self .data .controls .attitude_freq ,
271267 self .data .controls .thrust_freq ,
272268 self .data .core .rng_key ,
273- self .data .mjx_data ,
274269 )
275270
276271 def build_default_data (self ):
@@ -281,13 +276,10 @@ def build_mjx(self):
281276 if self .viewer is not None :
282277 self .viewer .close ()
283278 self .viewer = None
284- self .mj_model , self .mj_data , self .mjx_model , mjx_data = self .build_mjx_model (self .spec )
285- self .data = self .data .replace (mjx_data = mjx_data )
286- self .data = self .sync_sim2mjx (self .data , self .mjx_model )
287- self .default_data = self .default_data .replace (mjx_data = mjx_data )
279+ self .mj_model , self .mj_data , self .mjx_model , self .mjx_data = self .build_mjx_model (self .spec )
288280
289281 def init_data (
290- self , state_freq : int , attitude_freq : int , thrust_freq : int , rng_key : Array , mjx_data : Data
282+ self , state_freq : int , attitude_freq : int , thrust_freq : int , rng_key : Array
291283 ) -> tuple [SimData , SimData ]:
292284 """Initialize the simulation data."""
293285 drone_ids = [self .mj_model .body (f"drone:{ i } " ).id for i in range (self .n_drones )]
@@ -298,14 +290,11 @@ def init_data(
298290 controls = SimControls .create (N , D , state_freq , attitude_freq , thrust_freq , self .device ),
299291 params = SimParams .create (N , D , MASS , J , J_INV , self .device ),
300292 core = SimCore .create (self .freq , N , D , drone_ids , rng_key , self .device ),
301- mjx_data = mjx_data ,
302- mjx_model = None ,
303293 )
304294 if D > 1 : # If multiple drones, arrange them in a grid
305295 grid = grid_2d (D )
306296 states = data .states .replace (pos = data .states .pos .at [..., :2 ].set (grid ))
307297 data = data .replace (states = states )
308- data = self .sync_sim2mjx (data , self .mjx_model )
309298 return data
310299
311300 @property
@@ -343,6 +332,7 @@ def controllable(self) -> Array:
343332 raise NotImplementedError (f"Control mode { self .control } not implemented" )
344333 return controllable (self .data .core .steps , self .data .core .freq , control_steps , control_freq )
345334
335+ @requires_mujoco_sync
346336 def contacts (self , body : str | None = None ) -> Array :
347337 """Get contact information from the simulation.
348338
@@ -353,45 +343,11 @@ def contacts(self, body: str | None = None) -> Array:
353343 An boolean array of shape (n_worlds,) that is True if any contact is present.
354344 """
355345 if body is None :
356- return self .data . mjx_data ._impl .contact .dist < 0
346+ return self .mjx_data ._impl .contact .dist < 0
357347 body_id = self .mj_model .body (body ).id
358348 geom_start = self .mj_model .body_geomadr [body_id ]
359349 geom_count = self .mj_model .body_geomnum [body_id ]
360- return contacts (geom_start , geom_count , self .data .mjx_data )
361-
362- @staticmethod
363- @jax .jit
364- def sync_sim2mjx (data : SimData , mjx_model : Model | None = None ) -> SimData :
365- states = data .states
366- pos , quat , vel , ang_vel = states .pos , states .quat , states .vel , states .ang_vel
367- quat = quat [..., [3 , 0 , 1 , 2 ]] # MuJoCo quat is [w, x, y, z], ours is [x, y, z, w]
368- qpos = rearrange (jnp .concat ([pos , quat ], axis = - 1 ), "w d qpos -> w (d qpos)" )
369- qvel = rearrange (jnp .concat ([vel , ang_vel ], axis = - 1 ), "w d qvel -> w (d qvel)" )
370- mjx_data = data .mjx_data
371- mjx_model = data .mjx_model if mjx_model is None else mjx_model
372- assert mjx_model is not None , "MuJoCo model is not initialized"
373- mjx_data = mjx_data .replace (qpos = qpos , qvel = qvel )
374- mjx_data = jax .vmap (mjx .kinematics , in_axes = (None , 0 ))(mjx_model , mjx_data )
375- mjx_data = jax .vmap (mjx .collision , in_axes = (None , 0 ))(mjx_model , mjx_data )
376- data = data .replace (mjx_data = mjx_data )
377- if data .mjx_model is None : # Only modify model if it is part of data
378- return data
379- # Sync model parameters such as mass and inertia for domain randomization
380- # This is currently not supported. See https://github.com/google-deepmind/mujoco/issues/1607
381- # TODO: Implement once mjx supports batching single model fields.
382- return data
383-
384- @staticmethod
385- @jax .jit
386- def sync_mjx2sim (data : SimData ) -> SimData :
387- mjx_data = data .mjx_data
388- qpos = mjx_data .qpos .reshape (data .core .n_worlds , data .core .n_drones , 7 )
389- qvel = mjx_data .qvel .reshape (data .core .n_worlds , data .core .n_drones , 6 )
390- pos , quat = jnp .split (qpos , [3 ], axis = - 1 )
391- vel , ang_vel = jnp .split (qvel , [3 ], axis = - 1 )
392- quat = quat [..., [1 , 2 , 3 , 0 ]] # MuJoCo quat is [w, x, y, z], ours is [x, y, z, w]
393- states = data .states .replace (pos = pos , quat = quat , vel = vel , ang_vel = ang_vel )
394- return data .replace (states = states )
350+ return contacts (geom_start , geom_count , self .mjx_data )
395351
396352 @staticmethod
397353 def _reset (data : SimData , default_data : SimData , mask : Array | None = None ) -> SimData :
@@ -422,8 +378,6 @@ def select_wrench_fn(physics: Physics) -> Callable[[SimData], SimData]:
422378 return analytical_wrench
423379 case Physics .sys_id :
424380 return identified_wrench
425- case Physics .mujoco :
426- return mujoco_wrench
427381 case _:
428382 raise NotImplementedError (f"Physics mode { physics } not implemented" )
429383
@@ -451,37 +405,14 @@ def select_integrate_fn(physics: Physics, integrator: Integrator) -> Callable[[S
451405 case _:
452406 raise NotImplementedError (f"Integrator { integrator } not implemented" )
453407
454- match physics :
455- case Physics .sys_id | Physics .analytical :
456- derivative_fn = select_derivative_fn (physics )
457-
458- def integrate (data : SimData ) -> SimData :
459- data = integrate_fn (data , derivative_fn )
460- data = data .replace (core = data .core .replace (steps = data .core .steps + 1 ))
461- return data
462-
463- return integrate
464- case Physics .mujoco :
465-
466- def integrate (data : SimData ) -> SimData :
467- data = mjx_physics_fn (data )
468- data = data .replace (core = data .core .replace (steps = data .core .steps + 1 ))
469- return data
470-
471- return integrate
472- case _:
473- raise NotImplementedError (f"Physics mode { physics } not implemented" )
408+ derivative_fn = select_derivative_fn (physics )
474409
410+ def integrate (data : SimData ) -> SimData :
411+ data = integrate_fn (data , derivative_fn )
412+ data = data .replace (core = data .core .replace (steps = data .core .steps + 1 ))
413+ return data
475414
476- def select_sync_fn (physics : Physics ) -> Callable [[SimData ], SimData ]:
477- """Select the sync function for the given physics mode."""
478- match physics :
479- case Physics .sys_id | Physics .analytical :
480- return Sim .sync_sim2mjx
481- case Physics .mujoco :
482- return Sim .sync_mjx2sim
483- case _:
484- raise NotImplementedError (f"Physics mode { physics } not implemented" )
415+ return integrate
485416
486417
487418@jax .jit
@@ -511,6 +442,21 @@ def contacts(geom_start: int, geom_count: int, data: Data) -> Array:
511442 return (data .contact .dist < 0 ) & (geom1_valid | geom2_valid )
512443
513444
445+ @jax .jit
446+ def sync_sim2mjx (data : SimData , mjx_data : Data , mjx_model : Model ) -> tuple [SimData , Data ]:
447+ """Synchronize the simulation data with the MuJoCo model."""
448+ states = data .states
449+ pos , quat , vel , ang_vel = states .pos , states .quat , states .vel , states .ang_vel
450+ quat = jnp .roll (quat , 1 , axis = - 1 ) # MuJoCo quat is [w, x, y, z], ours is [x, y, z, w]
451+ qpos = rearrange (jnp .concat ([pos , quat ], axis = - 1 ), "w d qpos -> w (d qpos)" )
452+ qvel = rearrange (jnp .concat ([vel , ang_vel ], axis = - 1 ), "w d qvel -> w (d qvel)" )
453+ mjx_data = mjx_data .replace (qpos = qpos , qvel = qvel )
454+ mjx_data = jax .vmap (mjx .kinematics , in_axes = (None , 0 ))(mjx_model , mjx_data )
455+ mjx_data = jax .vmap (mjx .collision , in_axes = (None , 0 ))(mjx_model , mjx_data )
456+ data = data .replace (core = data .core .replace (mjx_synced = True ))
457+ return data , mjx_data
458+
459+
514460def step_state_controller (data : SimData ) -> SimData :
515461 """Compute the updated controls for the state controller."""
516462 states , controls = data .states , data .controls
@@ -585,33 +531,6 @@ def identified_wrench(data: SimData) -> SimData:
585531identified_derivative = analytical_derivative # We can use the same derivative function for both
586532
587533
588- def mujoco_wrench (data : SimData ) -> SimData :
589- """Compute the wrench from the MuJoCo dynamics model."""
590- forces = rpms2motor_forces (data .controls .rpms )
591- torques = SIGN_MIX_MATRIX [..., 2 ] * rpms2motor_torques (data .controls .rpms )
592- # Zero out external forces and torques to avoid summation over multiple steps
593- states = data .states
594- force , torque = jnp .zeros_like (states .force ), jnp .zeros_like (states .torque )
595- states = states .replace (motor_forces = forces , motor_torques = torques , force = force , torque = torque )
596- return data .replace (states = states )
597-
598-
599- batched_mjx_step = jax .vmap (mjx .step , in_axes = (None , 0 ))
600-
601-
602- def mjx_physics_fn (data : SimData ) -> SimData :
603- """Step the MuJoCo simulation."""
604- force_torques = jnp .concatenate ([data .states .motor_forces , data .states .motor_torques ], axis = - 1 )
605- force_torques = rearrange (force_torques , "w d ft -> w (d ft)" )
606- mjx_data = data .mjx_data .replace (ctrl = force_torques )
607- # Add disturbances from data.states.force/torque with mjx_data.xfrc_applied
608- xfrc = jnp .concatenate ([data .states .force , data .states .torque ], axis = - 1 )
609- xfrc_applied = data .mjx_data .xfrc_applied .at [:, data .core .drone_ids , :].set (xfrc )
610- mjx_data = mjx_data .replace (xfrc_applied = xfrc_applied )
611- mjx_data = batched_mjx_step (data .mjx_model , mjx_data )
612- return data .replace (mjx_data = mjx_data )
613-
614-
615534def identity (data : SimData , * args : Any , ** kwargs : Any ) -> SimData :
616535 """Identity function for the simulation pipeline.
617536
0 commit comments