Skip to content

Commit 375cca1

Browse files
committed
Remove mujoco physics backend
1 parent b1541a1 commit 375cca1

7 files changed

Lines changed: 78 additions & 158 deletions

File tree

crazyflow/sim/physics.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
class Physics(str, Enum):
2323
"""Physics mode for the simulation."""
2424

25-
mujoco = "mujoco"
2625
analytical = "analytical"
2726
sys_id = "sys_id"
2827
default = analytical

crazyflow/sim/sim.py

Lines changed: 48 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from functools import partial
3+
from functools import partial, wraps
44
from pathlib import Path
55
from typing import TYPE_CHECKING, Any, Callable
66

@@ -13,7 +13,7 @@
1313
from jax import Array, Device
1414
from jax.scipy.spatial.transform import Rotation as R
1515

16-
from crazyflow.constants import J_INV, MASS, SIGN_MIX_MATRIX, J
16+
from crazyflow.constants import J_INV, MASS, J
1717
from crazyflow.control.control import Control, attitude2rpm, pwm2rpm, state2attitude, thrust2pwm
1818
from crazyflow.exception import ConfigError, NotInitializedError
1919
from crazyflow.sim.integration import Integrator, euler, rk4, symplectic_euler
@@ -22,8 +22,6 @@
2222
collective_force2acceleration,
2323
collective_torque2ang_vel_deriv,
2424
rpms2collective_wrench,
25-
rpms2motor_forces,
26-
rpms2motor_torques,
2725
surrogate_identified_collective_wrench,
2826
)
2927
from crazyflow.sim.structs import SimControls, SimCore, SimData, SimParams, SimState, SimStateDeriv
@@ -34,6 +32,18 @@
3432
from numpy.typing import NDArray
3533

3634

35+
def requires_mujoco_sync(fn: Callable[[SimData], SimData]) -> Callable[[SimData], SimData]:
36+
"""Decorator to ensure that the simulation data is synchronized with the MuJoCo mjx data."""
37+
38+
@wraps(fn)
39+
def wrapper(sim: Sim, *args: Any, **kwargs: Any) -> SimData:
40+
if not sim.data.core.mjx_synced:
41+
sim.data, sim.mjx_data = sync_sim2mjx(sim.data, sim.mjx_data, sim.mjx_model)
42+
return fn(sim, *args, **kwargs)
43+
44+
return wrapper
45+
46+
3747
class Sim:
3848
default_path = Path(__file__).parents[1] / "models/cf2/scene.xml"
3949
drone_path = Path(__file__).parents[1] / "models/cf2/cf2.xml"
@@ -71,10 +81,10 @@ def __init__(
7181
# Initialize MuJoCo world and data
7282
self._xml_path = xml_path or self.default_path
7383
self.spec = self.build_mjx_spec()
74-
self.mj_model, self.mj_data, self.mjx_model, mjx_data = self.build_mjx_model(self.spec)
84+
self.mj_model, self.mj_data, self.mjx_model, self.mjx_data = self.build_mjx_model(self.spec)
7585
self.viewer: MujocoRenderer | None = None
7686

77-
self.data = self.init_data(state_freq, attitude_freq, thrust_freq, rng_key, mjx_data)
87+
self.data = self.init_data(state_freq, attitude_freq, thrust_freq, rng_key)
7888
self.default_data: SimData
7989
self.build_default_data()
8090

@@ -90,9 +100,6 @@ def __init__(
90100
# We never drop below -0.001 (drones can't pass through the floor). We use -0.001 to
91101
# enable checks for negative z sign
92102
self.step_pipeline += (clip_floor_pos,)
93-
# MuJoCo needs to sync after every physics step so that the next step control, wrench
94-
# and disturbance functions see the correct state.
95-
self.step_pipeline += (select_sync_fn(self.physics),)
96103

97104
self.build_reset_fn()
98105
self.build_step_fn()
@@ -140,6 +147,7 @@ def thrust_control(self, cmd: Array):
140147
controls = to_device(cmd, self.device)
141148
self.data = self.data.replace(controls=self.data.controls.replace(thrust=controls))
142149

150+
@requires_mujoco_sync
143151
def render(
144152
self,
145153
mode: str | None = "human",
@@ -160,9 +168,9 @@ def render(
160168
height=height,
161169
width=width,
162170
)
163-
self.mj_data.qpos[:] = self.data.mjx_data.qpos[world, :]
164-
self.mj_data.mocap_pos[:] = self.data.mjx_data.mocap_pos[world, :]
165-
self.mj_data.mocap_quat[:] = self.data.mjx_data.mocap_quat[world, :]
171+
self.mj_data.qpos[:] = self.mjx_data.qpos[world, :]
172+
self.mj_data.mocap_pos[:] = self.mjx_data.mocap_pos[world, :]
173+
self.mj_data.mocap_quat[:] = self.mjx_data.mocap_quat[world, :]
166174
mujoco.mj_forward(self.mj_model, self.mj_data)
167175
return self.viewer.render(mode)
168176

@@ -232,20 +240,8 @@ def single_step(data: SimData, _: None) -> tuple[SimData, None]:
232240
# always use the same n_steps value for successive calls.
233241
@partial(jax.jit, static_argnames="n_steps")
234242
def step(data: SimData, n_steps: int = 1) -> SimData:
235-
# Performance optimization: When step is called, jax checks if it can reuse a previously
236-
# compiled version of the function. This check flattens the sim.data PyTree and compares
237-
# the metadata of each leaf with the cached metadata. The more leaves contained in
238-
# sim.data, the more time is spent on the cache lookup even if the function has already
239-
# been compiled. Since mjx_model contains many PyTree nodes and it is not used by
240-
# physics modes other than mujoco with domain randomization, we set it to None and
241-
# capture the current sim.mjx_model in the step function's closure. Changes to the
242-
# params are synced to mjx_model at the start of the step function.
243-
if optimize_mjx_model := (data.mjx_model is None):
244-
data = data.replace(mjx_model=self.mjx_model)
245-
data = self.sync_sim2mjx(data)
246243
data, _ = jax.lax.scan(single_step, data, length=n_steps, unroll=1)
247-
if optimize_mjx_model:
248-
data = data.replace(mjx_model=None)
244+
data = data.replace(core=data.core.replace(mjx_synced=False)) # Flag mjx data as stale
249245
return data
250246

251247
self._step = step
@@ -259,7 +255,7 @@ def reset(data: SimData, default_data: SimData, mask: Array | None = None) -> Si
259255
data = pytree_replace(data, default_data, mask) # Does not overwrite rng_key
260256
for fn in pipeline:
261257
data = fn(data, mask)
262-
data = self.sync_sim2mjx(data, self.mjx_model)
258+
data = data.replace(core=data.core.replace(mjx_synced=False)) # Flag mjx data as stale
263259
return data
264260

265261
self._reset = reset
@@ -270,7 +266,6 @@ def build_data(self):
270266
self.data.controls.attitude_freq,
271267
self.data.controls.thrust_freq,
272268
self.data.core.rng_key,
273-
self.data.mjx_data,
274269
)
275270

276271
def build_default_data(self):
@@ -281,13 +276,10 @@ def build_mjx(self):
281276
if self.viewer is not None:
282277
self.viewer.close()
283278
self.viewer = None
284-
self.mj_model, self.mj_data, self.mjx_model, mjx_data = self.build_mjx_model(self.spec)
285-
self.data = self.data.replace(mjx_data=mjx_data)
286-
self.data = self.sync_sim2mjx(self.data, self.mjx_model)
287-
self.default_data = self.default_data.replace(mjx_data=mjx_data)
279+
self.mj_model, self.mj_data, self.mjx_model, self.mjx_data = self.build_mjx_model(self.spec)
288280

289281
def init_data(
290-
self, state_freq: int, attitude_freq: int, thrust_freq: int, rng_key: Array, mjx_data: Data
282+
self, state_freq: int, attitude_freq: int, thrust_freq: int, rng_key: Array
291283
) -> tuple[SimData, SimData]:
292284
"""Initialize the simulation data."""
293285
drone_ids = [self.mj_model.body(f"drone:{i}").id for i in range(self.n_drones)]
@@ -298,14 +290,11 @@ def init_data(
298290
controls=SimControls.create(N, D, state_freq, attitude_freq, thrust_freq, self.device),
299291
params=SimParams.create(N, D, MASS, J, J_INV, self.device),
300292
core=SimCore.create(self.freq, N, D, drone_ids, rng_key, self.device),
301-
mjx_data=mjx_data,
302-
mjx_model=None,
303293
)
304294
if D > 1: # If multiple drones, arrange them in a grid
305295
grid = grid_2d(D)
306296
states = data.states.replace(pos=data.states.pos.at[..., :2].set(grid))
307297
data = data.replace(states=states)
308-
data = self.sync_sim2mjx(data, self.mjx_model)
309298
return data
310299

311300
@property
@@ -343,6 +332,7 @@ def controllable(self) -> Array:
343332
raise NotImplementedError(f"Control mode {self.control} not implemented")
344333
return controllable(self.data.core.steps, self.data.core.freq, control_steps, control_freq)
345334

335+
@requires_mujoco_sync
346336
def contacts(self, body: str | None = None) -> Array:
347337
"""Get contact information from the simulation.
348338
@@ -353,45 +343,11 @@ def contacts(self, body: str | None = None) -> Array:
353343
An boolean array of shape (n_worlds,) that is True if any contact is present.
354344
"""
355345
if body is None:
356-
return self.data.mjx_data._impl.contact.dist < 0
346+
return self.mjx_data._impl.contact.dist < 0
357347
body_id = self.mj_model.body(body).id
358348
geom_start = self.mj_model.body_geomadr[body_id]
359349
geom_count = self.mj_model.body_geomnum[body_id]
360-
return contacts(geom_start, geom_count, self.data.mjx_data)
361-
362-
@staticmethod
363-
@jax.jit
364-
def sync_sim2mjx(data: SimData, mjx_model: Model | None = None) -> SimData:
365-
states = data.states
366-
pos, quat, vel, ang_vel = states.pos, states.quat, states.vel, states.ang_vel
367-
quat = quat[..., [3, 0, 1, 2]] # MuJoCo quat is [w, x, y, z], ours is [x, y, z, w]
368-
qpos = rearrange(jnp.concat([pos, quat], axis=-1), "w d qpos -> w (d qpos)")
369-
qvel = rearrange(jnp.concat([vel, ang_vel], axis=-1), "w d qvel -> w (d qvel)")
370-
mjx_data = data.mjx_data
371-
mjx_model = data.mjx_model if mjx_model is None else mjx_model
372-
assert mjx_model is not None, "MuJoCo model is not initialized"
373-
mjx_data = mjx_data.replace(qpos=qpos, qvel=qvel)
374-
mjx_data = jax.vmap(mjx.kinematics, in_axes=(None, 0))(mjx_model, mjx_data)
375-
mjx_data = jax.vmap(mjx.collision, in_axes=(None, 0))(mjx_model, mjx_data)
376-
data = data.replace(mjx_data=mjx_data)
377-
if data.mjx_model is None: # Only modify model if it is part of data
378-
return data
379-
# Sync model parameters such as mass and inertia for domain randomization
380-
# This is currently not supported. See https://github.com/google-deepmind/mujoco/issues/1607
381-
# TODO: Implement once mjx supports batching single model fields.
382-
return data
383-
384-
@staticmethod
385-
@jax.jit
386-
def sync_mjx2sim(data: SimData) -> SimData:
387-
mjx_data = data.mjx_data
388-
qpos = mjx_data.qpos.reshape(data.core.n_worlds, data.core.n_drones, 7)
389-
qvel = mjx_data.qvel.reshape(data.core.n_worlds, data.core.n_drones, 6)
390-
pos, quat = jnp.split(qpos, [3], axis=-1)
391-
vel, ang_vel = jnp.split(qvel, [3], axis=-1)
392-
quat = quat[..., [1, 2, 3, 0]] # MuJoCo quat is [w, x, y, z], ours is [x, y, z, w]
393-
states = data.states.replace(pos=pos, quat=quat, vel=vel, ang_vel=ang_vel)
394-
return data.replace(states=states)
350+
return contacts(geom_start, geom_count, self.mjx_data)
395351

396352
@staticmethod
397353
def _reset(data: SimData, default_data: SimData, mask: Array | None = None) -> SimData:
@@ -422,8 +378,6 @@ def select_wrench_fn(physics: Physics) -> Callable[[SimData], SimData]:
422378
return analytical_wrench
423379
case Physics.sys_id:
424380
return identified_wrench
425-
case Physics.mujoco:
426-
return mujoco_wrench
427381
case _:
428382
raise NotImplementedError(f"Physics mode {physics} not implemented")
429383

@@ -451,37 +405,14 @@ def select_integrate_fn(physics: Physics, integrator: Integrator) -> Callable[[S
451405
case _:
452406
raise NotImplementedError(f"Integrator {integrator} not implemented")
453407

454-
match physics:
455-
case Physics.sys_id | Physics.analytical:
456-
derivative_fn = select_derivative_fn(physics)
457-
458-
def integrate(data: SimData) -> SimData:
459-
data = integrate_fn(data, derivative_fn)
460-
data = data.replace(core=data.core.replace(steps=data.core.steps + 1))
461-
return data
462-
463-
return integrate
464-
case Physics.mujoco:
465-
466-
def integrate(data: SimData) -> SimData:
467-
data = mjx_physics_fn(data)
468-
data = data.replace(core=data.core.replace(steps=data.core.steps + 1))
469-
return data
470-
471-
return integrate
472-
case _:
473-
raise NotImplementedError(f"Physics mode {physics} not implemented")
408+
derivative_fn = select_derivative_fn(physics)
474409

410+
def integrate(data: SimData) -> SimData:
411+
data = integrate_fn(data, derivative_fn)
412+
data = data.replace(core=data.core.replace(steps=data.core.steps + 1))
413+
return data
475414

476-
def select_sync_fn(physics: Physics) -> Callable[[SimData], SimData]:
477-
"""Select the sync function for the given physics mode."""
478-
match physics:
479-
case Physics.sys_id | Physics.analytical:
480-
return Sim.sync_sim2mjx
481-
case Physics.mujoco:
482-
return Sim.sync_mjx2sim
483-
case _:
484-
raise NotImplementedError(f"Physics mode {physics} not implemented")
415+
return integrate
485416

486417

487418
@jax.jit
@@ -511,6 +442,21 @@ def contacts(geom_start: int, geom_count: int, data: Data) -> Array:
511442
return (data.contact.dist < 0) & (geom1_valid | geom2_valid)
512443

513444

445+
@jax.jit
446+
def sync_sim2mjx(data: SimData, mjx_data: Data, mjx_model: Model) -> tuple[SimData, Data]:
447+
"""Synchronize the simulation data with the MuJoCo model."""
448+
states = data.states
449+
pos, quat, vel, ang_vel = states.pos, states.quat, states.vel, states.ang_vel
450+
quat = jnp.roll(quat, 1, axis=-1) # MuJoCo quat is [w, x, y, z], ours is [x, y, z, w]
451+
qpos = rearrange(jnp.concat([pos, quat], axis=-1), "w d qpos -> w (d qpos)")
452+
qvel = rearrange(jnp.concat([vel, ang_vel], axis=-1), "w d qvel -> w (d qvel)")
453+
mjx_data = mjx_data.replace(qpos=qpos, qvel=qvel)
454+
mjx_data = jax.vmap(mjx.kinematics, in_axes=(None, 0))(mjx_model, mjx_data)
455+
mjx_data = jax.vmap(mjx.collision, in_axes=(None, 0))(mjx_model, mjx_data)
456+
data = data.replace(core=data.core.replace(mjx_synced=True))
457+
return data, mjx_data
458+
459+
514460
def step_state_controller(data: SimData) -> SimData:
515461
"""Compute the updated controls for the state controller."""
516462
states, controls = data.states, data.controls
@@ -585,33 +531,6 @@ def identified_wrench(data: SimData) -> SimData:
585531
identified_derivative = analytical_derivative # We can use the same derivative function for both
586532

587533

588-
def mujoco_wrench(data: SimData) -> SimData:
589-
"""Compute the wrench from the MuJoCo dynamics model."""
590-
forces = rpms2motor_forces(data.controls.rpms)
591-
torques = SIGN_MIX_MATRIX[..., 2] * rpms2motor_torques(data.controls.rpms)
592-
# Zero out external forces and torques to avoid summation over multiple steps
593-
states = data.states
594-
force, torque = jnp.zeros_like(states.force), jnp.zeros_like(states.torque)
595-
states = states.replace(motor_forces=forces, motor_torques=torques, force=force, torque=torque)
596-
return data.replace(states=states)
597-
598-
599-
batched_mjx_step = jax.vmap(mjx.step, in_axes=(None, 0))
600-
601-
602-
def mjx_physics_fn(data: SimData) -> SimData:
603-
"""Step the MuJoCo simulation."""
604-
force_torques = jnp.concatenate([data.states.motor_forces, data.states.motor_torques], axis=-1)
605-
force_torques = rearrange(force_torques, "w d ft -> w (d ft)")
606-
mjx_data = data.mjx_data.replace(ctrl=force_torques)
607-
# Add disturbances from data.states.force/torque with mjx_data.xfrc_applied
608-
xfrc = jnp.concatenate([data.states.force, data.states.torque], axis=-1)
609-
xfrc_applied = data.mjx_data.xfrc_applied.at[:, data.core.drone_ids, :].set(xfrc)
610-
mjx_data = mjx_data.replace(xfrc_applied=xfrc_applied)
611-
mjx_data = batched_mjx_step(data.mjx_model, mjx_data)
612-
return data.replace(mjx_data=mjx_data)
613-
614-
615534
def identity(data: SimData, *args: Any, **kwargs: Any) -> SimData:
616535
"""Identity function for the simulation pipeline.
617536

crazyflow/sim/structs.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,8 @@ class SimCore:
186186
"""MuJoCo IDs of the drones in the simulation."""
187187
rng_key: Array # (N, 1)
188188
"""Random number generator key for the simulation."""
189+
mjx_synced: bool = field(pytree_node=False)
190+
"""Whether the simulation data is synchronized with the MuJoCo model."""
189191

190192
@staticmethod
191193
def create(
@@ -208,6 +210,7 @@ def create(
208210
n_drones=n_drones,
209211
drone_ids=jnp.array(drone_ids, dtype=jnp.int32, device=device),
210212
rng_key=rng_key,
213+
mjx_synced=False,
211214
)
212215

213216

@@ -223,10 +226,3 @@ class SimData:
223226
"""Drone parameters."""
224227
core: SimCore
225228
"""Core parameters of the simulation."""
226-
mjx_data: Data
227-
"""MuJoCo data structure."""
228-
mjx_model: Model | None
229-
"""MuJoCo model structure.
230-
231-
Can be set to None for performance optimizations. See `Sim.build_step` for more details.
232-
"""

0 commit comments

Comments
 (0)