Skip to content

Commit c00113c

Browse files
authored
Fix device tracing inside jit (#56)
1 parent baf2035 commit c00113c

4 files changed

Lines changed: 32 additions & 6 deletions

File tree

crazyflow/sim/data.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,8 @@ def create(
209209
class SimCore:
210210
freq: int = field(pytree_node=False)
211211
"""Frequency of the simulation."""
212+
device: Device = field(pytree_node=False)
213+
"""Device of the simulation."""
212214
steps: Array # (N, 1)
213215
"""Simulation steps taken since the last reset."""
214216
n_worlds: int = field(pytree_node=False)
@@ -238,6 +240,7 @@ def create(
238240
rng_key = jax.device_put(rng_key, device)
239241
return SimCore(
240242
freq=freq,
243+
device=device,
241244
steps=steps,
242245
n_worlds=n_worlds,
243246
n_drones=n_drones,

crazyflow/sim/functional.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def state_control(data: SimData, controls: Array) -> SimData:
1818
"""State control function."""
1919
assert data.controls.mode == Control.state, f"control type {data.controls.mode} not enabled"
2020
assert controls.shape == (data.core.n_worlds, data.core.n_drones, 13), "controls shape mismatch"
21-
controls = to_device(controls, data.core.steps.device)
21+
controls = to_device(controls, data.core.device)
2222
data = data.replace(
2323
controls=data.controls.replace(state=data.controls.state.replace(staged_cmd=controls))
2424
)
@@ -36,7 +36,7 @@ def attitude_control(data: SimData, controls: Array) -> SimData:
3636
"""
3737
assert data.controls.mode == Control.attitude, f"control type {data.controls.mode} not enabled"
3838
assert controls.shape == (data.core.n_worlds, data.core.n_drones, 4), "controls shape mismatch"
39-
controls = to_device(controls, data.core.steps.device)
39+
controls = to_device(controls, data.core.device)
4040
data = data.replace(
4141
controls=data.controls.replace(attitude=data.controls.attitude.replace(staged_cmd=controls))
4242
)
@@ -49,7 +49,7 @@ def force_torque_control(data: SimData, controls: Array) -> SimData:
4949
f"control type {data.controls.mode} not enabled"
5050
)
5151
assert controls.shape == (data.core.n_worlds, data.core.n_drones, 4), "controls shape mismatch"
52-
controls = to_device(controls, data.core.steps.device)
52+
controls = to_device(controls, data.core.device)
5353
data = data.replace(
5454
controls=data.controls.replace(
5555
force_torque=data.controls.force_torque.replace(staged_cmd=controls)
@@ -65,7 +65,7 @@ def rotor_vel_control(data: SimData, controls: Array) -> SimData:
6565
"""
6666
assert data.controls.mode == Control.rotor_vel, f"control type {data.controls.mode} not enabled"
6767
assert controls.shape == (data.core.n_worlds, data.core.n_drones, 4), "controls shape mismatch"
68-
controls = to_device(controls, data.core.steps.device)
68+
controls = to_device(controls, data.core.device)
6969
return data.replace(controls=data.controls.replace(rotor_vel=controls))
7070

7171

@@ -81,7 +81,7 @@ def controllable(data: SimData) -> Array:
8181
control_steps = controls.force_torque.steps
8282
control_freq = controls.force_torque.freq
8383
case Control.rotor_vel:
84-
return jnp.ones((data.core.n_worlds, 1), dtype=bool, device=data.core.steps.device)
84+
return jnp.ones((data.core.n_worlds, 1), dtype=bool, device=data.core.device)
8585
case _:
8686
raise NotImplementedError(f"Control mode {data.controls.mode} not implemented")
8787
return _controllable(data.core.steps, data.core.freq, control_steps, control_freq)

tests/unit/test_functional.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,20 @@
11
from __future__ import annotations
22

3+
from typing import TYPE_CHECKING
4+
5+
import jax
36
import jax.numpy as jnp
47
import numpy as np
58
import pytest
9+
from jax import Array
610

711
import crazyflow.sim.functional as F
812
from crazyflow.control import Control
913
from crazyflow.sim import Sim
1014

15+
if TYPE_CHECKING:
16+
from crazyflow.sim.data import SimData
17+
1118

1219
@pytest.mark.unit
1320
def test_functional_resets():
@@ -80,6 +87,14 @@ def test_functional_attitude_control(attitude_freq: int):
8087
if i == 0: # Make world 2 asynchronous
8188
data = reset_fn(data, default_data, np.array([False, True]))
8289

90+
# Check if we can apply inside of jax jit which does not permit device tracking. See
91+
# https://github.com/jax-ml/jax/issues/26000 for more context.
92+
@jax.jit
93+
def apply_control(data: SimData, cmd: Array) -> SimData:
94+
return F.attitude_control(data, cmd)
95+
96+
jax.block_until_ready(apply_control(data, cmd))
97+
8398

8499
@pytest.mark.unit
85100
def test_functional_attitude_control_device(device: str):
@@ -89,7 +104,7 @@ def test_functional_attitude_control_device(device: str):
89104
cmd = np.random.rand(sim.n_worlds, sim.n_drones, 4)
90105
data = F.attitude_control(data, cmd)
91106
controls = data.controls.attitude
92-
assert isinstance(controls.staged_cmd, jnp.ndarray), "Buffers must remain JAX arrays"
107+
assert isinstance(controls.staged_cmd, Array), "Buffers must remain JAX arrays"
93108
assert jnp.all(controls.staged_cmd == cmd), "Buffers must match command"
94109

95110

@@ -133,6 +148,12 @@ def test_functional_state_control(state_freq: int):
133148
if i == 0: # Make world 2 asynchronous
134149
data = reset_fn(data, default_data, np.array([False, True]))
135150

151+
@jax.jit
152+
def apply_control(data: SimData, cmd: jnp.ndarray) -> SimData:
153+
return F.state_control(data, cmd)
154+
155+
jax.block_until_ready(apply_control(data, cmd))
156+
136157

137158
@pytest.mark.unit
138159
def test_functional_state_control_device(device: str):

tests/unit/test_sim.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,8 @@ def assert_committed(obj0: Array | Any, path: str = "data"):
422422
elif isinstance(obj0, (list, tuple)): # Handle sequences
423423
for i, item0 in enumerate(obj0):
424424
assert_committed(item0, f"{path}[{i}]")
425+
elif isinstance(obj0, type(sim.data.core.device)): # Device objects
426+
pass # Devices themselves don't have committed attribute
425427
else:
426428
raise TypeError(f"Could not handle type {type(obj0)} at {path}")
427429

0 commit comments

Comments
 (0)