diff --git a/TPTBox/core/dicom/dicom2nii_utils.py b/TPTBox/core/dicom/dicom2nii_utils.py index 3f512c1..cc835b4 100755 --- a/TPTBox/core/dicom/dicom2nii_utils.py +++ b/TPTBox/core/dicom/dicom2nii_utils.py @@ -11,6 +11,7 @@ from tqdm import tqdm from TPTBox import BIDS_FILE, NII, BIDS_Global_info, Print_Logger +from TPTBox.core.internal.nii_help import save_json as secure_save_json # source_folder = Path("/DATA/NAS/datasets_source/epi/NAKO/NAKO-732_Nachlieferung_20_25/") source_folder = Path("/DATA/NAS/datasets_source/epi/NAKO/NAKO_2D_issue/") @@ -309,21 +310,12 @@ def save_json(json_ob: dict, file: str | Path, check_exist: bool = False, overri FileExistsError: When *check_exist* is ``True`` and the existing file contains different content. """ - - def convert(obj): - if isinstance(obj, np.integer): - return int(obj) - if isinstance(obj, np.floating): - return float(obj) - raise TypeError(type(obj)) - if check_exist and test_name_conflict(json_ob, file): raise FileExistsError(file) if Path(file).exists() and not override: return True Print_Logger().on_save("save json with grid info", file) - with open(file, "w") as file_handel: - json.dump(json_ob, file_handel, indent=4, default=convert) + secure_save_json(file, json_ob, indent=4) return False diff --git a/TPTBox/core/dicom/dicom_extract.py b/TPTBox/core/dicom/dicom_extract.py index 93519da..459c396 100644 --- a/TPTBox/core/dicom/dicom_extract.py +++ b/TPTBox/core/dicom/dicom_extract.py @@ -28,26 +28,51 @@ sys.path.append(str(Path(__file__).parent)) +import string + from TPTBox.core.dicom.dicom2nii_utils import get_json_from_dicom, load_json, save_json, test_name_conflict logger = Print_Logger() -def _inc_key(keys: dict, inc: int = 1) -> None: - """Increment the sequence key inside *keys* by *inc*.""" - k = "sequ" +def _next_letter_suffix(s: str, inc: int = 1) -> str: + """Increment a letter suffix: a -> b, z -> aa, aa -> ab.""" + alphabet = string.ascii_lowercase + # Convert to a number (base 26, 1-indexed) + n = 0 + for c in s: + n = n * 26 + (ord(c) - ord("a") + 1) + n += inc + # Convert back to letters + result = [] + while n > 0: + n -= 1 + result.append(alphabet[n % 26]) + n //= 26 + return "".join(reversed(result)) + + +def _inc_key(keys: dict, inc: int = 1, k="sequ") -> None: + """Increment the sequence key inside *keys* by appending letter suffixes.""" if k not in keys: - keys[k] = 0 + keys[k] = "0" + value = str(keys[k]) try: - v = int(keys[k]) - keys[k] = str(v + int(inc)) - except Exception: - try: - a, b = str(keys[k]).rsplit("-", maxsplit=2) - except Exception: - a = keys[k] - b = 0 - keys[k] = a + "-" + str(int(b) + int(inc)) + # Pure number: 100 -> 100-a + int(value) + keys[k] = f"{value}-a" + return # noqa: TRY300 + except ValueError: + pass + + try: + base, suffix = value.rsplit("-", maxsplit=1) + if suffix.isalpha(): + keys[k] = f"{base}-{_next_letter_suffix(suffix, inc)}" + else: + keys[k] = f"{base}-a" + except ValueError: + keys[k] = f"{value}-a" def _generate_bids_path( diff --git a/TPTBox/core/internal/elastic_deform.py b/TPTBox/core/internal/elastic_deform.py index 29aecbf..f3099c3 100644 --- a/TPTBox/core/internal/elastic_deform.py +++ b/TPTBox/core/internal/elastic_deform.py @@ -1,6 +1,7 @@ import time -import elasticdeform +# pip install elasticdeform +import elasticdeform # See https://github.com/gvtulder/elasticdeform/issues/24 to install this for >2.x import numpy as np from numpy.typing import NDArray @@ -49,19 +50,23 @@ def deformed_nii( deformed_data = deformed_NII(arr_dic, sigma=sigma, points=points) """ if sigma is None or points is None: + np.random.seed(None) sigma, points = get_random_deform_parameter(deform_factor=deform_factor) - print("deformation parameter sigma = ", round(sigma, 4), "; n_points = ", points) t = time.time() - values = list(nii_dic.values()) + # Deform + max_v = None if joint_normalize: max_v = max([img.max() for img in nii_dic.values() if not img.seg]) nii_dic = {k: img if img.seg else img.set_dtype(np.float32) / max_v for k, img in nii_dic.items()} elif normalize: - nii_dic = {k: img if img.seg else img.set_dtype(np.float32).normalize() for k, img in nii_dic.items()} + max_v = {k: None if img.seg else (float(max(img.max() - img.min(), 1)), float(img.min())) for k, img in nii_dic.items()} + nii_dic = {k: img if img.seg else (img.set_dtype(np.float32) - max_v[k][1]) / max_v[k][0] for k, img in nii_dic.items()} else: nii_dic = {k: img if img.seg else img.set_dtype(np.float32) for k, img in nii_dic.items()} + + values = list(nii_dic.values()) assert sigma is not None p = deform_padding out: list[NDArray] = elasticdeform.deform_random_grid( @@ -74,6 +79,10 @@ def deformed_nii( for (k, nii), arr in zip(nii_dic.items(), out, strict=True): out2[k] = nii.set_array(arr[p:-p, p:-p, p:-p]) print("Deformation took", round(time.time() - t, 1), "Seconds") + if joint_normalize: + out2 = {k: img if img.seg else img.set_dtype(np.float32) * max_v for k, img in out2.items()} + elif normalize: + out2 = {k: img if img.seg else ((img.set_dtype(np.float32) * max_v[k][0]) + max_v[k][1]) for k, img in out2.items()} return out2 diff --git a/TPTBox/core/internal/nii_help.py b/TPTBox/core/internal/nii_help.py index b677bf8..d51e636 100644 --- a/TPTBox/core/internal/nii_help.py +++ b/TPTBox/core/internal/nii_help.py @@ -1,5 +1,6 @@ from __future__ import annotations +import json import shutil from collections.abc import Callable from functools import wraps @@ -14,10 +15,11 @@ if TYPE_CHECKING: from TPTBox.core.nii_poi_abstract import Has_Grid from TPTBox.core.nii_wrapper import NII + from TPTBox.core.vert_constants import AFFINE, MODES, SHAPE, ZOOMS, Sentinel, _supported_img_files -def secure_save(func) -> Callable: +def secure_save(func, *, file_types=tuple(_supported_img_files)) -> Callable: """Decorator that writes to a `.backup` file first and restores it if saving fails. Steps: (1) back up existing file, (2) call the wrapped save function, (3) delete backup on @@ -48,7 +50,7 @@ def save_to_file(self, file: Path, data: Any): @wraps(func) def wrapper(self, file: str | Path | bids_files.BIDS_FILE, *args, **kwargs): if isinstance(file, bids_files.BIDS_FILE): - for file_type in _supported_img_files: + for file_type in file_types: if file_type in file.file: file = file.file[file_type] break @@ -81,6 +83,84 @@ def wrapper(self, file: str | Path | bids_files.BIDS_FILE, *args, **kwargs): return wrapper +def _convert(obj): + if isinstance(obj, np.integer): + return int(obj) + if isinstance(obj, np.floating): + return float(obj) + if isinstance(obj, np.ndarray): + return obj.tolist() + if isinstance(obj, Path): + return str(obj.absolute()) + raise TypeError(type(obj)) + + +def _save_json(data, filepath: str | Path | bids_files.BIDS_FILE, indent=4, convert=_convert): + + if isinstance(filepath, bids_files.BIDS_FILE): + if "json" in filepath.file: + filepath = filepath.file["json"] + else: + nf = filepath.get_nii_file() + if nf is not None: + filepath = (nf.parent) / (nf.name.split(".")[0] + ".json") + else: + nf = next(iter(filepath.file.values())) + filepath = (nf.parent) / (nf.name.rsplit(".", maxsplit=1)[0] + ".json") + # print(markups[-1].get("display")) + with open(filepath, "w") as f: + json.dump(data, f, indent=indent, default=convert) + + +def save_json(filepath: str | Path | bids_files.BIDS_FILE, data, indent=4, convert=_convert) -> None: + """Safely save a Python object as a JSON file with automatic backup protection. + + This function writes JSON data to disk using a safe save mechanism: + if the target file already exists, it is first moved to a `.backup` + file. If writing succeeds, the backup is removed. If writing fails, + the original file is restored. + + The function supports flexible input types for the target path: + - str or Path: written directly to disk + - bids_files.BIDS_FILE: resolved to an appropriate `.json` path + + Non-JSON-serializable types are handled via a custom converter + that supports: + - numpy integers → int + - numpy floats → float + - numpy arrays → list + - pathlib.Path → absolute string path + + Args: + filepath (str | Path | bids_files.BIDS_FILE): + Target file path or BIDS file container. + data (Any): + Python object to serialize into JSON. + indent (int, optional): + Pretty-print indentation level. Default is 4. + convert (callable, optional): + Custom serialization function for unsupported types. + Defaults to `_convert`. + + Returns: + None + + Notes: + - Uses `secure_save` to ensure atomic write semantics with backup/restore. + - If `filepath` is a `BIDS_FILE`, the `.json` path is inferred from: + 1. explicit "json" entry in the file map + 2. associated NIfTI file path + 3. fallback to any available file in the container + - This function is intended for structured metadata and annotation storage. + + Raises: + Exception: + Propagates any error raised during serialization or file writing, + after attempting automatic recovery of the original file. + """ + return secure_save(_save_json, file_types=["json"])(data, filepath, indent=indent, convert=convert) + + def _resample_from_to( from_img: NII, to_img: tuple[SHAPE, AFFINE, ZOOMS] | Has_Grid, diff --git a/TPTBox/core/internal/train_nnUnet/__init__.py b/TPTBox/core/internal/train_nnUnet/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/TPTBox/core/internal/train_nnUnet/_prep_ds.py b/TPTBox/core/internal/train_nnUnet/_prep_ds.py new file mode 100644 index 0000000..64f5e51 --- /dev/null +++ b/TPTBox/core/internal/train_nnUnet/_prep_ds.py @@ -0,0 +1,402 @@ +import hashlib +import json +import random +import sys +from functools import partial +from math import ceil +from multiprocessing.pool import Pool as Pool_type +from pathlib import Path +from typing import Literal + +import numpy as np +import torch + +from TPTBox import BIDS_FILE, NII, Print_Logger, to_nii +from TPTBox.core.internal.elastic_deform import deformed_nii + +sys.path.append(str(Path(__file__).parent.parent)) +sys.path.append(str(Path(__file__).parents[1])) + +logger = Print_Logger() + + +def run(p): + p() + + +####################################################### +def set_up_dataset( + idx: int, + dataset_mapping: dict[str, int], + spacing: tuple[float, ...], + is_ct=False, + orientation=("R", "A", "S"), + turn_on_mirroring=False, + turn_on_data_aug_5=False, + nn_trainier: Literal[ + "nnUNetTrainer", + "nnUNetTrainerNoMirroring", + "nnUNetTrainerDA5", + "nnUNetTrainerDAExtGPU", + ] + | None = None, + AUGLAB_PARAMS_GPU_JSON="transform_params_gpu_default01-23.json", + ignore=False, + num_input=1, + base="/DATA/NAS/FASTDATA/robert/nnUNet", + **setting, +): + out_base = Path(base, f"nnUNet_raw/Dataset{idx:03}/") + Path(out_base).mkdir(exist_ok=True, parents=True) + + files = {"0": "ct"} if is_ct else {"0": "any"} + for i in range(1, num_input): + files[str(i)] = "any" + dataset_mapping["background"] = 0 + + if ignore: + dataset_mapping["ignore"] = len(dataset_mapping) # "ignore": 999 + data = { + "channel_names": files, + "labels": dataset_mapping, + "numTraining": None, + "file_ending": ".nii.gz", + "reference": "deep-spine.de", + "licence": "https://github.com/robert-graf/VibeSegmentator", + # "regions_class_order": [i for i in sorted(dataset_mapping.values()) if i not in [0]], + "dataset_release": "1.0", + "orientation": orientation, + "nnUNetTrainer": nn_trainier, + **setting, + } + if nn_trainier == "nnUNetTrainerDAExtGPU": + data["AUGLAB_PARAMS_GPU_JSON"] = AUGLAB_PARAMS_GPU_JSON + if turn_on_mirroring: + data["turn_on_mirroring"] = turn_on_mirroring + if turn_on_data_aug_5: + data["turn_on_data_aug_5"] = turn_on_data_aug_5 + + if spacing is not None and -1 not in spacing: + # data["spacing"] = tuple(str(i) for i in spacing)[::-1] + data["resolution_range"] = tuple(str(i) for i in spacing) + + return data, out_base + + +def add_file( + p: Pool_type, + img: Path | list[Path] | str | list[str], + seg: Path, + dataset_settings: dict, + root: Path, + target_height_half=None, + defrom=False, + axis="S", + deform_factor=1.0, + defrom_count=1, + mirror=None, + degeneration_count=0, + mapping=None, + auto_crop=None, + ignore_crop=None, + defrom_p=1.0, + **qargs, # noqa: ARG001 +): + if p is not None: + return partial( + _add_file_async, + img, + seg, + dataset_settings, + root, + target_height_half, + defrom, + axis, + deform_factor, + defrom_count, + mirror, + degeneration_count, + mapping=mapping, + auto_crop=auto_crop, + ignore_crop=ignore_crop, + defrom_p=defrom_p, + ) + return _add_file_async( + img, + seg, + dataset_settings, + root, + target_height_half, + defrom, + axis, + deform_factor, + defrom_count, + mirror, + degeneration_count, + mapping=mapping, + auto_crop=auto_crop, + ignore_crop=ignore_crop, + defrom_p=defrom_p, + ) + + +def finalize_ds(dataset_settings, out_base: Path): + dataset_settings["numTraining"] = len(list((out_base / "labelsTr").iterdir())) + with open(out_base / "dataset.json", "w") as f: + json.dump(dataset_settings, f, indent=4) + logger.on_ok(f"Finished dataset generation. Num Sampels is {dataset_settings['numTraining']}") + + +####################################################### + + +def _add_file_async( + img: Path | list[Path] | str | list[str], + seg: Path, + dataset_settings: dict, + root: Path, + target_height_half=None, + defrom=False, + axis="S", + deform_factor=1.0, + defrom_count=1, + mirror=None, + degeneration_count=0, + delete_brocken=True, + mapping=None, + auto_crop=None, + ignore_crop=None, + defrom_p=1.0, +): + # check if image exists (asume on split, safe 0 split last) + if _get_file_name(root, img, 0, defrom, defrom_count, mirror is not None, degeneration_count).exists(): + # try: + # if to_nii(seg, True).max() > len(dataset_settings["labels"]): + # seg.unlink(missing_ok=True) + # logger.on_debug(f"wrong label, {to_nii(seg, True).unique()}") + # else: + logger.on_debug(f"Skip {seg}, exists") + return + # except Exception: + # seg.unlink(missing_ok=True) + if delete_brocken: + if isinstance(img, Path): + img = [img] + for i in img: + try: + to_nii(i, True).max() + except Exception: + [Path(i).unlink(missing_ok=True) for i in img] + Path(seg).unlink(missing_ok=True) + return + try: + to_nii(seg, True).max() + except Exception: + [Path(i).unlink(missing_ok=True) for i in img] + Path(seg).unlink(missing_ok=True) + return + # load image + img_nii = [to_nii(img, False)] if isinstance(img, (Path, str, BIDS_FILE)) else [to_nii(i, False) for i in img] + seg_nii = to_nii(seg, True) + + if mapping is not None: + seg_nii = seg_nii.map_labels(mapping) + # spacing + spacing = dataset_settings.get("resolution_range", (-1, -1, -1)) + spacing = tuple(float(f) for f in spacing) + if auto_crop is not None: + print(f"{auto_crop=}, {ignore_crop=}") + seg_nii = seg_nii.reorient(dataset_settings["orientation"]).rescale(spacing) + crop = seg_nii.compute_crop(0, auto_crop) + if ignore_crop: + crop = list(crop) + for d in ignore_crop: + crop[seg_nii.get_axis(d)] = slice(0, None) + crop = tuple(crop) + print(crop) + seg_nii.apply_crop_(crop) + img_nii[0] = img_nii[0].resample_from_to_(seg_nii) + else: + img_nii[0] = img_nii[0].reorient(dataset_settings["orientation"]).rescale(spacing) + img_nii = [i.resample_from_to_(img_nii[0]) for i in img_nii] + seg_nii.resample_from_to_(img_nii[0]) + + for split_id, offset in _split_task(seg_nii, target_height_half, axis="S"): + for mirror_ in [False] if mirror is None else [False, True]: + for deg in range(degeneration_count + 1): + c = (1 if not defrom else defrom_count + 1) if defrom_p >= 1 or defrom_p <= random.random() else 1 + for defrom_ in range(c): + out_name = _get_file_name(root, img, split_id, defrom_ != 0, defrom_, mirror_, deg) + _make_sample( + img_nii, + seg_nii, + out_name, + offset, + defrom=defrom_ != 0, + degen=deg != 0, + deform_factor=deform_factor, + target_height_half=target_height_half, + axis=axis, + mirror=mirror if mirror_ else None, + ) + + +def _get_file_name( + root, + img: Path | list[Path] | str | list[str], + split: int, + defrom: bool, + defrom_count, + mirror, + degeneration_count, +) -> Path: + if not isinstance(img, (Path, str)): + img = Path(img[0]) + addendum = _deterministic_hash(str(img)) # [:9] + return ( + root + / "labelsTr" + / f"{img.name.split('.')[0]}_{addendum}_{split}{f'_d{defrom_count}' if defrom else ''}{'_m' if mirror else ''}{f'_deg{degeneration_count}' if degeneration_count != 0 else ''}.nii.gz" + ) + + +def _split_task(seg_nii: NII, target_height_half, axis="S"): + if target_height_half is None: + return [(0, 0)] + shape = seg_nii.shape + axis = seg_nii.get_axis(axis) # type: ignore + x = shape[axis] + h = ceil(x / target_height_half) + h = ceil(x / h) + out = [] + for i in range(99999): + out.append((i, i * h)) + if (i * h + target_height_half) >= x: + break + return reversed(out) + + +def _make_sample( + img_nii: list[NII], + seg_nii, + outpath: Path, + offset, + defrom, + degen, + deform_factor, + target_height_half, + axis="S", + mirror: list | None = None, +): + stem = outpath.name.split(".nii.gz")[0] + + if outpath.exists() and (outpath.parent.parent / f"imagesTr/{stem}_{0:04}.nii.gz").exists(): + try: + to_nii(outpath, True).max() + to_nii(outpath.parent.parent / f"imagesTr/{stem}_{0:04}.nii.gz", True).max() + logger.on_ok("Skip:", outpath.name) + return # noqa: TRY300 + except Exception: + pass + try: + logger.on_ok(outpath.name) + # split image + # deform + out_d = extract_image( + img_nii, + seg_nii, + offset, + deform=defrom, + degen=degen, + target_height_half=target_height_half, + deform_factor=deform_factor, + axis=axis, + mirror=mirror, + ) + img_num = 0 + for name, nii in out_d.items(): + if name == "seg": + out = outpath + else: + img_num = int(name.replace("img", "")) + out = outpath.parent.parent / f"imagesTr/{stem}_{img_num:04}.nii.gz" + + Path(out).parent.mkdir(exist_ok=True, parents=True) + assert out_d["seg"].shape == nii.shape, (out_d, nii) + nii.save(out) + except Exception: + logger.on_fail("FAILED", outpath) + logger.print_error() + raise + + +def extract_image( + img_nii: list[NII], + nii_seg: NII, + offset, + deform, + degen, + target_height_half=None, + crop_top=0, + deform_factor=1, + axis="S", + mirror: list | None = None, +): + assert img_nii[0].assert_affine(nii_seg) + axis = img_nii[0].get_axis(axis) # type: ignore + img_nii = [i.clone() for i in img_nii] + nii_seg = nii_seg.clone() + + if target_height_half is None: + offset = None + else: + shape = img_nii[0].shape + offset = max(min(offset, shape[axis] - 2 * target_height_half), 0) + max_offset = min(offset + 2 * target_height_half, shape[axis] - crop_top) + if offset >= 0: + crop = [slice(0, shape[0]), slice(0, shape[1]), slice(0, shape[2])] + crop[axis] = slice(offset, max_offset) + [i.apply_crop_(tuple(crop)) for i in img_nii] + nii_seg.apply_crop_(tuple(crop)) + if mirror is not None: + mapping = {} + for a, b in mirror: + mapping[a] = b + mapping[b] = a + + nii_seg.map_labels_(mapping) + nii_seg = nii_seg.flip("R", keep_global_coords=False) + img_nii = [i.flip("R", keep_global_coords=False) for i in img_nii] + + if degen: + img_nii = [random_transform(i) for i in img_nii] + niis: dict[str, NII] = {} + for i, _nii in enumerate(img_nii): + niis[f"img{i}"] = _nii + niis["seg"] = nii_seg + + if deform: + niis = deformed_nii(niis, deform_factor=deform_factor, normalize=True) # type: ignore + + return niis + + +def _deterministic_hash(string: str) -> str: + return hashlib.md5(string.encode()).hexdigest() + + +@torch.no_grad() +def random_transform(img: NII, prob=0.35): + raise NotImplementedError("Online degeneration, has been removed") + from training import transforms3D + + t = torch.tensor(img.get_array().astype(dtype=np.float32)) + min_v = t.min() + max_v = t.max().item() + tens = {"img": (t - min_v) / max_v} + tens = transforms3D.RandomBlur(prob=prob, std=(0.5, 3), kernel_size=9)(tens) + tens = transforms3D.RandomNoise(prob=prob, std=(0.0, 0.05))(tens) + tens = transforms3D.RandomBiasField(coefficients=(0, 0.75), prob=prob)(tens) + tens = transforms3D.ColorJitter3D_(prob=prob * 2)(tens) + arr = (tens["img"] * max_v + min_v).numpy().astype(img.dtype) + return img.set_array_(arr) diff --git a/TPTBox/core/internal/train_nnUnet/fastProcessor.py b/TPTBox/core/internal/train_nnUnet/fastProcessor.py new file mode 100644 index 0000000..7417593 --- /dev/null +++ b/TPTBox/core/internal/train_nnUnet/fastProcessor.py @@ -0,0 +1,242 @@ +# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from copy import deepcopy +from pathlib import Path + +import blosc2 +import nnunetv2.experiment_planning.plan_and_preprocess_api as pp +import numpy as np +from batchgenerators.utilities.file_and_folder_operations import join, load_json, maybe_mkdir_p, write_pickle +from nnunetv2.paths import nnUNet_preprocessed, nnUNet_raw +from nnunetv2.preprocessing.preprocessors.default_preprocessor import DefaultPreprocessor +from nnunetv2.utilities.plans_handling.plans_handler import ConfigurationManager, PlansManager + + +def preprocess( + dataset_ids: list[int], + plans_identifier: str = "nnUNetPlans", + configurations: tuple[str] | list[str] = ("2d", "3d_fullres", "3d_lowres"), # type: ignore + num_processes: int | tuple[int, ...] | list[int] = (8, 4, 8), + compress=True, + verbose: bool = False, +) -> None: + """Run nnunet data-processing.""" + for d in dataset_ids: + _preprocess_dataset(d, plans_identifier, configurations, num_processes, compress, verbose) + + +def _preprocess_dataset( + dataset_id: int, + plans_identifier: str = "nnUNetPlans", + configurations: tuple[str] | list[str] = ("2d", "3d_fullres", "3d_lowres"), # type: ignore + num_processes: int | tuple[int, ...] | list[int] = (8, 4, 8), + compress=True, + verbose: bool = False, +) -> None: + if not isinstance(num_processes, list): + num_processes = list(num_processes) # type: ignore + if len(num_processes) == 1: + num_processes = num_processes * len(configurations) + if len(num_processes) != len(configurations): + raise RuntimeError( + f"The list provided with num_processes must either have len 1 or as many elements as there are " + f"configurations (see --help). Number of configurations: {len(configurations)}, length " + f"of num_processes: " + f"{len(num_processes)}" + ) + + dataset_name = pp.convert_id_to_dataset_name(dataset_id) + + print(f"Preprocessing dataset {dataset_name}") + plans_file = join(nnUNet_preprocessed, dataset_name, plans_identifier + ".json") # type: ignore + plans_manager = PlansManager(plans_file) + for n, c in zip(num_processes, configurations, strict=False): + print(f"Configuration: {c}...") + if c not in plans_manager.available_configurations: + print(f"INFO: Configuration {c} not found in plans file {plans_identifier + '.json'} of dataset {dataset_name}. Skipping.") + continue + patch_size = plans_manager.get_configuration(c).patch_size + preprocessor = FastPreprocessor(verbose=verbose, compress=compress, patch_size=patch_size) + preprocessor.run(dataset_id, c, plans_identifier, num_processes=n) + + # copy the gt to a folder in the nnUNet_preprocessed so that we can do validation even if the raw data is no + # longer there (useful for compute cluster where only the preprocessed data is available) + from distutils.file_util import copy_file + + maybe_mkdir_p(join(nnUNet_preprocessed, dataset_name, "gt_segmentations")) # type: ignore + dataset_json = load_json(join(nnUNet_raw, dataset_name, "dataset.json")) # type: ignore + dataset = pp.get_filenames_of_train_images_and_targets(join(nnUNet_raw, dataset_name), dataset_json) # type: ignore + # only copy files that are newer than the ones already present + for k in dataset: + copy_file( + dataset[k]["label"], + join(nnUNet_preprocessed, dataset_name, "gt_segmentations", k + dataset_json["file_ending"]), # type: ignore + update=True, # type: ignore + ) # type: ignore + + +def _comp_blosc2_params( + image_size: tuple[int, int, int, int], + patch_size: tuple[int, int] | tuple[int, int, int], + bytes_per_pixel: int = 4, # 4 byte are float32 + l1_cache_size_per_core_in_bytes=32768, # 1 Kibibyte (KiB) = 2^10 Byte; 32 KiB = 32768 Byte + l3_cache_size_per_core_in_bytes=1441792, + # 1 Mibibyte (MiB) = 2^20 Byte = 1.048.576 Byte; 1.375MiB = 1441792 Byte + safety_factor: float = 0.8, # we dont will the caches to the brim. 0.8 means we target 80% of the caches +): + """Computes a recommended block and chunk size for saving arrays with blosc v2. + + Bloscv2 NDIM doku: "Remember that having a second partition means that we have better flexibility to fit the + different partitions at the different CPU cache levels; typically the first partition (aka chunks) should + be made to fit in L3 cache, whereas the second partition (aka blocks) should rather fit in L2/L1 caches + (depending on whether compression ratio or speed is desired)." + (https://www.blosc.org/posts/blosc2-ndim-intro/) + -> We are not 100% sure how to optimize for that. For now we try to fit the uncompressed block in L1. This + might spill over into L2, which is fine in our books. + + Note: this is optimized for nnU-Net dataloading where each read operation is done by one core. We cannot use threading + + Cache default values computed based on old Intel 4110 CPU with 32K L1, 128K L2 and 1408K L3 cache per core. + We cannot optimize further for more modern CPUs with more cache as the data will need be be read by the + old ones as well. + + Args: + patch_size: Image size, must be 4D (c, x, y, z). For 2D images, make x=1 + patch_size: Patch size, spatial dimensions only. So (x, y) or (x, y, z) + bytes_per_pixel: Number of bytes per element. Example: float32 -> 4 bytes + l1_cache_size_per_core_in_bytes: The size of the L1 cache per core in Bytes. + l3_cache_size_per_core_in_bytes: The size of the L3 cache exclusively accessible by each core. Usually the global size of the L3 cache divided by the number of cores. + + Returns: + The recommended block and the chunk size. + """ + # Fabians code is ugly, but eh + + num_channels = image_size[0] + if len(patch_size) == 2: + patch_size = [1, *patch_size] + patch_size = np.array(patch_size) + block_size = np.array((num_channels, *[2 ** (max(0, math.ceil(math.log2(i)))) for i in patch_size])) + + # shrink the block size until it fits in L1 + estimated_nbytes_block = np.prod(block_size) * bytes_per_pixel + while estimated_nbytes_block > (l1_cache_size_per_core_in_bytes * safety_factor): + # pick largest deviation from patch_size that is not 1 + axis_order = np.argsort(block_size[1:] / patch_size)[::-1] + idx = 0 + picked_axis = axis_order[idx] + while block_size[picked_axis + 1] == 1 or block_size[picked_axis + 1] == 1: + idx += 1 + picked_axis = axis_order[idx] + # now reduce that axis to the next lowest power of 2 + block_size[picked_axis + 1] = 2 ** (max(0, math.floor(math.log2(block_size[picked_axis + 1] - 1)))) + block_size[picked_axis + 1] = min(block_size[picked_axis + 1], image_size[picked_axis + 1]) + estimated_nbytes_block = np.prod(block_size) * bytes_per_pixel + + block_size = np.array([min(i, j) for i, j in zip(image_size, block_size)]) + + # note: there is no use extending the chunk size to 3d when we have a 2d patch size! This would unnecessarily + # load data into L3 + # now tile the blocks into chunks until we hit image_size or the l3 cache per core limit + chunk_size = deepcopy(block_size) + estimated_nbytes_chunk = np.prod(chunk_size) * bytes_per_pixel + while estimated_nbytes_chunk < (l3_cache_size_per_core_in_bytes * safety_factor): + if patch_size[0] == 1 and all(i == j for i, j in zip(chunk_size[2:], image_size[2:])): + break + if all(i == j for i, j in zip(chunk_size, image_size)): + break + # find axis that deviates from block_size the most + axis_order = np.argsort(chunk_size[1:] / block_size[1:]) + idx = 0 + picked_axis = axis_order[idx] + while chunk_size[picked_axis + 1] == image_size[picked_axis + 1] or patch_size[picked_axis] == 1: + idx += 1 + picked_axis = axis_order[idx] + chunk_size[picked_axis + 1] += block_size[picked_axis + 1] + chunk_size[picked_axis + 1] = min(chunk_size[picked_axis + 1], image_size[picked_axis + 1]) + estimated_nbytes_chunk = np.prod(chunk_size) * bytes_per_pixel + if np.mean([i / j for i, j in zip(chunk_size[1:], patch_size)]) > 1.5: + # chunk size should not exceed patch size * 1.5 on average + chunk_size[picked_axis + 1] -= block_size[picked_axis + 1] + break + # better safe than sorry + chunk_size = [min(i, j) for i, j in zip(image_size, chunk_size)] + + # print(image_size, chunk_size, block_size) + return tuple(block_size), tuple(chunk_size) + + +class FastPreprocessor(DefaultPreprocessor): + """Saves nnUnet data set in a mem-mappable data format. compress needs 2.5.2 or higher.""" + + def __init__(self, verbose: bool = True, compress=True, patch_size=None): + super().__init__(verbose) + print(f"FastPreprocessor {compress=}") + self.compress = compress + self.patch_size = patch_size + + def run_case_save( + self, + output_filename_truncated: str, + image_files: list[str], + seg_file: str, + plans_manager: PlansManager, + configuration_manager: ConfigurationManager, + dataset_json: dict | str, + ) -> None: + """Internal nnUnet function.""" + if Path(output_filename_truncated + ".npz").exists() and Path(output_filename_truncated + ".pkl").exists(): + print("skip", output_filename_truncated, end="\r") + return + + data, seg, properties = self.run_case(image_files, seg_file, plans_manager, configuration_manager, dataset_json) + # print("dtypes", data.dtype, seg.dtype) + # print(data.dtype, data.shape, data.max(), data.min()) + if self.compress: + if self.patch_size is None: + np.savez_compressed(output_filename_truncated + ".npz", data=data.astype(np.float16), seg=seg) + else: + # IMPORTANT + blosc2.set_nthreads(1) + + # derive chunk/block layout + blocks, chunks = _comp_blosc2_params( + image_size=data.shape, + patch_size=self.patch_size, + bytes_per_pixel=2, # float16 + ) + cparams = {"codec": blosc2.Codec.ZSTD, "filters": [blosc2.Filter.BITSHUFFLE], "clevel": 5} + # save image + blosc2.asarray( + np.ascontiguousarray(data, dtype=np.float16), + urlpath=output_filename_truncated + ".b2nd", + chunks=chunks, + blocks=blocks, + cparams=cparams, + ) + cparams = {"codec": blosc2.Codec.ZSTD, "filters": [blosc2.Filter.BITSHUFFLE], "clevel": 5} + + # segmentation usually compresses extremely well + blosc2.asarray( + np.ascontiguousarray(seg), + urlpath=output_filename_truncated + "_seg.b2nd", + chunks=chunks, + blocks=blocks, + cparams=cparams, + ) + else: + np.savez(output_filename_truncated + ".npz", data=data.astype(np.float16), seg=seg) + write_pickle(properties, output_filename_truncated + ".pkl") diff --git a/TPTBox/core/internal/train_nnUnet/prepere_dataset.py b/TPTBox/core/internal/train_nnUnet/prepere_dataset.py new file mode 100644 index 0000000..8b50a9f --- /dev/null +++ b/TPTBox/core/internal/train_nnUnet/prepere_dataset.py @@ -0,0 +1,331 @@ +# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ) +# Licensed under the Apache License, Version 2.0 + +from __future__ import annotations + +import os +import random +import sys +from dataclasses import dataclass +from enum import Enum +from multiprocessing import Pool +from pathlib import Path +from typing import Literal + +from tqdm import tqdm + +from TPTBox import Print_Logger, to_nii + +logger = Print_Logger() + + +# ── Config ───────────────────────────────────────────────────────────────────── +@dataclass +class DatasetConfig: + """All tuneable parameters for the feet dataset build.""" + + # ── Identifiers ────────────────────────────────────────────────────────── + dataset_id: int + files: list[tuple[Path, Path]] | list[tuple[list[Path], Path]] + raw_label_ids: list[int | Enum] | dict[int, str | Enum] | dict[int, str] | dict[int, Enum] + dataset_name_suffix: str = "" # appended after Dataset_ if non-empty + nnunet_base: Path = Path() + # ── Label IDs ───────────────────────────────────────────────────────────── + + # ── Preprocessing / spacing ─────────────────────────────────────────────── + spacing: tuple[float, float, float] = (1, 1, 1) + is_ct: bool = True + num_input: int = 1 + axis: str = "S" + target_height_half: int | None = None + auto_crop: int | None = None + ignore_crop: Literal["R", "L", "I", "S", "A", "P"] | str | None = None # noqa: PYI051 + # ── Augmentation ───────────────────────────────────────────────────────── + deform_count: int = 0 + deform_factor: float = 1.0 + degeneration_count: int = 0 + mirror: list[tuple[int | Enum, int | Enum]] | None = None + + # ── Trainer ────────────────────────────────────────────────────────────── + nn_trainer: Literal[ + "nnUNetTrainer", + "nnUNetTrainerNoMirroring", + "nnUNetTrainerDA5", + "nnUNetTrainerDAExtGPU", + ] = "nnUNetTrainer" + auglab_params_json: str = "transform_params_gpu_default01-23.json" + + # ── Runtime ─────────────────────────────────────────────────────────────── + cpu_workers: int | None = None # None → os.cpu_count()//2 + 3 + ignore_label: bool = False + dry_run: bool = True # print plan, skip actual processing + + +def _validate_config(cfg: DatasetConfig) -> None: + """Raise ValueError with a clear message if the config is inconsistent.""" + errors: list[str] = [] + + if cfg.mirror and "NoMirroring" not in cfg.nn_trainer: + errors.append( + f"use_mirror=True but nn_trainer='{cfg.nn_trainer}' does not contain " + "'NoMirroring'. Either set use_mirror=False or use nnUNetTrainerNoMirroring." + ) + if errors: + logger.on_fail("Config validation failed:") + for e in errors: + logger.on_fail(f" • {e}") + raise ValueError("Invalid DatasetConfig — see errors above.") + + +def _build_label_mapping( + cfg: DatasetConfig, +) -> tuple[dict[str, int], dict[int, int], dict[str, str | int], list[tuple[int, int]] | None]: + """Returns:. + ------- + labels_mapping: + nnUNet label definition + {"background": 0, "thymus": 1, ...} + + mapping_forward: + original_label_id -> consecutive_label_id + {17: 1, 42: 2, ...} + + labels_mapping_return: + consecutive_label_id -> original label/name + {"1": "thymus", "2": "femur", ...} + + mirror: + mirror pairs remapped to consecutive ids + """ # noqa: D205 + # ---------------------------------------------------------- + # normalize input to {original_id: name} + # ---------------------------------------------------------- + dataset_mapping: dict[int, str] + enums = {} + + if isinstance(cfg.raw_label_ids, dict): + dataset_mapping = {} + + for k, v in cfg.raw_label_ids.items(): + if isinstance(v, Enum): + dataset_mapping[int(k)] = v.name + enums[v.name] = v.value + else: + dataset_mapping[int(k)] = str(v) + + else: + dataset_mapping = {} + + for item in cfg.raw_label_ids: + if isinstance(item, int): + dataset_mapping[item] = str(item) + else: + dataset_mapping[item.value] = item.name + enums[item.value] = item.name + + # ---------------------------------------------------------- + # create consecutive mapping + # ---------------------------------------------------------- + labels_mapping: dict[str, int] = {"background": 0} + mapping_forward: dict[int, int] = {} + labels_mapping_return: dict[str, str | int] = {} + + for new_idx, (orig_idx, name) in enumerate( + sorted(dataset_mapping.items()), + start=1, + ): + labels_mapping[name] = new_idx + mapping_forward[orig_idx] = new_idx + labels_mapping_return[str(new_idx)] = enums.get(name, orig_idx) + + # ---------------------------------------------------------- + # remap mirror pairs + # ---------------------------------------------------------- + mirror_out: list[tuple[int, int]] | None = None + + if cfg.mirror is not None: + mirror_out = [] + + for left, right in cfg.mirror: + left_id = left.value if isinstance(left, Enum) else left + right_id = right.value if isinstance(right, Enum) else right + + if left_id not in mapping_forward: + raise ValueError(f"Mirror label {left_id} not present in raw_label_ids") + + if right_id not in mapping_forward: + raise ValueError(f"Mirror label {right_id} not present in raw_label_ids") + + mirror_out.append((mapping_forward[left_id], mapping_forward[right_id])) + + return (labels_mapping, mapping_forward, labels_mapping_return, mirror_out) + + +def build_dataset(cfg: DatasetConfig) -> None: + """Build a nnUnet dataset. + + Args: + cfg (DatasetConfig): _description_ + """ + # ── nnUNet env MUST be set before any nnunet import ─────────────────────────── + # These are module-level so they take effect the moment this file is imported. + + f"Building Dataset {cfg.dataset_id:03}" + _validate_config(cfg) + os.environ["nnUNet_raw"] = str(cfg.nnunet_base / "nnUNet_raw") # noqa: SIM112 + os.environ["nnUNet_preprocessed"] = str(cfg.nnunet_base / "nnUNet_preprocessed") # noqa: SIM112 + os.environ["nnUNet_results"] = str(cfg.nnunet_base / "nnUNet_results") # noqa: SIM112 + sys.path.append(str(Path(__file__).parent)) + from _prep_ds import add_file, finalize_ds, run, set_up_dataset + + labels_mapping, mapping_forward, mapping_back, mirror = _build_label_mapping(cfg) + logger.on_text(f"Label count : {len(mapping_forward)} classes") + logger.on_text(f"Mirror pairs : {len(mirror) if mirror else 0}") + logger.on_text(f"Trainer : {cfg.nn_trainer}") + logger.on_text(f"Spacing : {cfg.spacing}") + logger.on_text(f"Deform : (count={cfg.deform_count}, factor={cfg.deform_factor})") + logger.on_text(f"Degeneration : {cfg.degeneration_count}") + logger.on_text(f"Dry run : {cfg.dry_run}") + + if cfg.dry_run: + logger.on_warning("Dry run — stopping before file processing.") + logger.on_text(f"Forward mapping: {mapping_forward}") + + # Pick a random segmentation + _, seg = random.choice(cfg.files) + + seg_nii = to_nii(seg, seg=True) + + labels_found = set(seg_nii.unique()) + labels_found.discard(0) # ignore background + + expected_labels = set(mapping_forward.keys()) + + missing_mapping = labels_found - expected_labels + unused_mapping = expected_labels - labels_found + + logger.on_text(f"Sample segmentation: {seg}") + logger.on_text(f"Labels found : {sorted(labels_found)}") + + if missing_mapping: + logger.on_fail(f"Labels present in segmentation but missing in mapping: {sorted(missing_mapping)}") + + if unused_mapping: + logger.on_warning(f"Labels defined in mapping but not found in sample: {sorted(unused_mapping)}") + + # Test remapping + out = seg_nii.map_labels(mapping_forward) + remapped_labels = sorted(out.unique()) + + logger.on_text(f"Remapped labels : {remapped_labels}") + + expected_remapped = set(mapping_forward.values()) + unexpected = set(remapped_labels) - expected_remapped - {0} + + if unexpected: + logger.on_fail(f"Unexpected labels after remapping: {sorted(unexpected)}") + else: + logger.on_ok("Label mapping validation successful.") + + return + dataset_settings, out_base = set_up_dataset( + cfg.dataset_id, + labels_mapping, + spacing=cfg.spacing, + nn_trainier=cfg.nn_trainer, + AUGLAB_PARAMS_GPU_JSON=cfg.auglab_params_json, + ignore=cfg.ignore_label, + num_input=cfg.num_input, + is_ct=cfg.is_ct, + base=cfg.nnunet_base, + ) + dataset_settings["labels_mapping"] = mapping_back + + # ── Process files ───────────────────────────────────────────────────────── + cpu = cfg.cpu_workers if cfg.cpu_workers is not None else (os.cpu_count() or 4) // 2 + 3 + logger.on_text(f"Worker pool : {cpu} processes") + + results = [] + with Pool(cpu) as p: + logger.on_log("Scheduling file processing") + for img, seg in tqdm(cfg.files, desc="Queuing", unit="pair"): + seg_path = Path(seg) + if not seg_path.exists(): + logger.on_warning(f"Seg file missing, skipping: {seg_path}") + continue + + task = add_file( + p, + img, + seg_path, + dataset_settings, + out_base, + target_height_half=cfg.target_height_half, + defrom=cfg.deform_count > 0, + axis=cfg.axis, + deform_factor=cfg.deform_factor, + defrom_count=cfg.deform_count, + mirror=mirror, + degeneration_count=cfg.degeneration_count, + mapping=mapping_forward, + auto_crop=cfg.auto_crop, + ignore_crop=cfg.ignore_crop, + ) + if task is not None: + results.append(task) + + logger.on_text(f"Running {len(results)} async tasks …") + p.map(run, results) + + # ── Finalise ────────────────────────────────────────────────────────────── + finalize_ds(dataset_settings, out_base) + logger.on_ok(f"Dataset {cfg.dataset_id:03} written to {out_base}") + logger.on_text("Next step:") + logger.on_text("Single Folds") + logger.on_text( + f"python {Path(__file__).parent}/train.py -id {cfg.dataset_id} --gpu 0 -e 300 -el 1000 --num-folds 0 --start-fold 0 -b {cfg.nnunet_base.absolute()}" # noqa: G004 + ) # noqa: G004 + + logger.on_text("k-Folds") + logger.on_text( + f"python {Path(__file__).parent}/train.py -id {cfg.dataset_id} --gpu 0 -e 300 -el 1000 --num-folds 3 --start-fold 0 -b {cfg.nnunet_base.absolute()}" # noqa: G004 + ) + # logger.on_text( + # f" conda run --live-stream --name py3.12 python " + # f"/DATA/NAS/ongoing_projects/robert/code/totalvibesegmentor/" + # f"training_nn/train_ResEnc_.py" + # ) + + +if __name__ == "__main__": + from TPTBox import BIDS_FILE + + infolder = Path("/media/data/lisa/datasets/dataset-lu_dotatate_body_composition/seg_net-thymus/baseline") + data: list[tuple[Path, Path]] = [] + for file in infolder.glob("*.nii.gz"): + bf = BIDS_FILE(file, "/media/data/lisa/datasets/dataset-lu_dotatate_body_composition") + sub = bf.get("sub") + ses = bf.get("ses") + sequ = bf.get("sequ") + acq = bf.get("acq") + ce = bf.get("ce") + fn = file.name.replace("_seg-thym_msk", "_ct").replace("_seg-thym_net", "_ct") + ct = f"/media/data/lisa/datasets/dataset-lu_dotatate_body_composition/rawdata/sub-{sub}_seg/ses-{ses}/ct/{fn}" + ct = Path(ct) + assert ct.exists() + data.append((ct, file)) + + build_dataset( + DatasetConfig( + 7, + data, + {1: "thymus"}, + "thymus", + nnunet_base=Path("/media/data/lisa/code/nnUnet"), + auto_crop=150, + ignore_crop="RA", + is_ct=True, + spacing=(1, 1, 1), + dry_run=False, + ) + ) diff --git a/TPTBox/core/internal/train_nnUnet/train.py b/TPTBox/core/internal/train_nnUnet/train.py new file mode 100644 index 0000000..ba4a847 --- /dev/null +++ b/TPTBox/core/internal/train_nnUnet/train.py @@ -0,0 +1,404 @@ +from __future__ import annotations + +import argparse +import json +import os +from dataclasses import dataclass, field +from datetime import timedelta +from multiprocessing import Pool +from pathlib import Path +from time import time +from typing import TYPE_CHECKING, Union + +import torch +from torch.backends import cudnn + +if TYPE_CHECKING: + from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer + + +# nnunetv2-2.5.2 or higher +@dataclass(slots=True) +class Config: + """Config how the nnUnet is ran.""" + + dataset_id: int + out_base: Path = Path("/media/data/lisa/code/nnUnet") + gpus: list[str] = field(default_factory=lambda: ["0"]) + big_model: bool = True + planner: str | None = None + num_epochs: int = 250 + num_iterations_per_epoch: int = 1000 + num_val_iterations_per_epoch: int = 50 + + num_folds: int = 0 + start_fold: int = 0 + + patch_size: tuple[int, ...] | None = None + batch_size: int | None = None + batch_size_val: int = 1 + + gpu_memory_target: int | None = None + overwrite_target_spacing: list[float] | None = None + + verify_dataset_integrity: bool = True + preprocess: bool | None = None + compress = True + debug: bool = False + configurations: list[str] = field( + default_factory=lambda: ["3d_fullres"] + ) # default=["2d", "3d_fullres", "3d_lowres","3d_cascade_fullres"], + num_processes: tuple[int] = (4,) # [32] # [8, 4, 8] + verbose = False + + @property + def plans(self) -> str: + """Return plan name.""" + if self.planner is not None: + return self.planner + return "nnUNetPlannerResEncL" if self.big_model else "nnUNetPlannerResEncM" + + @property + def dataset_folder(self) -> str: + """Get dataset folder name.""" + return f"Dataset{self.dataset_id:03}" + + @property + def single_gpu(self) -> bool: + """Test if this is single GPU.""" + return len(self.gpus) == 1 + + +def _run_training_highjack(self: nnUNetTrainer) -> None: + self.on_train_start() + + for epoch in range(self.current_epoch, self.num_epochs): + t = time() + self.on_epoch_start() + + self.on_train_epoch_start() + train_outputs = [] + + for batch_id in range(self.num_iterations_per_epoch): + x = time() - t + print( + f"{epoch}:{batch_id:04}/{self.num_iterations_per_epoch:04}", + " time:", + str(timedelta(seconds=x)), + "ETA:", + str(timedelta(seconds=x / (max(1, batch_id)) * self.num_iterations_per_epoch)), + end="\r", + ) + train_outputs.append(self.train_step(next(self.dataloader_train))) # type: ignore + print(f"{epoch}:{batch_id:05}/{self.num_iterations_per_epoch:05}", " time", str(timedelta(seconds=time() - t))) + self.on_train_epoch_end(train_outputs) + torch.cuda.empty_cache() + t = time() + with torch.no_grad(): + self.on_validation_epoch_start() + val_outputs = [] + for batch_id in range(self.num_val_iterations_per_epoch): + print(f"{batch_id:05}/{self.num_val_iterations_per_epoch:05}", " time", str(timedelta(seconds=time() - t)), end="\r") + + val_outputs.append(self.validation_step(next(self.dataloader_val))) # type: ignore + self.on_validation_epoch_end(val_outputs) + + self.on_epoch_end() + l = list(self.dataset_json["labels"].keys()) + self.print_to_log_file( + "Dice", + ", ".join([f"{l[e]}:{i:.3f}" for e, i in enumerate(self.logger.my_fantastic_logging["dice_per_class_or_region"][-1], 1)]), + ) + self.on_train_end() + + +def _run_training( + dataset_name_or_id: Union[str, int], + configuration: str, + fold: Union[int, str], + trainer_class_name: str = "nnUNetTrainer", + plans_identifier: str = "nnUNetPlans", + pretrained_weights: str | None = None, + export_validation_probabilities: bool = False, + continue_training: bool = False, + only_run_validation: bool = False, + disable_checkpointing: bool = False, + val_with_best: bool = False, + device: torch.device = torch.device("cuda"), # noqa: B008 + num_iterations_per_epoch=250, + num_val_iterations_per_epoch=50, # 50 + num_epochs=250, # 1000 + oversample_foreground_percent=0.33, + current_epoch=0, + enable_deep_supervision=True, + save_every=1, # 50 +): + + from nnunetv2.run.run_training import get_trainer_from_args, join, maybe_load_checkpoint + + if plans_identifier == "nnUNetPlans": + print( + "\n############################\n" + "INFO: You are using the old nnU-Net default plans. We have updated our recommendations. " + "Please consider using those instead! " + "Read more here: https://github.com/MIC-DKFZ/nnUNet/blob/master/documentation/resenc_presets.md" + "\n############################\n" + ) + if isinstance(fold, str) and fold != "all": + try: + fold = int(fold) + except ValueError: + print(f'Unable to convert given value for fold to int: {fold}. fold must bei either "all" or an integer!') + raise + + if val_with_best: + assert not disable_checkpointing, "--val_best is not compatible with --disable_checkpointing" + + nnunet_trainer = get_trainer_from_args(dataset_name_or_id, configuration, fold, trainer_class_name, plans_identifier, device=device) + + nnunet_trainer = get_trainer_from_args(dataset_name_or_id, configuration, fold, trainer_class_name, plans_identifier, device=device) + nnunet_trainer.oversample_foreground_percent = oversample_foreground_percent + nnunet_trainer.num_val_iterations_per_epoch = num_val_iterations_per_epoch + nnunet_trainer.num_epochs = num_epochs + nnunet_trainer.current_epoch = current_epoch + nnunet_trainer.num_iterations_per_epoch = num_iterations_per_epoch + nnunet_trainer.enable_deep_supervision = enable_deep_supervision # type: ignore + nnunet_trainer.save_every = save_every + if disable_checkpointing: + nnunet_trainer.disable_checkpointing = disable_checkpointing + + assert not (continue_training and only_run_validation), "Cannot set --c and --val flag at the same time. Dummy." + + maybe_load_checkpoint(nnunet_trainer, continue_training, only_run_validation, pretrained_weights) + + if torch.cuda.is_available(): + cudnn.deterministic = False + cudnn.benchmark = True + + if not only_run_validation: + _run_training_highjack(nnunet_trainer) + + if val_with_best: + nnunet_trainer.load_checkpoint(join(nnunet_trainer.output_folder, "checkpoint_best.pth")) + nnunet_trainer.perform_actual_validation(export_validation_probabilities) + + +class NNUNetRunner: + """Runs the nnUnet training.""" + + def __init__(self, cfg: Config): + self.cfg = cfg + + # ---------------------------------------------------------- + # Environment + # ---------------------------------------------------------- + + def _setup_environment(self): + os.environ["nnUNet_raw"] = str(self.cfg.out_base / "nnUNet_raw") # noqa: SIM112 + os.environ["nnUNet_preprocessed"] = str(self.cfg.out_base / "nnUNet_preprocessed") # noqa: SIM112 + os.environ["nnUNet_results"] = str(self.cfg.out_base / "nnUNet_results") # noqa: SIM112 + os.environ["nnUNet_n_proc_DA"] = "40" # noqa: SIM112 + + # ---------------------------------------------------------- + # Dataset + # ---------------------------------------------------------- + + def _load_dataset_json(self) -> dict: + ds_file = self.cfg.out_base / "nnUNet_raw" / self.cfg.dataset_folder / "dataset.json" + with open(ds_file) as f: + return json.load(f) + + # ---------------------------------------------------------- + # Planning + # ---------------------------------------------------------- + + def _preprocess(self): + if self.cfg.preprocess is None: + plan_file = self.cfg.out_base / "nnUNet_preprocessed" / self.cfg.dataset_folder / f"{self.cfg.plans}.json" + self.cfg.preprocess = not (plan_file).exists() + if not self.cfg.preprocess: + return + + import nnunetv2.experiment_planning.plan_and_preprocess_api as pp + + print("Extract_fingerprints...") + pp.extract_fingerprints([self.cfg.dataset_id], "DatasetFingerprintExtractor", 8, self.cfg.verify_dataset_integrity, False, False) + print("Plan Experiments...") + pp.plan_experiments( + [self.cfg.dataset_id], + self.cfg.plans, + self.cfg.gpu_memory_target, + "DefaultPreprocessor", + self.cfg.overwrite_target_spacing, + self.cfg.plans, + ) + print("Preprocessing...") + + from TPTBox.core.internal.train_nnUnet.fastProcessor import preprocess + + preprocess( + [self.cfg.dataset_id], self.cfg.plans, self.cfg.configurations, self.cfg.num_processes, self.cfg.compress, self.cfg.verbose + ) + + # ---------------------------------------------------------- + # Plans patching + # ---------------------------------------------------------- + + def _patch_plans(self): + plan_file = self.cfg.out_base / "nnUNet_preprocessed" / self.cfg.dataset_folder / f"{self.cfg.plans}.json" + + if not plan_file.exists(): + return + + with open(plan_file) as f: + plans = json.load(f) + + changed = False + + if self.cfg.patch_size is not None: + plans["configurations"]["3d_fullres"]["patch_size"] = list(self.cfg.patch_size) + + changed = True + + if self.cfg.batch_size is not None: + plans["configurations"]["3d_fullres"]["batch_size"] = self.cfg.batch_size + + changed = True + + if changed: + with open(plan_file, "w") as f: + json.dump(plans, f, indent=2) + + # ---------------------------------------------------------- + # Training + # ---------------------------------------------------------- + + def _train_fold(self, fold: int | str): + + # from nnunetv2.run.run_training import run_training + + print(f"Training fold {fold}") + + best_checkpoints = list( + Path(self.cfg.out_base / "nnUNet_results").glob(f"Dataset{self.cfg.dataset_id:03}*/*_3d_full*/fold_{fold}/checkpoint_best.pth") + ) + print("existing Trainigs: ", best_checkpoints) + _run_training( + dataset_name_or_id=self.cfg.dataset_folder, + configuration="3d_fullres", + fold=fold, + trainer_class_name="nnUNetTrainer", + plans_identifier=self.cfg.plans, + num_iterations_per_epoch=self.cfg.num_iterations_per_epoch, + num_epochs=self.cfg.num_epochs, + continue_training=len(best_checkpoints) != 0, + ) + + # ---------------------------------------------------------- + # Multiprocessing + # ---------------------------------------------------------- + + def _train(self): + + if self.cfg.single_gpu: + os.environ["CUDA_VISIBLE_DEVICES"] = self.cfg.gpus[0] + + if self.cfg.num_folds == 0: + self._train_fold("all") + else: + for fold in range(self.cfg.start_fold, self.cfg.num_folds): + self._train_fold(fold) + return + + folds = list( + range( + self.cfg.start_fold, + max(self.cfg.num_folds, len(self.cfg.gpus)), + ) + ) + + def worker(args): + fold, gpu = args + + os.environ["CUDA_VISIBLE_DEVICES"] = gpu + + self._train_fold(fold) + + jobs = [(fold, self.cfg.gpus[i % len(self.cfg.gpus)]) for i, fold in enumerate(folds)] + + with Pool(len(self.cfg.gpus)) as p: + p.map(worker, jobs) + + # ---------------------------------------------------------- + # Main + # ---------------------------------------------------------- + + def run(self) -> None: + """Starts and runs the training.""" + self._setup_environment() + + ds = self._load_dataset_json() + + self.cfg.overwrite_target_spacing = ds.get("spacing", self.cfg.overwrite_target_spacing) + + self._preprocess() + + self._patch_plans() + + self._train() + + +def parse_args() -> Config: + """Arg parse.""" + parser = argparse.ArgumentParser() + + parser.add_argument("--dataset-id", "-id", required=True, type=int) + parser.add_argument("--base", "-b", required=True, type=str) + parser.add_argument("--gpu", nargs="+", default=["0"]) + parser.add_argument("--epochs", "-e", type=int, default=250) + parser.add_argument("--epoch-len", "-el", type=int, default=1000) + parser.add_argument("--planner", default=None) + parser.add_argument("--small", action="store_true") + parser.add_argument("--num-folds", type=int, default=0) + parser.add_argument("--start-fold", type=int, default=0) + parser.add_argument("--patch-size", nargs="+", type=int) + parser.add_argument("--batch-size", type=int) + parser.add_argument("--skip-preprocessing", action="store_true") + parser.add_argument("--force-preprocessing", action="store_true") + parser.add_argument("--num_processes", type=int, default=4) # 32 on server + + args = parser.parse_args() + preprocess = None + if args.skip_preprocessing: + preprocess = False + if args.force_preprocessing: + preprocess = True + return Config( + out_base=Path(args.base), + dataset_id=args.dataset_id, + gpus=args.gpu, + big_model=not args.small, + planner=args.planner, + num_epochs=args.epochs, + num_folds=args.num_folds, + start_fold=args.start_fold, + num_iterations_per_epoch=args.epoch_len, + patch_size=tuple(args.patch_size) if args.patch_size else None, + batch_size=args.batch_size, + preprocess=preprocess, + num_processes=(args.num_processes,), + ) + + +def main() -> None: + """Main.""" + cfg = parse_args() + + runner = NNUNetRunner(cfg) + + runner.run() + + +if __name__ == "__main__": + main() + # /media/data/anaconda3/envs/py3.12/bin/python /media/data/lisa/code/scripts/train_nnUnet/train.py -id 7 --gpu 0 -e 300 -el 1000 --force-preprocessing --num_processes 32 diff --git a/TPTBox/core/nii_poi_abstract.py b/TPTBox/core/nii_poi_abstract.py index 2c3ca6b..b2598ff 100755 --- a/TPTBox/core/nii_poi_abstract.py +++ b/TPTBox/core/nii_poi_abstract.py @@ -1,5 +1,6 @@ from __future__ import annotations +import math import sys from dataclasses import dataclass from typing import TYPE_CHECKING, Any @@ -7,6 +8,7 @@ import nibabel as nib import nibabel.orientations as nio import numpy as np +from scipy.spatial import ConvexHull from scipy.spatial.transform import Rotation from typing_extensions import Self @@ -383,6 +385,11 @@ def get_plane(self, res_threshold: float | None = 1) -> str: plane = "iso" return plane + def voxel_volume(self) -> float: + """Returns the volume of a single voxel in mm³ (product of all zoom values).""" + product = math.prod(self.spacing) + return product + def get_axis(self, direction: DIRECTIONS = "S") -> int: """Return the axis index corresponding to the given anatomical direction. @@ -467,7 +474,7 @@ def make_nii(self, arr: np.ndarray | None = None, seg=False) -> NII: arr = np.zeros(self.shape_int) return self.make_empty_nii(_arr=arr, seg=seg) - def global_to_local(self, x: COORDINATE) -> tuple: + def global_to_local(self, x: COORDINATE, itk=False) -> tuple: """Convert world (RAS/LPS) coordinates to voxel (local) coordinates. Applies the inverse affine transform: rotation-transpose times @@ -479,19 +486,26 @@ def global_to_local(self, x: COORDINATE) -> tuple: Returns: tuple: Voxel-space coordinate rounded to 7 decimal places. """ - a = self.rotation.T @ (np.array(x) - self.origin) / np.array(self.zoom) + x_ = np.array(x) + if itk: + x_[0] *= -1 + x_[1] *= -1 + a = self.rotation.T @ (x_ - self.origin) / np.array(self.zoom) return tuple(round(float(v), 7) for v in a) - def global_to_local_arr(self, coords: np.ndarray) -> np.ndarray: + def global_to_local_arr(self, coords: np.ndarray, itk=False) -> np.ndarray: """Vectorized :meth:`global_to_local` for an ``(N, 3)`` array of world coordinates. Equivalent to applying ``global_to_local`` to each row but in a single batched inverse-affine matmul. """ a = (np.asarray(coords, dtype=float) - np.asarray(self.origin)) @ np.asarray(self.rotation) / np.asarray(self.zoom) + if itk: + a[:, 0] *= -1 + a[:, 1] *= -1 return np.round(a, 7) - def local_to_global(self, x: COORDINATE) -> tuple: + def local_to_global(self, x: COORDINATE | np.ndarray, itk=False) -> tuple: """Convert voxel (local) coordinates to world (RAS/LPS) coordinates. Applies the forward affine transform: rotation times @@ -503,8 +517,10 @@ def local_to_global(self, x: COORDINATE) -> tuple: Returns: tuple: World-space coordinate rounded to 7 decimal places. """ - # TODO ITK version a = self.rotation @ (np.array(x) * np.array(self.zoom)) + self.origin + if itk: + a[0] *= -1 + a[1] *= -1 return tuple(round(float(v), 7) for v in a) def to_deepali_grid(self, align_corners: bool = True) -> Any: @@ -597,6 +613,211 @@ def get_num_dims(self) -> int: """ return len(self.shape) + def get_corners(self) -> np.ndarray: + """Compute the 8 corner points of the grid's oriented bounding box (OBB). + + in world coordinates. + + The box is defined by: + - shape_int: voxel dimensions (nx, ny, nz) + - zoom: physical spacing per axis (mm/voxel) + - rotation: 3×3 matrix whose columns define the local axes + - local_to_global: transforms local coordinates to world space + + The box is centered at the grid center in local space and then mapped + into world space using the grid's rotation and translation. + + Returns: + np.ndarray of shape (8, 3): + Corner points ordered in a consistent binary pattern: + (-u/±v/±w combinations). + """ + s = np.array(self.shape_int, dtype=float) + zoom = np.array(self.zoom, dtype=float) + + ctr = np.array(self.local_to_global((s - 1) * 0.5)) + R = self.rotation + + u = R[:, 0] * 0.5 * s[0] * zoom[0] + v = R[:, 1] * 0.5 * s[1] * zoom[1] + w = R[:, 2] * 0.5 * s[2] * zoom[2] + + return np.array( + [ + ctr - u - v - w, + ctr - u - v + w, + ctr - u + v - w, + ctr - u + v + w, + ctr + u - v - w, + ctr + u - v + w, + ctr + u + v - w, + ctr + u + v + w, + ] + ) + + def get_obb_quads(self: Has_Grid) -> list[np.ndarray]: + """Return the 6 faces of the grid's oriented bounding box (OBB) as quadrilateral patches. + + Each face is represented as a (4, 3) array of world-space vertices. + Vertex ordering follows the right-hand rule so that outward-facing + normals are consistent for all faces. + + The OBB is constructed from: + - center position in world space + - half-extent vectors along rotated axes: + u, v, w = (R[:,i] * half_size_i) + + Returns: + list[np.ndarray]: + A list of 6 quads corresponding to: + +Z, -Z, +X, -X, +Y, -Y faces (in this order). + """ + if len(self.zoom) < 3: + raise NotImplementedError("only implemented for 3D images.") + s = np.array(self.shape_int, dtype=float) + zoom = np.array(self.zoom, dtype=float) + ctr = np.array(self.local_to_global((s - 1) * 0.5)) + R = self.rotation # columns are local x/y/z axes + u = R[:, 0] * 0.5 * s[0] * zoom[0] + v = R[:, 1] * 0.5 * s[1] * zoom[1] + w = R[:, 2] * 0.5 * s[2] * zoom[2] + return [ + np.array([ctr + u + v + w, ctr + u - v + w, ctr - u - v + w, ctr - u + v + w]), # +Z face + np.array([ctr + u + v - w, ctr - u + v - w, ctr - u - v - w, ctr + u - v - w]), # -Z face + np.array([ctr + u + v + w, ctr + u + v - w, ctr + u - v - w, ctr + u - v + w]), # +X face + np.array([ctr - u + v + w, ctr - u - v + w, ctr - u - v - w, ctr - u + v - w]), # -X face + np.array([ctr + u + v + w, ctr - u + v + w, ctr - u + v - w, ctr + u + v - w]), # +Y face + np.array([ctr + u - v + w, ctr + u - v - w, ctr - u - v - w, ctr - u - v + w]), # -Y face + ] + + def get_intersecting_volume(self, b: Has_Grid) -> float: + """Approximate the geometric intersection volume of two oriented bounding boxes (OBBs). + + This method computes the intersection by: + 1. Collecting candidate points: + - Corners of A inside B + - Corners of B inside A + - Edge/plane intersection points between both OBBs + 2. Deduplicating points in world space + 3. Constructing a 3D convex hull over these points + 4. Returning the hull volume as the intersection volume + + Important notes: + - This is NOT an exact polytope clipping implementation. + - The result is an approximation that is exact only when the + intersection is fully convex and all boundary points are captured + (which is typically but not strictly guaranteed). + - The method assumes the intersection of two OBBs is convex + (which it is), but numerical robustness depends on point sampling. + - Degenerate or near-tangent configurations may return 0.0 due to + insufficient hull points. + + Computational complexity is constant with respect to image resolution + (depends only on 8 corners and 12 edges per box). + + Args: + b (Has_Grid): Another grid with position, orientation, and spacing. + + Returns: + float: + Estimated intersection volume in mm³. Returns 0.0 if no valid + convex hull can be constructed. + """ + + def _half_spaces(grid: Has_Grid): + """Yield (point_on_plane, inward_unit_normal) for each of the 6 OBB faces.""" + s = np.array(grid.shape_int, dtype=float) + zoom = np.array(grid.zoom, dtype=float) + ctr = np.array(grid.local_to_global((s - 1) * 0.5)) + R = grid.rotation + half = 0.5 * s * zoom + + for i in range(3): + axis = R[:, i] + for sign in (+1.0, -1.0): + yield ctr + sign * half[i] * axis, -sign * axis + + def _obb_edges(grid: Has_Grid): + c = grid.get_corners() + edge_ids = [ + (0, 1), + (0, 2), + (0, 4), + (1, 3), + (1, 5), + (2, 3), + (2, 6), + (3, 7), + (4, 5), + (4, 6), + (5, 7), + (6, 7), + ] + return [(c[i], c[j]) for i, j in edge_ids] + + def _point_inside_box(p: np.ndarray, box: Has_Grid, eps=1e-8) -> bool: + return all(np.dot(p - pt, n) >= -eps for pt, n in _half_spaces(box)) + + def _segment_plane_intersection(p0: np.ndarray, p1: np.ndarray, plane_pt: np.ndarray, plane_n: np.ndarray, eps=1e-10): + u = p1 - p0 + denom = np.dot(u, plane_n) + if abs(denom) < eps: + return None + t = np.dot(plane_pt - p0, plane_n) / denom + if t < -eps or t > 1 + eps: + return None + return p0 + t * u + + points = [] + + # corners of A inside B + for p in self.get_corners(): + if _point_inside_box(p, b): + points.append(p) # noqa: PERF401 + # corners of B inside A + for p in b.get_corners(): + if _point_inside_box(p, self): + points.append(p) # noqa: PERF401 + # edges of A against planes of B + for e0, e1 in _obb_edges(self): + for plane_pt, plane_n in _half_spaces(b): + p = _segment_plane_intersection( + e0, + e1, + plane_pt, + plane_n, + ) + if p is None: + continue + if _point_inside_box(p, self) and _point_inside_box(p, b): + points.append(p) + # edges of B against planes of A + for e0, e1 in _obb_edges(b): + for plane_pt, plane_n in _half_spaces(self): + p = _segment_plane_intersection( + e0, + e1, + plane_pt, + plane_n, + ) + if p is None: + continue + if _point_inside_box(p, self) and _point_inside_box(p, b): + points.append(p) + if len(points) == 0: + return 0.0 + pts = np.unique( + np.round(np.asarray(points), 8), + axis=0, + ) + if len(pts) < 4: + return 0.0 + try: + hull = ConvexHull(pts) + return float(hull.volume) + except Exception: + return 0.0 + @dataclass class Grid(Has_Grid): diff --git a/TPTBox/core/nii_wrapper.py b/TPTBox/core/nii_wrapper.py index de809e8..19b4df1 100755 --- a/TPTBox/core/nii_wrapper.py +++ b/TPTBox/core/nii_wrapper.py @@ -1,6 +1,5 @@ from __future__ import annotations -import math import traceback import warnings import zlib @@ -2038,27 +2037,35 @@ def get_overlapping_labels_to( assert self.seg and mask_other.seg return np_calc_overlapping_labels(self.get_seg_array(), mask_other.get_seg_array()) - def is_segmentation_in_border(self,minimum=0, voxel_tolerance: int = 2,use_mm=False) -> bool: - """Checks if the segmentation is touching the border of the image volume. + def is_segmentation_in_border(self,minimum=0,voxel_tolerance: int = 2,use_mm: bool = False) -> bool: + """Checks if the segmentation touches the border of the image volume. Parameters: - - minimum (int, optional): Minimum intensity threshold for segmentation. Defaults to 0. - - voxel_tolerance (int, optional): Number of voxels allowed as tolerance from the border. Defaults to 2. - - use_mm (bool, optional): Whether to use millimeter units instead of voxels. Defaults to False. + - minimum (int, optional): Minimum intensity threshold for segmentation. + Defaults to 0. + - voxel_tolerance (int, optional): Number of voxels allowed as tolerance + from the border. Defaults to 2. + - use_mm (bool, optional): Whether to use millimeter units instead of + voxels. Defaults to False. Returns: - - bool: True if the segmentation is within the defined voxel tolerance of the border, False otherwise. + - bool: True if the segmentation is within the defined tolerance of the + border, False otherwise. """ slices = self.compute_crop(minimum,dist=0,use_mm=use_mm,raise_error=False) if slices is None: return False - shp = self.shape - seg_at_border = False - for d in range(3): - if slices[d].start <= voxel_tolerance or slices[d].stop - 1 >= shp[d] - voxel_tolerance: - seg_at_border = True - break - return seg_at_border + for dim, s in enumerate(slices): + shape_dim = self.shape[dim] + # Interpret open-ended slices as full image bounds + start = 0 if s.start is None else s.start + stop = shape_dim if s.stop is None else s.stop + # stop is exclusive, so the last occupied voxel is stop - 1 + if start <= voxel_tolerance: + return True + if stop - 1 >= shape_dim - voxel_tolerance: + return True + return False def truncate_labels_beyond_reference_( self, idx: int | list[int] = 1, not_beyond: int | list[int] = 1, fill: int = 0, axis: DIRECTIONS = "S", inclusion: bool = False, inplace: bool = True @@ -2540,7 +2547,7 @@ def to_stl( elif number_path: out_path.with_name(f"{out_path.stem}_{label}.stl") log.on_save(f"Saving STL to {out_path}") - out_path.parent.mkdir(exist_ok=True) + out_path.parent.mkdir(exist_ok=True,parents=True) cube.save(str(out_path)) if include_normals: @@ -2641,7 +2648,7 @@ def is_intersecting_vertical(self, b: Self, min_overlap=40) -> bool: return True return min_v < x2[2] < max_v - def get_intersecting_volume(self, b: Self) -> float: + def get_intersecting_volume_slow(self, b: Self) -> float: """Returns the number of voxels in ``self``'s grid that overlap with image ``b``. ``b`` is binarised (all non-zero → 1) and resampled into ``self``'s voxel @@ -2773,10 +2780,7 @@ def unique(self,verbose:logging=False,crop=False) -> list[int]: out = np_unique_withoutzero(arr) log.print(out,verbose=verbose) return out - def voxel_volume(self) -> float: - """Returns the volume of a single voxel in mm³ (product of all zoom values).""" - product = math.prod(self.spacing) - return product + def volumes(self, include_zero: bool = False, in_mm3=False,sort=False) -> dict[int, float]|dict[int, int]: """Returns a dict stating how many pixels are present for each label.""" diff --git a/TPTBox/core/np_utils.py b/TPTBox/core/np_utils.py index 74166bd..1ac5f10 100755 --- a/TPTBox/core/np_utils.py +++ b/TPTBox/core/np_utils.py @@ -235,7 +235,7 @@ def np_unique(arr: np.ndarray) -> list[int]: max_val = int(arr.max()) if max_val < 2**20: # ~1M labels threshold — bincount stays fast counts = np.bincount(arr.ravel()) - return list(np.where(counts > 0)[0]) + return np.where(counts > 0)[0].tolist() # For sparse label spaces fall back to np.unique return old_np_unique(arr) diff --git a/TPTBox/core/poi.py b/TPTBox/core/poi.py index bd9b20c..41cb0e0 100755 --- a/TPTBox/core/poi.py +++ b/TPTBox/core/poi.py @@ -36,6 +36,7 @@ Location, Sentinel, Vertebra_Instance, + _same_direction, log, logging, v_name2idx, @@ -655,7 +656,7 @@ def save( self, out_path, make_parents, additional_info, verbose=verbose, save_hint=save_hint, resample_reference=resample_reference ) - def make_point_cloud_nii(self, affine=None, s=8, sphere=False) -> tuple[NII, NII]: + def make_point_cloud_nii(self, affine=None, s=8, sphere=True) -> tuple[NII, NII]: """Create point cloud NIfTI images from the POI coordinates. This method generates two NIfTI images, one for the regions and another for the subregions, @@ -683,9 +684,6 @@ def make_point_cloud_nii(self, affine=None, s=8, sphere=False) -> tuple[NII, NII affine = self.affine arr = np.zeros(self.shape_int) arr2 = np.zeros(self.shape_int) - s1 = max(s // 2, 1) - s2 = max(s - s1, 1) - from math import ceil, floor if sphere: zoom = np.asarray(self.zoom) @@ -704,7 +702,9 @@ def make_point_cloud_nii(self, affine=None, s=8, sphere=False) -> tuple[NII, NII for region, subregion, (x, y, z) in self.items(): x, y, z = round(x), round(y), round(z) # noqa: PLW2901 - + if not (0 <= x < self.shape[0] and 0 <= y < self.shape[1] and 0 <= z < self.shape[2]): + print(f"Skipping POI outside image: {region}, {subregion},{(x, y, z)} shape={self.shape}") + continue # image bounds x0 = max(x - rx, 0) x1 = min(x + rx + 1, self.shape[0]) @@ -726,25 +726,69 @@ def make_point_cloud_nii(self, affine=None, s=8, sphere=False) -> tuple[NII, NII kz1 = kz0 + (z1 - z0) local_mask = sphere_mask[kx0:kx1, ky0:ky1, kz0:kz1] - + if region == 0: + region = 1 # noqa: PLW2901 + if subregion == 0: + subregion = 1 # noqa: PLW2901 arr[x0:x1, y0:y1, z0:z1][local_mask] = region arr2[x0:x1, y0:y1, z0:z1][local_mask] = subregion else: for region, subregion, (x, y, z) in self.items(): + if region == 0: + region = 1 # noqa: PLW2901 + if subregion == 0: + subregion = 1 # noqa: PLW2901 + if not (0 <= x < self.shape[0] and 0 <= y < self.shape[1] and 0 <= z < self.shape[2]): + print(f"Skipping POI outside image: {region}, {subregion}, {(x, y, z)} shape={self.shape}") + continue + rx = int(np.ceil((s / 2) / self.zoom[0])) + ry = int(np.ceil((s / 2) / self.zoom[1])) + rz = int(np.ceil((s / 2) / self.zoom[2])) arr[ - max((floor(x - s1 / self.zoom[0])) + 1, 0) : min((ceil(x + s2 / self.zoom[0] + 1)), self.shape[0]), - max((floor(y - s1 / self.zoom[1])) + 1, 0) : min((ceil(y + s2 / self.zoom[1] + 1)), self.shape[1]), - max((floor(z - s1 / self.zoom[2])) + 1, 0) : min((ceil(z + s2 / self.zoom[2] + 1)), self.shape[2]), + int(max(x - rx, 0)) : int(min(x + rx + 1, self.shape[0])), + int(max(y - ry, 0)) : int(min(y + ry + 1, self.shape[1])), + int(max(z - rz, 0)) : int(min(z + rz + 1, self.shape[2])), ] = region + arr2[ - max((floor(x - s1 / self.zoom[0])) + 1, 0) : min((ceil(x + s2 / self.zoom[0] + 1)), self.shape[0]), - max((floor(y - s1 / self.zoom[1])) + 1, 0) : min((ceil(y + s2 / self.zoom[1] + 1)), self.shape[1]), - max((floor(z - s1 / self.zoom[2])) + 1, 0) : min((ceil(z + s2 / self.zoom[2] + 1)), self.shape[2]), + int(max(x - rx, 0)) : int(min(x + rx + 1, self.shape[0])), + int(max(y - ry, 0)) : int(min(y + ry + 1, self.shape[1])), + int(max(z - rz, 0)) : int(min(z + rz + 1, self.shape[2])), ] = subregion nii = nib.Nifti1Image(arr, affine=affine) nii2 = nib.Nifti1Image(arr2, affine=affine) return NII(nii, seg=True), NII(nii2, seg=True) + def flip(self, axis: int | str, keep_global_coords: bool = True, inplace: bool = False) -> Self: + """Flip the POIs along a spatial axis. + + Args: + axis: Axis to flip, either as an integer or anatomical direction + string (e.g. ``"S"``, ``"R"``). + keep_global_coords: If True, perform the flip by changing the + orientation, preserving world-space coordinates. If False, + mirror the voxel coordinates without changing the affine. + inplace: Whether to modify this POI in place. + + Returns: + The flipped POI. + """ + axis = self.get_axis(axis) if not isinstance(axis, int) else axis + + if keep_global_coords: + orient = list(self.orientation) + orient[axis] = _same_direction[orient[axis]] + return self.reorient(tuple(orient), inplace=inplace) + + assert self.shape is not None, "Cannot flip voxel coordinates without shape information." + + def _flip(x: float, y: float, z: float): + p = [x, y, z] + p[axis] = self.shape[axis] - 1 - p[axis] + return tuple(p) + + return self.apply_all(_flip, inplace=inplace) + def filter_points_inside_shape(self, inplace=False) -> Self: """Filter out POI points that are outside the defined shape. diff --git a/TPTBox/core/poi_fun/poi_abstract.py b/TPTBox/core/poi_fun/poi_abstract.py index d74dfee..194f03d 100755 --- a/TPTBox/core/poi_fun/poi_abstract.py +++ b/TPTBox/core/poi_fun/poi_abstract.py @@ -355,8 +355,8 @@ def __eq__(self, x): return False def __len__(self) -> int: - if self._len is None: - self._len = len(list(self.items())) + # if self._len is None: + self._len = len(list(self.items())) return self._len def __iter__(self): @@ -803,7 +803,7 @@ def remove(self, *label: tuple[int, int], inplace: bool = False) -> Self: obj.centroids.pop(loc, None) return obj - def extract_subregion(self, *location: Abstract_lvl | int, inplace: bool = False) -> Self: + def extract_subregion(self, *location: int | list[int] | Enum, inplace: bool = False) -> Self: """Return a POI containing only the specified subregion(s). Args: @@ -829,16 +829,32 @@ def extract_subregion_(self, *location: Abstract_lvl | int) -> Self: """In-place alias for :meth:`extract_subregion`.""" return self.extract_subregion(*location, inplace=True) - def extract_vert(self, *vert_label: int, inplace: bool = False) -> Self: - """Deprecated — use :meth:`extract_region` instead.""" - import warnings + def extract( + self, + *vert_label: tuple[int, int] | list[tuple[int, int]], + inplace: bool = False, + ) -> Self: + """Return a POI containing only the specified region(s) (vertebrae). - warnings.warn("extract_vert id deprecated use extract_region instead", stacklevel=5) # TODO remove in version 2.0 - return self.extract_region(*vert_label, inplace=inplace) + Args: + *vert_label: One or more region IDs, lists of IDs, or ``Enum`` + members to retain. + inplace: Filter in place. Defaults to ``False``. - def extract_vert_(self, *vert_label: int) -> Self: - """Deprecated in-place alias — use :meth:`extract_region_` instead.""" - return self.extract_vert(*vert_label, inplace=True) + Returns: + Filtered POI. + """ + # flatten list + vert_label = _flatten(vert_label) + vert_labels = tuple(vert_label) + extracted_centroids = POI_Descriptor() + for x1, x2, y in self.centroids.items(): + if (x1, x2) in vert_labels: + extracted_centroids[x1, x2] = y + if inplace: + self.centroids = extracted_centroids + return self + return self.copy(centroids=extracted_centroids) def extract_region(self, *vert_label: int | list[int] | Enum, inplace: bool = False) -> Self: """Return a POI containing only the specified region(s) (vertebrae). diff --git a/TPTBox/core/poi_fun/save_load.py b/TPTBox/core/poi_fun/save_load.py index 966d991..01f6c06 100644 --- a/TPTBox/core/poi_fun/save_load.py +++ b/TPTBox/core/poi_fun/save_load.py @@ -10,6 +10,7 @@ # from TPTBox import POI, POI_Global from TPTBox.core import bids_files +from TPTBox.core.internal.nii_help import save_json from TPTBox.core.nii_poi_abstract import Has_Grid from TPTBox.core.poi_fun.poi_abstract import POI_Descriptor from TPTBox.core.vert_constants import ( @@ -123,21 +124,8 @@ def save_poi( return json_object, print_add = _poi_to_dict_list(poi, additional_info, save_hint, resample_reference, verbose) - # Problem with python 3 and int64 serialization. - def convert(o): - if isinstance(o, np.integer): - return int(o) - if isinstance(o, np.floating): - return float(o) - if isinstance(o, np.ndarray): - return o.tolist() - if isinstance(o, Path): - return str(o.absolute()) - raise TypeError(type(o)) - try: - with open(out_path, "w") as f: - json.dump(json_object, f, default=convert, indent=4) + save_json(out_path, json_object, indent=4) except TypeError: Path(out_path).unlink(missing_ok=True) raise diff --git a/TPTBox/core/poi_fun/save_mkr.py b/TPTBox/core/poi_fun/save_mkr.py index 9783aa3..17eec8d 100644 --- a/TPTBox/core/poi_fun/save_mkr.py +++ b/TPTBox/core/poi_fun/save_mkr.py @@ -10,6 +10,7 @@ import numpy as np from typing_extensions import NotRequired +from TPTBox.core.internal.nii_help import save_json from TPTBox.logger.log_file import log from TPTBox.mesh3D.mesh_colors import RGB_Color, get_color_by_label @@ -545,8 +546,5 @@ def _save_mrk( "@schema": "https://raw.githubusercontent.com/slicer/slicer/master/Modules/Loadable/Markups/Resources/Schema/markups-schema-v1.0.3.json#", "markups": markups, } - # print(markups[-1].get("display")) - filepath.unlink(missing_ok=True) - with open(filepath, "w") as f: - json.dump(mrk_data, f, indent=2) + save_json(filepath, mrk_data, indent=2) log.on_save(f"Saved .mrk.json to {filepath}") diff --git a/TPTBox/core/vert_constants.py b/TPTBox/core/vert_constants.py index d7ef0e8..f8bbede 100755 --- a/TPTBox/core/vert_constants.py +++ b/TPTBox/core/vert_constants.py @@ -821,6 +821,7 @@ class Lower_Body(Abstract_lvl): LATERAL_CONDYLE_DISTAL = 16 MEDIAL_CONDYLE_DISTAL = 17 NOTCH_POINT = 18 + TGPP = 99 # Femur, Tibia ANATOMICAL_AXIS_PROXIMAL = 19 ANATOMICAL_AXIS_DISTAL = 20 @@ -837,7 +838,6 @@ class Lower_Body(Abstract_lvl): LATERAL_CONDYLE_LATERAL = 29 ANKLE_CENTER = 30 MEDIAL_MALLEOLUS = 31 - TGPP = 99 TTP = 98 # Fibula LATERAL_MALLEOLUS = 32 @@ -893,7 +893,7 @@ def get_mapping(cls) -> dict[str, tuple]: "TMM": (Full_Body_Instance.tibia_right, Lower_Body.MEDIAL_MALLEOLUS), "TAAP": (Full_Body_Instance.tibia_right, Lower_Body.ANATOMICAL_AXIS_PROXIMAL), "TADP": (Full_Body_Instance.tibia_right, Lower_Body.ANATOMICAL_AXIS_DISTAL), - "TGPP": (Full_Body_Instance.tibia_right, Lower_Body.TGPP), + "TGPP": (Full_Body_Instance.femur_right, Lower_Body.TGPP), "TTP": (Full_Body_Instance.tibia_right, Lower_Body.TTP), # Fibula "FLM": (Full_Body_Instance.fibula_right, Lower_Body.LATERAL_MALLEOLUS), diff --git a/TPTBox/mesh3D/snapshot3D.py b/TPTBox/mesh3D/snapshot3D.py index dffe31f..f0c01f0 100644 --- a/TPTBox/mesh3D/snapshot3D.py +++ b/TPTBox/mesh3D/snapshot3D.py @@ -11,6 +11,7 @@ import vtk from fury import window from PIL import Image +from tqdm import tqdm from vtk.util import numpy_support # type: ignore from xvfbwrapper import Xvfb @@ -36,9 +37,11 @@ def make_snapshot3D( resolution: float | None = None, width_factor: float = 1.0, scale_factor: int = 1, + debug: bool = False, verbose: bool = True, crop: bool = True, png_magnify: int = 1, + opacity: dict[int, float] | None = None, ) -> Image.Image: """Generate a 3D surface-rendered snapshot from a segmentation image. @@ -62,10 +65,13 @@ def make_snapshot3D( verbose: If True, logs the output path after saving. crop: If True, crops the image to its bounding box before rendering. png_magnify: Window pixel density multiplier for the fury renderer. + opacity: mapping idx to opacity (1 means full, 0 invisible) Returns: The rendered snapshot as a PIL Image object. """ + if opacity is None: + opacity = {} is_tmp = output_path is None t = None if output_path is None: @@ -101,14 +107,11 @@ def make_snapshot3D( show_m = window.ShowManager(scene=scene, size=window_size, reset_camera=False, png_magnify=png_magnify) show_m.initialize() for i, ids in enumerate(ids_list): + if debug: + logger.on_debug(f"{i + 1:02}/{len(ids_list):02} - Snapshot frames") x = width * i _plot_sub_seg( - scene, - nii.extract_label(ids, keep_label=True), - x, - 0, - smoothing, - view[i % len(view)], + scene, nii.extract_label(ids, keep_label=True), x, 0, smoothing, view[i % len(view)], opacity=opacity, debug=debug ) scene.projection(proj_type="parallel") scene.reset_camera_tight(margin_factor=1.02) @@ -141,6 +144,8 @@ def make_snapshot3D_parallel( scale_factor: int = 1, override: bool = True, crop: bool = True, + opacity: dict[int, float] | None = None, + debug=False, ) -> None: """Run :func:`make_snapshot3D` in parallel across multiple images. @@ -176,6 +181,8 @@ def make_snapshot3D_parallel( "png_magnify": png_magnify, "crop": crop, "scale_factor": scale_factor, + "opacity": opacity, + "debug": debug, }, ) ress.append(res) @@ -188,8 +195,12 @@ def make_snapshot3D_parallel( make_sub_snapshot_parallel = make_snapshot3D_parallel -def _plot_sub_seg(scene: window.Scene, nii: NII, x: int, y: int, smoothing: int, orientation: VIEW) -> None: +def _plot_sub_seg( + scene: window.Scene, nii: NII, x: int, y: int, smoothing: int, orientation: VIEW, opacity: dict[int, float] | None = None, debug=False +) -> None: """Render all labels from a segmentation NII into the fury scene at the given viewport offset.""" + if opacity is None: + opacity = {} if orientation == "A": # [ axis1(w) ] [ axis2(h) ] [ view in ] affine = np.array([[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1]]) @@ -209,7 +220,12 @@ def _plot_sub_seg(scene: window.Scene, nii: NII, x: int, y: int, smoothing: int, affine = np.array([[0, 0, 1, 0], [1, 0, 0, 0], [0, -1, 0, 0], [0, 0, 0, 1]]) else: raise NotImplementedError() - for idx in nii.unique(): + u = nii.unique() + idxs = u if not debug else tqdm(u) + for idx in idxs: + o = opacity.get(idx, 1) + if o == 0: + continue color = get_color_by_label(idx) cont_actor = _plot_mask( nii.extract_label(idx), @@ -218,7 +234,7 @@ def _plot_sub_seg(scene: window.Scene, nii: NII, x: int, y: int, smoothing: int, y, smoothing=smoothing, color=color, - opacity=1, + opacity=o, ) scene.add(cont_actor) @@ -388,7 +404,7 @@ def _contour_from_roi_smooth( skin_actor = vtk.vtkActor() skin_actor.SetMapper(skin_mapper) - skin_actor.GetProperty().SetOpacity(opacity) + skin_actor.GetProperty().SetOpacity(opacity) if opacity != 1 else None skin_actor.GetProperty().SetColor(color[0], color[1], color[2]) return skin_actor diff --git a/TPTBox/segmentation/VibeSeg/inference_nnunet.py b/TPTBox/segmentation/VibeSeg/inference_nnunet.py index 1ea4a46..d550dd5 100644 --- a/TPTBox/segmentation/VibeSeg/inference_nnunet.py +++ b/TPTBox/segmentation/VibeSeg/inference_nnunet.py @@ -84,7 +84,7 @@ def run_inference_on_file( gpu=None, keep_size: bool = False, fill_holes: bool = False, - logits: bool = False, + logits: bool = False, # deprecated mapping=None, crop: bool = False, max_folds=None, @@ -103,6 +103,7 @@ def run_inference_on_file( cache_model: bool = False, _key_ResEnc: str = "__nnUNet*ResEnc", fail_on_missing_memory=False, + _cpu_chunks: int | None = None, logger=logger, ) -> tuple[Image_Reference, np.ndarray | None]: """Load a VibeSeg model and run inference on the supplied NIfTI images. @@ -122,7 +123,8 @@ def run_inference_on_file( original image size. fill_holes: If ``True``, fill holes in the segmentation mask after inference. - logits: If ``True``, also return the raw softmax logits array. + logits: Deprecated; Was commented out for less GPU waste. + If ``True``, also return the raw softmax logits array. mapping: Optional label remapping dict applied to the segmentation after inference. crop: If ``True``, crop the input to its foreground bounding box before @@ -145,9 +147,12 @@ def run_inference_on_file( wait_till_gpu_percent_is_free: Minimum free GPU fraction to require before starting inference. tile_batch_size: Number of sliding-window tiles to run per network - forward pass. ``1`` (default) keeps the original per-tile behaviour; + forward pass. ``1`` (default) keeps the original per-tile behavior; larger values batch tiles to better saturate the GPU at the cost of higher peak memory. + _cpu_chunks: Split the prediction in k chunks along the largest axis. + This setting should only be used if there in not enough (CPU) RAM. + For GPU memory use the "memory_*" keys; verbose: Print progress information. cache_model: If ``True``, keep the loaded predictor in a process-wide cache and reuse it on subsequent calls with identical model and @@ -174,10 +179,12 @@ def run_inference_on_file( if out_file is not None and Path(out_file).exists() and not override: return out_file, None - from TPTBox.segmentation.nnUnet_utils.inference_api import ( - load_inf_model, - run_inference, - ) + if min(input_nii[0].shape) <= 1: + shape = input_nii[0].shape + logger.on_fail(f"{shape=} has only {min(shape)} slice in a dimension.") + return None, None + + from TPTBox.segmentation.nnUnet_utils.inference_api import _run_inference_patches, load_inf_model, run_inference if isinstance(idx, int): if auto_download: @@ -251,6 +258,7 @@ def run_inference_on_file( wait_till_gpu_percent_is_free=wait_till_gpu_percent_is_free, tile_batch_size=tile_batch_size, fail_on_missing_memory=fail_on_missing_memory, + logger=logger, ) if cache_model: _model_cache[cache_key] = nnunet @@ -291,6 +299,7 @@ def run_inference_on_file( "\n", nnunet_path, ) + if orientation is not None: logger.print("orientation", orientation, f"from {input_nii[0].orientation}") if verbose else None input_nii = [i.reorient(orientation) for i in input_nii] @@ -300,6 +309,7 @@ def run_inference_on_file( input_nii = [i.rescale_(zoom, mode=mode, verbose=True) for i in input_nii] logger.print(input_nii) logger.print("squash to float16") if verbose else None + input_nii = [squash_so_it_fits_in_float16(i) for i in input_nii] if crop: @@ -308,8 +318,17 @@ def run_inference_on_file( if padd != 0: p = (padd, padd) input_nii = [i.apply_pad([p, p, p], mode="reflect") for i in input_nii] + if _cpu_chunks is None or _cpu_chunks <= 1: + try: + seg_nii, _, softmax_logits = run_inference(input_nii, nnunet, logits=logits, logger=logger) + except MemoryError: + logger.print_error() + seg_nii = _run_inference_patches(input_nii, nnunet, None, logger=logger) + softmax_logits = None + else: + seg_nii = _run_inference_patches(input_nii, nnunet, _cpu_chunks, logger=logger) + softmax_logits = None - seg_nii, uncertainty_nii, softmax_logits = run_inference(input_nii, nnunet, logits=logits) if padd != 0: seg_nii = seg_nii[padd:-padd, padd:-padd, padd:-padd] diff --git a/TPTBox/segmentation/nnUnet_utils/export_prediction.py b/TPTBox/segmentation/nnUnet_utils/export_prediction.py index 712effe..013edae 100755 --- a/TPTBox/segmentation/nnUnet_utils/export_prediction.py +++ b/TPTBox/segmentation/nnUnet_utils/export_prediction.py @@ -10,12 +10,16 @@ from nnunetv2.utilities.label_handling.label_handling import LabelManager from tqdm import tqdm +from TPTBox import Print_Logger from TPTBox.segmentation.nnUnet_utils.plans_handler import ConfigurationManager, PlansManager +logger = Print_Logger() SAFETY_FACTOR = 0.5 # only use 50% of VRAM -def _argmax_with_gpu_fallback(predicted_logits: torch.Tensor | np.ndarray, device: torch.device, chunk_size: int = 64) -> np.ndarray: +def _argmax_with_gpu_fallback( + predicted_logits: torch.Tensor | np.ndarray, device: torch.device, chunk_size: int = 64, logger=logger +) -> np.ndarray: """Computes argmax(0). Tiered argmax: @@ -74,17 +78,17 @@ def _chunked_argmax_cpu(t: torch.Tensor | np.ndarray) -> np.ndarray: full_bytes = _array_bytes(t.shape) free_vram = _get_free_vram(device) - print(f"[argmax] array: {full_bytes / 1e6:.1f} MB, VRAM: {free_vram / 1e6:.1f} MB") + logger.on_debug(f"[argmax] array: {full_bytes / 1e6:.1f} MB, VRAM: {free_vram / 1e6:.1f} MB") # Tier 1: full GPU if full_bytes <= free_vram or device.type == "mps": try: return torch.argmax(t.to(device), dim=0).cpu().numpy().astype(np.int16) except torch.cuda.OutOfMemoryError: - print("[argmax] full GPU OOM despite estimate, trying chunked GPU") + logger.on_fail("[argmax] full GPU OOM despite estimate, trying chunked GPU") empty_cache(device) except Exception as e: - print(e) + logger.on_fail(e) empty_cache(device) for i in range(10): @@ -93,25 +97,27 @@ def _chunked_argmax_cpu(t: torch.Tensor | np.ndarray) -> np.ndarray: if chunk_bytes <= free_vram: chunk_size = max(int(chunk_size / 2**i), 1) break - print(f"[argmax] array chunk: {chunk_bytes / 1e6:.1f} MB, VRAM: {free_vram / 1e6:.1f} MB, {chunk_size=}") + logger.on_debug(f"[argmax] array chunk: {chunk_bytes / 1e6:.1f} MB, VRAM: {free_vram / 1e6:.1f} MB, {chunk_size=}") # Tier 2: chunked GPU if chunk_bytes <= free_vram: - print("[argmax] using chunked GPU") + logger.on_log("[argmax] using chunked GPU") try: return _chunked_argmax_gpu(t, device) except torch.cuda.OutOfMemoryError: - print("[argmax] chunked GPU OOM despite estimate, falling back to CPU") + logger.on_fail("[argmax] chunked GPU OOM despite estimate, falling back to CPU") empty_cache(device) else: - print("[argmax] chunk too large for VRAM, falling back to CPU") + logger.on_debug("[argmax] chunk too large for VRAM, falling back to CPU") # Tier 3: chunked CPU return _chunked_argmax_cpu(t) @torch.inference_mode() -def convert_probabilities_to_segmentation(self, predicted_probabilities: np.ndarray | torch.Tensor, device, chunk_size=64) -> np.ndarray: +def convert_probabilities_to_segmentation( + self, predicted_probabilities: np.ndarray | torch.Tensor, device, chunk_size=64, logger=logger +) -> np.ndarray: """Assumes that inference_nonlinearity was already applied! predicted_probabilities has to have shape (c, x, y(, z)) where c is the number of classes/regions @@ -143,7 +149,7 @@ def convert_probabilities_to_segmentation(self, predicted_probabilities: np.ndar segmentation = segmentation.cpu().numpy() else: # Issensee is no longer right when saying "numpy is faster than torch" newer torch versions no longer have this issue, on GPU we even get a 20x improvment. :facepalm: - segmentation = _argmax_with_gpu_fallback(predicted_probabilities, device, chunk_size=chunk_size) + segmentation = _argmax_with_gpu_fallback(predicted_probabilities, device, chunk_size=chunk_size, logger=logger) return segmentation @@ -157,6 +163,7 @@ def convert_predicted_logits_to_segmentation_with_correct_shape( return_probabilities: bool = False, num_threads_torch: int = 8, device=None, + logger=logger, ) -> np.ndarray: """Revert all preprocessing steps and return a segmentation in the original image space. @@ -203,7 +210,7 @@ def convert_predicted_logits_to_segmentation_with_correct_shape( # Softmax does not change when we use argmax in the next step predicted_logits = label_manager.apply_inference_nonlin(predicted_logits) # segmentation: np.ndarray = label_manager.convert_probabilities_to_segmentation(predicted_logits) # type: ignore - segmentation: np.ndarray = convert_probabilities_to_segmentation(label_manager, predicted_logits, device) + segmentation: np.ndarray = convert_probabilities_to_segmentation(label_manager, predicted_logits, device, logger=logger) segmentation = segmentation.astype(np.uint8 if len(label_manager.foreground_labels) < 255 else np.uint16) del predicted_logits # put segmentation in bbox (revert cropping) @@ -217,7 +224,7 @@ def convert_predicted_logits_to_segmentation_with_correct_shape( # revert transpose segmentation = segmentation.transpose(plans_manager.transpose_backward) - print(segmentation.shape) + logger.print(segmentation.shape) # if return_probabilities: # raise NotImplementedError() # # revert cropping diff --git a/TPTBox/segmentation/nnUnet_utils/inference_api.py b/TPTBox/segmentation/nnUnet_utils/inference_api.py index 9ddc538..d8d43df 100755 --- a/TPTBox/segmentation/nnUnet_utils/inference_api.py +++ b/TPTBox/segmentation/nnUnet_utils/inference_api.py @@ -1,15 +1,16 @@ from __future__ import annotations +from math import ceil from pathlib import Path import numpy as np import torch -from TPTBox import NII, Log_Type, No_Logger +from TPTBox import NII, Log_Type, Print_Logger from .predictor import nnUNetPredictor -logger = No_Logger() +logger = Print_Logger() logger.prefix = "API" _interop = False @@ -36,6 +37,7 @@ def load_inf_model( wait_till_gpu_percent_is_free: float = 0.3, fail_on_missing_memory=False, tile_batch_size: int = 1, + logger=logger, ) -> nnUNetPredictor: """Load and initialise an nnU-Net model predictor from a trained model folder. @@ -144,11 +146,67 @@ def load_inf_model( return predictor +def _split_ranges(length: int, n_chunks: int, overlap: int): + step = length // n_chunks + ranges = [] + for i in range(n_chunks): + start = i * step + end = length if i == n_chunks - 1 else (i + 1) * step + read_start = max(0, start - overlap) + read_end = min(length, end + overlap) + crop_start = start - read_start + crop_end = crop_start + (end - start) + ranges.append((read_start, read_end, crop_start, crop_end)) + return ranges + + +def _run_inference_patches(input_nii: list[NII], nnunet, _cpu_chunks, logger=logger): + """Split image into k _cpu_chunks along the largest dimension. + + Should only be used if there is not enough RAM on the system. + """ + logger.on_debug("Run: _run_inference_patches, You should only run this if you have limited RAM.") + from TPTBox.segmentation.nnUnet_utils.predictor import empty_cache + + empty_cache(nnunet.device) + shape = input_nii[0].shape + split_axis = int(np.argmax(shape)) + + if _cpu_chunks is None: + _cpu_chunks = shape[split_axis] // 250 + patch_size = nnunet.configuration_manager.patch_size + overlap = ceil(patch_size[split_axis] * (1 - nnunet.tile_step_size)) + logger.print(f"{overlap=}") + ranges = _split_ranges( + shape[split_axis], + _cpu_chunks, + overlap, + ) + logger.print(f"{ranges=}") + seg_chunks = [] + for read_start, read_end, crop_start, crop_end in ranges: + chunk_inputs = [] + for nii in input_nii: + sl = [slice(None)] * 3 + sl[split_axis] = slice(read_start, read_end) + chunk_inputs.append(nii[tuple(sl)]) + seg_chunk, _, _ = run_inference(chunk_inputs, nnunet, logits=False, logger=logger) + sl = [slice(None)] * 3 + sl[split_axis] = slice(crop_start, crop_end) + seg_chunk = seg_chunk[tuple(sl)] + seg_chunks.append(seg_chunk) + seg_arr = np.concatenate([s.get_array() for s in seg_chunks], axis=split_axis) + seg_nii = input_nii[0].copy() + seg_nii.seg = True + return seg_nii.set_array_(seg_arr).set_dtype("smallest_uint") + + def run_inference( input_nii: str | NII | list[NII], predictor: nnUNetPredictor, reorient_PIR: bool = False, # noqa: N803 logits: bool = False, + logger=logger, verbose: bool = False, # noqa: ARG001 ) -> tuple[NII, NII | None, np.ndarray | None]: """Run nnU-Net inference on a single image or list of images (multi-channel). @@ -193,10 +251,10 @@ def run_inference( try: img = np.vstack(img_arrs) except Exception: - print("could not stack images; shapes=", [a.shape for a in img_arrs]) + logger.on_fail("could not stack images; shapes=", [a.shape for a in img_arrs]) raise props = {"spacing": i.zoom[::-1]} # PIR - out = predictor.predict_single_npy_array(img, props, save_or_return_probabilities=False) + out = predictor.predict_single_npy_array(img, props, save_or_return_probabilities=False, logger=logger) segmentation: np.ndarray = out # type: ignore softmax_logits = None segmentation = np.transpose(segmentation, axes=segmentation.ndim - 1 - np.arange(segmentation.ndim)) diff --git a/TPTBox/segmentation/nnUnet_utils/predictor.py b/TPTBox/segmentation/nnUnet_utils/predictor.py index ff9e6e2..4e95e71 100755 --- a/TPTBox/segmentation/nnUnet_utils/predictor.py +++ b/TPTBox/segmentation/nnUnet_utils/predictor.py @@ -26,6 +26,8 @@ from TPTBox.segmentation.nnUnet_utils.plans_handler import PlansManager from TPTBox.segmentation.nnUnet_utils.sliding_window_prediction import compute_gaussian, compute_steps_for_sliding_window +logger = Print_Logger() + def get_gpu_memory_MB(device) -> float: """Return the amount of free GPU memory in megabytes for the given device.""" @@ -137,6 +139,7 @@ def initialize_from_trained_model_folder( use_folds: tuple[int | str, ...] | None, checkpoint_name: str = "checkpoint_final.pth", cache_state_dicts: bool = True, + logger=logger, ) -> None: """Load model weights and plans from a trained nnU-Net output directory. @@ -299,11 +302,11 @@ def mapp(d: dict): # Warn early if the requested device is unavailable (runs once, independent of folds). if self.device.type == "cuda" and not torch.cuda.is_available(): - Print_Logger().on_warning( + logger.on_warning( "No CUDA device. If you have a CUDA-able GPU (Nvidia), reinstall pytorch with cuda or for non-cuda devices use ddevice=cpu or ddevice=mps" ) if self.device.type == "mps" and not (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()): - Print_Logger().on_warning("No MPS device found. Use ddevice=cpu or ddevice=mps") + logger.on_warning("No MPS device found. Use ddevice=cpu or ddevice=mps") # loaded_networks holds one ready-to-run network per fold (or None to load weights # lazily per fold). We only cache the single-fold case: previously this loop appended @@ -324,10 +327,7 @@ def mapp(d: dict): self.loaded_networks = [self.network] def predict_single_npy_array( - self, - input_image: np.ndarray, - image_properties: dict, - save_or_return_probabilities: bool = False, + self, input_image: np.ndarray, image_properties: dict, save_or_return_probabilities: bool = False, logger=logger ) -> np.ndarray: """Run full inference on a single numpy image array. @@ -358,13 +358,13 @@ def predict_single_npy_array( verbose=self.verbose, ) if self.verbose: - print("preprocessing") + logger.on_log("preprocessing") dct = next(ppa) if self.verbose: - print("predicting") - predicted_logits = self.predict_logits_from_preprocessed_data(dct["data"]) # type: ignore - print( + logger.on_log("predicting") + predicted_logits = self.predict_logits_from_preprocessed_data(dct["data"], logger=logger) # type: ignore + logger.on_log( "convert_predicted_logits_to_segmentation_with_correct_shape", predicted_logits.shape, ) @@ -379,12 +379,13 @@ def predict_single_npy_array( dct["data_properites"], return_probabilities=save_or_return_probabilities, device=self.device, + logger=logger, ) print("convert_predicted_logits_to_segmentation_with_correct_shape; Took", time.time() - t, " seconds") return ret - def predict_logits_from_preprocessed_data(self, data: torch.Tensor, attempts: int = 10) -> torch.Tensor: + def predict_logits_from_preprocessed_data(self, data: torch.Tensor, attempts: int = 10, logger=logger) -> torch.Tensor: """Run sliding-window inference on already-preprocessed data and average across folds. If running the cascade, the previous-stage segmentation must already be @@ -426,7 +427,7 @@ def predict_logits_from_preprocessed_data(self, data: torch.Tensor, attempts: in else: self.network._orig_mod.load_state_dict(params) # print(type(self.network)) - new_prediction = self.predict_sliding_window_return_logits(data, network=network).to("cpu") + new_prediction = self.predict_sliding_window_return_logits(data, network=network, idx=idx, logger=logger).to("cpu") if prediction is None: prediction = new_prediction else: @@ -437,25 +438,27 @@ def predict_logits_from_preprocessed_data(self, data: torch.Tensor, attempts: in # prediction = prediction.to("cpu") # type: ignore empty_cache(self.device) - except RuntimeError: - print( - "Prediction with perform_everything_on_gpu=True failed due to insufficient GPU memory. " - "Falling back to perform_everything_on_gpu=False. Not a big deal, just slower..." - ) - print("Error:") - traceback.print_exc() - prediction = None - self.perform_everything_on_gpu = False + except RuntimeError as e: + logger.on_fail(e) + logger.on_debug("GPU attempts remaining: ", attempts) empty_cache(self.device) if attempts == 0 or self.fail_on_missing_memory: + logger.on_fail( + "Prediction with perform_everything_on_gpu=True failed due to insufficient GPU memory. " + "Falling back to perform_everything_on_gpu=False. Not a big deal, just slower..." + ) + logger.on_fail("Error:") + logger.print_error() + prediction = None + self.perform_everything_on_gpu = False raise - return self.predict_logits_from_preprocessed_data(data, attempts=attempts - 1) + return self.predict_logits_from_preprocessed_data(data, attempts=attempts - 1, logger=logger) # CPU version if prediction is None: try: - print("Run on CPU") + logger.on_log("Run on CPU") for idx, params in enumerate(self.list_of_parameters): network = None if self.loaded_networks is not None: @@ -467,25 +470,27 @@ def predict_logits_from_preprocessed_data(self, data: torch.Tensor, attempts: in self.network._orig_mod.load_state_dict(params) if prediction is None: - prediction = self.predict_sliding_window_return_logits(data, network=network).to("cpu") # type: ignore + prediction = self.predict_sliding_window_return_logits(data, network=network, idx=99, logger=logger).to("cpu") # type: ignore else: - new_prediction = self.predict_sliding_window_return_logits(data, network=network).to("cpu") # type: ignore + new_prediction = self.predict_sliding_window_return_logits(data, network=network, idx=99, logger=logger).to( + "cpu" + ) # type: ignore prediction += new_prediction if len(self.list_of_parameters) > 1: prediction /= len(self.list_of_parameters) # type: ignore except RuntimeError: - print(f"failed due to insufficient GPU memory. {attempts} attempts remaining.") + logger.on_fail(f"failed due to insufficient GPU memory. {attempts} attempts remaining.") # print("Error:") # traceback.print_exc() empty_cache(self.device) if attempts == 0: raise - print("Sleep for a minute and try again") + logger.on_bold("Sleep for a minute and try again") time.sleep(60) - return self.predict_logits_from_preprocessed_data(data, attempts=attempts - 1) + return self.predict_logits_from_preprocessed_data(data, attempts=attempts - 1, logger=logger) del data - print("Prediction done, transferring to CPU if needed") # if self.verbose else None + logger.on_log("Prediction done, transferring to CPU if needed") # if self.verbose else None prediction = prediction.to("cpu") # type: ignore self.perform_everything_on_gpu = original_perform_everything_on_gpu @@ -565,7 +570,9 @@ def _internal_maybe_mirror_and_predict(self, x: torch.Tensor, network) -> torch. prediction /= num_predictons return prediction - def predict_sliding_window_return_logits(self, input_image: torch.Tensor, network=None) -> np.ndarray | torch.Tensor: + def predict_sliding_window_return_logits( + self, input_image: torch.Tensor, network=None, idx=0, logger=logger + ) -> np.ndarray | torch.Tensor: """Tile the input image and aggregate per-tile logits into a full-volume prediction. Args: @@ -598,14 +605,11 @@ def predict_sliding_window_return_logits(self, input_image: torch.Tensor, networ ): assert len(input_image.shape) == 4, "input_image must be a 4D np.ndarray or torch.Tensor (c, x, y, z)" if self.verbose: - print(f"Input shape: {input_image.shape}") + logger.print(f"Input shape: {input_image.shape}") if self.verbose: - print("step_size:", self.tile_step_size) + logger.print("step_size:", self.tile_step_size) if self.verbose: - print( - "mirror_axes:", - self.allowed_mirroring_axes if self.use_mirroring else None, - ) + logger.print("mirror_axes:", self.allowed_mirroring_axes if self.use_mirroring else None) patch_size = self.configuration_manager.patch_size device = self.device # if input_image is smaller than tile_size we need to pad it to tile_size. @@ -614,9 +618,9 @@ def predict_sliding_window_return_logits(self, input_image: torch.Tensor, networ shape = data.shape[1:] slicers = self._internal_get_sliding_window_slicers(shape) - # print("pixel", np.prod(shape) / 1000000) - # print("memory", get_gpu_memory_MB(device), device) - if get_gpu_util(device) > 1 - self.wait_till_gpu_percent_is_free: + # logger.print("pixel", np.prod(shape) / 1000000) + # logger.print("memory", get_gpu_memory_MB(device), device) + if get_gpu_util(device) > 1 - self.wait_till_gpu_percent_is_free and idx == 0: t = tqdm(range(2400)) # Wait 40 minutes for i in t: util = get_gpu_util(device) @@ -637,7 +641,7 @@ def check_mem(shape): max_memory = self.memory_max min_memory = self.memory_base factor = self.memory_factor - # print(shape, "usage", np.prod(shape) / 1000000 * factor, max(min(memory, max_memory), min_memory)) + # logger.print(shape, "usage", np.prod(shape) / 1000000 * factor, max(min(memory, max_memory), min_memory)) return (np.prod(shape) / 1000000 * factor) + min_memory < max(min(memory, max_memory), min_memory) with tqdm(total=len(slicers), disable=not self.allow_tqdm) as pbar: @@ -655,7 +659,7 @@ def check_mem(shape): print("Fall Back into regular patch mode. Not enough space; s[j] == 1", shape, patch_size, splits, s) break shape_split = [ceil(s / sp) for s, sp in zip(shape, splits)] - # print(shape, patch_size, splits, s, np.prod(shape) / 1000000) + # logger.print(shape, patch_size, splits, s, np.prod(shape) / 1000000) if check_mem(shape_split): try: return self._run_prediction_splits( @@ -667,13 +671,13 @@ def check_mem(shape): pbar=pbar, )[(slice(None), *slicer_revert_padding[1:])] except AttributeError as e: - print("_run_prediction_splits failed; fallback to non splits") - print(e) + logger.on_fail("_run_prediction_splits failed; fallback to non splits") + logger.on_fail(e) break splits[j] += 1 - predicted_logits, n_predictions = self._run_sub(data, network, device, slicers, pbar) + predicted_logits, n_predictions = self._run_sub(data, network, device, slicers, pbar, logger=logger) pbar.desc = "finish" pbar.update(0) predicted_logits /= n_predictions @@ -733,7 +737,7 @@ def _run_prediction_splits( # empty_cache(self.device) return predicted_logits - def _allocate(self, data: torch.Tensor, results_device, pbar: tqdm, gauss: bool = True): + def _allocate(self, data: torch.Tensor, results_device, pbar: tqdm, gauss: bool = True, logger=logger): """Pre-allocate output logit and count tensors; falls back to CPU on OOM.""" pbar.desc = "preallocating arrays" pbar.update(0) @@ -753,37 +757,41 @@ def _allocate(self, data: torch.Tensor, results_device, pbar: tqdm, gauss: bool device=results_device, ) except RuntimeError as e: - n_predictions = None - gaussian = 1 - predicted_logits = 1 - print("allocate FALL BACK CPU") # raise - empty_cache(self.device) - print(e) - # sometimes the stuff is too large for GPUs. In that case fall back to CPU - results_device = torch.device("cpu") - predicted_logits = torch.zeros( - (self.label_manager.num_segmentation_heads, *data.shape[1:]), - dtype=torch.half, - device=results_device, - ) - n_predictions = torch.zeros(data.shape[1:], dtype=torch.half, device=results_device) - if self.use_gaussian and gauss: - gaussian = compute_gaussian( - tuple(self.configuration_manager.patch_size), - sigma_scale=1.0 / 8, - value_scaling_factor=1000, + try: + n_predictions = None + gaussian = 1 + predicted_logits = 1 + logger.on_warning("allocate FALL BACK CPU") # raise + empty_cache(self.device) + logger.print(e) + # sometimes the stuff is too large for GPUs. In that case fall back to CPU + results_device = torch.device("cpu") + predicted_logits = torch.zeros( + (self.label_manager.num_segmentation_heads, *data.shape[1:]), + dtype=torch.half, device=results_device, ) + n_predictions = torch.zeros(data.shape[1:], dtype=torch.half, device=results_device) + if self.use_gaussian and gauss: + gaussian = compute_gaussian( + tuple(self.configuration_manager.patch_size), + sigma_scale=1.0 / 8, + value_scaling_factor=1000, + device=results_device, + ) + except RuntimeError as e: + empty_cache(self.device) + raise MemoryError("Could not allocate RAM.", str(e)) from None # finally: # empty_cache(self.device) return predicted_logits, n_predictions, gaussian, results_device - def _run_sub(self, data: torch.Tensor, network, results_device, slicers, pbar: tqdm, addendum: str = ""): + def _run_sub(self, data: torch.Tensor, network, results_device, slicers, pbar: tqdm, addendum: str = "", logger=logger): """Iterate over slicers, run inference per tile (optionally batched), and accumulate results.""" slicers = list(slicers) try: data = data.to(self.device) # type: ignore - predicted_logits, n_predictions, gaussian, results_device = self._allocate(data, results_device, pbar) + predicted_logits, n_predictions, gaussian, results_device = self._allocate(data, results_device, pbar, logger=logger) pbar.desc = f"running prediction {addendum}" prediction = None work_on = None @@ -804,11 +812,14 @@ def _run_sub(self, data: torch.Tensor, network, results_device, slicers, pbar: t n_predictions[sl[1:]] += gaussian if self.use_gaussian else 1 return predicted_logits, n_predictions # noqa: TRY300 except RuntimeError: - del predicted_logits - del n_predictions - del gaussian - del work_on - del prediction + try: + del predicted_logits + del n_predictions + del gaussian + del work_on + del prediction + except UnboundLocalError: + pass empty_cache(self.device) empty_cache(results_device) self.memory_base += 1000 diff --git a/TPTBox/segmentation/spineps.py b/TPTBox/segmentation/spineps.py index 026b763..b7dd246 100644 --- a/TPTBox/segmentation/spineps.py +++ b/TPTBox/segmentation/spineps.py @@ -14,6 +14,7 @@ def get_outpaths_spineps( dataset: str | Path | None = None, derivative_name: str = "derivative", ignore_bids_filter: bool = True, + _dataset_id_ct_crop=100, ) -> dict[ Literal[ "out_spine", @@ -23,10 +24,11 @@ def get_outpaths_spineps( "out_unc", "out_logits", "out_snap", - "out_ctD", + "out_ctd", "out_snap2", "out_debug", "out_raw", + "out_vibeseg", ], Path, ]: @@ -54,6 +56,7 @@ def get_outpaths_spineps( None, input_format=file_path.format, non_strict_mode=ignore_bids_filter, + _dataset_id_ct_crop=_dataset_id_ct_crop, ) return output_paths @@ -74,7 +77,23 @@ def run_spineps( ignore_compatibility_issues: bool = False, use_cpu: bool = False, **args, -) -> dict: +) -> dict[ + Literal[ + "out_spine", + "out_spine_raw", + "out_vert", + "out_vert_raw", + "out_unc", + "out_logits", + "out_snap", + "out_ctd", + "out_snap2", + "out_debug", + "out_raw", + "out_vibeseg", + ], + Path, +]: """Run the SPINEPS spine segmentation pipeline on a single image. Handles model loading, BIDS path resolution, and delegates to SPINEPS' @@ -105,9 +124,14 @@ def run_spineps( Returns: The output paths dictionary returned by SPINEPS' ``process_img_nii``. """ - from spineps import get_instance_model, get_semantic_model, process_img_nii + from spineps import get_instance_model, get_semantic_model from spineps.get_models import get_actual_model + try: + from spineps import process_img_nii as segment_image + except Exception: + from spineps import segment_image + label = {} try: from spineps.get_models import get_labeling_model @@ -130,7 +154,7 @@ def run_spineps( model_instance = get_actual_model(model_instance, use_cpu=use_cpu) else: model_instance = get_instance_model(model_instance, use_cpu=use_cpu) - output_paths, errcode = process_img_nii( + output_paths, errcode = segment_image( img_ref=file_path, derivative_name=derivative_name, model_semantic=model_semantic, diff --git a/pyproject.toml b/pyproject.toml index c5356aa..cf85aa2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -167,7 +167,6 @@ ignore = [ "PLR2004", "SIM105", "TRY003", - "UP038", "N999", "E741", "SIM118", # dictionary keys diff --git a/unit_tests/test_poi_autogen.py b/unit_tests/test_poi_autogen.py index 3dcd95c..fb318d3 100755 --- a/unit_tests/test_poi_autogen.py +++ b/unit_tests/test_poi_autogen.py @@ -9,6 +9,7 @@ import numpy as np +from TPTBox import NII from TPTBox.core.poi import POI from TPTBox.core.poi_fun.poi_global import POI_Global from TPTBox.tests.test_utils import get_random_ax_code @@ -712,3 +713,258 @@ def test_attribute_access(self): assert poi.info == {"key": "value"} assert np.array_equal(poi.rotation, np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])) assert poi.origin == (10, 10, 10) + + +class Test_Has_Grid_IntersectingVolume(unittest.TestCase): + """Tests for Has_Grid.get_intersecting_volume using OBB/SAT/Sutherland-Hodgman. + + All expected volumes are computed analytically from geometry, never from + the resampling reference implementation. + """ + + @staticmethod + def _make_grid(shape, zoom=(1.0, 1.0, 1.0), origin=(0.0, 0.0, 0.0), rotation=None) -> NII: + """Build a NII (and therefore a Has_Grid) with explicit affine parameters.""" + aff = np.eye(4) + if rotation is not None: + aff[:3, :3] = np.array(rotation) @ np.diag(zoom) + else: + aff[:3, :3] = np.diag(zoom) + aff[:3, 3] = origin + arr = np.zeros(shape, dtype=np.uint8) + return NII.from_numpy(arr, affine=aff, seg=True) + + # ------------------------------------------------------------------ + # Basic identity / self-overlap + # ------------------------------------------------------------------ + + def test_self_overlap_equals_own_volume(self): + """Intersection of a grid with itself must equal its physical volume.""" + nii = self._make_grid((10, 8, 6), zoom=(1.0, 2.0, 3.0)) + expected = 10 * 1.0 * 8 * 2.0 * 6 * 3.0 + self.assertAlmostEqual(nii.get_intersecting_volume(nii), expected, places=3) + + def test_self_overlap_unit_zoom(self): + nii = self._make_grid((5, 5, 5)) + self.assertAlmostEqual(nii.get_intersecting_volume(nii), 125.0, places=3) + + # ------------------------------------------------------------------ + # No overlap → 0 + # ------------------------------------------------------------------ + + def test_no_overlap_separated_x(self): + a = self._make_grid((5, 5, 5), origin=(0.0, 0.0, 0.0)) + b = self._make_grid((5, 5, 5), origin=(10.0, 0.0, 0.0)) # gap of 5 mm + self.assertAlmostEqual(a.get_intersecting_volume(b), 0.0, places=6) + + def test_no_overlap_separated_y(self): + a = self._make_grid((4, 4, 4), origin=(0.0, 0.0, 0.0)) + b = self._make_grid((4, 4, 4), origin=(0.0, 10.0, 0.0)) + self.assertAlmostEqual(a.get_intersecting_volume(b), 0.0, places=6) + + def test_no_overlap_separated_z(self): + a = self._make_grid((4, 4, 4), origin=(0.0, 0.0, 0.0)) + b = self._make_grid((4, 4, 4), origin=(0.0, 0.0, 10.0)) + self.assertAlmostEqual(a.get_intersecting_volume(b), 0.0, places=6) + + def test_touching_faces_no_volume_overlap(self): + """Boxes that share a face but don't penetrate have zero volume intersection.""" + a = self._make_grid((5, 5, 5), zoom=(1.0, 1.0, 1.0), origin=(0.0, 0.0, 0.0)) + # b starts exactly where a ends (index 4 * zoom = 4, next voxel origin at 5) + b = self._make_grid((5, 5, 5), zoom=(1.0, 1.0, 1.0), origin=(5.0, 0.0, 0.0)) + self.assertAlmostEqual(a.get_intersecting_volume(b), 0.0, places=3) + + # ------------------------------------------------------------------ + # Partial axis-aligned overlaps — analytic ground truth + # ------------------------------------------------------------------ + + def test_half_overlap_x_axis(self): + """Two 10³ boxes offset by 5 along x → 5×10×10 = 500 mm³ overlap.""" + a = self._make_grid((10, 10, 10), origin=(0.0, 0.0, 0.0)) + b = self._make_grid((10, 10, 10), origin=(5.0, 0.0, 0.0)) + # physical size 10 mm each; offset 5 mm → overlap 5 mm on x, full on y/z + self.assertAlmostEqual(a.get_intersecting_volume(b), 5.0 * 10.0 * 10.0, places=2) + + def test_partial_overlap_all_axes(self): + """Offset by (2,3,4) on a 10³ grid → (10-2)*(10-3)*(10-4) = 8*7*6 = 336.""" + a = self._make_grid((10, 10, 10), origin=(0.0, 0.0, 0.0)) + b = self._make_grid((10, 10, 10), origin=(2.0, 3.0, 4.0)) + self.assertAlmostEqual(a.get_intersecting_volume(b), 8.0 * 7.0 * 6.0, places=2) + + def test_small_inside_large(self): + """Small box fully contained inside large box → volume equals small box.""" + large = self._make_grid((20, 20, 20), origin=(0.0, 0.0, 0.0)) + small = self._make_grid((5, 5, 5), origin=(7.0, 7.0, 7.0)) + expected = 5.0 * 5.0 * 5.0 + self.assertAlmostEqual(large.get_intersecting_volume(small), expected, places=3) + + def test_large_inside_small_is_symmetric(self): + """Containment is symmetric: small.intersect(large) == large.intersect(small).""" + large = self._make_grid((20, 20, 20), origin=(0.0, 0.0, 0.0)) + small = self._make_grid((5, 5, 5), origin=(7.0, 7.0, 7.0)) + self.assertAlmostEqual( + large.get_intersecting_volume(small), + small.get_intersecting_volume(large), + places=3, + ) + + def test_non_unit_zoom_partial_overlap(self): + """Non-unit zoom: 2 mm voxels, offset 4 mm (2 voxels) on x. + Overlap = (10*2 - 4) * 10*2 * 10*2 = 16 * 20 * 20 = 6400 mm³. + """ + a = self._make_grid((10, 10, 10), zoom=(2.0, 2.0, 2.0), origin=(0.0, 0.0, 0.0)) + b = self._make_grid((10, 10, 10), zoom=(2.0, 2.0, 2.0), origin=(4.0, 0.0, 0.0)) + self.assertAlmostEqual(a.get_intersecting_volume(b), 16.0 * 20.0 * 20.0, places=2) + + def test_asymmetric_zoom_overlap(self): + """Different zoom per axis: overlap region is (8,6,4) mm = 192 mm³.""" + a = self._make_grid((10, 10, 10), zoom=(1.0, 1.0, 1.0), origin=(0.0, 0.0, 0.0)) + b = self._make_grid((10, 10, 10), zoom=(1.0, 1.0, 1.0), origin=(2.0, 4.0, 6.0)) + # overlap on each axis: (10-2, 10-4, 10-6) = (8, 6, 4) + self.assertAlmostEqual(a.get_intersecting_volume(b), 8.0 * 6.0 * 4.0, places=2) + + # ------------------------------------------------------------------ + # Symmetry + # ------------------------------------------------------------------ + + def test_symmetry_axis_aligned(self): + a = self._make_grid((10, 10, 10), origin=(0.0, 0.0, 0.0)) + b = self._make_grid((10, 10, 10), origin=(3.0, 4.0, 5.0)) + # print(a.get_intersecting_volume(b),b.get_intersecting_volume(a),a.get_intersecting_volume(b)-b.get_intersecting_volume(a)) + self.assertAlmostEqual( + a.get_intersecting_volume(b), + b.get_intersecting_volume(a), + places=3, + ) + + def test_symmetry_different_shapes(self): + a = self._make_grid((12, 8, 6), zoom=(1.0, 2.0, 3.0), origin=(0.0, 0.0, 0.0)) + b = self._make_grid((6, 10, 4), zoom=(2.0, 1.0, 4.0), origin=(5.0, 3.0, 2.0)) + self.assertAlmostEqual( + a.get_intersecting_volume(b), + b.get_intersecting_volume(a), + places=3, + ) + + # ------------------------------------------------------------------ + # Rotated boxes — analytic ground truth via symmetry argument + # ------------------------------------------------------------------ + + def test_90_degree_rotation_z_self_overlap(self): + """A box rotated 90° around z intersected with itself → own volume.""" + angle = np.pi / 2 + R = np.array( + [ + [np.cos(angle), -np.sin(angle), 0.0], + [np.sin(angle), np.cos(angle), 0.0], + [0.0, 0.0, 1.0], + ] + ) + nii = self._make_grid((6, 6, 6), zoom=(1.0, 1.0, 1.0), origin=(0.0, 0.0, 0.0), rotation=R) + self.assertAlmostEqual(nii.get_intersecting_volume(nii), 216.0, places=2) + + def test_45_degree_rotation_z_vs_axis_aligned_symmetry(self): + """Rotated box vs axis-aligned: result must be symmetric.""" + angle = np.pi / 4 + R = np.array( + [ + [np.cos(angle), -np.sin(angle), 0.0], + [np.sin(angle), np.cos(angle), 0.0], + [0.0, 0.0, 1.0], + ] + ) + a = self._make_grid((10, 10, 10), zoom=(1.0, 1.0, 1.0), origin=(0.0, 0.0, 0.0)) + b = self._make_grid((10, 10, 10), zoom=(1.0, 1.0, 1.0), origin=(0.0, 0.0, 0.0), rotation=R) + self.assertAlmostEqual( + a.get_intersecting_volume(b), + b.get_intersecting_volume(a), + places=3, + ) + + def test_rotated_box_fully_inside_larger_box(self): + """Small box rotated 45° around z, fully inside a large axis-aligned box. + + The rotated 4×4×4 box has a bounding diagonal of 4*sqrt(2) ≈ 5.66 mm. + The large 20³ box easily contains it regardless of rotation, so the + intersection must equal the small box's own volume = 64 mm³. + """ + angle = np.pi / 4 + R = np.array( + [ + [np.cos(angle), -np.sin(angle), 0.0], + [np.sin(angle), np.cos(angle), 0.0], + [0.0, 0.0, 1.0], + ] + ) + large = self._make_grid((20, 20, 20), origin=(0.0, 0.0, 0.0)) + small = self._make_grid((4, 4, 4), zoom=(1.0, 1.0, 1.0), origin=(8.0, 8.0, 8.0), rotation=R) + self.assertAlmostEqual(large.get_intersecting_volume(small), 64.0, places=2) + + def test_rotated_no_overlap(self): + """Two rotated boxes placed far apart must give zero.""" + angle = np.pi / 3 + R = np.array( + [ + [np.cos(angle), -np.sin(angle), 0.0], + [np.sin(angle), np.cos(angle), 0.0], + [0.0, 0.0, 1.0], + ] + ) + a = self._make_grid((5, 5, 5), zoom=(1.0, 1.0, 1.0), origin=(0.0, 0.0, 0.0), rotation=R) + b = self._make_grid((5, 5, 5), zoom=(1.0, 1.0, 1.0), origin=(50.0, 50.0, 50.0), rotation=R) + self.assertAlmostEqual(a.get_intersecting_volume(b), 0.0, places=6) + + # ------------------------------------------------------------------ + # Result is always non-negative + # ------------------------------------------------------------------ + + def test_result_non_negative_random(self): + rng = np.random.default_rng(42) + for _ in range(20): + shape_a = tuple(rng.integers(3, 15, size=3).tolist()) + shape_b = tuple(rng.integers(3, 15, size=3).tolist()) + zoom_a = tuple(rng.uniform(0.5, 3.0, size=3).tolist()) + zoom_b = tuple(rng.uniform(0.5, 3.0, size=3).tolist()) + origin_a = tuple(rng.uniform(-10, 10, size=3).tolist()) + origin_b = tuple(rng.uniform(-10, 10, size=3).tolist()) + a = self._make_grid(shape_a, zoom=zoom_a, origin=origin_a) + b = self._make_grid(shape_b, zoom=zoom_b, origin=origin_b) + self.assertGreaterEqual(a.get_intersecting_volume(b), 0.0) + + def test_symmetry_random(self): + rng = np.random.default_rng(7) + for _ in range(20): + shape_a = tuple(rng.integers(3, 12, size=3).tolist()) + shape_b = tuple(rng.integers(3, 12, size=3).tolist()) + zoom_a = tuple(rng.uniform(0.5, 2.0, size=3).tolist()) + zoom_b = tuple(rng.uniform(0.5, 2.0, size=3).tolist()) + origin_a = tuple(rng.uniform(-5, 5, size=3).tolist()) + origin_b = tuple(rng.uniform(-5, 5, size=3).tolist()) + a = self._make_grid(shape_a, zoom=zoom_a, origin=origin_a) + b = self._make_grid(shape_b, zoom=zoom_b, origin=origin_b) + self.assertAlmostEqual( + a.get_intersecting_volume(b), + b.get_intersecting_volume(a), + places=3, + ) + + # ------------------------------------------------------------------ + # Volume bounded by both inputs + # ------------------------------------------------------------------ + + def test_intersection_never_exceeds_either_input(self): + rng = np.random.default_rng(99) + for _ in range(20): + shape_a = tuple(rng.integers(4, 14, size=3).tolist()) + shape_b = tuple(rng.integers(4, 14, size=3).tolist()) + zoom_a = tuple(rng.uniform(0.5, 2.5, size=3).tolist()) + zoom_b = tuple(rng.uniform(0.5, 2.5, size=3).tolist()) + origin_a = tuple(rng.uniform(-8, 8, size=3).tolist()) + origin_b = tuple(rng.uniform(-8, 8, size=3).tolist()) + a = self._make_grid(shape_a, zoom=zoom_a, origin=origin_a) + b = self._make_grid(shape_b, zoom=zoom_b, origin=origin_b) + vol_inter = a.get_intersecting_volume(b) + vol_a = a.voxel_volume() * np.prod(shape_a) + vol_b = b.voxel_volume() * np.prod(shape_b) + self.assertLessEqual(vol_inter, vol_a + 1e-3) + self.assertLessEqual(vol_inter, vol_b + 1e-3)