Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions crazyflow/sim/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,50 @@ def draw_points(sim: Sim, points: NDArray, rgba: NDArray | None = None, size: fl
)


def change_material(
sim: Sim,
mat_name: str,
drone_ids: NDArray,
rgba: NDArray | None = None,
emission: NDArray | None = None,
):
"""Change the material of specified drones.

Args:
sim: The simulation.
mat_name: The name of the material to change.
drone_ids: Array of drone indices to modify, shape (n,), dtype=int.
rgba: The RGBA color to set, should be of shape (n, 4) or (4,) to be auto-broadcasted.
emission: The emission value of material, should be of shape (n,) or scalar.
"""
if drone_ids.ndim != 1:
raise ValueError(f"drone_ids must be 1D array, got shape {drone_ids.shape}")
if np.any(drone_ids < 0) or np.any(drone_ids >= sim.n_drones):
raise ValueError(f"drone_ids must be in range [0, {sim.n_drones - 1}], got {drone_ids}")

if rgba is not None:
# this returns itself if rgba is already the right shape
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# this returns itself if rgba is already the right shape

rgba = np.broadcast_to(rgba, (len(drone_ids), 4))

if emission is not None:
emission = np.broadcast_to(emission, (len(drone_ids),))

mj_model = sim.mj_model
mat_ids = []
for i, drone_id in enumerate(drone_ids):
full_mat_name = f"{mat_name}:{drone_id}"
mat_id = mujoco.mj_name2id(mj_model, mujoco.mjtObj.mjOBJ_MATERIAL, full_mat_name)
if mat_id < 0:
raise ValueError(f"Material '{full_mat_name}' not found in MuJoCo model.")
mat_ids.append(mat_id)

if rgba is not None:
mj_model.mat_rgba[mat_ids, :] = rgba
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is [mat_ids, :] necessary?

Suggested change
mj_model.mat_rgba[mat_ids, :] = rgba
mj_model.mat_rgba[mat_ids] = rgba


if emission is not None:
mj_model.mat_emission[mat_ids] = emission


Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add tests for this function! Even if we are not rendering in the tests, we should at least try to see if this errors as expected and succeeds for the correct input.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added 2 small tests under test_sim.py
One mainly to make sure all the drone models have "led_top" and "led_bot" in xml file.
One to test the error output of the function.

def _rotation_matrix_from_points(p1: NDArray, p2: NDArray) -> R:
"""Generate rotation matrices that align their z-axis to p2-p1."""
p1, p2 = p1.copy(), p2.copy() # Make sure we don't modify the original arrays
Expand Down
48 changes: 48 additions & 0 deletions examples/led_deck.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import numpy as np

from crazyflow.control.control import Control
from crazyflow.sim import Sim
from crazyflow.sim.visualize import change_material


def main():
"""Spawn 25 drones in one world and activate led decks."""
sim = Sim(n_drones=25, drone_model="cf21B_500", control=Control.state)
fps = 60
cmd = np.zeros((sim.n_worlds, sim.n_drones, 4))
cmd[..., 3] = sim.data.params.mass[0, 0, 0] * 9.81
rgbas = np.random.default_rng(0).uniform(0, 1, (sim.n_drones, 4))
rgbas[..., 3] = 1.0

init_pos = np.array(sim.data.states.pos[0, :, :])
cmd = np.zeros((sim.n_worlds, sim.n_drones, 13))
cmd[:, :, :3] = init_pos
cmd[:, :, 2] += 1.5

for i in range(int(10 * sim.control_freq)):
sim.state_control(cmd)
sim.step(sim.freq // sim.control_freq)
if ((i * fps) % sim.control_freq) < fps:
even_ids = np.arange(0, sim.n_drones, 2)
odd_ids = np.arange(1, sim.n_drones, 2)
emission = np.sin(i / sim.control_freq * np.pi)
change_material(
sim,
mat_name="led_top",
drone_ids=even_ids,
rgba=rgbas[even_ids, :],
emission=emission,
)
change_material(
sim,
mat_name="led_bot",
drone_ids=odd_ids,
rgba=rgbas[odd_ids, :],
emission=emission,
)
sim.render()
sim.close()


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion submodules/drone-models
62 changes: 62 additions & 0 deletions tests/unit/test_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import jax
import jax.numpy as jnp
import mujoco
import numpy as np
import pytest
from jax import Array
Expand All @@ -16,6 +17,7 @@
from crazyflow.sim import Physics, Sim
from crazyflow.sim.data import ControlData
from crazyflow.sim.sim import sync_sim2mjx
from crazyflow.sim.visualize import change_material

if TYPE_CHECKING:
from typing import Any
Expand Down Expand Up @@ -493,3 +495,63 @@ def test_scan_results(physics: Physics):
assert np.all(pos_loop_steps[..., 2] > 0.1), "Drones should have moved"
assert np.allclose(pos_scan_steps, pos_loop_steps), "Scan results should be identical"
sim.close()


@pytest.mark.unit
@pytest.mark.parametrize("drone_model", ["cf2x_L250", "cf2x_P250", "cf2x_T350", "cf21B_500"])
@pytest.mark.parametrize("mat_name", ["led_top", "led_bot"])
def test_change_material(device: str, drone_model: str, mat_name: str):
"""change_material should broadcast RGBA/emission and update MuJoCo materials appropriately."""
n_drones = 2

sim = Sim(n_drones=n_drones, drone_model=drone_model, device=device)

drone_ids = np.array([0, 1], dtype=int)
rgba = 0.42 * np.ones((n_drones, 4), dtype=float)
emission = 0.42 * np.ones((n_drones,), dtype=float)

change_material(sim, mat_name=mat_name, drone_ids=drone_ids, rgba=rgba, emission=emission)

mj_model = sim.mj_model
mat0 = mujoco.mj_name2id(mj_model, mujoco.mjtObj.mjOBJ_MATERIAL, f"{mat_name}:0")
mat1 = mujoco.mj_name2id(mj_model, mujoco.mjtObj.mjOBJ_MATERIAL, f"{mat_name}:1")
expected_rgba = 0.42 * np.ones((4,), dtype=float)
expected_emission = 0.42
np.testing.assert_allclose(mj_model.mat_rgba[mat0], expected_rgba)
np.testing.assert_allclose(mj_model.mat_rgba[mat1], expected_rgba)
assert mj_model.mat_emission[mat0] == pytest.approx(expected_emission)
assert mj_model.mat_emission[mat1] == pytest.approx(expected_emission)


@pytest.mark.unit
def test_change_material_errors(device: str):
"""Test that change_material raises the expected errors for bad inputs."""
n_drones = 2
sim = Sim(n_drones=n_drones, device=device)

drone_ids = np.array([0, 1], dtype=int)
rgba = np.ones((n_drones, 4), dtype=float)
emission = np.ones((n_drones,), dtype=float)

with pytest.raises(ValueError):
change_material(
sim, mat_name="bad_mat", drone_ids=drone_ids, rgba=rgba, emission=emission
)

with pytest.raises(ValueError, match=r"drone_ids must be 1D array"):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you need the r-string. But please check.

Suggested change
with pytest.raises(ValueError, match=r"drone_ids must be 1D array"):
with pytest.raises(ValueError, match="drone_ids must be 1D array"):

change_material(
sim,
mat_name="led_top",
drone_ids=np.array(2, dtype=int),
rgba=rgba,
emission=emission,
)

with pytest.raises(ValueError, match=r"drone_ids must be in range \[0, 1\]"):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, but again, please check.

Suggested change
with pytest.raises(ValueError, match=r"drone_ids must be in range \[0, 1\]"):
with pytest.raises(ValueError, match="drone_ids must be in range [0, 1]"):

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked up, so "r" s needed when "[]" appears in a regular expression.
I'll leave it for the second check.

change_material(
sim,
mat_name="led_top",
drone_ids=np.arange(3, dtype=int),
rgba=rgba,
emission=emission,
)