Skip to content

Commit a32ad15

Browse files
committed
Fix device tracing inside jit
1 parent baf2035 commit a32ad15

3 files changed

Lines changed: 28 additions & 5 deletions

File tree

crazyflow/sim/data.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,8 @@ def create(
209209
class SimCore:
210210
freq: int = field(pytree_node=False)
211211
"""Frequency of the simulation."""
212+
device: Device = field(pytree_node=False)
213+
"""Device of the simulation."""
212214
steps: Array # (N, 1)
213215
"""Simulation steps taken since the last reset."""
214216
n_worlds: int = field(pytree_node=False)
@@ -238,6 +240,7 @@ def create(
238240
rng_key = jax.device_put(rng_key, device)
239241
return SimCore(
240242
freq=freq,
243+
device=device,
241244
steps=steps,
242245
n_worlds=n_worlds,
243246
n_drones=n_drones,

crazyflow/sim/functional.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def state_control(data: SimData, controls: Array) -> SimData:
1818
"""State control function."""
1919
assert data.controls.mode == Control.state, f"control type {data.controls.mode} not enabled"
2020
assert controls.shape == (data.core.n_worlds, data.core.n_drones, 13), "controls shape mismatch"
21-
controls = to_device(controls, data.core.steps.device)
21+
controls = to_device(controls, data.core.device)
2222
data = data.replace(
2323
controls=data.controls.replace(state=data.controls.state.replace(staged_cmd=controls))
2424
)
@@ -36,7 +36,7 @@ def attitude_control(data: SimData, controls: Array) -> SimData:
3636
"""
3737
assert data.controls.mode == Control.attitude, f"control type {data.controls.mode} not enabled"
3838
assert controls.shape == (data.core.n_worlds, data.core.n_drones, 4), "controls shape mismatch"
39-
controls = to_device(controls, data.core.steps.device)
39+
controls = to_device(controls, data.core.device)
4040
data = data.replace(
4141
controls=data.controls.replace(attitude=data.controls.attitude.replace(staged_cmd=controls))
4242
)
@@ -49,7 +49,7 @@ def force_torque_control(data: SimData, controls: Array) -> SimData:
4949
f"control type {data.controls.mode} not enabled"
5050
)
5151
assert controls.shape == (data.core.n_worlds, data.core.n_drones, 4), "controls shape mismatch"
52-
controls = to_device(controls, data.core.steps.device)
52+
controls = to_device(controls, data.core.device)
5353
data = data.replace(
5454
controls=data.controls.replace(
5555
force_torque=data.controls.force_torque.replace(staged_cmd=controls)
@@ -65,7 +65,7 @@ def rotor_vel_control(data: SimData, controls: Array) -> SimData:
6565
"""
6666
assert data.controls.mode == Control.rotor_vel, f"control type {data.controls.mode} not enabled"
6767
assert controls.shape == (data.core.n_worlds, data.core.n_drones, 4), "controls shape mismatch"
68-
controls = to_device(controls, data.core.steps.device)
68+
controls = to_device(controls, data.core.device)
6969
return data.replace(controls=data.controls.replace(rotor_vel=controls))
7070

7171

@@ -81,7 +81,7 @@ def controllable(data: SimData) -> Array:
8181
control_steps = controls.force_torque.steps
8282
control_freq = controls.force_torque.freq
8383
case Control.rotor_vel:
84-
return jnp.ones((data.core.n_worlds, 1), dtype=bool, device=data.core.steps.device)
84+
return jnp.ones((data.core.n_worlds, 1), dtype=bool, device=data.core.device)
8585
case _:
8686
raise NotImplementedError(f"Control mode {data.controls.mode} not implemented")
8787
return _controllable(data.core.steps, data.core.freq, control_steps, control_freq)

tests/unit/test_functional.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
from __future__ import annotations
22

3+
from typing import TYPE_CHECKING
4+
5+
import jax
36
import jax.numpy as jnp
47
import numpy as np
58
import pytest
@@ -8,6 +11,9 @@
811
from crazyflow.control import Control
912
from crazyflow.sim import Sim
1013

14+
if TYPE_CHECKING:
15+
from crazyflow.sim.data import SimData
16+
1117

1218
@pytest.mark.unit
1319
def test_functional_resets():
@@ -80,6 +86,14 @@ def test_functional_attitude_control(attitude_freq: int):
8086
if i == 0: # Make world 2 asynchronous
8187
data = reset_fn(data, default_data, np.array([False, True]))
8288

89+
# Check if we can apply inside of jax jit which does not permit device tracking. See
90+
# https://github.com/jax-ml/jax/issues/26000 for more context.
91+
@jax.jit
92+
def apply_control(data: SimData, cmd: jnp.ndarray) -> SimData:
93+
return F.attitude_control(data, cmd)
94+
95+
jax.block_until_ready(apply_control(data, cmd))
96+
8397

8498
@pytest.mark.unit
8599
def test_functional_attitude_control_device(device: str):
@@ -133,6 +147,12 @@ def test_functional_state_control(state_freq: int):
133147
if i == 0: # Make world 2 asynchronous
134148
data = reset_fn(data, default_data, np.array([False, True]))
135149

150+
@jax.jit
151+
def apply_control(data: SimData, cmd: jnp.ndarray) -> SimData:
152+
return F.state_control(data, cmd)
153+
154+
jax.block_until_ready(apply_control(data, cmd))
155+
136156

137157
@pytest.mark.unit
138158
def test_functional_state_control_device(device: str):

0 commit comments

Comments
 (0)