Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions TPTBox/core/poi.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,16 @@ class POI(Abstract_POI, Has_Grid):
_vert_orientation_pir = {} # Elusive; will not be saved; will not be copied. For Buffering results # noqa: RUF012

def _set_inplace(self, poi: Self) -> Self:
"""Copy all grid/affine attributes and centroids from ``poi`` into ``self``."""
"""Copy all grid/affine attributes, centroids, and naming metadata from ``poi`` into ``self``."""
self.orientation = poi.orientation
self.centroids = poi.centroids
self.zoom = poi.zoom
self.shape = poi.shape
self.origin = poi.origin
self.rotation = poi.rotation
self.info = poi.info
self.level_one_info = poi.level_one_info
self.level_two_info = poi.level_two_info
return self

@property
Expand Down Expand Up @@ -224,6 +227,8 @@ def copy(
origin=origin if not isinstance(origin, Sentinel) else self.origin,
info=deepcopy(self.info),
format=self.format,
level_one_info=self.level_one_info,
level_two_info=self.level_two_info,
)

def local_to_global(self, x: COORDINATE, itk_coords=False) -> COORDINATE:
Expand Down Expand Up @@ -1330,7 +1335,18 @@ def calc_poi_average(pois: list[POI], keep_points_not_present_in_all_pois: bool

# Sort the new ctd by keys
ctd = dict(sorted(ctd.items()))
return POI(centroids=ctd, orientation=pois[0].orientation, zoom=pois[0].zoom, shape=pois[0].shape, rotation=pois[0].rotation)
return POI(
centroids=ctd,
orientation=pois[0].orientation,
zoom=pois[0].zoom,
shape=pois[0].shape,
rotation=pois[0].rotation,
origin=pois[0].origin,
info=deepcopy(pois[0].info),
format=pois[0].format,
level_one_info=pois[0].level_one_info,
level_two_info=pois[0].level_two_info,
)


def _load_from_POI_spine_r(data: dict) -> POI:
Expand Down
191 changes: 165 additions & 26 deletions TPTBox/core/poi_fun/poi_abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,16 +79,51 @@ def __init__(
self.subregion_name2idx = {value: key for key, value in subregion.items()}


def unpack_poi_id(key: POI_ID, definition: _Abstract_POI_Definition) -> tuple[int, int]:
def _resolve_poi_part(part, name2idx: dict, level_info=None) -> int:
"""Resolve one half of a POI key (region or subregion) to its integer id.

Order: Enum -> ``.value``; numeric string -> int; other string -> ``level_info`` enum
name (when ``level_info`` is a concrete enum), then ``name2idx``. When ``level_info`` is
the ``Any`` wildcard and an Enum is used, a warning is emitted (the level info is not set).
"""
if isinstance(part, Enum):
if level_info is Any:
import warnings

warnings.warn(
f"Indexing a POI with the enum {part!r} but the matching level_one_info/level_two_info is not set "
f"(it is 'Any'); the integer value {part.value} is used.",
stacklevel=4,
)
return part.value
if isinstance(part, str):
try:
return int(part)
except ValueError:
pass
if level_info is not None and level_info is not Any:
try:
return level_info._get_id(part, no_raise=False)
except Exception:
pass
return name2idx[part]
return part


def unpack_poi_id(key: POI_ID, definition: _Abstract_POI_Definition, level_one_info=None, level_two_info=None) -> tuple[int, int]:
"""Convert any supported POI key type to a ``(region, subregion)`` integer pair.

Accepted key forms: plain integer (packed label), ``slice(region, subregion)``,
2-tuple of ints, 2-tuple of ``Abstract_lvl`` / ``Enum`` members, or mixed tuples.
String values are resolved via ``definition``'s name-to-index mappings.
String values resolve via ``level_one_info``/``level_two_info`` enum names (when given)
and otherwise via ``definition``'s name-to-index mappings, so ``poi[idx, "level2name"]``
and ``poi["level1name", "level2name"]`` work in addition to ids and Enums.

Args:
key: POI identifier in any of the supported formats.
definition: Name-to-integer mapping used to resolve string labels.
level_one_info: Optional region-level enum class for name/enum resolution + warnings.
level_two_info: Optional subregion-level enum class for name/enum resolution + warnings.

Returns:
``(region, subregion)`` tuple of plain Python integers.
Expand All @@ -101,23 +136,66 @@ def unpack_poi_id(key: POI_ID, definition: _Abstract_POI_Definition) -> tuple[in
subregion = key.stop
else:
region, subregion = key
if isinstance(region, str):
try:
region = int(region)
except ValueError:
region = definition.region_name2idx[region]
if isinstance(region, Enum):
region = region.value
if isinstance(subregion, str):
try:
subregion = int(subregion)
except ValueError:
subregion = definition.subregion_name2idx[subregion]
if isinstance(subregion, Enum):
subregion = subregion.value
region = _resolve_poi_part(region, definition.region_name2idx, level_one_info)
subregion = _resolve_poi_part(subregion, definition.subregion_name2idx, level_two_info)
return region, subregion


# ---------------------------------------------------------------------------
# label_name (custom per-point / per-region names stored in POI.info["label_name"])
# In-memory (new) format: {region:int -> {subregion:int -> name:str, "name": group_name:str}}
# ---------------------------------------------------------------------------
LABEL_NAME = "label_name"
_GROUP_NAME_KEY = "name"


def normalize_label_name(d: dict | None) -> dict[int, dict]:
"""Normalize any ``label_name`` mapping to the nested form (migration helper).

Target: ``{region:int -> {subregion:int -> name:str, "name": group_name:str}}``.
Accepts and migrates: the old flat ``{"(1, 2)": "C2"}`` form, a JSON-loaded nested
form with string keys ``{"1": {"2": "C2", "name": "Spine"}}``, and the already-nested
form (idempotent).
"""
import ast

if not d:
return {}
out: dict[int, dict] = {}
# old flat format: every key is a "(region, subregion)" tuple string
if all(isinstance(k, str) and k.strip().startswith("(") for k in d):
for k, name in d.items():
region, subregion = ast.literal_eval(k)
out.setdefault(int(region), {})[int(subregion)] = name
return out
# nested format (region keys may be JSON strings)
for region, sub in d.items():
target = out.setdefault(int(region), {})
if isinstance(sub, dict):
for s, name in sub.items():
target[_GROUP_NAME_KEY if s == _GROUP_NAME_KEY else int(s)] = name
else: # degenerate {region: name} -> treat as the region group name
target[_GROUP_NAME_KEY] = sub
return out


def label_name_dict(info: dict) -> dict[int, dict]:
"""Return the normalized nested ``label_name`` dict from ``info`` (normalizing in place)."""
ln = info.get(LABEL_NAME)
if not ln:
return {}
norm = normalize_label_name(ln)
info[LABEL_NAME] = norm # cache the normalized form back into info
return norm


def _id_of(x) -> int:
"""Resolve an Enum member / int / numeric string to its integer id."""
if isinstance(x, Enum):
return x.value
return int(x)


class POI_Descriptor(AbstractSet, MutableMapping):
"""Two-level dictionary that maps ``(region, subregion)`` pairs to 3-D coordinates.

Expand Down Expand Up @@ -169,10 +247,10 @@ def __set__(self, obj, value) -> None:
setattr(obj, self._name, int(value))

def copy(self) -> POI_Descriptor:
"""Return a deep copy of this descriptor."""
"""Return a deep copy of this descriptor (keeping its name<->id ``definition``)."""
from copy import deepcopy

return POI_Descriptor(default=deepcopy(self.pois))
return POI_Descriptor(default=deepcopy(self.pois), definition=self.definition)

def _sort(self: Self, inplace=True, order_dict: dict | None = None):
"""Sort vertebra dictionary by sorting_list."""
Expand Down Expand Up @@ -676,17 +754,20 @@ def __iter__(self):
"""Iterate over all ``(region, subregion)`` key pairs."""
return iter(self.centroids.keys())

def _resolve_key(self, key: POI_ID) -> tuple[int, int]:
"""Resolve any key (ids, Enums, or ``level_one_info``/``level_two_info`` names) to ``(region, subregion)`` ints."""
return unpack_poi_id(key, self.centroids.definition, self.level_one_info, self.level_two_info)

def __contains__(self, key: POI_ID) -> bool:
key = unpack_poi_id(key, self.centroids.definition)
return key in self.centroids
return self._resolve_key(key) in self.centroids

def __getitem__(self, key: POI_ID) -> COORDINATE:
return tuple(self.centroids[key])
return tuple(self.centroids[self._resolve_key(key)])

def __setitem__(self, key: POI_ID, value: tuple[float, float, float] | Sequence[float] | np.ndarray) -> None:
if len(value) != DIMENSIONS:
raise ValueError(value)
self.centroids[key] = tuple(value)
self.centroids[self._resolve_key(key)] = tuple(value)

def __len__(self) -> int:
return self.centroids.__len__()
Expand Down Expand Up @@ -754,6 +835,57 @@ def keys_subregion(self, sort: bool = False) -> list[int]:
self.sort()
return list(self.centroids.keys_subregion())

def label_name(self, region, subregion) -> str | None:
"""Return the human-readable name of the point ``(region, subregion)``.

A custom name in ``info["label_name"]`` takes priority over the ``level_two_info`` enum
name (and finally falls back to the raw id as a string). Accepts ints or Enum members.
A warning is raised when a custom name is set that differs from the ``level_two_info``
name for that id (when ``level_two_info`` is given).
"""
import warnings

region_i, subregion_i = _id_of(region), _id_of(subregion)
custom = label_name_dict(self.info).get(region_i, {}).get(subregion_i)
enum_name = None
if self.level_two_info not in (None, Any):
n = self.level_two_info._get_name(subregion_i, no_raise=True)
enum_name = n if n != str(subregion_i) else None
if custom is not None:
if enum_name is not None and custom != enum_name:
warnings.warn(
f"label_name {custom!r} for subregion {subregion_i} is not the {self.level_two_info.__name__} name {enum_name!r}",
stacklevel=2,
)
return custom
return enum_name if enum_name is not None else str(subregion_i)

def set_label_name(self, region, subregion, name: str) -> None:
"""Set a custom name for the point ``(region, subregion)`` in ``info["label_name"]``."""
d = label_name_dict(self.info)
d.setdefault(_id_of(region), {})[_id_of(subregion)] = name
self.info[LABEL_NAME] = d

def level_one_name(self, region) -> str | None:
"""Return the group (level-one) name of ``region``.

A custom group name in ``info["label_name"][region]["name"]`` takes priority over the
``level_one_info`` enum name. Accepts an int or Enum member.
"""
region_i = _id_of(region)
custom = label_name_dict(self.info).get(region_i, {}).get(_GROUP_NAME_KEY)
if custom is not None:
return custom
if self.level_one_info not in (None, Any):
return self.level_one_info._get_name(region_i, no_raise=True)
return str(region_i)

def set_level_one_name(self, region, name: str) -> None:
"""Set a custom group (level-one) name for ``region`` in ``info["label_name"]``."""
d = label_name_dict(self.info)
d.setdefault(_id_of(region), {})[_GROUP_NAME_KEY] = name
self.info[LABEL_NAME] = d

def values(self, sort: bool = False) -> list[COORDINATE]:
"""Return all stored ``(x, y, z)`` coordinate tuples.

Expand Down Expand Up @@ -1030,16 +1162,23 @@ def join_left(self, pois: Self, inplace=False, _right_join=False) -> Self:
Self: The combined set of centroids, either in-place or as a new set, depending on the 'inplace' parameter.
"""
ctd_list = self.centroids
if "label_name" in pois.info and "label_name" not in self.info:
self.info["label_name"] = {}
src_ln = label_name_dict(pois.info)
if not inplace:
ctd_list = ctd_list.copy()
dst_ln = label_name_dict(self.info) if src_ln else None
if dst_ln is not None:
self.info[LABEL_NAME] = dst_ln
for x, y, c in pois.items():
if (x, y) in self and not _right_join:
continue
ctd_list[x, y] = c
if "label_name" in pois.info and f"({x}, {y})" in pois.info["label_name"]:
self.info["label_name"][f"({x}, {y})"] = pois.info["label_name"][f"({x}, {y})"]
if dst_ln is not None:
name = src_ln.get(x, {}).get(y)
if name is not None:
dst_ln.setdefault(x, {})[y] = name
grp = src_ln.get(x, {}).get(_GROUP_NAME_KEY)
if grp is not None:
dst_ln.setdefault(x, {}).setdefault(_GROUP_NAME_KEY, grp)
if inplace:
return self
return self.copy(ctd_list)
Expand Down
9 changes: 8 additions & 1 deletion TPTBox/core/poi_fun/poi_global.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,14 @@ def to_other(self, msk: Has_Grid, verbose=False) -> poi.POI:
log.print(v, "-->", v_out)
out[k1, k2] = tuple(v_out)

return poi.POI(centroids=out, **msk._extract_affine(), info=self.info, format=self.format)
return poi.POI(
centroids=out,
**msk._extract_affine(),
info=self.info,
format=self.format,
level_one_info=self.level_one_info,
level_two_info=self.level_two_info,
)

def copy(self, centroids: POI_Descriptor | None = None) -> Self:
"""Return a deep copy of this ``POI_Global``.
Expand Down
13 changes: 9 additions & 4 deletions TPTBox/core/poi_fun/save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# from TPTBox import POI, POI_Global
from TPTBox.core import bids_files
from TPTBox.core.nii_poi_abstract import Has_Grid
from TPTBox.core.poi_fun.poi_abstract import POI_Descriptor
from TPTBox.core.poi_fun.poi_abstract import _GROUP_NAME_KEY, LABEL_NAME, POI_Descriptor, label_name_dict, normalize_label_name
from TPTBox.core.vert_constants import (
AX_CODES,
COORDINATE,
Expand Down Expand Up @@ -207,7 +207,8 @@ def _poi_to_dict_list( # noqa: C901

for k, v in ctd.info.items():
if k not in ori:
ori[k] = v
# always persist label_name in the new nested format {region:{sub:name,"name":group}}
ori[k] = normalize_label_name(v) if k == LABEL_NAME else v

dict_list: list[_Orientation | (_Point3D | dict)] = [ori]

Expand Down Expand Up @@ -353,6 +354,9 @@ def load_poi(ctd_path: POI_Reference, verbose=True) -> POI | POI_Global: # noqa
level_one_info = _register_lvl[dict_list[0].get("level_one_info", Vertebra_Instance.__name__)]
level_two_info = _register_lvl[dict_list[0].get("level_two_info", Location.__name__)]
info = {k: v for k, v in dict_list[0].items() if k not in ctd_info_blacklist}
if LABEL_NAME in info:
# migrate old flat {"(1, 2)": "C2"} / JSON string keys -> nested {1: {2: "C2", "name": ...}}
info[LABEL_NAME] = normalize_label_name(info[LABEL_NAME])
if format_ in (FORMAT_GLOBAL, FORMAT_PLST):
from TPTBox import POI_Global

Expand Down Expand Up @@ -649,8 +653,9 @@ def _load_mkr_POI(dict_mkr: dict) -> POI_Global:
description = control_points.get("description", region)
associatedNodeID = control_points.get("associatedNodeID", description)
label_group_name[region] = associatedNodeID
label_name.setdefault(region, {})[_GROUP_NAME_KEY] = associatedNodeID

label_name[str((region, subregion))] = label
label_name.setdefault(region, {})[subregion] = label
assert itk_coords is not None, "itk_coords not set"
from TPTBox import POI_Global

Expand Down Expand Up @@ -821,7 +826,7 @@ def _load_landmark_txt(path: Path) -> list:
id_ = label_group_name[current_group]
new_id = len(points[id_]) + 1
points[id_][new_id] = coords
label_name[str((id_, new_id))] = key
label_name.setdefault(id_, {})[new_id] = key
if len(label_name) != 0:
header["label_name"] = label_name
if len(label_group_name) != 0:
Expand Down
8 changes: 6 additions & 2 deletions TPTBox/core/poi_fun/save_mkr.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import numpy as np
from typing_extensions import NotRequired

from TPTBox.core.poi_fun.poi_abstract import _GROUP_NAME_KEY, label_name_dict
from TPTBox.logger.log_file import log
from TPTBox.mesh3D.mesh_colors import RGB_Color, get_color_by_label

Expand Down Expand Up @@ -404,15 +405,18 @@ def get_desc(self: POI_Global, region: int, subregion: int) -> tuple[str, str, s
Tuple ``(name, name2, label)`` where ``name`` and ``name2`` are the
group/region name and ``label`` is the subregion label string.
"""
label = self.info.get("label_name", {}).get(str((region, subregion)))
_ln = label_name_dict(self.info)
label = _ln.get(region, {}).get(subregion)
if label is None:
label = str(subregion)
try:
label = self.level_two_info(subregion).name
except Exception:
label = str(subregion)
try:
if region in self.info.get("label_group_name", {}):
if _ln.get(region, {}).get(_GROUP_NAME_KEY) is not None:
name2 = _ln[region][_GROUP_NAME_KEY]
elif region in self.info.get("label_group_name", {}):
name2 = self.info["label_group_name"][region]
else:
name2 = self.level_one_info(region).name
Expand Down
Loading
Loading