Skip to content

Commit db78eb8

Browse files
mgumowskCopilotivanzati
authored
feat: model_converter maskrcnn support (#567)
* maskrcnn Co-authored-by: Copilot <copilot@github.com> * small refactor Co-authored-by: Copilot <copilot@github.com> * mypy Co-authored-by: Copilot <copilot@github.com> * unit tests Co-authored-by: Copilot <copilot@github.com> * ruff Co-authored-by: Copilot <copilot@github.com> * ruff * pytest github actions --------- Co-authored-by: Copilot <copilot@github.com> Co-authored-by: Ivan Zakharov <ivan.zakharov@intel.com>
1 parent 35ba75d commit db78eb8

22 files changed

Lines changed: 3924 additions & 1189 deletions

.github/workflows/pre_commit.yml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,9 +152,12 @@ jobs:
152152

153153
- *install-dependencies
154154

155-
- name: Run python unit tests
155+
- name: Run model_api unit tests
156156
run: uv --directory model_api run pytest tests/unit --cov
157157

158+
- name: Run model_converter unit tests
159+
run: uv --directory model_converter run --group tests pytest tests/unit --cov
160+
158161
- &prepare-test-data
159162
name: Prepare test data
160163
run: |

model_converter/examples/config.json

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,35 @@
618618
"scale_values": "58.395 57.12 57.375",
619619
"license": "apache-2.0",
620620
"license_link": "https://spdx.org/licenses/Apache-2.0.html"
621+
},
622+
{
623+
"model_short_name": "maskrcnn_resnet50_fpn",
624+
"model_class_name": "torchvision.models.detection.maskrcnn_resnet50_fpn",
625+
"model_library": "torchvision",
626+
"model_full_name": "Mask R-CNN ResNet-50 FPN",
627+
"description": "Mask R-CNN with a ResNet-50-FPN backbone trained on COCO for object detection and instance segmentation",
628+
"docs": "https://docs.pytorch.org/vision/main/models/generated/torchvision.models.detection.maskrcnn_resnet50_fpn.html",
629+
"weights_url": "https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth",
630+
"input_shape": [1, 3, 800, 800],
631+
"input_names": ["image"],
632+
"output_names": ["boxes", "labels", "masks"],
633+
"model_params": null,
634+
"model_type": "MaskRCNN",
635+
"reverse_input_channels": true,
636+
"mean_values": "0 0 0",
637+
"scale_values": "255 255 255",
638+
"resize_type": "fit_to_window_letterbox",
639+
"pad_value": 0,
640+
"input_dtype": "u8",
641+
"confidence_threshold": 0.5,
642+
"postprocess_semantic_masks": true,
643+
"nms_execute": false,
644+
"iou_threshold": 0.5,
645+
"agnostic_nms": false,
646+
"nms_max_predictions": 200,
647+
"license": "bsd-3-clause",
648+
"license_link": "https://spdx.org/licenses/BSD-3-Clause.html",
649+
"labels": "COCO_V1"
621650
}
622651
]
623652
}

model_converter/pyproject.toml

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,9 @@ fixable = ["ALL"]
178178
unfixable = []
179179
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
180180

181+
[tool.ruff.lint.per-file-ignores]
182+
"**/tests/**/*.py" = ["SLF001", "FBT003"]
183+
181184
[tool.ruff.lint.mccabe]
182185
max-complexity = 15
183186

@@ -193,6 +196,16 @@ notice-rgx = """
193196
[tool.bandit]
194197
skips = ["B101", "B310"]
195198

199+
[tool.pytest.ini_options]
200+
pythonpath = ["src"]
201+
202+
[tool.coverage.run]
203+
source = ["model_converter"]
204+
196205
[tool.coverage.report]
197-
fail_under = 45
206+
fail_under = 100
198207
show_missing = true
208+
exclude_lines = [
209+
"if __name__ == .__main__.",
210+
"if TYPE_CHECKING:",
211+
]

model_converter/src/model_converter/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,6 @@
55

66
"""Tools for converting models to OpenVINO IR."""
77

8-
from .model_converter import ModelConverter, list_models, main
8+
from .cli import ModelConverter, list_models, main
99

1010
__all__ = ["ModelConverter", "list_models", "main"]

model_converter/src/model_converter/__main__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55

66
"""Run the model converter with ``python -m model_converter``."""
77

8-
from .model_converter import main
8+
import sys
99

10-
if __name__ == "__main__":
11-
raise SystemExit(main())
10+
from model_converter.cli import main
11+
12+
sys.exit(main())
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#
2+
# Copyright (C) 2026 Intel Corporation
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
6+
"""Export adapters for different model types."""
7+
8+
from __future__ import annotations
9+
10+
from typing import TYPE_CHECKING
11+
12+
from model_converter.adapters.base import ExportAdapter
13+
from model_converter.adapters.maskrcnn import TorchvisionMaskRCNNExportAdapter
14+
15+
if TYPE_CHECKING:
16+
import torch
17+
18+
_ADAPTER_REGISTRY: dict[str, type[ExportAdapter]] = {
19+
"maskrcnn": TorchvisionMaskRCNNExportAdapter,
20+
}
21+
22+
23+
def get_adapter(model_type: str, model: torch.nn.Module) -> torch.nn.Module:
24+
"""
25+
Get the appropriate export adapter for a model type.
26+
27+
If no adapter is registered for the model type, returns the model unchanged.
28+
29+
Args:
30+
model_type: Model type string (e.g., "MaskRCNN")
31+
model: The PyTorch model to adapt
32+
33+
Returns:
34+
Adapted model (or original model if no adapter needed)
35+
"""
36+
adapter_class = _ADAPTER_REGISTRY.get(model_type.lower())
37+
if adapter_class is not None:
38+
return adapter_class(model)
39+
return model
40+
41+
42+
__all__ = ["ExportAdapter", "TorchvisionMaskRCNNExportAdapter", "get_adapter"]
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#
2+
# Copyright (C) 2026 Intel Corporation
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
6+
"""Base export adapter interface."""
7+
8+
import torch.nn as nn
9+
10+
11+
class ExportAdapter(nn.Module):
12+
"""Base class for export adapters that reshape model outputs for Model API."""
13+
14+
def __init__(self, model: nn.Module):
15+
super().__init__()
16+
self.model = model
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#
2+
# Copyright (C) 2026 Intel Corporation
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
6+
"""Mask R-CNN export adapter for TorchVision models."""
7+
8+
from collections import OrderedDict
9+
10+
import torch
11+
12+
from model_converter.adapters.base import ExportAdapter
13+
14+
15+
class TorchvisionMaskRCNNExportAdapter(ExportAdapter):
16+
"""Adapt TorchVision Mask R-CNN to the Model API MaskRCNN output contract."""
17+
18+
def forward(self, images: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
19+
"""Return boxes-with-scores, shifted labels, and raw masks for one image."""
20+
image_list = [images[0]]
21+
transformed_images, _ = self.model.transform(image_list, None)
22+
features = self.model.backbone(transformed_images.tensors)
23+
if isinstance(features, torch.Tensor):
24+
features = OrderedDict([("0", features)])
25+
proposals, _ = self.model.rpn(transformed_images, features, None)
26+
predictions, _ = self.model.roi_heads(features, proposals, transformed_images.image_sizes, None)
27+
prediction = predictions[0]
28+
boxes = torch.cat((prediction["boxes"], prediction["scores"].unsqueeze(1)), dim=1)
29+
labels = prediction["labels"] - 1
30+
masks = prediction["masks"].squeeze(1)
31+
return boxes, labels, masks

0 commit comments

Comments
 (0)