Skip to content

Commit 30d09ee

Browse files
committed
Add fused models for swarm optimization
1 parent 564389d commit 30d09ee

5 files changed

Lines changed: 73 additions & 60 deletions

File tree

crazyflow/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,4 @@
2121
from crazyflow.sim import Physics, Sim
2222

2323
__all__ = ["Sim", "Physics", "Control"]
24-
__version__ = "0.1.0"
24+
__version__ = "0.2.0"

crazyflow/sim/data.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,8 @@ class SimCore:
217217
"""Number of worlds in the simulation."""
218218
n_drones: int = field(pytree_node=False)
219219
"""Number of drones in the simulation."""
220-
drone_ids: Array # (1, M)
221-
"""MuJoCo IDs of the drones in the simulation."""
220+
drone_mocap_ids: Array # (M,)
221+
"""MuJoCo mocap IDs of the drone bodies."""
222222
rng_key: Array # (N, 1)
223223
"""Random number generator key for the simulation."""
224224
mjx_synced: Array # (1,)
@@ -229,7 +229,7 @@ def create(
229229
freq: int,
230230
n_worlds: int,
231231
n_drones: int,
232-
drone_ids: Array,
232+
drone_mocap_ids: Array,
233233
rng_key: int | Array,
234234
device: Device,
235235
) -> SimCore:
@@ -244,7 +244,7 @@ def create(
244244
steps=steps,
245245
n_worlds=n_worlds,
246246
n_drones=n_drones,
247-
drone_ids=jnp.array(drone_ids, dtype=jnp.int32, device=device),
247+
drone_mocap_ids=jnp.array(drone_mocap_ids, dtype=jnp.int32, device=device),
248248
rng_key=rng_key,
249249
mjx_synced=jnp.array(False, dtype=jnp.bool_, device=device),
250250
)

crazyflow/sim/sim.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
force_torque2rotor_vel,
1515
state2attitude,
1616
)
17-
from einops import rearrange
1817
from gymnasium.envs.mujoco.mujoco_rendering import MujocoRenderer
1918
from jax import Array, Device
2019

@@ -74,6 +73,7 @@ def __init__(
7473
device: str = "cpu",
7574
xml_path: Path | None = None,
7675
rng_key: int = 0,
76+
fused_mjx_model: bool = False,
7777
):
7878
assert Physics(physics) in Physics, f"Physics mode {physics} not implemented"
7979
assert Control(control) in Control, f"Control mode {control} not implemented"
@@ -94,7 +94,8 @@ def __init__(
9494

9595
# Initialize MuJoCo world and data
9696
self._xml_path = xml_path or Path(__file__).parents[1] / "scene.xml"
97-
self.drone_path = Path(drone_models.__file__).parent / "data" / f"{drone_model}.xml"
97+
model_file_name = f"{drone_model}{'_fused' if fused_mjx_model else ''}.xml"
98+
self.drone_path = Path(drone_models.__file__).parent / "data" / model_file_name
9899
self.spec = self.build_mjx_spec()
99100
self.mj_model, self.mj_data, self.mjx_model, self.mjx_data = self.build_mjx_model(self.spec)
100101
self.viewer: MujocoRenderer | None = None
@@ -216,10 +217,18 @@ def build_mjx_spec(self) -> mujoco.MjSpec:
216217
frame = spec.worldbody.add_frame(name="world")
217218
if (drone_body := drone_spec.body("drone")) is None:
218219
raise ValueError("Drone body not found in drone spec")
219-
# Add drones and their actuators
220+
# Mocap bodies avoid the nv^2 cost of qM/qLD/efc_J. A single dummy slide joint keeps nv=1 so
221+
# mjx.kinematics doesn't error on a zero-DOF model.
222+
dummy = spec.worldbody.add_body()
223+
dummy.name = "_dummy"
224+
dummy.mass = 1e-6
225+
dummy.inertia = jnp.full(3, 1e-9)
226+
dummy_joint = dummy.add_joint()
227+
dummy_joint.name = "_dummy_joint"
228+
dummy_joint.type = mujoco.mjtJoint.mjJNT_SLIDE
229+
drone_body.mocap = True
220230
for i in range(self.n_drones):
221-
drone = frame.attach_body(drone_body, "", f":{i}")
222-
drone.add_freejoint()
231+
frame.attach_body(drone_body, "", f":{i}")
223232
return spec
224233

225234
def build_mjx_model(self, spec: mujoco.MjSpec) -> tuple[Any, Any, Model, Data]:
@@ -341,7 +350,9 @@ def init_data(
341350
self, state_freq: int, attitude_freq: int, force_torque_freq: int, rng_key: Array
342351
) -> SimData:
343352
"""Initialize the simulation data."""
344-
drone_ids = [self.mj_model.body(f"drone:{i}").id for i in range(self.n_drones)]
353+
drone_mocap_ids = [
354+
self.mj_model.body(f"drone:{i}").mocapid.item() for i in range(self.n_drones)
355+
]
345356
N, D = self.n_worlds, self.n_drones
346357
data = SimData(
347358
states=SimState.create(N, D, self.device),
@@ -357,7 +368,7 @@ def init_data(
357368
self.device,
358369
),
359370
params=SimParams.create(N, D, self.physics, self.drone_model, self.device),
360-
core=SimCore.create(self.freq, N, D, drone_ids, rng_key, self.device),
371+
core=SimCore.create(self.freq, N, D, drone_mocap_ids, rng_key, self.device),
361372
)
362373
if D > 1: # If multiple drones, arrange them in a grid
363374
grid = grid_2d(D)
@@ -497,12 +508,12 @@ def contacts(geom_start: int, geom_count: int, data: Data) -> Array:
497508
@jax.jit
498509
def sync_sim2mjx(data: SimData, mjx_data: Data, mjx_model: Model) -> tuple[SimData, Data]:
499510
"""Synchronize the simulation data with the MuJoCo model."""
500-
states = data.states
501-
pos, quat, vel, ang_vel = states.pos, states.quat, states.vel, states.ang_vel
502-
quat = jnp.roll(quat, 1, axis=-1) # MuJoCo quat is [w, x, y, z], ours is [x, y, z, w]
503-
qpos = rearrange(jnp.concat([pos, quat], axis=-1), "w d qpos -> w (d qpos)")
504-
qvel = rearrange(jnp.concat([vel, ang_vel], axis=-1), "w d qvel -> w (d qvel)")
505-
mjx_data = mjx_data.replace(qpos=qpos, qvel=qvel)
511+
pos, quat = data.states.pos, data.states.quat
512+
quat_mjx = jnp.roll(quat, 1, axis=-1) # MuJoCo quat is [w, x, y, z], ours is [x, y, z, w]
513+
ids = data.core.drone_mocap_ids
514+
mocap_pos = mjx_data.mocap_pos.at[:, ids, :].set(pos)
515+
mocap_quat = mjx_data.mocap_quat.at[:, ids, :].set(quat_mjx)
516+
mjx_data = mjx_data.replace(mocap_pos=mocap_pos, mocap_quat=mocap_quat)
506517
mjx_data = jax.vmap(mjx.kinematics, in_axes=(None, 0))(mjx_model, mjx_data)
507518
# Required for rendering w. ray casting
508519
mjx_data = jax.vmap(mjx.camlight, in_axes=(None, 0))(mjx_model, mjx_data)

0 commit comments

Comments
 (0)