|
7 | 7 | https://github.com/pytorch/vision/blob/0721867e42841171254c7acaa45fbaf8ee16d3d7/torchvision/csrc/ops/cuda/nms_kernel.cu |
8 | 8 | """ |
9 | 9 |
|
10 | | -from typing import Any |
11 | | - |
12 | 10 | import torch |
13 | 11 | import triton |
14 | 12 | import triton.language as tl |
@@ -77,71 +75,64 @@ def _create_iou_mask_kernel( |
77 | 75 | iou_mask_stride: Stride for IoU mask tensor. |
78 | 76 | cxpr_block_size: Block size for processing. |
79 | 77 | """ |
80 | | - row_block_start = tl.program_id(0) * cxpr_block_size |
81 | | - col_block_start = tl.program_id(1) * cxpr_block_size |
82 | | - |
83 | | - # Skip this block if we are entirely in the lower-triangular part of the matrix |
84 | | - if row_block_start > col_block_start: |
85 | | - return |
86 | | - |
87 | | - # Process a block of rows |
88 | | - row_offsets = row_block_start + tl.arange(0, cxpr_block_size) |
89 | | - row_mask = row_offsets < num_boxes |
90 | | - |
91 | | - # Load the reference boxes |
92 | | - # Shape: (cxpr_block_size,) |
93 | | - box1_offsets = row_offsets * boxes_stride |
94 | | - box1_x1 = tl.load(boxes_ptr + box1_offsets + 0, mask=row_mask, other=0.0) |
95 | | - box1_y1 = tl.load(boxes_ptr + box1_offsets + 1, mask=row_mask, other=0.0) |
96 | | - box1_x2 = tl.load(boxes_ptr + box1_offsets + 2, mask=row_mask, other=0.0) |
97 | | - box1_y2 = tl.load(boxes_ptr + box1_offsets + 3, mask=row_mask, other=0.0) |
98 | | - |
99 | | - # Calculate areas of the reference boxes |
| 78 | + # What row of the matrix are we processing? |
| 79 | + # Each row corresponds to a box, and we process one row at a time. |
| 80 | + row_index = tl.program_id(0) |
| 81 | + |
| 82 | + # Load the reference box |
| 83 | + box1_offset = row_index * boxes_stride |
| 84 | + box1_x1 = tl.load(boxes_ptr + box1_offset + 0) |
| 85 | + box1_y1 = tl.load(boxes_ptr + box1_offset + 1) |
| 86 | + box1_x2 = tl.load(boxes_ptr + box1_offset + 2) |
| 87 | + box1_y2 = tl.load(boxes_ptr + box1_offset + 3) |
| 88 | + |
| 89 | + # Calculate area of the reference box |
100 | 90 | box1_area = (box1_x2 - box1_x1) * (box1_y2 - box1_y1) |
101 | 91 |
|
102 | | - # Process a block of columns |
103 | | - col_offsets = col_block_start + tl.arange(0, cxpr_block_size) |
104 | | - col_mask = col_offsets < num_boxes |
| 92 | + # Process all of the columns, blockwise |
| 93 | + for col_block_start in range(row_index, num_boxes, cxpr_block_size): |
| 94 | + # Column offsets for the current block |
| 95 | + col_offsets = col_block_start + tl.arange(0, cxpr_block_size) |
| 96 | + col_mask = col_offsets < num_boxes |
105 | 97 |
|
106 | | - # Load boxes in the current block |
107 | | - # Shape: (cxpr_block_size,) |
108 | | - box2_offsets = col_offsets * boxes_stride |
109 | | - box2_x1 = tl.load(boxes_ptr + box2_offsets + 0, mask=col_mask, other=0.0) |
110 | | - box2_y1 = tl.load(boxes_ptr + box2_offsets + 1, mask=col_mask, other=0.0) |
111 | | - box2_x2 = tl.load(boxes_ptr + box2_offsets + 2, mask=col_mask, other=0.0) |
112 | | - box2_y2 = tl.load(boxes_ptr + box2_offsets + 3, mask=col_mask, other=0.0) |
| 98 | + # Load boxes in the current block |
| 99 | + # Shape: (cxpr_block_size,) |
| 100 | + box2_offsets = col_offsets * boxes_stride |
| 101 | + box2_x1 = tl.load(boxes_ptr + box2_offsets + 0, mask=col_mask, other=0.0) |
| 102 | + box2_y1 = tl.load(boxes_ptr + box2_offsets + 1, mask=col_mask, other=0.0) |
| 103 | + box2_x2 = tl.load(boxes_ptr + box2_offsets + 2, mask=col_mask, other=0.0) |
| 104 | + box2_y2 = tl.load(boxes_ptr + box2_offsets + 3, mask=col_mask, other=0.0) |
113 | 105 |
|
114 | | - # Calculate areas of the boxes |
115 | | - box2_area = (box2_x2 - box2_x1) * (box2_y2 - box2_y1) |
| 106 | + # Calculate areas of the boxes |
| 107 | + box2_area = (box2_x2 - box2_x1) * (box2_y2 - box2_y1) |
116 | 108 |
|
117 | | - # Calculate intersection |
118 | | - inter_x1 = tl.maximum(box1_x1[:, None], box2_x1[None, :]) |
119 | | - inter_y1 = tl.maximum(box1_y1[:, None], box2_y1[None, :]) |
120 | | - inter_x2 = tl.minimum(box1_x2[:, None], box2_x2[None, :]) |
121 | | - inter_y2 = tl.minimum(box1_y2[:, None], box2_y2[None, :]) |
| 109 | + # Calculate intersection |
| 110 | + inter_x1 = tl.maximum(box1_x1, box2_x1) |
| 111 | + inter_y1 = tl.maximum(box1_y1, box2_y1) |
| 112 | + inter_x2 = tl.minimum(box1_x2, box2_x2) |
| 113 | + inter_y2 = tl.minimum(box1_y2, box2_y2) |
122 | 114 |
|
123 | | - # Check if there's valid intersection |
124 | | - inter_w = tl.maximum(0.0, inter_x2 - inter_x1) |
125 | | - inter_h = tl.maximum(0.0, inter_y2 - inter_y1) |
126 | | - inter_area = inter_w * inter_h |
| 115 | + # Check if there's valid intersection |
| 116 | + inter_w = tl.maximum(0.0, inter_x2 - inter_x1) |
| 117 | + inter_h = tl.maximum(0.0, inter_y2 - inter_y1) |
| 118 | + inter_area = inter_w * inter_h |
127 | 119 |
|
128 | | - # Calculate union and IoU |
129 | | - # Shape: (cxpr_block_size, cxpr_block_size) |
130 | | - union_area = box1_area[:, None] + box2_area[None, :] - inter_area |
131 | | - iou = tl.where(union_area > 0.0, inter_area / union_area, 0.0) |
| 120 | + # Calculate union and IoU |
| 121 | + # Shape: (cxpr_block_size,) |
| 122 | + union_area = box1_area + box2_area - inter_area |
| 123 | + iou = tl.where(union_area > 0.0, inter_area / union_area, 0.0) |
132 | 124 |
|
133 | | - # Create a mask for IoU values that exceed the threshold |
134 | | - # Shape: (cxpr_block_size, cxpr_block_size) |
135 | | - exceeds_threshold = iou > iou_threshold |
| 125 | + # Create a mask for IoU values that exceed the threshold |
| 126 | + # Shape: (cxpr_block_size,) |
| 127 | + exceeds_threshold = iou > iou_threshold |
136 | 128 |
|
137 | | - # Note: for debugging, if you want to store the actual IoU values instead of boolean, |
138 | | - # you can store `iou` instead of `exceeds_threshold`. You'll also need to update the |
139 | | - # `iou_mask_ptr` type to `boxes.dtype` or similar (instead of `torch.bool`). |
| 129 | + # Note: for debugging, if you want to store the actual IoU values instead of boolean, |
| 130 | + # you can store `iou` instead of `exceeds_threshold`. You'll also need to update the |
| 131 | + # `iou_mask_ptr` type to `boxes.dtype` or similar (instead of `torch.bool`). |
140 | 132 |
|
141 | | - # Store IoU mask -> upper triangular part of the matrix |
142 | | - iou_output_offsets = row_offsets[:, None] * iou_mask_stride + col_offsets[None, :] |
143 | | - iou_output_mask = row_mask[:, None] & col_mask[None, :] & (row_offsets[:, None] < col_offsets[None, :]) |
144 | | - tl.store(iou_mask_ptr + iou_output_offsets, exceeds_threshold, mask=iou_output_mask) |
| 133 | + # Store IoU mask -> upper triangular part of the matrix |
| 134 | + iou_output_offsets = row_index * iou_mask_stride + col_offsets |
| 135 | + tl.store(iou_mask_ptr + iou_output_offsets, exceeds_threshold, mask=col_mask) |
145 | 136 |
|
146 | 137 |
|
147 | 138 | @triton.autotune( # type: ignore[misc] |
@@ -244,11 +235,9 @@ def nms_launcher( |
244 | 235 | _, sorted_indices = torch.sort(scores, dim=0, stable=True, descending=True) |
245 | 236 | sorted_boxes = boxes[sorted_indices].contiguous() |
246 | 237 |
|
247 | | - # Determine if IoU of one box against all other boxes exceeds the threshold in parallel. |
248 | | - # Process blockwise in both dimensions, in chunks of size `cxpr_block_size`. |
249 | | - def stage1_grid(meta: dict[str, Any]) -> tuple[int, int]: |
250 | | - num_blocks = triton.cdiv(num_boxes, meta["cxpr_block_size"]) |
251 | | - return (num_blocks, num_blocks) |
| 238 | + # For each box, create a mask indicating which boxes have IoU with it that exceeds the threshold. |
| 239 | + # Process other boxes blockwise, in chunks of size `cxpr_block_size`. |
| 240 | + stage1_grid = (num_boxes,) |
252 | 241 |
|
253 | 242 | # Create IoU mask in parallel, only upper-triangular part of the matrix is populated. |
254 | 243 | _create_iou_mask_kernel[stage1_grid]( |
|
0 commit comments