@@ -229,10 +229,12 @@ def extended_metrics(self):
229229 Notes:
230230 - Uses COCO-style evaluation results (precision and scores arrays).
231231 - Filters out classes with NaN results in any metric.
232- - The best F1-score across recall thresholds is used to select macro precision and recall.
232+ - The best F1-score across confidence thresholds is used to select macro precision and recall.
233+ - Precision and recall are computed from actual (non-interpolated) detection data to avoid
234+ over-estimating precision when false positives exist below the recall ceiling.
233235 """
234- # Extract IoU and recall thresholds from parameters
235- iou_thrs , rec_thrs = self .params .iouThrs , self . params . recThrs
236+ # Extract IoU thresholds from parameters
237+ iou_thrs = self .params .iouThrs
236238
237239 # Indices for IoU=0.50, first area, and last max dets
238240 _iou50_hits = np .where (np .isclose (iou_thrs , 0.50 ))[0 ]
@@ -243,33 +245,91 @@ def extended_metrics(self):
243245 )
244246 iou50_idx , area_idx , maxdet_idx = (int (_iou50_hits [0 ]), 0 , - 1 )
245247 P = self .eval ["precision" ]
246- S = self .eval ["scores" ]
247248
248- # Get precision for IoU=0.50, area, and max dets
249- prec_raw = P [iou50_idx , :, :, area_idx , maxdet_idx ]
250- prec = prec_raw .copy ().astype (float )
251- prec [prec < 0 ] = np .nan
252-
253- # Compute F1 score for each class and recall threshold
254- f1_cls = 2 * prec * rec_thrs [:, None ] / (prec + rec_thrs [:, None ])
255- f1_macro = np .nanmean (f1_cls , axis = 1 )
256- best_j = int (f1_macro .argmax ())
257-
258- # Macro precision and recall at the best F1 score
259- macro_precision = float (np .nanmean (prec [best_j ]))
260- macro_recall = float (rec_thrs [best_j ])
249+ # --- Compute actual (non-interpolated) precision/recall by sweeping confidence thresholds ---
250+ # Build set of TP detection IDs: detections matched to a GT with actual IoU >= 0.50
251+ tp_dt_ids = {
252+ int (k .split ("_" )[0 ])
253+ for k , iou in self .eval ["matched" ].items ()
254+ if iou >= 0.5
255+ }
261256
262- # Score vector for the best recall threshold
263- score_vec = S [iou50_idx , best_j , :, area_idx , maxdet_idx ].astype (float )
264- score_vec [prec_raw [best_j ] < 0 ] = np .nan
257+ cat_ids_eval = self .params .catIds if self .params .useCats else list (
258+ {ann ["category_id" ] for ann in self .cocoDt .anns .values ()}
259+ )
260+
261+ # Per-class: build sorted (descending) score arrays and cumulative TP counts
262+ class_arrays = {} # cat_id -> (scores_desc, cum_tp)
263+ class_items : dict = {}
264+ for dt_id , ann in self .cocoDt .anns .items ():
265+ cat_id = ann ["category_id" ]
266+ class_items .setdefault (cat_id , []).append ((float (ann ["score" ]), int (dt_id in tp_dt_ids )))
267+ for cat_id , items in class_items .items ():
268+ items .sort (key = lambda x : - x [0 ])
269+ scores = np .array ([s for s , _ in items ])
270+ is_tp = np .array ([t for _ , t in items ])
271+ class_arrays [cat_id ] = (scores , np .cumsum (is_tp ))
272+
273+ # Total non-crowd GTs per class
274+ total_gts : dict = {}
275+ for ann in self .cocoGt .anns .values ():
276+ if not ann .get ("iscrowd" , 0 ):
277+ total_gts [ann ["category_id" ]] = total_gts .get (ann ["category_id" ], 0 ) + 1
278+
279+ # Candidate thresholds: all unique detection scores, ascending
280+ # Iterating ascending ensures the FIRST threshold reaching the maximum macro-F1
281+ # is the most inclusive one (lowest confidence → highest recall).
282+ all_thresholds = sorted ({float (ann ["score" ]) for ann in self .cocoDt .anns .values ()})
283+
284+ best_macro_f1 = - np .inf
285+ # best_class_metrics is populated inside the loop when a better macro-F1 is found.
286+ # If no threshold produces valid metrics (e.g., no detections at all), it stays
287+ # empty and per-class precision/recall will be reported as NaN (and filtered out).
288+ best_class_metrics : dict = {}
289+ macro_precision = 0.0
290+ macro_recall = 0.0
291+
292+ for threshold in all_thresholds :
293+ cat_precs , cat_recs , cat_f1s , cat_ids_valid = [], [], [], []
294+ for cat_id in cat_ids_eval :
295+ n_gt = total_gts .get (cat_id , 0 )
296+ if n_gt == 0 :
297+ continue
298+ if cat_id in class_arrays :
299+ scores , cum_tp = class_arrays [cat_id ]
300+ # Count detections with score >= threshold using binary search
301+ n_above = int (np .searchsorted (- scores , - threshold , side = "right" ))
302+ tp = int (cum_tp [n_above - 1 ]) if n_above > 0 else 0
303+ prec = tp / n_above if n_above > 0 else 0.0
304+ rec = tp / n_gt
305+ else :
306+ prec , rec = 0.0 , 0.0
307+ f1 = 2.0 * prec * rec / (prec + rec ) if (prec + rec ) > 0.0 else 0.0
308+ cat_precs .append (prec )
309+ cat_recs .append (rec )
310+ cat_f1s .append (f1 )
311+ cat_ids_valid .append (cat_id )
312+
313+ if not cat_f1s :
314+ continue
315+
316+ macro_f1 = float (np .mean (cat_f1s ))
317+ if macro_f1 > best_macro_f1 :
318+ best_macro_f1 = macro_f1
319+ macro_precision = float (np .mean (cat_precs ))
320+ macro_recall = float (np .mean (cat_recs ))
321+ best_class_metrics = {
322+ cid : {"precision" : p , "recall" : r }
323+ for cid , p , r in zip (cat_ids_valid , cat_precs , cat_recs )
324+ }
265325
266326 per_class = []
267327 if self .params .useCats :
268328 # Map category IDs to names
269329 cat_ids = self .params .catIds
270330 cat_id_to_name = {c ["id" ]: c ["name" ] for c in self .cocoGt .loadCats (cat_ids )}
271331 for k , cid in enumerate (cat_ids ):
272- # Precision per category
332+ # AP per category (unchanged: uses interpolated P for mAP, which is correct)
273333 p_slice = P [:, :, k , area_idx , maxdet_idx ]
274334 valid = p_slice > - 1
275335 ap_50_95 = float (p_slice [valid ].mean ()) if valid .any () else float ("nan" )
@@ -279,8 +339,9 @@ def extended_metrics(self):
279339 else float ("nan" )
280340 )
281341
282- pc = float (prec [best_j , k ]) if prec_raw [best_j , k ] > - 1 else float ("nan" )
283- rc = macro_recall
342+ class_m = best_class_metrics .get (int (cid ), {})
343+ pc = class_m .get ("precision" , float ("nan" ))
344+ rc = class_m .get ("recall" , float ("nan" ))
284345
285346 # Filter out dataset class if any metric is NaN
286347 if np .isnan (ap_50_95 ) or np .isnan (ap_50 ) or np .isnan (pc ) or np .isnan (rc ):
0 commit comments