File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 1212from crazyflow .sim import Sim
1313
1414if TYPE_CHECKING :
15+ from jax import Array
16+
1517 from crazyflow .sim .data import SimData
1618
1719
@@ -89,7 +91,7 @@ def test_functional_attitude_control(attitude_freq: int):
8991 # Check if we can apply inside of jax jit which does not permit device tracking. See
9092 # https://github.com/jax-ml/jax/issues/26000 for more context.
9193 @jax .jit
92- def apply_control (data : SimData , cmd : jnp . ndarray ) -> SimData :
94+ def apply_control (data : SimData , cmd : Array ) -> SimData :
9395 return F .attitude_control (data , cmd )
9496
9597 jax .block_until_ready (apply_control (data , cmd ))
@@ -103,7 +105,7 @@ def test_functional_attitude_control_device(device: str):
103105 cmd = np .random .rand (sim .n_worlds , sim .n_drones , 4 )
104106 data = F .attitude_control (data , cmd )
105107 controls = data .controls .attitude
106- assert isinstance (controls .staged_cmd , jnp . ndarray ), "Buffers must remain JAX arrays"
108+ assert isinstance (controls .staged_cmd , Array ), "Buffers must remain JAX arrays"
107109 assert jnp .all (controls .staged_cmd == cmd ), "Buffers must match command"
108110
109111
You can’t perform that action at this time.
0 commit comments