Skip to content

Commit bf1e621

Browse files
Merge pull request #38 from stackav-oss/feature/voxelization
Add Voxelization
2 parents e4ac60c + d7458d8 commit bf1e621

8 files changed

Lines changed: 1246 additions & 0 deletions

File tree

Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
# Copyright 2025 Stack AV Co.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Voxelization benchmark."""
5+
6+
from typing import Final
7+
8+
import click
9+
import torch
10+
11+
from conch.ops.vision.voxelization import VoxelizationParameter, generate_voxels
12+
from conch.platforms import current_platform
13+
from conch.reference.vision.voxelization import collect_point_features, voxelization_stable
14+
from conch.utils.benchmark import BenchmarkMetadata, benchmark_it
15+
16+
17+
# cuda/triton kernels parallelize on either num_points or max/actual number of voxels
18+
# change num_points and grid_range to adjust point filtering & grid occupancy
19+
# keep max_num_points_per_voxel less than 64 for ideal performance for this impl
20+
# triton/cuda impl requires num_features_per_point = 4
21+
@click.command()
22+
@click.option(
23+
"--num-points",
24+
required=False,
25+
type=int,
26+
default=500000,
27+
help="Number of input points",
28+
)
29+
@click.option(
30+
"--max-num-points-per-voxel",
31+
required=False,
32+
type=int,
33+
default=4,
34+
help="Max number of points per voxel",
35+
)
36+
@click.option(
37+
"--voxel-dim",
38+
required=False,
39+
type=float,
40+
default=2.5,
41+
help="Voxel dimension same for x,y,z",
42+
)
43+
@click.option(
44+
"--grid-range",
45+
required=False,
46+
type=float,
47+
default=50,
48+
help="Grid boundary from -range to range, same for x,y,z",
49+
)
50+
@click.option(
51+
"--torch-ref",
52+
is_flag=True,
53+
help="Flag to enable Torch reference implementation for stable runs",
54+
)
55+
@click.option(
56+
"--iteration-time-ms",
57+
required=False,
58+
type=int,
59+
default=100,
60+
help="Time in milliseconds to run benchmark",
61+
)
62+
@click.option(
63+
"--warmup-time-ms",
64+
required=False,
65+
type=int,
66+
default=10,
67+
help="Time in milliseconds to warmup before recording times",
68+
)
69+
@click.option(
70+
"--gpu",
71+
required=False,
72+
type=str,
73+
default=current_platform.device,
74+
help="Device to run on",
75+
)
76+
@click.option(
77+
"--csv",
78+
is_flag=True,
79+
help="Flag for printing results in CSV format",
80+
)
81+
@click.option(
82+
"--compile-ref",
83+
is_flag=True,
84+
help="Flag to torch.compile() the reference impl",
85+
)
86+
@click.option(
87+
"--cuda-ref",
88+
is_flag=True,
89+
help="Flag to enable CUDA reference implementation",
90+
)
91+
def main(
92+
num_points: int,
93+
max_num_points_per_voxel: int,
94+
voxel_dim: float,
95+
grid_range: float,
96+
torch_ref: bool,
97+
iteration_time_ms: int,
98+
warmup_time_ms: int,
99+
gpu: str,
100+
csv: bool,
101+
compile_ref: bool,
102+
cuda_ref: bool,
103+
) -> None:
104+
"""Benchmark voxelization.
105+
106+
Args:
107+
num_points: Number of input points.
108+
max_num_points_per_voxel: Max number of points per voxel for output feature tensor.
109+
voxel_dim: Voxel dimensions for x,y,z
110+
grid_range: Grid boundary for x,y,z
111+
torch_ref: Flag to use pure torch reference implementation instead of hybrid triton/torch.
112+
iteration_time_ms: Time in milliseconds to run benchmark.
113+
warmup_time_ms: Time in milliseconds to warmup before recording times.
114+
gpu: Which gpu to run on.
115+
csv: Flag to indicate whether or not to print results in CSV format.
116+
compile_ref: Flag to torch.compile() the pure torch reference implementation.
117+
cuda_ref: Flag to enable CUDA reference implementation.
118+
"""
119+
120+
device: Final = torch.device(gpu)
121+
torch.set_default_device(device)
122+
123+
# init points & parameters
124+
num_features_per_point: Final = 4
125+
points = torch.randn((num_points, num_features_per_point), device=device) * grid_range
126+
param = VoxelizationParameter(
127+
min_range=(-grid_range, -grid_range, -grid_range),
128+
max_range=(grid_range, grid_range, grid_range),
129+
voxel_dim=(voxel_dim, voxel_dim, voxel_dim),
130+
max_num_points_per_voxel=max_num_points_per_voxel,
131+
)
132+
133+
print(f"Number of points: {num_points}")
134+
print(f"Grid dimensions: {param.grid_dim}")
135+
print(f"Max number of voxels: {param.max_num_voxels}")
136+
print(f"Max number of points per voxel: {param.max_num_points_per_voxel}")
137+
138+
args = (points, param)
139+
140+
def generate_voxels_torch(
141+
points: torch.Tensor, param: VoxelizationParameter
142+
) -> tuple[torch.tensor, torch.tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
143+
"""reference triton/torch hybrid, 2-step, stable voxelization first then generate a feature tensor."""
144+
use_triton = not torch_ref
145+
actual_num_points_per_voxel, point_raw_indices, flat_voxel_indices = voxelization_stable(
146+
points, param, use_triton=use_triton
147+
)
148+
point_features, capped_num_points_per_voxel = collect_point_features(
149+
points, actual_num_points_per_voxel, point_raw_indices, param, use_triton=use_triton
150+
)
151+
return (
152+
actual_num_points_per_voxel,
153+
point_raw_indices,
154+
flat_voxel_indices,
155+
point_features,
156+
capped_num_points_per_voxel,
157+
)
158+
159+
# run base version and report voxelization stats
160+
actual_num_points_per_voxel, point_raw_indices, _, _, _ = generate_voxels_torch(*args)
161+
print("voxelization done, stats:")
162+
print(f"number of points within grid boundary: {point_raw_indices.shape[0]}")
163+
print(f"number of filled voxels: {actual_num_points_per_voxel.shape[0]}")
164+
print(f"Avg number of points per voxel: {torch.mean(actual_num_points_per_voxel, dtype=torch.float32)}")
165+
print(f"Max number of points per voxel: {torch.max(actual_num_points_per_voxel)}")
166+
overflow_count = (actual_num_points_per_voxel > param.max_num_points_per_voxel).int().sum()
167+
print(f"Number of voxels with overflowing points: {overflow_count}")
168+
169+
# Benchmark implementations
170+
metadata = BenchmarkMetadata(
171+
platform=current_platform.name(),
172+
params={
173+
"num_points": num_points,
174+
"max_num_points_per_voxel": max_num_points_per_voxel,
175+
"voxel_dim": voxel_dim,
176+
"grid_range": grid_range,
177+
},
178+
)
179+
180+
baseline_result = benchmark_it(
181+
lambda: generate_voxels_torch(*args),
182+
tag="Baseline",
183+
metadata=metadata,
184+
iteration_time_ms=iteration_time_ms,
185+
warmup_time_ms=warmup_time_ms,
186+
)
187+
188+
conch_result = benchmark_it(
189+
lambda: generate_voxels(*args),
190+
tag="Conch",
191+
metadata=metadata,
192+
iteration_time_ms=iteration_time_ms,
193+
warmup_time_ms=warmup_time_ms,
194+
)
195+
196+
reference_cuda_result = None
197+
generate_voxels_cuda_fn = None
198+
if cuda_ref:
199+
from conch_cuda_ext.ops.vision.voxelization.voxelization import generate_voxels as generate_voxels_cuda
200+
201+
generate_voxels_cuda_fn = generate_voxels_cuda
202+
203+
if generate_voxels_cuda_fn:
204+
args_cuda = (
205+
points,
206+
param.min_range[0],
207+
param.min_range[1],
208+
param.min_range[2],
209+
param.max_range[0],
210+
param.max_range[1],
211+
param.max_range[2],
212+
param.voxel_dim[0],
213+
param.voxel_dim[1],
214+
param.voxel_dim[2],
215+
param.grid_dim[0],
216+
param.grid_dim[1],
217+
param.grid_dim[2],
218+
param.max_num_points_per_voxel,
219+
param.max_num_voxels,
220+
)
221+
reference_cuda_result = benchmark_it(
222+
lambda: generate_voxels_cuda_fn(*args_cuda),
223+
tag="CUDA",
224+
metadata=metadata,
225+
iteration_time_ms=iteration_time_ms,
226+
warmup_time_ms=warmup_time_ms,
227+
)
228+
229+
reference_compiled_result = None
230+
reference_compiled_fn = None
231+
232+
if compile_ref and torch_ref:
233+
# Compile the reference implementation if requested
234+
reference_compiled_fn = torch.compile(generate_voxels_torch)
235+
236+
if reference_compiled_fn:
237+
baseline_result = benchmark_it(
238+
lambda: reference_compiled_fn(*args),
239+
tag="Baseline (Torch compiled)",
240+
metadata=metadata,
241+
iteration_time_ms=iteration_time_ms,
242+
warmup_time_ms=warmup_time_ms,
243+
)
244+
245+
conch_result.print_parameters(csv=csv)
246+
conch_result.print_results(csv=csv)
247+
baseline_result.print_results(csv=csv)
248+
if reference_cuda_result:
249+
reference_cuda_result.print_results(csv=csv)
250+
if reference_compiled_result:
251+
reference_compiled_result.print_results(csv=csv)
252+
253+
254+
if __name__ == "__main__":
255+
main()

0 commit comments

Comments
 (0)