|
7 | 7 |
|
8 | 8 | import jax |
9 | 9 | import jax.numpy as jnp |
| 10 | +import mujoco |
10 | 11 | import numpy as np |
11 | 12 | import pytest |
12 | 13 | from jax import Array |
|
16 | 17 | from crazyflow.sim import Physics, Sim |
17 | 18 | from crazyflow.sim.data import ControlData |
18 | 19 | from crazyflow.sim.sim import sync_sim2mjx |
| 20 | +from crazyflow.sim.visualize import change_material |
19 | 21 |
|
20 | 22 | if TYPE_CHECKING: |
21 | 23 | from typing import Any |
@@ -493,3 +495,63 @@ def test_scan_results(physics: Physics): |
493 | 495 | assert np.all(pos_loop_steps[..., 2] > 0.1), "Drones should have moved" |
494 | 496 | assert np.allclose(pos_scan_steps, pos_loop_steps), "Scan results should be identical" |
495 | 497 | sim.close() |
| 498 | + |
| 499 | + |
| 500 | +@pytest.mark.unit |
| 501 | +@pytest.mark.parametrize("drone_model", ["cf2x_L250", "cf2x_P250", "cf2x_T350", "cf21B_500"]) |
| 502 | +@pytest.mark.parametrize("mat_name", ["led_top", "led_bot"]) |
| 503 | +def test_change_material(device: str, drone_model: str, mat_name: str): |
| 504 | + """change_material should broadcast RGBA/emission and update MuJoCo materials appropriately.""" |
| 505 | + n_drones = 2 |
| 506 | + |
| 507 | + sim = Sim(n_drones=n_drones, drone_model=drone_model, device=device) |
| 508 | + |
| 509 | + drone_ids = np.array([0, 1], dtype=int) |
| 510 | + rgba = 0.42 * np.ones((n_drones, 4), dtype=float) |
| 511 | + emission = 0.42 * np.ones((n_drones,), dtype=float) |
| 512 | + |
| 513 | + change_material(sim, mat_name=mat_name, drone_ids=drone_ids, rgba=rgba, emission=emission) |
| 514 | + |
| 515 | + mj_model = sim.mj_model |
| 516 | + mat0 = mujoco.mj_name2id(mj_model, mujoco.mjtObj.mjOBJ_MATERIAL, f"{mat_name}:0") |
| 517 | + mat1 = mujoco.mj_name2id(mj_model, mujoco.mjtObj.mjOBJ_MATERIAL, f"{mat_name}:1") |
| 518 | + expected_rgba = 0.42 * np.ones((4,), dtype=float) |
| 519 | + expected_emission = 0.42 |
| 520 | + np.testing.assert_allclose(mj_model.mat_rgba[mat0], expected_rgba) |
| 521 | + np.testing.assert_allclose(mj_model.mat_rgba[mat1], expected_rgba) |
| 522 | + assert mj_model.mat_emission[mat0] == pytest.approx(expected_emission) |
| 523 | + assert mj_model.mat_emission[mat1] == pytest.approx(expected_emission) |
| 524 | + |
| 525 | + |
| 526 | +@pytest.mark.unit |
| 527 | +def test_change_material_errors(device: str): |
| 528 | + """Test that change_material raises the expected errors for bad inputs.""" |
| 529 | + n_drones = 2 |
| 530 | + sim = Sim(n_drones=n_drones, device=device) |
| 531 | + |
| 532 | + drone_ids = np.array([0, 1], dtype=int) |
| 533 | + rgba = np.ones((n_drones, 4), dtype=float) |
| 534 | + emission = np.ones((n_drones,), dtype=float) |
| 535 | + |
| 536 | + with pytest.raises(ValueError): |
| 537 | + change_material( |
| 538 | + sim, mat_name="bad_mat", drone_ids=drone_ids, rgba=rgba, emission=emission |
| 539 | + ) |
| 540 | + |
| 541 | + with pytest.raises(ValueError, match="drone_ids must be 1D array"): |
| 542 | + change_material( |
| 543 | + sim, |
| 544 | + mat_name="led_top", |
| 545 | + drone_ids=np.array(2, dtype=int), |
| 546 | + rgba=rgba, |
| 547 | + emission=emission, |
| 548 | + ) |
| 549 | + |
| 550 | + with pytest.raises(ValueError, match=r"drone_ids must be in range \[0, 1\]"): |
| 551 | + change_material( |
| 552 | + sim, |
| 553 | + mat_name="led_top", |
| 554 | + drone_ids=np.arange(3, dtype=int), |
| 555 | + rgba=rgba, |
| 556 | + emission=emission, |
| 557 | + ) |
0 commit comments