Skip to content

Commit bdfdcd0

Browse files
authored
YOLO11 support (#468)
* Add support for YOLO11 * style fix
1 parent 66a5772 commit bdfdcd0

2 files changed

Lines changed: 12 additions & 5 deletions

File tree

src/model_api/models/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
get_contours,
3535
)
3636
from .visual_prompting import Prompt, SAMLearnableVisualPrompter, SAMVisualPrompter
37-
from .yolo import YOLO, YOLOF, YOLOX, YoloV3ONNX, YoloV4, YOLOv5, YOLOv8
37+
from .yolo import YOLO, YOLO11, YOLOF, YOLOX, YoloV3ONNX, YoloV4, YOLOv5, YOLOv8
3838

3939
classification_models = [
4040
"resnet-18-pytorch",
@@ -92,6 +92,7 @@
9292
"TopDownKeypointDetectionPipeline",
9393
"VisualPromptingResult",
9494
"YOLO",
95+
"YOLO11",
9596
"YOLOF",
9697
"YOLOv3ONNX",
9798
"YOLOv4",

src/model_api/models/yolo.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -746,7 +746,7 @@ def __init__(self, inference_adapter, configuration, preload=False):
746746
out_shape = output.shape
747747
if len(out_shape) != 3:
748748
self.raise_error("the output must be of rank 3")
749-
if self.labels and len(self.labels) + 4 != out_shape[1]:
749+
if self.params.labels and len(self.params.labels) + 4 != out_shape[1]:
750750
self.raise_error("number of labels must be smaller than out_shape[1] by 4")
751751

752752
@classmethod
@@ -799,20 +799,20 @@ def postprocess(self, outputs, meta) -> DetectionResult:
799799
)
800800
keep_top_k = 30000
801801
iou_threshold = self.params.iou_threshold
802-
if self.agnostic_nms: # type: ignore[attr-defined]
802+
if self.params.agnostic_nms:
803803
boxes = boxes[
804804
nms(
805805
boxes[:, 2],
806806
boxes[:, 3],
807807
boxes[:, 4],
808808
boxes[:, 5],
809809
boxes[:, 1],
810-
iou_threshold, # type: ignore[attr-defined]
810+
iou_threshold,
811811
keep_top_k=keep_top_k,
812812
)
813813
]
814814
else:
815-
boxes, _ = multiclass_nms(boxes, iou_threshold, keep_top_k) # type: ignore[attr-defined]
815+
boxes, _ = multiclass_nms(boxes, iou_threshold, keep_top_k)
816816
inputImgWidth = meta["original_shape"][1]
817817
inputImgHeight = meta["original_shape"][0]
818818
resize_meta = ResizeMetadata.compute(
@@ -853,3 +853,9 @@ class YOLOv8(YOLOv5):
853853
"""YOLOv5 and YOLOv8 are identical in terms of inference"""
854854

855855
__model__ = "YOLOv8"
856+
857+
858+
class YOLO11(YOLOv5):
859+
"""YOLO11 uses the same inference approach as YOLOv5 and YOLOv8"""
860+
861+
__model__ = "YOLO11"

0 commit comments

Comments
 (0)