Skip to content

Commit d63ca97

Browse files
committed
Fix linting. Add tests
1 parent 65a4c77 commit d63ca97

2 files changed

Lines changed: 29 additions & 5 deletions

File tree

crazyflow/sim/sim.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
from __future__ import annotations
2+
13
from functools import partial
24
from pathlib import Path
3-
from typing import Any, Callable, Optional
5+
from typing import TYPE_CHECKING, Any, Callable
46

57
import jax
68
import jax.numpy as jnp
@@ -10,7 +12,6 @@
1012
from gymnasium.envs.mujoco.mujoco_rendering import MujocoRenderer
1113
from jax import Array, Device
1214
from jax.scipy.spatial.transform import Rotation as R
13-
from mujoco.mjx import Data, Model
1415

1516
from crazyflow.constants import J_INV, MASS, SIGN_MIX_MATRIX, J
1617
from crazyflow.control.control import Control, attitude2rpm, pwm2rpm, state2attitude, thrust2pwm
@@ -28,6 +29,10 @@
2829
from crazyflow.sim.structs import SimControls, SimCore, SimData, SimParams, SimState, SimStateDeriv
2930
from crazyflow.utils import grid_2d, leaf_replace, patch_viewer, pytree_replace, to_device
3031

32+
if TYPE_CHECKING:
33+
from mujoco.mjx import Data, Model
34+
from numpy.typing import NDArray
35+
3136

3237
class Sim:
3338
default_path = Path(__file__).parents[1] / "models/cf2/scene.xml"
@@ -300,10 +305,19 @@ def thrust_control(self, cmd: Array):
300305
controls = to_device(cmd, self.device)
301306
self.data = self.data.replace(controls=self.data.controls.replace(thrust=controls))
302307

303-
def render(self, mode: Optional[str] = "human", world: Optional[int] = 0, default_cam_config: Optional[dict] = None):
308+
def render(
309+
self, mode: str | None = "human", world: int = 0, default_cam_config: dict | None = None
310+
) -> NDArray | None:
304311
if self.viewer is None:
305312
patch_viewer()
306-
self.viewer = MujocoRenderer(self.mj_model, self.mj_data, max_geom=self.max_visual_geom, default_cam_config=default_cam_config, height=480, width=640)
313+
self.viewer = MujocoRenderer(
314+
self.mj_model,
315+
self.mj_data,
316+
max_geom=self.max_visual_geom,
317+
default_cam_config=default_cam_config,
318+
height=480,
319+
width=640,
320+
)
307321
self.mj_data.qpos[:] = self.data.mjx_data.qpos[world, :]
308322
self.mj_data.mocap_pos[:] = self.data.mjx_data.mocap_pos[world, :]
309323
self.mj_data.mocap_quat[:] = self.data.mjx_data.mocap_quat[world, :]

tests/unit/test_sim.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,13 +258,23 @@ def test_sim_state_control_device(device: str):
258258

259259
@pytest.mark.parametrize("device", ["gpu", "cpu"])
260260
@pytest.mark.render
261-
def test_render(device: str):
261+
def test_render_human(device: str):
262262
skip_unavailable_device(device)
263263
sim = Sim(device=device)
264264
sim.render()
265265
sim.viewer.close()
266266

267267

268+
# Do not mark as render to ensure it runs by default. This function will not open a viewer.
269+
@pytest.mark.parametrize("device", ["gpu", "cpu"])
270+
def test_render_rgb_array(device: str):
271+
skip_unavailable_device(device)
272+
sim = Sim(n_worlds=2, device=device)
273+
img = sim.render(mode="rgb_array")
274+
assert isinstance(img, np.ndarray), "Image must be a numpy array"
275+
assert img.ndim == 3, "Image must be 3D"
276+
277+
268278
@pytest.mark.unit
269279
@pytest.mark.parametrize("device", ["gpu", "cpu"])
270280
def test_device(device: str):

0 commit comments

Comments
 (0)