Skip to content
Merged
175 changes: 168 additions & 7 deletions t4_devkit/dataclass/pointcloud.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from t4_devkit.common.io import load_json
import struct
from abc import abstractmethod
from typing import TYPE_CHECKING, ClassVar, TypeVar
Expand All @@ -18,14 +19,87 @@
"RadarPointCloud",
"SegmentationPointCloud",
"PointCloudLike",
"PointCloudMetainfo",
"PointcloudSourceInfo",
"Stamp",
]


@define
class Stamp:
"""A dataclass to represent timestamp.

Attributes:
sec (int): Seconds.
nanosec (int): Nanoseconds.
"""

sec: int
nanosec: int


@define
class PointcloudSourceInfo:
"""A dataclass to represent pointcloud source information.

Attributes:
id (str): source identifier.
idx_begin (int): Begin index of points for the source in the concatenated pointcloud structure.
length (int): Length of points for the source in the concatenated pointcloud structure.
stamp (Stamp): Timestamp.
"""

id: str
idx_begin: int
length: int
stamp: Stamp = field(converter=lambda x: Stamp(**x) if isinstance(x, dict) else x)


@define
class PointCloudMetainfo:
"""A dataclass to represent pointcloud metadata.

Attributes:
stamp (Stamp): Timestamp.
sources (list[PointcloudSourceInfo]): List of source information.
"""

stamp: Stamp = field(converter=lambda x: Stamp(**x) if isinstance(x, dict) else x)
sources: list[PointcloudSourceInfo] = field(factory=list)

@classmethod
def from_file(cls, filepath: str) -> Self:
"""Create an instance from a JSON file.

Args:
filepath (str): Path to the JSON file containing metadata.

Returns:
Self: PointCloudMetainfo instance.
"""
data = load_json(filepath)
stamp = Stamp(**data["stamp"])
sources = []
for source_data in data.get("sources", []):
sources.append(PointcloudSourceInfo(**source_data))
Comment thread
SamratThapa120 marked this conversation as resolved.
return cls(stamp=stamp, sources=sources)

@property
def source_ids(self) -> list[str]:
"""Get the list of source sensor IDs.

Returns:
list[str]: List of sensor names.
"""
return [source.id for source in self.sources]


@define
class PointCloud:
"""Abstract base dataclass for pointcloud data."""

points: NDArrayFloat = field(converter=np.array)
metainfo: PointCloudMetainfo | None = field(default=None)

@points.validator
def _check_dims(self, attribute, value) -> None:
Expand All @@ -34,6 +108,68 @@ def _check_dims(self, attribute, value) -> None:
f"Expected point dimension is {self.num_dims()}, but got {value.shape[0]}"
)

@metainfo.validator
def _validate_metainfo(self, attribute, value) -> None:
"""Validate that sources in metainfo form non-overlapping parts covering all points.

This validator ensures backward compatibility by allowing None metainfo.
"""
if value is None:
# Backward compatibility: metainfo is optional
return

if not value.sources:
# No sources to validate
return

num_points = self.num_points()

# Collect all intervals defined by sources
intervals = []
for source_info in value.sources:
Comment thread
SamratThapa120 marked this conversation as resolved.
source_id = source_info.id
idx_begin = source_info.idx_begin
length = source_info.length
idx_end = idx_begin + length

# Check bounds
if idx_begin < 0:
raise ValueError(f"Source '{source_id}' has negative idx_begin: {idx_begin}")
if length < 0:
raise ValueError(f"Source '{source_id}' has negative length: {length}")
if idx_end > num_points:
raise ValueError(
f"Source '{source_id}' exceeds point cloud size: "
f"idx_begin={idx_begin}, length={length}, but num_points={num_points}"
)

intervals.append((idx_begin, idx_end, source_id))

# Sort intervals by start index
intervals.sort(key=lambda x: x[0])

# Check for non-overlapping and complete coverage
expected_start = 0
for idx_begin, idx_end, source_id in intervals:
if idx_begin != expected_start:
if idx_begin > expected_start:
raise ValueError(
f"Gap detected: points [{expected_start}:{idx_begin}) are not covered by any source"
)
else:
raise ValueError(
f"Overlap detected: source '{source_id}' starts at {idx_begin}, "
f"but previous source ends at {expected_start}"
)
expected_start = idx_end

# Check if all points are covered
if expected_start != num_points:
raise ValueError(
f"Incomplete coverage: sources cover up to index {expected_start}, "
f"but num_points={num_points}"
)

@staticmethod
@abstractmethod
def num_dims() -> int:
Expand Down Expand Up @@ -91,12 +227,19 @@ def num_dims() -> int:
return 4

@classmethod
def from_file(cls, filepath: str) -> Self:
def from_file(cls, filepath: str, metainfo_filepath: str | None = None) -> Self:
assert filepath.endswith(".bin"), f"Unexpected filetype: {filepath}"

scan = np.fromfile(filepath, dtype=np.float32)
points = scan.reshape((-1, 5))[:, : cls.num_dims()]
return cls(points.T)

metainfo = (
PointCloudMetainfo.from_file(metainfo_filepath)
if metainfo_filepath is not None
else None
)

return cls(points.T, metainfo=metainfo)


@define
Expand All @@ -123,6 +266,7 @@ def from_file(
invalid_states: list[int] | None = None,
dynprop_states: list[int] | None = None,
ambig_states: list[int] | None = None,
metainfo_filepath: str | None = None,
) -> Self:
assert filepath.endswith(".pcd"), f"Unexpected filetype: {filepath}"

Expand Down Expand Up @@ -177,7 +321,12 @@ def from_file(
# A NaN in the first point indicates an empty pointcloud.
point = np.array(points[0])
if np.any(np.isnan(point)):
return cls(np.zeros((feature_count, 0)))
metainfo = (
PointCloudMetainfo.from_file(metainfo_filepath)
if metainfo_filepath is not None
else None
)
return cls(np.zeros((feature_count, 0)), metainfo=metainfo)

# Convert to numpy matrix.
points = np.array(points).transpose()
Expand All @@ -199,7 +348,12 @@ def from_file(
valid = [p in ambig_states for p in points[11, :]]
points = points[:, valid]

return cls(points)
metainfo = (
PointCloudMetainfo.from_file(metainfo_filepath)
if metainfo_filepath is not None
else None
)
return cls(points, metainfo=metainfo)


@define
Expand All @@ -211,18 +365,25 @@ class SegmentationPointCloud(PointCloud):
labels (NDArrayU8): Label matrix.
"""

labels: NDArrayU8 = field(converter=lambda x: np.array(x, dtype=np.uint8))
labels: NDArrayU8 = field(converter=lambda x: np.array(x, dtype=np.uint8), kw_only=True)

@staticmethod
def num_dims() -> int:
return 4

@classmethod
def from_file(cls, point_filepath: str, label_filepath: str) -> Self:
def from_file(
cls, point_filepath: str, label_filepath: str, metainfo_filepath: str | None = None
) -> Self:
scan = np.fromfile(point_filepath, dtype=np.float32)
points = scan.reshape((-1, 5))[:, : cls.num_dims()]
labels = np.fromfile(label_filepath, dtype=np.uint8)
return cls(points.T, labels)
metainfo = (
PointCloudMetainfo.from_file(metainfo_filepath)
if metainfo_filepath is not None
else None
)
return cls(points.T, labels=labels, metainfo=metainfo)


PointCloudLike = TypeVar("PointCloudLike", bound=PointCloud)
Loading