Skip to content

Commit 8a68608

Browse files
Process intervals in blocks, instead of points within interval in blocks
The previous implementation used a grid of `(num_intervals,)`, where each program would process all of the points in an interval, blockwise, in parallel. This is optimal if there are many points in an interval. However, in some cases, we don't have many points in an interval, so its actually better to process the intervals blockwise with each program processing a block of the intervals.
1 parent ce3c1d6 commit 8a68608

3 files changed

Lines changed: 115 additions & 85 deletions

File tree

benchmarks/bev_pool_backward_benchmark.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def _create_bev_pool_backward_data(
9191
"--num-points",
9292
required=False,
9393
type=int,
94-
default=10000,
94+
default=500000,
9595
help="Number of input points",
9696
)
9797
@click.option(
@@ -105,28 +105,28 @@ def _create_bev_pool_backward_data(
105105
"--batch-size",
106106
required=False,
107107
type=int,
108-
default=2,
108+
default=1,
109109
help="Batch size",
110110
)
111111
@click.option(
112112
"--grid-cells-z",
113113
required=False,
114114
type=int,
115-
default=16,
115+
default=32,
116116
help="Number of Z grid cells",
117117
)
118118
@click.option(
119119
"--grid-cells-x",
120120
required=False,
121121
type=int,
122-
default=200,
122+
default=250,
123123
help="Number of X grid cells",
124124
)
125125
@click.option(
126126
"--grid-cells-y",
127127
required=False,
128128
type=int,
129-
default=200,
129+
default=250,
130130
help="Number of Y grid cells",
131131
)
132132
@click.option(

benchmarks/bev_pool_benchmark.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def _create_bev_pool_data(
9090
"--num-points",
9191
required=False,
9292
type=int,
93-
default=10000,
93+
default=6000000,
9494
help="Number of input points",
9595
)
9696
@click.option(
@@ -104,28 +104,28 @@ def _create_bev_pool_data(
104104
"--batch-size",
105105
required=False,
106106
type=int,
107-
default=2,
107+
default=1,
108108
help="Batch size",
109109
)
110110
@click.option(
111111
"--grid-cells-z",
112112
required=False,
113113
type=int,
114-
default=16,
114+
default=20,
115115
help="Number of Z grid cells",
116116
)
117117
@click.option(
118118
"--grid-cells-x",
119119
required=False,
120120
type=int,
121-
default=200,
121+
default=800,
122122
help="Number of X grid cells",
123123
)
124124
@click.option(
125125
"--grid-cells-y",
126126
required=False,
127127
type=int,
128-
default=200,
128+
default=800,
129129
help="Number of Y grid cells",
130130
)
131131
@click.option(
@@ -239,6 +239,11 @@ def main(
239239
device=device,
240240
)
241241

242+
print(f"Number of intervals: {len(interval_starts)}", file=sys.stderr)
243+
print(f"Min interval length: {interval_lengths.float().min().item()}", file=sys.stderr)
244+
print(f"Mean interval length: {interval_lengths.float().mean().item()}", file=sys.stderr)
245+
print(f"Max interval length: {interval_lengths.float().max().item()}", file=sys.stderr)
246+
242247
# Compile functions if requested
243248
bev_pool_ref_fn = torch.compile(bev_pool_ref) if compile_ref else bev_pool_ref
244249
bev_pool_conch_fn = torch.compile(bev_pool_conch) if compile_conch else bev_pool_conch

conch/kernels/vision/bev_pool.py

Lines changed: 100 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
"""Quick cumsum kernel for 3D voxel grids."""
55

6+
from typing import Any
7+
68
import torch
79
import triton
810
import triton.language as tl
@@ -18,6 +20,7 @@ def _bev_pool_kernel(
1820
interval_lengths_ptr: tl.tensor,
1921
# Scalars
2022
num_channels: tl.int32,
23+
num_intervals: tl.int32,
2124
# Strides
2225
image_feats_stride: tl.int32,
2326
geom_feats_stride: tl.int32,
@@ -49,69 +52,80 @@ def _bev_pool_kernel(
4952
interval_starts_ptr: Pointer to the starting positions for pooled points, shape: (num_intervals,).
5053
interval_lengths_ptr: Pointer to the lengths of each pooled point, shape: (num_intervals,).
5154
num_channels: The number of channels in the image features.
55+
num_intervals: The number of intervals to process.
5256
image_feats_stride: The stride of the image features tensor.
5357
geom_feats_stride: The stride of the geometry features tensor.
5458
output_batch_stride: The stride of the output tensor for the batch dimension.
5559
output_z_stride: The stride of the output tensor for the z dimension.
5660
output_x_stride: The stride of the output tensor for the x dimension.
5761
output_y_stride: The stride of the output tensor for the y dimension.
5862
cxpr_num_channels_padded: The number of channels padded to the next power of 2.
59-
cxpr_block_size: The block size for processing points in parallel.
63+
cxpr_block_size: The block size for processing intervals in parallel.
6064
"""
61-
# What is the index of the interval this program is processing?
62-
interval_index = tl.program_id(0)
65+
# What is the starting index of the block of intervals this program is processing?
66+
interval_block_start = tl.program_id(0) * cxpr_block_size
67+
# Offsets for each interval in the block
68+
interval_block_offsets = interval_block_start + tl.arange(0, cxpr_block_size)
69+
# Mask out-of-bounds intervals
70+
interval_block_mask = interval_block_offsets < num_intervals
6371

64-
# Load current interval start and length
65-
interval_start = tl.load(interval_starts_ptr + interval_index)
66-
interval_length = tl.load(interval_lengths_ptr + interval_index)
72+
# Load start and length for each interval in the block
73+
interval_starts = tl.load(interval_starts_ptr + interval_block_offsets, mask=interval_block_mask, other=0)
74+
interval_lengths = tl.load(interval_lengths_ptr + interval_block_offsets, mask=interval_block_mask, other=0)
6775

6876
# Offsets and masks for the channels
6977
channel_offsets = tl.arange(0, cxpr_num_channels_padded)
7078
channel_mask = channel_offsets < num_channels
7179

72-
# Accumulator for the sum of image features for each channel for all points in the interval
73-
output = tl.zeros([cxpr_num_channels_padded], dtype=image_feats_ptr.dtype.element_ty)
80+
# Combine interval block mask with channel mask
81+
output_mask = interval_block_mask[:, None] & channel_mask[None, :]
82+
83+
# Accumulator for the sum of image features for each channel for all intervals in the block
84+
output = tl.zeros([cxpr_block_size, cxpr_num_channels_padded], dtype=image_feats_ptr.dtype.element_ty)
85+
86+
# Calculate the pointer to the start of the image features for the current block of intervals
87+
# Shape: (cxpr_block_size, cxpr_num_channels_padded)
88+
current_image_feats_ptr = image_feats_ptr + interval_starts[:, None] * image_feats_stride + channel_offsets[None, :]
7489

75-
# Calculate the pointer to the start of the image features for the current interval
76-
current_image_feats_ptr = image_feats_ptr + interval_start * image_feats_stride + channel_offsets[None, :]
90+
# Determine the maximum interval length in the block
91+
# This is used to determine how many points we need to process in the block
92+
max_interval_length = tl.max(interval_lengths, axis=0)
7793

78-
# Iterate blockwise over the points in the interval
79-
for block_index in range(tl.cdiv(interval_length, cxpr_block_size)):
80-
# Offsets for the current block
81-
block_offsets = block_index * cxpr_block_size + tl.arange(0, cxpr_block_size)
82-
# Mask out any indices that are out of bounds for the interval length
83-
block_mask = block_offsets < interval_length
94+
# Iterate over the max number of points in any interval in the block
95+
for point_index in range(max_interval_length):
96+
# Mask for intervals where this point index is valid
97+
index_mask = point_index < interval_lengths
8498

85-
# Load image features, shape: (cxpr_block_size, num_channels)
99+
# Load image features for the current point, shape: (cxpr_block_size, cxpr_num_channels_padded)
86100
image_feats = tl.load(
87-
current_image_feats_ptr + block_offsets[:, None] * image_feats_stride,
88-
mask=block_mask[:, None] & channel_mask[None, :],
101+
current_image_feats_ptr + point_index * image_feats_stride,
102+
mask=index_mask[:, None] & output_mask,
89103
other=0.0,
90104
)
91105

92-
# Calculate sum of image features for the current block, per channel
93-
# Shape: (cxpr_num_channels_padded,)
94-
output += tl.sum(image_feats, axis=0)
106+
# Accumulate the image features into the output tensor
107+
output += image_feats
95108

96-
# Load geometry coordinates for the first point in this interval
97-
# Note: all points in the interval share the same geom_feats, so we only need to load the first point's geom_feats.
98-
current_geom_feats_ptr = geom_feats_ptr + interval_start * geom_feats_stride
99-
geom_x = tl.load(current_geom_feats_ptr + 0) # X coordinate
100-
geom_y = tl.load(current_geom_feats_ptr + 1) # Y coordinate
101-
geom_z = tl.load(current_geom_feats_ptr + 2) # Z coordinate
102-
geom_b = tl.load(current_geom_feats_ptr + 3) # Batch index
109+
# Load geometry coordinates for the first point in each interval
110+
# Note: all points in each interval share the same geom_feats, so we only need to load the first point's geom_feats.
111+
# geom_{x|y|z|b} shape: (cxpr_block_size,)
112+
current_geom_feats_ptrs = geom_feats_ptr + interval_starts[:, None] * geom_feats_stride
113+
geom_x = tl.load(current_geom_feats_ptrs + 0) # X coordinates
114+
geom_y = tl.load(current_geom_feats_ptrs + 1) # Y coordinates
115+
geom_z = tl.load(current_geom_feats_ptrs + 2) # Z coordinates
116+
geom_b = tl.load(current_geom_feats_ptrs + 3) # Batch indices
103117

104-
# Calculate output tensor offset for shape [batch_size, grid_cells_z, grid_cells_x, grid_cells_y, num_channels]
105-
batch_offset = geom_b * output_batch_stride
106-
z_offset = geom_z * output_z_stride
107-
x_offset = geom_x * output_x_stride
108-
y_offset = geom_y * output_y_stride
118+
# Calculate output tensor offsets for shape [batch_size, grid_cells_z, grid_cells_x, grid_cells_y, num_channels]
119+
batch_offsets = geom_b * output_batch_stride
120+
z_offsets = geom_z * output_z_stride
121+
x_offsets = geom_x * output_x_stride
122+
y_offsets = geom_y * output_y_stride
109123

110124
# Store the accumulated output for the current interval
111125
tl.store(
112-
output_ptr + batch_offset + z_offset + x_offset + y_offset + channel_offsets,
126+
output_ptr + batch_offsets + z_offsets + x_offsets + y_offsets + channel_offsets[None, :],
113127
output,
114-
mask=channel_mask,
128+
mask=output_mask,
115129
)
116130

117131

@@ -125,6 +139,7 @@ def _bev_pool_backward_kernel(
125139
interval_lengths_ptr: tl.tensor,
126140
# Scalars
127141
num_channels: tl.int32,
142+
num_intervals: tl.int32,
128143
# Strides
129144
x_grad_stride: tl.int32,
130145
grad_output_batch_stride: tl.int32,
@@ -148,6 +163,7 @@ def _bev_pool_backward_kernel(
148163
interval_starts_ptr: Pointer to the starting positions for pooled points.
149164
interval_lengths_ptr: Pointer to the lengths of each pooled point.
150165
num_channels: The number of channels in the image features.
166+
num_intervals: The number of intervals to process.
151167
x_grad_stride: The stride of the x_grad tensor.
152168
grad_output_batch_stride: The stride of the grad_output tensor for the batch dimension.
153169
grad_output_z_stride: The stride of the grad_output tensor for the z dimension.
@@ -157,58 +173,65 @@ def _bev_pool_backward_kernel(
157173
cxpr_num_channels_padded: The number of channels padded to the next power of 2.
158174
cxpr_block_size: The block size for processing points in parallel.
159175
"""
160-
# What is the index of the interval this program is processing?
161-
interval_index = tl.program_id(0)
176+
# What is the starting index of the block of intervals this program is processing?
177+
interval_block_start = tl.program_id(0) * cxpr_block_size
178+
# Offsets for each interval in the block
179+
interval_block_offsets = interval_block_start + tl.arange(0, cxpr_block_size)
180+
# Mask out-of-bounds intervals
181+
interval_block_mask = interval_block_offsets < num_intervals
162182

163-
# Load current interval start and length
164-
interval_start = tl.load(interval_starts_ptr + interval_index)
165-
interval_length = tl.load(interval_lengths_ptr + interval_index)
183+
# Load start and length for each interval in the block
184+
interval_starts = tl.load(interval_starts_ptr + interval_block_offsets, mask=interval_block_mask, other=0)
185+
interval_lengths = tl.load(interval_lengths_ptr + interval_block_offsets, mask=interval_block_mask, other=0)
166186

167187
# Offsets and masks for the channels
168188
channel_offsets = tl.arange(0, cxpr_num_channels_padded)
169189
channel_mask = channel_offsets < num_channels
170190

171-
# Load geometry coordinates for the first point in this interval
172-
# Note: all points in the interval share the same geom_feats, so we only need to load the first point's geom_feats.
173-
current_geom_feats_ptr = geom_feats_ptr + interval_start * geom_feats_stride
174-
geom_x = tl.load(current_geom_feats_ptr + 0) # X coordinate
175-
geom_y = tl.load(current_geom_feats_ptr + 1) # Y coordinate
176-
geom_z = tl.load(current_geom_feats_ptr + 2) # Z coordinate
177-
geom_b = tl.load(current_geom_feats_ptr + 3) # Batch index
191+
# Combine interval block mask with channel mask
192+
output_mask = interval_block_mask[:, None] & channel_mask[None, :]
178193

179-
# Offset for the entry in the grad_output tensor for this interval
180-
grad_output_offset = (
194+
# Load geometry coordinates for the first point in each interval
195+
# Note: all points in each interval share the same geom_feats, so we only need to load the first point's geom_feats.
196+
current_geom_feats_ptrs = geom_feats_ptr + interval_starts[:, None] * geom_feats_stride
197+
# geom_{x|y|z|b} shape: (cxpr_block_size,)
198+
geom_x = tl.load(current_geom_feats_ptrs + 0) # X coordinates
199+
geom_y = tl.load(current_geom_feats_ptrs + 1) # Y coordinates
200+
geom_z = tl.load(current_geom_feats_ptrs + 2) # Z coordinates
201+
geom_b = tl.load(current_geom_feats_ptrs + 3) # Batch indices
202+
203+
# Offsets for the entry in the grad_output tensor for each interval
204+
grad_output_offsets = (
181205
geom_b * grad_output_batch_stride
182206
+ geom_z * grad_output_z_stride
183207
+ geom_x * grad_output_x_stride
184208
+ geom_y * grad_output_y_stride
185209
)
186210

187-
# Load gradient output, shape: (num_channels,)
211+
# Load gradient output, shape: (cxpr_block_size, cxpr_num_channels_padded)
188212
grad_output = tl.load(
189-
grad_output_ptr + grad_output_offset + channel_offsets,
190-
mask=channel_mask,
213+
grad_output_ptr + grad_output_offsets + channel_offsets[None, :],
214+
mask=output_mask,
191215
other=0.0,
192216
)
193217

194-
# Broadcast the grad_output to match the block size and number of channels
195-
# This is necessary to ensure we can write the gradients in blocks
196-
grad_output_expanded = grad_output[None, :].broadcast_to(cxpr_block_size, cxpr_num_channels_padded)
197-
198218
# Pointer to the start of the output for this block
199-
current_x_grad_ptr = x_grad_ptr + interval_start * x_grad_stride + channel_offsets
219+
current_x_grad_ptr = x_grad_ptr + interval_starts[:, None] * x_grad_stride + channel_offsets[None, :]
220+
221+
# Determine the maximum interval length in the block
222+
# This is used to determine how many points we need to process in the block
223+
max_interval_length = tl.max(interval_lengths, axis=0)
200224

201-
for block_index in range(tl.cdiv(interval_length, cxpr_block_size)):
202-
# Offsets for the current block
203-
block_offsets = block_index * cxpr_block_size + tl.arange(0, cxpr_block_size)
204-
# Mask out any indices that are out of bounds for the interval length
205-
block_mask = block_offsets < interval_length
225+
# Iterate over the max number of points in any interval in the block
226+
for point_index in range(max_interval_length):
227+
# Mask for intervals where this point index is valid
228+
index_mask = point_index < interval_lengths
206229

207-
# Store gradients for the current block
230+
# Store gradients for the current point, shape: (cxpr_block_size, cxpr_num_channels_padded)
208231
tl.store(
209-
current_x_grad_ptr + block_offsets[:, None] * x_grad_stride,
210-
grad_output_expanded,
211-
mask=block_mask[:, None] & channel_mask[None, :],
232+
current_x_grad_ptr + point_index * x_grad_stride,
233+
grad_output,
234+
mask=index_mask[:, None] & output_mask,
212235
)
213236

214237

@@ -231,8 +254,9 @@ def bev_pool_launcher(
231254
_, num_channels = image_feats.shape
232255
num_intervals = interval_lengths.size(0)
233256

234-
# Process each interval in parallel
235-
grid = (num_intervals,)
257+
def grid(meta: dict[str, Any]) -> tuple[int, ...]:
258+
# Process each interval in parallel, blockwise
259+
return (triton.cdiv(num_intervals, meta["cxpr_block_size"]),)
236260

237261
_bev_pool_kernel[grid](
238262
# Pointers to tensors
@@ -243,6 +267,7 @@ def bev_pool_launcher(
243267
interval_lengths_ptr=interval_lengths,
244268
# Scalars
245269
num_channels=num_channels,
270+
num_intervals=num_intervals,
246271
# Strides
247272
image_feats_stride=image_feats.stride(0),
248273
geom_feats_stride=geom_feats.stride(0),
@@ -252,10 +277,8 @@ def bev_pool_launcher(
252277
output_y_stride=output.stride(3),
253278
# Constexprs
254279
cxpr_num_channels_padded=triton.next_power_of_2(num_channels),
255-
# TODO(jmanning): Autotune?
256-
# The tricky thing here is the optimal block size depends on the average/maximum interval length,
257-
# not the number of intervals or number of points.
258-
# It also just may depend on the platform/device what the optimal block size is.
280+
# TODO(jmanning): We _could_ autotune based on the number of intervals,
281+
# but that would likely trigger many recompilations.
259282
cxpr_block_size=64,
260283
)
261284

@@ -279,8 +302,9 @@ def bev_pool_backward_launcher(
279302
num_intervals = interval_starts.size(0)
280303
_, num_channels = x_grad.shape
281304

282-
# Process each interval in parallel
283-
grid = (num_intervals,)
305+
def grid(meta: dict[str, Any]) -> tuple[int, ...]:
306+
# Process each interval in parallel, blockwise
307+
return (triton.cdiv(num_intervals, meta["cxpr_block_size"]),)
284308

285309
_bev_pool_backward_kernel[grid](
286310
# Pointers to tensors
@@ -291,6 +315,7 @@ def bev_pool_backward_launcher(
291315
interval_lengths_ptr=interval_lengths,
292316
# Scalars
293317
num_channels=num_channels,
318+
num_intervals=num_intervals,
294319
# Strides
295320
x_grad_stride=x_grad.stride(0),
296321
grad_output_batch_stride=grad_output.stride(0),

0 commit comments

Comments
 (0)