Skip to content

Commit 54a4048

Browse files
committed
Add functional API. Include formatting check in CI
1 parent ee01f2c commit 54a4048

7 files changed

Lines changed: 118 additions & 49 deletions

File tree

.github/workflows/ruff.yml

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,16 @@
11
name: Ruff
22
on: [ push, pull_request ]
33
jobs:
4-
ruff:
4+
ruff-check:
55
runs-on: ubuntu-latest
66
steps:
77
- uses: actions/checkout@v4
8-
- uses: astral-sh/ruff-action@v1
8+
- uses: astral-sh/ruff-action@v1
9+
10+
ruff-format:
11+
runs-on: ubuntu-latest
12+
steps:
13+
- uses: actions/checkout@v4
14+
- uses: astral-sh/ruff-action@v1
15+
with:
16+
args: "format --check --diff"

crazyflow/sim/sim.py

Lines changed: 59 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,7 @@ def __init__(
9898
self.viewer: MujocoRenderer | None = None
9999

100100
self.data = self.init_data(state_freq, attitude_freq, force_torque_freq, rng_key)
101-
self.default_data: SimData
102-
self.build_default_data()
101+
self.default_data: SimData = self.build_default_data()
103102

104103
# Build the simulation pipeline and overwrite the default _step implementation with it
105104
self.reset_pipeline: tuple[Callable[[SimData, Array[bool] | None], SimData], ...] = tuple()
@@ -115,8 +114,8 @@ def __init__(
115114
# enable checks for negative z sign
116115
self.step_pipeline += (clip_floor_pos,)
117116

118-
self.build_reset_fn()
119-
self.build_step_fn()
117+
self._reset = self.build_reset_fn()
118+
self._step = self.build_step_fn()
120119

121120
def reset(self, mask: Array | None = None):
122121
"""Reset the simulation to the initial state.
@@ -251,15 +250,23 @@ def build_mjx_model(self, spec: mujoco.MjSpec) -> tuple[Any, Any, Model, Data]:
251250
mjx_data = jax.vmap(lambda _: mjx_data)(jnp.arange(self.n_worlds))
252251
return mj_model, mj_data, mjx_model, mjx_data
253252

254-
def build_step_fn(self):
253+
def build_step_fn(self) -> Callable[[SimData, int], SimData]:
255254
"""Setup the chain of functions that are called in Sim.step().
256255
257256
We know all the functions that are called in succession since the simulation is configured
258257
at initialization time. Instead of branching through options at runtime, we construct a step
259258
function at initialization that selects the correct functions based on the settings.
260259
260+
Note:
261+
This function both changes the underlying implementation of Sim.step() in-place to the
262+
current pipeline and returns the function for pure functional style programming.
263+
261264
Warning:
262265
If any settings change, the pipeline of functions needs to be reconstructed.
266+
267+
Returns:
268+
The pure JAX function that steps through the simulation. It takes the current SimData
269+
and the number of steps to simulate, and returns the updated SimData.
263270
"""
264271
pipeline = self.step_pipeline
265272

@@ -271,9 +278,9 @@ def single_step(data: SimData, _: None) -> tuple[SimData, None]:
271278

272279
# ``scan`` allows us control over loop unrolling for single steps from a single WhileOp to
273280
# complete unrolling, reducing either compilation times or fusing the loops to give XLA
274-
# maximum freedom to reorder operations and jointly optimize the pipeline. This is especially
275-
# relevant for the common use case of running multiple sim steps in an outer loop, e.g. in
276-
# gym environments.
281+
# maximum freedom to reorder operations and jointly optimize the pipeline. This is
282+
# especially relevant for the common use case of running multiple sim steps in an outer
283+
# loop, e.g. in gym environments.
277284
# Having n_steps as a static argument is fine, since patterns with n_steps > 1 will almost
278285
# always use the same n_steps value for successive calls.
279286
@partial(jax.jit, static_argnames="n_steps")
@@ -283,9 +290,19 @@ def step(data: SimData, n_steps: int = 1) -> SimData:
283290
return data
284291

285292
self._step = step
293+
return step
294+
295+
def build_reset_fn(self) -> Callable[[SimData, SimData, Array | None], SimData]:
296+
"""Build the reset function for the current simulation configuration.
286297
287-
def build_reset_fn(self):
288-
"""Build the reset function for the current simulation configuration."""
298+
Note:
299+
This function both changes the underlying implementation of Sim.reset() in-place to the
300+
current pipeline and returns the function for pure functional style programming.
301+
302+
Returns:
303+
The pure JAX function that resets simulation data. It takes the current SimData, default
304+
SimData, and an optional mask for worlds to reset, returning the updated SimData.
305+
"""
289306
pipeline = self.reset_pipeline
290307

291308
@jax.jit
@@ -297,18 +314,43 @@ def reset(data: SimData, default_data: SimData, mask: Array | None = None) -> Si
297314
return data
298315

299316
self._reset = reset
317+
return reset
318+
319+
def build_data(self) -> SimData:
320+
"""Build the simulation data for the current configuration.
300321
301-
def build_data(self):
322+
Note:
323+
This function re-initializes the simulation data according to the current configuration.
324+
It also returns the constructed data for use with pure functions.
325+
326+
Returns:
327+
The simulation data as a single PyTree that can be passed to the pure simulation
328+
functions for stepping and resetting.
329+
"""
330+
state_freq = self.data.controls.state.freq if self.data.controls.state is not None else 0
331+
attitude_freq = (
332+
self.data.controls.attitude.freq if self.data.controls.attitude is not None else 0
333+
)
334+
force_torque_freq = self.data.controls.force_torque.freq
302335
self.data = self.init_data(
303-
self.data.controls.state_freq,
304-
self.data.controls.attitude_freq,
305-
self.data.controls.force_torque_freq,
306-
self.data.core.rng_key,
336+
state_freq, attitude_freq, force_torque_freq, self.data.core.rng_key
307337
)
338+
return self.data
308339

309-
def build_default_data(self):
310-
"""Initialize the default data for the simulation."""
340+
def build_default_data(self) -> SimData:
341+
"""Initialize the default data for the simulation.
342+
343+
Note:
344+
This function initializes the default data used as a reference in the reset function to
345+
reset the simulation to. It also returns the constructed data for use with pure
346+
functions.
347+
348+
Returns:
349+
The default simulation data used as a reference in the reset function to reset the
350+
simulation to.
351+
"""
311352
self.default_data = self.data.replace()
353+
return self.default_data
312354

313355
def build_mjx(self):
314356
if self.viewer is not None:

examples/gymnasium_env.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,10 @@ def main():
1919
action = np.zeros((20, 4), dtype=np.float32)
2020
action[..., 0] = 0.4
2121

22-
# Environments provide reset parameters that can be used to set the initial state of the environment.
22+
# Environments provide reset parameters that can be used to set the initial state of the
23+
# environment.
2324
obs, info = envs.reset(
2425
options={
25-
"pos_min": np.array([-1.0, 1.0, 1.0]),
26-
"pos_max": np.array([-1.0, 1.0, 1.0]),
27-
"vel_min": 0.0,
28-
"vel_max": 0.0,
2926
"goal_pos_min": np.array([-1.0, 1.0, 1.0]),
3027
"goal_pos_max": np.array([-1.0, 1.0, 1.0]),
3128
}

pixi.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,14 @@ indent-width = 4
8181
target-version = "py312"
8282

8383
[tool.ruff.lint]
84-
select = ["E4", "E7", "E9", "F", "I", "D", "TCH", "ANN"]
84+
select = ["E", "F", "I", "D", "TCH", "ANN"]
8585
ignore = ["ANN401"]
8686
fixable = ["ALL"]
8787
unfixable = []
8888

89+
[tool.ruff.lint.isort] # Prevent ruff from reporting conflicting settings with isort
90+
split-on-trailing-comma = false
91+
8992
[tool.ruff.lint.per-file-ignores]
9093
"benchmark/*" = ["D100", "D103"]
9194
"tests/*" = ["D100", "D103", "D104"]

tests/unit/test_sim.py

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from crazyflow.control import Control
1616
from crazyflow.exception import ConfigError
1717
from crazyflow.sim import Physics, Sim
18-
from crazyflow.sim.data import ControlData
18+
from crazyflow.sim.data import ControlData, SimData
1919
from crazyflow.sim.sim import sync_sim2mjx
2020
from crazyflow.sim.visualize import change_material
2121

@@ -534,24 +534,43 @@ def test_change_material_errors(device: str):
534534
emission = np.ones((n_drones,), dtype=float)
535535

536536
with pytest.raises(ValueError):
537-
change_material(
538-
sim, mat_name="bad_mat", drone_ids=drone_ids, rgba=rgba, emission=emission
539-
)
537+
change_material(sim, mat_name="bad_mat", drone_ids=drone_ids, rgba=rgba, emission=emission)
540538

541539
with pytest.raises(ValueError, match="drone_ids must be 1D array"):
542540
change_material(
543-
sim,
544-
mat_name="led_top",
545-
drone_ids=np.array(2, dtype=int),
546-
rgba=rgba,
547-
emission=emission,
541+
sim, mat_name="led_top", drone_ids=np.array(2, dtype=int), rgba=rgba, emission=emission
548542
)
549543

550544
with pytest.raises(ValueError, match=r"drone_ids must be in range \[0, 1\]"):
551545
change_material(
552-
sim,
553-
mat_name="led_top",
554-
drone_ids=np.arange(3, dtype=int),
555-
rgba=rgba,
556-
emission=emission,
546+
sim, mat_name="led_top", drone_ids=np.arange(3, dtype=int), rgba=rgba, emission=emission
557547
)
548+
549+
550+
@pytest.mark.unit
551+
@pytest.mark.parametrize("control", Control)
552+
def test_build_data(control: Control):
553+
sim = Sim(control=control)
554+
data = sim.build_data()
555+
assert isinstance(data, SimData), "build_data() must return a SimData instance"
556+
default_data = sim.build_default_data()
557+
assert isinstance(default_data, SimData), "build_default_data() must return a SimData instance"
558+
559+
560+
@pytest.mark.unit
561+
def test_functional_api():
562+
"""Test that the functional API works as expected."""
563+
sim = Sim()
564+
reset_fn = sim.build_reset_fn()
565+
step_fn = sim.build_step_fn()
566+
# Test types
567+
assert callable(reset_fn), "reset_fn must be a pure function"
568+
assert not hasattr(reset_fn, "__self__"), "reset_fn must not be a bound method"
569+
assert callable(step_fn), "step_fn must be a pure function"
570+
assert not hasattr(step_fn, "__self__"), "step_fn must not be a bound method"
571+
# Test the functions run as expected
572+
data, default_data = sim.build_data(), sim.build_default_data()
573+
data = reset_fn(data, default_data, None)
574+
data = reset_fn(data, default_data, jnp.array([True] * sim.n_worlds))
575+
data = step_fn(data, 1)
576+
data = step_fn(data, 2)

tests/unit/test_utils.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,16 @@ def test_enable_cache(enable_xla: bool):
2727
)
2828

2929
assert cache_path == jax.config.jax_compilation_cache_dir, "Cache path not set correctly"
30-
assert (
31-
min_size == jax.config.jax_persistent_cache_min_entry_size_bytes
32-
), "Min size not set correctly"
33-
assert (
34-
min_time == jax.config.jax_persistent_cache_min_compile_time_secs
35-
), "Min time not set correctly"
30+
assert min_size == jax.config.jax_persistent_cache_min_entry_size_bytes, (
31+
"Min size not set correctly"
32+
)
33+
assert min_time == jax.config.jax_persistent_cache_min_compile_time_secs, (
34+
"Min time not set correctly"
35+
)
3636
expected_xla = "all" if enable_xla else orig_xla
37-
assert (
38-
expected_xla == jax.config.jax_persistent_cache_enable_xla_caches
39-
), "XLA caches not set correctly"
37+
assert expected_xla == jax.config.jax_persistent_cache_enable_xla_caches, (
38+
"XLA caches not set correctly"
39+
)
4040

4141
finally:
4242
if orig_cache_dir is not None:

0 commit comments

Comments
 (0)