|
| 1 | +# |
| 2 | +# Copyright (C) 2026 Intel Corporation |
| 3 | +# SPDX-License-Identifier: Apache-2.0 |
| 4 | +# |
| 5 | + |
| 6 | +"""Custom ModelAPI wrapper for Ultralytics YOLO instance-segmentation inference.""" |
| 7 | + |
| 8 | +from __future__ import annotations |
| 9 | + |
| 10 | +from typing import Any, cast |
| 11 | + |
| 12 | +import numpy as np |
| 13 | + |
| 14 | +from model_api.adapters.utils import resize_image_ocv |
| 15 | +from model_api.models.detection_model import DetectionModel |
| 16 | +from model_api.models.result import InstanceSegmentationResult |
| 17 | +from model_api.models.utils import ResizeMetadata |
| 18 | +from model_api.models.yolo import xywh2xyxy |
| 19 | + |
| 20 | + |
| 21 | +class YOLOSeg(DetectionModel): |
| 22 | + """ModelAPI wrapper for YOLO instance-segmentation models. |
| 23 | +
|
| 24 | + Expects 2 outputs: |
| 25 | + * detection output: ``[1, 4 + num_classes + mask_dim, num_boxes]`` |
| 26 | + * prototype output: ``[1, mask_dim, proto_h, proto_w]`` |
| 27 | +
|
| 28 | + Post-processing: |
| 29 | + 1. Parse detection predictions (boxes + class scores + mask coefficients). |
| 30 | + 2. Filter by confidence, apply NMS. |
| 31 | + 3. Decode masks: ``coefficients @ protos.reshape(mask_dim, -1)`` → sigmoid → crop → resize. |
| 32 | + 4. Return ``InstanceSegmentationResult``. |
| 33 | + """ |
| 34 | + |
| 35 | + __model__ = "YOLO-seg" |
| 36 | + |
| 37 | + def __init__(self, inference_adapter: object, configuration: dict | None = None, preload: bool = False) -> None: |
| 38 | + super().__init__(inference_adapter, configuration or {}, preload) |
| 39 | + self._check_io_number(1, 2) |
| 40 | + |
| 41 | + self._det_output_name: str = "" |
| 42 | + self._proto_output_name: str = "" |
| 43 | + outputs = cast("dict[str, Any]", self.outputs or {}) |
| 44 | + |
| 45 | + for name, output in outputs.items(): |
| 46 | + shape = output.shape |
| 47 | + if len(shape) == 3: |
| 48 | + self._det_output_name = name |
| 49 | + elif len(shape) == 4: |
| 50 | + self._proto_output_name = name |
| 51 | + |
| 52 | + if not self._det_output_name or not self._proto_output_name: |
| 53 | + self.raise_error( |
| 54 | + "Expected one rank-3 detection output and one rank-4 prototype output, " |
| 55 | + f"but got shapes: {[(name, out.shape) for name, out in outputs.items()]}", |
| 56 | + ) |
| 57 | + |
| 58 | + det_shape = outputs[self._det_output_name].shape |
| 59 | + proto_shape = outputs[self._proto_output_name].shape |
| 60 | + self._mask_dim = proto_shape[1] |
| 61 | + self._proto_h = proto_shape[2] |
| 62 | + self._proto_w = proto_shape[3] |
| 63 | + |
| 64 | + self._num_classes = det_shape[1] - 4 - self._mask_dim |
| 65 | + if self._num_classes <= 0: |
| 66 | + self.raise_error(f"Detection output channel dim ({det_shape[1]}) must be > 4 + mask_dim ({self._mask_dim})") |
| 67 | + |
| 68 | + @classmethod |
| 69 | + def parameters(cls): |
| 70 | + parameters = super().parameters() |
| 71 | + parameters["pad_value"].update_default_value(114) |
| 72 | + parameters["resize_type"].update_default_value("fit_to_window_letterbox") |
| 73 | + parameters["reverse_input_channels"].update_default_value(default_value=False) |
| 74 | + parameters["scale_values"].update_default_value([255.0]) |
| 75 | + parameters["confidence_threshold"].update_default_value(0.25) |
| 76 | + parameters["iou_threshold"].update_default_value(0.5) |
| 77 | + return parameters |
| 78 | + |
| 79 | + def postprocess(self, outputs: dict[str, Any], meta: dict[str, Any]) -> InstanceSegmentationResult: |
| 80 | + """Decode detections and instance masks from raw model outputs. |
| 81 | +
|
| 82 | + Args: |
| 83 | + outputs: Raw model outputs keyed by output tensor name. |
| 84 | + meta: Preprocessing metadata from ModelAPI (original_shape, etc.). |
| 85 | +
|
| 86 | + Returns: |
| 87 | + InstanceSegmentationResult with boxes in original image coordinates |
| 88 | + and binary masks at original image resolution. |
| 89 | + """ |
| 90 | + det_output = outputs[self._det_output_name] |
| 91 | + proto_output = outputs[self._proto_output_name] |
| 92 | + |
| 93 | + prediction = det_output.astype(np.float32) |
| 94 | + protos = proto_output[0].astype(np.float32) |
| 95 | + |
| 96 | + pred = prediction[0].T |
| 97 | + |
| 98 | + boxes_xywh = pred[:, :4] |
| 99 | + class_scores = pred[:, 4 : 4 + self._num_classes] |
| 100 | + mask_coeffs = pred[:, 4 + self._num_classes :] |
| 101 | + |
| 102 | + params = cast("Any", self.params) |
| 103 | + conf_threshold = params.confidence_threshold |
| 104 | + max_scores = class_scores.max(axis=1) |
| 105 | + keep_conf = max_scores > conf_threshold |
| 106 | + |
| 107 | + if not keep_conf.any(): |
| 108 | + return self._empty_result(meta) |
| 109 | + |
| 110 | + boxes_xywh = boxes_xywh[keep_conf] |
| 111 | + class_scores = class_scores[keep_conf] |
| 112 | + mask_coeffs = mask_coeffs[keep_conf] |
| 113 | + |
| 114 | + labels = class_scores.argmax(axis=1) |
| 115 | + confidences = class_scores[np.arange(len(labels)), labels] |
| 116 | + |
| 117 | + boxes_xyxy = xywh2xyxy(boxes_xywh.copy()) |
| 118 | + |
| 119 | + keep_nms = self._calculate_nms( |
| 120 | + boxes=boxes_xyxy, |
| 121 | + scores=confidences, |
| 122 | + labels=labels.astype(np.float32), |
| 123 | + ) |
| 124 | + boxes_xyxy = boxes_xyxy[keep_nms] |
| 125 | + confidences = confidences[keep_nms] |
| 126 | + labels = labels[keep_nms] |
| 127 | + mask_coeffs = mask_coeffs[keep_nms] |
| 128 | + |
| 129 | + masks = self._decode_masks(mask_coeffs, protos, boxes_xyxy, meta) |
| 130 | + |
| 131 | + input_img_w = meta["original_shape"][1] |
| 132 | + input_img_h = meta["original_shape"][0] |
| 133 | + resize_meta = ResizeMetadata.compute( |
| 134 | + original_width=input_img_w, |
| 135 | + original_height=input_img_h, |
| 136 | + model_width=self.orig_width, |
| 137 | + model_height=self.orig_height, |
| 138 | + resize_type=params.resize_type, |
| 139 | + ) |
| 140 | + |
| 141 | + coords = boxes_xyxy.copy() |
| 142 | + coords -= (resize_meta.pad_left, resize_meta.pad_top, resize_meta.pad_left, resize_meta.pad_top) |
| 143 | + coords *= ( |
| 144 | + resize_meta.inverted_scale_x, |
| 145 | + resize_meta.inverted_scale_y, |
| 146 | + resize_meta.inverted_scale_x, |
| 147 | + resize_meta.inverted_scale_y, |
| 148 | + ) |
| 149 | + |
| 150 | + int_boxes = np.round(coords).astype(np.int32) |
| 151 | + np.clip( |
| 152 | + int_boxes, |
| 153 | + 0, |
| 154 | + [input_img_w, input_img_h, input_img_w, input_img_h], |
| 155 | + out=int_boxes, |
| 156 | + ) |
| 157 | + |
| 158 | + int_labels = labels.astype(np.int32) |
| 159 | + return InstanceSegmentationResult( |
| 160 | + bboxes=int_boxes, |
| 161 | + scores=confidences, |
| 162 | + labels=int_labels + 1, |
| 163 | + masks=masks, |
| 164 | + label_names=[self.get_label_name(i) for i in int_labels], |
| 165 | + saliency_map=[], |
| 166 | + feature_vector=np.ndarray(0), |
| 167 | + ) |
| 168 | + |
| 169 | + def _decode_masks( |
| 170 | + self, |
| 171 | + mask_coeffs: np.ndarray, |
| 172 | + protos: np.ndarray, |
| 173 | + boxes_xyxy: np.ndarray, |
| 174 | + meta: dict, |
| 175 | + ) -> np.ndarray: |
| 176 | + """Decode instance masks from mask coefficients and prototypes. |
| 177 | +
|
| 178 | + Args: |
| 179 | + mask_coeffs: Mask coefficients ``(N, mask_dim)``. |
| 180 | + protos: Prototype masks ``(mask_dim, proto_h, proto_w)``. |
| 181 | + boxes_xyxy: Bounding boxes in model input coordinates ``(N, 4)``. |
| 182 | + meta: Preprocessing metadata (original_shape, etc.). |
| 183 | +
|
| 184 | + Returns: |
| 185 | + Binary masks at original image resolution ``(N, orig_h, orig_w)``. |
| 186 | + """ |
| 187 | + mask_dim, proto_h, proto_w = protos.shape |
| 188 | + raw_masks = mask_coeffs @ protos.reshape(mask_dim, -1) |
| 189 | + raw_masks = raw_masks.reshape(-1, proto_h, proto_w) |
| 190 | + |
| 191 | + raw_masks = 1.0 / (1.0 + np.exp(-raw_masks)) |
| 192 | + |
| 193 | + model_h, model_w = self.orig_height, self.orig_width |
| 194 | + scale_x = proto_w / model_w |
| 195 | + scale_y = proto_h / model_h |
| 196 | + proto_boxes = boxes_xyxy * np.array([scale_x, scale_y, scale_x, scale_y], dtype=np.float32) |
| 197 | + |
| 198 | + raw_masks = self.crop_mask(raw_masks, proto_boxes) |
| 199 | + |
| 200 | + input_img_h = meta["original_shape"][0] |
| 201 | + input_img_w = meta["original_shape"][1] |
| 202 | + |
| 203 | + resize_meta = ResizeMetadata.compute( |
| 204 | + original_width=input_img_w, |
| 205 | + original_height=input_img_h, |
| 206 | + model_width=model_w, |
| 207 | + model_height=model_h, |
| 208 | + resize_type=cast("Any", self.params).resize_type, |
| 209 | + ) |
| 210 | + |
| 211 | + n = raw_masks.shape[0] |
| 212 | + upsampled = np.zeros((n, model_h, model_w), dtype=np.float32) |
| 213 | + for i in range(n): |
| 214 | + upsampled[i] = resize_image_ocv(raw_masks[i], (model_w, model_h)) |
| 215 | + |
| 216 | + pad_t = resize_meta.pad_top |
| 217 | + pad_l = resize_meta.pad_left |
| 218 | + effective_w = round(input_img_w / resize_meta.inverted_scale_x) |
| 219 | + effective_h = round(input_img_h / resize_meta.inverted_scale_y) |
| 220 | + cropped = upsampled[:, pad_t : pad_t + effective_h, pad_l : pad_l + effective_w] |
| 221 | + |
| 222 | + final_masks = np.zeros((n, input_img_h, input_img_w), dtype=np.uint8) |
| 223 | + for i in range(n): |
| 224 | + resized = resize_image_ocv(cropped[i], (input_img_w, input_img_h)) |
| 225 | + final_masks[i] = (resized > 0.5).astype(np.uint8) |
| 226 | + |
| 227 | + return final_masks |
| 228 | + |
| 229 | + def _empty_result(self, meta: dict) -> InstanceSegmentationResult: |
| 230 | + """Return an empty result when no detections pass filtering.""" |
| 231 | + return InstanceSegmentationResult( |
| 232 | + bboxes=np.empty((0, 4), dtype=np.int32), |
| 233 | + scores=np.empty(0, dtype=np.float32), |
| 234 | + labels=np.empty(0, dtype=np.int32), |
| 235 | + masks=np.empty((0, meta["original_shape"][0], meta["original_shape"][1]), dtype=np.uint8), |
| 236 | + label_names=[], |
| 237 | + saliency_map=[], |
| 238 | + feature_vector=np.ndarray(0), |
| 239 | + ) |
| 240 | + |
| 241 | + @staticmethod |
| 242 | + def crop_mask(masks: np.ndarray, boxes: np.ndarray) -> np.ndarray: |
| 243 | + """Zero-out mask pixels outside the bounding box. |
| 244 | +
|
| 245 | + Args: |
| 246 | + masks: Binary or float masks of shape ``(N, H, W)``. |
| 247 | + boxes: Bounding boxes ``(N, 4)`` in xyxy format, scaled to mask dims. |
| 248 | +
|
| 249 | + Returns: |
| 250 | + Cropped masks of shape ``(N, H, W)``. |
| 251 | + """ |
| 252 | + n, h, w = masks.shape |
| 253 | + rows = np.arange(h, dtype=np.float32).reshape(1, h, 1) |
| 254 | + cols = np.arange(w, dtype=np.float32).reshape(1, 1, w) |
| 255 | + x1 = boxes[:, 0].reshape(n, 1, 1) |
| 256 | + y1 = boxes[:, 1].reshape(n, 1, 1) |
| 257 | + x2 = boxes[:, 2].reshape(n, 1, 1) |
| 258 | + y2 = boxes[:, 3].reshape(n, 1, 1) |
| 259 | + inside = (cols >= x1) & (cols < x2) & (rows >= y1) & (rows < y2) |
| 260 | + return masks * inside |
0 commit comments