Skip to content

Commit b80641f

Browse files
committed
feat: add filter class with box visibility
Signed-off-by: ktro2828 <kotaro.uetake@tier4.jp>
1 parent 347ba2a commit b80641f

6 files changed

Lines changed: 96 additions & 5 deletions

File tree

t4_devkit/dataclass/box.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,8 @@ class Box3D(BaseBox):
126126
)
127127
visibility: VisibilityLevel = field(
128128
default=VisibilityLevel.UNAVAILABLE,
129-
converter=converters.optional(VisibilityLevel),
130-
validator=validators.optional(validators.instance_of(VisibilityLevel)),
129+
converter=VisibilityLevel,
130+
validator=validators.instance_of(VisibilityLevel),
131131
)
132132

133133
# additional attributes: set by `with_**`

t4_devkit/filtering/functional.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
from typing_extensions import Self
88

99
from t4_devkit.dataclass import Box2D, Box3D, HomogeneousMatrix, distance_box
10-
from t4_devkit.filtering.parameter import FilterParams
10+
from t4_devkit.schema import VisibilityLevel
11+
12+
from .parameter import FilterParams
1113

1214
if TYPE_CHECKING:
1315
from t4_devkit.dataclass import BoxLike, SemanticLabel
@@ -19,6 +21,9 @@
1921
"FilterByUUID",
2022
"FilterByDistance",
2123
"FilterByRegion",
24+
"FilterBySpeed",
25+
"FilterByNumPoints",
26+
"FilterByVisibility",
2227
"BoxFilterFunction",
2328
]
2429

@@ -236,4 +241,36 @@ def __call__(self, box: BoxLike, _tf_matrix: HomogeneousMatrix | None = None) ->
236241
return self.min_num_points <= box.num_points
237242

238243

244+
class FilterByVisibility(BaseBoxFilter):
245+
"""A filter that excludes 3D boxes with lower visibility than a specified threshold.
246+
247+
Boxes with `UNAVAILABLE` visibility are always passed through (i.e., not filtered).
248+
"""
249+
250+
def __init__(self, visibility: VisibilityLevel = VisibilityLevel.NONE) -> None:
251+
"""
252+
Initialize the filter with a visibility threshold.
253+
254+
Args:
255+
visibility (VisibilityLevel): The minimum visibility level for a box to pass the filter.
256+
257+
Raises:
258+
ValueError: If the given visibility is not comparable (e.g., UNAVAILABLE).
259+
"""
260+
super().__init__()
261+
if not visibility.is_comparable():
262+
raise ValueError(f"Comparable visibility must be set as threshold: {visibility}")
263+
264+
self.visibility = visibility
265+
266+
def from_params(cls, params: FilterParams) -> Self:
267+
return cls(params.visibility)
268+
269+
def __call__(self, box: BoxLike, _tf_matrix: HomogeneousMatrix | None = None) -> bool:
270+
if not isinstance(box, Box3D):
271+
return True
272+
else:
273+
return self.visibility <= box.visibility if box.visibility.is_comparable() else True
274+
275+
239276
BoxFilterFunction = TypeVar("BoxFilterFunction", bound=BaseBoxFilter)

t4_devkit/filtering/parameter.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from attrs import define, field, validators
77

88
from t4_devkit.dataclass import SemanticLabel
9+
from t4_devkit.schema import VisibilityLevel
910

1011

1112
@define
@@ -22,6 +23,7 @@ class FilterParams:
2223
min_speed (float, optional): Minimum speed [m/s].
2324
max_speed (float, optional): Maximum speed [m/s].
2425
min_num_points (int): The minimum number of points which the 3D box should include.
26+
visibility (str | VisibilityLevel, optional): Visibility threshold.
2527
"""
2628

2729
labels: Sequence[str | SemanticLabel] | None = field(
@@ -44,3 +46,8 @@ class FilterParams:
4446
default=0,
4547
validator=[validators.instance_of(int), validators.ge(0)],
4648
)
49+
visibility: VisibilityLevel = field(
50+
default=VisibilityLevel.NONE,
51+
converter=VisibilityLevel,
52+
validator=validators.instance_of(VisibilityLevel),
53+
)

t4_devkit/schema/tables/visibility.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,41 @@ def _from_alias(level: str) -> Self:
5959
)
6060
return VisibilityLevel.UNAVAILABLE
6161

62+
def rank(self) -> int:
63+
"""Return an integer rank for comparison (higher is more visible)."""
64+
ranking = {
65+
"full": 4,
66+
"most": 3,
67+
"partial": 2,
68+
"none": 1,
69+
"unavailable": None,
70+
}
71+
return ranking[self.value]
72+
73+
def is_comparable(self) -> bool:
74+
"""Return True if the visibility level has a defined rank."""
75+
return self.rank() is not None
76+
77+
def _check_comparability(self, other: VisibilityLevel) -> None:
78+
if not (self.is_comparable() and other.is_comparable()):
79+
raise ValueError(f"Cannot compare unknown visibility levels: {self}, {other}")
80+
81+
def __lt__(self, other: VisibilityLevel) -> bool:
82+
self._check_comparability(other)
83+
return self.rank() < other.rank()
84+
85+
def __le__(self, other: VisibilityLevel) -> bool:
86+
self._check_comparability(other)
87+
return self.rank() <= other.rank()
88+
89+
def __gt__(self, other: VisibilityLevel) -> bool:
90+
self._check_comparability(other)
91+
return self.rank() > other.rank()
92+
93+
def __ge__(self, other: VisibilityLevel) -> bool:
94+
self._check_comparability(other)
95+
return self.rank() >= other.rank()
96+
6297

6398
@define(slots=False)
6499
@SCHEMAS.register(SchemaName.VISIBILITY)

tests/conftest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def dummy_box3ds() -> list[Box3D]:
130130
confidence=1.0,
131131
uuid="bicycle3d_1",
132132
num_points=1,
133-
visibility=VisibilityLevel.FULL,
133+
visibility=VisibilityLevel.MOST,
134134
).with_future(
135135
timestamps=[101, 102, 103, 104],
136136
confidences=[1.0, 0.5],
@@ -164,7 +164,7 @@ def dummy_box3ds() -> list[Box3D]:
164164
confidence=1.0,
165165
uuid="pedestrian3d_1",
166166
num_points=1,
167-
visibility="full",
167+
visibility="none", # str is also OK
168168
).with_future(
169169
timestamps=[101, 102, 103, 104],
170170
confidences=[1.0, 0.5, 0.2],

tests/fitering/test_filter_function.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
1+
from __future__ import annotations
2+
13
from t4_devkit.filtering.functional import (
24
FilterByDistance,
35
FilterByLabel,
46
FilterByNumPoints,
57
FilterByRegion,
68
FilterBySpeed,
79
FilterByUUID,
10+
FilterByVisibility,
811
)
12+
from t4_devkit.schema import VisibilityLevel
913

1014

1115
def test_filter_by_label(dummy_box3ds, dummy_box2ds) -> None:
@@ -116,3 +120,11 @@ def test_filter_by_num_points(dummy_box3ds) -> None:
116120
answer = [box for box in dummy_box3ds if box_filter(box)]
117121

118122
assert len(answer) == 3
123+
124+
125+
def test_filter_by_visibility(dummy_box3ds) -> None:
126+
box_filter = FilterByVisibility(visibility=VisibilityLevel.MOST)
127+
128+
answer = [box for box in dummy_box3ds if box_filter(box)]
129+
130+
assert len(answer) == 2

0 commit comments

Comments
 (0)