Skip to content

Commit 3ab343d

Browse files
Hendrik-codeclaude
andcommitted
feat(poi): index POIs by level names + warn on enum without level_*_info
unpack_poi_id now resolves a key part via level_one_info/level_two_info enum names (when set) before the descriptor's definition, so poi[idx, 'level2name'] and poi['level1name', 'level2name'] work for get/set/contains alongside ids and Enums. Using an Enum key while the matching level_*_info is still 'Any' (unset) now emits a warning (the .value is still used). 100 POI tests pass. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
1 parent 38be00e commit 3ab343d

2 files changed

Lines changed: 82 additions & 20 deletions

File tree

TPTBox/core/poi_fun/poi_abstract.py

Lines changed: 46 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -79,16 +79,51 @@ def __init__(
7979
self.subregion_name2idx = {value: key for key, value in subregion.items()}
8080

8181

82-
def unpack_poi_id(key: POI_ID, definition: _Abstract_POI_Definition) -> tuple[int, int]:
82+
def _resolve_poi_part(part, name2idx: dict, level_info=None) -> int:
83+
"""Resolve one half of a POI key (region or subregion) to its integer id.
84+
85+
Order: Enum -> ``.value``; numeric string -> int; other string -> ``level_info`` enum
86+
name (when ``level_info`` is a concrete enum), then ``name2idx``. When ``level_info`` is
87+
the ``Any`` wildcard and an Enum is used, a warning is emitted (the level info is not set).
88+
"""
89+
if isinstance(part, Enum):
90+
if level_info is Any:
91+
import warnings
92+
93+
warnings.warn(
94+
f"Indexing a POI with the enum {part!r} but the matching level_one_info/level_two_info is not set "
95+
f"(it is 'Any'); the integer value {part.value} is used.",
96+
stacklevel=4,
97+
)
98+
return part.value
99+
if isinstance(part, str):
100+
try:
101+
return int(part)
102+
except ValueError:
103+
pass
104+
if level_info is not None and level_info is not Any:
105+
try:
106+
return level_info._get_id(part, no_raise=False)
107+
except Exception:
108+
pass
109+
return name2idx[part]
110+
return part
111+
112+
113+
def unpack_poi_id(key: POI_ID, definition: _Abstract_POI_Definition, level_one_info=None, level_two_info=None) -> tuple[int, int]:
83114
"""Convert any supported POI key type to a ``(region, subregion)`` integer pair.
84115
85116
Accepted key forms: plain integer (packed label), ``slice(region, subregion)``,
86117
2-tuple of ints, 2-tuple of ``Abstract_lvl`` / ``Enum`` members, or mixed tuples.
87-
String values are resolved via ``definition``'s name-to-index mappings.
118+
String values resolve via ``level_one_info``/``level_two_info`` enum names (when given)
119+
and otherwise via ``definition``'s name-to-index mappings, so ``poi[idx, "level2name"]``
120+
and ``poi["level1name", "level2name"]`` work in addition to ids and Enums.
88121
89122
Args:
90123
key: POI identifier in any of the supported formats.
91124
definition: Name-to-integer mapping used to resolve string labels.
125+
level_one_info: Optional region-level enum class for name/enum resolution + warnings.
126+
level_two_info: Optional subregion-level enum class for name/enum resolution + warnings.
92127
93128
Returns:
94129
``(region, subregion)`` tuple of plain Python integers.
@@ -101,20 +136,8 @@ def unpack_poi_id(key: POI_ID, definition: _Abstract_POI_Definition) -> tuple[in
101136
subregion = key.stop
102137
else:
103138
region, subregion = key
104-
if isinstance(region, str):
105-
try:
106-
region = int(region)
107-
except ValueError:
108-
region = definition.region_name2idx[region]
109-
if isinstance(region, Enum):
110-
region = region.value
111-
if isinstance(subregion, str):
112-
try:
113-
subregion = int(subregion)
114-
except ValueError:
115-
subregion = definition.subregion_name2idx[subregion]
116-
if isinstance(subregion, Enum):
117-
subregion = subregion.value
139+
region = _resolve_poi_part(region, definition.region_name2idx, level_one_info)
140+
subregion = _resolve_poi_part(subregion, definition.subregion_name2idx, level_two_info)
118141
return region, subregion
119142

120143

@@ -731,17 +754,20 @@ def __iter__(self):
731754
"""Iterate over all ``(region, subregion)`` key pairs."""
732755
return iter(self.centroids.keys())
733756

757+
def _resolve_key(self, key: POI_ID) -> tuple[int, int]:
758+
"""Resolve any key (ids, Enums, or ``level_one_info``/``level_two_info`` names) to ``(region, subregion)`` ints."""
759+
return unpack_poi_id(key, self.centroids.definition, self.level_one_info, self.level_two_info)
760+
734761
def __contains__(self, key: POI_ID) -> bool:
735-
key = unpack_poi_id(key, self.centroids.definition)
736-
return key in self.centroids
762+
return self._resolve_key(key) in self.centroids
737763

738764
def __getitem__(self, key: POI_ID) -> COORDINATE:
739-
return tuple(self.centroids[key])
765+
return tuple(self.centroids[self._resolve_key(key)])
740766

741767
def __setitem__(self, key: POI_ID, value: tuple[float, float, float] | Sequence[float] | np.ndarray) -> None:
742768
if len(value) != DIMENSIONS:
743769
raise ValueError(value)
744-
self.centroids[key] = tuple(value)
770+
self.centroids[self._resolve_key(key)] = tuple(value)
745771

746772
def __len__(self) -> int:
747773
return self.centroids.__len__()

unit_tests/test_poi_label_names.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,5 +107,41 @@ def test_metadata_survives_ops(self):
107107
self._check(q)
108108

109109

110+
class TestNameIndexing(unittest.TestCase):
111+
def test_index_by_level_names(self):
112+
p = _poi()
113+
self.assertEqual(p[1, "Vertebra_Corpus"], (1.0, 2.0, 3.0)) # [idx, "level2name"]
114+
self.assertEqual(p["C1", "Vertebra_Corpus"], (1.0, 2.0, 3.0)) # ["level1name", "level2name"]
115+
self.assertEqual(p["C2", "Vertebra_Corpus"], (4.0, 5.0, 6.0))
116+
self.assertIn(("C2", "Vertebra_Corpus"), p)
117+
p["C1", "Vertebra_Disc"] = (7.0, 8.0, 9.0)
118+
self.assertEqual(p[1, 100], (7.0, 8.0, 9.0)) # Vertebra_Disc == 100
119+
120+
def test_enum_with_level_set_no_warning(self):
121+
import warnings
122+
123+
p = _poi()
124+
with warnings.catch_warnings():
125+
warnings.simplefilter("error")
126+
self.assertEqual(p[Vertebra_Instance.C1, Location.Vertebra_Corpus], (1.0, 2.0, 3.0))
127+
128+
def test_enum_without_level_warns(self):
129+
import warnings
130+
131+
from TPTBox.core.vert_constants import Any
132+
133+
q = POI(
134+
centroids={1: {50: (1.0, 2.0, 3.0)}},
135+
orientation=("R", "A", "S"),
136+
zoom=(1.0, 1.0, 1.0),
137+
shape=(10.0, 10.0, 10.0),
138+
)
139+
self.assertIs(q.level_one_info, Any)
140+
with warnings.catch_warnings(record=True) as w:
141+
warnings.simplefilter("always")
142+
_ = q[Vertebra_Instance.C1, Location.Vertebra_Corpus]
143+
self.assertTrue(any("not set" in str(x.message) for x in w))
144+
145+
110146
if __name__ == "__main__":
111147
unittest.main()

0 commit comments

Comments
 (0)