33
44"""Quick cumsum kernel for 3D voxel grids."""
55
6+ from typing import Any
7+
68import torch
79import triton
810import triton .language as tl
@@ -18,6 +20,7 @@ def _bev_pool_kernel(
1820 interval_lengths_ptr : tl .tensor ,
1921 # Scalars
2022 num_channels : tl .int32 ,
23+ num_intervals : tl .int32 ,
2124 # Strides
2225 image_feats_stride : tl .int32 ,
2326 geom_feats_stride : tl .int32 ,
@@ -49,69 +52,80 @@ def _bev_pool_kernel(
4952 interval_starts_ptr: Pointer to the starting positions for pooled points, shape: (num_intervals,).
5053 interval_lengths_ptr: Pointer to the lengths of each pooled point, shape: (num_intervals,).
5154 num_channels: The number of channels in the image features.
55+ num_intervals: The number of intervals to process.
5256 image_feats_stride: The stride of the image features tensor.
5357 geom_feats_stride: The stride of the geometry features tensor.
5458 output_batch_stride: The stride of the output tensor for the batch dimension.
5559 output_z_stride: The stride of the output tensor for the z dimension.
5660 output_x_stride: The stride of the output tensor for the x dimension.
5761 output_y_stride: The stride of the output tensor for the y dimension.
5862 cxpr_num_channels_padded: The number of channels padded to the next power of 2.
59- cxpr_block_size: The block size for processing points in parallel.
63+ cxpr_block_size: The block size for processing intervals in parallel.
6064 """
61- # What is the index of the interval this program is processing?
62- interval_index = tl .program_id (0 )
65+ # What is the starting index of the block of intervals this program is processing?
66+ interval_block_start = tl .program_id (0 ) * cxpr_block_size
67+ # Offsets for each interval in the block
68+ interval_block_offsets = interval_block_start + tl .arange (0 , cxpr_block_size )
69+ # Mask out-of-bounds intervals
70+ interval_block_mask = interval_block_offsets < num_intervals
6371
64- # Load current interval start and length
65- interval_start = tl .load (interval_starts_ptr + interval_index )
66- interval_length = tl .load (interval_lengths_ptr + interval_index )
72+ # Load start and length for each interval in the block
73+ interval_starts = tl .load (interval_starts_ptr + interval_block_offsets , mask = interval_block_mask , other = 0 )
74+ interval_lengths = tl .load (interval_lengths_ptr + interval_block_offsets , mask = interval_block_mask , other = 0 )
6775
6876 # Offsets and masks for the channels
6977 channel_offsets = tl .arange (0 , cxpr_num_channels_padded )
7078 channel_mask = channel_offsets < num_channels
7179
72- # Accumulator for the sum of image features for each channel for all points in the interval
73- output = tl .zeros ([cxpr_num_channels_padded ], dtype = image_feats_ptr .dtype .element_ty )
80+ # Combine interval block mask with channel mask
81+ output_mask = interval_block_mask [:, None ] & channel_mask [None , :]
82+
83+ # Accumulator for the sum of image features for each channel for all intervals in the block
84+ output = tl .zeros ([cxpr_block_size , cxpr_num_channels_padded ], dtype = image_feats_ptr .dtype .element_ty )
85+
86+ # Calculate the pointer to the start of the image features for the current block of intervals
87+ # Shape: (cxpr_block_size, cxpr_num_channels_padded)
88+ current_image_feats_ptr = image_feats_ptr + interval_starts [:, None ] * image_feats_stride + channel_offsets [None , :]
7489
75- # Calculate the pointer to the start of the image features for the current interval
76- current_image_feats_ptr = image_feats_ptr + interval_start * image_feats_stride + channel_offsets [None , :]
90+ # Determine the maximum interval length in the block
91+ # This is used to determine how many points we need to process in the block
92+ max_interval_length = tl .max (interval_lengths , axis = 0 )
7793
78- # Iterate blockwise over the points in the interval
79- for block_index in range (tl .cdiv (interval_length , cxpr_block_size )):
80- # Offsets for the current block
81- block_offsets = block_index * cxpr_block_size + tl .arange (0 , cxpr_block_size )
82- # Mask out any indices that are out of bounds for the interval length
83- block_mask = block_offsets < interval_length
94+ # Iterate over the max number of points in any interval in the block
95+ for point_index in range (max_interval_length ):
96+ # Mask for intervals where this point index is valid
97+ index_mask = point_index < interval_lengths
8498
85- # Load image features, shape: (cxpr_block_size, num_channels )
99+ # Load image features for the current point , shape: (cxpr_block_size, cxpr_num_channels_padded )
86100 image_feats = tl .load (
87- current_image_feats_ptr + block_offsets [:, None ] * image_feats_stride ,
88- mask = block_mask [:, None ] & channel_mask [ None , :] ,
101+ current_image_feats_ptr + point_index * image_feats_stride ,
102+ mask = index_mask [:, None ] & output_mask ,
89103 other = 0.0 ,
90104 )
91105
92- # Calculate sum of image features for the current block, per channel
93- # Shape: (cxpr_num_channels_padded,)
94- output += tl .sum (image_feats , axis = 0 )
106+ # Accumulate the image features into the output tensor
107+ output += image_feats
95108
96- # Load geometry coordinates for the first point in this interval
97- # Note: all points in the interval share the same geom_feats, so we only need to load the first point's geom_feats.
98- current_geom_feats_ptr = geom_feats_ptr + interval_start * geom_feats_stride
99- geom_x = tl .load (current_geom_feats_ptr + 0 ) # X coordinate
100- geom_y = tl .load (current_geom_feats_ptr + 1 ) # Y coordinate
101- geom_z = tl .load (current_geom_feats_ptr + 2 ) # Z coordinate
102- geom_b = tl .load (current_geom_feats_ptr + 3 ) # Batch index
109+ # Load geometry coordinates for the first point in each interval
110+ # Note: all points in each interval share the same geom_feats, so we only need to load the first point's geom_feats.
111+ # geom_{x|y|z|b} shape: (cxpr_block_size,)
112+ current_geom_feats_ptrs = geom_feats_ptr + interval_starts [:, None ] * geom_feats_stride
113+ geom_x = tl .load (current_geom_feats_ptrs + 0 ) # X coordinates
114+ geom_y = tl .load (current_geom_feats_ptrs + 1 ) # Y coordinates
115+ geom_z = tl .load (current_geom_feats_ptrs + 2 ) # Z coordinates
116+ geom_b = tl .load (current_geom_feats_ptrs + 3 ) # Batch indices
103117
104- # Calculate output tensor offset for shape [batch_size, grid_cells_z, grid_cells_x, grid_cells_y, num_channels]
105- batch_offset = geom_b * output_batch_stride
106- z_offset = geom_z * output_z_stride
107- x_offset = geom_x * output_x_stride
108- y_offset = geom_y * output_y_stride
118+ # Calculate output tensor offsets for shape [batch_size, grid_cells_z, grid_cells_x, grid_cells_y, num_channels]
119+ batch_offsets = geom_b * output_batch_stride
120+ z_offsets = geom_z * output_z_stride
121+ x_offsets = geom_x * output_x_stride
122+ y_offsets = geom_y * output_y_stride
109123
110124 # Store the accumulated output for the current interval
111125 tl .store (
112- output_ptr + batch_offset + z_offset + x_offset + y_offset + channel_offsets ,
126+ output_ptr + batch_offsets + z_offsets + x_offsets + y_offsets + channel_offsets [ None , :] ,
113127 output ,
114- mask = channel_mask ,
128+ mask = output_mask ,
115129 )
116130
117131
@@ -125,6 +139,7 @@ def _bev_pool_backward_kernel(
125139 interval_lengths_ptr : tl .tensor ,
126140 # Scalars
127141 num_channels : tl .int32 ,
142+ num_intervals : tl .int32 ,
128143 # Strides
129144 x_grad_stride : tl .int32 ,
130145 grad_output_batch_stride : tl .int32 ,
@@ -148,6 +163,7 @@ def _bev_pool_backward_kernel(
148163 interval_starts_ptr: Pointer to the starting positions for pooled points.
149164 interval_lengths_ptr: Pointer to the lengths of each pooled point.
150165 num_channels: The number of channels in the image features.
166+ num_intervals: The number of intervals to process.
151167 x_grad_stride: The stride of the x_grad tensor.
152168 grad_output_batch_stride: The stride of the grad_output tensor for the batch dimension.
153169 grad_output_z_stride: The stride of the grad_output tensor for the z dimension.
@@ -157,58 +173,65 @@ def _bev_pool_backward_kernel(
157173 cxpr_num_channels_padded: The number of channels padded to the next power of 2.
158174 cxpr_block_size: The block size for processing points in parallel.
159175 """
160- # What is the index of the interval this program is processing?
161- interval_index = tl .program_id (0 )
176+ # What is the starting index of the block of intervals this program is processing?
177+ interval_block_start = tl .program_id (0 ) * cxpr_block_size
178+ # Offsets for each interval in the block
179+ interval_block_offsets = interval_block_start + tl .arange (0 , cxpr_block_size )
180+ # Mask out-of-bounds intervals
181+ interval_block_mask = interval_block_offsets < num_intervals
162182
163- # Load current interval start and length
164- interval_start = tl .load (interval_starts_ptr + interval_index )
165- interval_length = tl .load (interval_lengths_ptr + interval_index )
183+ # Load start and length for each interval in the block
184+ interval_starts = tl .load (interval_starts_ptr + interval_block_offsets , mask = interval_block_mask , other = 0 )
185+ interval_lengths = tl .load (interval_lengths_ptr + interval_block_offsets , mask = interval_block_mask , other = 0 )
166186
167187 # Offsets and masks for the channels
168188 channel_offsets = tl .arange (0 , cxpr_num_channels_padded )
169189 channel_mask = channel_offsets < num_channels
170190
171- # Load geometry coordinates for the first point in this interval
172- # Note: all points in the interval share the same geom_feats, so we only need to load the first point's geom_feats.
173- current_geom_feats_ptr = geom_feats_ptr + interval_start * geom_feats_stride
174- geom_x = tl .load (current_geom_feats_ptr + 0 ) # X coordinate
175- geom_y = tl .load (current_geom_feats_ptr + 1 ) # Y coordinate
176- geom_z = tl .load (current_geom_feats_ptr + 2 ) # Z coordinate
177- geom_b = tl .load (current_geom_feats_ptr + 3 ) # Batch index
191+ # Combine interval block mask with channel mask
192+ output_mask = interval_block_mask [:, None ] & channel_mask [None , :]
178193
179- # Offset for the entry in the grad_output tensor for this interval
180- grad_output_offset = (
194+ # Load geometry coordinates for the first point in each interval
195+ # Note: all points in each interval share the same geom_feats, so we only need to load the first point's geom_feats.
196+ current_geom_feats_ptrs = geom_feats_ptr + interval_starts [:, None ] * geom_feats_stride
197+ # geom_{x|y|z|b} shape: (cxpr_block_size,)
198+ geom_x = tl .load (current_geom_feats_ptrs + 0 ) # X coordinates
199+ geom_y = tl .load (current_geom_feats_ptrs + 1 ) # Y coordinates
200+ geom_z = tl .load (current_geom_feats_ptrs + 2 ) # Z coordinates
201+ geom_b = tl .load (current_geom_feats_ptrs + 3 ) # Batch indices
202+
203+ # Offsets for the entry in the grad_output tensor for each interval
204+ grad_output_offsets = (
181205 geom_b * grad_output_batch_stride
182206 + geom_z * grad_output_z_stride
183207 + geom_x * grad_output_x_stride
184208 + geom_y * grad_output_y_stride
185209 )
186210
187- # Load gradient output, shape: (num_channels, )
211+ # Load gradient output, shape: (cxpr_block_size, cxpr_num_channels_padded )
188212 grad_output = tl .load (
189- grad_output_ptr + grad_output_offset + channel_offsets ,
190- mask = channel_mask ,
213+ grad_output_ptr + grad_output_offsets + channel_offsets [ None , :] ,
214+ mask = output_mask ,
191215 other = 0.0 ,
192216 )
193217
194- # Broadcast the grad_output to match the block size and number of channels
195- # This is necessary to ensure we can write the gradients in blocks
196- grad_output_expanded = grad_output [None , :].broadcast_to (cxpr_block_size , cxpr_num_channels_padded )
197-
198218 # Pointer to the start of the output for this block
199- current_x_grad_ptr = x_grad_ptr + interval_start * x_grad_stride + channel_offsets
219+ current_x_grad_ptr = x_grad_ptr + interval_starts [:, None ] * x_grad_stride + channel_offsets [None , :]
220+
221+ # Determine the maximum interval length in the block
222+ # This is used to determine how many points we need to process in the block
223+ max_interval_length = tl .max (interval_lengths , axis = 0 )
200224
201- for block_index in range (tl .cdiv (interval_length , cxpr_block_size )):
202- # Offsets for the current block
203- block_offsets = block_index * cxpr_block_size + tl .arange (0 , cxpr_block_size )
204- # Mask out any indices that are out of bounds for the interval length
205- block_mask = block_offsets < interval_length
225+ # Iterate over the max number of points in any interval in the block
226+ for point_index in range (max_interval_length ):
227+ # Mask for intervals where this point index is valid
228+ index_mask = point_index < interval_lengths
206229
207- # Store gradients for the current block
230+ # Store gradients for the current point, shape: (cxpr_block_size, cxpr_num_channels_padded)
208231 tl .store (
209- current_x_grad_ptr + block_offsets [:, None ] * x_grad_stride ,
210- grad_output_expanded ,
211- mask = block_mask [:, None ] & channel_mask [ None , :] ,
232+ current_x_grad_ptr + point_index * x_grad_stride ,
233+ grad_output ,
234+ mask = index_mask [:, None ] & output_mask ,
212235 )
213236
214237
@@ -231,8 +254,9 @@ def bev_pool_launcher(
231254 _ , num_channels = image_feats .shape
232255 num_intervals = interval_lengths .size (0 )
233256
234- # Process each interval in parallel
235- grid = (num_intervals ,)
257+ def grid (meta : dict [str , Any ]) -> tuple [int , ...]:
258+ # Process each interval in parallel, blockwise
259+ return (triton .cdiv (num_intervals , meta ["cxpr_block_size" ]),)
236260
237261 _bev_pool_kernel [grid ](
238262 # Pointers to tensors
@@ -243,6 +267,7 @@ def bev_pool_launcher(
243267 interval_lengths_ptr = interval_lengths ,
244268 # Scalars
245269 num_channels = num_channels ,
270+ num_intervals = num_intervals ,
246271 # Strides
247272 image_feats_stride = image_feats .stride (0 ),
248273 geom_feats_stride = geom_feats .stride (0 ),
@@ -252,10 +277,8 @@ def bev_pool_launcher(
252277 output_y_stride = output .stride (3 ),
253278 # Constexprs
254279 cxpr_num_channels_padded = triton .next_power_of_2 (num_channels ),
255- # TODO(jmanning): Autotune?
256- # The tricky thing here is the optimal block size depends on the average/maximum interval length,
257- # not the number of intervals or number of points.
258- # It also just may depend on the platform/device what the optimal block size is.
280+ # TODO(jmanning): We _could_ autotune based on the number of intervals,
281+ # but that would likely trigger many recompilations.
259282 cxpr_block_size = 64 ,
260283 )
261284
@@ -279,8 +302,9 @@ def bev_pool_backward_launcher(
279302 num_intervals = interval_starts .size (0 )
280303 _ , num_channels = x_grad .shape
281304
282- # Process each interval in parallel
283- grid = (num_intervals ,)
305+ def grid (meta : dict [str , Any ]) -> tuple [int , ...]:
306+ # Process each interval in parallel, blockwise
307+ return (triton .cdiv (num_intervals , meta ["cxpr_block_size" ]),)
284308
285309 _bev_pool_backward_kernel [grid ](
286310 # Pointers to tensors
@@ -291,6 +315,7 @@ def bev_pool_backward_launcher(
291315 interval_lengths_ptr = interval_lengths ,
292316 # Scalars
293317 num_channels = num_channels ,
318+ num_intervals = num_intervals ,
294319 # Strides
295320 x_grad_stride = x_grad .stride (0 ),
296321 grad_output_batch_stride = grad_output .stride (0 ),
0 commit comments