Skip to content

Commit 791669d

Browse files
Cleanup bevpool benchmarks
1 parent 8a68608 commit 791669d

2 files changed

Lines changed: 147 additions & 20 deletions

File tree

benchmarks/bev_pool_backward_benchmark.py

Lines changed: 84 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def _create_bev_pool_backward_data(
9191
"--num-points",
9292
required=False,
9393
type=int,
94-
default=500000,
94+
default=6000000,
9595
help="Number of input points",
9696
)
9797
@click.option(
@@ -112,21 +112,21 @@ def _create_bev_pool_backward_data(
112112
"--grid-cells-z",
113113
required=False,
114114
type=int,
115-
default=32,
115+
default=20,
116116
help="Number of Z grid cells",
117117
)
118118
@click.option(
119119
"--grid-cells-x",
120120
required=False,
121121
type=int,
122-
default=250,
122+
default=800,
123123
help="Number of X grid cells",
124124
)
125125
@click.option(
126126
"--grid-cells-y",
127127
required=False,
128128
type=int,
129-
default=250,
129+
default=800,
130130
help="Number of Y grid cells",
131131
)
132132
@click.option(
@@ -177,6 +177,11 @@ def _create_bev_pool_backward_data(
177177
is_flag=True,
178178
help="Flag to torch.compile() the Conch impl",
179179
)
180+
@click.option(
181+
"--cuda-ref",
182+
is_flag=True,
183+
help="Flag to enable CUDA reference implementation",
184+
)
180185
def main(
181186
num_points: int,
182187
num_channels: int,
@@ -192,6 +197,7 @@ def main(
192197
csv: bool,
193198
compile_ref: bool,
194199
compile_conch: bool,
200+
cuda_ref: bool,
195201
) -> None:
196202
"""Benchmark BEV Pool backward pass.
197203
@@ -210,6 +216,7 @@ def main(
210216
csv: Flag to indicate whether or not to print results in CSV format.
211217
compile_ref: Flag to torch.compile() the reference implementation.
212218
compile_conch: Flag to torch.compile() the Conch implementation.
219+
cuda_ref: Flag to enable CUDA reference implementation.
213220
"""
214221
seed: Final = 0
215222
seed_everything(seed)
@@ -241,8 +248,21 @@ def main(
241248
)
242249

243250
# Compile functions if requested
244-
bev_pool_backward_ref_fn = torch.compile(bev_pool_backward_ref) if compile_ref else bev_pool_backward_ref
245-
bev_pool_backward_conch_fn = torch.compile(bev_pool_backward_conch) if compile_conch else bev_pool_backward_conch
251+
bev_pool_backward_compiled_fn = None
252+
bev_pool_backward_cuda_fn = None
253+
254+
if compile_ref:
255+
# Compile the reference implementation if requested
256+
bev_pool_backward_compiled_fn = torch.compile(bev_pool_backward_ref)
257+
258+
if cuda_ref:
259+
from conch_cuda_ext.ops.vision.bev_pool.bev_pool import bev_pool_backward as bev_pool_bwd_cuda
260+
261+
bev_pool_backward_cuda_fn = bev_pool_bwd_cuda
262+
263+
bev_pool_backward_conch_compiled_fn = None
264+
if compile_conch:
265+
bev_pool_backward_conch_compiled_fn = torch.compile(bev_pool_backward_conch)
246266

247267
# Test both implementations
248268
args = (
@@ -252,14 +272,14 @@ def main(
252272
interval_lengths,
253273
)
254274

255-
ref_output = bev_pool_backward_ref_fn(
275+
ref_output = bev_pool_backward_ref(
256276
*args,
257-
batch_size=batch_size,
258-
grid_cells_z=grid_cells_z,
259-
grid_cells_x=grid_cells_x,
260-
grid_cells_y=grid_cells_y,
277+
batch_size,
278+
grid_cells_z,
279+
grid_cells_x,
280+
grid_cells_y,
261281
)
262-
conch_output = bev_pool_backward_conch_fn(*args)
282+
conch_output = bev_pool_backward_conch(*args)
263283

264284
# Accuracy checks
265285
if not torch.allclose(ref_output, conch_output, atol=absolute_tolerance):
@@ -275,7 +295,7 @@ def main(
275295

276296
# Benchmark implementations
277297
baseline_result = benchmark_it(
278-
lambda: bev_pool_backward_ref_fn(
298+
lambda: bev_pool_backward_ref(
279299
*args,
280300
batch_size=batch_size,
281301
grid_cells_z=grid_cells_z,
@@ -289,17 +309,67 @@ def main(
289309
)
290310

291311
conch_result = benchmark_it(
292-
lambda: bev_pool_backward_conch_fn(*args),
312+
lambda: bev_pool_backward_conch(*args),
293313
tag="Conch",
294314
metadata=metadata,
295315
iteration_time_ms=iteration_time_ms,
296316
warmup_time_ms=warmup_time_ms,
297317
)
298318

319+
reference_compiled_result = None
320+
reference_cuda_result = None
321+
conch_compiled_result = None
322+
323+
if bev_pool_backward_compiled_fn:
324+
reference_compiled_result = benchmark_it(
325+
lambda: bev_pool_backward_compiled_fn(
326+
*args,
327+
batch_size=batch_size,
328+
grid_cells_z=grid_cells_z,
329+
grid_cells_x=grid_cells_x,
330+
grid_cells_y=grid_cells_y,
331+
),
332+
tag="Reference (Compiled)",
333+
metadata=metadata,
334+
iteration_time_ms=iteration_time_ms,
335+
warmup_time_ms=warmup_time_ms,
336+
)
337+
338+
if bev_pool_backward_cuda_fn:
339+
reference_cuda_result = benchmark_it(
340+
# Note: cannot use kwargs for CUDA fn
341+
lambda: bev_pool_backward_cuda_fn(
342+
*args,
343+
batch_size,
344+
grid_cells_z,
345+
grid_cells_x,
346+
grid_cells_y,
347+
),
348+
tag="CUDA",
349+
metadata=metadata,
350+
iteration_time_ms=iteration_time_ms,
351+
warmup_time_ms=warmup_time_ms,
352+
)
353+
354+
if bev_pool_backward_conch_compiled_fn:
355+
conch_compiled_result = benchmark_it(
356+
lambda: bev_pool_backward_conch_compiled_fn(*args),
357+
tag="Conch (Compiled)",
358+
metadata=metadata,
359+
iteration_time_ms=iteration_time_ms,
360+
warmup_time_ms=warmup_time_ms,
361+
)
362+
299363
# Print results
300364
conch_result.print_parameters(csv=csv)
301365
conch_result.print_results(csv=csv)
302366
baseline_result.print_results(csv=csv)
367+
if reference_compiled_result:
368+
reference_compiled_result.print_results(csv=csv)
369+
if reference_cuda_result:
370+
reference_cuda_result.print_results(csv=csv)
371+
if conch_compiled_result:
372+
conch_compiled_result.print_results(csv=csv)
303373

304374

305375
if __name__ == "__main__":

benchmarks/bev_pool_benchmark.py

Lines changed: 63 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,11 @@ def _create_bev_pool_data(
176176
is_flag=True,
177177
help="Flag to torch.compile() the Conch impl",
178178
)
179+
@click.option(
180+
"--cuda-ref",
181+
is_flag=True,
182+
help="Flag to enable CUDA reference implementation",
183+
)
179184
def main(
180185
num_points: int,
181186
num_channels: int,
@@ -191,6 +196,7 @@ def main(
191196
csv: bool,
192197
compile_ref: bool,
193198
compile_conch: bool,
199+
cuda_ref: bool,
194200
) -> None:
195201
"""Benchmark BEV Pool.
196202
@@ -209,6 +215,7 @@ def main(
209215
csv: Flag to indicate whether or not to print results in CSV format.
210216
compile_ref: Flag to torch.compile() the reference implementation.
211217
compile_conch: Flag to torch.compile() the Conch implementation.
218+
cuda_ref: Flag to enable CUDA reference implementation.
212219
"""
213220
seed: Final = 0
214221
seed_everything(seed)
@@ -245,8 +252,21 @@ def main(
245252
print(f"Max interval length: {interval_lengths.float().max().item()}", file=sys.stderr)
246253

247254
# Compile functions if requested
248-
bev_pool_ref_fn = torch.compile(bev_pool_ref) if compile_ref else bev_pool_ref
249-
bev_pool_conch_fn = torch.compile(bev_pool_conch) if compile_conch else bev_pool_conch
255+
bev_pool_forward_compiled_fn = None
256+
bev_pool_forward_cuda_fn = None
257+
258+
if compile_ref:
259+
# Compile the reference implementation if requested
260+
bev_pool_forward_compiled_fn = torch.compile(bev_pool_ref)
261+
262+
if cuda_ref:
263+
from conch_cuda_ext.ops.vision.bev_pool.bev_pool import bev_pool_forward as bev_pool_fwd_cuda
264+
265+
bev_pool_forward_cuda_fn = bev_pool_fwd_cuda
266+
267+
bev_pool_forward_conch_compiled_fn = None
268+
if compile_conch:
269+
bev_pool_forward_conch_compiled_fn = torch.compile(bev_pool_conch)
250270

251271
# Test both implementations
252272
args = (
@@ -260,8 +280,8 @@ def main(
260280
grid_cells_y,
261281
)
262282

263-
ref_output = bev_pool_ref_fn(*args)
264-
conch_output = bev_pool_conch_fn(*args)
283+
ref_output = bev_pool_ref(*args)
284+
conch_output = bev_pool_conch(*args)
265285

266286
# Accuracy checks
267287
if not torch.allclose(ref_output, conch_output, atol=absolute_tolerance):
@@ -277,25 +297,62 @@ def main(
277297

278298
# Benchmark implementations
279299
baseline_result = benchmark_it(
280-
lambda: bev_pool_ref_fn(*args),
300+
lambda: bev_pool_ref(*args),
281301
tag="Baseline",
282302
metadata=metadata,
283303
iteration_time_ms=iteration_time_ms,
284304
warmup_time_ms=warmup_time_ms,
285305
)
286306

287307
conch_result = benchmark_it(
288-
lambda: bev_pool_conch_fn(*args),
308+
lambda: bev_pool_conch(*args),
289309
tag="Conch",
290310
metadata=metadata,
291311
iteration_time_ms=iteration_time_ms,
292312
warmup_time_ms=warmup_time_ms,
293313
)
294314

315+
reference_compiled_result = None
316+
reference_cuda_result = None
317+
conch_compiled_result = None
318+
319+
if bev_pool_forward_compiled_fn:
320+
reference_compiled_result = benchmark_it(
321+
lambda: bev_pool_forward_compiled_fn(*args),
322+
tag="Reference (Compiled)",
323+
metadata=metadata,
324+
iteration_time_ms=iteration_time_ms,
325+
warmup_time_ms=warmup_time_ms,
326+
)
327+
328+
if bev_pool_forward_cuda_fn:
329+
reference_cuda_result = benchmark_it(
330+
lambda: bev_pool_forward_cuda_fn(*args),
331+
tag="CUDA",
332+
metadata=metadata,
333+
iteration_time_ms=iteration_time_ms,
334+
warmup_time_ms=warmup_time_ms,
335+
)
336+
337+
if bev_pool_forward_conch_compiled_fn:
338+
conch_compiled_result = benchmark_it(
339+
lambda: bev_pool_forward_conch_compiled_fn(*args),
340+
tag="Conch (Compiled)",
341+
metadata=metadata,
342+
iteration_time_ms=iteration_time_ms,
343+
warmup_time_ms=warmup_time_ms,
344+
)
345+
295346
# Print results
296347
conch_result.print_parameters(csv=csv)
297348
conch_result.print_results(csv=csv)
298349
baseline_result.print_results(csv=csv)
350+
if reference_compiled_result:
351+
reference_compiled_result.print_results(csv=csv)
352+
if reference_cuda_result:
353+
reference_cuda_result.print_results(csv=csv)
354+
if conch_compiled_result:
355+
conch_compiled_result.print_results(csv=csv)
299356

300357

301358
if __name__ == "__main__":

0 commit comments

Comments
 (0)