Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions t4_devkit/dataclass/box.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ class Box3D(BaseBox):
)
visibility: VisibilityLevel = field(
default=VisibilityLevel.UNAVAILABLE,
converter=converters.optional(VisibilityLevel),
validator=validators.optional(validators.instance_of(VisibilityLevel)),
converter=VisibilityLevel,
validator=validators.instance_of(VisibilityLevel),
)

# additional attributes: set by `with_**`
Expand Down
2 changes: 2 additions & 0 deletions t4_devkit/filtering/compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
FilterByRegion,
FilterBySpeed,
FilterByUUID,
FilterByVisibility,
)
from .parameter import FilterParams

Expand All @@ -35,6 +36,7 @@ def __init__(self, params: FilterParams, tf_buffer: TransformBuffer) -> None:
FilterByRegion.from_params(params),
FilterBySpeed.from_params(params),
FilterByNumPoints.from_params(params),
FilterByVisibility.from_params(params),
]

self.tf_buffer = tf_buffer
Expand Down
40 changes: 39 additions & 1 deletion t4_devkit/filtering/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from typing_extensions import Self

from t4_devkit.dataclass import Box2D, Box3D, HomogeneousMatrix, distance_box
from t4_devkit.filtering.parameter import FilterParams
from t4_devkit.schema import VisibilityLevel

from .parameter import FilterParams

if TYPE_CHECKING:
from t4_devkit.dataclass import BoxLike, SemanticLabel
Expand All @@ -19,6 +21,9 @@
"FilterByUUID",
"FilterByDistance",
"FilterByRegion",
"FilterBySpeed",
"FilterByNumPoints",
"FilterByVisibility",
"BoxFilterFunction",
]

Expand Down Expand Up @@ -236,4 +241,37 @@ def __call__(self, box: BoxLike, _tf_matrix: HomogeneousMatrix | None = None) ->
return self.min_num_points <= box.num_points


class FilterByVisibility(BaseBoxFilter):
"""A filter that excludes 3D boxes with lower visibility than a specified threshold.

Boxes with `UNAVAILABLE` visibility are always passed through (i.e., not filtered).
"""

def __init__(self, visibility: VisibilityLevel = VisibilityLevel.NONE) -> None:
"""
Initialize the filter with a visibility threshold.

Args:
visibility (VisibilityLevel): The minimum visibility level for a box to pass the filter.

Raises:
ValueError: If the given visibility is not comparable (e.g., UNAVAILABLE).
"""
super().__init__()
if not visibility.is_comparable():
raise ValueError(f"Comparable visibility must be set as threshold: {visibility}")

self.visibility = visibility

@classmethod
def from_params(cls, params: FilterParams) -> Self:
return cls(params.visibility)

def __call__(self, box: BoxLike, _tf_matrix: HomogeneousMatrix | None = None) -> bool:
if not isinstance(box, Box3D):
return True
else:
return self.visibility <= box.visibility if box.visibility.is_comparable() else True


BoxFilterFunction = TypeVar("BoxFilterFunction", bound=BaseBoxFilter)
7 changes: 7 additions & 0 deletions t4_devkit/filtering/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from attrs import define, field, validators

from t4_devkit.dataclass import SemanticLabel
from t4_devkit.schema import VisibilityLevel


@define
Expand All @@ -22,6 +23,7 @@ class FilterParams:
min_speed (float, optional): Minimum speed [m/s].
max_speed (float, optional): Maximum speed [m/s].
min_num_points (int): The minimum number of points which the 3D box should include.
visibility (str | VisibilityLevel, optional): Visibility threshold.
"""

labels: Sequence[str | SemanticLabel] | None = field(
Expand All @@ -44,3 +46,8 @@ class FilterParams:
default=0,
validator=[validators.instance_of(int), validators.ge(0)],
)
visibility: VisibilityLevel = field(
default=VisibilityLevel.NONE,
converter=VisibilityLevel,
validator=validators.instance_of(VisibilityLevel),
)
35 changes: 35 additions & 0 deletions t4_devkit/schema/tables/visibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,41 @@ def _from_alias(level: str) -> Self:
)
return VisibilityLevel.UNAVAILABLE

def rank(self) -> int:
"""Return an integer rank for comparison (higher is more visible)."""
ranking = {
"full": 4,
"most": 3,
"partial": 2,
"none": 1,
"unavailable": None,
}
return ranking[self.value]

def is_comparable(self) -> bool:
"""Return True if the visibility level has a defined rank."""
return self.rank() is not None

def _check_comparability(self, other: VisibilityLevel) -> None:
if not (self.is_comparable() and other.is_comparable()):
raise ValueError(f"Cannot compare unknown visibility levels: {self}, {other}")

def __lt__(self, other: VisibilityLevel) -> bool:
self._check_comparability(other)
return self.rank() < other.rank()

def __le__(self, other: VisibilityLevel) -> bool:
self._check_comparability(other)
return self.rank() <= other.rank()

def __gt__(self, other: VisibilityLevel) -> bool:
self._check_comparability(other)
return self.rank() > other.rank()

def __ge__(self, other: VisibilityLevel) -> bool:
self._check_comparability(other)
return self.rank() >= other.rank()


@define(slots=False)
@SCHEMAS.register(SchemaName.VISIBILITY)
Expand Down
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def dummy_box3ds() -> list[Box3D]:
confidence=1.0,
uuid="bicycle3d_1",
num_points=1,
visibility=VisibilityLevel.FULL,
visibility=VisibilityLevel.MOST,
).with_future(
timestamps=[101, 102, 103, 104],
confidences=[1.0, 0.5],
Expand Down Expand Up @@ -164,7 +164,7 @@ def dummy_box3ds() -> list[Box3D]:
confidence=1.0,
uuid="pedestrian3d_1",
num_points=1,
visibility="full",
visibility="none", # str is also OK
).with_future(
timestamps=[101, 102, 103, 104],
confidences=[1.0, 0.5, 0.2],
Expand Down
4 changes: 4 additions & 0 deletions tests/fitering/test_filter_compose.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from __future__ import annotations

from t4_devkit.filtering import BoxFilter, FilterParams
from t4_devkit.schema import VisibilityLevel


def test_composite_filter(dummy_box3ds, dummy_box2ds, dummy_tf_buffer) -> None:
Expand All @@ -19,6 +22,7 @@ def test_composite_filter(dummy_box3ds, dummy_box2ds, dummy_tf_buffer) -> None:
min_speed=0.5,
max_speed=2.0,
min_num_points=0,
visibility=VisibilityLevel.FULL,
)

box_filter = BoxFilter(params, dummy_tf_buffer)
Expand Down
17 changes: 17 additions & 0 deletions tests/fitering/test_filter_function.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from __future__ import annotations

from t4_devkit.filtering.functional import (
FilterByDistance,
FilterByLabel,
FilterByNumPoints,
FilterByRegion,
FilterBySpeed,
FilterByUUID,
FilterByVisibility,
)
from t4_devkit.schema import VisibilityLevel


def test_filter_by_label(dummy_box3ds, dummy_box2ds) -> None:
Expand Down Expand Up @@ -116,3 +120,16 @@ def test_filter_by_num_points(dummy_box3ds) -> None:
answer = [box for box in dummy_box3ds if box_filter(box)]

assert len(answer) == 3


def test_filter_by_visibility(dummy_box3ds) -> None:
"""Test `FilterByVisibility`.

Args:
dummy_box3ds (list[Box3D]): List of 3D boxes.
"""
box_filter = FilterByVisibility(visibility=VisibilityLevel.MOST)

answer = [box for box in dummy_box3ds if box_filter(box)]

assert len(answer) == 2
Loading