Skip to content

Commit fa647b8

Browse files
Merge pull request #36 from stackav-oss/feature/jmanning/nms-v5
Add Non-Max Suppression
2 parents f72e964 + c54e9a6 commit fa647b8

14 files changed

Lines changed: 1055 additions & 11 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@ dist/
1010
results/
1111

1212
.coverage
13+
.vscode

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ Each operation is complete with a PyTorch-only reference implementation (and som
3232
- GEMM
3333
- Mixed-precision
3434
- Scaled
35+
- Vision
36+
- Non-Max Suppression (NMS)
3537
- vLLM
3638
- KV cache operations
3739
- Copy blocks
@@ -99,6 +101,7 @@ We were inspired by and leverage components of the following libraries:
99101

100102
- [bitsandbytes](https://github.com/bitsandbytes-foundation/bitsandbytes)
101103
- [GemLite](https://github.com/mobiusml/gemlite)
104+
- [Torchvision](https://github.com/pytorch/vision)
102105
- [vLLM](https://github.com/vllm-project/vllm)
103106

104107
## License

benchmarks/nms_benchmark.py

Lines changed: 308 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,308 @@
1+
# Copyright 2025 Stack AV Co.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""NMS benchmark."""
5+
6+
import sys
7+
from typing import Final
8+
9+
import click
10+
import torch
11+
12+
from conch.ops.vision.nms import nms as nms_conch
13+
from conch.platforms import current_platform
14+
from conch.reference.vision.nms import nms as nms_ref
15+
from conch.third_party.vllm.utils import seed_everything
16+
from conch.utils.benchmark import BenchmarkMetadata, benchmark_it
17+
18+
19+
def _create_tensors_with_iou(num_boxes: int, iou_thresh: float) -> tuple[torch.Tensor, torch.Tensor]:
20+
# force last box to have a pre-defined iou with the first box
21+
# let b0 be [x0, y0, x1, y1], and b1 be [x0, y0, x1 + d, y1],
22+
# then, in order to satisfy ops.iou(b0, b1) == iou_thresh,
23+
# we need to have d = (x1 - x0) * (1 - iou_thresh) / iou_thresh
24+
# Adjust the threshold upward a bit with the intent of creating
25+
# at least one box that exceeds (barely) the threshold and so
26+
# should be suppressed.
27+
boxes = torch.rand(num_boxes, 4) * 100
28+
boxes[:, 2:] += boxes[:, :2]
29+
boxes[-1, :] = boxes[0, :]
30+
x0, y0, x1, y1 = boxes[-1].tolist()
31+
iou_thresh += 1e-5
32+
boxes[-1, 2] += (x1 - x0) * (1 - iou_thresh) / iou_thresh
33+
scores = torch.rand(num_boxes)
34+
return boxes, scores
35+
36+
37+
@click.command()
38+
@click.option(
39+
"--num-boxes",
40+
required=False,
41+
type=int,
42+
default=1000,
43+
help="Number of boxes to create",
44+
)
45+
@click.option(
46+
"--iou-threshold",
47+
required=False,
48+
type=float,
49+
default=0.5,
50+
help="IoU threshold for boxes to be kept",
51+
)
52+
@click.option(
53+
"--vectorize-ref",
54+
is_flag=True,
55+
help="Flag to enable vectorization in the reference implementation",
56+
)
57+
@click.option(
58+
"--gpu-ref",
59+
is_flag=True,
60+
help="Flag to enable GPU reference implementation",
61+
)
62+
@click.option(
63+
"--iteration-time-ms",
64+
required=False,
65+
type=int,
66+
default=10000,
67+
help="Time in milliseconds to run benchmark",
68+
)
69+
@click.option(
70+
"--warmup-time-ms",
71+
required=False,
72+
type=int,
73+
default=1000,
74+
help="Time in milliseconds to warmup before recording times",
75+
)
76+
@click.option(
77+
"--absolute-tolerance",
78+
required=False,
79+
type=float,
80+
default=1e-3,
81+
help="Absolute tolerance to match with",
82+
)
83+
@click.option(
84+
"--verbose",
85+
is_flag=True,
86+
help="Flag for printing verbose output",
87+
)
88+
@click.option(
89+
"--gpu",
90+
required=False,
91+
type=str,
92+
default=current_platform.device,
93+
help="Device to run on",
94+
)
95+
@click.option(
96+
"--csv",
97+
is_flag=True,
98+
help="Flag for printing results in CSV format",
99+
)
100+
@click.option(
101+
"--compile-ref",
102+
is_flag=True,
103+
help="Flag to torch.compile() the reference impl",
104+
)
105+
@click.option(
106+
"--compile-conch",
107+
is_flag=True,
108+
help="Flag to torch.compile() the Conch impl",
109+
)
110+
def main(
111+
num_boxes: int,
112+
iou_threshold: float,
113+
vectorize_ref: bool,
114+
gpu_ref: bool,
115+
iteration_time_ms: int,
116+
warmup_time_ms: int,
117+
absolute_tolerance: float,
118+
verbose: bool,
119+
gpu: str,
120+
csv: bool,
121+
compile_ref: bool,
122+
compile_conch: bool,
123+
) -> None:
124+
"""Benchmark NMS.
125+
126+
Args:
127+
num_boxes: Number of boxes to create.
128+
iou_threshold: IoU threshold for boxes to be kept.
129+
vectorize_ref: Flag to enable vectorization in the reference implementation.
130+
gpu_ref: Flag to enable GPU reference implementation.
131+
iteration_time_ms: Time in milliseconds to run benchmark.
132+
warmup_time_ms: Time in milliseconds to warmup before recording times.
133+
absolute_tolerance: Absolute tolerance used to check accuracy.
134+
verbose: Flag to indicate whether or not to print verbose output.
135+
gpu: Which gpu to run on.
136+
csv: Flag to indicate whether or not to print results in CSV format.
137+
compile_ref: Flag to torch.compile() the reference implementation.
138+
compile_conch: Flag to torch.compile() the Conch implementation.
139+
"""
140+
seed: Final = 0
141+
seed_everything(seed)
142+
143+
device: Final = torch.device(gpu)
144+
torch.set_default_device(device)
145+
146+
metadata = BenchmarkMetadata(
147+
platform=current_platform.name(),
148+
params={
149+
"num_boxes": num_boxes,
150+
"iou_threshold": iou_threshold,
151+
},
152+
)
153+
154+
boxes, scores = _create_tensors_with_iou(num_boxes, iou_threshold)
155+
156+
reference_vectorized_fn = None
157+
reference_gpu_fn = None
158+
if vectorize_ref:
159+
# Use vectorized reference implementation if requested
160+
from conch.reference.vision.nms import _nms_pytorch_vectorized
161+
162+
reference_vectorized_fn = _nms_pytorch_vectorized
163+
if gpu_ref:
164+
# Use GPU reference implementation if requested
165+
from torchvision.ops.boxes import nms as nms_torchvision # type: ignore[import-untyped]
166+
167+
reference_gpu_fn = nms_torchvision
168+
169+
reference_compiled_fn = None
170+
reference_vectorized_compiled_fn = None
171+
if compile_ref:
172+
# Compile the reference implementation if requested
173+
reference_compiled_fn = torch.compile(nms_ref)
174+
if vectorize_ref:
175+
reference_vectorized_compiled_fn = torch.compile(reference_vectorized_fn)
176+
177+
conch_compiled_fn = torch.compile(nms_conch) if compile_conch else None
178+
179+
# Get reference output
180+
reference_output = nms_ref(boxes, scores, iou_threshold)
181+
182+
# Test Conch implementation
183+
conch_output = nms_conch(boxes, scores, iou_threshold)
184+
185+
# Accuracy checks
186+
if not torch.allclose(conch_output, reference_output, atol=absolute_tolerance):
187+
print(f"WARNING: Reference and Conch results differ! (atol={absolute_tolerance})", file=sys.stderr)
188+
print(f"Ref kept: {len(reference_output)}, Conch kept: {len(conch_output)}", file=sys.stderr)
189+
190+
if verbose:
191+
print(f"Reference output: {reference_output}", file=sys.stderr)
192+
print(f"Conch output: {conch_output}", file=sys.stderr)
193+
else:
194+
print(f"Reference vs Conch: Results matched with atol={absolute_tolerance} :)", file=sys.stderr)
195+
196+
# Benchmark implementations
197+
baseline_result = benchmark_it(
198+
lambda: nms_ref(
199+
boxes,
200+
scores,
201+
iou_threshold,
202+
),
203+
tag="Baseline",
204+
metadata=metadata,
205+
iteration_time_ms=iteration_time_ms,
206+
warmup_time_ms=warmup_time_ms,
207+
)
208+
209+
conch_result = benchmark_it(
210+
lambda: nms_conch(
211+
boxes,
212+
scores,
213+
iou_threshold,
214+
),
215+
tag="Conch",
216+
metadata=metadata,
217+
iteration_time_ms=iteration_time_ms,
218+
warmup_time_ms=warmup_time_ms,
219+
)
220+
221+
reference_compiled_result = None
222+
reference_vectorized_result = None
223+
reference_vectorized_compiled_result = None
224+
reference_gpu_result = None
225+
conch_compiled_result = None
226+
227+
if reference_compiled_fn:
228+
reference_compiled_result = benchmark_it(
229+
lambda: reference_compiled_fn(
230+
boxes,
231+
scores,
232+
iou_threshold,
233+
),
234+
tag="PyTorch Reference (Compiled)",
235+
metadata=metadata,
236+
iteration_time_ms=iteration_time_ms,
237+
warmup_time_ms=warmup_time_ms,
238+
)
239+
240+
if reference_vectorized_fn:
241+
reference_vectorized_result = benchmark_it(
242+
lambda: reference_vectorized_fn(
243+
boxes,
244+
scores,
245+
iou_threshold,
246+
),
247+
tag="PyTorch Reference (Vectorized)",
248+
metadata=metadata,
249+
iteration_time_ms=iteration_time_ms,
250+
warmup_time_ms=warmup_time_ms,
251+
)
252+
253+
if reference_vectorized_compiled_fn:
254+
reference_vectorized_compiled_result = benchmark_it(
255+
lambda: reference_vectorized_compiled_fn( # type: ignore[call-arg]
256+
boxes, # type: ignore[arg-type]
257+
scores,
258+
iou_threshold,
259+
),
260+
tag="PyTorch Reference (Vectorized, Compiled)",
261+
metadata=metadata,
262+
iteration_time_ms=iteration_time_ms,
263+
warmup_time_ms=warmup_time_ms,
264+
)
265+
266+
if reference_gpu_fn:
267+
reference_gpu_result = benchmark_it(
268+
lambda: reference_gpu_fn(
269+
boxes,
270+
scores,
271+
iou_threshold,
272+
),
273+
tag="PyTorch GPU Reference",
274+
metadata=metadata,
275+
iteration_time_ms=iteration_time_ms,
276+
warmup_time_ms=warmup_time_ms,
277+
)
278+
279+
if conch_compiled_fn:
280+
conch_compiled_result = benchmark_it(
281+
lambda: conch_compiled_fn(
282+
boxes,
283+
scores,
284+
iou_threshold,
285+
),
286+
tag="Conch (Compiled)",
287+
metadata=metadata,
288+
iteration_time_ms=iteration_time_ms,
289+
warmup_time_ms=warmup_time_ms,
290+
)
291+
292+
conch_result.print_parameters(csv=csv)
293+
conch_result.print_results(csv=csv)
294+
baseline_result.print_results(csv=csv)
295+
if reference_compiled_result:
296+
reference_compiled_result.print_results(csv=csv)
297+
if reference_vectorized_result:
298+
reference_vectorized_result.print_results(csv=csv)
299+
if reference_vectorized_compiled_result:
300+
reference_vectorized_compiled_result.print_results(csv=csv)
301+
if reference_gpu_result:
302+
reference_gpu_result.print_results(csv=csv)
303+
if conch_compiled_result:
304+
conch_compiled_result.print_results(csv=csv)
305+
306+
307+
if __name__ == "__main__":
308+
main()

conch/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@ Here is a running directory of the various envrionment variables:
88
| ---| ---| ---|
99
| `CONCH_BENCH_ENABLE_ALL_REF` | `1` or `true` (case in-sensitive) | Toggles whether or not to use bitsandbytes/vLLM reference implementations for benchmarking. |
1010
| `CONCH_ENABLE_BNB` | `1` or `true` (case in-sensitive) | Toggles whether or not to use bitsandbytes reference implementations. |
11+
| `CONCH_ENABLE_TORCHVISION` | `1` or `true` (case in-sensitive) | Toggles whether or not to use torchvision reference implementations. |
1112
| `CONCH_ENABLE_VLLM` | `1` or `true` (case in-sensitive) | Toggles whether or not to use vLLM reference implementations. |

conch/envs.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@
2020
),
2121
# Enable bitsandbytes kernels for testing/benchmarking
2222
"CONCH_ENABLE_BNB": lambda: (os.environ.get("CONCH_ENABLE_BNB", "0").strip().lower() in ("1", "true")),
23+
# Enable torchvision kernels for testing/benchmarking
24+
"CONCH_ENABLE_TORCHVISION": lambda: (
25+
os.environ.get("CONCH_ENABLE_TORCHVISION", "0").strip().lower() in ("1", "true")
26+
),
2327
# Enable vLLM kernels for testing/benchmarking
2428
"CONCH_ENABLE_VLLM": lambda: (os.environ.get("CONCH_ENABLE_VLLM", "0").strip().lower() in ("1", "true")),
2529
}

0 commit comments

Comments
 (0)