Skip to content

Commit cb8a847

Browse files
author
Lenny Wang
committed
review changes
1 parent 86939aa commit cb8a847

5 files changed

Lines changed: 372 additions & 350 deletions

File tree

benchmarks/voxelization_benchmark.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,9 @@
88
import click
99
import torch
1010

11-
from conch.ops.vision.voxelization import (
12-
VoxelizationParameter,
13-
collect_point_features,
14-
generate_voxels,
15-
voxelization_stable,
16-
)
11+
from conch.ops.vision.voxelization import VoxelizationParameter, generate_voxels
1712
from conch.platforms import current_platform
13+
from conch.reference.vision.voxelization import collect_point_features, voxelization_stable
1814
from conch.utils.benchmark import BenchmarkMetadata, benchmark_it
1915

2016

@@ -82,6 +78,11 @@
8278
is_flag=True,
8379
help="Flag for printing results in CSV format",
8480
)
81+
@click.option(
82+
"--compile-ref",
83+
is_flag=True,
84+
help="Flag to torch.compile() the reference impl",
85+
)
8586
@click.option(
8687
"--cuda-ref",
8788
is_flag=True,
@@ -97,6 +98,7 @@ def main(
9798
warmup_time_ms: int,
9899
gpu: str,
99100
csv: bool,
101+
compile_ref: bool,
100102
cuda_ref: bool,
101103
) -> None:
102104
"""Benchmark voxelization.
@@ -106,11 +108,12 @@ def main(
106108
max_num_points_per_voxel: Max number of points per voxel for output feature tensor.
107109
voxel_dim: Voxel dimensions for x,y,z
108110
grid_range: Grid boundary for x,y,z
109-
torch_ref: Flag to enable torch reference implementation.
111+
torch_ref: Flag to use pure torch reference implementation instead of hybrid triton/torch.
110112
iteration_time_ms: Time in milliseconds to run benchmark.
111113
warmup_time_ms: Time in milliseconds to warmup before recording times.
112114
gpu: Which gpu to run on.
113115
csv: Flag to indicate whether or not to print results in CSV format.
116+
compile_ref: Flag to torch.compile() the pure torch reference implementation.
114117
cuda_ref: Flag to enable CUDA reference implementation.
115118
"""
116119

@@ -137,7 +140,7 @@ def main(
137140
def generate_voxels_torch(
138141
points: torch.Tensor, param: VoxelizationParameter
139142
) -> tuple[torch.tensor, torch.tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
140-
"""triton/torch hybrid, 2-step, stable voxelization first then generate a feature tensor."""
143+
"""reference triton/torch hybrid, 2-step, stable voxelization first then generate a feature tensor."""
141144
use_triton = not torch_ref
142145
actual_num_points_per_voxel, point_raw_indices, flat_voxel_indices = voxelization_stable(
143146
points, param, use_triton=use_triton
@@ -223,11 +226,29 @@ def generate_voxels_torch(
223226
warmup_time_ms=warmup_time_ms,
224227
)
225228

229+
reference_compiled_result = None
230+
reference_compiled_fn = None
231+
232+
if compile_ref and torch_ref:
233+
# Compile the reference implementation if requested
234+
reference_compiled_fn = torch.compile(generate_voxels_torch)
235+
236+
if reference_compiled_fn:
237+
baseline_result = benchmark_it(
238+
lambda: reference_compiled_fn(*args),
239+
tag="Baseline (Torch compiled)",
240+
metadata=metadata,
241+
iteration_time_ms=iteration_time_ms,
242+
warmup_time_ms=warmup_time_ms,
243+
)
244+
226245
conch_result.print_parameters(csv=csv)
227246
conch_result.print_results(csv=csv)
228247
baseline_result.print_results(csv=csv)
229248
if reference_cuda_result:
230249
reference_cuda_result.print_results(csv=csv)
250+
if reference_compiled_result:
251+
reference_compiled_result.print_results(csv=csv)
231252

232253

233254
if __name__ == "__main__":

conch/kernels/vision/voxelization.py

Lines changed: 1 addition & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -8,59 +8,6 @@
88
import triton.language as tl
99

1010

11-
@triton.jit
12-
def filter_and_label_points_triton_kernel( # noqa: PLR0913, D417
13-
# input
14-
points_ptr: torch.Tensor,
15-
num_points: int,
16-
num_features_per_point: int,
17-
# parameters
18-
min_x: float,
19-
min_y: float,
20-
min_z: float,
21-
max_x: float,
22-
max_y: float,
23-
max_z: float,
24-
voxel_dim_x: float,
25-
voxel_dim_y: float,
26-
voxel_dim_z: float,
27-
grid_dim_x: int,
28-
grid_dim_y: int,
29-
grid_dim_z: int,
30-
max_num_voxels: int,
31-
# output
32-
point_voxel_indices_ptr: torch.Tensor,
33-
# Constants
34-
cxpr_block_size: tl.constexpr,
35-
) -> None:
36-
"""Filter valid points and label each with a voxel index.
37-
38-
Args:
39-
points_ptr: input points, shape (num_points, num_features_per_point).
40-
voxelization parameters
41-
point_voxel_indices_ptr: output per point flattened voxel indices, shape (num_points).
42-
"""
43-
block_idx = tl.program_id(axis=0)
44-
point_idx = block_idx * cxpr_block_size + tl.arange(0, cxpr_block_size)
45-
point_mask = point_idx < num_points
46-
47-
point_x = tl.load(points_ptr + point_idx * num_features_per_point + 0, mask=point_mask, other=max_x + voxel_dim_x)
48-
point_y = tl.load(points_ptr + point_idx * num_features_per_point + 1, mask=point_mask, other=max_y + voxel_dim_y)
49-
point_z = tl.load(points_ptr + point_idx * num_features_per_point + 2, mask=point_mask, other=max_z + voxel_dim_z)
50-
51-
voxel_x = tl.floor((point_x - min_x) / voxel_dim_x).to(tl.int32)
52-
voxel_y = tl.floor((point_y - min_y) / voxel_dim_y).to(tl.int32)
53-
voxel_z = tl.floor((point_z - min_z) / voxel_dim_z).to(tl.int32)
54-
55-
valid_x = (voxel_x >= 0) & (voxel_x < grid_dim_x)
56-
valid_y = (voxel_y >= 0) & (voxel_y < grid_dim_y)
57-
valid_z = (voxel_z >= 0) & (voxel_z < grid_dim_z)
58-
valid_point = point_mask & valid_x & valid_y & valid_z
59-
60-
flat_voxel_idx = tl.where(valid_point, ((voxel_z * grid_dim_y + voxel_y) * grid_dim_x + voxel_x), max_num_voxels)
61-
tl.store(point_voxel_indices_ptr + point_idx, flat_voxel_idx, mask=point_mask)
62-
63-
6411
@triton.jit
6512
def generate_dense_voxels_triton_kernel( # noqa: PLR0913, D417
6613
# input
@@ -187,7 +134,7 @@ def generate_voxels_triton_kernel( # noqa: PLR0913, D417
187134
tl.store(voxel_indices_ptr + voxel_idx * 4 + 2, voxel_z, mask=valid_voxel)
188135

189136
# store all feature points, including padded 0s
190-
for point_idx in range(0, max_num_points_per_voxel, 1):
137+
for point_idx in range(0, max_num_points_per_voxel):
191138
input_idx = flat_voxel_idx * max_num_points_per_voxel + point_idx
192139
point_x = tl.load(dense_point_features_ptr + input_idx * 4 + 0, mask=valid_voxel)
193140
point_y = tl.load(dense_point_features_ptr + input_idx * 4 + 1, mask=valid_voxel)
@@ -199,53 +146,3 @@ def generate_voxels_triton_kernel( # noqa: PLR0913, D417
199146
tl.store(point_features_ptr + output_idx * 4 + 1, point_y, mask=valid_voxel)
200147
tl.store(point_features_ptr + output_idx * 4 + 2, point_z, mask=valid_voxel)
201148
tl.store(point_features_ptr + output_idx * 4 + 3, point_w, mask=valid_voxel)
202-
203-
204-
@triton.jit
205-
def collect_point_features_triton_kernel( # noqa: PLR0913, D417
206-
# input
207-
points_ptr: torch.Tensor,
208-
num_features_per_point: int,
209-
segment_offsets_ptr: torch.Tensor,
210-
num_filled_voxels: int,
211-
point_indices_ptr: torch.Tensor,
212-
# parameters
213-
max_num_points_per_voxel: int,
214-
# output
215-
point_features_ptr: torch.Tensor,
216-
capped_num_points_per_voxel_ptr: torch.Tensor,
217-
# Constants
218-
cxpr_block_size: tl.constexpr,
219-
) -> None:
220-
"""Group valid points into dense voxels.
221-
222-
Args:
223-
points_ptr: input points tensor, shape (num_points, num_features_per_point)
224-
segment_offsets_ptr: input segment end offsets, shape (num_filled_voxels)
225-
point_indices_ptr: input raw point indices, shape (num_valid_points)
226-
voxelization parameters
227-
point_features_ptr: output voxel point features, shape (num_filled_voxels, max_num_points_per_voxel, num_features_per_point)
228-
capped_num_points_per_voxel_ptr: output number of points per voxel tensor after capping, shape (num_filled_voxels)
229-
"""
230-
block_idx = tl.program_id(axis=0)
231-
voxel_idx = block_idx * cxpr_block_size + tl.arange(0, cxpr_block_size)
232-
voxel_mask = voxel_idx < num_filled_voxels
233-
234-
# top n filtering
235-
segment_start = tl.load(segment_offsets_ptr + voxel_idx - 1, mask=(voxel_mask & (voxel_idx > 0)), other=0)
236-
segment_end = tl.load(segment_offsets_ptr + voxel_idx, mask=voxel_mask, other=0)
237-
num_points_in_voxel = segment_end - segment_start
238-
num_points_in_voxel = tl.minimum(num_points_in_voxel, max_num_points_per_voxel)
239-
tl.store(capped_num_points_per_voxel_ptr + voxel_idx, num_points_in_voxel, mask=voxel_mask)
240-
241-
for voxel_point_idx in range(0, max_num_points_per_voxel, 1):
242-
# this mask is sufficient since other num_points_in_voxel == 0
243-
per_voxel_mask = voxel_point_idx < num_points_in_voxel
244-
245-
raw_point_idx = tl.load(point_indices_ptr + segment_start + voxel_point_idx, mask=per_voxel_mask)
246-
output_idx = voxel_idx * max_num_points_per_voxel + voxel_point_idx
247-
for feature_idx in range(0, num_features_per_point, 1):
248-
value = tl.load(
249-
points_ptr + raw_point_idx * num_features_per_point + feature_idx, mask=per_voxel_mask, other=0
250-
)
251-
tl.store(point_features_ptr + output_idx * num_features_per_point + feature_idx, value, mask=voxel_mask)

0 commit comments

Comments
 (0)