|
1 | 1 | from functools import partial |
| 2 | +from typing import Callable |
2 | 3 |
|
3 | 4 | import jax |
4 | 5 | import jax.numpy as jnp |
5 | 6 | import mujoco.mjx as mjx |
6 | | -import numpy as np |
7 | 7 | from jax import Array |
8 | 8 |
|
9 | 9 | from crazyflow.sim.sim import Sim, requires_mujoco_sync |
10 | 10 |
|
11 | 11 |
|
12 | 12 | @requires_mujoco_sync |
13 | | -def render_depth(sim: Sim, camera: int = 0, resolution: tuple[int, int] = (100, 100)) -> Array: |
| 13 | +def render_depth( |
| 14 | + sim: Sim, camera: int = 0, resolution: tuple[int, int] = (100, 100), include_drone: bool = False |
| 15 | +) -> Array: |
14 | 16 | """Render depth images using raycasting. |
15 | 17 |
|
16 | 18 | Note: |
17 | 19 | Code has been adoped from |
18 | 20 | https://github.com/Andrew-Luo1/jax_shac/blob/main/vision/2dof_ball.ipynb |
19 | 21 | """ |
20 | | - local_rays = camera_rays(resolution=resolution, fov_y=np.pi / 4)[None, ...] |
21 | | - rays = to_mjx_frame(local_rays, sim.mjx_data.cam_xmat[:, camera]) |
22 | | - dist, _ = ray_fn(sim.mjx_model, sim.mjx_data, sim.mjx_data.cam_xpos[:, camera], rays) |
23 | | - return dist |
| 22 | + return _render_depth( |
| 23 | + mjx_data=sim.mjx_data, |
| 24 | + mjx_model=sim.mjx_model, |
| 25 | + camera=camera, |
| 26 | + resolution=resolution, |
| 27 | + include_drone=include_drone, |
| 28 | + ) |
24 | 29 |
|
25 | 30 |
|
26 | | -@jax.jit |
27 | | -def to_mjx_frame(x: Array, xmat: Array) -> Array: |
28 | | - """Transform points to a different frame given its rotation matrix.""" |
29 | | - return (xmat[:, None, None, ...] @ x[..., None])[..., 0] |
| 31 | +def build_render_depth_fn( |
| 32 | + mjx_model: mjx.Model, |
| 33 | + camera: int = 0, |
| 34 | + resolution: tuple[int, int] = (100, 100), |
| 35 | + geomgroup: tuple[int, ...] = (1, 1, 0, 0, 1, 1, 1, 1), |
| 36 | +) -> Callable[[Sim], Array]: |
| 37 | + """Build a depth renderer function for given camera and resolution. |
| 38 | +
|
| 39 | + Compiles the mjx model and rays directly into the rendering function for higher performance. The |
| 40 | + returned function takes a Sim object as input and returns depth images. |
| 41 | + """ |
| 42 | + rays = _camera_rays(resolution=resolution, fov_y=jnp.pi / 4)[None, ...] |
| 43 | + ray_fn = jax.jit( |
| 44 | + partial(_render_rays, mjx_model=mjx_model, camera=camera, geomgroup=geomgroup, rays=rays), |
| 45 | + static_argnames=("mjx_model", "camera", "geomgroup", "rays"), |
| 46 | + ) |
| 47 | + |
| 48 | + @requires_mujoco_sync |
| 49 | + def render_depth_fn(sim: Sim) -> Array: |
| 50 | + return ray_fn(mjx_data=sim.mjx_data) |
| 51 | + |
| 52 | + return render_depth_fn |
30 | 53 |
|
31 | 54 |
|
32 | | -_geomgroup = (1, 1, 1, 0, 1, 1, 1, 1) # Exclude collision geoms |
33 | | -ray_fn = jax.jit( |
34 | | - jax.vmap( |
35 | | - jax.vmap( |
36 | | - jax.vmap(partial(mjx.ray, geomgroup=_geomgroup), in_axes=(None, None, None, 0)), |
37 | | - in_axes=(None, None, None, 0), |
38 | | - ), |
| 55 | +@jax.jit(static_argnames=("camera", "resolution", "include_drone")) |
| 56 | +def _render_depth( |
| 57 | + mjx_data: mjx.Data, |
| 58 | + mjx_model: mjx.Model, |
| 59 | + camera: int, |
| 60 | + resolution: tuple[int, int], |
| 61 | + include_drone: bool = False, |
| 62 | +) -> Array: |
| 63 | + """Accelerates the dynamic rendering of depth images.""" |
| 64 | + local_rays = _camera_rays(resolution=resolution, fov_y=jnp.pi / 4)[None, ...] |
| 65 | + geomgroup = (1, 1, 1, 0, 1, 1, 1, 1) if include_drone else (1, 1, 0, 0, 1, 1, 1, 1) |
| 66 | + return _render_rays( |
| 67 | + mjx_data=mjx_data, mjx_model=mjx_model, camera=camera, rays=local_rays, geomgroup=geomgroup |
| 68 | + ) |
| 69 | + |
| 70 | + |
| 71 | +def _render_rays( |
| 72 | + mjx_data: mjx.Data, mjx_model: mjx.Model, camera: int, rays: Array, geomgroup: tuple[int, ...] |
| 73 | +) -> Array: |
| 74 | + """Render a given ray array using MuJoCo's raycasting.""" |
| 75 | + rays = _to_mjx_frame(rays, mjx_data.cam_xmat[:, camera]) |
| 76 | + ray_ax = (None, None, None, 0) |
| 77 | + ray = jax.vmap( |
| 78 | + jax.vmap(jax.vmap(partial(mjx.ray, geomgroup=geomgroup), in_axes=ray_ax), in_axes=ray_ax), |
39 | 79 | in_axes=(None, 0, 0, 0), |
40 | 80 | ) |
41 | | -) |
| 81 | + return ray(mjx_model, mjx_data, mjx_data.cam_xpos[:, camera], rays)[0] |
| 82 | + |
| 83 | + |
| 84 | +def _to_mjx_frame(x: Array, xmat: Array) -> Array: |
| 85 | + """Transform points to a different frame given its rotation matrix.""" |
| 86 | + return (xmat[:, None, None, ...] @ x[..., None])[..., 0] |
42 | 87 |
|
43 | 88 |
|
44 | | -@partial(jax.jit, static_argnames=("resolution", "fov_y")) |
45 | | -def camera_rays(resolution: tuple[int, int] = (100, 100), fov_y: float = jnp.pi / 4) -> Array: |
| 89 | +def _camera_rays(resolution: tuple[int, int] = (100, 100), fov_y: float = jnp.pi / 4) -> Array: |
46 | 90 | """Create an array of rays with a given field of view and resolution. |
47 | 91 |
|
48 | 92 | Args: |
|
0 commit comments