diff --git a/TPTBox/core/poi.py b/TPTBox/core/poi.py index 6a3d53a..cc3a1e9 100755 --- a/TPTBox/core/poi.py +++ b/TPTBox/core/poi.py @@ -123,13 +123,16 @@ class POI(Abstract_POI, Has_Grid): _vert_orientation_pir = {} # Elusive; will not be saved; will not be copied. For Buffering results # noqa: RUF012 def _set_inplace(self, poi: Self) -> Self: - """Copy all grid/affine attributes and centroids from ``poi`` into ``self``.""" + """Copy all grid/affine attributes, centroids, and naming metadata from ``poi`` into ``self``.""" self.orientation = poi.orientation self.centroids = poi.centroids self.zoom = poi.zoom self.shape = poi.shape self.origin = poi.origin self.rotation = poi.rotation + self.info = poi.info + self.level_one_info = poi.level_one_info + self.level_two_info = poi.level_two_info return self @property @@ -224,6 +227,8 @@ def copy( origin=origin if not isinstance(origin, Sentinel) else self.origin, info=deepcopy(self.info), format=self.format, + level_one_info=self.level_one_info, + level_two_info=self.level_two_info, ) def local_to_global(self, x: COORDINATE, itk_coords=False) -> COORDINATE: @@ -1330,7 +1335,18 @@ def calc_poi_average(pois: list[POI], keep_points_not_present_in_all_pois: bool # Sort the new ctd by keys ctd = dict(sorted(ctd.items())) - return POI(centroids=ctd, orientation=pois[0].orientation, zoom=pois[0].zoom, shape=pois[0].shape, rotation=pois[0].rotation) + return POI( + centroids=ctd, + orientation=pois[0].orientation, + zoom=pois[0].zoom, + shape=pois[0].shape, + rotation=pois[0].rotation, + origin=pois[0].origin, + info=deepcopy(pois[0].info), + format=pois[0].format, + level_one_info=pois[0].level_one_info, + level_two_info=pois[0].level_two_info, + ) def _load_from_POI_spine_r(data: dict) -> POI: diff --git a/TPTBox/core/poi_fun/poi_abstract.py b/TPTBox/core/poi_fun/poi_abstract.py index d74dfee..809420e 100755 --- a/TPTBox/core/poi_fun/poi_abstract.py +++ b/TPTBox/core/poi_fun/poi_abstract.py @@ -79,16 +79,51 @@ def __init__( self.subregion_name2idx = {value: key for key, value in subregion.items()} -def unpack_poi_id(key: POI_ID, definition: _Abstract_POI_Definition) -> tuple[int, int]: +def _resolve_poi_part(part, name2idx: dict, level_info=None) -> int: + """Resolve one half of a POI key (region or subregion) to its integer id. + + Order: Enum -> ``.value``; numeric string -> int; other string -> ``level_info`` enum + name (when ``level_info`` is a concrete enum), then ``name2idx``. When ``level_info`` is + the ``Any`` wildcard and an Enum is used, a warning is emitted (the level info is not set). + """ + if isinstance(part, Enum): + if level_info is Any: + import warnings + + warnings.warn( + f"Indexing a POI with the enum {part!r} but the matching level_one_info/level_two_info is not set " + f"(it is 'Any'); the integer value {part.value} is used.", + stacklevel=4, + ) + return part.value + if isinstance(part, str): + try: + return int(part) + except ValueError: + pass + if level_info is not None and level_info is not Any: + try: + return level_info._get_id(part, no_raise=False) + except Exception: + pass + return name2idx[part] + return part + + +def unpack_poi_id(key: POI_ID, definition: _Abstract_POI_Definition, level_one_info=None, level_two_info=None) -> tuple[int, int]: """Convert any supported POI key type to a ``(region, subregion)`` integer pair. Accepted key forms: plain integer (packed label), ``slice(region, subregion)``, 2-tuple of ints, 2-tuple of ``Abstract_lvl`` / ``Enum`` members, or mixed tuples. - String values are resolved via ``definition``'s name-to-index mappings. + String values resolve via ``level_one_info``/``level_two_info`` enum names (when given) + and otherwise via ``definition``'s name-to-index mappings, so ``poi[idx, "level2name"]`` + and ``poi["level1name", "level2name"]`` work in addition to ids and Enums. Args: key: POI identifier in any of the supported formats. definition: Name-to-integer mapping used to resolve string labels. + level_one_info: Optional region-level enum class for name/enum resolution + warnings. + level_two_info: Optional subregion-level enum class for name/enum resolution + warnings. Returns: ``(region, subregion)`` tuple of plain Python integers. @@ -101,23 +136,66 @@ def unpack_poi_id(key: POI_ID, definition: _Abstract_POI_Definition) -> tuple[in subregion = key.stop else: region, subregion = key - if isinstance(region, str): - try: - region = int(region) - except ValueError: - region = definition.region_name2idx[region] - if isinstance(region, Enum): - region = region.value - if isinstance(subregion, str): - try: - subregion = int(subregion) - except ValueError: - subregion = definition.subregion_name2idx[subregion] - if isinstance(subregion, Enum): - subregion = subregion.value + region = _resolve_poi_part(region, definition.region_name2idx, level_one_info) + subregion = _resolve_poi_part(subregion, definition.subregion_name2idx, level_two_info) return region, subregion +# --------------------------------------------------------------------------- +# label_name (custom per-point / per-region names stored in POI.info["label_name"]) +# In-memory (new) format: {region:int -> {subregion:int -> name:str, "name": group_name:str}} +# --------------------------------------------------------------------------- +LABEL_NAME = "label_name" +_GROUP_NAME_KEY = "name" + + +def normalize_label_name(d: dict | None) -> dict[int, dict]: + """Normalize any ``label_name`` mapping to the nested form (migration helper). + + Target: ``{region:int -> {subregion:int -> name:str, "name": group_name:str}}``. + Accepts and migrates: the old flat ``{"(1, 2)": "C2"}`` form, a JSON-loaded nested + form with string keys ``{"1": {"2": "C2", "name": "Spine"}}``, and the already-nested + form (idempotent). + """ + import ast + + if not d: + return {} + out: dict[int, dict] = {} + # old flat format: every key is a "(region, subregion)" tuple string + if all(isinstance(k, str) and k.strip().startswith("(") for k in d): + for k, name in d.items(): + region, subregion = ast.literal_eval(k) + out.setdefault(int(region), {})[int(subregion)] = name + return out + # nested format (region keys may be JSON strings) + for region, sub in d.items(): + target = out.setdefault(int(region), {}) + if isinstance(sub, dict): + for s, name in sub.items(): + target[_GROUP_NAME_KEY if s == _GROUP_NAME_KEY else int(s)] = name + else: # degenerate {region: name} -> treat as the region group name + target[_GROUP_NAME_KEY] = sub + return out + + +def label_name_dict(info: dict) -> dict[int, dict]: + """Return the normalized nested ``label_name`` dict from ``info`` (normalizing in place).""" + ln = info.get(LABEL_NAME) + if not ln: + return {} + norm = normalize_label_name(ln) + info[LABEL_NAME] = norm # cache the normalized form back into info + return norm + + +def _id_of(x) -> int: + """Resolve an Enum member / int / numeric string to its integer id.""" + if isinstance(x, Enum): + return x.value + return int(x) + + class POI_Descriptor(AbstractSet, MutableMapping): """Two-level dictionary that maps ``(region, subregion)`` pairs to 3-D coordinates. @@ -169,10 +247,10 @@ def __set__(self, obj, value) -> None: setattr(obj, self._name, int(value)) def copy(self) -> POI_Descriptor: - """Return a deep copy of this descriptor.""" + """Return a deep copy of this descriptor (keeping its name<->id ``definition``).""" from copy import deepcopy - return POI_Descriptor(default=deepcopy(self.pois)) + return POI_Descriptor(default=deepcopy(self.pois), definition=self.definition) def _sort(self: Self, inplace=True, order_dict: dict | None = None): """Sort vertebra dictionary by sorting_list.""" @@ -676,17 +754,20 @@ def __iter__(self): """Iterate over all ``(region, subregion)`` key pairs.""" return iter(self.centroids.keys()) + def _resolve_key(self, key: POI_ID) -> tuple[int, int]: + """Resolve any key (ids, Enums, or ``level_one_info``/``level_two_info`` names) to ``(region, subregion)`` ints.""" + return unpack_poi_id(key, self.centroids.definition, self.level_one_info, self.level_two_info) + def __contains__(self, key: POI_ID) -> bool: - key = unpack_poi_id(key, self.centroids.definition) - return key in self.centroids + return self._resolve_key(key) in self.centroids def __getitem__(self, key: POI_ID) -> COORDINATE: - return tuple(self.centroids[key]) + return tuple(self.centroids[self._resolve_key(key)]) def __setitem__(self, key: POI_ID, value: tuple[float, float, float] | Sequence[float] | np.ndarray) -> None: if len(value) != DIMENSIONS: raise ValueError(value) - self.centroids[key] = tuple(value) + self.centroids[self._resolve_key(key)] = tuple(value) def __len__(self) -> int: return self.centroids.__len__() @@ -754,6 +835,57 @@ def keys_subregion(self, sort: bool = False) -> list[int]: self.sort() return list(self.centroids.keys_subregion()) + def label_name(self, region, subregion) -> str | None: + """Return the human-readable name of the point ``(region, subregion)``. + + A custom name in ``info["label_name"]`` takes priority over the ``level_two_info`` enum + name (and finally falls back to the raw id as a string). Accepts ints or Enum members. + A warning is raised when a custom name is set that differs from the ``level_two_info`` + name for that id (when ``level_two_info`` is given). + """ + import warnings + + region_i, subregion_i = _id_of(region), _id_of(subregion) + custom = label_name_dict(self.info).get(region_i, {}).get(subregion_i) + enum_name = None + if self.level_two_info not in (None, Any): + n = self.level_two_info._get_name(subregion_i, no_raise=True) + enum_name = n if n != str(subregion_i) else None + if custom is not None: + if enum_name is not None and custom != enum_name: + warnings.warn( + f"label_name {custom!r} for subregion {subregion_i} is not the {self.level_two_info.__name__} name {enum_name!r}", + stacklevel=2, + ) + return custom + return enum_name if enum_name is not None else str(subregion_i) + + def set_label_name(self, region, subregion, name: str) -> None: + """Set a custom name for the point ``(region, subregion)`` in ``info["label_name"]``.""" + d = label_name_dict(self.info) + d.setdefault(_id_of(region), {})[_id_of(subregion)] = name + self.info[LABEL_NAME] = d + + def level_one_name(self, region) -> str | None: + """Return the group (level-one) name of ``region``. + + A custom group name in ``info["label_name"][region]["name"]`` takes priority over the + ``level_one_info`` enum name. Accepts an int or Enum member. + """ + region_i = _id_of(region) + custom = label_name_dict(self.info).get(region_i, {}).get(_GROUP_NAME_KEY) + if custom is not None: + return custom + if self.level_one_info not in (None, Any): + return self.level_one_info._get_name(region_i, no_raise=True) + return str(region_i) + + def set_level_one_name(self, region, name: str) -> None: + """Set a custom group (level-one) name for ``region`` in ``info["label_name"]``.""" + d = label_name_dict(self.info) + d.setdefault(_id_of(region), {})[_GROUP_NAME_KEY] = name + self.info[LABEL_NAME] = d + def values(self, sort: bool = False) -> list[COORDINATE]: """Return all stored ``(x, y, z)`` coordinate tuples. @@ -1030,16 +1162,23 @@ def join_left(self, pois: Self, inplace=False, _right_join=False) -> Self: Self: The combined set of centroids, either in-place or as a new set, depending on the 'inplace' parameter. """ ctd_list = self.centroids - if "label_name" in pois.info and "label_name" not in self.info: - self.info["label_name"] = {} + src_ln = label_name_dict(pois.info) if not inplace: ctd_list = ctd_list.copy() + dst_ln = label_name_dict(self.info) if src_ln else None + if dst_ln is not None: + self.info[LABEL_NAME] = dst_ln for x, y, c in pois.items(): if (x, y) in self and not _right_join: continue ctd_list[x, y] = c - if "label_name" in pois.info and f"({x}, {y})" in pois.info["label_name"]: - self.info["label_name"][f"({x}, {y})"] = pois.info["label_name"][f"({x}, {y})"] + if dst_ln is not None: + name = src_ln.get(x, {}).get(y) + if name is not None: + dst_ln.setdefault(x, {})[y] = name + grp = src_ln.get(x, {}).get(_GROUP_NAME_KEY) + if grp is not None: + dst_ln.setdefault(x, {}).setdefault(_GROUP_NAME_KEY, grp) if inplace: return self return self.copy(ctd_list) diff --git a/TPTBox/core/poi_fun/poi_global.py b/TPTBox/core/poi_fun/poi_global.py index 4c1ba8d..36778e7 100755 --- a/TPTBox/core/poi_fun/poi_global.py +++ b/TPTBox/core/poi_fun/poi_global.py @@ -187,7 +187,14 @@ def to_other(self, msk: Has_Grid, verbose=False) -> poi.POI: log.print(v, "-->", v_out) out[k1, k2] = tuple(v_out) - return poi.POI(centroids=out, **msk._extract_affine(), info=self.info, format=self.format) + return poi.POI( + centroids=out, + **msk._extract_affine(), + info=self.info, + format=self.format, + level_one_info=self.level_one_info, + level_two_info=self.level_two_info, + ) def copy(self, centroids: POI_Descriptor | None = None) -> Self: """Return a deep copy of this ``POI_Global``. diff --git a/TPTBox/core/poi_fun/save_load.py b/TPTBox/core/poi_fun/save_load.py index 966d991..fb638a4 100644 --- a/TPTBox/core/poi_fun/save_load.py +++ b/TPTBox/core/poi_fun/save_load.py @@ -11,7 +11,7 @@ # from TPTBox import POI, POI_Global from TPTBox.core import bids_files from TPTBox.core.nii_poi_abstract import Has_Grid -from TPTBox.core.poi_fun.poi_abstract import POI_Descriptor +from TPTBox.core.poi_fun.poi_abstract import _GROUP_NAME_KEY, LABEL_NAME, POI_Descriptor, label_name_dict, normalize_label_name from TPTBox.core.vert_constants import ( AX_CODES, COORDINATE, @@ -207,7 +207,8 @@ def _poi_to_dict_list( # noqa: C901 for k, v in ctd.info.items(): if k not in ori: - ori[k] = v + # always persist label_name in the new nested format {region:{sub:name,"name":group}} + ori[k] = normalize_label_name(v) if k == LABEL_NAME else v dict_list: list[_Orientation | (_Point3D | dict)] = [ori] @@ -353,6 +354,9 @@ def load_poi(ctd_path: POI_Reference, verbose=True) -> POI | POI_Global: # noqa level_one_info = _register_lvl[dict_list[0].get("level_one_info", Vertebra_Instance.__name__)] level_two_info = _register_lvl[dict_list[0].get("level_two_info", Location.__name__)] info = {k: v for k, v in dict_list[0].items() if k not in ctd_info_blacklist} + if LABEL_NAME in info: + # migrate old flat {"(1, 2)": "C2"} / JSON string keys -> nested {1: {2: "C2", "name": ...}} + info[LABEL_NAME] = normalize_label_name(info[LABEL_NAME]) if format_ in (FORMAT_GLOBAL, FORMAT_PLST): from TPTBox import POI_Global @@ -649,8 +653,9 @@ def _load_mkr_POI(dict_mkr: dict) -> POI_Global: description = control_points.get("description", region) associatedNodeID = control_points.get("associatedNodeID", description) label_group_name[region] = associatedNodeID + label_name.setdefault(region, {})[_GROUP_NAME_KEY] = associatedNodeID - label_name[str((region, subregion))] = label + label_name.setdefault(region, {})[subregion] = label assert itk_coords is not None, "itk_coords not set" from TPTBox import POI_Global @@ -821,7 +826,7 @@ def _load_landmark_txt(path: Path) -> list: id_ = label_group_name[current_group] new_id = len(points[id_]) + 1 points[id_][new_id] = coords - label_name[str((id_, new_id))] = key + label_name.setdefault(id_, {})[new_id] = key if len(label_name) != 0: header["label_name"] = label_name if len(label_group_name) != 0: diff --git a/TPTBox/core/poi_fun/save_mkr.py b/TPTBox/core/poi_fun/save_mkr.py index 9783aa3..9da952d 100644 --- a/TPTBox/core/poi_fun/save_mkr.py +++ b/TPTBox/core/poi_fun/save_mkr.py @@ -10,6 +10,7 @@ import numpy as np from typing_extensions import NotRequired +from TPTBox.core.poi_fun.poi_abstract import _GROUP_NAME_KEY, label_name_dict from TPTBox.logger.log_file import log from TPTBox.mesh3D.mesh_colors import RGB_Color, get_color_by_label @@ -404,7 +405,8 @@ def get_desc(self: POI_Global, region: int, subregion: int) -> tuple[str, str, s Tuple ``(name, name2, label)`` where ``name`` and ``name2`` are the group/region name and ``label`` is the subregion label string. """ - label = self.info.get("label_name", {}).get(str((region, subregion))) + _ln = label_name_dict(self.info) + label = _ln.get(region, {}).get(subregion) if label is None: label = str(subregion) try: @@ -412,7 +414,9 @@ def get_desc(self: POI_Global, region: int, subregion: int) -> tuple[str, str, s except Exception: label = str(subregion) try: - if region in self.info.get("label_group_name", {}): + if _ln.get(region, {}).get(_GROUP_NAME_KEY) is not None: + name2 = _ln[region][_GROUP_NAME_KEY] + elif region in self.info.get("label_group_name", {}): name2 = self.info["label_group_name"][region] else: name2 = self.level_one_info(region).name diff --git a/unit_tests/test_poi_label_names.py b/unit_tests/test_poi_label_names.py new file mode 100644 index 0000000..88a0208 --- /dev/null +++ b/unit_tests/test_poi_label_names.py @@ -0,0 +1,147 @@ +"""Tests for POI label-name handling: nested format + migration, accessors, and metadata transfer.""" + +from __future__ import annotations + +import json +import os +import tempfile +import unittest +from pathlib import Path + +import numpy as np + +from TPTBox.core.poi import POI, calc_poi_average +from TPTBox.core.poi_fun.poi_abstract import normalize_label_name +from TPTBox.core.vert_constants import Location, Vertebra_Instance + + +def _poi(): + return POI( + centroids={1: {50: (1.0, 2.0, 3.0)}, 2: {50: (4.0, 5.0, 6.0)}}, + orientation=("R", "A", "S"), + zoom=(1.0, 1.0, 1.0), + shape=(10.0, 10.0, 10.0), + rotation=np.eye(3), + origin=(0.0, 0.0, 0.0), + level_one_info=Vertebra_Instance, + level_two_info=Location, + ) + + +class TestLabelNameMigration(unittest.TestCase): + def test_old_flat_to_nested(self): + self.assertEqual( + normalize_label_name({"(1, 2)": "C2", "(1, 3)": "C3", "(2, 50)": "Body"}), + {1: {2: "C2", 3: "C3"}, 2: {50: "Body"}}, + ) + + def test_json_string_keys_to_int(self): + self.assertEqual(normalize_label_name({"1": {"2": "C2", "name": "Spine"}}), {1: {2: "C2", "name": "Spine"}}) + + def test_idempotent_and_empty(self): + nested = {1: {2: "C2", "name": "Spine"}} + self.assertEqual(normalize_label_name(nested), nested) + self.assertEqual(normalize_label_name(None), {}) + self.assertEqual(normalize_label_name({}), {}) + + +class TestAccessors(unittest.TestCase): + def test_set_get_point_and_group_name(self): + p = _poi() + p.set_label_name(1, 50, "MyCorpus") + p.set_level_one_name(1, "CervicalGroup") + self.assertEqual(p.label_name(1, 50), "MyCorpus") + self.assertEqual(p.level_one_name(1), "CervicalGroup") + self.assertEqual(p.info["label_name"], {1: {50: "MyCorpus", "name": "CervicalGroup"}}) + + def test_enum_fallback(self): + p = _poi() + # Vertebra_Instance saves as name -> group name falls back to the enum name + self.assertEqual(p.level_one_name(2), Vertebra_Instance._get_name(2)) + # accepts Enum members too + self.assertEqual(p.level_one_name(Vertebra_Instance.C2), Vertebra_Instance._get_name(2)) + + +class TestRoundTrip(unittest.TestCase): + def test_save_new_format_load_back(self): + p = _poi() + p.set_label_name(1, 50, "MyCorpus") + p.set_level_one_name(1, "CervicalGroup") + fp = os.path.join(tempfile.mkdtemp(), "sub-x_ctd.json") + p.save(fp, verbose=False) + on_disk = json.loads(Path(fp).read_text())[0]["label_name"] + self.assertEqual(on_disk, {"1": {"50": "MyCorpus", "name": "CervicalGroup"}}) # nested, JSON string keys + q = POI.load(fp) + self.assertEqual(q.info["label_name"], {1: {50: "MyCorpus", "name": "CervicalGroup"}}) + self.assertEqual(q.label_name(1, 50), "MyCorpus") + + def test_load_old_flat_format_migrates(self): + p = _poi() + fp = os.path.join(tempfile.mkdtemp(), "sub-x_ctd.json") + p.save(fp, verbose=False) + raw = json.loads(Path(fp).read_text()) + raw[0]["label_name"] = {"(1, 50)": "OldName"} # legacy flat format on disk + Path(fp).write_text(json.dumps(raw)) + q = POI.load(fp) + self.assertEqual(q.info["label_name"], {1: {50: "OldName"}}) + + +class TestMetadataTransfer(unittest.TestCase): + def _check(self, q): + self.assertIs(q.level_one_info, Vertebra_Instance) + self.assertIs(q.level_two_info, Location) + self.assertIn("label_name", q.info) + + def test_metadata_survives_ops(self): + p = _poi() + p.set_label_name(1, 50, "MyCorpus") + for q in ( + p.copy(), + p.reorient(("P", "I", "R")), + p.rescale((2.0, 2.0, 2.0), verbose=False), + p.extract_subregion(Location.Vertebra_Corpus), + p.extract_region(1), + p.map_labels(label_map_region={1: 10}), + calc_poi_average([p, p.copy()]), + ): + self._check(q) + + +class TestNameIndexing(unittest.TestCase): + def test_index_by_level_names(self): + p = _poi() + self.assertEqual(p[1, "Vertebra_Corpus"], (1.0, 2.0, 3.0)) # [idx, "level2name"] + self.assertEqual(p["C1", "Vertebra_Corpus"], (1.0, 2.0, 3.0)) # ["level1name", "level2name"] + self.assertEqual(p["C2", "Vertebra_Corpus"], (4.0, 5.0, 6.0)) + self.assertIn(("C2", "Vertebra_Corpus"), p) + p["C1", "Vertebra_Disc"] = (7.0, 8.0, 9.0) + self.assertEqual(p[1, 100], (7.0, 8.0, 9.0)) # Vertebra_Disc == 100 + + def test_enum_with_level_set_no_warning(self): + import warnings + + p = _poi() + with warnings.catch_warnings(): + warnings.simplefilter("error") + self.assertEqual(p[Vertebra_Instance.C1, Location.Vertebra_Corpus], (1.0, 2.0, 3.0)) + + def test_enum_without_level_warns(self): + import warnings + + from TPTBox.core.vert_constants import Any + + q = POI( + centroids={1: {50: (1.0, 2.0, 3.0)}}, + orientation=("R", "A", "S"), + zoom=(1.0, 1.0, 1.0), + shape=(10.0, 10.0, 10.0), + ) + self.assertIs(q.level_one_info, Any) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + _ = q[Vertebra_Instance.C1, Location.Vertebra_Corpus] + self.assertTrue(any("not set" in str(x.message) for x in w)) + + +if __name__ == "__main__": + unittest.main()