@@ -91,7 +91,7 @@ def _create_bev_pool_backward_data(
9191 "--num-points" ,
9292 required = False ,
9393 type = int ,
94- default = 500000 ,
94+ default = 6000000 ,
9595 help = "Number of input points" ,
9696)
9797@click .option (
@@ -112,21 +112,21 @@ def _create_bev_pool_backward_data(
112112 "--grid-cells-z" ,
113113 required = False ,
114114 type = int ,
115- default = 32 ,
115+ default = 20 ,
116116 help = "Number of Z grid cells" ,
117117)
118118@click .option (
119119 "--grid-cells-x" ,
120120 required = False ,
121121 type = int ,
122- default = 250 ,
122+ default = 800 ,
123123 help = "Number of X grid cells" ,
124124)
125125@click .option (
126126 "--grid-cells-y" ,
127127 required = False ,
128128 type = int ,
129- default = 250 ,
129+ default = 800 ,
130130 help = "Number of Y grid cells" ,
131131)
132132@click .option (
@@ -177,6 +177,11 @@ def _create_bev_pool_backward_data(
177177 is_flag = True ,
178178 help = "Flag to torch.compile() the Conch impl" ,
179179)
180+ @click .option (
181+ "--cuda-ref" ,
182+ is_flag = True ,
183+ help = "Flag to enable CUDA reference implementation" ,
184+ )
180185def main (
181186 num_points : int ,
182187 num_channels : int ,
@@ -192,6 +197,7 @@ def main(
192197 csv : bool ,
193198 compile_ref : bool ,
194199 compile_conch : bool ,
200+ cuda_ref : bool ,
195201) -> None :
196202 """Benchmark BEV Pool backward pass.
197203
@@ -210,6 +216,7 @@ def main(
210216 csv: Flag to indicate whether or not to print results in CSV format.
211217 compile_ref: Flag to torch.compile() the reference implementation.
212218 compile_conch: Flag to torch.compile() the Conch implementation.
219+ cuda_ref: Flag to enable CUDA reference implementation.
213220 """
214221 seed : Final = 0
215222 seed_everything (seed )
@@ -241,8 +248,21 @@ def main(
241248 )
242249
243250 # Compile functions if requested
244- bev_pool_backward_ref_fn = torch .compile (bev_pool_backward_ref ) if compile_ref else bev_pool_backward_ref
245- bev_pool_backward_conch_fn = torch .compile (bev_pool_backward_conch ) if compile_conch else bev_pool_backward_conch
251+ bev_pool_backward_compiled_fn = None
252+ bev_pool_backward_cuda_fn = None
253+
254+ if compile_ref :
255+ # Compile the reference implementation if requested
256+ bev_pool_backward_compiled_fn = torch .compile (bev_pool_backward_ref )
257+
258+ if cuda_ref :
259+ from conch_cuda_ext .ops .vision .bev_pool .bev_pool import bev_pool_backward as bev_pool_bwd_cuda
260+
261+ bev_pool_backward_cuda_fn = bev_pool_bwd_cuda
262+
263+ bev_pool_backward_conch_compiled_fn = None
264+ if compile_conch :
265+ bev_pool_backward_conch_compiled_fn = torch .compile (bev_pool_backward_conch )
246266
247267 # Test both implementations
248268 args = (
@@ -252,14 +272,14 @@ def main(
252272 interval_lengths ,
253273 )
254274
255- ref_output = bev_pool_backward_ref_fn (
275+ ref_output = bev_pool_backward_ref (
256276 * args ,
257- batch_size = batch_size ,
258- grid_cells_z = grid_cells_z ,
259- grid_cells_x = grid_cells_x ,
260- grid_cells_y = grid_cells_y ,
277+ batch_size ,
278+ grid_cells_z ,
279+ grid_cells_x ,
280+ grid_cells_y ,
261281 )
262- conch_output = bev_pool_backward_conch_fn (* args )
282+ conch_output = bev_pool_backward_conch (* args )
263283
264284 # Accuracy checks
265285 if not torch .allclose (ref_output , conch_output , atol = absolute_tolerance ):
@@ -275,7 +295,7 @@ def main(
275295
276296 # Benchmark implementations
277297 baseline_result = benchmark_it (
278- lambda : bev_pool_backward_ref_fn (
298+ lambda : bev_pool_backward_ref (
279299 * args ,
280300 batch_size = batch_size ,
281301 grid_cells_z = grid_cells_z ,
@@ -289,17 +309,67 @@ def main(
289309 )
290310
291311 conch_result = benchmark_it (
292- lambda : bev_pool_backward_conch_fn (* args ),
312+ lambda : bev_pool_backward_conch (* args ),
293313 tag = "Conch" ,
294314 metadata = metadata ,
295315 iteration_time_ms = iteration_time_ms ,
296316 warmup_time_ms = warmup_time_ms ,
297317 )
298318
319+ reference_compiled_result = None
320+ reference_cuda_result = None
321+ conch_compiled_result = None
322+
323+ if bev_pool_backward_compiled_fn :
324+ reference_compiled_result = benchmark_it (
325+ lambda : bev_pool_backward_compiled_fn (
326+ * args ,
327+ batch_size = batch_size ,
328+ grid_cells_z = grid_cells_z ,
329+ grid_cells_x = grid_cells_x ,
330+ grid_cells_y = grid_cells_y ,
331+ ),
332+ tag = "Reference (Compiled)" ,
333+ metadata = metadata ,
334+ iteration_time_ms = iteration_time_ms ,
335+ warmup_time_ms = warmup_time_ms ,
336+ )
337+
338+ if bev_pool_backward_cuda_fn :
339+ reference_cuda_result = benchmark_it (
340+ # Note: cannot use kwargs for CUDA fn
341+ lambda : bev_pool_backward_cuda_fn (
342+ * args ,
343+ batch_size ,
344+ grid_cells_z ,
345+ grid_cells_x ,
346+ grid_cells_y ,
347+ ),
348+ tag = "CUDA" ,
349+ metadata = metadata ,
350+ iteration_time_ms = iteration_time_ms ,
351+ warmup_time_ms = warmup_time_ms ,
352+ )
353+
354+ if bev_pool_backward_conch_compiled_fn :
355+ conch_compiled_result = benchmark_it (
356+ lambda : bev_pool_backward_conch_compiled_fn (* args ),
357+ tag = "Conch (Compiled)" ,
358+ metadata = metadata ,
359+ iteration_time_ms = iteration_time_ms ,
360+ warmup_time_ms = warmup_time_ms ,
361+ )
362+
299363 # Print results
300364 conch_result .print_parameters (csv = csv )
301365 conch_result .print_results (csv = csv )
302366 baseline_result .print_results (csv = csv )
367+ if reference_compiled_result :
368+ reference_compiled_result .print_results (csv = csv )
369+ if reference_cuda_result :
370+ reference_cuda_result .print_results (csv = csv )
371+ if conch_compiled_result :
372+ conch_compiled_result .print_results (csv = csv )
303373
304374
305375if __name__ == "__main__" :
0 commit comments