Skip to content

Commit 4d93467

Browse files
committed
Cover rotor_vel in controllable
1 parent 9e29a40 commit 4d93467

2 files changed

Lines changed: 17 additions & 0 deletions

File tree

crazyflow/sim/functional.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from typing import TYPE_CHECKING
44

5+
import jax.numpy as jnp
6+
57
from crazyflow.control import Control
68
from crazyflow.control.control import controllable as _controllable
79
from crazyflow.utils import to_device
@@ -78,6 +80,8 @@ def controllable(data: SimData) -> Array:
7880
case Control.force_torque:
7981
control_steps = controls.force_torque.steps
8082
control_freq = controls.force_torque.freq
83+
case Control.rotor_vel:
84+
return jnp.ones((data.core.n_worlds, 1), dtype=bool, device=data.core.steps.device)
8185
case _:
8286
raise NotImplementedError(f"Control mode {data.controls.mode} not implemented")
8387
return _controllable(data.core.steps, data.core.freq, control_steps, control_freq)

tests/unit/test_functional.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,3 +145,16 @@ def test_functional_state_control_device(device: str):
145145
assert isinstance(controls.cmd, jnp.ndarray), "Buffers must remain JAX arrays"
146146
assert isinstance(controls.staged_cmd, jnp.ndarray), "Buffers must remain JAX arrays"
147147
assert jnp.all(controls.staged_cmd == cmd), "Buffers must match command"
148+
149+
150+
@pytest.mark.unit
151+
@pytest.mark.parametrize("control", Control)
152+
def test_functional_controllable(control: Control):
153+
"""Test that functional controllable function works correctly."""
154+
sim = Sim(n_worlds=2, n_drones=3, control=control)
155+
data = sim.build_data()
156+
controllable = F.controllable(data)
157+
assert isinstance(controllable, jnp.ndarray), "Controllable must be a JAX array"
158+
shape = controllable.shape
159+
des_shape = (sim.n_worlds, 1)
160+
assert shape == des_shape, f"Controllable shape must be {des_shape}, got {shape}"

0 commit comments

Comments
 (0)