Skip to content

Commit fb4bb21

Browse files
[metrics] Add COCO mAP like ObjectDetection metric (#2061)
1 parent db9d92a commit fb4bb21

3 files changed

Lines changed: 445 additions & 12 deletions

File tree

docs/source/modules/utils.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,8 @@ Implementations of task-specific metrics to easily assess your model performance
4949

5050
.. automethod:: update
5151
.. automethod:: summary
52+
53+
.. autoclass:: ObjectDetectionMetric
54+
55+
.. automethod:: update
56+
.. automethod:: summary

doctr/utils/metrics.py

Lines changed: 263 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import numpy as np
88
from anyascii import anyascii
99
from scipy.optimize import linear_sum_assignment
10-
from shapely.geometry import Polygon
10+
from shapely import area, intersection, polygons
1111

1212
__all__ = [
1313
"TextMatch",
@@ -17,6 +17,7 @@
1717
"LocalizationConfusion",
1818
"OCRMetric",
1919
"DetectionMetric",
20+
"ObjectDetectionMetric",
2021
]
2122

2223

@@ -155,27 +156,31 @@ def polygon_iou(polys_1: np.ndarray, polys_2: np.ndarray) -> np.ndarray:
155156
Args:
156157
polys_1: rotated bounding boxes of shape (N, 4, 2)
157158
polys_2: rotated bounding boxes of shape (M, 4, 2)
158-
mask_shape: spatial shape of the intermediate masks
159-
use_broadcasting: if set to True, leverage broadcasting speedup by consuming more memory
160159
161160
Returns:
162161
the IoU matrix of shape (N, M)
163162
"""
164163
if polys_1.ndim != 3 or polys_2.ndim != 3:
165164
raise AssertionError("expects boxes to be in format (N, 4, 2)")
166165

167-
iou_mat = np.zeros((polys_1.shape[0], polys_2.shape[0]), dtype=np.float32)
166+
n, m = polys_1.shape[0], polys_2.shape[0]
167+
if n == 0 or m == 0:
168+
return np.zeros((n, m), dtype=np.float32)
168169

169-
shapely_polys_1 = [Polygon(poly) for poly in polys_1]
170-
shapely_polys_2 = [Polygon(poly) for poly in polys_2]
170+
geoms_1 = polygons(polys_1)
171+
geoms_2 = polygons(polys_2)
172+
grid_1 = np.repeat(geoms_1, m)
173+
grid_2 = np.tile(geoms_2, n)
171174

172-
for i, poly1 in enumerate(shapely_polys_1):
173-
for j, poly2 in enumerate(shapely_polys_2):
174-
intersection_area = poly1.intersection(poly2).area
175-
union_area = poly1.area + poly2.area - intersection_area
176-
iou_mat[i, j] = intersection_area / union_area
175+
# Compute intersections and areas
176+
intersections = area(intersection(grid_1, grid_2))
177+
areas_1 = area(grid_1)
178+
areas_2 = area(grid_2)
177179

178-
return iou_mat
180+
# Compute IoU
181+
unions = areas_1 + areas_2 - intersections
182+
iou_flat = np.divide(intersections, unions, out=np.zeros_like(intersections), where=unions > 0)
183+
return iou_flat.reshape(n, m).astype(np.float32)
179184

180185

181186
def nms(boxes: np.ndarray, thresh: float = 0.5) -> list[int]:
@@ -549,3 +554,249 @@ def reset(self) -> None:
549554
self.num_preds = 0
550555
self.tot_iou = 0.0
551556
self.num_matches = 0
557+
558+
559+
class ObjectDetectionMetric:
560+
r"""Implements a COCO-style object detection metric (mAP@[.5:.95]) inspired by the COCO evaluation protocol.
561+
The aggregated metrics are computed as follows:
562+
563+
.. math::
564+
565+
\forall (B, C) \in \mathcal{B}^N \times \mathcal{C}^N,
566+
\forall (\hat{B}, \hat{C}, S) \in \mathcal{B}^M \times \mathcal{C}^M \times \mathbb{R}^M, \\
567+
568+
AP_t(C) =
569+
\frac{1}{101}
570+
\sum\limits_{r \in \{0, 0.01, \dots, 1.0\}}
571+
\max_{\tilde{r} \geq r} Precision_t(\tilde{r}, C) \\
572+
573+
mAP@[.5:.95] =
574+
\frac{1}{|\mathcal{T}|}
575+
\sum\limits_{t \in \mathcal{T}}
576+
\frac{1}{|\mathcal{C}|}
577+
\sum\limits_{c \in \mathcal{C}} AP_t(c)
578+
579+
where:
580+
- :math:`\mathcal{B}` is the set of possible bounding boxes,
581+
- :math:`\mathcal{C}` is the set of possible class indices,
582+
- :math:`S` are confidence scores associated to predictions,
583+
- :math:`\mathcal{T} = \{0.5, 0.55, \dots, 0.95\}` is the set of IoU thresholds,
584+
- :math:`AP_t(c)` is the Average Precision for class :math:`c`
585+
at IoU threshold :math:`t`.
586+
587+
For a given class and IoU threshold, predictions from all images are
588+
aggregated and sorted globally by decreasing confidence score.
589+
590+
Each prediction is greedily matched to the unmatched ground-truth box
591+
with the highest IoU, provided that:
592+
- the IoU is greater than or equal to the threshold,
593+
- the ground-truth box has not already been matched.
594+
595+
True positives and false positives are accumulated to build a
596+
precision-recall curve.
597+
598+
Average Precision is computed using the COCO 101-point interpolated
599+
precision-recall curve.
600+
601+
>>> import numpy as np
602+
>>> from doctr.utils import ObjectDetectionMetric
603+
>>> metric = ObjectDetectionMetric()
604+
>>> metric.update(
605+
... np.asarray([[0, 0, 100, 100]]),
606+
... np.asarray([[0, 0, 80, 80], [120, 120, 200, 200]]),
607+
... np.asarray([0]),
608+
... np.asarray([0, 1]),
609+
... np.asarray([0.9, 0.3])
610+
... )
611+
>>> metric.summary()
612+
613+
Args:
614+
iou_thresholds: sequence of IoU thresholds used to compute the metric
615+
(defaults to np.arange(0.5, 1.0, 0.05))
616+
num_classes: total number of classes. If None, inferred from data
617+
use_polygons: if set to True, predictions and targets will be expected
618+
to have rotated format
619+
"""
620+
621+
def __init__(
622+
self,
623+
iou_thresholds: np.ndarray | None = None,
624+
num_classes: int | None = None,
625+
use_polygons: bool = False,
626+
) -> None:
627+
self.iou_thresholds = iou_thresholds if iou_thresholds is not None else np.round(np.arange(0.5, 1.0, 0.05), 2)
628+
self.num_classes = num_classes
629+
self.use_polygons = use_polygons
630+
self.reset()
631+
632+
def update(
633+
self,
634+
gt_boxes: np.ndarray,
635+
pred_boxes: np.ndarray,
636+
gt_labels: np.ndarray,
637+
pred_labels: np.ndarray,
638+
pred_scores: np.ndarray,
639+
) -> None:
640+
if (
641+
gt_boxes.shape[0] != gt_labels.shape[0]
642+
or pred_boxes.shape[0] != pred_labels.shape[0]
643+
or pred_boxes.shape[0] != pred_scores.shape[0]
644+
):
645+
raise AssertionError("Mismatch between boxes, labels, scores")
646+
647+
self._gts.append({"boxes": gt_boxes, "labels": gt_labels})
648+
self._preds.append({"boxes": pred_boxes, "labels": pred_labels, "scores": pred_scores})
649+
650+
def summary(self) -> dict[str, float | dict[float, float]]:
651+
"""Computes the aggregated metrics
652+
653+
Returns:
654+
a dictionary with the mAP@[.5:.95], AP@[.5], AP@[.75] and AP per IoU threshold
655+
"""
656+
if len(self._gts) == 0:
657+
raise AssertionError("No samples added")
658+
659+
# Determine classes
660+
if self.num_classes is None:
661+
labels = []
662+
for g in self._gts:
663+
labels.extend(g["labels"].tolist())
664+
for p in self._preds:
665+
labels.extend(p["labels"].tolist())
666+
classes = np.unique(labels)
667+
else:
668+
classes = np.arange(self.num_classes)
669+
670+
ap_per_iou = {}
671+
672+
for iou_thresh in self.iou_thresholds:
673+
class_aps = []
674+
675+
for c in classes:
676+
# Collect GTs per image
677+
gt_by_image = {}
678+
total_gt = 0
679+
680+
for img_idx, gt in enumerate(self._gts):
681+
mask = gt["labels"] == c
682+
gt_boxes = gt["boxes"][mask]
683+
684+
gt_by_image[img_idx] = {
685+
"boxes": gt_boxes,
686+
"matched": np.zeros(len(gt_boxes), dtype=bool),
687+
}
688+
689+
total_gt += len(gt_boxes)
690+
691+
if total_gt == 0:
692+
continue
693+
694+
# Collect all detections globally
695+
detections = []
696+
697+
for img_idx, pred in enumerate(self._preds):
698+
mask = pred["labels"] == c
699+
700+
pred_boxes = pred["boxes"][mask]
701+
pred_scores = pred["scores"][mask]
702+
703+
for box, score in zip(pred_boxes, pred_scores):
704+
detections.append({
705+
"image_id": img_idx,
706+
"box": box,
707+
"score": float(score),
708+
})
709+
710+
if len(detections) == 0:
711+
class_aps.append(0.0)
712+
continue
713+
714+
# Global sorting COCO-style
715+
detections.sort(key=lambda x: -x["score"])
716+
717+
tp = np.zeros(len(detections))
718+
fp = np.zeros(len(detections))
719+
720+
# Evaluate detections
721+
for det_idx, det in enumerate(detections):
722+
img_idx = det["image_id"]
723+
pred_box = det["box"]
724+
725+
gt_data = gt_by_image[img_idx]
726+
gt_boxes = gt_data["boxes"]
727+
728+
if len(gt_boxes) == 0:
729+
fp[det_idx] = 1
730+
continue
731+
732+
# Compute IoUs
733+
if self.use_polygons:
734+
iou_mat = polygon_iou(
735+
gt_boxes,
736+
np.expand_dims(pred_box, axis=0),
737+
)
738+
else:
739+
iou_mat = box_iou(
740+
gt_boxes,
741+
np.expand_dims(pred_box, axis=0),
742+
)
743+
744+
ious = iou_mat[:, 0]
745+
746+
best_gt = np.argmax(ious)
747+
best_iou = ious[best_gt]
748+
749+
if best_iou >= iou_thresh and not gt_data["matched"][best_gt]:
750+
tp[det_idx] = 1
751+
gt_data["matched"][best_gt] = True
752+
else:
753+
fp[det_idx] = 1
754+
755+
# Precision / Recall
756+
tp_cum = np.cumsum(tp)
757+
fp_cum = np.cumsum(fp)
758+
759+
recall = tp_cum / total_gt
760+
precision = tp_cum / np.maximum(tp_cum + fp_cum, 1e-8)
761+
762+
ap = self._compute_ap(recall, precision)
763+
class_aps.append(ap)
764+
765+
ap_per_iou[float(iou_thresh)] = float(np.mean(class_aps)) if len(class_aps) > 0 else 0.0
766+
767+
map_value = float(np.mean(list(ap_per_iou.values())))
768+
ap50 = ap_per_iou.get(0.5, 0.0)
769+
ap75 = ap_per_iou.get(0.75, 0.0)
770+
771+
return {
772+
"mAP@[.5:.95]": map_value,
773+
"AP@[.5]": ap50,
774+
"AP@[.75]": ap75,
775+
"AP_per_IoU": ap_per_iou,
776+
}
777+
778+
def _compute_ap(self, recall: np.ndarray, precision: np.ndarray) -> float:
779+
"""Computes the Average Precision using the 101-point interpolation method from COCO
780+
781+
Args:
782+
recall: array of recall values
783+
precision: array of precision values
784+
785+
Returns:
786+
the Average Precision score
787+
"""
788+
# 101-point interpolation as per COCO
789+
precision = np.maximum.accumulate(precision[::-1])[::-1]
790+
791+
recall_levels = np.linspace(0, 1, 101)
792+
precisions = np.zeros_like(recall_levels)
793+
794+
for i, r in enumerate(recall_levels):
795+
p = precision[recall >= r]
796+
precisions[i] = np.max(p) if p.size > 0 else 0.0
797+
798+
return float(np.mean(precisions))
799+
800+
def reset(self):
801+
self._gts = []
802+
self._preds = []

0 commit comments

Comments
 (0)