|
| 1 | +from functools import partial |
| 2 | +from typing import Callable |
| 3 | + |
| 4 | +import jax |
| 5 | +import jax.numpy as jnp |
| 6 | +import mujoco.mjx as mjx |
| 7 | +from jax import Array |
| 8 | + |
| 9 | +from crazyflow.sim.sim import Sim, requires_mujoco_sync |
| 10 | + |
| 11 | + |
| 12 | +@requires_mujoco_sync |
| 13 | +def render_depth( |
| 14 | + sim: Sim, camera: int = 0, resolution: tuple[int, int] = (100, 100), include_drone: bool = False |
| 15 | +) -> Array: |
| 16 | + """Render depth images using raycasting. |
| 17 | +
|
| 18 | + Note: |
| 19 | + Code has been adoped from |
| 20 | + https://github.com/Andrew-Luo1/jax_shac/blob/main/vision/2dof_ball.ipynb |
| 21 | + """ |
| 22 | + return _render_depth( |
| 23 | + mjx_data=sim.mjx_data, |
| 24 | + mjx_model=sim.mjx_model, |
| 25 | + camera=camera, |
| 26 | + resolution=resolution, |
| 27 | + include_drone=include_drone, |
| 28 | + ) |
| 29 | + |
| 30 | + |
| 31 | +def build_render_depth_fn( |
| 32 | + mjx_model: mjx.Model, |
| 33 | + camera: int = 0, |
| 34 | + resolution: tuple[int, int] = (100, 100), |
| 35 | + geomgroup: tuple[int, ...] = (1, 1, 0, 0, 1, 1, 1, 1), |
| 36 | +) -> Callable[[Sim], Array]: |
| 37 | + """Build a depth renderer function for given camera and resolution. |
| 38 | +
|
| 39 | + Compiles the mjx model and rays directly into the rendering function for higher performance. The |
| 40 | + returned function takes a Sim object as input and returns depth images. |
| 41 | + """ |
| 42 | + rays = _camera_rays(resolution=resolution, fov_y=jnp.pi / 4)[None, ...] |
| 43 | + ray_fn = jax.jit( |
| 44 | + partial(_render_rays, mjx_model=mjx_model, camera=camera, geomgroup=geomgroup, rays=rays), |
| 45 | + static_argnames=("mjx_model", "camera", "geomgroup", "rays"), |
| 46 | + ) |
| 47 | + |
| 48 | + @requires_mujoco_sync |
| 49 | + def render_depth_fn(sim: Sim) -> Array: |
| 50 | + return ray_fn(mjx_data=sim.mjx_data) |
| 51 | + |
| 52 | + return render_depth_fn |
| 53 | + |
| 54 | + |
| 55 | +@jax.jit(static_argnames=("camera", "resolution", "include_drone")) |
| 56 | +def _render_depth( |
| 57 | + mjx_data: mjx.Data, |
| 58 | + mjx_model: mjx.Model, |
| 59 | + camera: int, |
| 60 | + resolution: tuple[int, int], |
| 61 | + include_drone: bool = False, |
| 62 | +) -> Array: |
| 63 | + """Accelerates the dynamic rendering of depth images.""" |
| 64 | + local_rays = _camera_rays(resolution=resolution, fov_y=jnp.pi / 4)[None, ...] |
| 65 | + geomgroup = (1, 1, 1, 0, 1, 1, 1, 1) if include_drone else (1, 1, 0, 0, 1, 1, 1, 1) |
| 66 | + return _render_rays( |
| 67 | + mjx_data=mjx_data, mjx_model=mjx_model, camera=camera, rays=local_rays, geomgroup=geomgroup |
| 68 | + ) |
| 69 | + |
| 70 | + |
| 71 | +def _render_rays( |
| 72 | + mjx_data: mjx.Data, mjx_model: mjx.Model, camera: int, rays: Array, geomgroup: tuple[int, ...] |
| 73 | +) -> Array: |
| 74 | + """Render a given ray array using MuJoCo's raycasting.""" |
| 75 | + rays = _to_mjx_frame(rays, mjx_data.cam_xmat[:, camera]) |
| 76 | + ray_ax = (None, None, None, 0) |
| 77 | + ray = jax.vmap( |
| 78 | + jax.vmap(jax.vmap(partial(mjx.ray, geomgroup=geomgroup), in_axes=ray_ax), in_axes=ray_ax), |
| 79 | + in_axes=(None, 0, 0, 0), |
| 80 | + ) |
| 81 | + return ray(mjx_model, mjx_data, mjx_data.cam_xpos[:, camera], rays)[0] |
| 82 | + |
| 83 | + |
| 84 | +def _to_mjx_frame(x: Array, xmat: Array) -> Array: |
| 85 | + """Transform points to a different frame given its rotation matrix.""" |
| 86 | + return (xmat[:, None, None, ...] @ x[..., None])[..., 0] |
| 87 | + |
| 88 | + |
| 89 | +def _camera_rays(resolution: tuple[int, int] = (100, 100), fov_y: float = jnp.pi / 4) -> Array: |
| 90 | + """Create an array of rays with a given field of view and resolution. |
| 91 | +
|
| 92 | + Args: |
| 93 | + resolution: Image resolution as (width, height). |
| 94 | + fov_y: Vertical field of view in radians. |
| 95 | + """ |
| 96 | + image_height = jnp.tan(fov_y / 2) * 2 |
| 97 | + image_width = image_height * (resolution[0] / resolution[1]) # Square pixels. |
| 98 | + delta = image_width / (2 * resolution[0]) |
| 99 | + x = jnp.linspace(-image_width / 2 + delta, image_width / 2 - delta, resolution[0]) |
| 100 | + y = jnp.flip(jnp.linspace(-image_height / 2 + delta, image_height / 2 - delta, resolution[1])) |
| 101 | + X, Y = jnp.meshgrid(x, y) |
| 102 | + rays = jnp.stack([X, Y, -jnp.ones_like(X)], axis=-1) |
| 103 | + return rays / jnp.linalg.norm(rays, axis=-1, keepdims=True) |
0 commit comments