@@ -75,7 +75,8 @@ def __init__(
7575 self .viewer : MujocoRenderer | None = None
7676
7777 self .data = self .init_data (state_freq , attitude_freq , thrust_freq , rng_key , mjx_data )
78- self .default_data = self .init_default_data ()
78+ self .default_data : SimData
79+ self .build_default_data ()
7980
8081 # Build the simulation pipeline and overwrite the default _step implementation with it
8182 self .reset_pipeline : tuple [Callable [[SimData , Array [bool ] | None ], SimData ], ...] = tuple ()
@@ -96,6 +97,88 @@ def __init__(
9697 self .build_reset_fn ()
9798 self .build_step_fn ()
9899
100+ def reset (self , mask : Array | None = None ):
101+ """Reset the simulation to the initial state.
102+
103+ Args:
104+ mask: Boolean array of shape (n_worlds, ) that indicates which worlds to reset. If None,
105+ all worlds are reset.
106+ """
107+ assert mask is None or mask .shape == (self .n_worlds ,), f"Mask shape mismatch { mask .shape } "
108+ self .data = self ._reset (self .data , self .default_data , mask )
109+
110+ def step (self , n_steps : int = 1 ):
111+ """Simulate all drones in all worlds for n time steps."""
112+ assert n_steps > 0 , "Number of steps must be positive"
113+ self .data = self ._step (self .data , n_steps = n_steps )
114+
115+ def attitude_control (self , controls : Array ):
116+ """Set the desired attitude for all drones in all worlds.
117+
118+ We need to stage the attitude controls because the sys_id physics mode operates directly on
119+ the attitude controls. If we were to directly update the controls, this would effectively
120+ bypass the control frequency and run the attitude controller at the physics update rate. By
121+ staging the controls, we ensure that the physics module sees the old controls until the
122+ controller updates at its correct frequency.
123+ """
124+ assert controls .shape == (self .n_worlds , self .n_drones , 4 ), "controls shape mismatch"
125+ assert self .control == Control .attitude , "Attitude control is not enabled by the sim config"
126+ controls = to_device (controls , self .device )
127+ self .data = self .data .replace (controls = self .data .controls .replace (staged_attitude = controls ))
128+
129+ def state_control (self , controls : Array ):
130+ """Set the desired state for all drones in all worlds."""
131+ assert controls .shape == (self .n_worlds , self .n_drones , 13 ), "controls shape mismatch"
132+ assert self .control == Control .state , "State control is not enabled by the sim config"
133+ controls = to_device (controls , self .device )
134+ self .data = self .data .replace (controls = self .data .controls .replace (state = controls ))
135+
136+ def thrust_control (self , cmd : Array ):
137+ """Set the desired thrust for all drones in all worlds."""
138+ assert cmd .shape == (self .n_worlds , self .n_drones , 4 ), "Command shape mismatch"
139+ assert self .control == Control .thrust , "Thrust control is not enabled by the sim config"
140+ controls = to_device (cmd , self .device )
141+ self .data = self .data .replace (controls = self .data .controls .replace (thrust = controls ))
142+
143+ def render (
144+ self ,
145+ mode : str | None = "human" ,
146+ world : int = 0 ,
147+ default_cam_config : dict | None = None ,
148+ width : int = 640 ,
149+ height : int = 480 ,
150+ ) -> NDArray | None :
151+ if self .viewer is None :
152+ patch_viewer ()
153+ self .mj_model .vis .global_ .offwidth = width
154+ self .mj_model .vis .global_ .offheight = height
155+ self .viewer = MujocoRenderer (
156+ self .mj_model ,
157+ self .mj_data ,
158+ max_geom = self .max_visual_geom ,
159+ default_cam_config = default_cam_config ,
160+ height = height ,
161+ width = width ,
162+ )
163+ self .mj_data .qpos [:] = self .data .mjx_data .qpos [world , :]
164+ self .mj_data .mocap_pos [:] = self .data .mjx_data .mocap_pos [world , :]
165+ self .mj_data .mocap_quat [:] = self .data .mjx_data .mocap_quat [world , :]
166+ mujoco .mj_forward (self .mj_model , self .mj_data )
167+ return self .viewer .render (mode )
168+
169+ def seed (self , seed : int ):
170+ """Set the JAX rng key for the simulation.
171+
172+ Args:
173+ seed: The seed for the JAX rng.
174+ """
175+ self .data = seed_sim (self .data , seed , self .device )
176+
177+ def close (self ):
178+ if self .viewer is not None :
179+ self .viewer .close ()
180+ self .viewer = None
181+
99182 def build_mjx_spec (self ) -> mujoco .MjSpec :
100183 """Build the MuJoCo model specification for the simulation."""
101184 assert self ._xml_path .exists (), f"Model file { self ._xml_path } does not exist"
@@ -167,6 +250,42 @@ def step(data: SimData, n_steps: int = 1) -> SimData:
167250
168251 self ._step = step
169252
253+ def build_reset_fn (self ):
254+ """Build the reset function for the current simulation configuration."""
255+ pipeline = self .reset_pipeline
256+
257+ @jax .jit
258+ def reset (data : SimData , default_data : SimData , mask : Array | None = None ) -> SimData :
259+ data = pytree_replace (data , default_data , mask ) # Does not overwrite rng_key
260+ for fn in pipeline :
261+ data = fn (data , mask )
262+ data = self .sync_sim2mjx (data , self .mjx_model )
263+ return data
264+
265+ self ._reset = reset
266+
267+ def build_data (self ):
268+ self .data = self .init_data (
269+ self .data .controls .state_freq ,
270+ self .data .controls .attitude_freq ,
271+ self .data .controls .thrust_freq ,
272+ self .data .core .rng_key ,
273+ self .data .mjx_data ,
274+ )
275+
276+ def build_default_data (self ):
277+ """Initialize the default data for the simulation."""
278+ self .default_data = self .data .replace ()
279+
280+ def build_mjx (self ):
281+ if self .viewer is not None :
282+ self .viewer .close ()
283+ self .viewer = None
284+ self .mj_model , self .mj_data , self .mjx_model , mjx_data = self .build_mjx_model (self .spec )
285+ self .data = self .data .replace (mjx_data = mjx_data )
286+ self .data = self .sync_sim2mjx (self .data , self .mjx_model )
287+ self .default_data = self .default_data .replace (mjx_data = mjx_data )
288+
170289 def init_data (
171290 self , state_freq : int , attitude_freq : int , thrust_freq : int , rng_key : Array , mjx_data : Data
172291 ) -> tuple [SimData , SimData ]:
@@ -189,161 +308,6 @@ def init_data(
189308 data = self .sync_sim2mjx (data , self .mjx_model )
190309 return data
191310
192- def init_default_data (self ) -> SimData :
193- """Initialize the default data for the simulation.
194-
195- Todo:
196- Only save the data of one world.
197- """
198- return self .data .replace ()
199-
200- def build_reset_fn (self ):
201- """Build the reset function for the current simulation configuration."""
202- pipeline = self .reset_pipeline
203-
204- @jax .jit
205- def reset (data : SimData , default_data : SimData , mask : Array | None = None ) -> SimData :
206- data = pytree_replace (data , default_data , mask ) # Does not overwrite rng_key
207- for fn in pipeline :
208- data = fn (data , mask )
209- data = self .sync_sim2mjx (data , self .mjx_model )
210- return data
211-
212- self ._reset = reset
213-
214- def build (
215- self ,
216- * ,
217- mjx : bool = True ,
218- data : bool = True ,
219- default_data : bool = True ,
220- reset : bool = True ,
221- step : bool = True ,
222- ):
223- """Build the simulation pipeline.
224-
225- This method is used to (re)build the simulation pipeline after changing the MuJoCo
226- model specification or any of the default functions that are used in the compiled step
227- function.
228-
229- Warning:
230- Depending on what you build, you reset the simulation state. For example, rebuilding the
231- simulation data will reset the drone states.
232-
233- Args:
234- mjx: Flag to (re)build the MuJoCo model and data structures.
235- data: Flag to (re)build the simulation data.
236- default_data: Flag to (re)build the default data. Useful for setting the reset state to
237- the current state.
238- reset: Flag to (re)build the reset function.
239- step: Flag to (re)build the simulation step function.
240- """
241- # TODO: Write tests for all options
242- if mjx :
243- if self .viewer is not None :
244- self .viewer .close ()
245- self .viewer = None
246- self .mj_model , self .mj_data , self .mjx_model , mjx_data = self .build_mjx_model (self .spec )
247- self .data = self .data .replace (mjx_data = mjx_data )
248- self .data = self .sync_sim2mjx (self .data , self .mjx_model )
249- self .default_data = self .default_data .replace (mjx_data = mjx_data )
250- if data :
251- self .data = self .init_data (
252- self .data .controls .state_freq ,
253- self .data .controls .attitude_freq ,
254- self .data .controls .thrust_freq ,
255- self .data .core .rng_key ,
256- self .data .mjx_data if not mjx else mjx_data ,
257- )
258- if default_data :
259- self .default_data = self .init_default_data ()
260- if reset :
261- self .build_reset_fn ()
262- if step :
263- self .build_step_fn ()
264-
265- def reset (self , mask : Array | None = None ):
266- """Reset the simulation to the initial state.
267-
268- Args:
269- mask: Boolean array of shape (n_worlds, ) that indicates which worlds to reset. If None,
270- all worlds are reset.
271- """
272- assert mask is None or mask .shape == (self .n_worlds ,), f"Mask shape mismatch { mask .shape } "
273- self .data = self ._reset (self .data , self .default_data , mask )
274-
275- def step (self , n_steps : int = 1 ):
276- """Simulate all drones in all worlds for n time steps."""
277- assert n_steps > 0 , "Number of steps must be positive"
278- self .data = self ._step (self .data , n_steps = n_steps )
279-
280- def attitude_control (self , controls : Array ):
281- """Set the desired attitude for all drones in all worlds.
282-
283- We need to stage the attitude controls because the sys_id physics mode operates directly on
284- the attitude controls. If we were to directly update the controls, this would effectively
285- bypass the control frequency and run the attitude controller at the physics update rate. By
286- staging the controls, we ensure that the physics module sees the old controls until the
287- controller updates at its correct frequency.
288- """
289- assert controls .shape == (self .n_worlds , self .n_drones , 4 ), "controls shape mismatch"
290- assert self .control == Control .attitude , "Attitude control is not enabled by the sim config"
291- controls = to_device (controls , self .device )
292- self .data = self .data .replace (controls = self .data .controls .replace (staged_attitude = controls ))
293-
294- def state_control (self , controls : Array ):
295- """Set the desired state for all drones in all worlds."""
296- assert controls .shape == (self .n_worlds , self .n_drones , 13 ), "controls shape mismatch"
297- assert self .control == Control .state , "State control is not enabled by the sim config"
298- controls = to_device (controls , self .device )
299- self .data = self .data .replace (controls = self .data .controls .replace (state = controls ))
300-
301- def thrust_control (self , cmd : Array ):
302- """Set the desired thrust for all drones in all worlds."""
303- assert cmd .shape == (self .n_worlds , self .n_drones , 4 ), "Command shape mismatch"
304- assert self .control == Control .thrust , "Thrust control is not enabled by the sim config"
305- controls = to_device (cmd , self .device )
306- self .data = self .data .replace (controls = self .data .controls .replace (thrust = controls ))
307-
308- def render (
309- self ,
310- mode : str | None = "human" ,
311- world : int = 0 ,
312- default_cam_config : dict | None = None ,
313- width : int = 640 ,
314- height : int = 480 ,
315- ) -> NDArray | None :
316- if self .viewer is None :
317- patch_viewer ()
318- self .mj_model .vis .global_ .offwidth = width
319- self .mj_model .vis .global_ .offheight = height
320- self .viewer = MujocoRenderer (
321- self .mj_model ,
322- self .mj_data ,
323- max_geom = self .max_visual_geom ,
324- default_cam_config = default_cam_config ,
325- height = height ,
326- width = width ,
327- )
328- self .mj_data .qpos [:] = self .data .mjx_data .qpos [world , :]
329- self .mj_data .mocap_pos [:] = self .data .mjx_data .mocap_pos [world , :]
330- self .mj_data .mocap_quat [:] = self .data .mjx_data .mocap_quat [world , :]
331- mujoco .mj_forward (self .mj_model , self .mj_data )
332- return self .viewer .render (mode )
333-
334- def seed (self , seed : int ):
335- """Set the JAX rng key for the simulation.
336-
337- Args:
338- seed: The seed for the JAX rng.
339- """
340- self .data = seed_sim (self .data , seed , self .device )
341-
342- def close (self ):
343- if self .viewer is not None :
344- self .viewer .close ()
345- self .viewer = None
346-
347311 @property
348312 def time (self ) -> Array :
349313 return self .data .core .steps / self .data .core .freq
0 commit comments