@@ -98,8 +98,7 @@ def __init__(
9898 self .viewer : MujocoRenderer | None = None
9999
100100 self .data = self .init_data (state_freq , attitude_freq , force_torque_freq , rng_key )
101- self .default_data : SimData
102- self .build_default_data ()
101+ self .default_data : SimData = self .build_default_data ()
103102
104103 # Build the simulation pipeline and overwrite the default _step implementation with it
105104 self .reset_pipeline : tuple [Callable [[SimData , Array [bool ] | None ], SimData ], ...] = tuple ()
@@ -115,8 +114,8 @@ def __init__(
115114 # enable checks for negative z sign
116115 self .step_pipeline += (clip_floor_pos ,)
117116
118- self .build_reset_fn ()
119- self .build_step_fn ()
117+ self ._reset = self . build_reset_fn ()
118+ self ._step = self . build_step_fn ()
120119
121120 def reset (self , mask : Array | None = None ):
122121 """Reset the simulation to the initial state.
@@ -251,15 +250,23 @@ def build_mjx_model(self, spec: mujoco.MjSpec) -> tuple[Any, Any, Model, Data]:
251250 mjx_data = jax .vmap (lambda _ : mjx_data )(jnp .arange (self .n_worlds ))
252251 return mj_model , mj_data , mjx_model , mjx_data
253252
254- def build_step_fn (self ):
253+ def build_step_fn (self ) -> Callable [[ SimData , int ], SimData ] :
255254 """Setup the chain of functions that are called in Sim.step().
256255
257256 We know all the functions that are called in succession since the simulation is configured
258257 at initialization time. Instead of branching through options at runtime, we construct a step
259258 function at initialization that selects the correct functions based on the settings.
260259
260+ Note:
261+ This function both changes the underlying implementation of Sim.step() in-place to the
262+ current pipeline and returns the function for pure functional style programming.
263+
261264 Warning:
262265 If any settings change, the pipeline of functions needs to be reconstructed.
266+
267+ Returns:
268+ The pure JAX function that steps through the simulation. It takes the current SimData
269+ and the number of steps to simulate, and returns the updated SimData.
263270 """
264271 pipeline = self .step_pipeline
265272
@@ -271,9 +278,9 @@ def single_step(data: SimData, _: None) -> tuple[SimData, None]:
271278
272279 # ``scan`` allows us control over loop unrolling for single steps from a single WhileOp to
273280 # complete unrolling, reducing either compilation times or fusing the loops to give XLA
274- # maximum freedom to reorder operations and jointly optimize the pipeline. This is especially
275- # relevant for the common use case of running multiple sim steps in an outer loop, e.g. in
276- # gym environments.
281+ # maximum freedom to reorder operations and jointly optimize the pipeline. This is
282+ # especially relevant for the common use case of running multiple sim steps in an outer
283+ # loop, e.g. in gym environments.
277284 # Having n_steps as a static argument is fine, since patterns with n_steps > 1 will almost
278285 # always use the same n_steps value for successive calls.
279286 @partial (jax .jit , static_argnames = "n_steps" )
@@ -283,9 +290,19 @@ def step(data: SimData, n_steps: int = 1) -> SimData:
283290 return data
284291
285292 self ._step = step
293+ return step
294+
295+ def build_reset_fn (self ) -> Callable [[SimData , SimData , Array | None ], SimData ]:
296+ """Build the reset function for the current simulation configuration.
286297
287- def build_reset_fn (self ):
288- """Build the reset function for the current simulation configuration."""
298+ Note:
299+ This function both changes the underlying implementation of Sim.reset() in-place to the
300+ current pipeline and returns the function for pure functional style programming.
301+
302+ Returns:
303+ The pure JAX function that resets simulation data. It takes the current SimData, default
304+ SimData, and an optional mask for worlds to reset, returning the updated SimData.
305+ """
289306 pipeline = self .reset_pipeline
290307
291308 @jax .jit
@@ -297,18 +314,43 @@ def reset(data: SimData, default_data: SimData, mask: Array | None = None) -> Si
297314 return data
298315
299316 self ._reset = reset
317+ return reset
318+
319+ def build_data (self ) -> SimData :
320+ """Build the simulation data for the current configuration.
300321
301- def build_data (self ):
322+ Note:
323+ This function re-initializes the simulation data according to the current configuration.
324+ It also returns the constructed data for use with pure functions.
325+
326+ Returns:
327+ The simulation data as a single PyTree that can be passed to the pure simulation
328+ functions for stepping and resetting.
329+ """
330+ state_freq = self .data .controls .state .freq if self .data .controls .state is not None else 0
331+ attitude_freq = (
332+ self .data .controls .attitude .freq if self .data .controls .attitude is not None else 0
333+ )
334+ force_torque_freq = self .data .controls .force_torque .freq
302335 self .data = self .init_data (
303- self .data .controls .state_freq ,
304- self .data .controls .attitude_freq ,
305- self .data .controls .force_torque_freq ,
306- self .data .core .rng_key ,
336+ state_freq , attitude_freq , force_torque_freq , self .data .core .rng_key
307337 )
338+ return self .data
308339
309- def build_default_data (self ):
310- """Initialize the default data for the simulation."""
340+ def build_default_data (self ) -> SimData :
341+ """Initialize the default data for the simulation.
342+
343+ Note:
344+ This function initializes the default data used as a reference in the reset function to
345+ reset the simulation to. It also returns the constructed data for use with pure
346+ functions.
347+
348+ Returns:
349+ The default simulation data used as a reference in the reset function to reset the
350+ simulation to.
351+ """
311352 self .default_data = self .data .replace ()
353+ return self .default_data
312354
313355 def build_mjx (self ):
314356 if self .viewer is not None :
0 commit comments