Skip to content

Commit fc016fc

Browse files
committed
perf(ovrtx): batched tile extraction - single kernel launch for all envs
1 parent e7f34eb commit fc016fc

2 files changed

Lines changed: 67 additions & 36 deletions

File tree

source/isaaclab_ov/isaaclab_ov/renderers/ovrtx_renderer.py

Lines changed: 30 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@
4646
from .ovrtx_renderer_kernels import (
4747
DEVICE,
4848
create_camera_transforms_kernel,
49+
extract_all_depth_tiles_kernel,
50+
extract_all_rgba_tiles_kernel,
4951
extract_depth_tile_from_tiled_buffer_kernel,
5052
extract_tile_from_tiled_buffer_kernel,
5153
generate_random_colors_from_ids_kernel,
@@ -424,46 +426,38 @@ def _extract_rgba_tiles(
424426
buffer_key: str,
425427
suffix: str = "",
426428
) -> None:
427-
"""Extract per-env RGBA tiles from tiled buffer into output_buffers."""
428-
for env_idx in range(render_data.num_envs):
429-
tile_x = env_idx % render_data.num_cols
430-
tile_y = env_idx // render_data.num_cols
431-
wp.launch(
432-
kernel=extract_tile_from_tiled_buffer_kernel,
433-
dim=(render_data.height, render_data.width),
434-
inputs=[
435-
tiled_data,
436-
output_buffers[buffer_key][env_idx],
437-
tile_x,
438-
tile_y,
439-
render_data.width,
440-
render_data.height,
441-
],
442-
device=DEVICE,
443-
)
429+
"""Extract per-env RGBA tiles from tiled buffer into output_buffers (single kernel launch)."""
430+
wp.launch(
431+
kernel=extract_all_rgba_tiles_kernel,
432+
dim=(render_data.num_envs, render_data.height, render_data.width),
433+
inputs=[
434+
tiled_data,
435+
output_buffers[buffer_key],
436+
render_data.num_cols,
437+
render_data.width,
438+
render_data.height,
439+
],
440+
device=DEVICE,
441+
)
444442

445443
def _extract_depth_tiles(
446444
self, render_data: OVRTXRenderData, tiled_depth_data: wp.array, output_buffers: dict
447445
) -> None:
448-
"""Extract per-env depth tiles into output_buffers."""
449-
for env_idx in range(render_data.num_envs):
450-
tile_x = env_idx % render_data.num_cols
451-
tile_y = env_idx // render_data.num_cols
452-
for depth_type in ["depth", "distance_to_image_plane", "distance_to_camera"]:
453-
if depth_type in output_buffers:
454-
wp.launch(
455-
kernel=extract_depth_tile_from_tiled_buffer_kernel,
456-
dim=(render_data.height, render_data.width),
457-
inputs=[
458-
tiled_depth_data,
459-
output_buffers[depth_type][env_idx],
460-
tile_x,
461-
tile_y,
462-
render_data.width,
463-
render_data.height,
464-
],
465-
device=DEVICE,
466-
)
446+
"""Extract per-env depth tiles into output_buffers (single kernel launch)."""
447+
for depth_type in ["depth", "distance_to_image_plane", "distance_to_camera"]:
448+
if depth_type in output_buffers:
449+
wp.launch(
450+
kernel=extract_all_depth_tiles_kernel,
451+
dim=(render_data.num_envs, render_data.height, render_data.width),
452+
inputs=[
453+
tiled_depth_data,
454+
output_buffers[depth_type],
455+
render_data.num_cols,
456+
render_data.width,
457+
render_data.height,
458+
],
459+
device=DEVICE,
460+
)
467461

468462
def _process_render_frame(self, render_data: OVRTXRenderData, frame, output_buffers: dict) -> None:
469463
"""Extract RGB, depth, albedo, and semantic from a single render frame into output_buffers."""

source/isaaclab_ov/isaaclab_ov/renderers/ovrtx_renderer_kernels.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,43 @@ def extract_tile_from_tiled_buffer_kernel(
7373
tile_buffer[y, x, 3] = tiled_buffer[src_y, src_x, 3]
7474

7575

76+
@wp.kernel
77+
def extract_all_rgba_tiles_kernel(
78+
tiled_buffer: wp.array(dtype=wp.uint8, ndim=3), # type: ignore
79+
output_buffer: wp.array(dtype=wp.uint8, ndim=4), # type: ignore (num_envs, H, W, 4)
80+
num_cols: int,
81+
tile_width: int,
82+
tile_height: int,
83+
):
84+
"""Extract ALL RGBA tiles from a tiled buffer in a single kernel launch."""
85+
env_idx, y, x = wp.tid()
86+
tile_x = env_idx % num_cols
87+
tile_y = env_idx // num_cols
88+
src_x = tile_x * tile_width + x
89+
src_y = tile_y * tile_height + y
90+
output_buffer[env_idx, y, x, 0] = tiled_buffer[src_y, src_x, 0]
91+
output_buffer[env_idx, y, x, 1] = tiled_buffer[src_y, src_x, 1]
92+
output_buffer[env_idx, y, x, 2] = tiled_buffer[src_y, src_x, 2]
93+
output_buffer[env_idx, y, x, 3] = tiled_buffer[src_y, src_x, 3]
94+
95+
96+
@wp.kernel
97+
def extract_all_depth_tiles_kernel(
98+
tiled_buffer: wp.array(dtype=wp.float32, ndim=2), # type: ignore
99+
output_buffer: wp.array(dtype=wp.float32, ndim=4), # type: ignore (num_envs, H, W, 1)
100+
num_cols: int,
101+
tile_width: int,
102+
tile_height: int,
103+
):
104+
"""Extract ALL depth tiles from a tiled buffer in a single kernel launch."""
105+
env_idx, y, x = wp.tid()
106+
tile_x = env_idx % num_cols
107+
tile_y = env_idx // num_cols
108+
src_x = tile_x * tile_width + x
109+
src_y = tile_y * tile_height + y
110+
output_buffer[env_idx, y, x, 0] = tiled_buffer[src_y, src_x]
111+
112+
76113
@wp.kernel
77114
def extract_depth_tile_from_tiled_buffer_kernel(
78115
tiled_buffer: wp.array(dtype=wp.float32, ndim=2), # type: ignore

0 commit comments

Comments
 (0)