Skip to content

Commit a6f6458

Browse files
committed
Improve build interface
1 parent 3705f9b commit a6f6458

3 files changed

Lines changed: 122 additions & 158 deletions

File tree

crazyflow/sim/sim.py

Lines changed: 120 additions & 156 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,8 @@ def __init__(
7575
self.viewer: MujocoRenderer | None = None
7676

7777
self.data = self.init_data(state_freq, attitude_freq, thrust_freq, rng_key, mjx_data)
78-
self.default_data = self.init_default_data()
78+
self.default_data: SimData
79+
self.build_default_data()
7980

8081
# Build the simulation pipeline and overwrite the default _step implementation with it
8182
self.reset_pipeline: tuple[Callable[[SimData, Array[bool] | None], SimData], ...] = tuple()
@@ -96,6 +97,88 @@ def __init__(
9697
self.build_reset_fn()
9798
self.build_step_fn()
9899

100+
def reset(self, mask: Array | None = None):
101+
"""Reset the simulation to the initial state.
102+
103+
Args:
104+
mask: Boolean array of shape (n_worlds, ) that indicates which worlds to reset. If None,
105+
all worlds are reset.
106+
"""
107+
assert mask is None or mask.shape == (self.n_worlds,), f"Mask shape mismatch {mask.shape}"
108+
self.data = self._reset(self.data, self.default_data, mask)
109+
110+
def step(self, n_steps: int = 1):
111+
"""Simulate all drones in all worlds for n time steps."""
112+
assert n_steps > 0, "Number of steps must be positive"
113+
self.data = self._step(self.data, n_steps=n_steps)
114+
115+
def attitude_control(self, controls: Array):
116+
"""Set the desired attitude for all drones in all worlds.
117+
118+
We need to stage the attitude controls because the sys_id physics mode operates directly on
119+
the attitude controls. If we were to directly update the controls, this would effectively
120+
bypass the control frequency and run the attitude controller at the physics update rate. By
121+
staging the controls, we ensure that the physics module sees the old controls until the
122+
controller updates at its correct frequency.
123+
"""
124+
assert controls.shape == (self.n_worlds, self.n_drones, 4), "controls shape mismatch"
125+
assert self.control == Control.attitude, "Attitude control is not enabled by the sim config"
126+
controls = to_device(controls, self.device)
127+
self.data = self.data.replace(controls=self.data.controls.replace(staged_attitude=controls))
128+
129+
def state_control(self, controls: Array):
130+
"""Set the desired state for all drones in all worlds."""
131+
assert controls.shape == (self.n_worlds, self.n_drones, 13), "controls shape mismatch"
132+
assert self.control == Control.state, "State control is not enabled by the sim config"
133+
controls = to_device(controls, self.device)
134+
self.data = self.data.replace(controls=self.data.controls.replace(state=controls))
135+
136+
def thrust_control(self, cmd: Array):
137+
"""Set the desired thrust for all drones in all worlds."""
138+
assert cmd.shape == (self.n_worlds, self.n_drones, 4), "Command shape mismatch"
139+
assert self.control == Control.thrust, "Thrust control is not enabled by the sim config"
140+
controls = to_device(cmd, self.device)
141+
self.data = self.data.replace(controls=self.data.controls.replace(thrust=controls))
142+
143+
def render(
144+
self,
145+
mode: str | None = "human",
146+
world: int = 0,
147+
default_cam_config: dict | None = None,
148+
width: int = 640,
149+
height: int = 480,
150+
) -> NDArray | None:
151+
if self.viewer is None:
152+
patch_viewer()
153+
self.mj_model.vis.global_.offwidth = width
154+
self.mj_model.vis.global_.offheight = height
155+
self.viewer = MujocoRenderer(
156+
self.mj_model,
157+
self.mj_data,
158+
max_geom=self.max_visual_geom,
159+
default_cam_config=default_cam_config,
160+
height=height,
161+
width=width,
162+
)
163+
self.mj_data.qpos[:] = self.data.mjx_data.qpos[world, :]
164+
self.mj_data.mocap_pos[:] = self.data.mjx_data.mocap_pos[world, :]
165+
self.mj_data.mocap_quat[:] = self.data.mjx_data.mocap_quat[world, :]
166+
mujoco.mj_forward(self.mj_model, self.mj_data)
167+
return self.viewer.render(mode)
168+
169+
def seed(self, seed: int):
170+
"""Set the JAX rng key for the simulation.
171+
172+
Args:
173+
seed: The seed for the JAX rng.
174+
"""
175+
self.data = seed_sim(self.data, seed, self.device)
176+
177+
def close(self):
178+
if self.viewer is not None:
179+
self.viewer.close()
180+
self.viewer = None
181+
99182
def build_mjx_spec(self) -> mujoco.MjSpec:
100183
"""Build the MuJoCo model specification for the simulation."""
101184
assert self._xml_path.exists(), f"Model file {self._xml_path} does not exist"
@@ -167,6 +250,42 @@ def step(data: SimData, n_steps: int = 1) -> SimData:
167250

168251
self._step = step
169252

253+
def build_reset_fn(self):
254+
"""Build the reset function for the current simulation configuration."""
255+
pipeline = self.reset_pipeline
256+
257+
@jax.jit
258+
def reset(data: SimData, default_data: SimData, mask: Array | None = None) -> SimData:
259+
data = pytree_replace(data, default_data, mask) # Does not overwrite rng_key
260+
for fn in pipeline:
261+
data = fn(data, mask)
262+
data = self.sync_sim2mjx(data, self.mjx_model)
263+
return data
264+
265+
self._reset = reset
266+
267+
def build_data(self):
268+
self.data = self.init_data(
269+
self.data.controls.state_freq,
270+
self.data.controls.attitude_freq,
271+
self.data.controls.thrust_freq,
272+
self.data.core.rng_key,
273+
self.data.mjx_data,
274+
)
275+
276+
def build_default_data(self):
277+
"""Initialize the default data for the simulation."""
278+
self.default_data = self.data.replace()
279+
280+
def build_mjx(self):
281+
if self.viewer is not None:
282+
self.viewer.close()
283+
self.viewer = None
284+
self.mj_model, self.mj_data, self.mjx_model, mjx_data = self.build_mjx_model(self.spec)
285+
self.data = self.data.replace(mjx_data=mjx_data)
286+
self.data = self.sync_sim2mjx(self.data, self.mjx_model)
287+
self.default_data = self.default_data.replace(mjx_data=mjx_data)
288+
170289
def init_data(
171290
self, state_freq: int, attitude_freq: int, thrust_freq: int, rng_key: Array, mjx_data: Data
172291
) -> tuple[SimData, SimData]:
@@ -189,161 +308,6 @@ def init_data(
189308
data = self.sync_sim2mjx(data, self.mjx_model)
190309
return data
191310

192-
def init_default_data(self) -> SimData:
193-
"""Initialize the default data for the simulation.
194-
195-
Todo:
196-
Only save the data of one world.
197-
"""
198-
return self.data.replace()
199-
200-
def build_reset_fn(self):
201-
"""Build the reset function for the current simulation configuration."""
202-
pipeline = self.reset_pipeline
203-
204-
@jax.jit
205-
def reset(data: SimData, default_data: SimData, mask: Array | None = None) -> SimData:
206-
data = pytree_replace(data, default_data, mask) # Does not overwrite rng_key
207-
for fn in pipeline:
208-
data = fn(data, mask)
209-
data = self.sync_sim2mjx(data, self.mjx_model)
210-
return data
211-
212-
self._reset = reset
213-
214-
def build(
215-
self,
216-
*,
217-
mjx: bool = True,
218-
data: bool = True,
219-
default_data: bool = True,
220-
reset: bool = True,
221-
step: bool = True,
222-
):
223-
"""Build the simulation pipeline.
224-
225-
This method is used to (re)build the simulation pipeline after changing the MuJoCo
226-
model specification or any of the default functions that are used in the compiled step
227-
function.
228-
229-
Warning:
230-
Depending on what you build, you reset the simulation state. For example, rebuilding the
231-
simulation data will reset the drone states.
232-
233-
Args:
234-
mjx: Flag to (re)build the MuJoCo model and data structures.
235-
data: Flag to (re)build the simulation data.
236-
default_data: Flag to (re)build the default data. Useful for setting the reset state to
237-
the current state.
238-
reset: Flag to (re)build the reset function.
239-
step: Flag to (re)build the simulation step function.
240-
"""
241-
# TODO: Write tests for all options
242-
if mjx:
243-
if self.viewer is not None:
244-
self.viewer.close()
245-
self.viewer = None
246-
self.mj_model, self.mj_data, self.mjx_model, mjx_data = self.build_mjx_model(self.spec)
247-
self.data = self.data.replace(mjx_data=mjx_data)
248-
self.data = self.sync_sim2mjx(self.data, self.mjx_model)
249-
self.default_data = self.default_data.replace(mjx_data=mjx_data)
250-
if data:
251-
self.data = self.init_data(
252-
self.data.controls.state_freq,
253-
self.data.controls.attitude_freq,
254-
self.data.controls.thrust_freq,
255-
self.data.core.rng_key,
256-
self.data.mjx_data if not mjx else mjx_data,
257-
)
258-
if default_data:
259-
self.default_data = self.init_default_data()
260-
if reset:
261-
self.build_reset_fn()
262-
if step:
263-
self.build_step_fn()
264-
265-
def reset(self, mask: Array | None = None):
266-
"""Reset the simulation to the initial state.
267-
268-
Args:
269-
mask: Boolean array of shape (n_worlds, ) that indicates which worlds to reset. If None,
270-
all worlds are reset.
271-
"""
272-
assert mask is None or mask.shape == (self.n_worlds,), f"Mask shape mismatch {mask.shape}"
273-
self.data = self._reset(self.data, self.default_data, mask)
274-
275-
def step(self, n_steps: int = 1):
276-
"""Simulate all drones in all worlds for n time steps."""
277-
assert n_steps > 0, "Number of steps must be positive"
278-
self.data = self._step(self.data, n_steps=n_steps)
279-
280-
def attitude_control(self, controls: Array):
281-
"""Set the desired attitude for all drones in all worlds.
282-
283-
We need to stage the attitude controls because the sys_id physics mode operates directly on
284-
the attitude controls. If we were to directly update the controls, this would effectively
285-
bypass the control frequency and run the attitude controller at the physics update rate. By
286-
staging the controls, we ensure that the physics module sees the old controls until the
287-
controller updates at its correct frequency.
288-
"""
289-
assert controls.shape == (self.n_worlds, self.n_drones, 4), "controls shape mismatch"
290-
assert self.control == Control.attitude, "Attitude control is not enabled by the sim config"
291-
controls = to_device(controls, self.device)
292-
self.data = self.data.replace(controls=self.data.controls.replace(staged_attitude=controls))
293-
294-
def state_control(self, controls: Array):
295-
"""Set the desired state for all drones in all worlds."""
296-
assert controls.shape == (self.n_worlds, self.n_drones, 13), "controls shape mismatch"
297-
assert self.control == Control.state, "State control is not enabled by the sim config"
298-
controls = to_device(controls, self.device)
299-
self.data = self.data.replace(controls=self.data.controls.replace(state=controls))
300-
301-
def thrust_control(self, cmd: Array):
302-
"""Set the desired thrust for all drones in all worlds."""
303-
assert cmd.shape == (self.n_worlds, self.n_drones, 4), "Command shape mismatch"
304-
assert self.control == Control.thrust, "Thrust control is not enabled by the sim config"
305-
controls = to_device(cmd, self.device)
306-
self.data = self.data.replace(controls=self.data.controls.replace(thrust=controls))
307-
308-
def render(
309-
self,
310-
mode: str | None = "human",
311-
world: int = 0,
312-
default_cam_config: dict | None = None,
313-
width: int = 640,
314-
height: int = 480,
315-
) -> NDArray | None:
316-
if self.viewer is None:
317-
patch_viewer()
318-
self.mj_model.vis.global_.offwidth = width
319-
self.mj_model.vis.global_.offheight = height
320-
self.viewer = MujocoRenderer(
321-
self.mj_model,
322-
self.mj_data,
323-
max_geom=self.max_visual_geom,
324-
default_cam_config=default_cam_config,
325-
height=height,
326-
width=width,
327-
)
328-
self.mj_data.qpos[:] = self.data.mjx_data.qpos[world, :]
329-
self.mj_data.mocap_pos[:] = self.data.mjx_data.mocap_pos[world, :]
330-
self.mj_data.mocap_quat[:] = self.data.mjx_data.mocap_quat[world, :]
331-
mujoco.mj_forward(self.mj_model, self.mj_data)
332-
return self.viewer.render(mode)
333-
334-
def seed(self, seed: int):
335-
"""Set the JAX rng key for the simulation.
336-
337-
Args:
338-
seed: The seed for the JAX rng.
339-
"""
340-
self.data = seed_sim(self.data, seed, self.device)
341-
342-
def close(self):
343-
if self.viewer is not None:
344-
self.viewer.close()
345-
self.viewer = None
346-
347311
@property
348312
def time(self) -> Array:
349313
return self.data.core.steps / self.data.core.freq

examples/disturbance.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def main(plot: bool = False):
3939
# inspect the step pipeline with
4040
# print(sim.step_pipeline)
4141
sim.step_pipeline = sim.step_pipeline[:2] + (disturbance_fn,) + sim.step_pipeline[2:]
42-
sim.build(mjx=False, data=False, default_data=False, step=True)
42+
sim.build_step_fn()
4343
pos_disturbed, rpy_disturbed = [], []
4444
sim.reset()
4545
for _ in range(3 * sim.control_freq):

tests/integration/test_disturbance.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def test_disturbance(physics: Physics):
2929

3030
sim.reset()
3131
sim.step_pipeline = sim.step_pipeline[:2] + (disturbance_fn,) + sim.step_pipeline[2:]
32-
sim.build(mjx=False, data=False, default_data=False)
32+
sim.build_step_fn()
3333
for _ in range(sim.control_freq):
3434
sim.state_control(control)
3535
sim.step(sim.freq // sim.control_freq)

0 commit comments

Comments
 (0)