Skip to content

Commit 1951d64

Browse files
authored
fix: support full-image masks in instance segmentation postprocessing (#588)
* feat: add DETRInstanceSegmentation model wrapper for full-image masks DETR-family instance segmentation models (e.g. RF-DETR-Seg) output full-image masks at reduced resolution (input_size/4) rather than per-box crop masks (28x28) like Mask R-CNN. Add DETRInstanceSegmentation class (__model__ = "DETRInstSeg") that inherits from MaskRCNNModel and overrides postprocess() to resize masks to original image dimensions directly, instead of the box-crop placement logic used by MaskRCNNModel. This follows the same pattern as SSD vs YOLO for detection -- different architectures get different model wrappers, selected via model_type in the exported model's rt_info. MaskRCNNModel remains unchanged for backward compatibility. Resolves: open-edge-platform/geti#6488 * test: add unit tests for DETRInstanceSegmentation postprocess Tests cover: - _full_image_mask_postprocess: resize, threshold, dtype, spatial pattern preservation - Comparison between full-image and per-box-crop postprocessing approaches - DETRInstanceSegmentation.postprocess: basic flow, batch dim squeezing, confidence filtering, empty results, label increment, label names, mask positioning (verifies masks are NOT shifted to box position), multiple detections, class attributes, and inheritance * refactor: extract InstanceSegmentationModel base class Introduce InstanceSegmentationModel as the common base for both MaskRCNNModel and DETRInstanceSegmentation. The base class contains all shared logic: initialization, output detection, preprocessing, box rescaling, confidence/area filtering, and NMS. Subclasses only need to implement _postprocess_single_mask(): - MaskRCNNModel: per-box-crop postprocess (_segm_postprocess) - DETRInstanceSegmentation: full-image resize (_full_image_mask_postprocess) This eliminates the duplicated postprocess code and makes the hierarchy cleanly express the architectural difference between the two approaches. Also updates the tiler to use InstanceSegmentationModel for isinstance checks, and adds tests verifying the new hierarchy. * fix: resolve ruff lint errors (import sorting, unused var, naming)
1 parent fce6479 commit 1951d64

5 files changed

Lines changed: 371 additions & 10 deletions

File tree

model_api/src/model_api/models/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from .classification import ClassificationModel
99
from .detection_model import DetectionModel
1010
from .image_model import ImageModel
11-
from .instance_segmentation import MaskRCNNModel
11+
from .instance_segmentation import DETRInstanceSegmentation, InstanceSegmentationModel, MaskRCNNModel
1212
from .keypoint_detection import KeypointDetectionModel, TopDownKeypointDetectionPipeline
1313
from .model import Model
1414
from .result import (
@@ -68,9 +68,11 @@
6868
"DetectedKeypoints",
6969
"DetectionModel",
7070
"DetectionResult",
71+
"DETRInstanceSegmentation",
7172
"get_contours",
7273
"ImageModel",
7374
"ImageResultWithSoftPrediction",
75+
"InstanceSegmentationModel",
7476
"InstanceSegmentationResult",
7577
"KeypointDetectionModel",
7678
"Label",

model_api/src/model_api/models/instance_segmentation.py

Lines changed: 58 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
# SPDX-License-Identifier: Apache-2.0
44
#
55

6+
from abc import abstractmethod
7+
68
import cv2
79
import numpy as np
810

@@ -14,8 +16,13 @@
1416
from .utils import ResizeMetadata, calculate_nms, load_labels
1517

1618

17-
class MaskRCNNModel(ImageModel):
18-
__model__ = "MaskRCNN"
19+
class InstanceSegmentationModel(ImageModel):
20+
"""Base class for instance segmentation models.
21+
22+
Handles common initialization, output detection, preprocessing, box rescaling,
23+
confidence filtering, and NMS. Subclasses implement mask-specific postprocessing
24+
via `_postprocess_single_mask`.
25+
"""
1926

2027
def __init__(self, inference_adapter: InferenceAdapter, configuration: dict = {}, preload: bool = False) -> None:
2128
super().__init__(inference_adapter, configuration, preload)
@@ -107,6 +114,20 @@ def preprocess(self, dict_inputs: dict, meta: dict) -> tuple[dict, dict]:
107114
dict_inputs[self.image_info_blob_names[0]] = input_image_info
108115
return dict_inputs, meta
109116

117+
@abstractmethod
118+
def _postprocess_single_mask(self, box: np.ndarray, raw_cls_mask: np.ndarray, im_h: int, im_w: int) -> np.ndarray:
119+
"""Process a single raw mask into a full-image binary mask.
120+
121+
Args:
122+
box: Bounding box [x1, y1, x2, y2] in original image coordinates.
123+
raw_cls_mask: Raw mask output from the model (2D array).
124+
im_h: Original image height.
125+
im_w: Original image width.
126+
127+
Returns:
128+
Binary mask of shape (im_h, im_w) with dtype uint8.
129+
"""
130+
110131
def postprocess(self, outputs: dict, meta: dict) -> InstanceSegmentationResult:
111132
if (
112133
outputs[self.output_blob_name["labels"]].ndim == 2
@@ -213,7 +234,7 @@ def postprocess(self, outputs: dict, meta: dict) -> InstanceSegmentationResult:
213234

214235
raw_cls_mask = raw_mask[label_idx, ...] if self.is_segmentoly else raw_mask
215236
if self.params.postprocess_semantic_masks or has_feature_vector_name:
216-
resized_mask = _segm_postprocess(
237+
resized_mask = self._postprocess_single_mask(
217238
box,
218239
raw_cls_mask,
219240
*meta["original_shape"][:-1],
@@ -226,18 +247,44 @@ def postprocess(self, outputs: dict, meta: dict) -> InstanceSegmentationResult:
226247
if has_feature_vector_name:
227248
saliency_maps[label_idx - 1].append(resized_mask)
228249

229-
_masks = np.stack(resized_masks) if len(resized_masks) > 0 else np.empty((0, 16, 16), dtype=np.uint8)
250+
result_masks = np.stack(resized_masks) if len(resized_masks) > 0 else np.empty((0, 16, 16), dtype=np.uint8)
230251
return InstanceSegmentationResult(
231252
bboxes=boxes,
232253
labels=labels,
233254
scores=scores,
234-
masks=_masks,
255+
masks=result_masks,
235256
label_names=label_names or None,
236257
saliency_map=_average_and_normalize(saliency_maps),
237258
feature_vector=outputs.get(_feature_vector_name, np.ndarray(0)),
238259
)
239260

240261

262+
class MaskRCNNModel(InstanceSegmentationModel):
263+
"""Instance segmentation model for Mask R-CNN-style architectures.
264+
265+
Uses per-box-crop mask postprocessing: resizes the small mask (e.g. 28x28)
266+
to the bounding box dimensions and places it at the box position.
267+
"""
268+
269+
__model__ = "MaskRCNN"
270+
271+
def _postprocess_single_mask(self, box: np.ndarray, raw_cls_mask: np.ndarray, im_h: int, im_w: int) -> np.ndarray:
272+
return _segm_postprocess(box, raw_cls_mask, im_h, im_w)
273+
274+
275+
class DETRInstanceSegmentation(InstanceSegmentationModel):
276+
"""Instance segmentation model for DETR-family architectures (e.g. RF-DETR-Seg).
277+
278+
Uses full-image mask postprocessing: resizes the mask (e.g. 96x96 covering the
279+
entire image) to the original image dimensions and applies a threshold.
280+
"""
281+
282+
__model__ = "DETRInstSeg"
283+
284+
def _postprocess_single_mask(self, box: np.ndarray, raw_cls_mask: np.ndarray, im_h: int, im_w: int) -> np.ndarray:
285+
return _full_image_mask_postprocess(raw_cls_mask, im_h, im_w)
286+
287+
241288
def _average_and_normalize(saliency_maps: list) -> list:
242289
aggregated = []
243290
for per_object_maps in saliency_maps:
@@ -286,6 +333,12 @@ def _segm_postprocess(box: np.ndarray, raw_cls_mask: np.ndarray, im_h: int, im_w
286333
return im_mask
287334

288335

336+
def _full_image_mask_postprocess(raw_cls_mask: np.ndarray, im_h: int, im_w: int) -> np.ndarray:
337+
"""Resize a full-image mask to original dimensions and threshold."""
338+
resized = cv2.resize(raw_cls_mask.astype(np.float32), (im_w, im_h), interpolation=cv2.INTER_LINEAR)
339+
return (resized > 0.5).astype(np.uint8)
340+
341+
289342
_saliency_map_name = "saliency_map"
290343
_feature_vector_name = "feature_vector"
291344

model_api/src/model_api/models/model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,12 +84,14 @@ def __init__(self, inference_adapter: InferenceAdapter, configuration: dict = {}
8484
ONNXRuntimeAdapter,
8585
) and self.__model__ not in {
8686
"Classification",
87+
"DETRInstSeg",
8788
"MaskRCNN",
8889
"SSD",
8990
"Segmentation",
9091
}:
9192
self.raise_error(
92-
"ONNXRuntimeAdapter is only supported for Classification, MaskRCNN, SSD, and Segmentation wrappers",
93+
"ONNXRuntimeAdapter is only supported for Classification, DETRInstSeg, MaskRCNN, SSD,"
94+
" and Segmentation wrappers",
9395
)
9496

9597
self.inputs = self.inference_adapter.get_input_layers()

model_api/src/model_api/tilers/instance_segmentation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import numpy as np
1010

1111
from model_api.models import InstanceSegmentationResult
12-
from model_api.models.instance_segmentation import MaskRCNNModel, _segm_postprocess
12+
from model_api.models.instance_segmentation import InstanceSegmentationModel, _segm_postprocess
1313
from model_api.models.utils import multiclass_nms
1414

1515
from .detection import DetectionTiler
@@ -194,13 +194,13 @@ def __call__(self, inputs):
194194
@contextmanager
195195
def setup_maskrcnn(*args, **kwds):
196196
postprocess_state = None
197-
if isinstance(self.model, MaskRCNNModel):
197+
if isinstance(self.model, InstanceSegmentationModel):
198198
postprocess_state = self.model.params.postprocess_semantic_masks
199199
self.model._postprocess_semantic_masks = False # noqa: SLF001
200200
try:
201201
yield
202202
finally:
203-
if isinstance(self.model, MaskRCNNModel):
203+
if isinstance(self.model, InstanceSegmentationModel):
204204
self.model._postprocess_semantic_masks = postprocess_state # noqa: SLF001
205205

206206
with setup_maskrcnn():

0 commit comments

Comments
 (0)