Skip to content

Commit b8874e7

Browse files
authored
Support yololite with fused NMS (#2203)
1 parent f29cc8f commit b8874e7

4 files changed

Lines changed: 205 additions & 100 deletions

File tree

inference_models/inference_models/models/auto_loaders/models_registry.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -255,9 +255,12 @@ class RegistryEntry:
255255
module_name="inference_models.models.yolo26.yolo26_instance_segmentation_trt",
256256
class_name="YOLO26ForInstanceSegmentationTRT",
257257
),
258-
("yololite", OBJECT_DETECTION_TASK, BackendType.ONNX): LazyClass(
259-
module_name="inference_models.models.yololite.yololite_object_detection_onnx",
260-
class_name="YOLOLiteForObjectDetectionOnnx",
258+
("yololite", OBJECT_DETECTION_TASK, BackendType.ONNX): RegistryEntry(
259+
model_class=LazyClass(
260+
module_name="inference_models.models.yololite.yololite_object_detection_onnx",
261+
class_name="YOLOLiteForObjectDetectionOnnx",
262+
),
263+
supported_model_features={"nms_fused"},
261264
),
262265
("paligemma-2", VLM_TASK, BackendType.HF): LazyClass(
263266
module_name="inference_models.models.paligemma.paligemma_hf",

inference_models/inference_models/models/yololite/yololite_object_detection_onnx.py

Lines changed: 46 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
parse_inference_config,
3131
)
3232
from inference_models.models.common.roboflow.post_processing import (
33+
post_process_nms_fused_model_output,
3334
rescale_detections,
3435
run_nms_for_object_detection,
3536
)
@@ -198,38 +199,13 @@ def post_process(
198199
class_agnostic_nms: bool = INFERENCE_MODELS_YOLOLITE_DEFAULT_CLASS_AGNOSTIC_NMS,
199200
**kwargs,
200201
) -> List[Detections]:
201-
# YOLOLite decoded export outputs 3 tensors:
202-
# boxes_xyxy: [B, N, 4] - decoded bounding boxes in xyxy pixel coords
203-
# obj_logits: [B, N, 1] - objectness logits (pre-sigmoid)
204-
# cls_logits: [B, N, C] - class logits (pre-sigmoid)
205-
boxes_xyxy, obj_logits, cls_logits = (
206-
model_results[0],
207-
model_results[1],
208-
model_results[2],
209-
)
210-
211-
# Apply sigmoid to convert logits to probabilities
212-
obj_conf = torch.sigmoid(obj_logits) # [B, N, 1]
213-
cls_conf = torch.sigmoid(cls_logits) # [B, N, C]
214-
215-
# Combined score: objectness * class confidence
216-
combined_scores = obj_conf * cls_conf # [B, N, C]
217-
218-
# Reshape to [B, 4+C, N] format expected by run_nms_for_object_detection:
219-
# channels 0-3: box coords (xyxy)
220-
# channels 4+: class scores
221-
boxes_t = boxes_xyxy.permute(0, 2, 1) # [B, 4, N]
222-
scores_t = combined_scores.permute(0, 2, 1) # [B, C, N]
223-
nms_input = torch.cat([boxes_t, scores_t], dim=1) # [B, 4+C, N]
224-
225-
nms_results = run_nms_for_object_detection(
226-
output=nms_input,
227-
conf_thresh=confidence,
228-
iou_thresh=iou_threshold,
229-
max_detections=max_detections,
230-
class_agnostic=class_agnostic_nms,
231-
box_format="xyxy",
232-
)
202+
# Backward compatibility: earlier model packages have no post_processing config — always unfused 3-tensor output
203+
if self._inference_config.post_processing and self._inference_config.post_processing.fused:
204+
nms_results = self._post_process_fused(model_results, confidence)
205+
else:
206+
nms_results = self._post_process_unfused(
207+
model_results, confidence, iou_threshold, max_detections, class_agnostic_nms,
208+
)
233209
rescaled_results = rescale_detections(
234210
detections=nms_results,
235211
images_metadata=pre_processing_meta,
@@ -244,3 +220,41 @@ def post_process(
244220
)
245221
)
246222
return results
223+
224+
def _post_process_fused(
225+
self,
226+
model_results: Tuple[torch.Tensor, ...],
227+
confidence: float,
228+
) -> List[torch.Tensor]:
229+
# Single output tensor [B, max_det, 6]: x1, y1, x2, y2, conf, class_id
230+
output = model_results[0]
231+
return post_process_nms_fused_model_output(output=output, conf_thresh=confidence)
232+
233+
def _post_process_unfused(
234+
self,
235+
model_results: Tuple[torch.Tensor, ...],
236+
confidence: float,
237+
iou_threshold: float,
238+
max_detections: int,
239+
class_agnostic_nms: bool,
240+
) -> List[torch.Tensor]:
241+
# Decoded outputs without fused NMS: boxes_xyxy [B,N,4], obj_logits [B,N,1], cls_logits [B,N,C]
242+
boxes_xyxy, obj_logits, cls_logits = (
243+
model_results[0], model_results[1], model_results[2],
244+
)
245+
obj_conf = torch.sigmoid(obj_logits)
246+
cls_conf = torch.sigmoid(cls_logits)
247+
combined_scores = obj_conf * cls_conf
248+
249+
boxes_t = boxes_xyxy.permute(0, 2, 1)
250+
scores_t = combined_scores.permute(0, 2, 1)
251+
nms_input = torch.cat([boxes_t, scores_t], dim=1)
252+
253+
return run_nms_for_object_detection(
254+
output=nms_input,
255+
conf_thresh=confidence,
256+
iou_thresh=iou_threshold,
257+
max_detections=max_detections,
258+
class_agnostic=class_agnostic_nms,
259+
box_format="xyxy",
260+
)

inference_models/tests/integration_tests/models/conftest.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,9 @@
7373
COIN_COUNTING_YOLACT_ONNX_STATIC_BS_STATIC_CROP_STRETCH_URL = "https://storage.googleapis.com/roboflow-tests-assets/rf-platform-models/yolact-static-bs-static-crop-stretch-onnx.zip"
7474
COIN_COUNTING_YOLACT_ONNX_STATIC_BS_STRETCH_URL = "https://storage.googleapis.com/roboflow-tests-assets/rf-platform-models/yolact-static-bs-stretch-onnx.zip"
7575

76-
COIN_COUNTING_YOLOLITE_N_ONNX_DYNAMIC_BS_LETTERBOX_URL = "https://storage.googleapis.com/roboflow-tests-assets/rf-platform-models/coin-counting-yololite-n-onnx-dynamic-bs-letterbox.zip"
76+
COIN_COUNTING_YOLOLITE_EDGE_N_ONNX_STATIC_BS_STRETCH_URL = "https://storage.googleapis.com/roboflow-tests-assets/rf-platform-models/coin-counting-yololite-edge-n-onnx-static-bs-stretch.zip"
77+
COIN_COUNTING_YOLOLITE_EDGE_N_ONNX_DYNAMIC_BS_STRETCH_URL = "https://storage.googleapis.com/roboflow-tests-assets/rf-platform-models/coin-counting-yololite-edge-n-onnx-dynamic-bs-stretch.zip"
78+
COIN_COUNTING_YOLOLITE_EDGE_N_ONNX_DYNAMIC_BS_STRETCH_FUSED_NMS_URL = "https://storage.googleapis.com/roboflow-tests-assets/rf-platform-models/coin-counting-yololite-edge-n-onnx-dynamic-bs-stretch-fused-nms.zip"
7779

7880
ASL_YOLOV8N_SEG_ONNX_DYNAMIC_BS_STRETCH_URL = "https://storage.googleapis.com/roboflow-tests-assets/rf-platform-models/yolov8n-seg-onnx-dynamic-bs-stretch.zip"
7981
ASL_YOLOV8N_SEG_ONNX_DYNAMIC_BS_STRETCH_FUSED_NMS_URL = "https://storage.googleapis.com/roboflow-tests-assets/rf-platform-models/yolov8n-seg-onnx-dynamic-bs-stretch-fused-nms.zip"
@@ -700,10 +702,26 @@ def coin_counting_yolo_nas_onnx_static_bs_center_crop_package() -> str:
700702

701703

702704
@pytest.fixture(scope="module")
703-
def coin_counting_yololite_n_onnx_dynamic_bs_letterbox_package() -> str:
705+
def coin_counting_yololite_edge_n_onnx_static_bs_stretch_package() -> str:
704706
return download_model_package(
705-
model_package_zip_url=COIN_COUNTING_YOLOLITE_N_ONNX_DYNAMIC_BS_LETTERBOX_URL,
706-
package_name="coin-counting-yololite-n-onnx-dynamic-bs-letterbox",
707+
model_package_zip_url=COIN_COUNTING_YOLOLITE_EDGE_N_ONNX_STATIC_BS_STRETCH_URL,
708+
package_name="coin-counting-yololite-edge-n-onnx-static-bs-stretch",
709+
)
710+
711+
712+
@pytest.fixture(scope="module")
713+
def coin_counting_yololite_edge_n_onnx_dynamic_bs_stretch_package() -> str:
714+
return download_model_package(
715+
model_package_zip_url=COIN_COUNTING_YOLOLITE_EDGE_N_ONNX_DYNAMIC_BS_STRETCH_URL,
716+
package_name="coin-counting-yololite-edge-n-onnx-dynamic-bs-stretch",
717+
)
718+
719+
720+
@pytest.fixture(scope="module")
721+
def coin_counting_yololite_edge_n_onnx_dynamic_bs_stretch_fused_nms_package() -> str:
722+
return download_model_package(
723+
model_package_zip_url=COIN_COUNTING_YOLOLITE_EDGE_N_ONNX_DYNAMIC_BS_STRETCH_FUSED_NMS_URL,
724+
package_name="coin-counting-yololite-edge-n-onnx-dynamic-bs-stretch-fused-nms",
707725
)
708726

709727

0 commit comments

Comments
 (0)