Skip to content

Commit c54e9a6

Browse files
Change to 1D grid for first stage of NMS kernel
1 parent 8611d1e commit c54e9a6

1 file changed

Lines changed: 51 additions & 62 deletions

File tree

conch/kernels/vision/nms.py

Lines changed: 51 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
https://github.com/pytorch/vision/blob/0721867e42841171254c7acaa45fbaf8ee16d3d7/torchvision/csrc/ops/cuda/nms_kernel.cu
88
"""
99

10-
from typing import Any
11-
1210
import torch
1311
import triton
1412
import triton.language as tl
@@ -77,71 +75,64 @@ def _create_iou_mask_kernel(
7775
iou_mask_stride: Stride for IoU mask tensor.
7876
cxpr_block_size: Block size for processing.
7977
"""
80-
row_block_start = tl.program_id(0) * cxpr_block_size
81-
col_block_start = tl.program_id(1) * cxpr_block_size
82-
83-
# Skip this block if we are entirely in the lower-triangular part of the matrix
84-
if row_block_start > col_block_start:
85-
return
86-
87-
# Process a block of rows
88-
row_offsets = row_block_start + tl.arange(0, cxpr_block_size)
89-
row_mask = row_offsets < num_boxes
90-
91-
# Load the reference boxes
92-
# Shape: (cxpr_block_size,)
93-
box1_offsets = row_offsets * boxes_stride
94-
box1_x1 = tl.load(boxes_ptr + box1_offsets + 0, mask=row_mask, other=0.0)
95-
box1_y1 = tl.load(boxes_ptr + box1_offsets + 1, mask=row_mask, other=0.0)
96-
box1_x2 = tl.load(boxes_ptr + box1_offsets + 2, mask=row_mask, other=0.0)
97-
box1_y2 = tl.load(boxes_ptr + box1_offsets + 3, mask=row_mask, other=0.0)
98-
99-
# Calculate areas of the reference boxes
78+
# What row of the matrix are we processing?
79+
# Each row corresponds to a box, and we process one row at a time.
80+
row_index = tl.program_id(0)
81+
82+
# Load the reference box
83+
box1_offset = row_index * boxes_stride
84+
box1_x1 = tl.load(boxes_ptr + box1_offset + 0)
85+
box1_y1 = tl.load(boxes_ptr + box1_offset + 1)
86+
box1_x2 = tl.load(boxes_ptr + box1_offset + 2)
87+
box1_y2 = tl.load(boxes_ptr + box1_offset + 3)
88+
89+
# Calculate area of the reference box
10090
box1_area = (box1_x2 - box1_x1) * (box1_y2 - box1_y1)
10191

102-
# Process a block of columns
103-
col_offsets = col_block_start + tl.arange(0, cxpr_block_size)
104-
col_mask = col_offsets < num_boxes
92+
# Process all of the columns, blockwise
93+
for col_block_start in range(row_index, num_boxes, cxpr_block_size):
94+
# Column offsets for the current block
95+
col_offsets = col_block_start + tl.arange(0, cxpr_block_size)
96+
col_mask = col_offsets < num_boxes
10597

106-
# Load boxes in the current block
107-
# Shape: (cxpr_block_size,)
108-
box2_offsets = col_offsets * boxes_stride
109-
box2_x1 = tl.load(boxes_ptr + box2_offsets + 0, mask=col_mask, other=0.0)
110-
box2_y1 = tl.load(boxes_ptr + box2_offsets + 1, mask=col_mask, other=0.0)
111-
box2_x2 = tl.load(boxes_ptr + box2_offsets + 2, mask=col_mask, other=0.0)
112-
box2_y2 = tl.load(boxes_ptr + box2_offsets + 3, mask=col_mask, other=0.0)
98+
# Load boxes in the current block
99+
# Shape: (cxpr_block_size,)
100+
box2_offsets = col_offsets * boxes_stride
101+
box2_x1 = tl.load(boxes_ptr + box2_offsets + 0, mask=col_mask, other=0.0)
102+
box2_y1 = tl.load(boxes_ptr + box2_offsets + 1, mask=col_mask, other=0.0)
103+
box2_x2 = tl.load(boxes_ptr + box2_offsets + 2, mask=col_mask, other=0.0)
104+
box2_y2 = tl.load(boxes_ptr + box2_offsets + 3, mask=col_mask, other=0.0)
113105

114-
# Calculate areas of the boxes
115-
box2_area = (box2_x2 - box2_x1) * (box2_y2 - box2_y1)
106+
# Calculate areas of the boxes
107+
box2_area = (box2_x2 - box2_x1) * (box2_y2 - box2_y1)
116108

117-
# Calculate intersection
118-
inter_x1 = tl.maximum(box1_x1[:, None], box2_x1[None, :])
119-
inter_y1 = tl.maximum(box1_y1[:, None], box2_y1[None, :])
120-
inter_x2 = tl.minimum(box1_x2[:, None], box2_x2[None, :])
121-
inter_y2 = tl.minimum(box1_y2[:, None], box2_y2[None, :])
109+
# Calculate intersection
110+
inter_x1 = tl.maximum(box1_x1, box2_x1)
111+
inter_y1 = tl.maximum(box1_y1, box2_y1)
112+
inter_x2 = tl.minimum(box1_x2, box2_x2)
113+
inter_y2 = tl.minimum(box1_y2, box2_y2)
122114

123-
# Check if there's valid intersection
124-
inter_w = tl.maximum(0.0, inter_x2 - inter_x1)
125-
inter_h = tl.maximum(0.0, inter_y2 - inter_y1)
126-
inter_area = inter_w * inter_h
115+
# Check if there's valid intersection
116+
inter_w = tl.maximum(0.0, inter_x2 - inter_x1)
117+
inter_h = tl.maximum(0.0, inter_y2 - inter_y1)
118+
inter_area = inter_w * inter_h
127119

128-
# Calculate union and IoU
129-
# Shape: (cxpr_block_size, cxpr_block_size)
130-
union_area = box1_area[:, None] + box2_area[None, :] - inter_area
131-
iou = tl.where(union_area > 0.0, inter_area / union_area, 0.0)
120+
# Calculate union and IoU
121+
# Shape: (cxpr_block_size,)
122+
union_area = box1_area + box2_area - inter_area
123+
iou = tl.where(union_area > 0.0, inter_area / union_area, 0.0)
132124

133-
# Create a mask for IoU values that exceed the threshold
134-
# Shape: (cxpr_block_size, cxpr_block_size)
135-
exceeds_threshold = iou > iou_threshold
125+
# Create a mask for IoU values that exceed the threshold
126+
# Shape: (cxpr_block_size,)
127+
exceeds_threshold = iou > iou_threshold
136128

137-
# Note: for debugging, if you want to store the actual IoU values instead of boolean,
138-
# you can store `iou` instead of `exceeds_threshold`. You'll also need to update the
139-
# `iou_mask_ptr` type to `boxes.dtype` or similar (instead of `torch.bool`).
129+
# Note: for debugging, if you want to store the actual IoU values instead of boolean,
130+
# you can store `iou` instead of `exceeds_threshold`. You'll also need to update the
131+
# `iou_mask_ptr` type to `boxes.dtype` or similar (instead of `torch.bool`).
140132

141-
# Store IoU mask -> upper triangular part of the matrix
142-
iou_output_offsets = row_offsets[:, None] * iou_mask_stride + col_offsets[None, :]
143-
iou_output_mask = row_mask[:, None] & col_mask[None, :] & (row_offsets[:, None] < col_offsets[None, :])
144-
tl.store(iou_mask_ptr + iou_output_offsets, exceeds_threshold, mask=iou_output_mask)
133+
# Store IoU mask -> upper triangular part of the matrix
134+
iou_output_offsets = row_index * iou_mask_stride + col_offsets
135+
tl.store(iou_mask_ptr + iou_output_offsets, exceeds_threshold, mask=col_mask)
145136

146137

147138
@triton.autotune( # type: ignore[misc]
@@ -244,11 +235,9 @@ def nms_launcher(
244235
_, sorted_indices = torch.sort(scores, dim=0, stable=True, descending=True)
245236
sorted_boxes = boxes[sorted_indices].contiguous()
246237

247-
# Determine if IoU of one box against all other boxes exceeds the threshold in parallel.
248-
# Process blockwise in both dimensions, in chunks of size `cxpr_block_size`.
249-
def stage1_grid(meta: dict[str, Any]) -> tuple[int, int]:
250-
num_blocks = triton.cdiv(num_boxes, meta["cxpr_block_size"])
251-
return (num_blocks, num_blocks)
238+
# For each box, create a mask indicating which boxes have IoU with it that exceeds the threshold.
239+
# Process other boxes blockwise, in chunks of size `cxpr_block_size`.
240+
stage1_grid = (num_boxes,)
252241

253242
# Create IoU mask in parallel, only upper-triangular part of the matrix is populated.
254243
_create_iou_mask_kernel[stage1_grid](

0 commit comments

Comments
 (0)