Skip to content

Commit 367abd8

Browse files
committed
feat: add ground truth class (#129)
* feat: add ground truth class Signed-off-by: ktro2828 <kotaro.uetake@tier4.jp> * feat: add support of setting transforms Signed-off-by: ktro2828 <kotaro.uetake@tier4.jp> --------- Signed-off-by: ktro2828 <kotaro.uetake@tier4.jp>
1 parent 757f6b4 commit 367abd8

2 files changed

Lines changed: 118 additions & 0 deletions

File tree

t4_devkit/evaluation/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .dataset import * # noqa

t4_devkit/evaluation/dataset.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
from attrs import define
6+
7+
from t4_devkit import Tier4
8+
from t4_devkit.dataclass import HomogeneousMatrix, TransformBuffer
9+
10+
if TYPE_CHECKING:
11+
from t4_devkit.dataclass import BoxLike
12+
from t4_devkit.schema import EgoPose, Sensor
13+
14+
15+
__all__ = ["load_dataset", "FrameGroundTruth", "SceneGroundTruth"]
16+
17+
18+
def load_dataset(data_root: str) -> SceneGroundTruth:
19+
"""Load dataset.
20+
21+
Args:
22+
data_root (str): Root directory path to the dataset.
23+
24+
Returns:
25+
SceneGroundTruth: Loaded container of ground truths.
26+
"""
27+
t4 = Tier4("annotation", data_root=data_root, verbose=False)
28+
29+
frames: list[FrameGroundTruth] = []
30+
for i, sample in enumerate(t4.sample):
31+
# annotation boxes
32+
boxes = list(map(t4.get_box3d, sample.ann_3ds))
33+
34+
# transformation matrix from ego to map
35+
ego_pose = _closest_ego_pose(t4, sample.timestamp)
36+
ego2map = HomogeneousMatrix(
37+
position=ego_pose.translation,
38+
rotation=ego_pose.rotation,
39+
src="map",
40+
dst="base_link",
41+
)
42+
43+
frames.append(
44+
FrameGroundTruth(
45+
unix_time=sample.timestamp,
46+
frame_index=i,
47+
boxes=boxes,
48+
ego2map=ego2map,
49+
)
50+
)
51+
52+
# transformation matrices from ego to each sensor
53+
ego2sensors = TransformBuffer()
54+
for cs_record in t4.calibrated_sensor:
55+
sensor: Sensor = t4.get("sensor", cs_record.sensor_token)
56+
matrix = HomogeneousMatrix(
57+
position=cs_record.translation,
58+
rotation=cs_record.rotation,
59+
src="base_link",
60+
dst=sensor.channel,
61+
)
62+
63+
ego2sensors.set_transform(matrix)
64+
65+
return SceneGroundTruth(data_root=data_root, frames=frames, ego2sensors=ego2sensors)
66+
67+
68+
def _closest_ego_pose(t4: Tier4, timestamp: int) -> EgoPose:
69+
"""Lookup the ego pose record at the closest timestamp."""
70+
return min(t4.ego_pose, key=lambda e: abs(e.timestamp - timestamp))
71+
72+
73+
@define
74+
class FrameGroundTruth:
75+
"""A container of boxes at a single frame.
76+
77+
Attributes:
78+
unix_time (int): Unix timestamp.
79+
frame_index (int): Index number of the frame.
80+
boxes (list[BoxLike]): List of ground truth instances.
81+
ego2map (HomogeneousMatrix): Transformation matrix from ego to map coordinate.
82+
"""
83+
84+
unix_time: int
85+
frame_index: int
86+
boxes: list[BoxLike]
87+
ego2map: HomogeneousMatrix
88+
89+
90+
@define
91+
class SceneGroundTruth:
92+
"""A container of frame ground truths.
93+
94+
Attributes:
95+
data_root (str): Root directory path to the dataset.
96+
frames (list[FrameGroundTruth]): List of frame ground truths.
97+
ego2sensors (TransformBuffer): Buffer of transformation matrices from ego to each sensor coordinates.
98+
"""
99+
100+
data_root: str
101+
frames: list[FrameGroundTruth]
102+
ego2sensors: TransformBuffer
103+
104+
def lookup_frame(self, unix_time: int, tolerance: int) -> FrameGroundTruth | None:
105+
"""Lookup the closest set of ground truth frame.
106+
107+
Return None if the minimum time difference exceeds `tolerance`.
108+
109+
Args:
110+
unix_time (int): Unix timestamp.
111+
tolerance (int): Time difference tolerance in micro seconds.
112+
113+
Returns:
114+
Return frame ground truth if succeeded, otherwise None.
115+
"""
116+
closest = min(self.frames, key=lambda f: abs(unix_time - f.unix_time))
117+
return closest if abs(unix_time - closest.unix_time) <= tolerance else None

0 commit comments

Comments
 (0)