diff --git a/faster_coco_eval/core/faster_eval_api.py b/faster_coco_eval/core/faster_eval_api.py index 2c56290..ed65811 100644 --- a/faster_coco_eval/core/faster_eval_api.py +++ b/faster_coco_eval/core/faster_eval_api.py @@ -229,33 +229,93 @@ def extended_metrics(self): Notes: - Uses COCO-style evaluation results (precision and scores arrays). - Filters out classes with NaN results in any metric. - - The best F1-score across recall thresholds is used to select macro precision and recall. + - The best F1-score across confidence thresholds is used to select macro precision and recall. + - Precision and recall are computed from actual (non-interpolated) detection data to avoid + over-estimating precision when false positives exist below the recall ceiling. """ - # Extract IoU and recall thresholds from parameters - iou_thrs, rec_thrs = self.params.iouThrs, self.params.recThrs + # Extract IoU thresholds from parameters + iou_thrs = self.params.iouThrs # Indices for IoU=0.50, first area, and last max dets iou50_idx, area_idx, maxdet_idx = (np.argwhere(np.isclose(iou_thrs, 0.50)).item(), 0, -1) P = self.eval["precision"] - S = self.eval["scores"] - # Get precision for IoU=0.50, area, and max dets - prec_raw = P[iou50_idx, :, :, area_idx, maxdet_idx] - prec = prec_raw.copy().astype(float) - prec[prec < 0] = np.nan - - # Compute F1 score for each class and recall threshold - f1_cls = 2 * prec * rec_thrs[:, None] / (prec + rec_thrs[:, None]) - f1_macro = np.nanmean(f1_cls, axis=1) - best_j = int(f1_macro.argmax()) - - # Macro precision and recall at the best F1 score - macro_precision = float(np.nanmean(prec[best_j])) - macro_recall = float(rec_thrs[best_j]) + # --- Compute actual (non-interpolated) precision/recall by sweeping confidence thresholds --- + # Build set of TP detection IDs: detections matched to a GT with actual IoU >= 0.50 + tp_dt_ids = { + int(k.split("_")[0]) + for k, iou in self.eval["matched"].items() + if iou >= 0.5 + } - # Score vector for the best recall threshold - score_vec = S[iou50_idx, best_j, :, area_idx, maxdet_idx].astype(float) - score_vec[prec_raw[best_j] < 0] = np.nan + cat_ids_eval = self.params.catIds if self.params.useCats else list( + {ann["category_id"] for ann in self.cocoDt.anns.values()} + ) + + # Per-class: build sorted (descending) score arrays and cumulative TP counts + class_arrays = {} # cat_id -> (scores_desc, cum_tp) + class_items: dict = {} + for dt_id, ann in self.cocoDt.anns.items(): + cat_id = ann["category_id"] + class_items.setdefault(cat_id, []).append((float(ann["score"]), int(dt_id in tp_dt_ids))) + for cat_id, items in class_items.items(): + items.sort(key=lambda x: -x[0]) + scores = np.array([s for s, _ in items]) + is_tp = np.array([t for _, t in items]) + class_arrays[cat_id] = (scores, np.cumsum(is_tp)) + + # Total non-crowd GTs per class + total_gts: dict = {} + for ann in self.cocoGt.anns.values(): + if not ann.get("iscrowd", 0): + total_gts[ann["category_id"]] = total_gts.get(ann["category_id"], 0) + 1 + + # Candidate thresholds: all unique detection scores, ascending + # Iterating ascending ensures the FIRST threshold reaching the maximum macro-F1 + # is the most inclusive one (lowest confidence → highest recall). + all_thresholds = sorted({float(ann["score"]) for ann in self.cocoDt.anns.values()}) + + best_macro_f1 = -np.inf + # best_class_metrics is populated inside the loop when a better macro-F1 is found. + # If no threshold produces valid metrics (e.g., no detections at all), it stays + # empty and per-class precision/recall will be reported as NaN (and filtered out). + best_class_metrics: dict = {} + macro_precision = 0.0 + macro_recall = 0.0 + + for threshold in all_thresholds: + cat_precs, cat_recs, cat_f1s, cat_ids_valid = [], [], [], [] + for cat_id in cat_ids_eval: + n_gt = total_gts.get(cat_id, 0) + if n_gt == 0: + continue + if cat_id in class_arrays: + scores, cum_tp = class_arrays[cat_id] + # Count detections with score >= threshold using binary search + n_above = int(np.searchsorted(-scores, -threshold, side="right")) + tp = int(cum_tp[n_above - 1]) if n_above > 0 else 0 + prec = tp / n_above if n_above > 0 else 0.0 + rec = tp / n_gt + else: + prec, rec = 0.0, 0.0 + f1 = 2.0 * prec * rec / (prec + rec) if (prec + rec) > 0.0 else 0.0 + cat_precs.append(prec) + cat_recs.append(rec) + cat_f1s.append(f1) + cat_ids_valid.append(cat_id) + + if not cat_f1s: + continue + + macro_f1 = float(np.mean(cat_f1s)) + if macro_f1 > best_macro_f1: + best_macro_f1 = macro_f1 + macro_precision = float(np.mean(cat_precs)) + macro_recall = float(np.mean(cat_recs)) + best_class_metrics = { + cid: {"precision": p, "recall": r} + for cid, p, r in zip(cat_ids_valid, cat_precs, cat_recs) + } per_class = [] if self.params.useCats: @@ -263,7 +323,7 @@ def extended_metrics(self): cat_ids = self.params.catIds cat_id_to_name = {c["id"]: c["name"] for c in self.cocoGt.loadCats(cat_ids)} for k, cid in enumerate(cat_ids): - # Precision per category + # AP per category (unchanged: uses interpolated P for mAP, which is correct) p_slice = P[:, :, k, area_idx, maxdet_idx] valid = p_slice > -1 ap_50_95 = float(p_slice[valid].mean()) if valid.any() else float("nan") @@ -273,8 +333,9 @@ def extended_metrics(self): else float("nan") ) - pc = float(prec[best_j, k]) if prec_raw[best_j, k] > -1 else float("nan") - rc = macro_recall + class_m = best_class_metrics.get(int(cid), {}) + pc = class_m.get("precision", float("nan")) + rc = class_m.get("recall", float("nan")) # Filter out dataset class if any metric is NaN if np.isnan(ap_50_95) or np.isnan(ap_50) or np.isnan(pc) or np.isnan(rc): diff --git a/tests/test_basic.py b/tests/test_basic.py index e404bfd..bcf5764 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -1,8 +1,10 @@ #!/usr/bin/python3 +import math import os import unittest import numpy as np +import pytest import faster_coco_eval import faster_coco_eval.core.mask as mask_util @@ -233,5 +235,130 @@ def test_to_dict(self): self.assertListEqual(parsed_data["images"], orig_data["images"]) +# --------------------------------------------------------------------------- +# Extended-metrics tests (pytest-style) +# --------------------------------------------------------------------------- +_SIZE = 200 +_SPACING = 250 +_ROW = 260 + + +def _contained_box(gt_box, iou): + x, y, s, _ = gt_box + p = s * math.sqrt(iou) + off = (s - p) / 2 + return [x + off, y + off, p, p] + + +@pytest.fixture +def coco_gt_dt_with_fp(): + """COCO GT/DT pair with two categories: + - cat1: 10 GTs, 10 TPs (conf 0.0–0.9), 0 FPs + - cat2: 10 GTs, 10 TPs (conf 0.5–0.95), 10 FPs (conf 0.0–0.45) + """ + image_id = 1 + anns, dets = [], [] + ann_id = 1 + + # cat1: perfect predictions + for i, conf in enumerate([0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1, 0.0]): + gt = [float(i * _SPACING), 0.0, float(_SIZE), float(_SIZE)] + anns.append({ + "id": ann_id, "image_id": image_id, "category_id": 1, + "bbox": gt, "area": gt[2] * gt[3], "iscrowd": 0, + }) + dets.append({ + "image_id": image_id, "category_id": 1, + "bbox": _contained_box(gt, 0.96), "score": conf, + }) + ann_id += 1 + + # cat2: 10 TPs followed by 10 FPs at lower confidence + for i, conf in enumerate([0.95, 0.90, 0.85, 0.80, 0.75, 0.70, 0.65, 0.60, 0.55, 0.50]): + gt = [float(i * _SPACING), float(_ROW), float(_SIZE), float(_SIZE)] + anns.append({ + "id": ann_id, "image_id": image_id, "category_id": 2, + "bbox": gt, "area": gt[2] * gt[3], "iscrowd": 0, + }) + dets.append({ + "image_id": image_id, "category_id": 2, + "bbox": _contained_box(gt, 0.96), "score": conf, + }) + ann_id += 1 + + # cat2 FPs (no GT overlap) + for i, conf in enumerate([0.45, 0.40, 0.35, 0.30, 0.25, 0.20, 0.15, 0.10, 0.05, 0.00]): + dets.append({ + "image_id": image_id, "category_id": 2, + "bbox": [float(i * _SPACING), float(2 * _ROW), float(_SIZE), float(_SIZE)], + "score": conf, + }) + + coco_gt = COCO() + coco_gt.dataset = { + "images": [{"id": image_id, "width": 10 * _SPACING, "height": 3 * _ROW}], + "annotations": anns, + "categories": [{"id": 1, "name": "cat1"}, {"id": 2, "name": "cat2"}], + } + coco_gt.createIndex() + coco_dt = coco_gt.loadRes(dets) + return coco_gt, coco_dt + + +def test_extended_metrics_precision_not_overestimated(coco_gt_dt_with_fp): + """Regression test: extended_metrics must not over-estimate precision. + + When cat2 has FPs at low confidence (below the recall ceiling), the + interpolated PR-curve hides those FPs and reports precision=1.0. The + correct macro-precision at the F1-optimal confidence threshold is 0.75. + """ + coco_gt, coco_dt = coco_gt_dt_with_fp + coco_eval = COCOeval_faster(coco_gt, coco_dt, iouType="bbox") + coco_eval.evaluate() + coco_eval.accumulate() + coco_eval.summarize() + m = coco_eval.extended_metrics + + # At the F1-optimal global threshold both classes must reach recall=1.0. + # cat1: P=1.0, R=1.0; cat2: P=0.5, R=1.0 → macro P=0.75, R=1.0 + assert m["precision"] == pytest.approx(0.75, abs=1e-6), ( + "Macro-precision must not use interpolated (overestimated) values" + ) + assert m["recall"] == pytest.approx(1.0, abs=1e-6), ( + "Macro-recall should be 1.0 at the F1-optimal threshold" + ) + + +def test_extended_metrics_perfect_predictions(): + """All predictions are correct TPs: precision=recall=1.0.""" + image_id = 1 + anns, dets = [], [] + ann_id = 1 + for i, conf in enumerate([0.9, 0.8, 0.7, 0.6, 0.5]): + gt = [float(i * _SPACING), 0.0, float(_SIZE), float(_SIZE)] + anns.append({"id": ann_id, "image_id": image_id, "category_id": 1, + "bbox": gt, "area": gt[2] * gt[3], "iscrowd": 0}) + dets.append({"image_id": image_id, "category_id": 1, + "bbox": _contained_box(gt, 0.96), "score": conf}) + ann_id += 1 + + coco_gt = COCO() + coco_gt.dataset = { + "images": [{"id": image_id, "width": 10 * _SPACING, "height": _SIZE + 10}], + "annotations": anns, + "categories": [{"id": 1, "name": "cat1"}], + } + coco_gt.createIndex() + coco_dt = coco_gt.loadRes(dets) + coco_eval = COCOeval_faster(coco_gt, coco_dt, iouType="bbox") + coco_eval.evaluate() + coco_eval.accumulate() + coco_eval.summarize() + + m = coco_eval.extended_metrics + assert m["precision"] == pytest.approx(1.0, abs=1e-6) + assert m["recall"] == pytest.approx(1.0, abs=1e-6) + + if __name__ == "__main__": unittest.main()