diff --git a/t4_devkit/evaluation/__init__.py b/t4_devkit/evaluation/__init__.py index 461c4f5..c336c05 100644 --- a/t4_devkit/evaluation/__init__.py +++ b/t4_devkit/evaluation/__init__.py @@ -4,3 +4,4 @@ from .matching import * # noqa from .result import * # noqa from .metric import * # noqa +from .task import * # noqa diff --git a/t4_devkit/evaluation/dataset.py b/t4_devkit/evaluation/dataset.py index f721d0e..7fb4c39 100644 --- a/t4_devkit/evaluation/dataset.py +++ b/t4_devkit/evaluation/dataset.py @@ -7,6 +7,8 @@ from t4_devkit import Tier4 from t4_devkit.dataclass import HomogeneousMatrix, TransformBuffer +from .task import EvaluationTask + if TYPE_CHECKING: from t4_devkit.dataclass import BoxLike from t4_devkit.schema import EgoPose, Sensor @@ -15,11 +17,12 @@ __all__ = ["load_dataset", "FrameGroundTruth", "SceneGroundTruth"] -def load_dataset(data_root: str) -> SceneGroundTruth: +def load_dataset(data_root: str, task: EvaluationTask) -> SceneGroundTruth: """Load dataset. Args: data_root (str): Root directory path to the dataset. + task (EvaluationTask): Evaluation task. Returns: SceneGroundTruth: Loaded container of ground truths. @@ -29,7 +32,11 @@ def load_dataset(data_root: str) -> SceneGroundTruth: frames: list[FrameGroundTruth] = [] for i, sample in enumerate(t4.sample): # annotation boxes - boxes = list(map(t4.get_box3d, sample.ann_3ds)) + boxes = ( + list(map(t4.get_box3d, sample.ann_3ds)) + if task.is_3d() + else list(map(t4.get_box2d, sample.ann_2ds)) + ) # transformation matrix from ego to map ego_pose = _closest_ego_pose(t4, sample.timestamp) diff --git a/t4_devkit/evaluation/task.py b/t4_devkit/evaluation/task.py new file mode 100644 index 0000000..b452e25 --- /dev/null +++ b/t4_devkit/evaluation/task.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +from enum import Enum + +__all__ = ["EvaluationTask"] + + +class EvaluationTask(str, Enum): + """Enumeration of evaluation tasks.""" + + DETECTION3D = "detection3d" + TRACKING3D = "tracking3d" + PREDICTION3D = "prediction3d" + DETECTION2D = "detection2d" + TRACKING2D = "tracking2d" + + def is_3d(self) -> bool: + return self in ( + EvaluationTask.DETECTION3D, + EvaluationTask.TRACKING3D, + EvaluationTask.PREDICTION3D, + ) + + def is_2d(self) -> bool: + return self in (EvaluationTask.DETECTION2D, EvaluationTask.TRACKING2D) diff --git a/tests/evaluation/test_task.py b/tests/evaluation/test_task.py new file mode 100644 index 0000000..638ed2b --- /dev/null +++ b/tests/evaluation/test_task.py @@ -0,0 +1,14 @@ +from __future__ import annotations + +from t4_devkit.evaluation import EvaluationTask + + +def test_task() -> None: + task_names = {"detection3d", "tracking3d", "prediction3d", "detection2d", "tracking2d"} + + assert task_names == {e.value for e in EvaluationTask} + + for name in task_names: + task = EvaluationTask(name) + + assert task == name