Skip to content

Commit 4056829

Browse files
committed
perf(ovrtx): batched tile extraction - single kernel launch for all envs
1 parent 0090622 commit 4056829

2 files changed

Lines changed: 67 additions & 38 deletions

File tree

source/isaaclab_ov/isaaclab_ov/renderers/ovrtx_renderer.py

Lines changed: 30 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@
4646
from .ovrtx_renderer_kernels import (
4747
DEVICE,
4848
create_camera_transforms_kernel,
49-
extract_depth_tile_from_tiled_buffer_kernel,
50-
extract_tile_from_tiled_buffer_kernel,
49+
extract_all_depth_tiles_kernel,
50+
extract_all_rgba_tiles_kernel,
5151
generate_random_colors_from_ids_kernel,
5252
sync_newton_transforms_kernel,
5353
)
@@ -424,46 +424,38 @@ def _extract_rgba_tiles(
424424
buffer_key: str,
425425
suffix: str = "",
426426
) -> None:
427-
"""Extract per-env RGBA tiles from tiled buffer into output_buffers."""
428-
for env_idx in range(render_data.num_envs):
429-
tile_x = env_idx % render_data.num_cols
430-
tile_y = env_idx // render_data.num_cols
431-
wp.launch(
432-
kernel=extract_tile_from_tiled_buffer_kernel,
433-
dim=(render_data.height, render_data.width),
434-
inputs=[
435-
tiled_data,
436-
output_buffers[buffer_key][env_idx],
437-
tile_x,
438-
tile_y,
439-
render_data.width,
440-
render_data.height,
441-
],
442-
device=DEVICE,
443-
)
427+
"""Extract per-env RGBA tiles from tiled buffer into output_buffers (single kernel launch)."""
428+
wp.launch(
429+
kernel=extract_all_rgba_tiles_kernel,
430+
dim=(render_data.num_envs, render_data.height, render_data.width),
431+
inputs=[
432+
tiled_data,
433+
output_buffers[buffer_key],
434+
render_data.num_cols,
435+
render_data.width,
436+
render_data.height,
437+
],
438+
device=DEVICE,
439+
)
444440

445441
def _extract_depth_tiles(
446442
self, render_data: OVRTXRenderData, tiled_depth_data: wp.array, output_buffers: dict
447443
) -> None:
448-
"""Extract per-env depth tiles into output_buffers."""
449-
for env_idx in range(render_data.num_envs):
450-
tile_x = env_idx % render_data.num_cols
451-
tile_y = env_idx // render_data.num_cols
452-
for depth_type in ["depth", "distance_to_image_plane", "distance_to_camera"]:
453-
if depth_type in output_buffers:
454-
wp.launch(
455-
kernel=extract_depth_tile_from_tiled_buffer_kernel,
456-
dim=(render_data.height, render_data.width),
457-
inputs=[
458-
tiled_depth_data,
459-
output_buffers[depth_type][env_idx],
460-
tile_x,
461-
tile_y,
462-
render_data.width,
463-
render_data.height,
464-
],
465-
device=DEVICE,
466-
)
444+
"""Extract per-env depth tiles into output_buffers (single kernel launch)."""
445+
for depth_type in ["depth", "distance_to_image_plane", "distance_to_camera"]:
446+
if depth_type in output_buffers:
447+
wp.launch(
448+
kernel=extract_all_depth_tiles_kernel,
449+
dim=(render_data.num_envs, render_data.height, render_data.width),
450+
inputs=[
451+
tiled_depth_data,
452+
output_buffers[depth_type],
453+
render_data.num_cols,
454+
render_data.width,
455+
render_data.height,
456+
],
457+
device=DEVICE,
458+
)
467459

468460
def _process_render_frame(self, render_data: OVRTXRenderData, frame, output_buffers: dict) -> None:
469461
"""Extract RGB, depth, albedo, and semantic from a single render frame into output_buffers."""

source/isaaclab_ov/isaaclab_ov/renderers/ovrtx_renderer_kernels.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,43 @@ def extract_tile_from_tiled_buffer_kernel(
7373
tile_buffer[y, x, 3] = tiled_buffer[src_y, src_x, 3]
7474

7575

76+
@wp.kernel
77+
def extract_all_rgba_tiles_kernel(
78+
tiled_buffer: wp.array(dtype=wp.uint8, ndim=3), # type: ignore
79+
output_buffer: wp.array(dtype=wp.uint8, ndim=4), # type: ignore (num_envs, H, W, 4)
80+
num_cols: int,
81+
tile_width: int,
82+
tile_height: int,
83+
):
84+
"""Extract ALL RGBA tiles from a tiled buffer in a single kernel launch."""
85+
env_idx, y, x = wp.tid()
86+
tile_x = env_idx % num_cols
87+
tile_y = env_idx // num_cols
88+
src_x = tile_x * tile_width + x
89+
src_y = tile_y * tile_height + y
90+
output_buffer[env_idx, y, x, 0] = tiled_buffer[src_y, src_x, 0]
91+
output_buffer[env_idx, y, x, 1] = tiled_buffer[src_y, src_x, 1]
92+
output_buffer[env_idx, y, x, 2] = tiled_buffer[src_y, src_x, 2]
93+
output_buffer[env_idx, y, x, 3] = tiled_buffer[src_y, src_x, 3]
94+
95+
96+
@wp.kernel
97+
def extract_all_depth_tiles_kernel(
98+
tiled_buffer: wp.array(dtype=wp.float32, ndim=2), # type: ignore
99+
output_buffer: wp.array(dtype=wp.float32, ndim=4), # type: ignore (num_envs, H, W, 1)
100+
num_cols: int,
101+
tile_width: int,
102+
tile_height: int,
103+
):
104+
"""Extract ALL depth tiles from a tiled buffer in a single kernel launch."""
105+
env_idx, y, x = wp.tid()
106+
tile_x = env_idx % num_cols
107+
tile_y = env_idx // num_cols
108+
src_x = tile_x * tile_width + x
109+
src_y = tile_y * tile_height + y
110+
output_buffer[env_idx, y, x, 0] = tiled_buffer[src_y, src_x]
111+
112+
76113
@wp.kernel
77114
def extract_depth_tile_from_tiled_buffer_kernel(
78115
tiled_buffer: wp.array(dtype=wp.float32, ndim=2), # type: ignore

0 commit comments

Comments
 (0)