diff --git a/AGENTS.md b/AGENTS.md index 745ce10..9e76207 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -124,7 +124,7 @@ Version bumping: `bumpver update --patch`, `--minor`, or `--major`. - Scripts that instantiate `SegmentChestTotalSegmentator` must guard the top-level invocation with `if __name__ == "__main__":` on Windows (`torch.multiprocessing` requires it). -- Single quotes for strings; double quotes for docstrings. Keep lines at or +- Double quotes for strings and docstrings. Keep lines at or below 88 characters. - Full type hints are required under strict mypy. Use `Optional[X]`, not `X | None`. diff --git a/CLAUDE.md b/CLAUDE.md index bf268f0..af531fa 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -148,7 +148,7 @@ Document via docstrings and inline comments. ## Code Style -- Single quotes for strings; double quotes for docstrings +- Double quotes for strings and docstrings - Full type hints (`mypy` strict; `disallow_untyped_defs = true`) - `Optional[X]` not `X | None` (ruff `UP007` suppressed) - Breaking changes are acceptable — backward compatibility is not a priority diff --git a/docs/api/utilities/index.rst b/docs/api/utilities/index.rst index 0e4d2c0..62277bb 100644 --- a/docs/api/utilities/index.rst +++ b/docs/api/utilities/index.rst @@ -21,6 +21,7 @@ Quick Links **Utility Modules**: * :doc:`image_tools` - Image processing utilities + * :doc:`labelmap_tools` - Labelmap to registration-mask conversion * :doc:`transform_tools` - Transform operations * :doc:`contour_tools` - Contour processing * :doc:`image_conversion` - 4D image to 3D time-series utilities @@ -34,6 +35,7 @@ Module Documentation :maxdepth: 2 image_tools + labelmap_tools transform_tools contour_tools image_conversion diff --git a/docs/api/utilities/labelmap_tools.rst b/docs/api/utilities/labelmap_tools.rst new file mode 100644 index 0000000..7b453b5 --- /dev/null +++ b/docs/api/utilities/labelmap_tools.rst @@ -0,0 +1,19 @@ +==================================== +Labelmap Tools +==================================== + +.. currentmodule:: physiomotion4d + +Convert segmentation labelmaps into binary registration masks, with optional +label exclusion and physically isotropic dilation. + +Module Reference +================ + +.. automodule:: physiomotion4d.labelmap_tools + :members: + :undoc-members: + +.. rubric:: Navigation + +:doc:`index` | :doc:`image_tools` | :doc:`transform_tools` diff --git a/docs/contributing.rst b/docs/contributing.rst index 4c62e97..d6b5d96 100644 --- a/docs/contributing.rst +++ b/docs/contributing.rst @@ -141,7 +141,7 @@ PhysioMotion4D follows strict code quality standards using modern, fast tooling. Formatting and Linting with Ruff --------------------------------- -We use **Ruff** for all formatting and linting (line length: 88, single quotes): +We use **Ruff** for all formatting and linting (line length: 88, double quotes): .. code-block:: bash diff --git a/docs/developer/registration_images.rst b/docs/developer/registration_images.rst index dab4078..1c0d023 100644 --- a/docs/developer/registration_images.rst +++ b/docs/developer/registration_images.rst @@ -25,7 +25,11 @@ Basic Pattern registered = registrar.get_registered_image() The result dictionary contains ``forward_transform``, ``inverse_transform``, -and ``loss``. +and ``loss``. Applying the right one is critical and direction-dependent: +``forward_transform`` warps the moving image onto the fixed grid, while +``inverse_transform`` warps moving points/landmarks into fixed space (image and +point warps use opposite transforms). See +:doc:`transform_conventions` for the full rules. Time Series =========== @@ -57,5 +61,6 @@ Development Notes See Also ======== +* :doc:`transform_conventions` * :doc:`../api/registration/index` * :doc:`workflows` diff --git a/docs/developer/registration_models.rst b/docs/developer/registration_models.rst index 3102f1a..466b904 100644 --- a/docs/developer/registration_models.rst +++ b/docs/developer/registration_models.rst @@ -45,9 +45,14 @@ Development Notes * Convert volumetric meshes to surfaces before surface registration when needed. * Treat ITK/PyVista coordinate transforms as high-risk and add focused tests. * Keep synthetic test meshes small and deterministic. +* ``RegisterModelsPCA`` returns ``forward_point_transform`` / + ``inverse_point_transform``. These are **point** transforms whose orientation + is opposite to the image-registration transforms; see + :doc:`transform_conventions` before applying them to images or meshes. See Also ======== +* :doc:`transform_conventions` * :doc:`../api/model_registration/index` * :doc:`workflows` diff --git a/docs/developer/transform_conventions.rst b/docs/developer/transform_conventions.rst new file mode 100644 index 0000000..ee80cf0 --- /dev/null +++ b/docs/developer/transform_conventions.rst @@ -0,0 +1,137 @@ +=============================== +Transform Direction Conventions +=============================== + +Registration in PhysioMotion4D produces a pair of transforms, and choosing the +wrong one of the pair is the single most common registration mistake. The rules +are simple but easy to get backwards, because **warping an image and warping a +point require opposite transforms**, and because **model (PCA) registration +returns its transforms in the opposite orientation from image registration**. + +Read this page before applying any transform to an image, mask, contour, or +landmark. + +The two transform families +=========================== + +Image registration + :class:`physiomotion4d.RegisterImagesANTS`, + :class:`physiomotion4d.RegisterImagesICON`, and + :class:`physiomotion4d.RegisterImagesGreedy` register a *moving* image to a + *fixed* image and return a dict with ``forward_transform`` and + ``inverse_transform``. :class:`physiomotion4d.RegisterTimeSeriesImages` + returns the list-valued ``forward_transforms`` / ``inverse_transforms``. + +Model (PCA) registration + :class:`physiomotion4d.RegisterModelsPCA` deforms a *template* model toward + a *target* (patient) and, via ``compute_pca_transforms()``, returns + ``forward_point_transform`` and ``inverse_point_transform``. These are + **point transforms**, oriented opposite to the image-registration transforms + (see `PCA point transforms`_ below). + +Image warping vs. point warping use opposite transforms +======================================================== + +ITK resampling is a *pull-back* operation. To build the warped image on the +fixed grid, :func:`TransformTools.transform_image` (an ``itk.ResampleImageFilter``) +visits every fixed-grid sample ``q`` and looks up the moving image at +``transform.TransformPoint(q)``. The transform it needs therefore maps +**fixed-space coordinates to moving-space coordinates**. + +Warping a *point* (landmark, contour vertex, mesh node) is a *push-forward* +operation: :func:`TransformTools.transform_pvcontour` / +:func:`TransformTools.transform_dataset` apply ``transform.TransformPoint(p)`` +directly to each input point. To move a moving-space landmark to its location in +the fixed image, the transform must map **moving-space coordinates to +fixed-space coordinates** -- the inverse of the image-warp transform. + +So for the **same** moving-to-fixed registration result: + +.. list-table:: Image registration: which transform to apply + :header-rows: 1 + :widths: 50 25 25 + + * - Goal + - Transform + - Helper + * - Warp the **moving image** into fixed space (onto the fixed grid) + - ``forward_transform`` + - :func:`TransformTools.transform_image` + * - Warp **moving points / contours / landmarks** into fixed space + - ``inverse_transform`` + - :func:`TransformTools.transform_pvcontour` + * - Warp the **fixed image** into moving space (e.g. time-series reconstruction) + - ``inverse_transform`` + - :func:`TransformTools.transform_image` + * - Warp **fixed points / contours / landmarks** into moving space + - ``forward_transform`` + - :func:`TransformTools.transform_pvcontour` + +The first two rows are the everyday case (warping the registered moving data +into the fixed/reference frame): the **image uses** ``forward_transform``, the +**points use** ``inverse_transform``. The last two rows are the mirror image; +:meth:`physiomotion4d.RegisterTimeSeriesImages.reconstruct_time_series` is the +canonical consumer of ``inverse_transform`` for image warping (it resamples the +fixed image back onto each moving frame's grid). + +.. note:: + + All three image-registration backends (ANTS, ICON, Greedy) follow this same + convention. ``transform_image(moving, forward_transform, fixed)`` is the + correct call to warp the moving image onto the fixed grid for every backend. + +PCA point transforms +==================== + +:class:`physiomotion4d.RegisterModelsPCA` builds ``forward_point_transform`` +directly from the template-to-target point displacement, so +``forward_point_transform.TransformPoint(template_point)`` returns the +corresponding *target* point. As a **point** map it goes template (moving) to +target (fixed) -- which is the same orientation as image registration's +``inverse_transform``, and therefore the **opposite** orientation of image +registration's ``forward_transform``. + +Concretely, treating the template as the moving object and the patient/target as +the fixed object: + +.. list-table:: Same goal, opposite transform names across the two families + :header-rows: 1 + :widths: 50 25 25 + + * - Goal + - Image registration + - PCA model registration + * - Warp the **image** (moving/template space -> fixed/target grid) + - ``forward_transform`` + - ``inverse_point_transform`` + * - Warp **points / meshes** (moving/template -> fixed/target) + - ``inverse_transform`` + - ``forward_point_transform`` + +In other words, ``forward_point_transform`` plays the role that +``inverse_transform`` plays for image registration, and +``inverse_point_transform`` plays the role of ``forward_transform``. Deforming +the template mesh onto the patient (the usual PCA use, performed internally by +``transform_template_model()`` and ``transform_point()``) uses +``forward_point_transform``; resampling an image with the PCA result uses +``inverse_point_transform``. + +Rule of thumb +============= + +* **Images pull back; points push forward.** For one registration result, the + image and the points always use the two *different* members of the transform + pair. +* **Image into the reference frame** -> ``forward_transform`` (image + registration) / ``inverse_point_transform`` (PCA). +* **Points into the reference frame** -> ``inverse_transform`` (image + registration) / ``forward_point_transform`` (PCA). +* When in doubt, warp a known landmark and a small image patch and confirm they + land in the same place before trusting a pipeline. + +See Also +======== + +* :doc:`registration_images` +* :doc:`registration_models` +* :doc:`utilities` diff --git a/docs/index.rst b/docs/index.rst index 458de99..320a9fd 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -162,6 +162,7 @@ per-tutorial implementation details. developer/segmentation developer/registration_images developer/registration_models + developer/transform_conventions developer/usd_generation developer/utilities diff --git a/experiments/Heart-GatedCT_To_USD/test_compare_registration_speed.py b/experiments/Heart-GatedCT_To_USD/test_compare_registration_speed.py deleted file mode 100644 index addadea..0000000 --- a/experiments/Heart-GatedCT_To_USD/test_compare_registration_speed.py +++ /dev/null @@ -1,182 +0,0 @@ -#!/usr/bin/env python -# %% [markdown] -# # Compare registration speed: Greedy vs ANTs vs ICON -# -# This notebook times **Greedy**, **ANTs**, and **ICON** when registering two time points of CT from the Slicer-Heart-CT data (TruncalValve 4D CT). -# -# **Prerequisites:** Run `0-download_and_convert_4d_to_3d.py` first so that `data/Slicer-Heart-CT/` contains the 4D NRRD and the 3D slice series (`slice_000.mha`, `slice_001.mha`, ...), and `results/slice_fixed.mha` exists. - -# %% -import os -import time - -import itk -import matplotlib.pyplot as plt -import pandas as pd -from itk import TubeTK as ttk - -from physiomotion4d.test_tools import TestTools -from physiomotion4d.register_images_ants import RegisterImagesANTS -from physiomotion4d.register_images_greedy import RegisterImagesGreedy -from physiomotion4d.register_images_icon import RegisterImagesICON - -_HERE = os.path.dirname(os.path.abspath(__file__)) - -# %% -data_dir = os.path.join(_HERE, "..", "..", "data", "Slicer-Heart-CT") -output_dir = os.path.join(_HERE, "results") -os.makedirs(output_dir, exist_ok=True) - -# Fixed = reference time point; moving = time point to align to fixed -fixed_image_path = os.path.join(output_dir, "slice_fixed.mha") -moving_image_path = os.path.join(data_dir, "slice_000.mha") - -if not os.path.exists(fixed_image_path): - raise FileNotFoundError( - f"Fixed image not found: {fixed_image_path}. " - "Run 0-download_and_convert_4d_to_3d.py first." - ) -if not os.path.exists(moving_image_path): - raise FileNotFoundError( - f"Moving image not found: {moving_image_path}. " - "Run 0-download_and_convert_4d_to_3d.py first." - ) - -fixed_image = itk.imread(fixed_image_path) -moving_image = itk.imread(moving_image_path) -print(f"Fixed image: {itk.size(fixed_image)}, spacing {itk.spacing(fixed_image)}") -print(f"Moving image: {itk.size(moving_image)}, spacing {itk.spacing(moving_image)}") - -# %% [markdown] -# ## Optional: downsample for faster comparison -# -# Set `downsample_factor = 1.0` to use full resolution (slower). Use e.g. `0.5` to halve each dimension for a quicker run. - -# %% -downsample_factor = 0.5 # 1.0 = full resolution - -if downsample_factor != 1.0: - resampler_f = ttk.ResampleImage.New(Input=fixed_image) - resampler_f.SetResampleFactor([downsample_factor] * 3) - resampler_f.Update() - fixed_image = resampler_f.GetOutput() - - resampler_m = ttk.ResampleImage.New(Input=moving_image) - resampler_m.SetResampleFactor([downsample_factor] * 3) - resampler_m.Update() - moving_image = resampler_m.GetOutput() - print(f"Downsampled to factor {downsample_factor}") - print(f" Fixed: {itk.size(fixed_image)}") - print(f" Moving: {itk.size(moving_image)}") -else: - print("Using full resolution.") - -# %% [markdown] -# ## Run each method and record time -# -# All three use **deformable** registration (Greedy: affine + deformable; ANTs: SyN; ICON: deep learning). Settings are chosen for a fair comparison with reduced iterations so the notebook runs in a few minutes. - -# %% -results_list = [] - -# --- Greedy (deformable) --- -try: - reg_g = RegisterImagesGreedy() - reg_g.set_modality("ct") - reg_g.set_transform_type("Deformable") - reg_g.set_number_of_iterations([10, 5, 2]) - reg_g.set_fixed_image(fixed_image) - - t0 = time.perf_counter() - out_g = reg_g.register(moving_image) - elapsed_g = time.perf_counter() - t0 - - loss_g = out_g.get("loss") - results_list.append( - { - "method": "Greedy", - "time_sec": round(elapsed_g, 2), - "loss": float(loss_g) if loss_g is not None else None, - } - ) - print(f"Greedy: {elapsed_g:.2f} s") -except Exception as e: - results_list.append({"method": "Greedy", "time_sec": None, "loss": None}) - print(f"Greedy: failed - {e}") - -# --- ANTs (deformable SyN) --- -try: - reg_a = RegisterImagesANTS() - reg_a.set_modality("ct") - reg_a.set_transform_type("Deformable") - reg_a.set_number_of_iterations([10, 5, 2]) # reduced for speed - reg_a.set_fixed_image(fixed_image) - - t0 = time.perf_counter() - out_a = reg_a.register(moving_image) - elapsed_a = time.perf_counter() - t0 - - loss_a = out_a.get("loss") - results_list.append( - { - "method": "ANTs", - "time_sec": round(elapsed_a, 2), - "loss": float(loss_a) if loss_a is not None else None, - } - ) - print(f"ANTs: {elapsed_a:.2f} s") -except Exception as e: - results_list.append({"method": "ANTs", "time_sec": None, "loss": None}) - print(f"ANTs: failed - {e}") - -# --- ICON (deformable, GPU) --- -try: - reg_i = RegisterImagesICON() - reg_i.set_modality("ct") - reg_i.set_number_of_iterations(50) - reg_i.set_fixed_image(fixed_image) - - t0 = time.perf_counter() - out_i = reg_i.register(moving_image) - elapsed_i = time.perf_counter() - t0 - - loss_i = out_i.get("loss") - results_list.append( - { - "method": "ICON", - "time_sec": round(elapsed_i, 2), - "loss": float(loss_i) if loss_i is not None else None, - } - ) - print(f"ICON: {elapsed_i:.2f} s") -except Exception as e: - results_list.append({"method": "ICON", "time_sec": None, "loss": None}) - print(f"ICON: failed - {e}") - -df = pd.DataFrame(results_list) - -# %% -print(df) - -# %% -fig, ax = plt.subplots(figsize=(6, 4)) -valid = df["time_sec"].notna() -if valid.any(): - methods = df.loc[valid, "method"] - times = df.loc[valid, "time_sec"] - ax.bar(methods, times, color=["#2ecc71", "#3498db", "#9b59b6"]) - ax.set_ylabel("Time (seconds)") - ax.set_title("Registration time: two time points (Slicer-Heart-CT)") - plt.tight_layout() - if not TestTools.running_as_test(): - plt.show() -else: - print("No successful runs to plot.") - -# %% [markdown] -# ## Notes -# -# - **Greedy**: CPU-based, often faster than ANTs for comparable quality; see [Greedy](https://greedy.readthedocs.io/) and [picsl-greedy](https://pypi.org/project/picsl-greedy/). -# - **ANTs**: CPU-based, very widely used; typically slower than Greedy for similar settings. -# - **ICON**: GPU-based (UniGradIcon); speed depends on GPU. Loss values are not directly comparable across methods. -# - For a quicker comparison, use `downsample_factor = 0.5` or reduce `number_of_iterations` further. diff --git a/experiments/LongitudinalRegistration/.gitignore b/experiments/LongitudinalRegistration/.gitignore index 19960a9..f850328 100644 --- a/experiments/LongitudinalRegistration/.gitignore +++ b/experiments/LongitudinalRegistration/.gitignore @@ -1 +1,2 @@ uniGradICON +fixed diff --git a/experiments/LongitudinalRegistration/1-finetune_icon.py b/experiments/LongitudinalRegistration/1-finetune_icon.py deleted file mode 100644 index 968a078..0000000 --- a/experiments/LongitudinalRegistration/1-finetune_icon.py +++ /dev/null @@ -1,186 +0,0 @@ -# %% [markdown] -# # Fine-tune uniGradICON on Duke 4D Gated CT Data -# -# Discovers per-patient gated CT images and their precomputed -# SegmentHeartSimpleware labelmaps and applies the project-wide fixed 80/20 -# train/test split (sort patients in ``ref_data_dir`` by filename; the first -# 80% are train, the last 20% are test). The train cohort is handed to -# :class:`WorkflowFineTuneICONRegistration`, which builds the paired dataset -# JSON, YAML config, and derived loss-function masks, then launches -# ``unigradicon.finetuning.finetune`` as a subprocess. -# -# ``2-recon_4d_icon_eval.py`` re-derives the same split from the same sorted -# patient list — no cached split file is needed. -# -# Each patient directory under ``src_data_dir_base`` is one ``subject_id``; -# all of that patient's gated time-point frames form a paired training group. -# Frames whose labelmap is missing on disk are dropped from the dataset. - -# %% -import os -from pathlib import Path - -import itk - -from physiomotion4d import WorkflowFineTuneICONRegistration -from physiomotion4d.register_images_icon import RegisterImagesICON - -# %% [markdown] -# ## 1. Configure data, output locations, and the train/test split - -# %% -ref_data_dir = Path("d:/PhysioMotion4D/duke_data/ref_images") -src_data_dir_base = Path("d:/PhysioMotion4D/duke_data/gated_nii") -segmentation_dir_base = Path("d:/PhysioMotion4D/duke_data/simple_ascardio") - -# Where the workflow writes the dataset JSON, YAML config, derived masks, and -# the uniGradICON ``checkpoints/`` tree. experiment_dir resolves to -# ``output_dir / fine_tune_name``. -output_dir = Path("./results") -fine_tune_name = "icon_finetuned" - -# Fixed train/test split: sort patients in ``ref_data_dir`` by filename; -# first 80% are train, last 20% are test. ``2-recon_4d_icon_eval.py`` applies -# the same rule so the two scripts agree without a cached split record. -train_fraction = 0.8 - -# Local clone of uniGradICON (feat-add-finetuning branch) — prepended to -# PYTHONPATH so the subprocess picks up the local source instead of the -# installed package. Set to ``None`` to use the pip-installed unigradicon. -unigradicon_src_path: Path | None = Path(__file__).parent / "uniGradICON" / "src" - -# %% [markdown] -# ## 2. Enumerate patients and apply the fixed 80/20 split -# -# Sort ``ref_data_dir`` by filename to produce the canonical patient order. -# The first 80% become the train cohort; the last 20% are the held-out test -# cohort that ``2-recon_4d_icon_eval.py`` will evaluate. - -# %% -ref_files = sorted( - p - for p in ref_data_dir.iterdir() - if p.name.startswith("pm00") and p.suffixes[-2:] == [".nii", ".gz"] -) -all_patient_ids = [p.name[:6] for p in ref_files] -print(f"Found {len(all_patient_ids)} patients under {ref_data_dir}") - -if len(all_patient_ids) < 2: - raise FileNotFoundError( - f"Need at least 2 patients to form a train/test split; " - f"discovered {len(all_patient_ids)} under {ref_data_dir}" - ) - -n_train = max( - 1, - min(len(all_patient_ids) - 1, round(train_fraction * len(all_patient_ids))), -) -train_subjects = all_patient_ids[:n_train] -test_subjects = all_patient_ids[n_train:] -print(f" Train (first {n_train}): {train_subjects}") -print(f" Test (last {len(test_subjects)}): {test_subjects}") - -# %% [markdown] -# ## 3. Gather the train cohort's gated frames and labelmaps -# -# For each train-cohort patient, list gated frames in -# ``src_data_dir_base / `` (excluding ``"nop"`` non-gated -# references) and pair each frame with its -# ``_labelmap.nii.gz`` under ``segmentation_dir_base / ``. -# Patients with no source directory or no valid frames are skipped here only -# — they remain part of the canonical train list above, but contribute no -# training data. Missing labelmaps are recorded as ``None`` so the workflow -# skips just that frame. - -# %% -train_image_files: list[list[str]] = [] -train_segmentation_files: list[list[str | None]] = [] -valid_train_subjects: list[str] = [] - -for patient_id in train_subjects: - src_dir = src_data_dir_base / patient_id - seg_dir = segmentation_dir_base / patient_id - - if not src_dir.is_dir(): - print(f" Skipping {patient_id}: source dir {src_dir} not found") - continue - - frame_names = sorted( - f for f in os.listdir(src_dir) if "nop" not in f and f.endswith(".nii.gz") - ) - if not frame_names: - print(f" Skipping {patient_id}: no valid frames in {src_dir}") - continue - - image_paths = [str(src_dir / f) for f in frame_names] - seg_paths: list[str | None] = [] - for f in frame_names: - labelmap = seg_dir / f.replace(".nii.gz", "_labelmap.nii.gz") - seg_paths.append(str(labelmap) if labelmap.exists() else None) - - train_image_files.append(image_paths) - train_segmentation_files.append(seg_paths) - valid_train_subjects.append(patient_id) - - n_seg = sum(1 for s in seg_paths if s is not None) - print(f" {patient_id}: {len(image_paths)} frames, {n_seg} with labelmap") - -# %% [markdown] -# ## 4. Pre-compute loss-function masks next to each labelmap -# -# Use :meth:`RegisterImagesICON.create_mask` (``>0`` threshold + 5 mm -# physical-radius dilation) to derive each frame's binary heart-ROI mask and -# write it as ``_mask.nii.gz`` in the labelmap's own directory. -# Pre-computing here means the workflow does not have to re-derive masks -# during ``run_fine_tuning`` and the same masks are reused by downstream -# evaluation scripts. - -# %% -mask_dilation_mm = 5.0 -train_mask_files: list[list[str | None]] = [] -for image_paths, seg_paths in zip( - train_image_files, train_segmentation_files, strict=True -): - mask_paths: list[str | None] = [] - for seg_path in seg_paths: - if seg_path is None: - mask_paths.append(None) - continue - seg_p = Path(seg_path) - stem = seg_p.name - stem = stem[:-7] if stem.endswith(".nii.gz") else seg_p.stem - mask_p = seg_p.parent / f"{stem}_mask.nii.gz" - if not mask_p.exists(): - mask = RegisterImagesICON.create_mask( - itk.imread(str(seg_p)), dilation_mm=mask_dilation_mm - ) - itk.imwrite(mask, str(mask_p), compression=True) - mask_paths.append(str(mask_p)) - train_mask_files.append(mask_paths) - -# %% [markdown] -# ## 5. Fine-tune uniGradICON on the train cohort -# -# The workflow consumes both the labelmaps (for paired-with-seg training and -# ``use_label``) and the pre-computed masks (for ``loss_function_masking``) -# and launches ``unigradicon.finetuning.finetune`` as a subprocess. The -# final checkpoint lands at -# :meth:`WorkflowFineTuneICONRegistration.expected_weights_path`, which is -# the default ``--finetuned-weights-path`` read by ``2-recon_4d_icon_eval.py``. - -# %% -workflow = WorkflowFineTuneICONRegistration( - subject_image_files=train_image_files, - output_dir=output_dir, - fine_tune_name=fine_tune_name, - subject_ids=valid_train_subjects, - subject_segmentation_files=train_segmentation_files, - subject_mask_files=train_mask_files, - mask_dilation_mm=mask_dilation_mm, - unigradicon_src_path=unigradicon_src_path, - epochs=100, -) - -weights_path = workflow.run_fine_tuning() -print(f"\nFine-tuning complete. Expected weights at: {weights_path}") -print(f"Held-out test cohort (for 2-recon_4d_icon_eval.py): {test_subjects}") diff --git a/experiments/LongitudinalRegistration/1-initial_registration.py b/experiments/LongitudinalRegistration/1-initial_registration.py new file mode 100644 index 0000000..8775814 --- /dev/null +++ b/experiments/LongitudinalRegistration/1-initial_registration.py @@ -0,0 +1,614 @@ +# %% [markdown] +# Initial registration: compare ANTS vs Greedy vs ICON on the Duke gated CT cohort +# +# * :class:`RegisterImagesANTS` (CPU, SyN deformable) +# * :class:`RegisterImagesGreedy` (CPU, deformable) +# * :class:`RegisterImagesICON` (GPU, uniGradICON deformable) +# +# %% +import csv +import shutil +import time +from pathlib import Path +from typing import Optional + +import itk +import numpy as np + +from physiomotion4d.labelmap_tools import LabelmapTools +from physiomotion4d.landmark_tools import LandmarkTools +from physiomotion4d.register_images_ants import RegisterImagesANTS +from physiomotion4d.register_images_greedy import RegisterImagesGreedy +from physiomotion4d.register_images_icon import RegisterImagesICON +from physiomotion4d.transform_tools import TransformTools + +# %% +ref_data_dir = Path("d:/PhysioMotion4D/duke_data/ref_images") +src_data_dir_base = Path("d:/PhysioMotion4D/duke_data/gated_nii") +segmentation_dir_base = Path("d:/PhysioMotion4D/duke_data/simple_ascardio") + +use_mask_list = [False] +use_labelmap_list = [False] + +# ICON only +use_mass_list = [False] + +methods_list = [ + ["Greedy"], +] +number_of_iterations_ANTS_list = [ + [40, 20, 10], +] +number_of_iterations_greedy_list = [ + [60, 20, 10], +] +number_of_iterations_ICON_list = [100] + +exclude_tokens = ["nop"] +ref_suffix = "_ref" +icon_weights_path: Optional[Path] = None +mask_dilation_mm = 3.0 +use_crop = False +fixed_image_resolution_mm = 0.0 + +debug_subjects = [] # ["pm0002", "pm0003", "pm0004"] + +labelmap_tools = LabelmapTools() +landmark_tools = LandmarkTools() +transform_tools = TransformTools() + +# %% +ref_files = sorted( + p + for p in ref_data_dir.iterdir() + if p.name.startswith("pm00") and p.suffixes[-2:] == [".nii", ".gz"] +) +all_patient_ids = [p.name[:6] for p in ref_files] +print(f"Found {len(all_patient_ids)} patients under {ref_data_dir}") + +if debug_subjects: + cohort = [pid for pid in all_patient_ids if pid in debug_subjects] + print( + f"DEBUG: restricting cohort to {debug_subjects} -> " + f"{len(cohort)} matching patients" + ) +else: + cohort = all_patient_ids + + +def per_label_dice( + fixed_labelmap: itk.Image, warped_labelmap: itk.Image +) -> dict[int, float]: + """Return ``{label_id: Dice}`` for every positive label present in + either the fixed or the warped labelmap. + + Arrays come back from :func:`itk.array_from_image` in shape + ``(Z, Y, X)`` (numpy reverses ITK's index order); we compare element-wise + so the axis convention does not matter as long as both labelmaps live + on the same reference grid (guaranteed because ``warped_labelmap`` was + resampled with ``fixed_labelmap`` as the reference image). + """ + fixed_array = itk.array_from_image(fixed_labelmap) + warped_array = itk.array_from_image(warped_labelmap) + labels = sorted( + {int(v) for v in np.unique(fixed_array)} + | {int(v) for v in np.unique(warped_array)} + ) + labels = [label for label in labels if label > 0] + + dice_by_label: dict[int, float] = {} + for label in labels: + a = fixed_array == label + b = warped_array == label + denom = int(a.sum()) + int(b.sum()) + if denom == 0: + continue + intersection = int(np.logical_and(a, b).sum()) + dice_by_label[label] = 2.0 * intersection / denom + return dice_by_label + + +def warp_landmarks( + inverse_transform: itk.Transform, + moving_landmarks: dict[str, tuple[float, float, float]], +) -> dict[str, tuple[float, float, float]]: + """Warp every moving landmark into reference space. + + Point/landmark warping uses ``inverse_transform`` -- the moving-space -> + fixed-space point map -- which is the opposite of the transform used to + warp the moving image onto the fixed grid (images pull back; points push + forward). Returns a ``{label: (x, y, z)}`` dict in LPS. See + docs/developer/transform_conventions. + """ + new_landmarks = {} + for name, point in moving_landmarks.items(): + new_point = inverse_transform.TransformPoint(np.array(point)) + new_landmarks[name] = tuple(np.array(new_point).tolist()) + return new_landmarks + + +def landmark_rms_errors( + warped_landmarks: dict[str, tuple[float, float, float]], + fixed_landmarks: dict[str, tuple[float, float, float]], +) -> list[tuple[str, float]]: + """Return per-landmark RMS Euclidean error in mm between the + reference-space ``warped_landmarks`` and the matching reference + landmarks, in sorted-name order. + """ + errors: list[tuple[str, float]] = [] + for name in fixed_landmarks.keys(): + if name not in warped_landmarks: + errors.append((name, float("nan"))) + continue + diff = 0 + for i in range(3): + diff += (warped_landmarks[name][i] - fixed_landmarks[name][i]) ** 2 + errors.append((name, float(np.sqrt(diff)))) + print(f"Landmark {name} RMS error: {errors[-1][1]:.4f} mm") + return errors + + +def load_or_derive_mask(labelmap: itk.Image, mask_path: Path) -> itk.Image: + """Return the cached ``_labelmap_mask.nii.gz`` next to the + labelmap, or derive it via + :meth:`LabelmapTools.convert_labelmap_to_mask` (threshold ``>0`` plus + 3 mm physical-radius dilation) and write it out so subsequent runs and + the ICON eval reuse the same mask. + """ + if mask_path.exists(): + return itk.imread(str(mask_path)) + mask = labelmap_tools.convert_labelmap_to_mask( + labelmap, + dilation_in_mm=mask_dilation_mm, + exclude_labels=[1, 2, 3, 4], + # Interior chambers of the heart: LV, RV, LA, RA + ) + itk.imwrite(mask, str(mask_path), compression=True) + return mask + + +def crop_image_to_mask( + image: itk.Image, + mask: Optional[itk.Image] = None, + labelmap: Optional[itk.Image] = None, + margin_fraction: float = 0.1, +) -> dict[str, itk.Image]: + if mask is None: + mask_arr = itk.array_from_image(image) + print("No mask provided, using image as mask") + else: + mask_arr = itk.array_from_image(mask) + bounding_box = np.where(mask_arr > 0) + min_x = np.min(bounding_box[2]) + max_x = np.max(bounding_box[2]) + min_y = np.min(bounding_box[1]) + max_y = np.max(bounding_box[1]) + min_z = np.min(bounding_box[0]) + max_z = np.max(bounding_box[0]) + margin_x = int((max_x - min_x) * margin_fraction) + margin_y = int((max_y - min_y) * margin_fraction) + margin_z = int((max_z - min_z) * margin_fraction) + min_x -= margin_x + max_x += margin_x + min_y -= margin_y + max_y += margin_y + min_z -= margin_z + max_z += margin_z + if min_x < 0: + min_x = 0 + if min_y < 0: + min_y = 0 + if min_z < 0: + min_z = 0 + max_size = image.GetLargestPossibleRegion().GetSize() + if max_x >= max_size[0]: + max_x = max_size[0] - 1 + if max_y >= max_size[1]: + max_y = max_size[1] - 1 + if max_z >= max_size[2]: + max_z = max_size[2] - 1 + print(f"array shape: {mask_arr.shape}") + print( + f"min_x: {min_x}, max_x: {max_x}, min_y: {min_y}, max_y: {max_y}, min_z: {min_z}, max_z: {max_z}" + ) + new_image_arr = itk.array_from_image(image) + new_image_arr = new_image_arr[min_z:max_z, min_y:max_y, min_x:max_x] + new_origin = image.TransformIndexToPhysicalPoint( + [int(min_x), int(min_y), int(min_z)] + ) + new_image = itk.image_from_array(new_image_arr) + new_image.SetSpacing(image.GetSpacing()) + new_image.SetDirection(image.GetDirection()) + new_image.SetOrigin(new_origin) + + if labelmap is not None: + new_labelmap_arr = itk.array_from_image(labelmap) + new_labelmap_arr = new_labelmap_arr[min_z:max_z, min_y:max_y, min_x:max_x] + new_labelmap = itk.image_from_array(new_labelmap_arr) + new_labelmap.CopyInformation(new_image) + else: + new_labelmap = None + + if mask is not None: + new_mask_arr = itk.array_from_image(mask) + new_mask_arr = new_mask_arr[min_z:max_z, min_y:max_y, min_x:max_x] + new_mask = itk.image_from_array(new_mask_arr) + new_mask.CopyInformation(new_image) + else: + new_mask = None + + return { + "image": new_image, + "labelmap": new_labelmap, + "mask": new_mask, + } + + +# %% +_HERE = Path(__file__).parent +for subject_index, subject_id in enumerate(cohort): + print(f"\n=== Subject {subject_index + 1}/{len(cohort)}: {subject_id} ===") + src_dir = src_data_dir_base / subject_id + seg_dir = segmentation_dir_base / subject_id + + ref_file = next((p for p in ref_files if p.name.startswith(subject_id)), None) + ref_stem = ref_file.name[:-7] + ref_labelmap_path = seg_dir / f"{ref_stem}_labelmap.nii.gz" + ref_mask_path = seg_dir / f"{ref_stem}_labelmap_mask.nii.gz" + ref_landmark_path = seg_dir / f"{ref_stem}_landmark.mrk.json" + + fixed_output_dir = _HERE / "fixed" + fixed_output_dir.mkdir(parents=True, exist_ok=True) + + fixed_image = itk.imread(str(ref_file), pixel_type=itk.F) + + fixed_labelmap = None + if ref_labelmap_path.exists(): + fixed_labelmap = itk.imread(str(ref_labelmap_path)) + + fixed_mask = None + if ref_mask_path.exists(): + fixed_mask = load_or_derive_mask(fixed_labelmap, ref_mask_path) + + fixed_landmarks = None + if ref_landmark_path.exists(): + fixed_landmarks = landmark_tools.read_landmarks_3dslicer(ref_landmark_path) + + if use_crop: + cropped = crop_image_to_mask( + fixed_image, mask=fixed_mask, labelmap=fixed_labelmap + ) + fixed_image = cropped["image"] + fixed_labelmap = cropped["labelmap"] + fixed_mask = cropped["mask"] + + if fixed_image_resolution_mm > 0.0: + fixed_image_size = fixed_image.GetLargestPossibleRegion().GetSize() + fixed_image_size[0] = int( + fixed_image_size[0] + * fixed_image.GetSpacing()[0] + / fixed_image_resolution_mm + ) + fixed_image_size[1] = int( + fixed_image_size[1] + * fixed_image.GetSpacing()[1] + / fixed_image_resolution_mm + ) + fixed_image_size[2] = int( + fixed_image_size[2] + * fixed_image.GetSpacing()[2] + / fixed_image_resolution_mm + ) + fixed_image = itk.resample_image_filter( + fixed_image, + output_direction=fixed_image.GetDirection(), + output_origin=fixed_image.GetOrigin(), + size=fixed_image_size, + output_spacing=[ + fixed_image_resolution_mm, + fixed_image_resolution_mm, + fixed_image_resolution_mm, + ], + default_pixel_value=-1000, + ) + if fixed_labelmap is not None: + fixed_labelmap = itk.resample_image_filter( + fixed_labelmap, + output_parameters_from_image=fixed_image, + default_pixel_value=0, + interpolator=itk.NearestNeighborInterpolateImageFunction.New( + fixed_labelmap + ), + ) + if fixed_mask is not None: + fixed_mask = itk.resample_image_filter( + fixed_mask, + output_parameters_from_image=fixed_image, + default_pixel_value=0, + interpolator=itk.NearestNeighborInterpolateImageFunction.New( + fixed_mask + ), + ) + + print(f"Writing reference image to {f'{subject_id}_ref.nii.gz'}") + itk.imwrite( + fixed_image, + str(fixed_output_dir / f"{subject_id}_ref.nii.gz"), + compression=True, + ) + if fixed_labelmap is not None: + print(f"Writing reference labelmap to {f'{subject_id}_ref_labelmap.nii.gz'}") + itk.imwrite( + fixed_labelmap, + str(fixed_output_dir / f"{subject_id}_ref_labelmap.nii.gz"), + compression=True, + ) + if fixed_mask is not None: + print(f"Writing reference mask to {f'{subject_id}_ref_mask.nii.gz'}") + itk.imwrite( + fixed_mask, + str(fixed_output_dir / f"{subject_id}_ref_mask.nii.gz"), + compression=True, + ) + if fixed_landmarks is not None: + print( + f"Writing reference landmarks to {f'{subject_id}_ref_landmarks.mrk.json'}" + ) + landmark_tools.write_landmarks_3dslicer( + fixed_landmarks, + str(fixed_output_dir / f"{subject_id}_ref_landmarks.mrk.json"), + ) + + gated_files = sorted( + p + for p in src_dir.glob("*.nii.gz") + if not any(token in p.name for token in exclude_tokens) + and not p.name.endswith(f"{ref_suffix}.nii.gz") + ) + + print(f"Found {len(gated_files)} gated images under {src_dir}") + + for image_index, image_path in enumerate(gated_files): + stem = image_path.name[:-7] + + print(f"\n\n *** Processing {stem} ***\n\n") + + labelmap_path = seg_dir / f"{stem}_labelmap.nii.gz" + mask_path = seg_dir / f"{stem}_labelmap_mask.nii.gz" + landmark_path = seg_dir / f"{stem}_landmark.mrk.json" + + moving_image_name = str(image_path) + moving_image = itk.imread(moving_image_name, pixel_type=itk.F) + moving_labelmap = None + if fixed_labelmap is not None and labelmap_path.exists(): + moving_labelmap = itk.imread(str(labelmap_path)) + moving_mask = None + if fixed_mask is not None and mask_path.exists(): + moving_mask = load_or_derive_mask(moving_labelmap, mask_path) + moving_landmarks = None + if fixed_landmarks is not None and landmark_path.exists(): + moving_landmarks = landmark_tools.read_landmarks_3dslicer(landmark_path) + + if use_crop: + cropped = crop_image_to_mask( + moving_image, mask=moving_mask, labelmap=moving_labelmap + ) + moving_image = cropped["image"] + moving_labelmap = cropped["labelmap"] + moving_mask = cropped["mask"] + + for cond_index in range(len(use_mask_list)): + methods = methods_list[cond_index] + + use_mask = use_mask_list[cond_index] + use_labelmap = use_labelmap_list[cond_index] + use_mass = use_mass_list[cond_index] + + number_of_iterations_ANTS = number_of_iterations_ANTS_list[cond_index] + number_of_iterations_greedy = number_of_iterations_greedy_list[cond_index] + number_of_iterations_ICON = number_of_iterations_ICON_list[cond_index] + + cond = "_" + if use_mask: + cond += "m" + if use_labelmap: + cond += "l" + if use_mass: + cond += "p" + if use_crop: + cond += "c" + if cond == "_": + cond += "raw" + + print( + f"\n\n ***** {cond_index + 1}/{len(use_mask_list)}: results{cond} *****\n\n" + ) + + output_dir = _HERE / f"results{cond}" + detail_landmarks_file = output_dir / "registration_landmarks_init.csv" + detail_dice_file = output_dir / "registration_dice_init.csv" + if subject_index == 0 and image_index == 0 and cond_index == 0: + if output_dir.exists(): + shutil.rmtree(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + for method_index, method_name in enumerate(methods): + print(f"\n\n --- {method_name} --- \n\n") + if method_name == "ANTS": + reg = RegisterImagesANTS() + reg.set_number_of_iterations(number_of_iterations_ANTS) + num_iters_str = ".".join(str(n) for n in number_of_iterations_ANTS) + reg.set_transform_type("Deformable") + # NCC ("CC") beats MeanSquares for same-modality CT registration. + if use_labelmap: + reg.set_metric("MeanSquares") + else: + reg.set_metric("CC") + elif method_name == "Greedy": + reg = RegisterImagesGreedy() + reg.set_number_of_iterations(number_of_iterations_greedy) + print(f"Number of iterations: {number_of_iterations_greedy}") + num_iters_str = ".".join( + str(n) for n in number_of_iterations_greedy + ) + reg.set_transform_type("Deformable") + # NCC ("CC") beats MeanSquares for same-modality CT registration. + if use_labelmap: + reg.set_metric("MeanSquares") + else: + reg.set_metric("CC") + else: # ICON: GPU deep-learning deformable registration. + reg = RegisterImagesICON() + reg.set_number_of_iterations(number_of_iterations_ICON) + num_iters_str = str(number_of_iterations_ICON) + reg.set_multi_modality(False) + reg.set_mass_preservation(use_mass) + if icon_weights_path is not None: + reg.set_weights_path(str(icon_weights_path)) + reg.set_modality("ct") + reg.set_mask_dilation(0) # Already dilated + reg.set_fixed_image(fixed_image) + if use_mask: + reg.set_fixed_mask(fixed_mask) + if use_labelmap: + reg.set_fixed_labelmap(fixed_labelmap) + + method_dir_name = str(method_name.lower()) + "_" + num_iters_str + method_dir = output_dir / method_dir_name / subject_id + print(f"Method directory: {method_dir}") + method_dir.mkdir(parents=True, exist_ok=True) + print(f"Registering {stem} with {method_name}...") + time_start = time.perf_counter() + reg_result = reg.register( + moving_image=moving_image, + moving_mask=moving_mask if use_mask else None, + moving_labelmap=moving_labelmap if use_labelmap else None, + ) + time_elapsed = time.perf_counter() - time_start + print(f" ...finished registration in {time_elapsed:.1f}s") + + forward_transform = reg_result["forward_transform"] + inverse_transform = reg_result["inverse_transform"] + loss = float(reg_result["loss"]) + + print(f"Writing results to {method_dir / f'{stem}_init_*.*'}") + + itk.transformwrite( + forward_transform, + str(method_dir / f"{stem}_init_fwd.hdf"), + compression=True, + ) + itk.transformwrite( + inverse_transform, + str(method_dir / f"{stem}_init_inv.hdf"), + compression=True, + ) + + warped_image = transform_tools.transform_image( + moving_image, + forward_transform, + fixed_image, + interpolation_method="linear", + ) + itk.imwrite( + warped_image, + str(method_dir / f"{stem}_init.mha"), + compression=True, + ) + + average_dice = float("nan") + if fixed_labelmap is not None and moving_labelmap is not None: + warped_labelmap = transform_tools.transform_image( + moving_labelmap, + forward_transform, + fixed_labelmap, + interpolation_method="nearest", + ) + itk.imwrite( + warped_labelmap, + str(method_dir / f"{stem}_init_labelmap.mha"), + compression=True, + ) + dice_by_label = per_label_dice(fixed_labelmap, warped_labelmap) + with detail_dice_file.open("a", newline="", encoding="utf-8") as fh: + writer = csv.writer(fh) + if fh.tell() == 0: + writer.writerow( + ["subject_id", "method", "stem", "label", "dice"] + ) + for label, dice in dice_by_label.items(): + writer.writerow( + [ + subject_id, + method_name + "_" + num_iters_str, + stem, + label, + dice, + ] + ) + if dice_by_label: + average_dice = float(np.mean(list(dice_by_label.values()))) + + if fixed_mask is not None and moving_mask is not None: + warped_mask = transform_tools.transform_image( + moving_mask, + forward_transform, + fixed_mask, + interpolation_method="nearest", + ) + itk.imwrite( + warped_mask, + str(method_dir / f"{stem}_init_mask.mha"), + compression=True, + ) + + average_rms_errors = float("nan") + if fixed_landmarks is not None and moving_landmarks is not None: + # Landmarks live in LPS world space, unaffected by cropping, so + # the uncropped moving_landmarks are warped here. + warped_landmarks = warp_landmarks( + inverse_transform, moving_landmarks + ) + landmark_tools.write_landmarks_3dslicer( + warped_landmarks, + str(method_dir / f"{stem}_init_landmarks.mrk.json"), + ) + rms_errors = landmark_rms_errors(warped_landmarks, fixed_landmarks) + with detail_landmarks_file.open( + "a", newline="", encoding="utf-8" + ) as fh: + writer = csv.writer(fh) + if fh.tell() == 0: + writer.writerow( + [ + "subject_id", + "method", + "stem", + "name", + "rms_err_mm", + ] + ) + for name, error in rms_errors: + writer.writerow( + [ + subject_id, + method_name + "_" + num_iters_str, + stem, + name, + error, + ] + ) + rms_values = np.asarray( + [e for _, e in rms_errors], dtype=np.float64 + ) + if rms_values.size and not np.all(np.isnan(rms_values)): + average_rms_errors = float(np.nanmean(rms_values)) + + print( + f"Method: {method_name}_{num_iters_str}, Subject: {subject_id}, " + f"Timepoint: {stem}, time: {time_elapsed:.1f}s, " + f"loss: {loss:.4f}, Dice: {average_dice:.4f}, " + f"RMS(mm): {average_rms_errors:.4f}" + ) + +# %% diff --git a/experiments/LongitudinalRegistration/2-finetune_icon.py b/experiments/LongitudinalRegistration/2-finetune_icon.py new file mode 100644 index 0000000..05e55db --- /dev/null +++ b/experiments/LongitudinalRegistration/2-finetune_icon.py @@ -0,0 +1,266 @@ +# %% [markdown] +# # Fine-tune uniGradICON on Duke 4D Gated CT Data +# +# Discovers per-patient gated CT images and their precomputed +# SegmentHeartSimpleware labelmaps and applies the project-wide fixed 80/20 +# train/test split (sort patients in ``ref_data_dir`` by filename; the first +# 80% are train, the last 20% are test). The train cohort is handed to +# :class:`WorkflowFineTuneICONRegistration`, which builds the paired dataset +# JSON, YAML config, and derived loss-function masks, then launches +# ``unigradicon.finetuning.finetune`` as a subprocess. +# +# ``2-recon_4d_icon_eval.py`` re-derives the same split from the same sorted +# patient list — no cached split file is needed. +# +# Each patient directory under ``src_data_dir_base`` is one ``subject_id``; +# all of that patient's gated time-point frames form a paired training group. +# Frames whose labelmap is missing on disk are dropped from the dataset. +# +# In addition to the original ``gated_nii`` frames, each patient's training +# group is augmented with that patient's ANTS- and Greedy-warped frames +# written by ``1-initial_registration.py`` (warped image + labelmap per gated +# frame, under ``output_dir / / ``). Because the warped +# frames are merged into the *same* ``subject_id`` group, uniGradICON pairs the +# original gated frames and both backends' pre-registered frames together. + +# %% +import os +from pathlib import Path +from typing import Optional + +import itk + +from physiomotion4d import WorkflowFineTuneICONRegistration +from physiomotion4d.labelmap_tools import LabelmapTools + +# %% [markdown] +# ## 1. Configure data, output locations, and the train/test split + +# %% +ref_data_dir = Path("d:/PhysioMotion4D/duke_data/ref_images") +src_data_dir_base = Path("d:/PhysioMotion4D/duke_data/gated_nii") +labelmap_dir_base = Path("d:/PhysioMotion4D/duke_data/simple_ascardio") + +# Where the workflow writes the dataset JSON, YAML config, derived masks, and +# the uniGradICON ``checkpoints/`` tree. experiment_dir resolves to +# ``output_dir / fine_tune_name``. +_HERE = Path(__file__).parent +output_dir = _HERE / "results_finetuning" +fine_tune_name = "icon_finetuning" + +# Pre-registration augmentation: ``1-initial_registration.py`` warps every gated +# moving frame into reference space with these backends and writes the warped +# image + labelmap under ``initial_registration_dir / .lower() / +# ``. Those warped frames are merged into each patient's training +# group below (section 4b). +initial_registration_dir = output_dir +initial_registration_methods = ["Greedy"] + +# Fixed train/test split: sort patients in ``ref_data_dir`` by filename; +# first 80% are train, last 20% are test. ``2-recon_4d_icon_eval.py`` applies +# the same rule so the two scripts agree without a cached split record. +train_fraction = 0.8 + +# Local clone of uniGradICON (feat-add-finetuning branch) — prepended to +# PYTHONPATH so the subprocess picks up the local source instead of the +# installed package. Set to ``None`` to use the pip-installed unigradicon. +unigradicon_src_path: Optional[Path] = Path(__file__).parent / "uniGradICON" / "src" + +# %% [markdown] +# ## 2. Enumerate patients and apply the fixed 80/20 split +# +# Sort ``ref_data_dir`` by filename to produce the canonical patient order. +# The first 80% become the train cohort; the last 20% are the held-out test +# cohort that ``2-recon_4d_icon_eval.py`` will evaluate. + +# %% +ref_files = sorted( + p + for p in ref_data_dir.iterdir() + if p.name.startswith("pm00") and p.suffixes[-2:] == [".nii", ".gz"] +) +all_patient_ids = [p.name[:6] for p in ref_files] +print(f"Found {len(all_patient_ids)} patients under {ref_data_dir}") + +if len(all_patient_ids) < 2: + raise FileNotFoundError( + f"Need at least 2 patients to form a train/test split; " + f"discovered {len(all_patient_ids)} under {ref_data_dir}" + ) + +n_train = max( + 1, + min(len(all_patient_ids) - 1, round(train_fraction * len(all_patient_ids))), +) +train_subjects = all_patient_ids[:n_train] +test_subjects = all_patient_ids[n_train:] +print(f" Train (first {n_train}): {train_subjects}") +print(f" Test (last {len(test_subjects)}): {test_subjects}") + +# %% [markdown] +# ## 3. Gather the train cohort's gated frames and labelmaps and masks +# +# For each train-cohort patient, list gated frames in +# ``src_data_dir_base / `` (excluding ``"nop"`` non-gated +# references) and pair each frame with its +# ``_labelmap.nii.gz`` and ``_mask.nii.gz`` under ``labelmap_dir_base / ``. +# Patients with no source directory or no valid frames are skipped here only +# — they remain part of the canonical train list above, but contribute no +# training data. Missing labelmaps are recorded as ``None`` so the workflow +# skips just that frame. + +# %% +train_image_files: list[list[Path]] = [] +train_labelmap_files: list[list[Optional[Path]]] = [] +train_mask_files: list[list[Optional[Path]]] = [] +valid_train_subjects: list[str] = [] + +mask_dilation_mm = 3.0 +labelmap_tools = LabelmapTools() + + +# %% +def load_or_derive_mask(labelmap_path: Path) -> Optional[Path]: + """Create (or reuse) a loss-function mask next to ``labelmap_path``. + + Thresholds the labelmap at ``>0`` and dilates by ``mask_dilation_mm`` mm + via :meth:`LabelmapTools.convert_labelmap_to_mask`, writing the result as + ``_mask.nii.gz`` in the labelmap's own directory. Handles + both ``.nii.gz`` (original Simpleware labelmaps) and ``.mha`` + (pre-registration warped labelmaps). Returns the mask path; existing + masks on disk are reused unmodified. + """ + if not labelmap_path.exists(): + return None + + name = labelmap_path.name + if name.endswith(".nii.gz"): + stem = name[:-7] + elif name.endswith(".mha"): + stem = name[:-4] + else: + stem = labelmap_path.stem + mask_p = labelmap_path.parent / f"{stem}_mask.nii.gz" + if not mask_p.exists(): + mask = labelmap_tools.convert_labelmap_to_mask( + itk.imread(str(labelmap_path)), dilation_in_mm=mask_dilation_mm + ) + itk.imwrite(mask, str(mask_p), compression=True) + return mask_p + + +# %% +def gather_warped_frames( + method_dir: Path, +) -> tuple[list[Path], list[Optional[Path]], list[Optional[Path]]]: + """Return ``(warped_image_paths, warped_labelmap_paths, warped_mask_paths)`` for one + ``initial_registration_dir / / `` directory. + + Enumerates the warped moving images (``.mha``), excluding the + ``_labelmap.mha`` and ``_mask.mha`` + companions, and pairs each with its ``_labelmap.mha`` (``None`` when + that labelmap is absent). Returns empty lists when ``method_dir`` does + not exist. + """ + if not method_dir.is_dir(): + return [], [], [] + companion_suffixes = ( + "_labelmap.mha", + "_mask.mha", + ) + image_paths: list[Path] = [] + labelmap_paths: list[Optional[Path]] = [] + mask_paths: list[Optional[Path]] = [] + for image in sorted(method_dir.glob("*.mha")): + if image.name.endswith(companion_suffixes): + continue + stem = image.name[:-4] + labelmap = method_dir / f"{stem}_labelmap.mha" + mask = method_dir / f"{stem}_mask.mha" + image_paths.append(image) + labelmap_paths.append(labelmap if labelmap.exists() else None) + mask_paths.append(mask if mask.exists() else None) + return image_paths, labelmap_paths, mask_paths + + +# %% +for patient_id in train_subjects: + src_dir = src_data_dir_base / patient_id + seg_dir = labelmap_dir_base / patient_id + + if not src_dir.is_dir(): + print(f" Skipping {patient_id}: source dir {src_dir} not found") + continue + + frame_names = sorted( + f for f in os.listdir(src_dir) if "nop" not in f and f.endswith(".nii.gz") + ) + if not frame_names: + print(f" Skipping {patient_id}: no valid frames in {src_dir}") + continue + + image_paths = [src_dir / f for f in frame_names] + labelmap_paths: list[Optional[Path]] = [] + mask_paths: list[Optional[Path]] = [] + for f in frame_names: + labelmap = seg_dir / f.replace(".nii.gz", "_labelmap.nii.gz") + labelmap_paths.append(labelmap if labelmap.exists() else None) + mask = load_or_derive_mask(labelmap) + mask_paths.append(mask) + + train_image_files.append(image_paths) + train_labelmap_files.append(labelmap_paths) + train_mask_files.append(mask_paths) + valid_train_subjects.append(patient_id) + + n_seg = sum(1 for s in labelmap_paths if s is not None) + print(f" {patient_id}: {len(image_paths)} frames, {n_seg} with labelmap") + + +for subject_index, patient_id in enumerate(valid_train_subjects): + for method_name in initial_registration_methods: + method_dir = initial_registration_dir / method_name.lower() / patient_id + warped_images, warped_labelmaps, warped_masks = gather_warped_frames(method_dir) + if not warped_images: + print( + f" {patient_id}/{method_name}: no initial-registered frames " + f"in {method_dir}" + ) + continue + train_image_files[subject_index].extend(warped_images) + train_labelmap_files[subject_index].extend(warped_labelmaps) + train_mask_files[subject_index].extend(warped_masks) + n_warped = sum(1 for labelmap in warped_labelmaps if labelmap is not None) + print( + f" {patient_id}/{method_name}: +{len(warped_images)} warped frames, " + f"{n_warped} with labelmap" + ) + +# %% +workflow = WorkflowFineTuneICONRegistration( + subject_image_files=[ + [str(image_path) for image_path in image_paths] + for image_paths in train_image_files + ], + output_dir=output_dir, + fine_tune_name=fine_tune_name, + subject_ids=valid_train_subjects, + subject_labelmap_files=[ + [ + str(labelmap_path) if labelmap_path is not None else None + for labelmap_path in labelmap_paths + ] + for labelmap_paths in train_labelmap_files + ], + subject_mask_files=[ + [str(mask_path) if mask_path is not None else None for mask_path in mask_paths] + for mask_paths in train_mask_files + ], + mask_dilation_mm=mask_dilation_mm, + unigradicon_src_path=unigradicon_src_path, + epochs=500, +) + +weights_path = workflow.run_fine_tuning() +print(f"\nFine-tuning complete. Expected weights at: {weights_path}") +print(f"Held-out test cohort (for 2-recon_4d_icon_eval.py): {test_subjects}") diff --git a/experiments/LongitudinalRegistration/2-recon_4d_icon_eval.py b/experiments/LongitudinalRegistration/3-recon_4d_icon_eval.py similarity index 79% rename from experiments/LongitudinalRegistration/2-recon_4d_icon_eval.py rename to experiments/LongitudinalRegistration/3-recon_4d_icon_eval.py index 83ff66c..6b61734 100644 --- a/experiments/LongitudinalRegistration/2-recon_4d_icon_eval.py +++ b/experiments/LongitudinalRegistration/3-recon_4d_icon_eval.py @@ -3,11 +3,11 @@ # # Enumerates the Duke patient cohort by sorting ``ref_images/`` and uses the # *last 20%* of patients as the held-out test set — the same fixed split -# applied by ``1-finetune_icon.py`` (first 80% train, last 20% test). For +# applied by ``2-finetune_icon.py`` (first 80% train, last 20% test). For # each test subject the 70th-percentile gated frame is selected as the # reference and every other frame is registered to it twice with # ``RegisterTimeSeriesImages``: once with the default uniGradICON weights and -# once with the finetuned checkpoint from ``1-finetune_icon.py``. The +# once with the finetuned checkpoint from ``2-finetune_icon.py``. The # resampler-convention inverse transform (which maps moving-grid points back # to reference-grid points) is applied to each time-point's precomputed # landmarks to land them in reference space, and the Euclidean error against @@ -25,8 +25,8 @@ import numpy as np from physiomotion4d import RegisterTimeSeriesImages +from physiomotion4d.labelmap_tools import LabelmapTools from physiomotion4d.landmark_tools import LandmarkTools -from physiomotion4d.register_images_icon import RegisterImagesICON # %% [markdown] # ## 1. Hard-coded paths and configuration @@ -35,15 +35,21 @@ ref_data_dir = Path("d:/PhysioMotion4D/duke_data/ref_images") timepoint_base_dir = Path("d:/PhysioMotion4D/duke_data/gated_nii") segmentation_base_dir = Path("d:/PhysioMotion4D/duke_data/simple_ascardio") -output_dir = Path("./results") -finetuned_weights_path = Path( - "./results/icon_finetuned/checkpoints/Finetune_multi_final.trch" + +_HERE = Path(__file__).parent +output_dir = _HERE / "results" +finetuned_weights_path = ( + output_dir + / "icon_finetuned" + / "icon_finetuned_model" + / "checkpoints" + / "network_weights_final.trch" ) train_fraction = 0.8 -icon_iterations = 20 +icon_iterations = None reference_percentile = 0.70 -exclude_tokens = ("nop", "dia", "sys", "_ref") +exclude_tokens = ["nop"] timepoint_re = re.compile(r"_g(?P[0-9]{3})") methods: list[tuple[str, Optional[Path]]] = [ @@ -87,12 +93,13 @@ # Landmarks are read with :meth:`LandmarkTools.read_landmarks_3dslicer` — # they were written as ``_landmark.mrk.json`` (3D Slicer Markups JSON, # LPS) by ``0-cardiacGatedCT_segment_and_landmark.py``. Binary registration -# masks come from :meth:`RegisterImagesICON.create_mask` (``>0`` threshold -# plus 5 mm dilation by default), matching the loss-function masks used +# masks come from :meth:`LabelmapTools.convert_labelmap_to_mask` (``>0`` +# threshold plus 5 mm dilation), matching the loss-function masks used # during fine-tuning in ``1-finetune_icon.py``. # %% landmark_tools = LandmarkTools() +labelmap_tools = LabelmapTools() # %% [markdown] @@ -103,15 +110,20 @@ for subject_id in test_subjects: source_dir = timepoint_base_dir / subject_id + print(f"Source directory: {source_dir}") + seg_dir = segmentation_base_dir / subject_id + print(f"Segmentation directory: {seg_dir}") image_files = [ p for p in sorted(source_dir.glob("*.nii.gz")) if not any(t in p.name for t in exclude_tokens) ] + print(f"Found {len(image_files)} image files") stems = [p.name[:-7] for p in image_files] labelmap_files = [seg_dir / f"{s}_labelmap.nii.gz" for s in stems] + mask_files = [seg_dir / f"{s}_labelmap_mask.nii.gz" for s in stems] landmark_files = [seg_dir / f"{s}_landmark.mrk.json" for s in stems] timepoints = [timepoint_re.search(p.name).group("timepoint") for p in image_files] @@ -122,17 +134,34 @@ ) fixed_image = itk.imread(str(image_files[reference_index]), pixel_type=itk.F) - fixed_mask = RegisterImagesICON.create_mask( - itk.imread(str(labelmap_files[reference_index])) - ) + fixed_labelmap = itk.imread(str(labelmap_files[reference_index])) + if mask_files[reference_index].exists(): + fixed_mask = itk.imread(str(mask_files[reference_index])) + else: + fixed_mask = labelmap_tools.convert_labelmap_to_mask( + fixed_labelmap, dilation_in_mm=5.0 + ) + itk.imwrite(fixed_mask, str(mask_files[reference_index]), compression=True) reference_landmarks = landmark_tools.read_landmarks_3dslicer( landmark_files[reference_index] ) moving_images = [itk.imread(str(p), pixel_type=itk.F) for p in image_files] - moving_masks = [ - RegisterImagesICON.create_mask(itk.imread(str(p))) for p in labelmap_files + moving_labelmaps = [itk.imread(str(p)) for p in labelmap_files] + moving_landmarks = [ + landmark_tools.read_landmarks_3dslicer(str(p)) for p in landmark_files ] + moving_masks = [] + for index, p in enumerate(mask_files): + if not p.exists(): + mask = labelmap_tools.convert_labelmap_to_mask( + moving_labelmaps[index], dilation_in_mm=5.0 + ) + itk.imwrite(mask, str(p), compression=True) + moving_masks.append(mask) + else: + mask = itk.imread(str(p)) + moving_masks.append(mask) for method_name, weights_path in methods: print(f" Method: {method_name}") @@ -147,7 +176,7 @@ result = registrar.register_time_series( moving_images=moving_images, moving_masks=moving_masks, - moving_labelmaps=None, + moving_labelmaps=moving_labelmaps, reference_frame=reference_index, register_reference=False, prior_weight=0.0, @@ -178,9 +207,7 @@ # inverse_transform follows the ITK resampler convention — it maps # moving-grid points back to reference-grid points, which is what # we need to warp time-point landmarks into reference space. - timepoint_landmarks = landmark_tools.read_landmarks_3dslicer( - landmark_files[index] - ) + timepoint_landmarks = moving_landmarks[index] shared = sorted(timepoint_landmarks.keys() & reference_landmarks.keys()) errors: list[tuple[str, float]] = [] for name in shared: @@ -221,11 +248,14 @@ # ## 5. Write the wide-form per-timepoint summary CSV # %% -with summary_file.open("w", newline="", encoding="utf-8") as fh: - writer = csv.DictWriter(fh, fieldnames=list(summary_rows[0].keys())) - writer.writeheader() - writer.writerows(summary_rows) -print(f"Wrote summary: {summary_file}") +if not summary_rows: + print("No summary rows to write") +else: + with summary_file.open("w", newline="", encoding="utf-8") as fh: + writer = csv.DictWriter(fh, fieldnames=list(summary_rows[0].keys())) + writer.writeheader() + writer.writerows(summary_rows) + print(f"Wrote summary: {summary_file}") print(f"Wrote landmark details: {detail_file}") # %% [markdown] diff --git a/experiments/LongitudinalRegistration/3-run_registration_method_comparison.py b/experiments/LongitudinalRegistration/3-run_registration_method_comparison.py deleted file mode 100644 index 5467cad..0000000 --- a/experiments/LongitudinalRegistration/3-run_registration_method_comparison.py +++ /dev/null @@ -1,700 +0,0 @@ -"""Compare longitudinal cardiac CT registration methods with landmarks. - -The experiment registers each gated time-point image to the high-resolution -reference image for the same subject. Input images are 3D CT volumes in LPS -world space. Landmarks are CSV rows with physical LPS coordinates -``Name,X,Y,Z`` in millimeters. - -Two accuracy modes are written: -1. Direct landmarks: reference landmarks are transformed into each time-point - image space with the inverse registration transform and compared to the - precomputed time-point landmarks. -2. Re-segmented landmarks: the reference image is warped into each time-point - image space, re-segmented with Simpleware, and the newly extracted landmarks - are compared to the precomputed time-point landmarks. -""" - -from __future__ import annotations - -import argparse -import csv -import re -from dataclasses import dataclass -from pathlib import Path -from typing import Optional - -import itk -import numpy as np - -from physiomotion4d import ( - RegisterTimeSeriesImages, - SegmentHeartSimpleware, - TransformTools, -) - -DEFAULT_REF_DIR = Path("d:/PhysioMotion4D/duke_data/ref_images") -DEFAULT_TIMEPOINT_BASE_DIR = Path("d:/PhysioMotion4D/duke_data/gated_nii") -DEFAULT_SEGMENTATION_BASE_DIR = Path("d:/PhysioMotion4D/duke_data/simple_ascardio") -DEFAULT_OUTPUT_DIR = Path("d:/PhysioMotion4D/duke_data/longitudinal_registration") -DEFAULT_EXCLUDE_TOKENS = ("nop", "dia", "sys", "_ref") -DEFAULT_SEGMENTATION_DIR = "results-labelmaps_and_landmarks" -DEFAULT_METHODS = ("ANTS", "greedy", "icon_default", "ants_icon_default") -TIMEPOINT_RE = re.compile(r"_g(?P[0-9]{3})") - - -Landmarks = dict[str, tuple[float, float, float]] - - -@dataclass(frozen=True) -class MethodSpec: - """Registration method plus optional ICON checkpoint.""" - - output_name: str - registration_method: str - icon_weights_path: Optional[Path] = None - - -@dataclass(frozen=True) -class ImageArtifacts: - """Input files associated with one image volume.""" - - image_file: Path - landmark_file: Optional[Path] - labelmap_file: Optional[Path] - timepoint: str - - -def nii_stem(path: Path) -> str: - """Return a stable stem for ``.nii.gz`` or single-suffix files.""" - if path.name.endswith(".nii.gz"): - return path.name[:-7] - return path.stem - - -def timepoint_from_name(path: Path) -> str: - """Extract the gated time-point tag from a filename.""" - match = TIMEPOINT_RE.search(path.name) - if match: - return match.group("timepoint") - return nii_stem(path) - - -def first_existing(paths: list[Path]) -> Optional[Path]: - """Return the first existing path from a candidate list.""" - for path in paths: - if path.exists(): - return path - return None - - -def landmark_candidates( - image_file: Path, - segmentation_dir: str, - artifact_dir: Optional[Path], -) -> list[Path]: - """Return likely landmark CSV paths for an image.""" - stem = nii_stem(image_file) - parent = image_file.parent - seg_parent = parent / segmentation_dir - candidates = [ - parent / f"{stem}_landmark.csv", - parent / f"{stem}_landmarks.csv", - seg_parent / f"{stem}_landmark.csv", - seg_parent / f"{stem}_landmarks.csv", - ] - if artifact_dir is not None: - candidates = [ - artifact_dir / f"{stem}_landmark.csv", - artifact_dir / f"{stem}_landmarks.csv", - *candidates, - ] - return candidates - - -def labelmap_candidates( - image_file: Path, - segmentation_dir: str, - artifact_dir: Optional[Path], -) -> list[Path]: - """Return likely labelmap paths for an image.""" - stem = nii_stem(image_file) - parent = image_file.parent - seg_parent = parent / segmentation_dir - candidates = [ - parent / f"{stem}_labelmap.nii.gz", - seg_parent / f"{stem}_labelmap.nii.gz", - ] - if artifact_dir is not None: - candidates = [artifact_dir / f"{stem}_labelmap.nii.gz", *candidates] - return candidates - - -def image_artifacts( - image_file: Path, - segmentation_dir: str, - artifact_dir: Optional[Path] = None, -) -> ImageArtifacts: - """Find landmarks and labelmaps associated with one image.""" - return ImageArtifacts( - image_file=image_file, - landmark_file=first_existing( - landmark_candidates(image_file, segmentation_dir, artifact_dir) - ), - labelmap_file=first_existing( - labelmap_candidates(image_file, segmentation_dir, artifact_dir) - ), - timepoint=timepoint_from_name(image_file), - ) - - -def read_landmarks(path: Path) -> Landmarks: - """Read physical LPS landmarks from ``Name,X,Y,Z`` CSV.""" - landmarks: Landmarks = {} - with path.open(newline="", encoding="utf-8-sig") as fh: - for row in csv.DictReader(fh): - landmarks[row["Name"]] = ( - float(row["X"]), - float(row["Y"]), - float(row["Z"]), - ) - return landmarks - - -def write_landmarks(path: Path, landmarks: Landmarks) -> None: - """Write physical LPS landmarks to ``Name,X,Y,Z`` CSV.""" - path.parent.mkdir(parents=True, exist_ok=True) - with path.open("w", newline="", encoding="utf-8") as fh: - writer = csv.writer(fh) - writer.writerow(["Name", "X", "Y", "Z"]) - for name, coords in sorted(landmarks.items()): - writer.writerow([name, coords[0], coords[1], coords[2]]) - - -def transform_landmarks(landmarks: Landmarks, transform: itk.Transform) -> Landmarks: - """Apply an ITK physical-space transform to landmark coordinates.""" - transformed: Landmarks = {} - for name, point in landmarks.items(): - transformed_point = transform.TransformPoint(point) - transformed[name] = ( - float(transformed_point[0]), - float(transformed_point[1]), - float(transformed_point[2]), - ) - return transformed - - -def landmark_errors(source: Landmarks, target: Landmarks) -> dict[str, float]: - """Return per-landmark Euclidean errors in millimeters.""" - errors: dict[str, float] = {} - for name in sorted(source.keys() & target.keys()): - source_point = np.asarray(source[name], dtype=np.float64) - target_point = np.asarray(target[name], dtype=np.float64) - errors[name] = float(np.linalg.norm(source_point - target_point)) - return errors - - -def summarize_errors(errors: dict[str, float], prefix: str) -> dict[str, object]: - """Summarize landmark errors for one comparison mode.""" - if not errors: - return { - f"{prefix}_landmarks": 0, - f"{prefix}_mean_mm": "", - f"{prefix}_median_mm": "", - f"{prefix}_max_mm": "", - } - values = np.asarray(list(errors.values()), dtype=np.float64) - return { - f"{prefix}_landmarks": len(errors), - f"{prefix}_mean_mm": float(np.mean(values)), - f"{prefix}_median_mm": float(np.median(values)), - f"{prefix}_max_mm": float(np.max(values)), - } - - -def write_error_details( - path: Path, - subject_id: str, - method_name: str, - timepoint: str, - mode: str, - errors: dict[str, float], -) -> None: - """Append per-landmark errors to the detail CSV.""" - exists = path.exists() - path.parent.mkdir(parents=True, exist_ok=True) - with path.open("a", newline="", encoding="utf-8") as fh: - fieldnames = ["subject_id", "method", "timepoint", "mode", "name", "error_mm"] - writer = csv.DictWriter(fh, fieldnames=fieldnames) - if not exists: - writer.writeheader() - for name, error in sorted(errors.items()): - writer.writerow( - { - "subject_id": subject_id, - "method": method_name, - "timepoint": timepoint, - "mode": mode, - "name": name, - "error_mm": error, - } - ) - - -def dice_by_label( - labelmap_a: itk.Image, - labelmap_b: itk.Image, -) -> dict[int, float]: - """Compute Dice scores for labels present in either 3D labelmap.""" - arr_a = itk.array_from_image(labelmap_a) - arr_b = itk.array_from_image(labelmap_b) - if arr_a.shape != arr_b.shape: - return {} - labels = sorted(set(np.unique(arr_a)).union(set(np.unique(arr_b))) - {0}) - scores: dict[int, float] = {} - for label in labels: - mask_a = arr_a == label - mask_b = arr_b == label - denom = int(mask_a.sum() + mask_b.sum()) - if denom > 0: - scores[int(label)] = float( - 2.0 * np.logical_and(mask_a, mask_b).sum() / denom - ) - return scores - - -def summarize_dice(scores: dict[int, float]) -> dict[str, object]: - """Summarize per-label Dice scores.""" - if not scores: - return {"dice_labels": 0, "dice_mean": "", "dice_min": ""} - values = np.asarray(list(scores.values()), dtype=np.float64) - return { - "dice_labels": len(scores), - "dice_mean": float(np.mean(values)), - "dice_min": float(np.min(values)), - } - - -def discover_subjects( - reference_dir: Path, - timepoint_base_dir: Path, - reference_pattern: str, - timepoint_pattern: str, - exclude_tokens: tuple[str, ...], - segmentation_dir: str, - segmentation_base_dir: Optional[Path], -) -> list[tuple[str, ImageArtifacts, list[ImageArtifacts]]]: - """Discover reference and time-point files for each subject.""" - if not reference_dir.exists(): - raise FileNotFoundError(f"Reference image directory not found: {reference_dir}") - if not timepoint_base_dir.exists(): - raise FileNotFoundError( - f"Time-point image base directory not found: {timepoint_base_dir}" - ) - - subjects: list[tuple[str, ImageArtifacts, list[ImageArtifacts]]] = [] - for reference_file in sorted(reference_dir.glob(reference_pattern)): - subject_id = reference_file.name[:6] - source_dir = timepoint_base_dir / subject_id - if not source_dir.exists(): - raise FileNotFoundError( - f"No time-point directory for {subject_id}: {source_dir}" - ) - artifact_dir = None - if segmentation_base_dir is not None: - candidate_dir = segmentation_base_dir / subject_id - if candidate_dir.exists(): - artifact_dir = candidate_dir - - reference_in_source = source_dir / reference_file.name - reference_artifacts = image_artifacts( - reference_in_source if reference_in_source.exists() else reference_file, - segmentation_dir, - artifact_dir, - ) - - timepoint_files = [ - path - for path in sorted(source_dir.glob(timepoint_pattern)) - if not any(token in path.name for token in exclude_tokens) - ] - timepoints = [ - image_artifacts(path, segmentation_dir, artifact_dir) - for path in timepoint_files - if path.is_file() - ] - subjects.append((subject_id, reference_artifacts, timepoints)) - return subjects - - -def build_method_specs( - method_names: list[str], - finetuned_weights_path: Optional[Path], -) -> list[MethodSpec]: - """Map output method labels to registrar methods and optional weights.""" - specs: list[MethodSpec] = [] - for method_name in method_names: - if method_name == "ANTS": - specs.append(MethodSpec(method_name, "ANTS")) - elif method_name == "greedy": - specs.append(MethodSpec(method_name, "greedy")) - elif method_name == "icon_default": - specs.append(MethodSpec(method_name, "ICON")) - elif method_name == "ants_icon_default": - specs.append(MethodSpec(method_name, "ANTS_ICON")) - elif method_name == "greedy_icon_default": - specs.append(MethodSpec(method_name, "greedy_ICON")) - elif method_name == "icon_finetuned": - specs.append(MethodSpec(method_name, "ICON", finetuned_weights_path)) - elif method_name == "ants_icon_finetuned": - specs.append(MethodSpec(method_name, "ANTS_ICON", finetuned_weights_path)) - elif method_name == "greedy_icon_finetuned": - specs.append(MethodSpec(method_name, "greedy_ICON", finetuned_weights_path)) - else: - raise ValueError(f"Unknown method: {method_name}") - - for spec in specs: - if "finetuned" in spec.output_name and spec.icon_weights_path is None: - raise ValueError(f"{spec.output_name} requires --finetuned-weights-path") - return specs - - -def configure_registrar( - method_spec: MethodSpec, - fixed_image: itk.Image, - fixed_labelmap: Optional[itk.Image], - ants_iterations: list[int], - greedy_iterations: list[int], - icon_iterations: int, -) -> RegisterTimeSeriesImages: - """Create and configure the time-series registrar.""" - registrar = RegisterTimeSeriesImages( - registration_method=method_spec.registration_method - ) - registrar.set_modality("ct") - registrar.set_fixed_image(fixed_image) - registrar.set_fixed_labelmap(fixed_labelmap) - registrar.set_number_of_iterations_ANTS(ants_iterations) - registrar.set_number_of_iterations_greedy(greedy_iterations) - registrar.set_number_of_iterations_ICON(icon_iterations) - if method_spec.icon_weights_path is not None: - registrar.registrar_ICON.set_weights_path(str(method_spec.icon_weights_path)) - return registrar - - -def write_summary(path: Path, rows: list[dict[str, object]]) -> None: - """Write experiment summary rows.""" - if not rows: - return - path.parent.mkdir(parents=True, exist_ok=True) - fieldnames = list(rows[0].keys()) - with path.open("w", newline="", encoding="utf-8") as fh: - writer = csv.DictWriter(fh, fieldnames=fieldnames) - writer.writeheader() - writer.writerows(rows) - - -def run_method_for_subject( - subject_id: str, - reference_artifacts: ImageArtifacts, - timepoint_artifacts: list[ImageArtifacts], - method_spec: MethodSpec, - output_dir: Path, - run_resegmentation: bool, - ants_iterations: list[int], - greedy_iterations: list[int], - icon_iterations: int, - error_detail_file: Path, -) -> list[dict[str, object]]: - """Run one registration method for one subject and return summary rows.""" - if reference_artifacts.landmark_file is None: - raise FileNotFoundError( - f"Missing reference landmarks for {reference_artifacts.image_file}" - ) - - fixed_image = itk.imread(str(reference_artifacts.image_file), pixel_type=itk.F) - fixed_labelmap = None - if reference_artifacts.labelmap_file is not None: - fixed_labelmap = itk.imread(str(reference_artifacts.labelmap_file)) - - moving_images = [ - itk.imread(str(artifacts.image_file), pixel_type=itk.F) - for artifacts in timepoint_artifacts - ] - moving_labelmaps = None - if all(artifacts.labelmap_file is not None for artifacts in timepoint_artifacts): - moving_labelmaps = [ - itk.imread(str(artifacts.labelmap_file)) - for artifacts in timepoint_artifacts - ] - - registrar = configure_registrar( - method_spec, - fixed_image, - fixed_labelmap, - ants_iterations, - greedy_iterations, - icon_iterations, - ) - - result = registrar.register_time_series( - moving_images=moving_images, - moving_labelmaps=moving_labelmaps, - reference_frame=0, - register_reference=True, - prior_weight=0.0, - ) - - reference_landmarks = read_landmarks(reference_artifacts.landmark_file) - transform_tools = TransformTools() - segmenter = SegmentHeartSimpleware() if run_resegmentation else None - subject_method_dir = output_dir / method_spec.output_name / subject_id - subject_method_dir.mkdir(parents=True, exist_ok=True) - - rows: list[dict[str, object]] = [] - for index, artifacts in enumerate(timepoint_artifacts): - timepoint_dir = subject_method_dir / artifacts.timepoint - timepoint_dir.mkdir(parents=True, exist_ok=True) - - forward_transform = result["forward_transforms"][index] - inverse_transform = result["inverse_transforms"][index] - loss = result["losses"][index] - - forward_file = timepoint_dir / "time_to_reference.hdf" - inverse_file = timepoint_dir / "reference_to_time.hdf" - itk.transformwrite(forward_transform, str(forward_file), compression=True) - itk.transformwrite(inverse_transform, str(inverse_file), compression=True) - - moving_to_reference = transform_tools.transform_image( - moving_images[index], - forward_transform, - fixed_image, - ) - moving_to_reference_file = timepoint_dir / "time_to_reference.mha" - itk.imwrite( - moving_to_reference, str(moving_to_reference_file), compression=True - ) - - reference_to_time = transform_tools.transform_image( - fixed_image, - inverse_transform, - moving_images[index], - ) - reference_to_time_file = timepoint_dir / "reference_to_time.mha" - itk.imwrite(reference_to_time, str(reference_to_time_file), compression=True) - - row: dict[str, object] = { - "subject_id": subject_id, - "method": method_spec.output_name, - "timepoint": artifacts.timepoint, - "moving_image": str(artifacts.image_file), - "forward_transform": str(forward_file), - "inverse_transform": str(inverse_file), - "loss": float(loss), - } - - if artifacts.landmark_file is not None: - timepoint_landmarks = read_landmarks(artifacts.landmark_file) - direct_landmarks = transform_landmarks( - reference_landmarks, - inverse_transform, - ) - direct_errors = landmark_errors(direct_landmarks, timepoint_landmarks) - write_error_details( - error_detail_file, - subject_id, - method_spec.output_name, - artifacts.timepoint, - "direct", - direct_errors, - ) - row.update(summarize_errors(direct_errors, "direct")) - else: - row.update(summarize_errors({}, "direct")) - - if run_resegmentation and segmenter is not None: - segmentation = segmenter.segment( - reference_to_time, - contrast_enhanced_study=False, - ) - warped_labelmap = segmentation["labelmap"] - warped_labelmap_file = timepoint_dir / "reference_to_time_labelmap.nii.gz" - itk.imwrite(warped_labelmap, str(warped_labelmap_file), compression=True) - reseg_landmarks = segmenter.get_landmarks() - reseg_landmark_file = timepoint_dir / "reference_to_time_landmark.csv" - write_landmarks(reseg_landmark_file, reseg_landmarks) - row["resegmented_labelmap"] = str(warped_labelmap_file) - row["resegmented_landmarks"] = str(reseg_landmark_file) - - if artifacts.landmark_file is not None: - timepoint_landmarks = read_landmarks(artifacts.landmark_file) - reseg_errors = landmark_errors(reseg_landmarks, timepoint_landmarks) - write_error_details( - error_detail_file, - subject_id, - method_spec.output_name, - artifacts.timepoint, - "resegmented", - reseg_errors, - ) - row.update(summarize_errors(reseg_errors, "resegmented")) - else: - row.update(summarize_errors({}, "resegmented")) - - if artifacts.labelmap_file is not None: - timepoint_labelmap = itk.imread(str(artifacts.labelmap_file)) - row.update( - summarize_dice(dice_by_label(warped_labelmap, timepoint_labelmap)) - ) - else: - row.update(summarize_dice({})) - else: - row["resegmented_labelmap"] = "" - row["resegmented_landmarks"] = "" - row.update(summarize_errors({}, "resegmented")) - row.update(summarize_dice({})) - - rows.append(row) - - return rows - - -def parse_iterations(value: str) -> list[int]: - """Parse comma-separated multi-resolution iteration counts.""" - return [int(item.strip()) for item in value.split(",") if item.strip()] - - -def main() -> int: - """Run the longitudinal registration comparison experiment.""" - parser = argparse.ArgumentParser( - description="Compare ANTS, Greedy, and ICON longitudinal registration." - ) - parser.add_argument("--reference-dir", type=Path, default=DEFAULT_REF_DIR) - parser.add_argument( - "--timepoint-base-dir", - type=Path, - default=DEFAULT_TIMEPOINT_BASE_DIR, - ) - parser.add_argument("--output-dir", type=Path, default=DEFAULT_OUTPUT_DIR) - parser.add_argument( - "--segmentation-base-dir", - type=Path, - default=DEFAULT_SEGMENTATION_BASE_DIR, - help="Directory with per-subject precomputed *_labelmap and *_landmark files.", - ) - parser.add_argument("--reference-pattern", default="pm00*.nii.gz") - parser.add_argument("--timepoint-pattern", default="*.nii.gz") - parser.add_argument("--segmentation-dir", default=DEFAULT_SEGMENTATION_DIR) - parser.add_argument( - "--exclude-token", - action="append", - default=list(DEFAULT_EXCLUDE_TOKENS), - help="Filename token to exclude from time-point inputs.", - ) - parser.add_argument( - "--methods", - nargs="+", - default=None, - help="Methods to run. Defaults include finetuned methods when weights are set.", - ) - parser.add_argument("--finetuned-weights-path", type=Path, default=None) - parser.add_argument("--max-subjects", type=int, default=None) - parser.add_argument("--max-timepoints", type=int, default=None) - parser.add_argument("--ANTS-iterations", default="30,15,7,3") - parser.add_argument("--greedy-iterations", default="30,15,7,3") - parser.add_argument("--ICON-iterations", type=int, default=20) - parser.add_argument( - "--skip-resegmentation", - action="store_true", - help="Skip Simpleware re-segmentation mode.", - ) - parser.add_argument( - "--dry-run", - action="store_true", - help="Validate discovered files and planned methods without registration.", - ) - args = parser.parse_args() - - method_names = args.methods - if method_names is None: - method_names = list(DEFAULT_METHODS) - method_names.append("greedy_icon_default") - if args.finetuned_weights_path is not None: - method_names.extend( - [ - "icon_finetuned", - "ants_icon_finetuned", - "greedy_icon_finetuned", - ] - ) - - method_specs = build_method_specs(method_names, args.finetuned_weights_path) - subjects = discover_subjects( - args.reference_dir, - args.timepoint_base_dir, - args.reference_pattern, - args.timepoint_pattern, - tuple(args.exclude_token), - args.segmentation_dir, - args.segmentation_base_dir, - ) - if args.max_subjects is not None: - subjects = subjects[: args.max_subjects] - - if args.dry_run: - for subject_id, reference_artifacts, timepoint_artifacts in subjects: - if args.max_timepoints is not None: - timepoint_artifacts = timepoint_artifacts[: args.max_timepoints] - missing_landmarks = sum( - artifacts.landmark_file is None for artifacts in timepoint_artifacts - ) - missing_labelmaps = sum( - artifacts.labelmap_file is None for artifacts in timepoint_artifacts - ) - print( - f"{subject_id}: {len(timepoint_artifacts)} time points, " - f"reference_landmarks={reference_artifacts.landmark_file is not None}, " - f"reference_labelmap={reference_artifacts.labelmap_file is not None}, " - f"missing_time_landmarks={missing_landmarks}, " - f"missing_time_labelmaps={missing_labelmaps}" - ) - print("Methods: " + ", ".join(spec.output_name for spec in method_specs)) - return 0 - - summary_rows: list[dict[str, object]] = [] - detail_file = args.output_dir / "landmark_errors_by_point.csv" - if detail_file.exists(): - detail_file.unlink() - - for subject_id, reference_artifacts, timepoint_artifacts in subjects: - if args.max_timepoints is not None: - timepoint_artifacts = timepoint_artifacts[: args.max_timepoints] - if not timepoint_artifacts: - raise ValueError(f"No time-point images found for {subject_id}") - print( - f"Running {subject_id}: {len(timepoint_artifacts)} time points, " - f"{len(method_specs)} methods" - ) - for method_spec in method_specs: - print(f" Method: {method_spec.output_name}") - rows = run_method_for_subject( - subject_id, - reference_artifacts, - timepoint_artifacts, - method_spec, - args.output_dir, - not args.skip_resegmentation, - parse_iterations(args.ants_iterations), - parse_iterations(args.greedy_iterations), - args.icon_iterations, - detail_file, - ) - summary_rows.extend(rows) - write_summary(args.output_dir / "registration_summary.csv", summary_rows) - - print(f"Wrote summary: {args.output_dir / 'registration_summary.csv'}") - print(f"Wrote landmark details: {detail_file}") - return 0 - - -if __name__ == "__main__": - raise SystemExit(main()) diff --git a/experiments/LongitudinalRegistration/experiment_recon_4d.py b/experiments/LongitudinalRegistration/experiment_recon_4d.py deleted file mode 100644 index 7d9ce26..0000000 --- a/experiments/LongitudinalRegistration/experiment_recon_4d.py +++ /dev/null @@ -1,191 +0,0 @@ -# %% [markdown] -# # 4D CT Reconstruction Using RegisterTimeSeriesImages Class -# -# This script demonstrates the use of the `RegisterTimeSeriesImages` class -# to register a time series of CT images to a common reference frame. -# -# This is a refactored version of `reconstruct_4d_ct.ipynb` that uses -# the new class-based approach,including: -# - Registration of time series images using ANTs, Greedy, ICON, or combined -# ANTs/Greedy + ICON methods -# - Reconstruction of time series using the `reconstruct_time_series()` method -# - Optional upsampling to fixed image resolution while preserving spatial positioning -# - -# %% -# Import necessary libraries -######################################################## - -import os - -import itk -import numpy as np -from physiomotion4d import RegisterTimeSeriesImages - -# %% -# Identify reference images -######################################################## - -ref_data_dir = "d:/PhysioMotion4D/duke_data/ref_images" -src_data_dir_base = "d:/PhysioMotion4D/duke_data/gated_nii" -dest_data_dir_base = "d:/PhysioMotion4D/duke_data/recon4d" - -ref_files = [ - os.path.join(ref_data_dir, f) - for f in sorted(os.listdir(ref_data_dir)) - if f.startswith("pm00") and f.endswith(".nii.gz") -] - -print(f"Found {len(ref_files)} reference images") - -# %% -# Identify source data directories and files using reference image names -######################################################## - - -print(os.path.basename(ref_files[0])[:6]) -src_data_dirs = [] -src_data_files = [] -for ref_file in ref_files: - src_dir = os.path.join(src_data_dir_base, os.path.basename(ref_file)[:6]) - src_data_dirs.append(src_dir) - - file_list = sorted(os.listdir(src_dir)) - valid_file_list = [ - f - for f in file_list - if "dia" not in f - and "nop" not in f - and "sys" not in f - and f.endswith(".nii.gz") - ] - src_data_files.append(valid_file_list) - -print(f"Found {len(src_data_dirs)} source data directories") -for d, fs in zip(src_data_dirs, src_data_files): - print(f"{d}: {len(fs)} files") - for f in fs: - print(f" {f}") - -# %% -# Define registration function -######################################################## - - -def register_time_series( - reference_image_file: str, - source_image_dir: str, - source_image_files: list[str], - registration_method: str, -) -> None: - # ANTs registration - if registration_method in ["ANTS", "greedy"]: - number_of_iterations = [30, 15, 7, 3] - elif registration_method == "ICON": - number_of_iterations = 20 - elif registration_method in ["ANTS_ICON", "greedy_ICON"]: - number_of_iterations = [[30, 15, 7, 3], 20] - else: - raise ValueError(f"Invalid registration method: {registration_method}") - - # Create output dir - output_dir = os.path.join( - dest_data_dir_base, registration_method, os.path.basename(source_image_dir) - ) - os.makedirs(output_dir, exist_ok=True) - - # Read the reference image as the fixed image - fixed_image = itk.imread(reference_image_file, pixel_type=itk.F) - - images = [] - for file in source_image_files: - img = itk.imread(os.path.join(source_image_dir, file), pixel_type=itk.F) - images.append(img) - - reference_image_num = 7 - register_start_to_reference = True - if reference_image_file in source_image_files: - reference_image_num = source_image_files.index(reference_image_file) - register_start_to_reference = False - - portion_of_prior_transform_to_init_next_transform = 0.0 - - # Register the time series - registrar = RegisterTimeSeriesImages(registration_method=registration_method) - registrar.set_modality("ct") - registrar.set_fixed_image(fixed_image) - if registration_method == "ANTS": - registrar.set_number_of_iterations_ANTS(number_of_iterations) - elif registration_method == "greedy": - registrar.set_number_of_iterations_greedy(number_of_iterations) - elif registration_method == "ICON": - registrar.set_number_of_iterations_ICON(number_of_iterations) - elif registration_method == "ANTS_ICON": - registrar.set_number_of_iterations_ANTS(number_of_iterations[0]) - registrar.set_number_of_iterations_ICON(number_of_iterations[1]) - elif registration_method == "greedy_ICON": - registrar.set_number_of_iterations_greedy(number_of_iterations[0]) - registrar.set_number_of_iterations_ICON(number_of_iterations[1]) - else: - raise ValueError(f"Invalid registration method: {registration_method}") - - result = registrar.register_time_series( - moving_images=images, - reference_frame=reference_image_num, - register_reference=register_start_to_reference, - prior_weight=portion_of_prior_transform_to_init_next_transform, - ) - - upsampled_images = registrar.reconstruct_time_series( - moving_images=images, - inverse_transforms=result["inverse_transforms"], - upsample_to_fixed_resolution=True, - ) - - losses = result["losses"] - print("Registration complete!") - print(f" Average loss: {np.mean(losses):.6f}") - print(f" Min loss: {np.min(losses):.6f}") - print(f" Max loss: {np.max(losses):.6f}") - print("") - print("Saving results...") - output_file_basename = os.path.basename(reference_image_file)[:6] - for i, fwd_transform in enumerate(result["forward_transforms"]): - time_point_index = source_image_files[i].index("_g") + 2 - time_point = source_image_files[i][time_point_index : time_point_index + 3] - - output_file = f"{output_file_basename}_{time_point}_fwd.hdf" - itk.transformwrite( - fwd_transform, - os.path.join(output_dir, output_file), - compression=True, - ) - - inv_transform = result["inverse_transforms"][i] - output_file = f"{output_file_basename}_{time_point}_inv.hdf" - itk.transformwrite( - inv_transform, - os.path.join(output_dir, output_file), - compression=True, - ) - - output_file = f"{output_file_basename}_{time_point}_hrr.mha" - itk.imwrite( - upsampled_images[i], - os.path.join(output_dir, output_file), - compression=True, - ) - - -# %% -# Register time series -######################################################## - -for ref_file, src_dir, src_files in zip(ref_files, src_data_dirs, src_data_files): - register_time_series(ref_file, src_dir, src_files, "ANTS") - register_time_series(ref_file, src_dir, src_files, "greedy") - register_time_series(ref_file, src_dir, src_files, "ICON") - register_time_series(ref_file, src_dir, src_files, "ANTS_ICON") - register_time_series(ref_file, src_dir, src_files, "greedy_ICON") - -# %% diff --git a/experiments/LongitudinalRegistration/registration_results_analysis.py b/experiments/LongitudinalRegistration/registration_results_analysis.py new file mode 100644 index 0000000..562f61b --- /dev/null +++ b/experiments/LongitudinalRegistration/registration_results_analysis.py @@ -0,0 +1,281 @@ +"""Summarize registration Dice and landmark RMSE results across experiments. + +For every ``results_*`` directory under a base directory this script reads: + +- ``registration_dice_init.csv`` with columns + ``subject_id, method, stem, label, dice`` (one row per subject / time point + / anatomy label). +- ``registration_landmarks_init.csv`` with columns + ``subject_id, method, stem, name, rms_err_mm`` (one row per subject / time + point / landmark). + +It pools all subjects and all time points (rows) within each +``(results_dir, method)`` group and reports the mean, standard deviation, and +95th percentile of the Dice score per label (1..10) and across all labels, and +the same statistics of the landmark RMSE per landmark and across all landmarks. + +**Duplicate handling.** If the same ``(subject_id, method, stem, label)`` +combination appears more than once in a single Dice CSV, the *n*-th occurrence +is treated as if it came from a separate directory named +``{results_dir}_{n}`` (e.g. ``results_ml_2``, ``results_ml_3``). The same +applies for ``(subject_id, method, stem, name)`` duplicates in the landmark +CSV. + +Two summary tables are produced, each indexed by ``(results_dir, method)`` +(with any ``_n`` suffixes introduced by duplicate splitting): one for label +Dice scores and one for landmark RMSE. +""" + +from __future__ import annotations + +import argparse +from pathlib import Path +from typing import Optional + +import numpy as np +import pandas as pd + +DICE_CSV = "registration_dice_init.csv" +LANDMARKS_CSV = "registration_landmarks_init.csv" + +ALL_KEY = "all" +STATS = ("mean", "std", "p95") + + +def _split_by_occurrence( + frame: pd.DataFrame, + key_cols: list[str], + dir_name: str, +) -> dict[str, pd.DataFrame]: + """Split ``frame`` into per-occurrence sub-frames keyed by a group name. + + Within ``frame``, each unique combination of ``key_cols`` may appear + multiple times. The *n*-th occurrence (1-indexed) of a given key is + assigned to group ``dir_name`` (for n=1) or ``{dir_name}_{n}`` (for n>1). + Sub-frames for different occurrence numbers are returned in a dict whose + keys follow that naming convention. + + Args: + frame: Input data frame, one row per observation. + key_cols: Columns that together identify a unique observation. + dir_name: Base name used to construct group keys. + + Returns: + Ordered dict mapping group name → sub-data-frame. + """ + occurrence = frame.groupby(key_cols, sort=False).cumcount() + 1 # 1-indexed + max_occ = int(occurrence.max()) + result: dict[str, pd.DataFrame] = {} + for n in range(1, max_occ + 1): + sub = frame[occurrence == n] + if sub.empty: + continue + group_name = dir_name if n == 1 else f"{dir_name}_{n}" + result[group_name] = sub.reset_index(drop=True) + return result + + +def _aggregate( + frames: dict[str, pd.DataFrame], + group_col: str, + value_col: str, +) -> pd.DataFrame: + """Aggregate one value column grouped by ``group_col`` for each (directory, method). + + Args: + frames: Mapping from ``results_*`` directory name to its data frame. + group_col: Column whose distinct values become summary categories + (``"label"`` for Dice, ``"name"`` for landmarks). + value_col: Column to summarize (``"dice"`` or ``"rms_err_mm"``). + + Returns: + A data frame indexed by ``(results_dir, method)`` with a two-level + column ``(category, stat)`` where ``category`` is each distinct group + value plus ``"all"`` and ``stat`` is one of ``mean``, ``std``, ``p95``. + Categories are pooled across all subjects and time points within each + ``(directory, method)`` group. + """ + + def _stats(series: pd.Series) -> dict[str, float]: + # Drop missing measurements so a single blank row does not poison the + # mean / std / percentile for an entire landmark or the pooled "all". + values = pd.to_numeric(series, errors="coerce").dropna().to_numpy(dtype=float) + if values.size == 0: + return {stat: float("nan") for stat in STATS} + return { + "mean": float(np.mean(values)), + "std": float(np.std(values, ddof=1)) if values.size > 1 else 0.0, + "p95": float(np.percentile(values, 95)), + } + + rows: dict[tuple[str, str], dict[tuple[object, str], float]] = {} + for dir_name, frame in frames.items(): + method_groups: list[tuple[str, pd.DataFrame]] + if "method" in frame.columns: + method_groups = [ + (str(m), sub) for m, sub in frame.groupby("method", sort=True) + ] + else: + method_groups = [("", frame)] + for method, mframe in method_groups: + row: dict[tuple[object, str], float] = {} + for category, group in mframe.groupby(group_col): + for stat, value in _stats(group[value_col]).items(): + row[(category, stat)] = value + for stat, value in _stats(mframe[value_col]).items(): + row[(ALL_KEY, stat)] = value + rows[(dir_name, method)] = row + + table = pd.DataFrame.from_dict(rows, orient="index") + table.index = pd.MultiIndex.from_tuples( + table.index, names=["results_dir", "method"] + ) + # from_dict yields a flat Index of tuples; promote to a MultiIndex so the + # reindex below aligns on (category, stat) rather than producing NaNs. + table.columns = pd.MultiIndex.from_tuples(table.columns) + + # Order columns: numeric/string categories sorted, "all" last. + categories = sorted( + {category for category, _ in table.columns if category != ALL_KEY}, + key=lambda c: (0, int(c)) if str(c).isdigit() else (1, str(c)), + ) + categories.append(ALL_KEY) + ordered = [(category, stat) for category in categories for stat in STATS] + table = table.reindex(columns=pd.MultiIndex.from_tuples(ordered)) + return table + + +def _method_summary( + dice_table: pd.DataFrame, + landmark_table: pd.DataFrame, +) -> pd.DataFrame: + """Collapse both summary tables to one row per ``(results_dir, method)``. + + Extracts the ``"all"`` slice (pooled over every label / landmark) from + each table and concatenates the columns side-by-side. + + Args: + dice_table: Output of :func:`_aggregate` for Dice scores. + landmark_table: Output of :func:`_aggregate` for landmark RMSE. + + Returns: + Data frame indexed by ``(results_dir, method)`` with flat columns + ``dice_mean``, ``dice_std``, ``dice_p95``, ``landmark_mean``, + ``landmark_std``, ``landmark_p95``. + """ + parts: list[pd.DataFrame] = [] + for table, prefix in ((dice_table, "dice"), (landmark_table, "landmark")): + if ALL_KEY in table.columns.get_level_values(0): + sub = table[ALL_KEY].copy() + sub.columns = [f"{prefix}_{c}" for c in sub.columns] + parts.append(sub) + if not parts: + return pd.DataFrame() + return pd.concat(parts, axis=1) + + +def summarize(base_dir: Path) -> tuple[pd.DataFrame, pd.DataFrame]: + """Build the Dice and landmark RMSE summary tables for ``base_dir``. + + Args: + base_dir: Directory containing one or more ``results_*`` subdirectories. + + Returns: + ``(dice_table, landmark_table)`` summary frames, each indexed by + ``(results_dir, method)``. + """ + result_dirs = sorted(p for p in base_dir.glob("results_*") if p.is_dir()) + if not result_dirs: + raise FileNotFoundError(f"No results_* directories found in {base_dir}") + + dice_frames: dict[str, pd.DataFrame] = {} + landmark_frames: dict[str, pd.DataFrame] = {} + for result_dir in result_dirs: + dice_path = result_dir / DICE_CSV + landmark_path = result_dir / LANDMARKS_CSV + if dice_path.is_file(): + dice_frames.update( + _split_by_occurrence( + pd.read_csv(dice_path), + ["subject_id", "method", "stem", "label"], + result_dir.name, + ) + ) + if landmark_path.is_file(): + landmark_frames.update( + _split_by_occurrence( + pd.read_csv(landmark_path), + ["subject_id", "method", "stem", "name"], + result_dir.name, + ) + ) + + dice_table = _aggregate(dice_frames, "label", "dice") + landmark_table = _aggregate(landmark_frames, "name", "rms_err_mm") + return dice_table, landmark_table + + +def _print_table(title: str, table: pd.DataFrame) -> None: + """Print a summary table with a heading at full width.""" + print(f"\n{'=' * 70}\n{title}\n{'=' * 70}") + with pd.option_context( + "display.max_columns", + None, + "display.width", + None, + "display.float_format", + lambda v: f"{v:.4f}", + ): + print(table) + + +def main(argv: Optional[list[str]] = None) -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "base_dir", + nargs="?", + type=Path, + default=Path(__file__).resolve().parent, + help="Directory containing results_* subdirectories " + "(default: this script's directory).", + ) + parser.add_argument( + "--out-dir", + type=Path, + default=None, + help="Optional directory to write summary CSV files into.", + ) + args = parser.parse_args(argv) + + dice_table, landmark_table = summarize(args.base_dir) + methods_table = _method_summary(dice_table, landmark_table) + + _print_table( + "Per-method summary (mean / std / 95th percentile across ALL labels " + "and ALL landmarks), grouped by results_dir + method", + methods_table, + ) + _print_table( + "Dice score by label (mean / std / 95th percentile), grouped by " + "results_dir + method, pooled over subjects and time points", + dice_table, + ) + _print_table( + "Landmark RMSE [mm] (mean / std / 95th percentile), grouped by " + "results_dir + method, pooled over subjects and time points", + landmark_table, + ) + + if args.out_dir is not None: + args.out_dir.mkdir(parents=True, exist_ok=True) + methods_out = args.out_dir / "summary_methods.csv" + dice_out = args.out_dir / "summary_dice.csv" + landmark_out = args.out_dir / "summary_landmarks.csv" + methods_table.to_csv(methods_out) + dice_table.to_csv(dice_out) + landmark_table.to_csv(landmark_out) + print(f"\nWrote {methods_out}\nWrote {dice_out}\nWrote {landmark_out}") + + +if __name__ == "__main__": + main() diff --git a/experiments/LongitudinalRegistration/registration_test.py b/experiments/LongitudinalRegistration/registration_test.py new file mode 100644 index 0000000..0c66ef3 --- /dev/null +++ b/experiments/LongitudinalRegistration/registration_test.py @@ -0,0 +1,134 @@ +# %% [markdown] +# # Registration test: pm0003 time point 20 -> time point 60 +# +# Registers pm0003 gated CT time point 20 (moving) to time point 60 +# (fixed) with deformable registration, then warps time point 20 into +# time point 60's space and writes it to disk. +# +# Switch backends by editing the single ``method`` variable below +# ("ANTS", "ICON", or "Greedy"). All paths are hard-coded; run the +# cells top to bottom. + +# %% +import time +from pathlib import Path + +import itk + +from physiomotion4d.register_images_ants import RegisterImagesANTS +from physiomotion4d.register_images_greedy import RegisterImagesGreedy +from physiomotion4d.register_images_icon import RegisterImagesICON +from physiomotion4d.transform_tools import TransformTools + +# %% [markdown] +# ## 1. Configuration and hard-coded paths +# +# Change ``method`` to switch backends. Time point 20 is the moving +# image; time point 60 is the fixed image. + +# %% +method = "Greedy" # one of: "ANTS", "ICON", "Greedy" + +data_dir = Path("d:/PhysioMotion4D/duke_data/gated_nii/pm0003") +moving_path = data_dir / "pm0003_dupr_135-0094_135_4700_g020_s2.000_n0058_11.nii.gz" +fixed_path = data_dir / "pm0003_dupr_135-0094_135_4700_g060_s2.000_n0058_15.nii.gz" + +output_dir = Path(__file__).parent / "results" / "registration_test" +output_dir.mkdir(parents=True, exist_ok=True) +output_path = output_dir / f"pm0003_g020_to_g060_{method.lower()}.mha" + +# %% [markdown] +# ## 2. Load the fixed (time point 60) and moving (time point 20) images + +# %% +fixed_image = itk.imread(str(fixed_path), pixel_type=itk.F) +moving_image = itk.imread(str(moving_path), pixel_type=itk.F) +print(f"Fixed (g060): {fixed_path.name}") +print(f"Moving (g020): {moving_path.name}") + +# %% [markdown] +# ## 3. Build and configure the registration backend +# +# ANTS and Greedy share ``set_transform_type``/``set_metric`` and take a +# per-level iteration list; ICON takes a single iteration count and has +# no transform-type/metric setters. + +# %% +if method == "ANTS": + reg = RegisterImagesANTS() + reg.set_transform_type("Deformable") + reg.set_metric("MeanSquares") + reg.set_number_of_iterations([40, 20, 10]) +elif method == "Greedy": + reg = RegisterImagesGreedy() + reg.set_transform_type("Deformable") + # NCC (CC) beats SSD for same-modality CT; tighter update-field smoothing + # (first sigma) captures more cardiac motion while staying diffeomorphic. + reg.set_metric("CC") + reg.set_number_of_iterations([40, 20, 10]) + reg.deformable_smoothing = "1.0vox 0.5vox" +elif method == "ICON": + reg = RegisterImagesICON() + reg.set_number_of_iterations(50) +else: + raise ValueError(f"Unknown method: {method}") + +reg.set_modality("ct") +reg.set_fixed_image(fixed_image) + +# %% [markdown] +# ## 4. Register time point 20 to time point 60 + +# %% +t_start = time.perf_counter() +reg_result = reg.register(moving_image=moving_image) +elapsed = time.perf_counter() - t_start + +forward_transform = reg_result["forward_transform"] +loss = float(reg_result["loss"]) +print(f"{method} registration done in {elapsed:.1f} s, loss={loss:.4f}") + +# %% [markdown] +# ## 5. Warp time point 20 into time point 60's space and save +# +# ``forward_transform`` is the transform consumed by ``transform_image`` to +# resample the moving image onto the fixed grid (it supplies the fixed->moving +# sampling map the ITK resampler needs). ``inverse_transform`` is the opposite +# direction, used to warp the fixed image onto the moving grid (e.g. in +# ``RegisterTimeSeriesImages.reconstruct_time_series``). This holds for all +# three backends (ANTS, ICON, Greedy). + +# %% +transform_tools = TransformTools() +warp_t_start = time.perf_counter() +warped_image = transform_tools.transform_image( + moving_image, + forward_transform, + fixed_image, + interpolation_method="linear", +) +itk.imwrite(warped_image, str(output_path), compression=True) +warp_elapsed = time.perf_counter() - warp_t_start +print(f"Wrote warped time point 20 -> 60: {output_path}") + +# %% [markdown] +# ## 6. Timing report +# +# Wall-clock seconds for the registration and the warp/write step. The +# registration time dominates and is the figure to compare across backends; +# for ICON it includes the one-time network load on this first (and only) +# pair. + +# %% +total_elapsed = elapsed + warp_elapsed +print() +print("=" * 44) +print(f"Timing report ({method})") +print("=" * 44) +print(f"{'Step':<22}{'seconds':>12}") +print("-" * 44) +print(f"{'register':<22}{elapsed:>12.2f}") +print(f"{'warp + write':<22}{warp_elapsed:>12.2f}") +print("-" * 44) +print(f"{'total':<22}{total_elapsed:>12.2f}") +print("=" * 44) diff --git a/pyproject.toml b/pyproject.toml index cefcec4..70d0aa2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -280,7 +280,6 @@ minversion = "7.0" addopts = [ "--strict-markers", "--strict-config", - "-W", "always", "--cov=physiomotion4d", "--cov-report=term-missing", "--cov-report=html", @@ -288,6 +287,17 @@ addopts = [ ] testpaths = ["tests"] pythonpath = ["."] +# Surface every warning by default ("always"), then silence the third-party +# ITK/SWIG binding noise emitted while wrapped C++ types are defined at import +# on CPython >=3.12 ("builtin type Swig... has no __module__ attribute"). +# Those come from the bindings, not our code, and would otherwise drown the +# warnings summary. Order matters: the trailing "ignore" entries are applied +# last and so take precedence over "always" for these specific messages. +filterwarnings = [ + "always", + "ignore:builtin type Swig", + "ignore:builtin type swigvarlink", +] markers = [ "unit: marks tests as unit tests (fast, isolated)", "integration: marks tests as integration tests (slower, multiple components)", @@ -398,12 +408,12 @@ lines-after-imports = 2 max-complexity = 10 [tool.ruff.lint.flake8-quotes] -inline-quotes = "single" +inline-quotes = "double" multiline-quotes = "double" docstring-quotes = "double" [tool.ruff.format] -quote-style = "single" +quote-style = "double" indent-style = "space" skip-magic-trailing-comma = false line-ending = "auto" diff --git a/src/physiomotion4d/__init__.py b/src/physiomotion4d/__init__.py index 5d00e32..63fc3a7 100644 --- a/src/physiomotion4d/__init__.py +++ b/src/physiomotion4d/__init__.py @@ -43,6 +43,7 @@ # Utility classes from .image_tools import ImageTools +from .labelmap_tools import LabelmapTools from .landmark_tools import LandmarkTools # Base classes @@ -106,6 +107,7 @@ "PhysioMotion4DBase", # Utility classes "ImageTools", + "LabelmapTools", "LandmarkTools", "TestTools", "TransformTools", diff --git a/src/physiomotion4d/contour_tools.py b/src/physiomotion4d/contour_tools.py index 9f3028e..6e72efe 100644 --- a/src/physiomotion4d/contour_tools.py +++ b/src/physiomotion4d/contour_tools.py @@ -230,7 +230,7 @@ def create_mask_from_mesh( # Direction: use identity for now (axis-aligned), will be handled by resampling # Flip Z axis to match ITK convention - ref_dir = np.array(binary_image.GetDirection()) + ref_dir = itk.array_from_matrix(binary_image.GetDirection()) ref_dir[2, 2] = -ref_dir[2, 2] binary_image.SetDirection(ref_dir) diff --git a/src/physiomotion4d/image_tools.py b/src/physiomotion4d/image_tools.py index c6e4023..c031603 100644 --- a/src/physiomotion4d/image_tools.py +++ b/src/physiomotion4d/image_tools.py @@ -279,9 +279,12 @@ def flip_image( flip1 = False flip2 = False if flip_and_make_identity: - flip0 = np.array(in_image.GetDirection())[0, 0] < 0 - flip1 = np.array(in_image.GetDirection())[1, 1] < 0 - flip2 = np.array(in_image.GetDirection())[2, 2] < 0 + # itk.array_from_matrix avoids itk.Matrix.__array__, whose missing + # copy keyword triggers a numpy>=2.0 DeprecationWarning. + direction = itk.array_from_matrix(in_image.GetDirection()) + flip0 = direction[0, 0] < 0 + flip1 = direction[1, 1] < 0 + flip2 = direction[2, 2] < 0 if flip_x: flip0 = True if flip_y: diff --git a/src/physiomotion4d/labelmap_tools.py b/src/physiomotion4d/labelmap_tools.py new file mode 100644 index 0000000..08fe4e6 --- /dev/null +++ b/src/physiomotion4d/labelmap_tools.py @@ -0,0 +1,187 @@ +""" +Labelmap Tools for PhysioMotion4D + +This module provides the :class:`LabelmapTools` class with the definitive +utility for turning a multi-label (or binary) segmentation labelmap into a +binary registration mask, optionally excluding specific labels and dilating +the result by a physical radius in millimeters. +""" + +import logging +from typing import Optional + +import itk +import numpy as np + +from physiomotion4d.physiomotion4d_base import PhysioMotion4DBase + + +class LabelmapTools(PhysioMotion4DBase): + """ + Utilities for converting segmentation labelmaps into registration masks. + + A labelmap is an ``itk.Image`` of integer labels where ``0`` is background + and each positive value identifies an anatomical structure. A registration + mask is a binary ``itk.Image`` where every foreground voxel is ``1``. This + class centralizes the labelmap-to-mask conversion so that thresholding, + label exclusion, and physically isotropic dilation are performed + identically everywhere in the platform. + + Example: + >>> tools = LabelmapTools() + >>> # Binary mask of every labeled voxel, dilated 5 mm + >>> mask = tools.convert_labelmap_to_mask(labelmap, dilation_in_mm=5.0) + >>> # Exclude the table/background labels 8 and 9 before masking + >>> mask = tools.convert_labelmap_to_mask( + ... labelmap, dilation_in_mm=5.0, exclude_labels=[8, 9] + ... ) + """ + + def __init__(self, log_level: int | str = logging.INFO) -> None: + """Initialize LabelmapTools. + + Args: + log_level: Logging level (default: logging.INFO) + """ + super().__init__(class_name=self.__class__.__name__, log_level=log_level) + + def convert_labelmap_to_mask( + self, + labelmap: itk.Image, + dilation_in_mm: float = 0.0, + exclude_labels: Optional[list[int]] = None, + ) -> itk.Image: + """Convert a labelmap into a binary registration mask. + + Any voxel whose label is in ``exclude_labels`` is set to background + first; every remaining non-zero voxel becomes foreground (``1``). The + binary mask is then dilated by ``dilation_in_mm`` millimeters of + physical radius. The radius is converted into per-axis voxel counts + from the labelmap's spacing so the dilation is physically isotropic + even on anisotropic grids; each per-axis count is clamped to at least + 1 voxel when ``dilation_in_mm > 0``. + + Axis ordering: the labelmap is a scalar 3D ``itk.Image`` in ITK + world-axis order (X, Y, Z). All thresholding is performed on the numpy + view (Z, Y, X) and written back through ``CopyInformation``, so origin, + spacing, and direction are preserved. + + Args: + labelmap: Multi-label or binary ``itk.Image``. Any non-zero voxel + that is not excluded is treated as foreground. + dilation_in_mm: Physical radius of the binary dilation in + millimeters. Pass ``0`` (or negative) to skip dilation and + return the raw thresholded mask. Default 0.0. + exclude_labels: Optional list of integer label values to force + to background before thresholding. When ``None`` (the default) + no labels are excluded. + + Returns: + ``itk.Image[itk.UC, 3]`` binary mask in the same physical space as + ``labelmap`` (origin, spacing, direction copied from the input). + """ + arr = itk.array_from_image(labelmap) + if exclude_labels: + arr = np.where(np.isin(arr, exclude_labels), 0, arr) + mask_arr = (arr > 0).astype(np.uint8) + mask = itk.image_from_array(mask_arr) + mask.CopyInformation(labelmap) + + if dilation_in_mm <= 0: + return mask + + spacing = labelmap.GetSpacing() + radius = itk.Size[3]() + for i in range(3): + radius[i] = max(1, int(round(dilation_in_mm / float(spacing[i])))) + structuring_element = itk.FlatStructuringElement[3].Ball(radius) + return itk.binary_dilate_image_filter( + mask, kernel=structuring_element, foreground_value=1 + ) + + def create_distance_map( + self, + labelmap: itk.Image, + max_distance_mm: float = 20.0, + distance_scale: float = 5.0, + ) -> itk.Image: + """Encode a labelmap as a continuous label-plus-boundary-distance image. + + Each output voxel holds its original integer label plus a small + fractional offset that encodes how far the voxel lies from the nearest + boundary between two differently-labeled regions: + + value = label + min(distance_to_nearest_boundary_mm, + max_distance_mm) / distance_scale + + The boundary set is every voxel that 6-neighbors a voxel with a + different label (background label ``0`` participates, so the outer + surface of each structure is a boundary). The unsigned physical + distance from each voxel to that set is computed with + ``SignedMaurerDistanceMapImageFilter`` (taking the magnitude), clipped + to ``max_distance_mm``, divided by ``distance_scale``, and added to the + voxel's original label. + + With the defaults (``20`` mm clip, ``5`` scale) the fractional offset + stays in ``[0.0, 4.0]``, potentially passing adjacent integer labels but + emphasizing in medial alignment as well as boundary. + + The motivation is registration metrics such as Greedy's NCC: a raw + integer labelmap is piecewise-constant, so the local variance inside + each region is zero and NCC produces NaN gradients. Replacing it with + this continuous encoding gives every region a smoothly varying signal + while preserving label identity. + + Axis ordering: ``labelmap`` is a scalar 3D ``itk.Image`` in ITK + world-axis order (X, Y, Z). All work is done on the numpy view + (Z, Y, X) and written back through ``CopyInformation``, so origin, + spacing, and direction are preserved. + + Args: + labelmap: Multi-label (or binary) ``itk.Image`` of integer labels. + max_distance_mm: Distance clip, in millimeters. Default 20.0. + distance_scale: Divisor applied to the clipped distance before it + is added to the label. Default 5.0. With the default clip + this bounds the fractional offset to ``[0, 4.0]``. + + Returns: + ``itk.Image[itk.F, 3]`` in the same physical space as ``labelmap`` + (origin, spacing, direction copied from the input). + """ + labels = itk.array_from_image(labelmap) + + # A voxel is on a label boundary when it differs from a 6-connected + # neighbor along any axis. Mark both voxels straddling each change. + boundary = np.zeros(labels.shape, dtype=bool) + for axis in range(labels.ndim): + changed = np.diff(labels, axis=axis) != 0 + lower = [slice(None)] * labels.ndim + upper = [slice(None)] * labels.ndim + lower[axis] = slice(0, -1) + upper[axis] = slice(1, None) + boundary[tuple(lower)] |= changed + boundary[tuple(upper)] |= changed + + if boundary.any(): + boundary_image = itk.image_from_array(boundary.astype(np.uint8)) + boundary_image.CopyInformation(labelmap) + distance_filter = itk.SignedMaurerDistanceMapImageFilter.New( + Input=boundary_image + ) + distance_filter.SetSquaredDistance(False) + distance_filter.SetUseImageSpacing(True) + distance_filter.Update() + distance = np.abs( + itk.array_from_image(distance_filter.GetOutput()).astype(np.float32) + ) + else: + # No inter-label boundary exists (single uniform label); every + # voxel gets a zero offset. + distance = np.zeros(labels.shape, dtype=np.float32) + + offset = np.clip(distance, 0.0, max_distance_mm) / distance_scale + encoded = labels.astype(np.float32) + offset + + encoded_image = itk.image_from_array(encoded) + encoded_image.CopyInformation(labelmap) + return encoded_image diff --git a/src/physiomotion4d/register_images_ants.py b/src/physiomotion4d/register_images_ants.py index 182afaf..d31552c 100644 --- a/src/physiomotion4d/register_images_ants.py +++ b/src/physiomotion4d/register_images_ants.py @@ -17,6 +17,7 @@ import itk import numpy as np from numpy.typing import NDArray + from physiomotion4d.register_images_base import RegisterImagesBase from physiomotion4d.transform_tools import TransformTools @@ -199,9 +200,7 @@ def _itk_to_ants_image( image_dimension = len(spatial_shape) - direction = np.asarray(itk_image.GetDirection()).reshape( - (image_dimension, image_dimension) - ) + direction = itk.array_from_matrix(itk_image.GetDirection()) spacing = list(itk_image.GetSpacing()) origin = list(itk_image.GetOrigin()) @@ -511,7 +510,7 @@ def registration_method( moving_image: itk.Image, moving_mask: Optional[itk.Image] = None, moving_labelmap: Optional[itk.Image] = None, - moving_image_pre: Optional[ants.ANTsImage] = None, + moving_image_pre: Optional[itk.Image] = None, initial_forward_transform: Optional[itk.Transform] = None, ) -> dict[str, Union[itk.Transform, float]]: """Register moving image to fixed image using ANTs registration algorithm. @@ -524,18 +523,25 @@ def registration_method( moving_image (itk.image): The 3D image to be registered/aligned. moving_mask (itk.image, optional): Binary mask defining the region of interest in the moving image - moving_image_pre (ants.core.ANTsImage, optional): Pre-processed moving image - in ANTs format. If None, preprocessing is performed automatically - initial_forward_transform (itk.Transform, optional): Initial transform from moving - to fixed space. Can be any ITK transform type (Affine, Rigid, - DisplacementField, Composite, etc.). Will be converted to ANTs - format automatically. The returned transforms will include this - initial transform composed with the registration result. + moving_image_pre (itk.Image, optional): Pre-processed moving image. + If None, preprocessing is performed automatically + initial_forward_transform (itk.Transform, optional): Initial + forward transform (same convention as the returned + forward_transform: used to warp the moving image onto the fixed + grid). Can be any ITK transform type (Affine, Rigid, + DisplacementField, Composite, etc.). It is applied by pre-warping + the moving image onto the fixed grid before registration; the + returned transforms compose this initial alignment with the + registration refinement. Returns: dict: Dictionary containing: - - "forward_transform": Transformation from moving to fixed - - "inverse_transform": Transformation from fixed to moving + - "forward_transform": Warps the moving image onto the fixed + grid (warping moving points/landmarks into fixed space uses + "inverse_transform" instead -- image and point warps use + opposite transforms; see + docs/developer/transform_conventions) + - "inverse_transform": Warps the fixed image onto the moving grid - "loss": Loss value from the registration Note: @@ -543,11 +549,13 @@ def registration_method( consistent. The forward and inverse transforms are stored separately by ANTs. - IMPORTANT: ANTs registration does NOT include the initial_transform - in its output fwdtransforms/invtransforms. This method automatically - composes the initial transform with the registration result, so the - returned transforms include both the initial alignment and - the registration refinement. + IMPORTANT: the initial transform is applied by pre-warping the + moving image onto the fixed grid (the same approach as + RegisterImagesICON) rather than via ants.registration's + initial_transform argument, which mishandles matrix (affine/ + translation) initials. This method composes the initial transform + with the registration result, so the returned transforms include + both the initial alignment and the registration refinement. Implementation details: - Uses ANTs registration with configurable transform types @@ -584,24 +592,29 @@ def registration_method( if self.fixed_image_pre is None: self.fixed_image_pre = self.preprocess(self.fixed_image, self.modality) - # Convert initial ITK transform to ANTs format if provided - initial_transform: str | list[str] = "identity" if initial_forward_transform is not None: - self.log_info("Converting initial ITK transform to ANTs format...") - initial_transform = self.itk_transform_to_ANTSfile( - itk_tfm=initial_forward_transform, - reference_image=self.fixed_image, - output_filename="initial_transform_temp.mat", + self.log_info("Pre-warping moving image with initial transform...") + transform_tools = TransformTools() + self.moving_image_pre = transform_tools.transform_image( + self.moving_image_pre, + initial_forward_transform, + self.fixed_image, ) - self.log_info("Initial transform converted successfully") + if self.moving_mask is not None: + self.moving_mask = transform_tools.transform_image( + self.moving_mask, + initial_forward_transform, + self.fixed_image, + interpolation_method="nearest", + ) transform_type = None if self.transform_type == "Deformable": transform_type = "antsRegistrationSyNQuick[so]" elif self.transform_type == "Affine": - transform_type = "antsRegistrationAffineQuick[so]" + transform_type = "Affine" elif self.transform_type == "Rigid": - transform_type = "antsRegistrationRigidQuick[so]" + transform_type = "Rigid" else: self.log_error("Invalid transform type: %s", self.transform_type) raise ValueError(f"Invalid transform type: {self.transform_type}") @@ -627,13 +640,36 @@ def registration_method( elif self.metric == "MeanSquares": syn_metric = "meansquares" + # antsRegistration --dimensionality 3 --float 0 \ + # --output [$thisfolder/pennTemplate_to_${sub}_,$thisfolder/pennTemplate_to_${sub}_Warped.nii.gz] \ + # --interpolation Linear \ + # --winsorize-image-intensities [0.005,0.995] \ + # --use-histogram-matching 0 \ + # --initial-moving-transform [$t1brain,$template,1] \ + # --transform Rigid[0.1] \ + # --metric MI[$t1brain,$template,1,32,Regular,0.25] \ + # --convergence [1000x500x250x100,1e-6,10] \ + # --shrink-factors 8x4x2x1 \ + # --smoothing-sigmas 3x2x1x0vox \ + # --transform Affine[0.1] \ + # --metric MI[$t1brain,$template,1,32,Regular,0.25] \ + # --convergence [1000x500x250x100,1e-6,10] \ + # --shrink-factors 8x4x2x1 \ + # --smoothing-sigmas 3x2x1x0vox \ + # --transform SyN[0.1,3,0] \ + # --metric CC[$t1brain,$template,1,4] \ + # --convergence [100x70x50x20,1e-6,10] \ + # --shrink-factors 8x4x2x1 \ + # --smoothing-sigmas 3x2x1x0vox \ + # -x $brainlesionmask + if self.fixed_mask is not None and self.moving_mask is not None: registration_result = ants.registration( fixed=self._itk_to_ants_image(self.fixed_image_pre), mask=self._itk_to_ants_image(self.fixed_mask), moving=self._itk_to_ants_image(self.moving_image_pre), moving_mask=self._itk_to_ants_image(self.moving_mask), - initial_transform=[initial_transform], + initial_transform=["identity"], type_of_transform=transform_type, aff_metric=aff_metric, syn_metric=syn_metric, @@ -646,7 +682,7 @@ def registration_method( registration_result = ants.registration( fixed=self._itk_to_ants_image(self.fixed_image_pre), moving=self._itk_to_ants_image(self.moving_image_pre), - initial_transform=[initial_transform], + initial_transform=["identity"], type_of_transform=transform_type, aff_metric=aff_metric, syn_metric=syn_metric, diff --git a/src/physiomotion4d/register_images_base.py b/src/physiomotion4d/register_images_base.py index 463e26f..1c397fb 100644 --- a/src/physiomotion4d/register_images_base.py +++ b/src/physiomotion4d/register_images_base.py @@ -19,9 +19,8 @@ from typing import Any, Optional, Union import itk -import numpy as np -from itk import TubeTK as ttk +from physiomotion4d.labelmap_tools import LabelmapTools from physiomotion4d.physiomotion4d_base import PhysioMotion4DBase from physiomotion4d.transform_tools import TransformTools @@ -59,8 +58,8 @@ class and implement the register() method. ... def registration_method(self, moving_image, **kwargs): ... # Implement specific registration algorithm ... return { - ... 'forward_transform': tfm_forward, # Moving → Fixed - ... 'inverse_transform': tfm_inverse, # Fixed → Moving + ... 'forward_transform': tfm_forward, # warps moving image -> fixed grid + ... 'inverse_transform': tfm_inverse, # warps fixed image -> moving grid ... 'loss': 0.0, ... } >>> @@ -68,8 +67,8 @@ class and implement the register() method. >>> registrar.set_modality('ct') >>> registrar.set_fixed_image(reference_image) >>> result = registrar.register(moving_image) - >>> forward_tfm = result['forward_transform'] # Moving → Fixed - >>> inverse_tfm = result['inverse_transform'] # Fixed → Moving + >>> forward_tfm = result['forward_transform'] # warps moving image -> fixed grid + >>> inverse_tfm = result['inverse_transform'] # warps fixed image -> moving grid """ def __init__(self, log_level: int | str = logging.INFO) -> None: @@ -84,6 +83,8 @@ def __init__(self, log_level: int | str = logging.INFO) -> None: """ super().__init__(class_name=self.__class__.__name__, log_level=log_level) + self.labelmap_tools = LabelmapTools(log_level=log_level) + self.net: Any = None self.modality: str = "ct" @@ -180,16 +181,10 @@ def set_fixed_mask(self, fixed_mask: Optional[itk.Image]) -> None: if self.fixed_image is None: raise ValueError("Fixed image must be set before setting a fixed mask.") - mask_arr = itk.GetArrayFromImage(fixed_mask) - mask_arr = np.where(mask_arr > 0, 1, 0) - self.fixed_mask = itk.GetImageFromArray(mask_arr.astype(np.uint8)) + self.fixed_mask = self.labelmap_tools.convert_labelmap_to_mask( + fixed_mask, dilation_in_mm=self.mask_dilation_mm + ) self.fixed_mask.CopyInformation(self.fixed_image) - if self.mask_dilation_mm > 0: - imMath = ttk.ImageMath.New(self.fixed_mask) - imMath.Dilate( - int(self.mask_dilation_mm / self.fixed_image.GetSpacing()[0]), 1, 0 - ) - self.fixed_mask = imMath.GetOutputUChar() def set_fixed_labelmap(self, fixed_labelmap: Optional[itk.Image]) -> None: """Set the fixed image labelmap (multi-label segmentation). @@ -251,8 +246,11 @@ def registration_method( Returns: dict: Dictionary containing: - - "forward_transform": Transform that warps moving image into fixed space - - "inverse_transform": Transform that warps fixed image into moving space + - "forward_transform": Warps the moving image onto the fixed + grid. Warping moving points/landmarks into fixed space uses + "inverse_transform" instead (see register() and + docs/developer/transform_conventions). + - "inverse_transform": Warps the fixed image onto the moving grid - "loss": Registration loss/metric value Raises: @@ -283,13 +281,24 @@ def register( Returns: dict: Dictionary containing transformation results: - - "forward_transform": Transforms moving image to fixed space (warps moving → fixed) - - "inverse_transform": Transforms fixed image to moving space (warps fixed → moving) + - "forward_transform": Warps the moving IMAGE onto the fixed + grid, i.e. transform_image(moving, forward_transform, fixed). + - "inverse_transform": Warps the fixed IMAGE onto the moving + grid, i.e. transform_image(fixed, inverse_transform, moving). - "loss": Registration loss/metric value Note: - - forward_transform: Use this to warp the moving image to match the fixed image - - inverse_transform: Use this to warp the fixed image to match the moving image + Image warps and point/landmark warps use OPPOSITE members of the + transform pair, because ITK image resampling pulls back (it maps a + fixed-grid sample to the moving image) while point transforms push + forward (they map a point to its corresponding location): + + - Warp the moving image into fixed space -> forward_transform + - Warp moving points/landmarks into fixed -> inverse_transform + - Warp the fixed image into moving space -> inverse_transform + - Warp fixed points/landmarks into moving -> forward_transform + + See docs/developer/transform_conventions for the full discussion. Raises: NotImplementedError: This method must be implemented by subclasses @@ -313,16 +322,10 @@ def register( new_moving_mask = moving_mask if moving_mask is not None: - mask_arr = itk.GetArrayFromImage(moving_mask) - mask_arr = np.where(mask_arr > 0, 1, 0) - new_moving_mask = itk.GetImageFromArray(mask_arr.astype(np.uint8)) + new_moving_mask = self.labelmap_tools.convert_labelmap_to_mask( + moving_mask, dilation_in_mm=self.mask_dilation_mm + ) new_moving_mask.CopyInformation(moving_image) - if self.mask_dilation_mm > 0: - imMath = ttk.ImageMath.New(new_moving_mask) - imMath.Dilate( - int(self.mask_dilation_mm / moving_image.GetSpacing()[0]), 1, 0 - ) - new_moving_mask = imMath.GetOutputUChar() self.moving_image = moving_image self.moving_image_pre = moving_image_pre diff --git a/src/physiomotion4d/register_images_greedy.py b/src/physiomotion4d/register_images_greedy.py index d3caed8..9e3016f 100644 --- a/src/physiomotion4d/register_images_greedy.py +++ b/src/physiomotion4d/register_images_greedy.py @@ -12,6 +12,8 @@ from __future__ import annotations import logging +import os +import tempfile from typing import Any, Optional, Union import itk @@ -64,6 +66,15 @@ class RegisterImagesGreedy(RegisterImagesBase): deformable_smoothing: Smoothing sigmas for deformable (e.g. "2.0vox 0.5vox") """ + # picsl_greedy 0.0.12 segfaults when its multi-component metric (image + + # labelmap channels) allocates a working buffer for a fixed grid larger + # than roughly 100M voxels (empirically: 95M voxels succeeds, 104M crashes; + # single-channel metrics are unaffected at any size). When the labelmap + # channel is active, the metric inputs are isotropically downsampled to + # stay under this conservative cap. Greedy emits physical-space transforms, + # so a coarser metric grid only coarsens warp sampling, not the frame. + _MAX_METRIC_VOXELS = 90_000_000 + def __init__(self, log_level: int | str = logging.INFO) -> None: """Initialize the Greedy image registration class. @@ -135,6 +146,98 @@ def _greedy_iterations_str(self) -> str: """Format iterations as Greedy -n string (e.g. 40x20x10).""" return "x".join(str(i) for i in self.number_of_iterations) + def _metric_downsample_scale(self, reference_image: itk.Image) -> float: + """Per-axis scale that keeps ``reference_image`` under the voxel cap. + + Returns ``1.0`` when the grid already fits within + ``_MAX_METRIC_VOXELS``; otherwise returns the isotropic per-axis factor + ``(_MAX_METRIC_VOXELS / voxels) ** (1/3)`` (always < 1.0) so the + downsampled grid lands at or just below the cap. + + Args: + reference_image: The fixed metric image (X, Y, Z) whose voxel count + drives the Greedy multi-component buffer size. + + Returns: + Per-axis resampling scale in ``(0, 1]``. + """ + size = reference_image.GetLargestPossibleRegion().GetSize() + voxels = int(size[0]) * int(size[1]) * int(size[2]) + if voxels <= self._MAX_METRIC_VOXELS: + return 1.0 + scale = float((self._MAX_METRIC_VOXELS / voxels) ** (1.0 / 3.0)) + self.log_info( + "Greedy labelmap metric: downsampling %d-voxel fixed grid by " + "%.3f/axis to stay under the %d-voxel picsl_greedy crash threshold.", + voxels, + scale, + self._MAX_METRIC_VOXELS, + ) + return scale + + def _downsample_image( + self, image: itk.Image, scale: float, nearest: bool = False + ) -> itk.Image: + """Isotropically resample ``image`` by ``scale`` (no-op when >= 1.0). + + The physical extent is preserved exactly: the new per-axis spacing is + chosen so ``new_size * new_spacing == old_size * old_spacing``, so the + coarser grid covers the same world-space region with the same origin + and direction. Axis order is ITK world order (X, Y, Z). + + Args: + image: Scalar 3D ``itk.Image`` to resample. + scale: Per-axis factor in ``(0, 1]``; ``>= 1.0`` returns ``image`` + unchanged so the full-resolution path is untouched. + nearest: Use nearest-neighbor interpolation (for labelmaps and + masks) instead of linear. + + Returns: + The resampled ``itk.Image``, or ``image`` itself when ``scale`` is + ``>= 1.0``. + """ + if scale >= 1.0: + return image + size = image.GetLargestPossibleRegion().GetSize() + spacing = image.GetSpacing() + new_size = [max(1, int(round(int(size[i]) * scale))) for i in range(3)] + new_spacing = [float(spacing[i]) * int(size[i]) / new_size[i] for i in range(3)] + kwargs: dict[str, Any] = { + "output_origin": image.GetOrigin(), + "output_direction": image.GetDirection(), + "size": new_size, + "output_spacing": new_spacing, + } + if nearest: + kwargs["interpolator"] = itk.NearestNeighborInterpolateImageFunction.New( + image + ) + return itk.resample_image_filter(image, **kwargs) + + def _write_affine_matrix_file(self, mat_4x4: NDArray[np.float64]) -> str: + """Write a 4x4 RAS affine matrix to a temporary Greedy ``.mat`` file. + + Greedy's in-memory interface corrupts the heap when a numpy affine + matrix is supplied as an initial transform (``-ia``/``-it``); passing a + file path instead avoids that native crash. Greedy reads a plain-text + 4x4 RAS matrix, which is what ``numpy.savetxt`` writes here. + + Args: + mat_4x4: 4x4 affine matrix in RAS (Greedy) convention. + + Returns: + Path to the written temporary ``.mat`` file. The caller is + responsible for deleting it. + """ + mat_4x4 = np.asarray(mat_4x4, dtype=np.float64) + if mat_4x4.shape != (4, 4): + raise ValueError(f"Expected 4x4 matrix, got shape {mat_4x4.shape}") + fd, path = tempfile.mkstemp(suffix=".mat", prefix="greedy_aff_") + os.close(fd) + np.savetxt(path, mat_4x4, fmt="%.10f") + self.log_debug("Wrote Greedy affine init matrix to %s", path) + return path + def _matrix_to_itk_affine(self, mat_4x4: NDArray[np.float64]) -> itk.Transform: """Convert 4x4 affine matrix to ITK AffineTransform.""" mat_4x4 = np.asarray(mat_4x4, dtype=np.float64) @@ -174,48 +277,69 @@ def _registration_method_affine_or_rigid( self, fixed_sitk: Any, moving_sitk: Any, - fixed_mask_sitk: Optional[Any], - moving_mask_sitk: Optional[Any], iterations_str: str, metric_str: str, dof: int, + fixed_mask_sitk: Optional[Any] = None, + moving_mask_sitk: Optional[Any] = None, + fixed_labelmap_sitk: Optional[Any] = None, + moving_labelmap_sitk: Optional[Any] = None, initial_affine: Optional[NDArray[np.float64]] = None, ) -> tuple[NDArray[np.float64], float]: """Run Greedy affine or rigid registration. Returns (4x4 matrix, loss).""" Greedy3D = _try_import_greedy() g = Greedy3D() - cmd = f"-i fixed moving -a -dof {dof} -n {iterations_str} -m {metric_str} -o aff_out" + cmd = "-d 3" + if fixed_labelmap_sitk is not None and moving_labelmap_sitk is not None: + cmd += " -w 0.60" + cmd += " -i fixed moving" kwargs: dict[str, Any] = { "fixed": fixed_sitk, "moving": moving_sitk, - "aff_out": None, } + if fixed_labelmap_sitk is not None and moving_labelmap_sitk is not None: + cmd += " -w 0.40 -i fixed_labelmap moving_labelmap" + kwargs["fixed_labelmap"] = fixed_labelmap_sitk + kwargs["moving_labelmap"] = moving_labelmap_sitk + cmd += f" -a -dof {dof} -n {iterations_str} -m {metric_str} -o aff_out" + kwargs["aff_out"] = None if fixed_mask_sitk is not None and moving_mask_sitk is not None: cmd += " -gm fixed_mask -mm moving_mask" kwargs["fixed_mask"] = fixed_mask_sitk kwargs["moving_mask"] = moving_mask_sitk + # Greedy crashes (heap corruption) when an initial affine is passed as an + # in-memory matrix; write it to a temp file and pass the path instead. + initial_affine_file: Optional[str] = None if initial_affine is not None: - cmd += " -ia aff_initial" - kwargs["aff_initial"] = initial_affine + initial_affine_file = self._write_affine_matrix_file(initial_affine) + cmd += f" -ia {initial_affine_file}" - g.execute(cmd, **kwargs) + self.log_debug("Greedy affine/rigid command: %s", cmd) + try: + g.execute(cmd, **kwargs) + finally: + if initial_affine_file is not None: + os.remove(initial_affine_file) mat = np.array(g["aff_out"], dtype=np.float64) try: ml = g.metric_log() loss = float(ml[-1]["TotalPerPixelMetric"][-1]) if ml else 0.0 except Exception: loss = 0.0 + self.log_info("Greedy affine/rigid registration loss: %s", loss) return mat, loss def _registration_method_deformable( self, fixed_sitk: Any, moving_sitk: Any, - fixed_mask_sitk: Optional[Any], - moving_mask_sitk: Optional[Any], iterations_str: str, metric_str: str, + fixed_mask_sitk: Optional[Any] = None, + moving_mask_sitk: Optional[Any] = None, + fixed_labelmap_sitk: Optional[Any] = None, + moving_labelmap_sitk: Optional[Any] = None, initial_affine: Optional[NDArray[np.float64]] = None, ) -> tuple[Optional[NDArray[np.float64]], Any, float]: """Run Greedy deformable registration. Returns (affine 4x4 or None, warp_sitk, loss).""" @@ -224,37 +348,66 @@ def _registration_method_deformable( # Optional affine init (uses configured metric) if initial_affine is None: - cmd_aff = f"-i fixed moving -a -dof 6 -n {iterations_str} -m {metric_str} -o aff_init" - kwargs_aff = {"fixed": fixed_sitk, "moving": moving_sitk, "aff_init": None} + cmd_aff = "-d 3" + if fixed_labelmap_sitk is not None and moving_labelmap_sitk is not None: + cmd_aff += " -w 0.60" + cmd_aff += " -i fixed moving" + kwargs_aff = { + "fixed": fixed_sitk, + "moving": moving_sitk, + } + if fixed_labelmap_sitk is not None and moving_labelmap_sitk is not None: + cmd_aff += " -w 0.40 -i fixed_labelmap moving_labelmap" + kwargs_aff["fixed_labelmap"] = fixed_labelmap_sitk + kwargs_aff["moving_labelmap"] = moving_labelmap_sitk + cmd_aff += f" -a -dof 12 -n {iterations_str} -m {metric_str} -o aff_init" + kwargs_aff["aff_init"] = None if fixed_mask_sitk is not None and moving_mask_sitk is not None: cmd_aff += " -gm fixed_mask -mm moving_mask" kwargs_aff["fixed_mask"] = fixed_mask_sitk kwargs_aff["moving_mask"] = moving_mask_sitk + self.log_debug("Greedy deformable affine-init command: %s", cmd_aff) g.execute(cmd_aff, **kwargs_aff) initial_affine = np.array(g["aff_init"], dtype=np.float64) - - cmd_def = ( - f"-i fixed moving -it aff_init -n {iterations_str} " - f"-m {metric_str} -s {self.deformable_smoothing} -o warp_out" - ) + self.log_info("Greedy deformable affine init complete") + + # Greedy crashes (heap corruption) when the affine init is passed as an + # in-memory matrix via -it; write it to a temp file and pass the path. + initial_affine_file = self._write_affine_matrix_file(initial_affine) + cmd_def = "-d 3" + if fixed_labelmap_sitk is not None and moving_labelmap_sitk is not None: + cmd_def += " -w 0.60" + cmd_def += " -i fixed moving" kwargs_def = { "fixed": fixed_sitk, "moving": moving_sitk, - "aff_init": initial_affine, - "warp_out": None, } + if fixed_labelmap_sitk is not None and moving_labelmap_sitk is not None: + cmd_def += " -w 0.40 -i fixed_labelmap moving_labelmap" + kwargs_def["fixed_labelmap"] = fixed_labelmap_sitk + kwargs_def["moving_labelmap"] = moving_labelmap_sitk + cmd_def += ( + f" -it {initial_affine_file} -n {iterations_str}" + f" -m {metric_str} -s {self.deformable_smoothing} -o warp_out" + ) + kwargs_def["warp_out"] = None if fixed_mask_sitk is not None and moving_mask_sitk is not None: cmd_def += " -gm fixed_mask -mm moving_mask" kwargs_def["fixed_mask"] = fixed_mask_sitk kwargs_def["moving_mask"] = moving_mask_sitk - g.execute(cmd_def, **kwargs_def) + self.log_debug("Greedy deformable command: %s", cmd_def) + try: + g.execute(cmd_def, **kwargs_def) + finally: + os.remove(initial_affine_file) warp_out = g["warp_out"] try: ml = g.metric_log() loss = float(ml[-1]["TotalPerPixelMetric"][-1]) if ml else 0.0 except Exception: loss = 0.0 + self.log_info("Greedy deformable registration loss: %s", loss) return initial_affine, warp_out, loss def registration_method( @@ -270,20 +423,79 @@ def registration_method( Converts ITK images to SimpleITK, runs Greedy (affine and/or deformable), then converts outputs back to ITK transforms. Composes with initial_forward_transform when provided. + + Returns a dict with "forward_transform", "inverse_transform", and + "loss". As with the other image-registration backends, + forward_transform warps the moving image onto the fixed grid and + inverse_transform warps the fixed image onto the moving grid; point and + landmark warps use the opposite transform from image warps (see + docs/developer/transform_conventions). """ if self.fixed_image is None or self.fixed_image_pre is None: raise ValueError("Fixed image must be set before registration.") moving_pre = moving_image_pre if moving_image_pre is not None else moving_image - fixed_sitk = self._itk_to_sitk(self.fixed_image_pre) + + # The labelmap is added as a second Greedy metric channel only when both + # the fixed and moving labelmaps are present. That multi-component + # metric crashes picsl_greedy on large grids, so downsample every metric + # input by a single isotropic scale when (and only when) the channel is + # active; the single-channel path stays full resolution. + use_labelmap_channel = ( + self.fixed_labelmap is not None and moving_labelmap is not None + ) + metric_scale = ( + self._metric_downsample_scale(self.fixed_image_pre) + if use_labelmap_channel + else 1.0 + ) + + fixed_pre = self._downsample_image(self.fixed_image_pre, metric_scale) + moving_pre = self._downsample_image(moving_pre, metric_scale) + # warp_out lands on the (possibly downsampled) fixed grid; use the same + # grid as the displacement-field reference so shapes match. + displacement_reference = fixed_pre + fixed_sitk = self._itk_to_sitk(fixed_pre) moving_sitk = self._itk_to_sitk(moving_pre) + # Greedy applies one global metric to every input channel. A raw + # integer labelmap is piecewise-constant, so NCC sees zero local + # variance and emits NaN gradients (a native crash). Encode each + # labelmap as a continuous label-plus-boundary-distance field instead. + from physiomotion4d.labelmap_tools import LabelmapTools + + labelmap_tools = LabelmapTools() + fixed_labelmap_sitk = None + moving_labelmap_sitk = None + if self.fixed_labelmap is not None: + fixed_labelmap_ds = self._downsample_image( + self.fixed_labelmap, metric_scale, nearest=True + ) + fixed_labelmap_dist_map = labelmap_tools.create_distance_map( + fixed_labelmap_ds + ) + fixed_labelmap_sitk = self._itk_to_sitk(fixed_labelmap_dist_map) + if moving_labelmap is not None: + moving_labelmap_ds = self._downsample_image( + moving_labelmap, metric_scale, nearest=True + ) + moving_labelmap_dist_map = labelmap_tools.create_distance_map( + moving_labelmap_ds + ) + moving_labelmap_sitk = self._itk_to_sitk(moving_labelmap_dist_map) + fixed_mask_sitk = None moving_mask_sitk = None if self.fixed_mask is not None: - fixed_mask_sitk = self._itk_to_sitk(self.fixed_mask) + fixed_mask_ds = self._downsample_image( + self.fixed_mask, metric_scale, nearest=True + ) + fixed_mask_sitk = self._itk_to_sitk(fixed_mask_ds) if moving_mask is not None: - moving_mask_sitk = self._itk_to_sitk(moving_mask) + moving_mask_ds = self._downsample_image( + moving_mask, metric_scale, nearest=True + ) + moving_mask_sitk = self._itk_to_sitk(moving_mask_ds) iterations_str = self._greedy_iterations_str() metric_str = self._greedy_metric() @@ -314,10 +526,12 @@ def registration_method( mat, loss_val = self._registration_method_affine_or_rigid( fixed_sitk, moving_sitk, - fixed_mask_sitk, - moving_mask_sitk, - iterations_str, - metric_str, + fixed_mask_sitk=fixed_mask_sitk, + moving_mask_sitk=moving_mask_sitk, + fixed_labelmap_sitk=fixed_labelmap_sitk, + moving_labelmap_sitk=moving_labelmap_sitk, + iterations_str=iterations_str, + metric_str=metric_str, dof=6, initial_affine=initial_affine, ) @@ -329,10 +543,12 @@ def registration_method( mat, loss_val = self._registration_method_affine_or_rigid( fixed_sitk, moving_sitk, - fixed_mask_sitk, - moving_mask_sitk, - iterations_str, - metric_str, + fixed_mask_sitk=fixed_mask_sitk, + moving_mask_sitk=moving_mask_sitk, + fixed_labelmap_sitk=fixed_labelmap_sitk, + moving_labelmap_sitk=moving_labelmap_sitk, + iterations_str=iterations_str, + metric_str=metric_str, dof=12, initial_affine=initial_affine, ) @@ -345,10 +561,12 @@ def registration_method( aff_mat, warp_sitk, loss_val = self._registration_method_deformable( fixed_sitk, moving_sitk, - fixed_mask_sitk, - moving_mask_sitk, - iterations_str, - metric_str, + fixed_mask_sitk=fixed_mask_sitk, + moving_mask_sitk=moving_mask_sitk, + fixed_labelmap_sitk=fixed_labelmap_sitk, + moving_labelmap_sitk=moving_labelmap_sitk, + iterations_str=iterations_str, + metric_str=metric_str, initial_affine=initial_affine, ) aff_tfm = ( @@ -357,7 +575,7 @@ def registration_method( # warp_sitk can be displacement field (SimpleITK image) or numpy if hasattr(warp_sitk, "GetSize"): disp_tfm = self._sitk_warp_to_itk_displacement_transform( - warp_sitk, self.fixed_image + warp_sitk, displacement_reference ) else: # Assume numpy displacement field (z,y,x,3) @@ -365,19 +583,23 @@ def registration_method( image_tools = ImageTools() warp_arr = np.asarray(warp_sitk, dtype=np.float64) - ref = self.fixed_image + ref = displacement_reference disp_itk = image_tools.convert_array_to_image_of_vectors( warp_arr, ref, itk.D ) disp_tfm = itk.DisplacementFieldTransform[itk.D, 3].New() disp_tfm.SetDisplacementField(disp_itk) - # Forward = moving -> fixed: first affine then deformable in Greedy + # forward_transform is consumed by transform_image(moving, ..., + # fixed) to warp the moving image onto the fixed grid, so it holds + # Greedy's raw affine+warp (Greedy applies the affine first, then + # the warp). inverse_transform is the numerically inverted field, + # used to warp the fixed image onto the moving grid. This matches + # RegisterImagesANTS/ICON and RegisterTimeSeriesImages. forward_composite = itk.CompositeTransform[itk.D, 3].New() if aff_tfm is not None: forward_composite.AddTransform(aff_tfm) forward_composite.AddTransform(disp_tfm) forward_transform = forward_composite - # Inverse: inverse warp then inverse affine inv_disp = TransformTools().invert_displacement_field_transform(disp_tfm) inv_aff = itk.AffineTransform[itk.D, 3].New() if aff_tfm is not None: diff --git a/src/physiomotion4d/register_images_icon.py b/src/physiomotion4d/register_images_icon.py index 28bd03d..65c527e 100644 --- a/src/physiomotion4d/register_images_icon.py +++ b/src/physiomotion4d/register_images_icon.py @@ -88,6 +88,10 @@ def set_weights_path(self, weights_path: str) -> None: pretrained weights. Clears any previously loaded network so the new weights are applied on the next call to register(). + Also, use this to specify the path to store the downloaded weights. The + file must not exist for the weights to be downloaded correctly. Typical + suffix is ".trch". + Args: weights_path: Path to a uniGradICON checkpoint, e.g. "results/duke_4d_finetune/checkpoints/network_weights_100" @@ -185,16 +189,21 @@ def registration_method( Returns: dict: Dictionary containing: - - "forward_transform": transform moving image into fixed space - - "inverse_transform": transform fixed image to moving space + - "forward_transform": Warps the moving image onto the fixed + grid (warping moving points/landmarks into fixed space uses + "inverse_transform" instead -- image and point warps use + opposite transforms; see + docs/developer/transform_conventions) + - "inverse_transform": Warps the fixed image onto the moving grid - "loss": Loss value from the registration Note: The transformations are inverse consistent, meaning - forward_transform ≈ inverse(inverse_transform). - The inverse_transform is used to warp the fixed image - to the moving image space. The forward_transform is used - to warp the moving image to the fixed image space. + forward_transform is approximately inverse(inverse_transform). + Use forward_transform to warp the moving image onto the fixed grid, + and inverse_transform to warp the fixed image onto the moving grid. + Point/landmark warps use the opposite transform from image warps + (see docs/developer/transform_conventions). Implementation details: - Uses UniGradIcon with LNCC loss function @@ -346,42 +355,6 @@ def _image_to_resized_tensor( tensor, size=shape[2:], mode="trilinear", align_corners=False ) - @staticmethod - def create_mask(labelmap: itk.Image, dilation_mm: float = 5.0) -> itk.Image: - """Create a binary registration mask from a labelmap. - - Thresholds the labelmap at ``>0`` (so every non-zero label becomes - foreground) and dilates the result by ``dilation_mm`` millimeters of - physical radius. The radius is converted into per-axis voxel counts - from the labelmap's spacing so the dilation is physically isotropic - even on anisotropic grids; each per-axis count is clamped to at least - 1 voxel when ``dilation_mm > 0``. - - Args: - labelmap: Multi-label or binary ``itk.Image``. Any non-zero voxel - is treated as foreground. - dilation_mm: Physical radius of the binary dilation in - millimeters. Pass ``0`` (or negative) to skip dilation and - return the raw ``>0`` mask. Default 5.0 mm. - - Returns: - ``itk.Image[itk.UC, 3]`` binary mask in the same physical space as - ``labelmap`` (origin, spacing, direction copied from the input). - """ - arr = (itk.array_from_image(labelmap) > 0).astype(np.uint8) - mask = itk.image_from_array(arr) - mask.CopyInformation(labelmap) - if dilation_mm <= 0: - return mask - spacing = labelmap.GetSpacing() - radius = itk.Size[3]() - for i in range(3): - radius[i] = max(1, int(round(dilation_mm / float(spacing[i])))) - structuring_element = itk.FlatStructuringElement[3].Ball(radius) - return itk.binary_dilate_image_filter( - mask, kernel=structuring_element, foreground_value=1 - ) - def _mask_to_resized_tensor( self, mask: itk.Image, shape: torch.Size ) -> torch.Tensor: diff --git a/src/physiomotion4d/register_models_distance_maps.py b/src/physiomotion4d/register_models_distance_maps.py index 4b8090e..c03751f 100644 --- a/src/physiomotion4d/register_models_distance_maps.py +++ b/src/physiomotion4d/register_models_distance_maps.py @@ -41,7 +41,7 @@ >>> >>> # Access results >>> aligned_model = result['registered_model'] - >>> forward_transform = result['forward_transform'] # Moving to fixed transform + >>> forward_transform = result['forward_transform'] # warps moving image -> fixed grid """ import logging @@ -49,9 +49,9 @@ import itk import pyvista as pv -from itk import TubeTK as ttk from physiomotion4d.contour_tools import ContourTools +from physiomotion4d.labelmap_tools import LabelmapTools from physiomotion4d.physiomotion4d_base import PhysioMotion4DBase from physiomotion4d.register_images_ants import RegisterImagesANTS from physiomotion4d.register_images_icon import RegisterImagesICON @@ -74,8 +74,15 @@ class RegisterModelsDistanceMaps(PhysioMotion4DBase): - **Optional**: ICON deep learning refinement after any mode **Transform Convention:** - - forward_transform: Moving → fixed space transformation - - inverse_transform: Fixed → moving space transformation + These are the underlying image-registration (ANTs/ICON) transforms, so + they follow the image convention (see + docs/developer/transform_conventions): + + - forward_transform: warps the moving image/mask onto the fixed grid. + Warping the moving MODEL points/landmarks onto the fixed model uses + inverse_transform instead (image and point warps use opposite + transforms). + - inverse_transform: warps the fixed image/mask onto the moving grid. Attributes: moving_model (pv.PolyData): Surface model to be aligned @@ -148,6 +155,7 @@ def __init__( # Utilities self.transform_tools = TransformTools() self.contour_tools = ContourTools() + self.labelmap_tools = LabelmapTools(log_level=log_level) # Registration instances self.registrar_ANTS = RegisterImagesANTS(log_level=log_level) @@ -194,12 +202,9 @@ def _create_masks_from_models(self) -> None: mask = self.contour_tools.create_mask_from_mesh( self.fixed_model, self.reference_image ) - imMath = ttk.ImageMath.New(mask) - dilation_voxels = int( - self.roi_dilation_mm / self.reference_image.GetSpacing()[0] + self.fixed_mask_roi_image = self.labelmap_tools.convert_labelmap_to_mask( + mask, dilation_in_mm=self.roi_dilation_mm ) - imMath.Dilate(dilation_voxels, 1, 0) - self.fixed_mask_roi_image = imMath.GetOutput() # Create moving mask self.moving_mask_image = self.contour_tools.create_distance_map( @@ -216,9 +221,9 @@ def _create_masks_from_models(self) -> None: mask = self.contour_tools.create_mask_from_mesh( self.moving_model, self.reference_image ) - imMath = ttk.ImageMath.New(self.moving_mask_image) - imMath.Dilate(dilation_voxels, 1, 0) - self.moving_mask_roi_image = imMath.GetOutputUChar() + self.moving_mask_roi_image = self.labelmap_tools.convert_labelmap_to_mask( + mask, dilation_in_mm=self.roi_dilation_mm + ) self.log_info("Mask generation complete") diff --git a/src/physiomotion4d/register_models_icp.py b/src/physiomotion4d/register_models_icp.py index b3b2b61..3236f02 100644 --- a/src/physiomotion4d/register_models_icp.py +++ b/src/physiomotion4d/register_models_icp.py @@ -61,10 +61,16 @@ class RegisterModelsICP(PhysioMotion4DBase): - **Affine transform type**: Centroid alignment → Rigid ICP → Affine ICP **Transform Convention:** - - forward_point_transform: moving → fixed space transformation - (This is the inverse of the transform used to wrap the moving image to the - fixed image) - - inverse_point_transform: moving → fixed space transformation + These are POINT transforms (applied with TransformPoint, e.g. via + TransformTools.transform_pvcontour), so their orientation is opposite to + the image-registration transforms (see + docs/developer/transform_conventions): + + - forward_point_transform: maps moving points -> fixed points; use it to + warp the moving model/landmarks onto the fixed model. This is the + inverse of the transform that would warp the moving IMAGE onto the + fixed grid. + - inverse_point_transform: maps fixed points -> moving points. Attributes: moving_model (pv.PolyData): Surface model to be aligned diff --git a/src/physiomotion4d/register_models_pca.py b/src/physiomotion4d/register_models_pca.py index d230e0a..339b34a 100644 --- a/src/physiomotion4d/register_models_pca.py +++ b/src/physiomotion4d/register_models_pca.py @@ -42,10 +42,15 @@ class RegisterModelsPCA(PhysioMotion4DBase): pca_coefficients (np.ndarray): Optimized PCA coefficients registered_model (pv.DataSet): Final registered and deformed model post_pca_transform (itk.Transform): Transform to apply after PCA registration - forward_point_transform (itk.DisplacementFieldTransform): Forward displacement field transform - (Does not include the post-PCA transform) - inverse_point_transform (itk.DisplacementFieldTransform): Inverse displacement field transform - (Does not include the post-PCA transform) + forward_point_transform (itk.DisplacementFieldTransform): POINT transform + mapping template points -> registered/target points; use it to warp + the template model/landmarks onto the target. Its orientation is + opposite to an image-registration forward_transform (see + docs/developer/transform_conventions). Does not include the post-PCA + transform. + inverse_point_transform (itk.DisplacementFieldTransform): POINT transform + mapping target points -> template points. Does not include the + post-PCA transform. Example: >>> # Load PCA model data @@ -741,8 +746,14 @@ def compute_pca_transforms(self, reference_image: itk.Image) -> dict: Returns: Dictionary containing: - - 'forward_point_transform': Forward displacement field transform - - 'inverse_point_transform': Inverse displacement field transform + - 'forward_point_transform': POINT transform mapping template + points -> target points (warps the template onto the target) + - 'inverse_point_transform': POINT transform mapping target + points -> template points + + Note: + These are point transforms, oriented opposite to image-registration + transforms; see docs/developer/transform_conventions. """ assert self.registered_model_pca_deformation is not None, ( "PCA deformation must be computed" diff --git a/src/physiomotion4d/register_time_series_images.py b/src/physiomotion4d/register_time_series_images.py index 5e14f5a..6ced80c 100644 --- a/src/physiomotion4d/register_time_series_images.py +++ b/src/physiomotion4d/register_time_series_images.py @@ -75,8 +75,8 @@ class RegisterTimeSeriesImages(RegisterImagesBase): ... prior_weight=0.5, ... ) >>> - >>> forward_tfms = result['forward_transforms'] # Moving → Fixed - >>> inverse_tfms = result['inverse_transforms'] # Fixed → Moving + >>> forward_tfms = result['forward_transforms'] # warp moving images -> fixed grid + >>> inverse_tfms = result['inverse_transforms'] # warp fixed image -> moving grids >>> losses = result['losses'] >>> >>> # Reconstruct time series with optional upsampling @@ -205,6 +205,16 @@ def set_fixed_mask(self, fixed_mask: Optional[itk.Image]) -> None: """ self.fixed_mask = fixed_mask + def set_fixed_labelmap(self, fixed_labelmap: Optional[itk.Image]) -> None: + """Set a labelmap for the fixed image region of interest. + + This passes through to the underlying registration method. + + Args: + fixed_labelmap (Optional[itk.Image]): Labelmap defining ROI + """ + self.fixed_labelmap = fixed_labelmap + def register_time_series( self, moving_images: list[itk.Image], @@ -247,10 +257,14 @@ def register_time_series( Returns: dict: Dictionary containing results: - - "forward_transforms" (list[itk.Transform]): Transforms from moving to fixed - space for each image (warps moving → fixed) - - "inverse_transforms" (list[itk.Transform]): Transforms from fixed to moving - space for each image (warps fixed → moving) + - "forward_transforms" (list[itk.Transform]): one per image; + each warps its moving image onto the fixed grid (warping + moving points/landmarks into fixed space uses the matching + inverse transform instead -- see + docs/developer/transform_conventions) + - "inverse_transforms" (list[itk.Transform]): one per image; + each warps the fixed image onto that moving image's grid + (used by reconstruct_time_series) - "losses" (list[float]): Registration loss value for each image Raises: @@ -277,6 +291,7 @@ def register_time_series( >>> result = registrar.register_time_series( ... moving_images=image_list, ... moving_masks=mask_list, # Optional + ... moving_labelmaps=labelmap_list, # Optional ... reference_frame=5, ... register_reference=True, ... prior_weight=0.5, @@ -630,7 +645,8 @@ def reconstruct_time_series( Args: moving_images (list[itk.Image]): List of moving images to reconstruct inverse_transforms (list[itk.Transform]): List of inverse transforms - (one per moving image) from fixed space to moving space + (one per moving image), each used to warp the fixed image onto + that moving image's grid upsample_to_fixed_resolution (bool, optional): If True, reconstructed images will be upsampled to isotropic resolution (mean of fixed image's X and Y spacing) while maintaining their original origin and direction. diff --git a/src/physiomotion4d/segment_anatomy_base.py b/src/physiomotion4d/segment_anatomy_base.py index ca2a788..8368d6b 100644 --- a/src/physiomotion4d/segment_anatomy_base.py +++ b/src/physiomotion4d/segment_anatomy_base.py @@ -557,29 +557,6 @@ def segmentation_method(self, preprocessed_image: itk.image) -> itk.image: """ raise NotImplementedError("This method should be implemented by the subclass.") - def dilate_mask(self, mask: itk.image, dilation: int) -> itk.image: - """ - Dilate a binary mask using morphological operations. - - Expands the mask regions by the specified number of pixels to create - larger regions of interest. Useful for creating candidate regions or - ensuring complete coverage of anatomical structures. - - Args: - mask (itk.image): The binary mask to dilate - dilation (int): Number of pixels to dilate in each direction - - Returns: - itk.image: The dilated binary mask - - Example: - >>> dilated_heart = segmenter.dilate_mask(heart_mask, 5) - """ - imMath = tube.ImageMath.New(mask) - imMath.Dilate(dilation, 1, 0) - dilated_mask = imMath.GetOutputUChar() - return dilated_mask - def segment( self, input_image: itk.image, diff --git a/src/physiomotion4d/segment_heart_simpleware.py b/src/physiomotion4d/segment_heart_simpleware.py index 15a1031..63ba60d 100644 --- a/src/physiomotion4d/segment_heart_simpleware.py +++ b/src/physiomotion4d/segment_heart_simpleware.py @@ -106,7 +106,7 @@ def __init__(self, log_level: int | str = logging.INFO): self._finalize_other_group() # Path to Simpleware Medical console executable - self.simpleware_exe_path = "C:/Program Files/Synopsys/Simpleware Medical/X-2025.06/ConsoleSimplewareMedical.exe" + self.simpleware_exe_path = "C:/Program Files/Synopsys/Simpleware Medical/Y-2026.03/ConsoleSimplewareMedical.exe" # Path to the Simpleware Python script for heart segmentation self.simpleware_script_path = os.path.join( @@ -338,8 +338,8 @@ def segmentation_method(self, preprocessed_image: itk.image) -> itk.image: ) if mask_image is not None: - in_direction = np.array(preprocessed_image.GetDirection()) - out_direction = np.array(mask_image.GetDirection()) + in_direction = itk.array_from_matrix(preprocessed_image.GetDirection()) + out_direction = itk.array_from_matrix(mask_image.GetDirection()) flip = [False, False, False] for i in range(3): if np.sign(out_direction[i, i]) != np.sign(in_direction[i, i]): diff --git a/src/physiomotion4d/transform_tools.py b/src/physiomotion4d/transform_tools.py index 2a22aa2..a1f4197 100644 --- a/src/physiomotion4d/transform_tools.py +++ b/src/physiomotion4d/transform_tools.py @@ -89,12 +89,13 @@ def combine_displacement_field_transforms( mode: str = "compose", tfm1_blur_sigma: float = 0.0, tfm2_blur_sigma: float = 0.0, - ) -> itk.DisplacementFieldTransform: + ) -> itk.Transform: """ Compose two displacement field transforms. - Composes two displacement field transforms into a single displacement field - transform. + In ``add`` mode, returns a single displacement field transform with + weighted summed vectors. In ``compose`` mode, returns a composite + transform containing both weighted displacement field transforms. """ assert mode in ["add", "compose"], "Invalid mode" @@ -225,7 +226,7 @@ def convert_transform_to_displacement_field( field = tfm.GetDisplacementField() field_arr = itk.array_view_from_image(tfm.GetDisplacementField()) reference_image_arr = itk.array_view_from_image(reference_image) - if field_arr.shape[:2] != reference_image_arr.shape: + if field_arr.shape[:3] != reference_image_arr.shape: field_filter = itk.TransformToDisplacementFieldFilter[ itk.Image[itk.Vector[itk.F, 3], 3], TfmPrecision ].New() @@ -365,7 +366,7 @@ def transform_dataset( np.array(tfm.TransformPoint((float(p[0]), float(p[1]), float(p[2])))) for p in pnts ] - new_mesh.points = np.asarray(new_pnts, dtype=float) + new_mesh.points = np.asarray(new_pnts, dtype=float).reshape(-1, 3) if with_deformation_magnitude: if cp is not None: @@ -419,7 +420,7 @@ def transform_image( ... ) >>> # Transform label map preserving discrete values >>> warped_labels = transform_tools.transform_image( - ... labelmap, transform, reference, tfm_type='nearest' + ... labelmap, transform, reference, interpolation_method='nearest' ... ) """ # Handle case where tfm is a list (e.g., from itk.transformread) @@ -618,6 +619,7 @@ def combine_transforms_with_masks( sum_fields_arr = mask1_arr * field1_arr + mask2_arr * field2_arr denom = mask1_arr + mask2_arr + denom[denom == 0] = 1.0 combined_field_arr = sum_fields_arr / denom @@ -757,17 +759,22 @@ def generate_grid_image( width_max = width_min + line_width for i in range(grid_size): for j in range(grid_size): - min_idx0 = int(i * grid_spacing[0]) - width_min - max_idx0 = int(i * grid_spacing[0]) + width_max - min_idx1 = int(j * grid_spacing[1]) - width_min - max_idx1 = int(j * grid_spacing[1]) + width_max - img_arr[min_idx0:max_idx0, min_idx1:max_idx1, :] = img_arr_max - min_idx2 = int(j * grid_spacing[2]) - width_min - max_idx2 = int(j * grid_spacing[2]) + width_max - img_arr[min_idx0:max_idx0, :, min_idx2:max_idx2] = img_arr_max - min_idx1 = int(i * grid_spacing[1]) - width_min - max_idx1 = int(i * grid_spacing[1]) + width_max - img_arr[:, min_idx1:max_idx1, min_idx2:max_idx2] = img_arr_max + min_idx0 = max(0, int(i * grid_spacing[0]) - width_min) + max_idx0 = min(img_arr.shape[0], int(i * grid_spacing[0]) + width_max) + min_idx1 = max(0, int(j * grid_spacing[1]) - width_min) + max_idx1 = min(img_arr.shape[1], int(j * grid_spacing[1]) + width_max) + if min_idx0 < max_idx0 and min_idx1 < max_idx1: + img_arr[min_idx0:max_idx0, min_idx1:max_idx1, :] = img_arr_max + + min_idx2 = max(0, int(j * grid_spacing[2]) - width_min) + max_idx2 = min(img_arr.shape[2], int(j * grid_spacing[2]) + width_max) + if min_idx0 < max_idx0 and min_idx2 < max_idx2: + img_arr[min_idx0:max_idx0, :, min_idx2:max_idx2] = img_arr_max + + min_idx1 = max(0, int(i * grid_spacing[1]) - width_min) + max_idx1 = min(img_arr.shape[1], int(i * grid_spacing[1]) + width_max) + if min_idx1 < max_idx1 and min_idx2 < max_idx2: + img_arr[:, min_idx1:max_idx1, min_idx2:max_idx2] = img_arr_max grid_image = itk.image_from_array(img_arr) grid_image.CopyInformation(reference_image) diff --git a/src/physiomotion4d/workflow_fine_tune_icon_registration.py b/src/physiomotion4d/workflow_fine_tune_icon_registration.py index 223a5c0..c16e9f9 100644 --- a/src/physiomotion4d/workflow_fine_tune_icon_registration.py +++ b/src/physiomotion4d/workflow_fine_tune_icon_registration.py @@ -42,8 +42,8 @@ import numpy as np import yaml +from physiomotion4d.labelmap_tools import LabelmapTools from physiomotion4d.physiomotion4d_base import PhysioMotion4DBase -from physiomotion4d.register_images_icon import RegisterImagesICON from physiomotion4d.register_time_series_images import RegisterTimeSeriesImages from physiomotion4d.transform_tools import TransformTools @@ -79,7 +79,7 @@ class WorkflowFineTuneICONRegistration(PhysioMotion4DBase): subject_ids (Optional[list[str]]): One ID per subject (e.g. patient identifiers). Written into the dataset JSON's ``subject_id`` field; falls back to synthetic ``subject_NNNN`` when ``None``. - subject_segmentation_files (Optional[list[list[Optional[str]]]]): + subject_labelmap_files (Optional[list[list[Optional[str]]]]): Per-subject multi-label segmentation/labelmap paths aligned with ``subject_image_files``. ``None`` (or per-image ``None``) means no segmentation for that image. If supplied for at least one image, @@ -88,14 +88,16 @@ class WorkflowFineTuneICONRegistration(PhysioMotion4DBase): Per-subject binary mask paths aligned with ``subject_image_files``. When supplied for a frame these masks are used directly for loss-function masking; otherwise masks are derived from - ``subject_segmentation_files``. + ``subject_labelmap_files``. subject_landmark_files (Optional[list[list[Optional[str]]]]): Per-subject landmark CSV paths (``Name,X,Y,Z`` format) aligned with ``subject_image_files``. Recorded in the dataset JSON for traceability; not consumed by uniGradICON fine-tuning itself. mask_dilation_mm (float): Millimeters of physical-radius binary dilation applied to the >0 labelmap when deriving the loss-masking - binary mask via :meth:`RegisterImagesICON.create_mask`. + binary mask via :meth:`LabelmapTools.convert_labelmap_to_mask`. + mask_exclude_labels (Optional[list[int]]): Labels to exclude from the mask. + Default is None. mask_dir (Optional[Path]): Directory where derived binary masks are written and looked up. ``None`` (default) writes each derived mask next to its source labelmap on disk. @@ -113,7 +115,7 @@ class WorkflowFineTuneICONRegistration(PhysioMotion4DBase): ... ], ... output_dir=Path('d:/PhysioMotion4D/icon_finetuned'), ... fine_tune_name='duke_4d_gated_icon_ft', - ... subject_segmentation_files=[ + ... subject_labelmap_files=[ ... ['pm0001/g000_labelmap.nii.gz', 'pm0001/g050_labelmap.nii.gz'], ... ['pm0002/g000_labelmap.nii.gz', 'pm0002/g050_labelmap.nii.gz'], ... ], @@ -138,7 +140,7 @@ def __init__( output_dir: Path, fine_tune_name: str, subject_ids: Optional[list[str]] = None, - subject_segmentation_files: Optional[list[list[Optional[str]]]] = None, + subject_labelmap_files: Optional[list[list[Optional[str]]]] = None, subject_mask_files: Optional[list[list[Optional[str]]]] = None, subject_landmark_files: Optional[list[list[Optional[str]]]] = None, epochs: int = 2000, @@ -148,13 +150,14 @@ def __init__( similarity: str = "lncc", lambda_value: float = 1.5, dice_loss_weight: float = 0.5, - lncc_sigma: int = 5, + lncc_sigma: int = 1, ct_window: tuple[float, float] = (-1000.0, 1000.0), is_ct: bool = True, gpus: Optional[list[int]] = None, eval_period: int = 10, save_period: int = 50, mask_dilation_mm: float = 5.0, + mask_exclude_labels: Optional[list[int]] = None, mask_dir: Optional[Path] = None, unigradicon_src_path: Optional[Path] = None, log_level: Union[int, str] = logging.INFO, @@ -174,15 +177,15 @@ def __init__( JSON's ``subject_id`` field so paired training groups frames that share an ID. ``None`` falls back to synthetic IDs of the form ``subject_0000``, ``subject_0001``, ... Must be unique. - subject_segmentation_files: Per-subject multi-label segmentation + subject_labelmap_files: Per-subject multi-label segmentation (labelmap) paths matching ``subject_image_files``. ``None`` - disables paired-with-seg training (no ``use_label``). + disables paired-with-seg training. Individual ``None`` entries inside the inner lists skip just those frames when paired-with-seg training is enabled. subject_mask_files: Per-subject binary mask paths matching ``subject_image_files``. When supplied these are used directly for ICON loss-function masking; otherwise masks are derived - from ``subject_segmentation_files`` via a >0 threshold and + from ``subject_labelmap_files`` via a >0 threshold and dilation by ``mask_dilation_mm``. Per-image ``None`` entries fall back to the derived mask for that frame (or skip it if no segmentation is available either). @@ -205,8 +208,8 @@ def __init__( mask_dilation_mm: Physical radius (millimeters) of binary dilation applied to the >0 labelmap when deriving the loss-masking binary mask via - :meth:`RegisterImagesICON.create_mask`. Ignored when no - segmentations are supplied. Default 5.0 mm. + :meth:`LabelmapTools.convert_labelmap_to_mask`. Ignored when + no segmentations are supplied. Default 5.0 mm. mask_dir: Directory where derived binary masks are written and looked up. ``None`` (default) writes each derived mask next to its source labelmap on disk @@ -220,7 +223,7 @@ def __init__( Raises: ValueError: If ``subject_image_files`` is empty. - ValueError: If ``subject_segmentation_files``, + ValueError: If ``subject_labelmap_files``, ``subject_mask_files``, or ``subject_landmark_files`` is provided with a shape that does not match ``subject_image_files``. @@ -243,8 +246,8 @@ def __init__( self._validate_companion_shape( subject_image_files, - subject_segmentation_files, - "subject_segmentation_files", + subject_labelmap_files, + "subject_labelmap_files", ) self._validate_companion_shape( subject_image_files, subject_mask_files, "subject_mask_files" @@ -255,10 +258,15 @@ def __init__( self.subject_image_files = subject_image_files self.subject_ids = subject_ids - self.subject_segmentation_files = subject_segmentation_files + self.subject_labelmap_files = subject_labelmap_files self.subject_mask_files = subject_mask_files self.subject_landmark_files = subject_landmark_files + self.use_segmentations: bool = subject_labelmap_files is not None + self.use_masks: bool = ( + subject_mask_files is not None or subject_labelmap_files is not None + ) + self.output_dir = Path(output_dir).resolve() self.fine_tune_name = fine_tune_name self.experiment_dir = self.output_dir / fine_tune_name @@ -277,14 +285,19 @@ def __init__( self.gpus = list(gpus) if gpus is not None else [0] self.eval_period = eval_period self.save_period = save_period + self.mask_exclude_labels = mask_exclude_labels self.mask_dilation_mm = float(mask_dilation_mm) self.unigradicon_src_path = ( Path(unigradicon_src_path) if unigradicon_src_path is not None else None ) self.transform_tools = TransformTools() + self.labelmap_tools = LabelmapTools(log_level=log_level) self.registrar: Optional[RegisterTimeSeriesImages] = None + self._use_segmentations: bool = self.use_segmentations + self._use_masks: bool = self.use_masks + self._dataset_json_path: Optional[Path] = None self._config_yaml_path: Optional[Path] = None @@ -309,48 +322,21 @@ def _validate_companion_shape( f"subject_image_files[{i}] length ({len(images)})" ) - @property - def uses_segmentations(self) -> bool: - """Whether at least one segmentation file is supplied for training. - - Drives the uniGradICON ``training.use_label`` flag. - """ - return self._any_non_none(self.subject_segmentation_files) - - @property - def uses_masks(self) -> bool: - """Whether the dataset will have a ``mask`` field on every kept entry. - - True when explicit masks are supplied OR when segmentations are supplied - (since masks are then derived). Drives the uniGradICON - ``training.loss_function_masking`` flag. - """ - return self._any_non_none(self.subject_mask_files) or self.uses_segmentations - - @staticmethod - def _any_non_none( - companion: Optional[list[list[Optional[str]]]], - ) -> bool: - """Return True when ``companion`` contains at least one non-``None`` entry.""" - if companion is None: - return False - for inner in companion: - for item in inner: - if item is not None: - return True - return False - @staticmethod def _posix(path: Union[str, Path]) -> str: """Return a forward-slashed string path (uniGradICON expects POSIX paths).""" return str(path).replace("\\", "/") - def _derive_mask(self, labelmap_path: Union[str, Path]) -> Path: + def _derive_mask( + self, + labelmap_path: Union[str, Path], + ) -> Path: """Create (or reuse) a dilated binary mask from a multi-label labelmap. Threshold the labelmap at ``>0`` and dilate by ``mask_dilation_mm`` mm - of physical radius via :meth:`RegisterImagesICON.create_mask` to widen - the ROI for loss-function masking. + of physical radius via + :meth:`LabelmapTools.convert_labelmap_to_mask` to widen the ROI for + loss-function masking. When :attr:`mask_dir` is ``None`` (the default) the mask is written next to the source labelmap as @@ -378,13 +364,19 @@ def _derive_mask(self, labelmap_path: Union[str, Path]) -> Path: return mask_path labelmap = itk.imread(str(labelmap_path)) - mask = RegisterImagesICON.create_mask( - labelmap, dilation_mm=self.mask_dilation_mm + mask = self.labelmap_tools.convert_labelmap_to_mask( + labelmap, + dilation_in_mm=self.mask_dilation_mm, + exclude_labels=self.mask_exclude_labels, ) itk.imwrite(mask, str(mask_path), compression=True) return mask_path - def prepare_dataset(self) -> Path: + def prepare_dataset( + self, + use_segmentations: Optional[bool] = None, + use_masks: Optional[bool] = None, + ) -> Path: """Write the uniGradICON dataset JSON from the configured file lists. Builds one entry per image with ``image``, optional ``segmentation``, @@ -392,7 +384,7 @@ def prepare_dataset(self) -> Path: ``subject_id`` derived from the inner-list index. Masks are taken from ``subject_mask_files`` when supplied for a frame; - otherwise they are derived from ``subject_segmentation_files`` via a + otherwise they are derived from ``subject_labelmap_files`` via a >0 threshold and ``mask_dilation_mm`` mm dilation. Frames are skipped (with a log warning) when a required companion (segmentation for paired-with-seg training, or mask for loss-function masking) is @@ -406,8 +398,14 @@ def prepare_dataset(self) -> Path: does not exist on disk. """ self.experiment_dir.mkdir(parents=True, exist_ok=True) - use_seg = self.uses_segmentations - use_mask = self.uses_masks + + if use_segmentations is None: + use_segmentations = self.use_segmentations + if use_masks is None: + use_masks = self.use_masks + + self._use_segmentations = use_segmentations + self._use_masks = use_masks dataset_entries: list[dict[str, str]] = [] for subject_index, image_files in enumerate(self.subject_image_files): @@ -416,16 +414,24 @@ def prepare_dataset(self) -> Path: if self.subject_ids is not None else f"subject_{subject_index:04d}" ) - seg_list = ( - self.subject_segmentation_files[subject_index] - if self.subject_segmentation_files is not None - else [None] * len(image_files) - ) - mask_list = ( - self.subject_mask_files[subject_index] - if self.subject_mask_files is not None - else [None] * len(image_files) - ) + seg_list: list[Optional[str]] + if not use_segmentations: + seg_list = [None] * len(image_files) + else: + seg_list = ( + self.subject_labelmap_files[subject_index] + if self.subject_labelmap_files is not None + else [None] * len(image_files) + ) + mask_list: list[Optional[str]] + if not use_masks: + mask_list = [None] * len(image_files) + else: + mask_list = ( + self.subject_mask_files[subject_index] + if self.subject_mask_files is not None + else [None] * len(image_files) + ) landmark_list = ( self.subject_landmark_files[subject_index] if self.subject_landmark_files is not None @@ -444,7 +450,7 @@ def prepare_dataset(self) -> Path: "subject_id": subject_id, } - if use_seg: + if use_segmentations: if seg_file is None or not Path(seg_file).exists(): self.log_warning( "Skipping %s: segmentation missing for paired-with-seg " @@ -455,7 +461,7 @@ def prepare_dataset(self) -> Path: continue entry["segmentation"] = self._posix(seg_file) - if use_mask: + if use_masks: if mask_file is not None and Path(mask_file).exists(): resolved_mask: Path = Path(mask_file) elif seg_file is not None and Path(seg_file).exists(): @@ -529,8 +535,8 @@ def prepare_config(self, dataset_json_path: Optional[Path] = None) -> Path: "lambda": self.lambda_value, "dice_loss_weight": self.dice_loss_weight, "lncc_sigma": self.lncc_sigma, - "loss_function_masking": self.uses_masks, - "use_label": self.uses_segmentations, + "loss_function_masking": self._use_masks, + "use_label": self._use_segmentations, "roi_masking": False, }, "datasets": [ @@ -735,8 +741,8 @@ def apply_registration( self.log_info("ICON weights: %s", weights_path) fixed_mask = ( - RegisterImagesICON.create_mask( - reference_segmentation, dilation_mm=self.mask_dilation_mm + self.labelmap_tools.convert_labelmap_to_mask( + reference_segmentation, dilation_in_mm=self.mask_dilation_mm ) if reference_segmentation is not None else None @@ -745,8 +751,8 @@ def apply_registration( if moving_segmentations is not None: moving_masks = [ ( - RegisterImagesICON.create_mask( - seg, dilation_mm=self.mask_dilation_mm + self.labelmap_tools.convert_labelmap_to_mask( + seg, dilation_in_mm=self.mask_dilation_mm ) if seg is not None else None diff --git a/src/physiomotion4d/workflow_fit_statistical_model_to_patient.py b/src/physiomotion4d/workflow_fit_statistical_model_to_patient.py index a533f4f..0e2c384 100644 --- a/src/physiomotion4d/workflow_fit_statistical_model_to_patient.py +++ b/src/physiomotion4d/workflow_fit_statistical_model_to_patient.py @@ -30,6 +30,7 @@ import pyvista as pv from physiomotion4d.contour_tools import ContourTools +from physiomotion4d.labelmap_tools import LabelmapTools from physiomotion4d.physiomotion4d_base import PhysioMotion4DBase from physiomotion4d.register_images_ants import RegisterImagesANTS from physiomotion4d.register_images_icon import RegisterImagesICON @@ -199,6 +200,7 @@ def __init__( # Utilities (needed for create_reference_image when patient_image is None) self.transform_tools = TransformTools() self.contour_tools = ContourTools() + self.labelmap_tools = LabelmapTools() if patient_image is not None: self.patient_image = patient_image @@ -319,11 +321,9 @@ def _auto_generate_mask( if dilate_mm is None: dilate_mm = self.mask_dilation_mm if dilate_mm > 0: - ttk = _load_tubetk() - imMath = ttk.ImageMath.New(mask) - dilation_voxels = int(dilate_mm / self.patient_image.GetSpacing()[0]) - imMath.Dilate(dilation_voxels, 1, 0) - mask = imMath.GetOutputUChar() + mask = self.labelmap_tools.convert_labelmap_to_mask( + mask, dilation_in_mm=dilate_mm + ) self.log_info("Masks auto-generated successfully.") @@ -349,11 +349,9 @@ def _auto_generate_roi_mask( # Generate model ROI mask roi = None if dilate_mm > 0: - ttk = _load_tubetk() - imMath = ttk.ImageMath.New(mask) - dilation_voxels = int(dilate_mm / mask.GetSpacing()[0]) - imMath.Dilate(dilation_voxels, 1, 0) - roi = imMath.GetOutputUChar() + roi = self.labelmap_tools.convert_labelmap_to_mask( + mask, dilation_in_mm=dilate_mm + ) else: roi = mask diff --git a/src/physiomotion4d/workflow_reconstruct_highres_4d_ct.py b/src/physiomotion4d/workflow_reconstruct_highres_4d_ct.py index ebfb732..b3105c5 100644 --- a/src/physiomotion4d/workflow_reconstruct_highres_4d_ct.py +++ b/src/physiomotion4d/workflow_reconstruct_highres_4d_ct.py @@ -61,8 +61,10 @@ class WorkflowReconstructHighres4DCT(PhysioMotion4DBase): registration_method (str): Registration method ('ANTS', 'ICON', or 'ANTS_ICON') number_of_iterations: Iterations for registration registrar (RegisterTimeSeriesImages): Internal registration object - forward_transforms (list[itk.Transform]): Forward transforms (moving → fixed) - inverse_transforms (list[itk.Transform]): Inverse transforms (fixed → moving) + forward_transforms (list[itk.Transform]): one per frame; each warps its + moving image onto the fixed grid + inverse_transforms (list[itk.Transform]): one per frame; each warps the + fixed image onto that frame's moving grid (used for reconstruction) losses (list[float]): Registration loss values reconstructed_images (list[itk.Image]): Reconstructed high-resolution images @@ -260,10 +262,11 @@ def register_time_series(self) -> dict: Returns: dict: Dictionary containing: - - 'forward_transforms' (list[itk.Transform]): Transforms from moving - to fixed space (warps moving → fixed) - - 'inverse_transforms' (list[itk.Transform]): Transforms from fixed - to moving space (warps fixed → moving) + - 'forward_transforms' (list[itk.Transform]): one per frame; + each warps its moving image onto the fixed grid + - 'inverse_transforms' (list[itk.Transform]): one per frame; + each warps the fixed image onto that frame's moving grid + (see docs/developer/transform_conventions) - 'losses' (list[float]): Registration loss value for each image Raises: diff --git a/tests/test_image_tools.py b/tests/test_image_tools.py index 7e48b7c..100a35c 100644 --- a/tests/test_image_tools.py +++ b/tests/test_image_tools.py @@ -428,7 +428,7 @@ def test_flip_and_make_identity_sets_direction_to_identity( direction = np.diag([-1.0, 1.0, 1.0]) itk_image = _make_synthetic_itk_image(shape_xyz, arr=arr, direction=direction) out = image_tools.flip_image(itk_image, flip_and_make_identity=True) - out_direction = np.array(out.GetDirection()) + out_direction = itk.array_from_matrix(out.GetDirection()) identity = np.eye(3) assert np.allclose(out_direction, identity), ( "flip_and_make_identity should set direction to identity" @@ -450,7 +450,7 @@ def test_flip_and_make_identity_with_mask_sets_both_directions_to_identity( itk_image, in_mask=itk_mask, flip_and_make_identity=True ) for im, name in [(out_image, "image"), (out_mask, "mask")]: - dir_mat = np.array(im.GetDirection()) + dir_mat = itk.array_from_matrix(im.GetDirection()) assert np.allclose(dir_mat, np.eye(3)), ( f"flip_and_make_identity should set {name} direction to identity" ) diff --git a/tests/test_labelmap_tools.py b/tests/test_labelmap_tools.py new file mode 100644 index 0000000..6e1279c --- /dev/null +++ b/tests/test_labelmap_tools.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python +""" +Tests for LabelmapTools functionality. + +Covers thresholding a multi-label labelmap into a binary registration mask, +physically isotropic dilation that respects per-axis spacing, and forcing +selected labels to background via ``exclude_labels``. +""" + +from __future__ import annotations + +import itk +import numpy as np +import pytest + +from physiomotion4d.labelmap_tools import LabelmapTools + + +class TestLabelmapTools: + """Test suite for LabelmapTools.convert_labelmap_to_mask.""" + + @pytest.fixture + def labelmap_tools(self) -> LabelmapTools: + """Create LabelmapTools instance.""" + return LabelmapTools() + + def test_threshold_without_dilation(self, labelmap_tools: LabelmapTools) -> None: + """Every non-zero label becomes foreground; no dilation grows it.""" + arr = np.zeros((5, 5, 5), dtype=np.uint8) + arr[2, 2, 2] = 3 # non-zero label id + labelmap = itk.image_from_array(arr) + labelmap.SetSpacing([1.0, 1.0, 1.0]) + + mask = labelmap_tools.convert_labelmap_to_mask(labelmap, dilation_in_mm=0.0) + mask_arr = itk.array_from_image(mask) + + assert set(np.unique(mask_arr).tolist()) == {0, 1} + assert int(mask_arr.sum()) == 1 + assert mask_arr[2, 2, 2] == 1 + + def test_dilation_grows_mask(self, labelmap_tools: LabelmapTools) -> None: + """Positive dilation_in_mm grows the mask but keeps the seed voxel.""" + arr = np.zeros((5, 5, 5), dtype=np.uint8) + arr[2, 2, 2] = 3 + labelmap = itk.image_from_array(arr) + # Unit isotropic spacing so dilation_in_mm == voxel radius. + labelmap.SetSpacing([1.0, 1.0, 1.0]) + + dilated = labelmap_tools.convert_labelmap_to_mask(labelmap, dilation_in_mm=1.0) + dilated_arr = itk.array_from_image(dilated) + + assert int(dilated_arr.sum()) > 1 + assert dilated_arr[2, 2, 2] == 1 + + def test_dilation_respects_anisotropic_spacing( + self, labelmap_tools: LabelmapTools + ) -> None: + """A 5 mm radius covers more voxels along the finely spaced axis.""" + arr = np.zeros((11, 11, 11), dtype=np.uint8) + arr[5, 5, 5] = 1 + labelmap = itk.image_from_array(arr) + # numpy axes are (Z, Y, X); ITK spacing is (X, Y, Z). Make X coarse + # (5 mm/voxel -> 1-voxel radius) and Z fine (1 mm/voxel -> 5-voxel + # radius) so the per-axis radius differs. + labelmap.SetSpacing([5.0, 1.0, 1.0]) + + dilated = itk.array_from_image( + labelmap_tools.convert_labelmap_to_mask(labelmap, dilation_in_mm=5.0) + ) + + # Z axis (numpy axis 0) reaches 5 voxels out; X axis (numpy axis 2) + # only 1 voxel out. + assert dilated[0, 5, 5] == 1 + assert dilated[10, 5, 5] == 1 + assert dilated[5, 5, 0] == 0 + assert dilated[5, 5, 10] == 0 + + def test_exclude_labels_removes_voxels(self, labelmap_tools: LabelmapTools) -> None: + """Excluded labels become background before thresholding.""" + arr = np.zeros((5, 5, 5), dtype=np.uint8) + arr[1, 1, 1] = 2 # kept + arr[3, 3, 3] = 7 # excluded + labelmap = itk.image_from_array(arr) + labelmap.SetSpacing([1.0, 1.0, 1.0]) + + mask_arr = itk.array_from_image( + labelmap_tools.convert_labelmap_to_mask( + labelmap, dilation_in_mm=0.0, exclude_labels=[7] + ) + ) + + assert mask_arr[1, 1, 1] == 1 + assert mask_arr[3, 3, 3] == 0 + assert int(mask_arr.sum()) == 1 + + def test_preserves_image_information(self, labelmap_tools: LabelmapTools) -> None: + """Origin, spacing, and direction are copied from the labelmap.""" + arr = np.zeros((4, 4, 4), dtype=np.uint8) + arr[2, 2, 2] = 1 + labelmap = itk.image_from_array(arr) + labelmap.SetSpacing([0.5, 1.0, 2.0]) + labelmap.SetOrigin([10.0, -5.0, 3.0]) + + mask = labelmap_tools.convert_labelmap_to_mask(labelmap, dilation_in_mm=0.0) + + assert list(mask.GetSpacing()) == [0.5, 1.0, 2.0] + assert list(mask.GetOrigin()) == [10.0, -5.0, 3.0] + assert np.allclose( + itk.array_from_matrix(mask.GetDirection()), + itk.array_from_matrix(labelmap.GetDirection()), + ) diff --git a/tests/test_register_images_ants.py b/tests/test_register_images_ants.py index 8102c58..61c778d 100644 --- a/tests/test_register_images_ants.py +++ b/tests/test_register_images_ants.py @@ -20,6 +20,27 @@ from physiomotion4d.transform_tools import TransformTools +def _foreground_ncc( + reference_arr: np.ndarray, warped_arr: np.ndarray, foreground: np.ndarray +) -> float: + """Normalized cross-correlation over a foreground mask (higher = better). + + Args: + reference_arr: Reference image array (e.g. the fixed image), axes (Z, Y, X). + warped_arr: Warped image array on the same grid/axes as ``reference_arr``. + foreground: Boolean mask (same shape) selecting the voxels to score. + + Returns: + NCC in [-1, 1] over the foreground voxels (nan if degenerate). + """ + a = reference_arr[foreground].astype(np.float64) + b = warped_arr[foreground].astype(np.float64) + a0 = a - a.mean() + b0 = b - b.mean() + denom = float(np.sqrt((a0 * a0).sum() * (b0 * b0).sum())) + return float((a0 * b0).sum() / denom) if denom > 0 else float("nan") + + @pytest.mark.slow class TestRegisterImagesANTS: """Test suite for ANTs-based image registration.""" @@ -309,6 +330,218 @@ def test_registration_with_initial_transform( print("Registration with initial transform complete") + def test_initial_transform_composition_metrics( + self, + registrar_ANTS: RegisterImagesANTS, + test_images: list[Any], + test_directories: dict[str, Path], + ) -> None: + """Verify the initial_forward_transform composition path with metrics. + + Exercises the two initial-transform inputs the platform actually uses + (identity and a prior deformable forward_transform, as in prior-based + time-series registration) and confirms the composed forward_transform + warps the moving image onto the fixed grid. Scored with foreground NCC + over the brightest 30% of the fixed image (tissue/blood pool). See + docs/developer/transform_conventions. + + Asserted facts: + * a plain registration improves on the unregistered pair, + * an identity initial reproduces the baseline exactly (the + composition machinery is a structurally correct no-op; note an + identity AffineTransform is itself a matrix initial), + * a prior-deformable initial reaches the no-initial baseline quality + (the composition recovers the full transform). + + The initial transform is applied by pre-warping the moving image (as in + RegisterImagesICON), which keeps the composition self-consistent for any + initial transform type. + """ + output_dir = test_directories["output"] + reg_output_dir = output_dir / "registration_ANTS" + reg_output_dir.mkdir(exist_ok=True) + + # Pick two phases that are far apart in the cycle so there is real motion. + fixed_image = test_images[0] + moving_image = test_images[min(10, len(test_images) - 1)] + + fixed_arr = itk.array_from_image(fixed_image) + # Moving and fixed share the acquisition grid (split from one 4D image), + # so the moving array is directly comparable for the unregistered score. + moving_arr = itk.array_from_image(moving_image) + threshold = float(np.percentile(fixed_arr, 70.0)) + foreground = fixed_arr > threshold + + transform_tools = TransformTools() + + def warp_score(forward_transform: Any) -> float: + warped = transform_tools.transform_image( + moving_image, + forward_transform, + fixed_image, + interpolation_method="linear", + ) + return _foreground_ncc(fixed_arr, itk.array_from_image(warped), foreground) + + ncc_unregistered = _foreground_ncc(fixed_arr, moving_arr, foreground) + + # Baseline: no initial transform. + registrar_ANTS.set_modality("ct") + registrar_ANTS.set_fixed_image(fixed_image) + baseline = registrar_ANTS.register(moving_image=moving_image) + ncc_baseline = warp_score(baseline["forward_transform"]) + + # Identity initial: the composition machinery must be a no-op. + identity = itk.AffineTransform[itk.D, 3].New() + identity.SetIdentity() + registrar_identity = RegisterImagesANTS() + registrar_identity.set_modality("ct") + registrar_identity.set_fixed_image(fixed_image) + identity_result = registrar_identity.register( + moving_image=moving_image, initial_forward_transform=identity + ) + ncc_identity = warp_score(identity_result["forward_transform"]) + + # Prior deformable initial: the realistic time-series prior use case. + registrar_prior = RegisterImagesANTS() + registrar_prior.set_modality("ct") + registrar_prior.set_fixed_image(fixed_image) + prior_result = registrar_prior.register( + moving_image=moving_image, + initial_forward_transform=baseline["forward_transform"], + ) + ncc_prior = warp_score(prior_result["forward_transform"]) + + print("\nANTS initial-transform composition metrics (foreground NCC):") + print(f" unregistered: {ncc_unregistered:.4f}") + print(f" baseline (no initial): {ncc_baseline:.4f}") + print(f" identity initial: {ncc_identity:.4f}") + print(f" prior-deformable init: {ncc_prior:.4f}") + + warped_prior = transform_tools.transform_image( + moving_image, + prior_result["forward_transform"], + fixed_image, + interpolation_method="linear", + ) + itk.imwrite( + warped_prior, + str(reg_output_dir / "ants_warped_prior_initial.mha"), + compression=True, + ) + + # Registration must improve alignment over the unregistered pair. + assert ncc_baseline > ncc_unregistered, ( + f"Baseline registration did not improve alignment: " + f"{ncc_baseline:.4f} <= {ncc_unregistered:.4f}" + ) + # Identity initial must reproduce the baseline (composition is a no-op). + assert abs(ncc_identity - ncc_baseline) < 0.03, ( + f"Identity initial transform changed the result: " + f"identity={ncc_identity:.4f} vs baseline={ncc_baseline:.4f}" + ) + # A prior-deformable initial must reach the no-initial baseline quality + # (the composition recovers the full transform). + assert ncc_prior >= ncc_baseline - 0.03, ( + f"Prior-initial composition did not reach baseline quality: " + f"{ncc_prior:.4f} < {ncc_baseline:.4f} - 0.03" + ) + + def test_initial_transform_matrix_composition( + self, + registrar_ANTS: RegisterImagesANTS, + test_images: list[Any], + ) -> None: + """A matrix (translation/affine) initial composes without corruption. + + Regression guard for the previously-broken matrix initial_transform + path: feeding a translation initial used to corrupt the composition + (foreground NCC far below the unregistered pair). With the moving image + pre-warped by the initial, the composed forward_transform must align the + moving image onto the fixed grid at least as well as the unregistered + pair. + """ + fixed_image = test_images[0] + moving_image = test_images[min(10, len(test_images) - 1)] + + fixed_arr = itk.array_from_image(fixed_image) + threshold = float(np.percentile(fixed_arr, 70.0)) + foreground = fixed_arr > threshold + ncc_unregistered = _foreground_ncc( + fixed_arr, itk.array_from_image(moving_image), foreground + ) + + translation = itk.TranslationTransform[itk.D, 3].New() + translation.SetOffset([-5.0, -5.0, -5.0]) + + registrar_ANTS.set_modality("ct") + registrar_ANTS.set_fixed_image(fixed_image) + result = registrar_ANTS.register( + moving_image=moving_image, initial_forward_transform=translation + ) + + transform_tools = TransformTools() + warped = transform_tools.transform_image( + moving_image, + result["forward_transform"], + fixed_image, + interpolation_method="linear", + ) + ncc = _foreground_ncc(fixed_arr, itk.array_from_image(warped), foreground) + print( + f"\nMatrix-initial composed NCC={ncc:.4f} (unregistered={ncc_unregistered:.4f})" + ) + assert ncc > ncc_unregistered, ( + f"Matrix-initial composition worse than unregistered: " + f"{ncc:.4f} <= {ncc_unregistered:.4f}" + ) + + def test_affine_and_rigid_transform_types( + self, + registrar_ANTS: RegisterImagesANTS, + test_images: list[Any], + ) -> None: + """Affine and Rigid transform types run and improve alignment. + + Regression guard for the ANTS preset names: ``set_transform_type`` + previously mapped Affine/Rigid to ``antsRegistration{Affine,Rigid}Quick`` + preset strings that do not exist in antspy, raising ValueError. Each + type must now run and warp the moving image onto the fixed grid at least + as well as the unregistered pair. + """ + fixed_image = test_images[0] + moving_image = test_images[min(10, len(test_images) - 1)] + + fixed_arr = itk.array_from_image(fixed_image) + threshold = float(np.percentile(fixed_arr, 70.0)) + foreground = fixed_arr > threshold + ncc_unregistered = _foreground_ncc( + fixed_arr, itk.array_from_image(moving_image), foreground + ) + + transform_tools = TransformTools() + for transform_type in ("Rigid", "Affine"): + registrar = RegisterImagesANTS() + registrar.set_modality("ct") + registrar.set_transform_type(transform_type) + registrar.set_fixed_image(fixed_image) + result = registrar.register(moving_image=moving_image) + warped = transform_tools.transform_image( + moving_image, + result["forward_transform"], + fixed_image, + interpolation_method="linear", + ) + ncc = _foreground_ncc(fixed_arr, itk.array_from_image(warped), foreground) + print( + f"\n{transform_type} transform NCC={ncc:.4f} " + f"(unregistered={ncc_unregistered:.4f})" + ) + assert ncc > ncc_unregistered, ( + f"{transform_type} registration did not improve alignment: " + f"{ncc:.4f} <= {ncc_unregistered:.4f}" + ) + def test_multiple_registrations( self, registrar_ANTS: RegisterImagesANTS, test_images: list[Any] ) -> None: diff --git a/tests/test_transform_tools.py b/tests/test_transform_tools.py index 3b12635..75e16de 100644 --- a/tests/test_transform_tools.py +++ b/tests/test_transform_tools.py @@ -19,6 +19,24 @@ from physiomotion4d.transform_tools import TransformTools +def test_generate_grid_image_clamps_boundary_lines() -> None: + """ + Grid image clamps boundary slices for an ITK image with axes (X, Y, Z). + + The synthetic ITK image has axes (X, Y, Z) = (7, 6, 5), created from + NumPy array shape (Z, Y, X) = (5, 6, 7). + """ + image_arr = np.zeros((5, 6, 7), dtype=np.float32) + image_arr[-1, -1, -1] = 5.0 + image = itk.image_from_array(image_arr) + + grid_image = TransformTools().generate_grid_image(image, grid_size=2, line_width=3) + grid_arr = itk.array_from_image(grid_image) + + assert grid_arr.shape == image_arr.shape + assert grid_arr[0, 0, 0] == 5.0 + + @pytest.mark.slow class TestTransformTools: """Test suite for TransformTools functionality.""" diff --git a/tests/test_workflow_fine_tune_icon_registration.py b/tests/test_workflow_fine_tune_icon_registration.py index 11cde75..311fd96 100644 --- a/tests/test_workflow_fine_tune_icon_registration.py +++ b/tests/test_workflow_fine_tune_icon_registration.py @@ -20,7 +20,6 @@ import pytest import yaml -from physiomotion4d.register_images_icon import RegisterImagesICON from physiomotion4d.workflow_fine_tune_icon_registration import ( WorkflowFineTuneICONRegistration, ) @@ -47,7 +46,7 @@ def two_subject_dataset(tmp_path: Path) -> dict[str, Any]: output_dir = tmp_path / "ft_out" subject_image_files: list[list[str]] = [] - subject_segmentation_files: list[list[Optional[str]]] = [] + subject_labelmap_files: list[list[Optional[str]]] = [] for patient_id in ("pm0001", "pm0002"): pdir = data_dir / patient_id pdir.mkdir() @@ -61,14 +60,14 @@ def two_subject_dataset(tmp_path: Path) -> dict[str, Any]: images.append(str(image_path)) segs.append(str(label_path)) subject_image_files.append(images) - subject_segmentation_files.append(segs) + subject_labelmap_files.append(segs) return { "output_dir": output_dir, "fine_tune_name": "test_exp", "subject_ids": ["pm0001", "pm0002"], "subject_image_files": subject_image_files, - "subject_segmentation_files": subject_segmentation_files, + "subject_labelmap_files": subject_labelmap_files, } @@ -128,7 +127,7 @@ def test_init_rejects_mismatched_subject_ids_length(tmp_path: Path) -> None: ) -def test_uses_segmentations_and_uses_masks_flags(tmp_path: Path) -> None: +def test_use_segmentations_and_use_masks_flags(tmp_path: Path) -> None: """The two helper flags reflect supplied companions independently.""" base: dict[str, Any] = { "subject_image_files": [["a"]], @@ -136,45 +135,20 @@ def test_uses_segmentations_and_uses_masks_flags(tmp_path: Path) -> None: "fine_tune_name": "x", } none_wf = WorkflowFineTuneICONRegistration(**base) - assert not none_wf.uses_segmentations - assert not none_wf.uses_masks + assert not none_wf.use_segmentations + assert not none_wf.use_masks seg_only = WorkflowFineTuneICONRegistration( - **base, subject_segmentation_files=[["seg.nii.gz"]] + **base, subject_labelmap_files=[["seg.nii.gz"]] ) - assert seg_only.uses_segmentations - assert seg_only.uses_masks # derived from segs + assert seg_only.use_segmentations + assert seg_only.use_masks # derived from segs mask_only = WorkflowFineTuneICONRegistration( **base, subject_mask_files=[["mask.nii.gz"]] ) - assert not mask_only.uses_segmentations - assert mask_only.uses_masks - - -# --------------------------------------------------------------------------- -# RegisterImagesICON.create_mask (in-memory dilation, used by the workflow) -# --------------------------------------------------------------------------- - - -def test_create_mask_thresholds_and_dilates() -> None: - """Single-voxel labelmap becomes a binary mask whose dilation grows it.""" - arr = np.zeros((5, 5, 5), dtype=np.uint8) - arr[2, 2, 2] = 3 # non-zero label id - labelmap = itk.image_from_array(arr) - # Unit isotropic spacing so dilation_mm == voxel radius. - labelmap.SetSpacing([1.0, 1.0, 1.0]) - - no_dilate = RegisterImagesICON.create_mask(labelmap, dilation_mm=0.0) - no_dilate_arr = itk.array_from_image(no_dilate) - assert set(np.unique(no_dilate_arr).tolist()) == {0, 1} - assert int(no_dilate_arr.sum()) == 1 - - dilated = RegisterImagesICON.create_mask(labelmap, dilation_mm=1.0) - dilated_arr = itk.array_from_image(dilated) - assert int(dilated_arr.sum()) > 1 - # Original foreground voxel stays foreground. - assert dilated_arr[2, 2, 2] == 1 + assert not mask_only.use_segmentations + assert mask_only.use_masks # --------------------------------------------------------------------------- @@ -221,7 +195,7 @@ def test_prepare_dataset_skips_frames_with_missing_segmentation( subject_image_files=[[str(img_a), str(img_b)]], output_dir=tmp_path / "out", fine_tune_name="exp", - subject_segmentation_files=[[str(seg_a), None]], + subject_labelmap_files=[[str(seg_a), None]], log_level=logging.CRITICAL, ) dataset_json_path = workflow.prepare_dataset() @@ -246,7 +220,7 @@ def test_prepare_dataset_uses_explicit_mask_over_derived(tmp_path: Path) -> None subject_image_files=[[str(image)]], output_dir=tmp_path / "out", fine_tune_name="exp", - subject_segmentation_files=[[str(seg)]], + subject_labelmap_files=[[str(seg)]], subject_mask_files=[[str(explicit_mask)]], log_level=logging.CRITICAL, ) @@ -295,7 +269,7 @@ def test_prepare_dataset_derives_mask_next_to_labelmap_by_default( seg_files = [ Path(s) - for inner in workflow.subject_segmentation_files or [] + for inner in workflow.subject_labelmap_files or [] for s in inner if s is not None ] @@ -325,7 +299,7 @@ def test_prepare_dataset_derives_mask_under_explicit_mask_dir( # None of the labelmap-adjacent locations should have been written to. seg_files = [ Path(s) - for inner in workflow.subject_segmentation_files or [] + for inner in workflow.subject_labelmap_files or [] for s in inner if s is not None ] diff --git a/tutorials/tutorial_08_dirlab_pca_time_series.py b/tutorials/tutorial_08_dirlab_pca_time_series.py index 232e8ed..75cd063 100644 --- a/tutorials/tutorial_08_dirlab_pca_time_series.py +++ b/tutorials/tutorial_08_dirlab_pca_time_series.py @@ -170,9 +170,16 @@ def run_tutorial() -> dict[str, Any]: compression=True, ) + # Warp the reference-space fitted mesh into this phase's space. + # Warping reference -> phase POINTS uses the forward transform + # (the fixed -> moving point map), which is the opposite of the + # transform used to warp an image into phase space (images pull + # back, points push forward). The forward transform is named + # "phase_to_reference" after its image-warp role. See + # docs/developer/transform_conventions. phase_mesh = transform_tools.transform_pvcontour( fitted_reference_mesh, - reference_to_phase, + phase_to_reference, with_deformation_magnitude=True, ) phase_mesh_file = meshes_dir / f"{phase_name}_pca_fit.vtp"