diff --git a/t4_devkit/dataclass/pointcloud.py b/t4_devkit/dataclass/pointcloud.py index 54f124e..d131c66 100644 --- a/t4_devkit/dataclass/pointcloud.py +++ b/t4_devkit/dataclass/pointcloud.py @@ -1,5 +1,6 @@ from __future__ import annotations +from t4_devkit.common.io import load_json import struct from abc import abstractmethod from typing import TYPE_CHECKING, ClassVar, TypeVar @@ -18,14 +19,87 @@ "RadarPointCloud", "SegmentationPointCloud", "PointCloudLike", + "PointCloudMetainfo", + "PointcloudSourceInfo", + "Stamp", ] +@define +class Stamp: + """A dataclass to represent timestamp. + + Attributes: + sec (int): Seconds. + nanosec (int): Nanoseconds. + """ + + sec: int + nanosec: int + + +@define +class PointcloudSourceInfo: + """A dataclass to represent pointcloud source information. + + Attributes: + id (str): source identifier. + idx_begin (int): Begin index of points for the source in the concatenated pointcloud structure. + length (int): Length of points for the source in the concatenated pointcloud structure. + stamp (Stamp): Timestamp. + """ + + id: str + idx_begin: int + length: int + stamp: Stamp = field(converter=lambda x: Stamp(**x) if isinstance(x, dict) else x) + + +@define +class PointCloudMetainfo: + """A dataclass to represent pointcloud metadata. + + Attributes: + stamp (Stamp): Timestamp. + sources (list[PointcloudSourceInfo]): List of source information. + """ + + stamp: Stamp = field(converter=lambda x: Stamp(**x) if isinstance(x, dict) else x) + sources: list[PointcloudSourceInfo] = field(factory=list) + + @classmethod + def from_file(cls, filepath: str) -> Self: + """Create an instance from a JSON file. + + Args: + filepath (str): Path to the JSON file containing metadata. + + Returns: + Self: PointCloudMetainfo instance. + """ + data = load_json(filepath) + stamp = Stamp(**data["stamp"]) + sources = [] + for source_data in data.get("sources", []): + sources.append(PointcloudSourceInfo(**source_data)) + return cls(stamp=stamp, sources=sources) + + @property + def source_ids(self) -> list[str]: + """Get the list of source sensor IDs. + + Returns: + list[str]: List of sensor names. + """ + return [source.id for source in self.sources] + + @define class PointCloud: """Abstract base dataclass for pointcloud data.""" points: NDArrayFloat = field(converter=np.array) + metainfo: PointCloudMetainfo | None = field(default=None) @points.validator def _check_dims(self, attribute, value) -> None: @@ -34,6 +108,68 @@ def _check_dims(self, attribute, value) -> None: f"Expected point dimension is {self.num_dims()}, but got {value.shape[0]}" ) + @metainfo.validator + def _validate_metainfo(self, attribute, value) -> None: + """Validate that sources in metainfo form non-overlapping parts covering all points. + + This validator ensures backward compatibility by allowing None metainfo. + """ + if value is None: + # Backward compatibility: metainfo is optional + return + + if not value.sources: + # No sources to validate + return + + num_points = self.num_points() + + # Collect all intervals defined by sources + intervals = [] + for source_info in value.sources: + source_id = source_info.id + idx_begin = source_info.idx_begin + length = source_info.length + idx_end = idx_begin + length + + # Check bounds + if idx_begin < 0: + raise ValueError(f"Source '{source_id}' has negative idx_begin: {idx_begin}") + if length < 0: + raise ValueError(f"Source '{source_id}' has negative length: {length}") + if idx_end > num_points: + raise ValueError( + f"Source '{source_id}' exceeds point cloud size: " + f"idx_begin={idx_begin}, length={length}, but num_points={num_points}" + ) + + intervals.append((idx_begin, idx_end, source_id)) + + # Sort intervals by start index + intervals.sort(key=lambda x: x[0]) + + # Check for non-overlapping and complete coverage + expected_start = 0 + for idx_begin, idx_end, source_id in intervals: + if idx_begin != expected_start: + if idx_begin > expected_start: + raise ValueError( + f"Gap detected: points [{expected_start}:{idx_begin}) are not covered by any source" + ) + else: + raise ValueError( + f"Overlap detected: source '{source_id}' starts at {idx_begin}, " + f"but previous source ends at {expected_start}" + ) + expected_start = idx_end + + # Check if all points are covered + if expected_start != num_points: + raise ValueError( + f"Incomplete coverage: sources cover up to index {expected_start}, " + f"but num_points={num_points}" + ) + @staticmethod @abstractmethod def num_dims() -> int: @@ -91,12 +227,19 @@ def num_dims() -> int: return 4 @classmethod - def from_file(cls, filepath: str) -> Self: + def from_file(cls, filepath: str, metainfo_filepath: str | None = None) -> Self: assert filepath.endswith(".bin"), f"Unexpected filetype: {filepath}" scan = np.fromfile(filepath, dtype=np.float32) points = scan.reshape((-1, 5))[:, : cls.num_dims()] - return cls(points.T) + + metainfo = ( + PointCloudMetainfo.from_file(metainfo_filepath) + if metainfo_filepath is not None + else None + ) + + return cls(points.T, metainfo=metainfo) @define @@ -123,6 +266,7 @@ def from_file( invalid_states: list[int] | None = None, dynprop_states: list[int] | None = None, ambig_states: list[int] | None = None, + metainfo_filepath: str | None = None, ) -> Self: assert filepath.endswith(".pcd"), f"Unexpected filetype: {filepath}" @@ -177,7 +321,12 @@ def from_file( # A NaN in the first point indicates an empty pointcloud. point = np.array(points[0]) if np.any(np.isnan(point)): - return cls(np.zeros((feature_count, 0))) + metainfo = ( + PointCloudMetainfo.from_file(metainfo_filepath) + if metainfo_filepath is not None + else None + ) + return cls(np.zeros((feature_count, 0)), metainfo=metainfo) # Convert to numpy matrix. points = np.array(points).transpose() @@ -199,7 +348,12 @@ def from_file( valid = [p in ambig_states for p in points[11, :]] points = points[:, valid] - return cls(points) + metainfo = ( + PointCloudMetainfo.from_file(metainfo_filepath) + if metainfo_filepath is not None + else None + ) + return cls(points, metainfo=metainfo) @define @@ -211,18 +365,25 @@ class SegmentationPointCloud(PointCloud): labels (NDArrayU8): Label matrix. """ - labels: NDArrayU8 = field(converter=lambda x: np.array(x, dtype=np.uint8)) + labels: NDArrayU8 = field(converter=lambda x: np.array(x, dtype=np.uint8), kw_only=True) @staticmethod def num_dims() -> int: return 4 @classmethod - def from_file(cls, point_filepath: str, label_filepath: str) -> Self: + def from_file( + cls, point_filepath: str, label_filepath: str, metainfo_filepath: str | None = None + ) -> Self: scan = np.fromfile(point_filepath, dtype=np.float32) points = scan.reshape((-1, 5))[:, : cls.num_dims()] labels = np.fromfile(label_filepath, dtype=np.uint8) - return cls(points.T, labels) + metainfo = ( + PointCloudMetainfo.from_file(metainfo_filepath) + if metainfo_filepath is not None + else None + ) + return cls(points.T, labels=labels, metainfo=metainfo) PointCloudLike = TypeVar("PointCloudLike", bound=PointCloud)