Skip to content

Commit e4ac60c

Browse files
Merge pull request #37 from stackav-oss/feature/jmanning/bevpool
Add BEVPool Kernel
2 parents fa647b8 + 791669d commit e4ac60c

16 files changed

Lines changed: 1808 additions & 1 deletion

File tree

.gitignore

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,12 @@ dist/
1010
results/
1111

1212
.coverage
13+
1314
.vscode
15+
16+
# Ignore any shared libraries
17+
*.so
18+
19+
# Ignore HIP-ified CUDA kernels
20+
conch_cuda_ext/**/*_hip.cc
21+
conch_cuda_ext/**/*.hip

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ Check out the [developer instructions](./docs/getting_started/developer_environm
9999

100100
We were inspired by and leverage components of the following libraries:
101101

102+
- [BEVFusion](https://github.com/mit-han-lab/bevfusion)
102103
- [bitsandbytes](https://github.com/bitsandbytes-foundation/bitsandbytes)
103104
- [GemLite](https://github.com/mobiusml/gemlite)
104105
- [Torchvision](https://github.com/pytorch/vision)
Lines changed: 376 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,376 @@
1+
# Copyright 2025 Stack AV Co.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""BEV Pool backward pass benchmark."""
5+
6+
import sys
7+
from typing import Final
8+
9+
import click
10+
import torch
11+
12+
from conch.ops.vision.bev_pool import bev_pool_backward as bev_pool_backward_conch
13+
from conch.platforms import current_platform
14+
from conch.reference.vision.bev_pool import bev_pool_backward as bev_pool_backward_ref
15+
from conch.third_party.vllm.utils import seed_everything
16+
from conch.utils.benchmark import BenchmarkMetadata, benchmark_it
17+
18+
19+
def _create_bev_pool_backward_data(
20+
num_points: int,
21+
num_channels: int,
22+
batch_size: int,
23+
grid_cells_z: int,
24+
grid_cells_x: int,
25+
grid_cells_y: int,
26+
device: torch.device,
27+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
28+
"""Create test data for BEV Pool backward operation.
29+
30+
Args:
31+
num_points: Number of input points.
32+
num_channels: Number of feature channels per point.
33+
batch_size: Number of batches.
34+
grid_cells_z: Number of Z grid cells.
35+
grid_cells_x: Number of X grid cells.
36+
grid_cells_y: Number of Y grid cells.
37+
device: Device to create tensors on.
38+
39+
Returns:
40+
Tuple of (grad_output, geom_feats, interval_starts, interval_lengths).
41+
"""
42+
# Create random gradient output in the shape of BEV pool output
43+
grad_output = torch.randn(
44+
batch_size, grid_cells_z, grid_cells_x, grid_cells_y, num_channels, device=device, dtype=torch.float32
45+
)
46+
47+
# Create geometry features with random but valid coordinates
48+
geom_feats = torch.stack(
49+
[
50+
torch.randint(0, grid_cells_x, (num_points,), device=device), # X coordinate
51+
torch.randint(0, grid_cells_y, (num_points,), device=device), # Y coordinate
52+
torch.randint(0, grid_cells_z, (num_points,), device=device), # Z coordinate
53+
torch.randint(0, batch_size, (num_points,), device=device), # Batch index
54+
],
55+
dim=1,
56+
).to(torch.int32)
57+
58+
# Create a linear index for sorting and grouping
59+
linear_indices = (
60+
geom_feats[:, 3] * (grid_cells_z * grid_cells_x * grid_cells_y) # batch
61+
+ geom_feats[:, 2] * (grid_cells_x * grid_cells_y) # z
62+
+ geom_feats[:, 1] * grid_cells_y # x
63+
+ geom_feats[:, 0] # y
64+
)
65+
66+
# Sort by linear indices to group points in same voxels
67+
sorted_indices = torch.argsort(linear_indices)
68+
sorted_linear_indices = linear_indices[sorted_indices]
69+
70+
# Find unique voxels and create intervals
71+
unique_indices, counts = torch.unique_consecutive(sorted_linear_indices, return_counts=True)
72+
num_intervals = len(unique_indices)
73+
74+
# Create interval starts and lengths
75+
interval_starts = torch.zeros(num_intervals, device=device, dtype=torch.int32)
76+
interval_lengths = counts.to(torch.int32)
77+
78+
current_start = 0
79+
for i in range(num_intervals):
80+
interval_starts[i] = current_start
81+
current_start += interval_lengths[i]
82+
83+
# Reorder geometry by sorted indices
84+
geom_feats = geom_feats[sorted_indices]
85+
86+
return grad_output, geom_feats, interval_starts, interval_lengths
87+
88+
89+
@click.command()
90+
@click.option(
91+
"--num-points",
92+
required=False,
93+
type=int,
94+
default=6000000,
95+
help="Number of input points",
96+
)
97+
@click.option(
98+
"--num-channels",
99+
required=False,
100+
type=int,
101+
default=64,
102+
help="Number of feature channels per point",
103+
)
104+
@click.option(
105+
"--batch-size",
106+
required=False,
107+
type=int,
108+
default=1,
109+
help="Batch size",
110+
)
111+
@click.option(
112+
"--grid-cells-z",
113+
required=False,
114+
type=int,
115+
default=20,
116+
help="Number of Z grid cells",
117+
)
118+
@click.option(
119+
"--grid-cells-x",
120+
required=False,
121+
type=int,
122+
default=800,
123+
help="Number of X grid cells",
124+
)
125+
@click.option(
126+
"--grid-cells-y",
127+
required=False,
128+
type=int,
129+
default=800,
130+
help="Number of Y grid cells",
131+
)
132+
@click.option(
133+
"--iteration-time-ms",
134+
required=False,
135+
type=int,
136+
default=10000,
137+
help="Time in milliseconds to run benchmark",
138+
)
139+
@click.option(
140+
"--warmup-time-ms",
141+
required=False,
142+
type=int,
143+
default=1000,
144+
help="Time in milliseconds to warmup before recording times",
145+
)
146+
@click.option(
147+
"--absolute-tolerance",
148+
required=False,
149+
type=float,
150+
default=1e-3,
151+
help="Absolute tolerance to match with",
152+
)
153+
@click.option(
154+
"--verbose",
155+
is_flag=True,
156+
help="Flag for printing verbose output",
157+
)
158+
@click.option(
159+
"--gpu",
160+
required=False,
161+
type=str,
162+
default=current_platform.device,
163+
help="Device to run on",
164+
)
165+
@click.option(
166+
"--csv",
167+
is_flag=True,
168+
help="Flag for printing results in CSV format",
169+
)
170+
@click.option(
171+
"--compile-ref",
172+
is_flag=True,
173+
help="Flag to torch.compile() the reference impl",
174+
)
175+
@click.option(
176+
"--compile-conch",
177+
is_flag=True,
178+
help="Flag to torch.compile() the Conch impl",
179+
)
180+
@click.option(
181+
"--cuda-ref",
182+
is_flag=True,
183+
help="Flag to enable CUDA reference implementation",
184+
)
185+
def main(
186+
num_points: int,
187+
num_channels: int,
188+
batch_size: int,
189+
grid_cells_z: int,
190+
grid_cells_x: int,
191+
grid_cells_y: int,
192+
iteration_time_ms: int,
193+
warmup_time_ms: int,
194+
absolute_tolerance: float,
195+
verbose: bool,
196+
gpu: str,
197+
csv: bool,
198+
compile_ref: bool,
199+
compile_conch: bool,
200+
cuda_ref: bool,
201+
) -> None:
202+
"""Benchmark BEV Pool backward pass.
203+
204+
Args:
205+
num_points: Number of input points.
206+
num_channels: Number of feature channels per point.
207+
batch_size: Batch size.
208+
grid_cells_z: Number of Z grid cells.
209+
grid_cells_x: Number of X grid cells.
210+
grid_cells_y: Number of Y grid cells.
211+
iteration_time_ms: Time in milliseconds to run benchmark.
212+
warmup_time_ms: Time in milliseconds to warmup before recording times.
213+
absolute_tolerance: Absolute tolerance used to check accuracy.
214+
verbose: Flag to indicate whether or not to print verbose output.
215+
gpu: Which gpu to run on.
216+
csv: Flag to indicate whether or not to print results in CSV format.
217+
compile_ref: Flag to torch.compile() the reference implementation.
218+
compile_conch: Flag to torch.compile() the Conch implementation.
219+
cuda_ref: Flag to enable CUDA reference implementation.
220+
"""
221+
seed: Final = 0
222+
seed_everything(seed)
223+
224+
device: Final = torch.device(gpu)
225+
torch.set_default_device(device)
226+
227+
metadata = BenchmarkMetadata(
228+
platform=current_platform.name(),
229+
params={
230+
"num_points": num_points,
231+
"num_channels": num_channels,
232+
"batch_size": batch_size,
233+
"grid_cells_z": grid_cells_z,
234+
"grid_cells_x": grid_cells_x,
235+
"grid_cells_y": grid_cells_y,
236+
},
237+
)
238+
239+
# Create test data
240+
grad_output, geom_feats, interval_starts, interval_lengths = _create_bev_pool_backward_data(
241+
num_points=num_points,
242+
num_channels=num_channels,
243+
batch_size=batch_size,
244+
grid_cells_z=grid_cells_z,
245+
grid_cells_x=grid_cells_x,
246+
grid_cells_y=grid_cells_y,
247+
device=device,
248+
)
249+
250+
# Compile functions if requested
251+
bev_pool_backward_compiled_fn = None
252+
bev_pool_backward_cuda_fn = None
253+
254+
if compile_ref:
255+
# Compile the reference implementation if requested
256+
bev_pool_backward_compiled_fn = torch.compile(bev_pool_backward_ref)
257+
258+
if cuda_ref:
259+
from conch_cuda_ext.ops.vision.bev_pool.bev_pool import bev_pool_backward as bev_pool_bwd_cuda
260+
261+
bev_pool_backward_cuda_fn = bev_pool_bwd_cuda
262+
263+
bev_pool_backward_conch_compiled_fn = None
264+
if compile_conch:
265+
bev_pool_backward_conch_compiled_fn = torch.compile(bev_pool_backward_conch)
266+
267+
# Test both implementations
268+
args = (
269+
grad_output,
270+
geom_feats,
271+
interval_starts,
272+
interval_lengths,
273+
)
274+
275+
ref_output = bev_pool_backward_ref(
276+
*args,
277+
batch_size,
278+
grid_cells_z,
279+
grid_cells_x,
280+
grid_cells_y,
281+
)
282+
conch_output = bev_pool_backward_conch(*args)
283+
284+
# Accuracy checks
285+
if not torch.allclose(ref_output, conch_output, atol=absolute_tolerance):
286+
print(f"WARNING: Reference and Conch results differ! (atol={absolute_tolerance})", file=sys.stderr)
287+
print(f"Output max diff: {(ref_output - conch_output).abs().max().item()}", file=sys.stderr)
288+
print(f"Ref shape: {ref_output.shape}, Conch shape: {conch_output.shape}", file=sys.stderr)
289+
290+
if verbose:
291+
print(f"Reference output: {ref_output}", file=sys.stderr)
292+
print(f"Conch output: {conch_output}", file=sys.stderr)
293+
else:
294+
print(f"Reference vs Conch: Results matched with atol={absolute_tolerance} :)", file=sys.stderr)
295+
296+
# Benchmark implementations
297+
baseline_result = benchmark_it(
298+
lambda: bev_pool_backward_ref(
299+
*args,
300+
batch_size=batch_size,
301+
grid_cells_z=grid_cells_z,
302+
grid_cells_x=grid_cells_x,
303+
grid_cells_y=grid_cells_y,
304+
),
305+
tag="Baseline",
306+
metadata=metadata,
307+
iteration_time_ms=iteration_time_ms,
308+
warmup_time_ms=warmup_time_ms,
309+
)
310+
311+
conch_result = benchmark_it(
312+
lambda: bev_pool_backward_conch(*args),
313+
tag="Conch",
314+
metadata=metadata,
315+
iteration_time_ms=iteration_time_ms,
316+
warmup_time_ms=warmup_time_ms,
317+
)
318+
319+
reference_compiled_result = None
320+
reference_cuda_result = None
321+
conch_compiled_result = None
322+
323+
if bev_pool_backward_compiled_fn:
324+
reference_compiled_result = benchmark_it(
325+
lambda: bev_pool_backward_compiled_fn(
326+
*args,
327+
batch_size=batch_size,
328+
grid_cells_z=grid_cells_z,
329+
grid_cells_x=grid_cells_x,
330+
grid_cells_y=grid_cells_y,
331+
),
332+
tag="Reference (Compiled)",
333+
metadata=metadata,
334+
iteration_time_ms=iteration_time_ms,
335+
warmup_time_ms=warmup_time_ms,
336+
)
337+
338+
if bev_pool_backward_cuda_fn:
339+
reference_cuda_result = benchmark_it(
340+
# Note: cannot use kwargs for CUDA fn
341+
lambda: bev_pool_backward_cuda_fn(
342+
*args,
343+
batch_size,
344+
grid_cells_z,
345+
grid_cells_x,
346+
grid_cells_y,
347+
),
348+
tag="CUDA",
349+
metadata=metadata,
350+
iteration_time_ms=iteration_time_ms,
351+
warmup_time_ms=warmup_time_ms,
352+
)
353+
354+
if bev_pool_backward_conch_compiled_fn:
355+
conch_compiled_result = benchmark_it(
356+
lambda: bev_pool_backward_conch_compiled_fn(*args),
357+
tag="Conch (Compiled)",
358+
metadata=metadata,
359+
iteration_time_ms=iteration_time_ms,
360+
warmup_time_ms=warmup_time_ms,
361+
)
362+
363+
# Print results
364+
conch_result.print_parameters(csv=csv)
365+
conch_result.print_results(csv=csv)
366+
baseline_result.print_results(csv=csv)
367+
if reference_compiled_result:
368+
reference_compiled_result.print_results(csv=csv)
369+
if reference_cuda_result:
370+
reference_cuda_result.print_results(csv=csv)
371+
if conch_compiled_result:
372+
conch_compiled_result.print_results(csv=csv)
373+
374+
375+
if __name__ == "__main__":
376+
main()

0 commit comments

Comments
 (0)