Skip to content

Commit 459f1c2

Browse files
committed
feat: add box matching operations (#130)
* feat: add matching functions Signed-off-by: ktro2828 <kotaro.uetake@tier4.jp> * feat: add ground truth class (#129) * feat: add ground truth class Signed-off-by: ktro2828 <kotaro.uetake@tier4.jp> * feat: add support of setting transforms Signed-off-by: ktro2828 <kotaro.uetake@tier4.jp> --------- Signed-off-by: ktro2828 <kotaro.uetake@tier4.jp> * feat: add matching functions Signed-off-by: ktro2828 <kotaro.uetake@tier4.jp> * fix: resolve matching algorithm error Signed-off-by: ktro2828 <kotaro.uetake@tier4.jp> * refactor: rename method from smaller_is_better to is_smaller_score_better Signed-off-by: ktro2828 <kotaro.uetake@tier4.jp> * feat: add score calculation for plane distance Signed-off-by: ktro2828 <kotaro.uetake@tier4.jp> * docs: update API reference Signed-off-by: ktro2828 <kotaro.uetake@tier4.jp> * chore: remove duplicated Signed-off-by: ktro2828 <kotaro.uetake@tier4.jp> * chore: remove unused matching/context.py Signed-off-by: ktro2828 <kotaro.uetake@tier4.jp> * test: add unit test for matching params Signed-off-by: ktro2828 <kotaro.uetake@tier4.jp> --------- Signed-off-by: ktro2828 <kotaro.uetake@tier4.jp>
1 parent fa69898 commit 459f1c2

19 files changed

Lines changed: 1103 additions & 0 deletions

File tree

docs/apis/evaluation.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# `evaluation`
2+
3+
## `matching`
4+
5+
<!-- prettier-ignore-start -->
6+
::: t4_devkit.evaluation.matching.parameter
7+
8+
::: t4_devkit.evaluation.matching.context
9+
10+
::: t4_devkit.evaluation.matching.scorer
11+
12+
::: t4_devkit.evaluation.matching.policy
13+
14+
::: t4_devkit.evaluation.matching.algorithm
15+
<!-- prettier-ignore-end -->
16+
17+
## `result`
18+
19+
<!-- prettier-ignore-start -->
20+
::: t4_devkit.evaluation.result.box
21+
22+
::: t4_devkit.evaluation.result.status
23+
<!-- prettier-ignore-end -->

mkdocs.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ nav:
2323
- Serialize Schema: apis/schema/serialize.md
2424
- t4_devkit.dataclass: apis/dataclass.md
2525
- t4_devkit.filtering: apis/filtering.md
26+
- t4_devkit.evaluation: apis/evaluation.md
2627
- t4_devkit.viewer: apis/viewer.md
2728
- t4_devkit.common: apis/common.md
2829

t4_devkit/dataclass/box.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,33 @@ def corners(self, box_scale: float = 1.0) -> NDArrayF64:
229229
# Rotate and translate
230230
return np.dot(self.rotation.rotation_matrix, corners).T + self.position
231231

232+
def diff_yaw(self, other: Box3D) -> float:
233+
"""Return the yaw difference between the two boxes.
234+
235+
Args:
236+
other (Box3D): Another box.
237+
238+
Raises:
239+
ValueError: Both boxes must have the same `frame_id`.
240+
241+
Returns:
242+
Yaw difference in the range of [-pi, pi].
243+
"""
244+
if self.frame_id != other.frame_id:
245+
raise ValueError(f"Invalid frame comparison: {self.frame_id=} and {other.frame_id=}")
246+
247+
yaw1, *_ = self.rotation.yaw_pitch_roll
248+
yaw2, *_ = other.rotation.yaw_pitch_roll
249+
250+
def _clip(diff: float) -> float:
251+
if diff < -np.pi:
252+
diff += 2 * np.pi
253+
elif diff > np.pi:
254+
diff -= 2 * np.pi
255+
return diff
256+
257+
return _clip(yaw2 - yaw1)
258+
232259

233260
@define(eq=False)
234261
class Box2D(BaseBox):

t4_devkit/evaluation/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,5 @@
11
from .dataset import * # noqa
2+
from .matching import * # noqa
3+
from .result import * # noqa
4+
from .matching import * # noqa
5+
from .result import * # noqa
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .algorithm import * # noqa
2+
from .parameter import * # noqa
3+
from .policy import * # noqa
4+
from .scorer import * # noqa
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
from __future__ import annotations
2+
3+
from abc import ABC, abstractmethod
4+
from copy import deepcopy
5+
from typing import TYPE_CHECKING, Sequence, TypeVar
6+
7+
import numpy as np
8+
9+
from ..result import BoxMatch
10+
11+
if TYPE_CHECKING:
12+
from t4_devkit.dataclass import BoxLike
13+
from t4_devkit.typing import NDArrayF64
14+
15+
from .policy import MatchingPolicyLike
16+
from .scorer import MatchingScorerLike
17+
18+
__all__ = ["GreedyMatcher", "MatchingAlgorithmLike"]
19+
20+
21+
# ===== Base Class for Matching Algorithm =====
22+
23+
24+
class MatchingAlgorithmImpl(ABC):
25+
"""Abstract base class for matching algorithm class."""
26+
27+
def __init__(
28+
self,
29+
scorer: MatchingScorerLike,
30+
policy: MatchingPolicyLike,
31+
matchable_threshold: float,
32+
) -> None:
33+
super().__init__()
34+
self._scorer = scorer
35+
self._policy = policy
36+
self._matchable_threshold = matchable_threshold
37+
38+
def __call__(
39+
self,
40+
estimations: Sequence[BoxLike],
41+
ground_truths: Sequence[BoxLike],
42+
) -> list[BoxMatch]:
43+
"""Execute matching.
44+
45+
Args:
46+
estimations (Sequence[BoxLike]): Sequence of estimations.
47+
ground_truths (Sequence[BoxLike]): Sequence of ground truths.
48+
49+
Returns:
50+
list[BoxMatch]: List of matches.
51+
"""
52+
score_table = self._score_table(estimations, ground_truths)
53+
return self._do_matching(estimations, ground_truths, score_table)
54+
55+
def _score_table(
56+
self,
57+
estimations: Sequence[BoxLike],
58+
ground_truths: Sequence[BoxLike],
59+
) -> NDArrayF64:
60+
"""Create a score table.
61+
62+
Args:
63+
estimations (Sequence[BoxLike]): Sequence of estimations.
64+
ground_truths (Sequence[BoxLike]): Sequence of ground truths.
65+
66+
Returns:
67+
NDArrayF64: Score table in the shape of (NumEst, NumGT).
68+
"""
69+
num_rows, num_cols = len(estimations), len(ground_truths)
70+
71+
table: NDArrayF64 = np.full((num_rows, num_cols), fill_value=np.nan)
72+
for i, box1 in enumerate(estimations):
73+
for j, box2 in enumerate(ground_truths):
74+
if box1.frame_id != box2.frame_id:
75+
continue
76+
77+
score = self._scorer(box1, box2)
78+
79+
# check if boxes distance and label is matchable
80+
if self._scorer.is_better_than(
81+
score, self._matchable_threshold
82+
) and self._policy.is_matchable(box1, box2):
83+
table[i, j] = score
84+
85+
return table
86+
87+
def _get_indices(self, score_table: NDArrayF64) -> tuple[int, int]:
88+
"""Return indices of estimation and ground truth in the score table at the best score.
89+
90+
Args:
91+
score_table (NDArrayF64): Score table in the shape of (NumEst, NumGt).
92+
93+
Returns:
94+
Estimation index and ground truth index.
95+
"""
96+
estimation_idx, ground_truth_idx = (
97+
np.unravel_index(np.nanargmin(score_table), score_table.shape)
98+
if self._scorer.is_smaller_score_better()
99+
else np.unravel_index(np.nanargmax(score_table), score_table.shape)
100+
)
101+
return estimation_idx, ground_truth_idx
102+
103+
@abstractmethod
104+
def _do_matching(
105+
self,
106+
estimations: Sequence[BoxLike],
107+
ground_truths: Sequence[BoxLike],
108+
score_table: NDArrayF64,
109+
) -> list[BoxMatch]:
110+
pass
111+
112+
113+
MatchingAlgorithmLike = TypeVar("MatchingAlgorithmLike", bound=MatchingAlgorithmImpl)
114+
115+
116+
# ===== Specific Matching Algorithms =====
117+
118+
119+
class GreedyMatcher(MatchingAlgorithmImpl):
120+
def _do_matching(
121+
self,
122+
estimations: Sequence[BoxLike],
123+
ground_truths: Sequence[BoxLike],
124+
score_table: NDArrayF64,
125+
) -> list[BoxMatch]:
126+
tmp_estimations = list(deepcopy(estimations))
127+
tmp_ground_truths = list(deepcopy(ground_truths))
128+
129+
output: list[BoxMatch] = []
130+
# 1. match the nearest matchable estimations and GTs
131+
num_estimations, *_ = score_table.shape
132+
for _ in range(num_estimations):
133+
if np.isnan(score_table).all():
134+
break
135+
136+
estimation_idx, ground_truth_idx = self._get_indices(score_table)
137+
138+
estimation_picked = tmp_estimations.pop(estimation_idx)
139+
ground_truth_picked = tmp_ground_truths.pop(ground_truth_idx)
140+
output.append(BoxMatch(estimation_picked, ground_truth_picked))
141+
142+
# remove picked estimations and GTs
143+
score_table = np.delete(score_table, estimation_idx, axis=0)
144+
score_table = np.delete(score_table, ground_truth_idx, axis=1)
145+
146+
# 2. assign remaining estimations(=FPs) and GTs(=FNs)
147+
output += [BoxMatch(estimation=estimation) for estimation in tmp_estimations]
148+
output += [BoxMatch(ground_truth=ground_truth) for ground_truth in tmp_ground_truths]
149+
150+
return output
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
from __future__ import annotations
2+
3+
from enum import Enum
4+
5+
import numpy as np
6+
from attrs import define, field
7+
8+
from .algorithm import GreedyMatcher, MatchingAlgorithmLike
9+
from .policy import AllowAnyPolicy, AllowUnknownPolicy, MatchingPolicyLike, StrictPolicy
10+
from .scorer import CenterDistance, Iou2D, Iou3D, MatchingScorerLike, PlaneDistance
11+
12+
__all__ = [
13+
"build_matcher",
14+
"build_matching_scorer",
15+
"build_matching_policy",
16+
"MatchingScorer",
17+
"MatchingPolicy",
18+
"MatchingAlgorithm",
19+
"MatchingParams",
20+
]
21+
22+
# ===== Builder Functions =====
23+
24+
25+
def build_matcher(params: MatchingParams) -> MatchingAlgorithmLike:
26+
"""Build a matching algorithm from the parameter.
27+
28+
Examples:
29+
>>> params = MatchingParams(...)
30+
>>> matcher = build_matcher(params)
31+
"""
32+
scorer = build_matching_scorer(params)
33+
policy = build_matching_policy(params)
34+
if params.algorithm == MatchingAlgorithm.GREEDY:
35+
return GreedyMatcher(
36+
scorer=scorer,
37+
policy=policy,
38+
matchable_threshold=params.matchable_distance,
39+
)
40+
else:
41+
raise ValueError(f"Unexpected algorithm name: {params.algorithm}")
42+
43+
44+
def build_matching_scorer(params: MatchingParams) -> MatchingScorerLike:
45+
"""Build a matching scorer from the parameter.
46+
47+
Examples:
48+
>>> params = MatchingParams(...)
49+
>>> scorer = build_matching_scorer(params)
50+
"""
51+
if params.scorer == MatchingScorer.CENTER_DISTANCE:
52+
return CenterDistance()
53+
elif params.scorer == MatchingScorer.PLANE_DISTANCE:
54+
return PlaneDistance()
55+
elif params.scorer == MatchingScorer.IOU2D:
56+
return Iou2D()
57+
elif params.scorer == MatchingScorer.IOU3D:
58+
return Iou3D()
59+
else:
60+
raise ValueError(f"Unexpected scorer name: {params.scorer}")
61+
62+
63+
def build_matching_policy(params: MatchingParams) -> MatchingPolicyLike:
64+
"""Build a matching policy from the parameter.
65+
66+
Examples:
67+
>>> params = MatchingParams(...)
68+
>>> policy = build_matching_policy(params)
69+
"""
70+
if params.policy == MatchingPolicy.STRICT:
71+
return StrictPolicy()
72+
elif params.policy == MatchingPolicy.ALLOW_UNKNOWN:
73+
return AllowUnknownPolicy()
74+
elif params.policy == MatchingPolicy.ALLOW_ANY:
75+
return AllowAnyPolicy()
76+
else:
77+
raise ValueError(f"Unexpected policy name: {params.policy}")
78+
79+
80+
# ===== Parameters =====
81+
82+
83+
class MatchingScorer(str, Enum):
84+
"""An enum to represent matching scorer names."""
85+
86+
CENTER_DISTANCE = "CENTER_DISTANCE"
87+
PLANE_DISTANCE = "PLANE_DISTANCE"
88+
IOU2D = "IOU2D"
89+
IOU3D = "IOU3D"
90+
91+
92+
class MatchingPolicy(str, Enum):
93+
"""An enum to represent matching policy names."""
94+
95+
STRICT = "STRICT"
96+
ALLOW_UNKNOWN = "ALLOW_UNKNOWN"
97+
ALLOW_ANY = "ALLOW_ANY"
98+
99+
100+
class MatchingAlgorithm(str, Enum):
101+
"""An enum to represent matching algorithm names."""
102+
103+
GREEDY = "GREEDY"
104+
105+
106+
@define
107+
class MatchingParams:
108+
"""A dataclass to represent matching parameters.
109+
110+
Attributes:
111+
scorer (MatchingScorer): Name of matching scorer.
112+
policy (MatchingPolicy): Name of matching policy.
113+
algorithm (MatchingAlgorithm): Name of matching algorithm.
114+
matchable_distance (float): Max distance from a GT which a estimation can be matched.
115+
"""
116+
117+
scorer: MatchingScorer = field(default=MatchingScorer.CENTER_DISTANCE, converter=MatchingScorer)
118+
policy: MatchingPolicy = field(default=MatchingPolicy.STRICT, converter=MatchingPolicy)
119+
algorithm: MatchingAlgorithm = field(
120+
default=MatchingAlgorithm.GREEDY,
121+
converter=MatchingAlgorithm,
122+
)
123+
matchable_distance: float = field(default=np.inf)

0 commit comments

Comments
 (0)