1414 force_torque2rotor_vel ,
1515 state2attitude ,
1616)
17- from einops import rearrange
1817from gymnasium .envs .mujoco .mujoco_rendering import MujocoRenderer
1918from jax import Array , Device
2019
@@ -74,6 +73,7 @@ def __init__(
7473 device : str = "cpu" ,
7574 xml_path : Path | None = None ,
7675 rng_key : int = 0 ,
76+ fused_mjx_model : bool = False ,
7777 ):
7878 assert Physics (physics ) in Physics , f"Physics mode { physics } not implemented"
7979 assert Control (control ) in Control , f"Control mode { control } not implemented"
@@ -94,7 +94,8 @@ def __init__(
9494
9595 # Initialize MuJoCo world and data
9696 self ._xml_path = xml_path or Path (__file__ ).parents [1 ] / "scene.xml"
97- self .drone_path = Path (drone_models .__file__ ).parent / "data" / f"{ drone_model } .xml"
97+ model_file_name = f"{ drone_model } { '_fused' if fused_mjx_model else '' } .xml"
98+ self .drone_path = Path (drone_models .__file__ ).parent / "data" / model_file_name
9899 self .spec = self .build_mjx_spec ()
99100 self .mj_model , self .mj_data , self .mjx_model , self .mjx_data = self .build_mjx_model (self .spec )
100101 self .viewer : MujocoRenderer | None = None
@@ -216,10 +217,18 @@ def build_mjx_spec(self) -> mujoco.MjSpec:
216217 frame = spec .worldbody .add_frame (name = "world" )
217218 if (drone_body := drone_spec .body ("drone" )) is None :
218219 raise ValueError ("Drone body not found in drone spec" )
219- # Add drones and their actuators
220+ # Mocap bodies avoid the nv^2 cost of qM/qLD/efc_J. A single dummy slide joint keeps nv=1 so
221+ # mjx.kinematics doesn't error on a zero-DOF model.
222+ dummy = spec .worldbody .add_body ()
223+ dummy .name = "_dummy"
224+ dummy .mass = 1e-6
225+ dummy .inertia = jnp .full (3 , 1e-9 )
226+ dummy_joint = dummy .add_joint ()
227+ dummy_joint .name = "_dummy_joint"
228+ dummy_joint .type = mujoco .mjtJoint .mjJNT_SLIDE
229+ drone_body .mocap = True
220230 for i in range (self .n_drones ):
221- drone = frame .attach_body (drone_body , "" , f":{ i } " )
222- drone .add_freejoint ()
231+ frame .attach_body (drone_body , "" , f":{ i } " )
223232 return spec
224233
225234 def build_mjx_model (self , spec : mujoco .MjSpec ) -> tuple [Any , Any , Model , Data ]:
@@ -341,7 +350,9 @@ def init_data(
341350 self , state_freq : int , attitude_freq : int , force_torque_freq : int , rng_key : Array
342351 ) -> SimData :
343352 """Initialize the simulation data."""
344- drone_ids = [self .mj_model .body (f"drone:{ i } " ).id for i in range (self .n_drones )]
353+ drone_mocap_ids = [
354+ self .mj_model .body (f"drone:{ i } " ).mocapid .item () for i in range (self .n_drones )
355+ ]
345356 N , D = self .n_worlds , self .n_drones
346357 data = SimData (
347358 states = SimState .create (N , D , self .device ),
@@ -357,7 +368,7 @@ def init_data(
357368 self .device ,
358369 ),
359370 params = SimParams .create (N , D , self .physics , self .drone_model , self .device ),
360- core = SimCore .create (self .freq , N , D , drone_ids , rng_key , self .device ),
371+ core = SimCore .create (self .freq , N , D , drone_mocap_ids , rng_key , self .device ),
361372 )
362373 if D > 1 : # If multiple drones, arrange them in a grid
363374 grid = grid_2d (D )
@@ -497,12 +508,12 @@ def contacts(geom_start: int, geom_count: int, data: Data) -> Array:
497508@jax .jit
498509def sync_sim2mjx (data : SimData , mjx_data : Data , mjx_model : Model ) -> tuple [SimData , Data ]:
499510 """Synchronize the simulation data with the MuJoCo model."""
500- states = data .states
501- pos , quat , vel , ang_vel = states . pos , states . quat , states . vel , states . ang_vel
502- quat = jnp . roll ( quat , 1 , axis = - 1 ) # MuJoCo quat is [w, x, y, z], ours is [x, y, z, w]
503- qpos = rearrange ( jnp . concat ([ pos , quat ], axis = - 1 ), "w d qpos -> w (d qpos)" )
504- qvel = rearrange ( jnp . concat ([ vel , ang_vel ], axis = - 1 ), "w d qvel -> w (d qvel)" )
505- mjx_data = mjx_data .replace (qpos = qpos , qvel = qvel )
511+ pos , quat = data .states . pos , data . states . quat
512+ quat_mjx = jnp . roll ( quat , 1 , axis = - 1 ) # MuJoCo quat is [w, x, y, z], ours is [x, y, z, w]
513+ ids = data . core . drone_mocap_ids
514+ mocap_pos = mjx_data . mocap_pos . at [:, ids , :]. set ( pos )
515+ mocap_quat = mjx_data . mocap_quat . at [:, ids , :]. set ( quat_mjx )
516+ mjx_data = mjx_data .replace (mocap_pos = mocap_pos , mocap_quat = mocap_quat )
506517 mjx_data = jax .vmap (mjx .kinematics , in_axes = (None , 0 ))(mjx_model , mjx_data )
507518 # Required for rendering w. ray casting
508519 mjx_data = jax .vmap (mjx .camlight , in_axes = (None , 0 ))(mjx_model , mjx_data )
0 commit comments