55import numpy as np
66from attrs import define , field
77
8- from ..matching import CenterDistance , HeadingYaw
9-
108if TYPE_CHECKING :
11- from t4_devkit .evaluation import FrameBoxMatch , MatchingScorerLike
9+ from t4_devkit .evaluation import BoxMatch , FrameBoxMatch , MatchingScorerLike
1210
1311__all__ = ["Ap" , "ApH" ]
1412
@@ -61,18 +59,18 @@ def compute_ap(self) -> float:
6159
6260 return float (np .mean (filtered_precision )) / (1.0 - Ap .min_precision )
6361
64- def __init__ (self , threshold : float ) -> None :
65- self .scorer = self . _configure_scorer ()
62+ def __init__ (self , scorer : MatchingScorerLike , threshold : float ) -> None :
63+ self .scorer = scorer
6664 self .threshold = threshold
6765
68- def _configure_scorer (self ) -> MatchingScorerLike :
69- return CenterDistance ()
70-
7166 def __call__ (self , frames : list [FrameBoxMatch ]) -> float :
72- component = self ._compute_tp_fp (frames )
73- return component .compute_ap ()
67+ buffer = self ._update_buffer (frames )
68+ return buffer .compute_ap ()
7469
75- def _compute_tp_fp (self , frames : list [FrameBoxMatch ]) -> ApBuffer :
70+ def _compute_tp (self , _box_match : BoxMatch ) -> float :
71+ return 1.0
72+
73+ def _update_buffer (self , frames : list [FrameBoxMatch ]) -> ApBuffer :
7674 buffer = self .ApBuffer ()
7775 for frame in frames :
7876 buffer .num_gt += frame .num_gt
@@ -86,7 +84,7 @@ def _compute_tp_fp(self, frames: list[FrameBoxMatch]) -> ApBuffer:
8684 threshold = self .threshold ,
8785 ego2map = frame .ego2map ,
8886 ):
89- buffer .tp_list .append (1.0 )
87+ buffer .tp_list .append (self . _compute_tp ( box_match ) )
9088 buffer .fp_list .append (0.0 )
9189 else :
9290 buffer .tp_list .append (0.0 )
@@ -98,5 +96,11 @@ class ApH(Ap):
9896 def __init__ (self , threshold : float ) -> None :
9997 super ().__init__ (threshold = threshold )
10098
101- def _configure_scorer (self ) -> HeadingYaw :
102- return HeadingYaw ()
99+ def _compute_tp (self , box_match : BoxMatch ) -> float :
100+ if not box_match .is_matched ():
101+ return 0.0
102+
103+ diff_yaw = box_match .estimation .diff_yaw (box_match .ground_truth )
104+ if diff_yaw > np .pi :
105+ diff_yaw = 2.0 * np .pi - diff_yaw
106+ return min (1.0 , max (0.0 , 1.0 - diff_yaw / np .pi ))
0 commit comments