Skip to content

Commit df05de2

Browse files
committed
feat: add evaluation task (#149)
Signed-off-by: ktro2828 <kotaro.uetake@tier4.jp>
1 parent 2173bd8 commit df05de2

4 files changed

Lines changed: 49 additions & 2 deletions

File tree

t4_devkit/evaluation/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44
from .matching import * # noqa
55
from .result import * # noqa
66
from .metric import * # noqa
7+
from .task import * # noqa

t4_devkit/evaluation/dataset.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from t4_devkit import Tier4
88
from t4_devkit.dataclass import HomogeneousMatrix, TransformBuffer
99

10+
from .task import EvaluationTask
11+
1012
if TYPE_CHECKING:
1113
from t4_devkit.dataclass import BoxLike
1214
from t4_devkit.schema import EgoPose, Sensor
@@ -15,11 +17,12 @@
1517
__all__ = ["load_dataset", "FrameGroundTruth", "SceneGroundTruth"]
1618

1719

18-
def load_dataset(data_root: str) -> SceneGroundTruth:
20+
def load_dataset(data_root: str, task: EvaluationTask) -> SceneGroundTruth:
1921
"""Load dataset.
2022
2123
Args:
2224
data_root (str): Root directory path to the dataset.
25+
task (EvaluationTask): Evaluation task.
2326
2427
Returns:
2528
SceneGroundTruth: Loaded container of ground truths.
@@ -29,7 +32,11 @@ def load_dataset(data_root: str) -> SceneGroundTruth:
2932
frames: list[FrameGroundTruth] = []
3033
for i, sample in enumerate(t4.sample):
3134
# annotation boxes
32-
boxes = list(map(t4.get_box3d, sample.ann_3ds))
35+
boxes = (
36+
list(map(t4.get_box3d, sample.ann_3ds))
37+
if task.is_3d()
38+
else list(map(t4.get_box2d, sample.ann_2ds))
39+
)
3340

3441
# transformation matrix from ego to map
3542
ego_pose = _closest_ego_pose(t4, sample.timestamp)

t4_devkit/evaluation/task.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from __future__ import annotations
2+
3+
from enum import Enum
4+
5+
__all__ = ["EvaluationTask"]
6+
7+
8+
class EvaluationTask(str, Enum):
9+
"""Enumeration of evaluation tasks."""
10+
11+
DETECTION3D = "detection3d"
12+
TRACKING3D = "tracking3d"
13+
PREDICTION3D = "prediction3d"
14+
DETECTION2D = "detection2d"
15+
TRACKING2D = "tracking2d"
16+
17+
def is_3d(self) -> bool:
18+
return self in (
19+
EvaluationTask.DETECTION3D,
20+
EvaluationTask.TRACKING3D,
21+
EvaluationTask.PREDICTION3D,
22+
)
23+
24+
def is_2d(self) -> bool:
25+
return self in (EvaluationTask.DETECTION2D, EvaluationTask.TRACKING2D)

tests/evaluation/test_task.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from __future__ import annotations
2+
3+
from t4_devkit.evaluation import EvaluationTask
4+
5+
6+
def test_task() -> None:
7+
task_names = {"detection3d", "tracking3d", "prediction3d", "detection2d", "tracking2d"}
8+
9+
assert task_names == {e.value for e in EvaluationTask}
10+
11+
for name in task_names:
12+
task = EvaluationTask(name)
13+
14+
assert task == name

0 commit comments

Comments
 (0)