|
| 1 | +from __future__ import annotations |
| 2 | + |
1 | 3 | from functools import partial |
2 | 4 | from pathlib import Path |
3 | | -from typing import Any, Callable, Optional |
| 5 | +from typing import TYPE_CHECKING, Any, Callable |
4 | 6 |
|
5 | 7 | import jax |
6 | 8 | import jax.numpy as jnp |
|
10 | 12 | from gymnasium.envs.mujoco.mujoco_rendering import MujocoRenderer |
11 | 13 | from jax import Array, Device |
12 | 14 | from jax.scipy.spatial.transform import Rotation as R |
13 | | -from mujoco.mjx import Data, Model |
14 | 15 |
|
15 | 16 | from crazyflow.constants import J_INV, MASS, SIGN_MIX_MATRIX, J |
16 | 17 | from crazyflow.control.control import Control, attitude2rpm, pwm2rpm, state2attitude, thrust2pwm |
|
28 | 29 | from crazyflow.sim.structs import SimControls, SimCore, SimData, SimParams, SimState, SimStateDeriv |
29 | 30 | from crazyflow.utils import grid_2d, leaf_replace, patch_viewer, pytree_replace, to_device |
30 | 31 |
|
| 32 | +if TYPE_CHECKING: |
| 33 | + from mujoco.mjx import Data, Model |
| 34 | + from numpy.typing import NDArray |
| 35 | + |
31 | 36 |
|
32 | 37 | class Sim: |
33 | 38 | default_path = Path(__file__).parents[1] / "models/cf2/scene.xml" |
@@ -300,10 +305,19 @@ def thrust_control(self, cmd: Array): |
300 | 305 | controls = to_device(cmd, self.device) |
301 | 306 | self.data = self.data.replace(controls=self.data.controls.replace(thrust=controls)) |
302 | 307 |
|
303 | | - def render(self, mode: Optional[str] = "human", world: Optional[int] = 0, default_cam_config: Optional[dict] = None): |
| 308 | + def render( |
| 309 | + self, mode: str | None = "human", world: int = 0, default_cam_config: dict | None = None |
| 310 | + ) -> NDArray | None: |
304 | 311 | if self.viewer is None: |
305 | 312 | patch_viewer() |
306 | | - self.viewer = MujocoRenderer(self.mj_model, self.mj_data, max_geom=self.max_visual_geom, default_cam_config=default_cam_config, height=480, width=640) |
| 313 | + self.viewer = MujocoRenderer( |
| 314 | + self.mj_model, |
| 315 | + self.mj_data, |
| 316 | + max_geom=self.max_visual_geom, |
| 317 | + default_cam_config=default_cam_config, |
| 318 | + height=480, |
| 319 | + width=640, |
| 320 | + ) |
307 | 321 | self.mj_data.qpos[:] = self.data.mjx_data.qpos[world, :] |
308 | 322 | self.mj_data.mocap_pos[:] = self.data.mjx_data.mocap_pos[world, :] |
309 | 323 | self.mj_data.mocap_quat[:] = self.data.mjx_data.mocap_quat[world, :] |
|
0 commit comments