Skip to content

Commit b84a431

Browse files
committed
Fix type hints
1 parent a32ad15 commit b84a431

1 file changed

Lines changed: 4 additions & 2 deletions

File tree

tests/unit/test_functional.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from crazyflow.sim import Sim
1313

1414
if TYPE_CHECKING:
15+
from jax import Array
16+
1517
from crazyflow.sim.data import SimData
1618

1719

@@ -89,7 +91,7 @@ def test_functional_attitude_control(attitude_freq: int):
8991
# Check if we can apply inside of jax jit which does not permit device tracking. See
9092
# https://github.com/jax-ml/jax/issues/26000 for more context.
9193
@jax.jit
92-
def apply_control(data: SimData, cmd: jnp.ndarray) -> SimData:
94+
def apply_control(data: SimData, cmd: Array) -> SimData:
9395
return F.attitude_control(data, cmd)
9496

9597
jax.block_until_ready(apply_control(data, cmd))
@@ -103,7 +105,7 @@ def test_functional_attitude_control_device(device: str):
103105
cmd = np.random.rand(sim.n_worlds, sim.n_drones, 4)
104106
data = F.attitude_control(data, cmd)
105107
controls = data.controls.attitude
106-
assert isinstance(controls.staged_cmd, jnp.ndarray), "Buffers must remain JAX arrays"
108+
assert isinstance(controls.staged_cmd, Array), "Buffers must remain JAX arrays"
107109
assert jnp.all(controls.staged_cmd == cmd), "Buffers must match command"
108110

109111

0 commit comments

Comments
 (0)