Skip to content

Commit df64cad

Browse files
committed
Improve raycasting performance
1 parent 6e549a5 commit df64cad

2 files changed

Lines changed: 85 additions & 28 deletions

File tree

crazyflow/sim/sensors.py

Lines changed: 64 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,92 @@
11
from functools import partial
2+
from typing import Callable
23

34
import jax
45
import jax.numpy as jnp
56
import mujoco.mjx as mjx
6-
import numpy as np
77
from jax import Array
88

99
from crazyflow.sim.sim import Sim, requires_mujoco_sync
1010

1111

1212
@requires_mujoco_sync
13-
def render_depth(sim: Sim, camera: int = 0, resolution: tuple[int, int] = (100, 100)) -> Array:
13+
def render_depth(
14+
sim: Sim, camera: int = 0, resolution: tuple[int, int] = (100, 100), include_drone: bool = False
15+
) -> Array:
1416
"""Render depth images using raycasting.
1517
1618
Note:
1719
Code has been adoped from
1820
https://github.com/Andrew-Luo1/jax_shac/blob/main/vision/2dof_ball.ipynb
1921
"""
20-
local_rays = camera_rays(resolution=resolution, fov_y=np.pi / 4)[None, ...]
21-
rays = to_mjx_frame(local_rays, sim.mjx_data.cam_xmat[:, camera])
22-
dist, _ = ray_fn(sim.mjx_model, sim.mjx_data, sim.mjx_data.cam_xpos[:, camera], rays)
23-
return dist
22+
return _render_depth(
23+
mjx_data=sim.mjx_data,
24+
mjx_model=sim.mjx_model,
25+
camera=camera,
26+
resolution=resolution,
27+
include_drone=include_drone,
28+
)
2429

2530

26-
@jax.jit
27-
def to_mjx_frame(x: Array, xmat: Array) -> Array:
28-
"""Transform points to a different frame given its rotation matrix."""
29-
return (xmat[:, None, None, ...] @ x[..., None])[..., 0]
31+
def build_render_depth_fn(
32+
mjx_model: mjx.Model,
33+
camera: int = 0,
34+
resolution: tuple[int, int] = (100, 100),
35+
geomgroup: tuple[int, ...] = (1, 1, 0, 0, 1, 1, 1, 1),
36+
) -> Callable[[Sim], Array]:
37+
"""Build a depth renderer function for given camera and resolution.
38+
39+
Compiles the mjx model and rays directly into the rendering function for higher performance. The
40+
returned function takes a Sim object as input and returns depth images.
41+
"""
42+
rays = _camera_rays(resolution=resolution, fov_y=jnp.pi / 4)[None, ...]
43+
ray_fn = jax.jit(
44+
partial(_render_rays, mjx_model=mjx_model, camera=camera, geomgroup=geomgroup, rays=rays),
45+
static_argnames=("mjx_model", "camera", "geomgroup", "rays"),
46+
)
47+
48+
@requires_mujoco_sync
49+
def render_depth_fn(sim: Sim) -> Array:
50+
return ray_fn(mjx_data=sim.mjx_data)
51+
52+
return render_depth_fn
3053

3154

32-
_geomgroup = (1, 1, 1, 0, 1, 1, 1, 1) # Exclude collision geoms
33-
ray_fn = jax.jit(
34-
jax.vmap(
35-
jax.vmap(
36-
jax.vmap(partial(mjx.ray, geomgroup=_geomgroup), in_axes=(None, None, None, 0)),
37-
in_axes=(None, None, None, 0),
38-
),
55+
@jax.jit(static_argnames=("camera", "resolution", "include_drone"))
56+
def _render_depth(
57+
mjx_data: mjx.Data,
58+
mjx_model: mjx.Model,
59+
camera: int,
60+
resolution: tuple[int, int],
61+
include_drone: bool = False,
62+
) -> Array:
63+
"""Accelerates the dynamic rendering of depth images."""
64+
local_rays = _camera_rays(resolution=resolution, fov_y=jnp.pi / 4)[None, ...]
65+
geomgroup = (1, 1, 1, 0, 1, 1, 1, 1) if include_drone else (1, 1, 0, 0, 1, 1, 1, 1)
66+
return _render_rays(
67+
mjx_data=mjx_data, mjx_model=mjx_model, camera=camera, rays=local_rays, geomgroup=geomgroup
68+
)
69+
70+
71+
def _render_rays(
72+
mjx_data: mjx.Data, mjx_model: mjx.Model, camera: int, rays: Array, geomgroup: tuple[int, ...]
73+
) -> Array:
74+
"""Render a given ray array using MuJoCo's raycasting."""
75+
rays = _to_mjx_frame(rays, mjx_data.cam_xmat[:, camera])
76+
ray_ax = (None, None, None, 0)
77+
ray = jax.vmap(
78+
jax.vmap(jax.vmap(partial(mjx.ray, geomgroup=geomgroup), in_axes=ray_ax), in_axes=ray_ax),
3979
in_axes=(None, 0, 0, 0),
4080
)
41-
)
81+
return ray(mjx_model, mjx_data, mjx_data.cam_xpos[:, camera], rays)[0]
82+
83+
84+
def _to_mjx_frame(x: Array, xmat: Array) -> Array:
85+
"""Transform points to a different frame given its rotation matrix."""
86+
return (xmat[:, None, None, ...] @ x[..., None])[..., 0]
4287

4388

44-
@partial(jax.jit, static_argnames=("resolution", "fov_y"))
45-
def camera_rays(resolution: tuple[int, int] = (100, 100), fov_y: float = jnp.pi / 4) -> Array:
89+
def _camera_rays(resolution: tuple[int, int] = (100, 100), fov_y: float = jnp.pi / 4) -> Array:
4690
"""Create an array of rays with a given field of view and resolution.
4791
4892
Args:

examples/raycasting.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,35 @@
22
import matplotlib.pyplot as plt
33

44
from crazyflow.sim import Sim
5-
from crazyflow.sim.sensors import render_depth
5+
from crazyflow.sim.sensors import build_render_depth_fn, render_depth
66

77

88
def main(plot: bool = False):
99
sim = Sim()
1010
sim.data = sim.data.replace(
1111
states=sim.data.states.replace(pos=sim.data.states.pos.at[..., 2].set(0.2))
1212
)
13-
dist = render_depth(sim, camera=0, resolution=(100, 100))
13+
# The easiest way to get depth images is to use the render_depth function
14+
dist = render_depth(sim, camera=0, resolution=(100, 100), include_drone=False)
1415
dist = dist.at[dist > 1.5].set(jnp.nan) # Cap max distance for better visualization
15-
if not plot:
16-
return
17-
plt.imshow(dist[0], cmap="viridis")
18-
plt.colorbar(label="Distance (m)")
19-
plt.title("Raycast Distance from Camera")
20-
plt.show()
16+
if plot:
17+
plt.imshow(dist[0], cmap="viridis")
18+
plt.colorbar(label="Distance (m)")
19+
plt.title("Raycast Distance from Camera")
20+
plt.show()
21+
# We can also build a depth renderer function for better performance if we need maximum speed or
22+
# more fine-grained control. Here we only render the drone collision geometry to avoid expensive
23+
# raycasting against the high-poly visual mesh of the drone.
24+
render_depth_fn = build_render_depth_fn(
25+
sim.mjx_model, camera=0, resolution=(200, 200), geomgroup=(1, 1, 0, 1, 1, 1, 1, 1)
26+
)
27+
dist_fn = render_depth_fn(sim)
28+
dist_fn = dist_fn.at[dist_fn > 1.5].set(jnp.nan) # Cap max distance for better visualization
29+
if plot:
30+
plt.imshow(dist_fn[0], cmap="viridis")
31+
plt.colorbar(label="Distance (m)")
32+
plt.title("Raycast Distance from Camera (Compiled)")
33+
plt.show()
2134

2235

2336
if __name__ == "__main__":

0 commit comments

Comments
 (0)