Skip to content

Commit 057cd6c

Browse files
authored
Add optional visualization for drone lights. (#45)
1 parent b89e1a6 commit 057cd6c

4 files changed

Lines changed: 154 additions & 1 deletion

File tree

crazyflow/sim/visualize.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,49 @@ def draw_points(sim: Sim, points: NDArray, rgba: NDArray | None = None, size: fl
6969
)
7070

7171

72+
def change_material(
73+
sim: Sim,
74+
mat_name: str,
75+
drone_ids: NDArray,
76+
rgba: NDArray | None = None,
77+
emission: NDArray | None = None,
78+
):
79+
"""Change the material of specified drones.
80+
81+
Args:
82+
sim: The simulation.
83+
mat_name: The name of the material to change.
84+
drone_ids: Array of drone indices to modify, shape (n,), dtype=int.
85+
rgba: The RGBA color to set, should be of shape (n, 4) or (4,) to be auto-broadcasted.
86+
emission: The emission value of material, should be of shape (n,) or scalar.
87+
"""
88+
if drone_ids.ndim != 1:
89+
raise ValueError(f"drone_ids must be 1D array, got shape {drone_ids.shape}")
90+
if np.any(drone_ids < 0) or np.any(drone_ids >= sim.n_drones):
91+
raise ValueError(f"drone_ids must be in range [0, {sim.n_drones - 1}], got {drone_ids}")
92+
93+
if rgba is not None:
94+
rgba = np.broadcast_to(rgba, (len(drone_ids), 4))
95+
96+
if emission is not None:
97+
emission = np.broadcast_to(emission, (len(drone_ids),))
98+
99+
mj_model = sim.mj_model
100+
mat_ids = []
101+
for i, drone_id in enumerate(drone_ids):
102+
full_mat_name = f"{mat_name}:{drone_id}"
103+
mat_id = mujoco.mj_name2id(mj_model, mujoco.mjtObj.mjOBJ_MATERIAL, full_mat_name)
104+
if mat_id < 0:
105+
raise ValueError(f"Material '{full_mat_name}' not found in MuJoCo model.")
106+
mat_ids.append(mat_id)
107+
108+
if rgba is not None:
109+
mj_model.mat_rgba[mat_ids] = rgba
110+
111+
if emission is not None:
112+
mj_model.mat_emission[mat_ids] = emission
113+
114+
72115
def _rotation_matrix_from_points(p1: NDArray, p2: NDArray) -> R:
73116
"""Generate rotation matrices that align their z-axis to p2-p1."""
74117
p1, p2 = p1.copy(), p2.copy() # Make sure we don't modify the original arrays

examples/led_deck.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import numpy as np
2+
3+
from crazyflow.control.control import Control
4+
from crazyflow.sim import Sim
5+
from crazyflow.sim.visualize import change_material
6+
7+
8+
def main():
9+
"""Spawn 25 drones in one world and activate led decks."""
10+
sim = Sim(n_drones=25, drone_model="cf21B_500", control=Control.state)
11+
fps = 60
12+
cmd = np.zeros((sim.n_worlds, sim.n_drones, 4))
13+
cmd[..., 3] = sim.data.params.mass[0, 0, 0] * 9.81
14+
rgbas = np.random.default_rng(0).uniform(0, 1, (sim.n_drones, 4))
15+
rgbas[..., 3] = 1.0
16+
17+
init_pos = np.array(sim.data.states.pos[0, :, :])
18+
cmd = np.zeros((sim.n_worlds, sim.n_drones, 13))
19+
cmd[:, :, :3] = init_pos
20+
cmd[:, :, 2] += 1.5
21+
22+
for i in range(int(10 * sim.control_freq)):
23+
sim.state_control(cmd)
24+
sim.step(sim.freq // sim.control_freq)
25+
if ((i * fps) % sim.control_freq) < fps:
26+
even_ids = np.arange(0, sim.n_drones, 2)
27+
odd_ids = np.arange(1, sim.n_drones, 2)
28+
emission = np.sin(i / sim.control_freq * np.pi)
29+
change_material(
30+
sim,
31+
mat_name="led_top",
32+
drone_ids=even_ids,
33+
rgba=rgbas[even_ids, :],
34+
emission=emission,
35+
)
36+
change_material(
37+
sim,
38+
mat_name="led_bot",
39+
drone_ids=odd_ids,
40+
rgba=rgbas[odd_ids, :],
41+
emission=emission,
42+
)
43+
sim.render()
44+
sim.close()
45+
46+
47+
if __name__ == "__main__":
48+
main()

submodules/drone-models

tests/unit/test_sim.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import jax
99
import jax.numpy as jnp
10+
import mujoco
1011
import numpy as np
1112
import pytest
1213
from jax import Array
@@ -16,6 +17,7 @@
1617
from crazyflow.sim import Physics, Sim
1718
from crazyflow.sim.data import ControlData
1819
from crazyflow.sim.sim import sync_sim2mjx
20+
from crazyflow.sim.visualize import change_material
1921

2022
if TYPE_CHECKING:
2123
from typing import Any
@@ -493,3 +495,63 @@ def test_scan_results(physics: Physics):
493495
assert np.all(pos_loop_steps[..., 2] > 0.1), "Drones should have moved"
494496
assert np.allclose(pos_scan_steps, pos_loop_steps), "Scan results should be identical"
495497
sim.close()
498+
499+
500+
@pytest.mark.unit
501+
@pytest.mark.parametrize("drone_model", ["cf2x_L250", "cf2x_P250", "cf2x_T350", "cf21B_500"])
502+
@pytest.mark.parametrize("mat_name", ["led_top", "led_bot"])
503+
def test_change_material(device: str, drone_model: str, mat_name: str):
504+
"""change_material should broadcast RGBA/emission and update MuJoCo materials appropriately."""
505+
n_drones = 2
506+
507+
sim = Sim(n_drones=n_drones, drone_model=drone_model, device=device)
508+
509+
drone_ids = np.array([0, 1], dtype=int)
510+
rgba = 0.42 * np.ones((n_drones, 4), dtype=float)
511+
emission = 0.42 * np.ones((n_drones,), dtype=float)
512+
513+
change_material(sim, mat_name=mat_name, drone_ids=drone_ids, rgba=rgba, emission=emission)
514+
515+
mj_model = sim.mj_model
516+
mat0 = mujoco.mj_name2id(mj_model, mujoco.mjtObj.mjOBJ_MATERIAL, f"{mat_name}:0")
517+
mat1 = mujoco.mj_name2id(mj_model, mujoco.mjtObj.mjOBJ_MATERIAL, f"{mat_name}:1")
518+
expected_rgba = 0.42 * np.ones((4,), dtype=float)
519+
expected_emission = 0.42
520+
np.testing.assert_allclose(mj_model.mat_rgba[mat0], expected_rgba)
521+
np.testing.assert_allclose(mj_model.mat_rgba[mat1], expected_rgba)
522+
assert mj_model.mat_emission[mat0] == pytest.approx(expected_emission)
523+
assert mj_model.mat_emission[mat1] == pytest.approx(expected_emission)
524+
525+
526+
@pytest.mark.unit
527+
def test_change_material_errors(device: str):
528+
"""Test that change_material raises the expected errors for bad inputs."""
529+
n_drones = 2
530+
sim = Sim(n_drones=n_drones, device=device)
531+
532+
drone_ids = np.array([0, 1], dtype=int)
533+
rgba = np.ones((n_drones, 4), dtype=float)
534+
emission = np.ones((n_drones,), dtype=float)
535+
536+
with pytest.raises(ValueError):
537+
change_material(
538+
sim, mat_name="bad_mat", drone_ids=drone_ids, rgba=rgba, emission=emission
539+
)
540+
541+
with pytest.raises(ValueError, match="drone_ids must be 1D array"):
542+
change_material(
543+
sim,
544+
mat_name="led_top",
545+
drone_ids=np.array(2, dtype=int),
546+
rgba=rgba,
547+
emission=emission,
548+
)
549+
550+
with pytest.raises(ValueError, match=r"drone_ids must be in range \[0, 1\]"):
551+
change_material(
552+
sim,
553+
mat_name="led_top",
554+
drone_ids=np.arange(3, dtype=int),
555+
rgba=rgba,
556+
emission=emission,
557+
)

0 commit comments

Comments
 (0)