Skip to content

Commit db22603

Browse files
Add wrapper for ultralytics yolo for instance segmentation task (#601)
* add wrapper for inst seg yolo * shorter description * add unit test * minor fix linter --------- Co-authored-by: Alexander Barabanov <97449232+AlexanderBarabanov@users.noreply.github.com>
1 parent 292f628 commit db22603

3 files changed

Lines changed: 493 additions & 0 deletions

File tree

model_api/src/model_api/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
)
3636
from .visual_prompting import Prompt, SAMLearnableVisualPrompter, SAMVisualPrompter
3737
from .yolo import YOLO, YOLO11, YOLOF, YOLOX, YoloV3ONNX, YoloV4, YOLOv5, YOLOv8
38+
from .yolo_seg import YOLOSeg
3839

3940
classification_models = [
4041
"resnet-18-pytorch",
@@ -95,6 +96,7 @@
9596
"VisualPromptingResult",
9697
"YOLO",
9798
"YOLO11",
99+
"YOLOSeg",
98100
"YOLOF",
99101
"YOLOv3ONNX",
100102
"YOLOv4",
Lines changed: 260 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
#
2+
# Copyright (C) 2026 Intel Corporation
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
6+
"""Custom ModelAPI wrapper for Ultralytics YOLO instance-segmentation inference."""
7+
8+
from __future__ import annotations
9+
10+
from typing import Any, cast
11+
12+
import numpy as np
13+
14+
from model_api.adapters.utils import resize_image_ocv
15+
from model_api.models.detection_model import DetectionModel
16+
from model_api.models.result import InstanceSegmentationResult
17+
from model_api.models.utils import ResizeMetadata
18+
from model_api.models.yolo import xywh2xyxy
19+
20+
21+
class YOLOSeg(DetectionModel):
22+
"""ModelAPI wrapper for YOLO instance-segmentation models.
23+
24+
Expects 2 outputs:
25+
* detection output: ``[1, 4 + num_classes + mask_dim, num_boxes]``
26+
* prototype output: ``[1, mask_dim, proto_h, proto_w]``
27+
28+
Post-processing:
29+
1. Parse detection predictions (boxes + class scores + mask coefficients).
30+
2. Filter by confidence, apply NMS.
31+
3. Decode masks: ``coefficients @ protos.reshape(mask_dim, -1)`` → sigmoid → crop → resize.
32+
4. Return ``InstanceSegmentationResult``.
33+
"""
34+
35+
__model__ = "YOLO-seg"
36+
37+
def __init__(self, inference_adapter: object, configuration: dict | None = None, preload: bool = False) -> None:
38+
super().__init__(inference_adapter, configuration or {}, preload)
39+
self._check_io_number(1, 2)
40+
41+
self._det_output_name: str = ""
42+
self._proto_output_name: str = ""
43+
outputs = cast("dict[str, Any]", self.outputs or {})
44+
45+
for name, output in outputs.items():
46+
shape = output.shape
47+
if len(shape) == 3:
48+
self._det_output_name = name
49+
elif len(shape) == 4:
50+
self._proto_output_name = name
51+
52+
if not self._det_output_name or not self._proto_output_name:
53+
self.raise_error(
54+
"Expected one rank-3 detection output and one rank-4 prototype output, "
55+
f"but got shapes: {[(name, out.shape) for name, out in outputs.items()]}",
56+
)
57+
58+
det_shape = outputs[self._det_output_name].shape
59+
proto_shape = outputs[self._proto_output_name].shape
60+
self._mask_dim = proto_shape[1]
61+
self._proto_h = proto_shape[2]
62+
self._proto_w = proto_shape[3]
63+
64+
self._num_classes = det_shape[1] - 4 - self._mask_dim
65+
if self._num_classes <= 0:
66+
self.raise_error(f"Detection output channel dim ({det_shape[1]}) must be > 4 + mask_dim ({self._mask_dim})")
67+
68+
@classmethod
69+
def parameters(cls):
70+
parameters = super().parameters()
71+
parameters["pad_value"].update_default_value(114)
72+
parameters["resize_type"].update_default_value("fit_to_window_letterbox")
73+
parameters["reverse_input_channels"].update_default_value(default_value=False)
74+
parameters["scale_values"].update_default_value([255.0])
75+
parameters["confidence_threshold"].update_default_value(0.25)
76+
parameters["iou_threshold"].update_default_value(0.5)
77+
return parameters
78+
79+
def postprocess(self, outputs: dict[str, Any], meta: dict[str, Any]) -> InstanceSegmentationResult:
80+
"""Decode detections and instance masks from raw model outputs.
81+
82+
Args:
83+
outputs: Raw model outputs keyed by output tensor name.
84+
meta: Preprocessing metadata from ModelAPI (original_shape, etc.).
85+
86+
Returns:
87+
InstanceSegmentationResult with boxes in original image coordinates
88+
and binary masks at original image resolution.
89+
"""
90+
det_output = outputs[self._det_output_name]
91+
proto_output = outputs[self._proto_output_name]
92+
93+
prediction = det_output.astype(np.float32)
94+
protos = proto_output[0].astype(np.float32)
95+
96+
pred = prediction[0].T
97+
98+
boxes_xywh = pred[:, :4]
99+
class_scores = pred[:, 4 : 4 + self._num_classes]
100+
mask_coeffs = pred[:, 4 + self._num_classes :]
101+
102+
params = cast("Any", self.params)
103+
conf_threshold = params.confidence_threshold
104+
max_scores = class_scores.max(axis=1)
105+
keep_conf = max_scores > conf_threshold
106+
107+
if not keep_conf.any():
108+
return self._empty_result(meta)
109+
110+
boxes_xywh = boxes_xywh[keep_conf]
111+
class_scores = class_scores[keep_conf]
112+
mask_coeffs = mask_coeffs[keep_conf]
113+
114+
labels = class_scores.argmax(axis=1)
115+
confidences = class_scores[np.arange(len(labels)), labels]
116+
117+
boxes_xyxy = xywh2xyxy(boxes_xywh.copy())
118+
119+
keep_nms = self._calculate_nms(
120+
boxes=boxes_xyxy,
121+
scores=confidences,
122+
labels=labels.astype(np.float32),
123+
)
124+
boxes_xyxy = boxes_xyxy[keep_nms]
125+
confidences = confidences[keep_nms]
126+
labels = labels[keep_nms]
127+
mask_coeffs = mask_coeffs[keep_nms]
128+
129+
masks = self._decode_masks(mask_coeffs, protos, boxes_xyxy, meta)
130+
131+
input_img_w = meta["original_shape"][1]
132+
input_img_h = meta["original_shape"][0]
133+
resize_meta = ResizeMetadata.compute(
134+
original_width=input_img_w,
135+
original_height=input_img_h,
136+
model_width=self.orig_width,
137+
model_height=self.orig_height,
138+
resize_type=params.resize_type,
139+
)
140+
141+
coords = boxes_xyxy.copy()
142+
coords -= (resize_meta.pad_left, resize_meta.pad_top, resize_meta.pad_left, resize_meta.pad_top)
143+
coords *= (
144+
resize_meta.inverted_scale_x,
145+
resize_meta.inverted_scale_y,
146+
resize_meta.inverted_scale_x,
147+
resize_meta.inverted_scale_y,
148+
)
149+
150+
int_boxes = np.round(coords).astype(np.int32)
151+
np.clip(
152+
int_boxes,
153+
0,
154+
[input_img_w, input_img_h, input_img_w, input_img_h],
155+
out=int_boxes,
156+
)
157+
158+
int_labels = labels.astype(np.int32)
159+
return InstanceSegmentationResult(
160+
bboxes=int_boxes,
161+
scores=confidences,
162+
labels=int_labels + 1,
163+
masks=masks,
164+
label_names=[self.get_label_name(i) for i in int_labels],
165+
saliency_map=[],
166+
feature_vector=np.ndarray(0),
167+
)
168+
169+
def _decode_masks(
170+
self,
171+
mask_coeffs: np.ndarray,
172+
protos: np.ndarray,
173+
boxes_xyxy: np.ndarray,
174+
meta: dict,
175+
) -> np.ndarray:
176+
"""Decode instance masks from mask coefficients and prototypes.
177+
178+
Args:
179+
mask_coeffs: Mask coefficients ``(N, mask_dim)``.
180+
protos: Prototype masks ``(mask_dim, proto_h, proto_w)``.
181+
boxes_xyxy: Bounding boxes in model input coordinates ``(N, 4)``.
182+
meta: Preprocessing metadata (original_shape, etc.).
183+
184+
Returns:
185+
Binary masks at original image resolution ``(N, orig_h, orig_w)``.
186+
"""
187+
mask_dim, proto_h, proto_w = protos.shape
188+
raw_masks = mask_coeffs @ protos.reshape(mask_dim, -1)
189+
raw_masks = raw_masks.reshape(-1, proto_h, proto_w)
190+
191+
raw_masks = 1.0 / (1.0 + np.exp(-raw_masks))
192+
193+
model_h, model_w = self.orig_height, self.orig_width
194+
scale_x = proto_w / model_w
195+
scale_y = proto_h / model_h
196+
proto_boxes = boxes_xyxy * np.array([scale_x, scale_y, scale_x, scale_y], dtype=np.float32)
197+
198+
raw_masks = self.crop_mask(raw_masks, proto_boxes)
199+
200+
input_img_h = meta["original_shape"][0]
201+
input_img_w = meta["original_shape"][1]
202+
203+
resize_meta = ResizeMetadata.compute(
204+
original_width=input_img_w,
205+
original_height=input_img_h,
206+
model_width=model_w,
207+
model_height=model_h,
208+
resize_type=cast("Any", self.params).resize_type,
209+
)
210+
211+
n = raw_masks.shape[0]
212+
upsampled = np.zeros((n, model_h, model_w), dtype=np.float32)
213+
for i in range(n):
214+
upsampled[i] = resize_image_ocv(raw_masks[i], (model_w, model_h))
215+
216+
pad_t = resize_meta.pad_top
217+
pad_l = resize_meta.pad_left
218+
effective_w = round(input_img_w / resize_meta.inverted_scale_x)
219+
effective_h = round(input_img_h / resize_meta.inverted_scale_y)
220+
cropped = upsampled[:, pad_t : pad_t + effective_h, pad_l : pad_l + effective_w]
221+
222+
final_masks = np.zeros((n, input_img_h, input_img_w), dtype=np.uint8)
223+
for i in range(n):
224+
resized = resize_image_ocv(cropped[i], (input_img_w, input_img_h))
225+
final_masks[i] = (resized > 0.5).astype(np.uint8)
226+
227+
return final_masks
228+
229+
def _empty_result(self, meta: dict) -> InstanceSegmentationResult:
230+
"""Return an empty result when no detections pass filtering."""
231+
return InstanceSegmentationResult(
232+
bboxes=np.empty((0, 4), dtype=np.int32),
233+
scores=np.empty(0, dtype=np.float32),
234+
labels=np.empty(0, dtype=np.int32),
235+
masks=np.empty((0, meta["original_shape"][0], meta["original_shape"][1]), dtype=np.uint8),
236+
label_names=[],
237+
saliency_map=[],
238+
feature_vector=np.ndarray(0),
239+
)
240+
241+
@staticmethod
242+
def crop_mask(masks: np.ndarray, boxes: np.ndarray) -> np.ndarray:
243+
"""Zero-out mask pixels outside the bounding box.
244+
245+
Args:
246+
masks: Binary or float masks of shape ``(N, H, W)``.
247+
boxes: Bounding boxes ``(N, 4)`` in xyxy format, scaled to mask dims.
248+
249+
Returns:
250+
Cropped masks of shape ``(N, H, W)``.
251+
"""
252+
n, h, w = masks.shape
253+
rows = np.arange(h, dtype=np.float32).reshape(1, h, 1)
254+
cols = np.arange(w, dtype=np.float32).reshape(1, 1, w)
255+
x1 = boxes[:, 0].reshape(n, 1, 1)
256+
y1 = boxes[:, 1].reshape(n, 1, 1)
257+
x2 = boxes[:, 2].reshape(n, 1, 1)
258+
y2 = boxes[:, 3].reshape(n, 1, 1)
259+
inside = (cols >= x1) & (cols < x2) & (rows >= y1) & (rows < y2)
260+
return masks * inside

0 commit comments

Comments
 (0)