Skip to content

Commit 3b28b44

Browse files
authored
Add batched raycasting sensor support to crazyflow (#48)
1 parent c92445c commit 3b28b44

5 files changed

Lines changed: 203 additions & 3 deletions

File tree

.github/copilot-instructions.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# General
2+
- Prefer failing fast over capturing exceptions and continuing execution.
3+
- Use pure functions where possible.
4+
- Prefer protocols over inheritance.
5+
- Use pathlib instead of os.path for file path manipulations.
6+
- Use napoleon-type docstrings for documenting functions and classes.
7+
- Use type hints for function signatures. Do not put types into the docstrings.

crazyflow/sim/sensors.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
from functools import partial
2+
from typing import Callable
3+
4+
import jax
5+
import jax.numpy as jnp
6+
import mujoco.mjx as mjx
7+
from jax import Array
8+
9+
from crazyflow.sim.sim import Sim, requires_mujoco_sync
10+
11+
12+
@requires_mujoco_sync
13+
def render_depth(
14+
sim: Sim, camera: int = 0, resolution: tuple[int, int] = (100, 100), include_drone: bool = False
15+
) -> Array:
16+
"""Render depth images using raycasting.
17+
18+
Note:
19+
Code has been adoped from
20+
https://github.com/Andrew-Luo1/jax_shac/blob/main/vision/2dof_ball.ipynb
21+
"""
22+
return _render_depth(
23+
mjx_data=sim.mjx_data,
24+
mjx_model=sim.mjx_model,
25+
camera=camera,
26+
resolution=resolution,
27+
include_drone=include_drone,
28+
)
29+
30+
31+
def build_render_depth_fn(
32+
mjx_model: mjx.Model,
33+
camera: int = 0,
34+
resolution: tuple[int, int] = (100, 100),
35+
geomgroup: tuple[int, ...] = (1, 1, 0, 0, 1, 1, 1, 1),
36+
) -> Callable[[Sim], Array]:
37+
"""Build a depth renderer function for given camera and resolution.
38+
39+
Compiles the mjx model and rays directly into the rendering function for higher performance. The
40+
returned function takes a Sim object as input and returns depth images.
41+
"""
42+
rays = _camera_rays(resolution=resolution, fov_y=jnp.pi / 4)[None, ...]
43+
ray_fn = jax.jit(
44+
partial(_render_rays, mjx_model=mjx_model, camera=camera, geomgroup=geomgroup, rays=rays),
45+
static_argnames=("mjx_model", "camera", "geomgroup", "rays"),
46+
)
47+
48+
@requires_mujoco_sync
49+
def render_depth_fn(sim: Sim) -> Array:
50+
return ray_fn(mjx_data=sim.mjx_data)
51+
52+
return render_depth_fn
53+
54+
55+
@jax.jit(static_argnames=("camera", "resolution", "include_drone"))
56+
def _render_depth(
57+
mjx_data: mjx.Data,
58+
mjx_model: mjx.Model,
59+
camera: int,
60+
resolution: tuple[int, int],
61+
include_drone: bool = False,
62+
) -> Array:
63+
"""Accelerates the dynamic rendering of depth images."""
64+
local_rays = _camera_rays(resolution=resolution, fov_y=jnp.pi / 4)[None, ...]
65+
geomgroup = (1, 1, 1, 0, 1, 1, 1, 1) if include_drone else (1, 1, 0, 0, 1, 1, 1, 1)
66+
return _render_rays(
67+
mjx_data=mjx_data, mjx_model=mjx_model, camera=camera, rays=local_rays, geomgroup=geomgroup
68+
)
69+
70+
71+
def _render_rays(
72+
mjx_data: mjx.Data, mjx_model: mjx.Model, camera: int, rays: Array, geomgroup: tuple[int, ...]
73+
) -> Array:
74+
"""Render a given ray array using MuJoCo's raycasting."""
75+
rays = _to_mjx_frame(rays, mjx_data.cam_xmat[:, camera])
76+
ray_ax = (None, None, None, 0)
77+
ray = jax.vmap(
78+
jax.vmap(jax.vmap(partial(mjx.ray, geomgroup=geomgroup), in_axes=ray_ax), in_axes=ray_ax),
79+
in_axes=(None, 0, 0, 0),
80+
)
81+
return ray(mjx_model, mjx_data, mjx_data.cam_xpos[:, camera], rays)[0]
82+
83+
84+
def _to_mjx_frame(x: Array, xmat: Array) -> Array:
85+
"""Transform points to a different frame given its rotation matrix."""
86+
return (xmat[:, None, None, ...] @ x[..., None])[..., 0]
87+
88+
89+
def _camera_rays(resolution: tuple[int, int] = (100, 100), fov_y: float = jnp.pi / 4) -> Array:
90+
"""Create an array of rays with a given field of view and resolution.
91+
92+
Args:
93+
resolution: Image resolution as (width, height).
94+
fov_y: Vertical field of view in radians.
95+
"""
96+
image_height = jnp.tan(fov_y / 2) * 2
97+
image_width = image_height * (resolution[0] / resolution[1]) # Square pixels.
98+
delta = image_width / (2 * resolution[0])
99+
x = jnp.linspace(-image_width / 2 + delta, image_width / 2 - delta, resolution[0])
100+
y = jnp.flip(jnp.linspace(-image_height / 2 + delta, image_height / 2 - delta, resolution[1]))
101+
X, Y = jnp.meshgrid(x, y)
102+
rays = jnp.stack([X, Y, -jnp.ones_like(X)], axis=-1)
103+
return rays / jnp.linalg.norm(rays, axis=-1, keepdims=True)

crazyflow/sim/sim.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -232,11 +232,10 @@ def build_mjx_spec(self) -> mujoco.MjSpec:
232232
spec.copy_during_attach = True
233233
drone_spec = mujoco.MjSpec.from_file(str(self.drone_path))
234234
frame = spec.worldbody.add_frame(name="world")
235+
if (drone_body := drone_spec.body("drone")) is None:
236+
raise ValueError("Drone body not found in drone spec")
235237
# Add drones and their actuators
236238
for i in range(self.n_drones):
237-
drone_body = drone_spec.body("drone")
238-
if drone_body is None:
239-
raise ValueError("Drone body not found in drone spec")
240239
drone = frame.attach_body(drone_body, "", f":{i}")
241240
drone.add_freejoint()
242241
return spec
@@ -534,6 +533,8 @@ def sync_sim2mjx(data: SimData, mjx_data: Data, mjx_model: Model) -> tuple[SimDa
534533
qvel = rearrange(jnp.concat([vel, ang_vel], axis=-1), "w d qvel -> w (d qvel)")
535534
mjx_data = mjx_data.replace(qpos=qpos, qvel=qvel)
536535
mjx_data = jax.vmap(mjx.kinematics, in_axes=(None, 0))(mjx_model, mjx_data)
536+
# Required for rendering w. ray casting
537+
mjx_data = jax.vmap(mjx.camlight, in_axes=(None, 0))(mjx_model, mjx_data)
537538
mjx_data = jax.vmap(mjx.collision, in_axes=(None, 0))(mjx_model, mjx_data)
538539
data = data.replace(core=data.core.replace(mjx_synced=True))
539540
return data, mjx_data

examples/raycasting.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import jax.numpy as jnp
2+
import matplotlib.pyplot as plt
3+
4+
from crazyflow.sim import Sim
5+
from crazyflow.sim.sensors import build_render_depth_fn, render_depth
6+
7+
8+
def main(plot: bool = False):
9+
sim = Sim()
10+
sim.data = sim.data.replace(
11+
states=sim.data.states.replace(pos=sim.data.states.pos.at[..., 2].set(0.2))
12+
)
13+
# The easiest way to get depth images is to use the render_depth function
14+
dist = render_depth(sim, camera=0, resolution=(100, 100), include_drone=False)
15+
dist = dist.at[dist > 1.5].set(jnp.nan) # Cap max distance for better visualization
16+
if plot:
17+
plt.imshow(dist[0], cmap="viridis")
18+
plt.colorbar(label="Distance (m)")
19+
plt.title("Raycast Distance from Camera")
20+
plt.show()
21+
# We can also build a depth renderer function for better performance if we need maximum speed or
22+
# more fine-grained control. Here we only render the drone collision geometry to avoid expensive
23+
# raycasting against the high-poly visual mesh of the drone.
24+
render_depth_fn = build_render_depth_fn(
25+
sim.mjx_model, camera=0, resolution=(200, 200), geomgroup=(1, 1, 0, 1, 1, 1, 1, 1)
26+
)
27+
dist_fn = render_depth_fn(sim)
28+
dist_fn = dist_fn.at[dist_fn > 1.5].set(jnp.nan) # Cap max distance for better visualization
29+
if plot:
30+
plt.imshow(dist_fn[0], cmap="viridis")
31+
plt.colorbar(label="Distance (m)")
32+
plt.title("Raycast Distance from Camera (Compiled)")
33+
plt.show()
34+
35+
36+
if __name__ == "__main__":
37+
main(plot=True)

tests/unit/test_sensors.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
"""Unit tests for the sensors module."""
2+
3+
from __future__ import annotations
4+
5+
import jax
6+
import jax.numpy as jnp
7+
import numpy as np
8+
import pytest
9+
10+
from crazyflow.sim import Sim
11+
from crazyflow.sim.sensors import _camera_rays, build_render_depth_fn, render_depth
12+
13+
14+
@pytest.mark.unit
15+
def test_camera_rays():
16+
"""Test that camera_rays produces arrays with correct shape and device."""
17+
resolution = (64, 48)
18+
rays = _camera_rays(resolution=resolution)
19+
# Check shape: should be (height, width, 3)
20+
expected_shape = (resolution[1], resolution[0], 3)
21+
assert rays.shape == expected_shape, f"Expected shape {expected_shape}, got {rays.shape}"
22+
# Check that rays are normalized
23+
norm = jnp.linalg.norm(rays, axis=-1)
24+
assert jnp.allclose(norm, 1.0, atol=1e-6), "Rays should be normalized"
25+
# Check that rays respect the FOV
26+
rays_narrow = _camera_rays(fov_y=np.pi / 6)
27+
rays_wide = _camera_rays(fov_y=np.pi / 3)
28+
# Corner rays should have different angles for different FOV
29+
# Check the top corner ray y-component (wider FOV should have larger y-component)
30+
corner_y_narrow = abs(rays_narrow[0, 0, 1]) # Top-left corner
31+
corner_y_wide = abs(rays_wide[0, 0, 1])
32+
assert corner_y_wide > corner_y_narrow, "Wider FOV should produce rays with larger y-components"
33+
34+
35+
@pytest.mark.unit
36+
def test_render_depth(device: str):
37+
"""Test render_depth with different resolutions."""
38+
sim = Sim(n_worlds=2, device=device)
39+
dist = render_depth(sim, camera=0, resolution=(10, 10))
40+
assert dist.shape == (2, 10, 10), f"Expected shape (2, 10, 10), got {dist.shape}"
41+
assert dist.device == jax.devices(device)[0], f"Expected device {device}, got {dist.device}"
42+
43+
44+
@pytest.mark.unit
45+
def test_build_render_depth_fn():
46+
"""Test build_render_depth_fn produces a callable that returns correct shapes."""
47+
sim = Sim(n_worlds=3)
48+
render_depth_fn = build_render_depth_fn(
49+
sim.mjx_model, camera=0, resolution=(20, 15), geomgroup=(1, 1, 0, 1, 1, 1, 1, 1)
50+
)
51+
dist = render_depth_fn(sim)
52+
assert dist.shape == (3, 15, 20), f"Expected shape (3, 15, 20), got {dist.shape}"

0 commit comments

Comments
 (0)