From e6a8b027bc6b44af0fbc99244b5ceb6029d725c7 Mon Sep 17 00:00:00 2001 From: Stephen Aylward Date: Sat, 30 May 2026 08:13:37 -0400 Subject: [PATCH 01/10] ENH: Standardize forward/inverse transform conventions in registration Transform conventions - Add docs/developer/transform_conventions.rst documenting the rule that forward_transform warps the moving IMAGE onto the fixed grid, while warping moving points/landmarks into fixed space uses inverse_transform (image and point warps need opposite transforms), and that model (PCA) registration returns transforms in the opposite orientation from image registration. - Clarify docstrings/comments to this convention across register_images_{base,ants,greedy,icon}, register_models_{distance_maps, icp,pca}, register_time_series_images, and the workflow_* modules; link the new page from registration_images.rst, registration_models.rst, index.rst. - register_images_base: rework initial-transform handling to pre-warp the moving image onto the fixed grid before registration (matching the ICON backend) instead of composing an itk transform. Longitudinal registration experiment - Add 1-preregistration.py: register every gated frame to its patient reference with ANTS and Greedy, recording per-label Dice and landmark squared error, and writing the warped image, labelmap, mask, transforms, and deformation-grid visualization per frame. - 1-finetune_icon.py: merge each patient's ANTS- and Greedy-warped frames into the ICON fine-tuning group and consume the warped loss masks; use Optional[X] type hints. - Update 2-recon_4d_icon_eval.py and 3-run_registration_method_comparison.py for the convention; add tests/test_register_images_ants.py and registration_test.py; drop the obsolete experiments/Heart-GatedCT_To_USD/test_compare_registration_speed.py. Code style - Correct the documented string-quote standard to double quotes (the actual codebase standard) in pyproject ruff quote-style/inline-quotes, CLAUDE.md, AGENTS.md, and docs/contributing.rst. Co-Authored-By: Claude Opus 4.8 (1M context) --- AGENTS.md | 2 +- CLAUDE.md | 2 +- docs/contributing.rst | 2 +- docs/developer/registration_images.rst | 7 +- docs/developer/registration_models.rst | 5 + docs/developer/transform_conventions.rst | 137 ++++ docs/index.rst | 1 + .../test_compare_registration_speed.py | 182 ------ .../1-finetune_icon.py | 155 ++++- .../1-preregistration.py | 605 ++++++++++++++++++ .../2-recon_4d_icon_eval.py | 52 +- .../3-run_registration_method_comparison.py | 8 +- .../registration_test.py | 110 ++++ pyproject.toml | 4 +- src/physiomotion4d/register_images_ants.py | 86 ++- src/physiomotion4d/register_images_base.py | 34 +- src/physiomotion4d/register_images_greedy.py | 71 +- src/physiomotion4d/register_images_icon.py | 39 +- .../register_models_distance_maps.py | 13 +- src/physiomotion4d/register_models_icp.py | 14 +- src/physiomotion4d/register_models_pca.py | 23 +- .../register_time_series_images.py | 30 +- .../workflow_fine_tune_icon_registration.py | 4 +- .../workflow_reconstruct_highres_4d_ct.py | 15 +- tests/test_register_images_ants.py | 233 +++++++ .../tutorial_08_dirlab_pca_time_series.py | 9 +- 26 files changed, 1538 insertions(+), 305 deletions(-) create mode 100644 docs/developer/transform_conventions.rst delete mode 100644 experiments/Heart-GatedCT_To_USD/test_compare_registration_speed.py create mode 100644 experiments/LongitudinalRegistration/1-preregistration.py create mode 100644 experiments/LongitudinalRegistration/registration_test.py diff --git a/AGENTS.md b/AGENTS.md index 745ce10..9e76207 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -124,7 +124,7 @@ Version bumping: `bumpver update --patch`, `--minor`, or `--major`. - Scripts that instantiate `SegmentChestTotalSegmentator` must guard the top-level invocation with `if __name__ == "__main__":` on Windows (`torch.multiprocessing` requires it). -- Single quotes for strings; double quotes for docstrings. Keep lines at or +- Double quotes for strings and docstrings. Keep lines at or below 88 characters. - Full type hints are required under strict mypy. Use `Optional[X]`, not `X | None`. diff --git a/CLAUDE.md b/CLAUDE.md index bf268f0..af531fa 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -148,7 +148,7 @@ Document via docstrings and inline comments. ## Code Style -- Single quotes for strings; double quotes for docstrings +- Double quotes for strings and docstrings - Full type hints (`mypy` strict; `disallow_untyped_defs = true`) - `Optional[X]` not `X | None` (ruff `UP007` suppressed) - Breaking changes are acceptable — backward compatibility is not a priority diff --git a/docs/contributing.rst b/docs/contributing.rst index 4c62e97..d6b5d96 100644 --- a/docs/contributing.rst +++ b/docs/contributing.rst @@ -141,7 +141,7 @@ PhysioMotion4D follows strict code quality standards using modern, fast tooling. Formatting and Linting with Ruff --------------------------------- -We use **Ruff** for all formatting and linting (line length: 88, single quotes): +We use **Ruff** for all formatting and linting (line length: 88, double quotes): .. code-block:: bash diff --git a/docs/developer/registration_images.rst b/docs/developer/registration_images.rst index dab4078..1c0d023 100644 --- a/docs/developer/registration_images.rst +++ b/docs/developer/registration_images.rst @@ -25,7 +25,11 @@ Basic Pattern registered = registrar.get_registered_image() The result dictionary contains ``forward_transform``, ``inverse_transform``, -and ``loss``. +and ``loss``. Applying the right one is critical and direction-dependent: +``forward_transform`` warps the moving image onto the fixed grid, while +``inverse_transform`` warps moving points/landmarks into fixed space (image and +point warps use opposite transforms). See +:doc:`transform_conventions` for the full rules. Time Series =========== @@ -57,5 +61,6 @@ Development Notes See Also ======== +* :doc:`transform_conventions` * :doc:`../api/registration/index` * :doc:`workflows` diff --git a/docs/developer/registration_models.rst b/docs/developer/registration_models.rst index 3102f1a..466b904 100644 --- a/docs/developer/registration_models.rst +++ b/docs/developer/registration_models.rst @@ -45,9 +45,14 @@ Development Notes * Convert volumetric meshes to surfaces before surface registration when needed. * Treat ITK/PyVista coordinate transforms as high-risk and add focused tests. * Keep synthetic test meshes small and deterministic. +* ``RegisterModelsPCA`` returns ``forward_point_transform`` / + ``inverse_point_transform``. These are **point** transforms whose orientation + is opposite to the image-registration transforms; see + :doc:`transform_conventions` before applying them to images or meshes. See Also ======== +* :doc:`transform_conventions` * :doc:`../api/model_registration/index` * :doc:`workflows` diff --git a/docs/developer/transform_conventions.rst b/docs/developer/transform_conventions.rst new file mode 100644 index 0000000..ee80cf0 --- /dev/null +++ b/docs/developer/transform_conventions.rst @@ -0,0 +1,137 @@ +=============================== +Transform Direction Conventions +=============================== + +Registration in PhysioMotion4D produces a pair of transforms, and choosing the +wrong one of the pair is the single most common registration mistake. The rules +are simple but easy to get backwards, because **warping an image and warping a +point require opposite transforms**, and because **model (PCA) registration +returns its transforms in the opposite orientation from image registration**. + +Read this page before applying any transform to an image, mask, contour, or +landmark. + +The two transform families +=========================== + +Image registration + :class:`physiomotion4d.RegisterImagesANTS`, + :class:`physiomotion4d.RegisterImagesICON`, and + :class:`physiomotion4d.RegisterImagesGreedy` register a *moving* image to a + *fixed* image and return a dict with ``forward_transform`` and + ``inverse_transform``. :class:`physiomotion4d.RegisterTimeSeriesImages` + returns the list-valued ``forward_transforms`` / ``inverse_transforms``. + +Model (PCA) registration + :class:`physiomotion4d.RegisterModelsPCA` deforms a *template* model toward + a *target* (patient) and, via ``compute_pca_transforms()``, returns + ``forward_point_transform`` and ``inverse_point_transform``. These are + **point transforms**, oriented opposite to the image-registration transforms + (see `PCA point transforms`_ below). + +Image warping vs. point warping use opposite transforms +======================================================== + +ITK resampling is a *pull-back* operation. To build the warped image on the +fixed grid, :func:`TransformTools.transform_image` (an ``itk.ResampleImageFilter``) +visits every fixed-grid sample ``q`` and looks up the moving image at +``transform.TransformPoint(q)``. The transform it needs therefore maps +**fixed-space coordinates to moving-space coordinates**. + +Warping a *point* (landmark, contour vertex, mesh node) is a *push-forward* +operation: :func:`TransformTools.transform_pvcontour` / +:func:`TransformTools.transform_dataset` apply ``transform.TransformPoint(p)`` +directly to each input point. To move a moving-space landmark to its location in +the fixed image, the transform must map **moving-space coordinates to +fixed-space coordinates** -- the inverse of the image-warp transform. + +So for the **same** moving-to-fixed registration result: + +.. list-table:: Image registration: which transform to apply + :header-rows: 1 + :widths: 50 25 25 + + * - Goal + - Transform + - Helper + * - Warp the **moving image** into fixed space (onto the fixed grid) + - ``forward_transform`` + - :func:`TransformTools.transform_image` + * - Warp **moving points / contours / landmarks** into fixed space + - ``inverse_transform`` + - :func:`TransformTools.transform_pvcontour` + * - Warp the **fixed image** into moving space (e.g. time-series reconstruction) + - ``inverse_transform`` + - :func:`TransformTools.transform_image` + * - Warp **fixed points / contours / landmarks** into moving space + - ``forward_transform`` + - :func:`TransformTools.transform_pvcontour` + +The first two rows are the everyday case (warping the registered moving data +into the fixed/reference frame): the **image uses** ``forward_transform``, the +**points use** ``inverse_transform``. The last two rows are the mirror image; +:meth:`physiomotion4d.RegisterTimeSeriesImages.reconstruct_time_series` is the +canonical consumer of ``inverse_transform`` for image warping (it resamples the +fixed image back onto each moving frame's grid). + +.. note:: + + All three image-registration backends (ANTS, ICON, Greedy) follow this same + convention. ``transform_image(moving, forward_transform, fixed)`` is the + correct call to warp the moving image onto the fixed grid for every backend. + +PCA point transforms +==================== + +:class:`physiomotion4d.RegisterModelsPCA` builds ``forward_point_transform`` +directly from the template-to-target point displacement, so +``forward_point_transform.TransformPoint(template_point)`` returns the +corresponding *target* point. As a **point** map it goes template (moving) to +target (fixed) -- which is the same orientation as image registration's +``inverse_transform``, and therefore the **opposite** orientation of image +registration's ``forward_transform``. + +Concretely, treating the template as the moving object and the patient/target as +the fixed object: + +.. list-table:: Same goal, opposite transform names across the two families + :header-rows: 1 + :widths: 50 25 25 + + * - Goal + - Image registration + - PCA model registration + * - Warp the **image** (moving/template space -> fixed/target grid) + - ``forward_transform`` + - ``inverse_point_transform`` + * - Warp **points / meshes** (moving/template -> fixed/target) + - ``inverse_transform`` + - ``forward_point_transform`` + +In other words, ``forward_point_transform`` plays the role that +``inverse_transform`` plays for image registration, and +``inverse_point_transform`` plays the role of ``forward_transform``. Deforming +the template mesh onto the patient (the usual PCA use, performed internally by +``transform_template_model()`` and ``transform_point()``) uses +``forward_point_transform``; resampling an image with the PCA result uses +``inverse_point_transform``. + +Rule of thumb +============= + +* **Images pull back; points push forward.** For one registration result, the + image and the points always use the two *different* members of the transform + pair. +* **Image into the reference frame** -> ``forward_transform`` (image + registration) / ``inverse_point_transform`` (PCA). +* **Points into the reference frame** -> ``inverse_transform`` (image + registration) / ``forward_point_transform`` (PCA). +* When in doubt, warp a known landmark and a small image patch and confirm they + land in the same place before trusting a pipeline. + +See Also +======== + +* :doc:`registration_images` +* :doc:`registration_models` +* :doc:`utilities` diff --git a/docs/index.rst b/docs/index.rst index 458de99..320a9fd 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -162,6 +162,7 @@ per-tutorial implementation details. developer/segmentation developer/registration_images developer/registration_models + developer/transform_conventions developer/usd_generation developer/utilities diff --git a/experiments/Heart-GatedCT_To_USD/test_compare_registration_speed.py b/experiments/Heart-GatedCT_To_USD/test_compare_registration_speed.py deleted file mode 100644 index addadea..0000000 --- a/experiments/Heart-GatedCT_To_USD/test_compare_registration_speed.py +++ /dev/null @@ -1,182 +0,0 @@ -#!/usr/bin/env python -# %% [markdown] -# # Compare registration speed: Greedy vs ANTs vs ICON -# -# This notebook times **Greedy**, **ANTs**, and **ICON** when registering two time points of CT from the Slicer-Heart-CT data (TruncalValve 4D CT). -# -# **Prerequisites:** Run `0-download_and_convert_4d_to_3d.py` first so that `data/Slicer-Heart-CT/` contains the 4D NRRD and the 3D slice series (`slice_000.mha`, `slice_001.mha`, ...), and `results/slice_fixed.mha` exists. - -# %% -import os -import time - -import itk -import matplotlib.pyplot as plt -import pandas as pd -from itk import TubeTK as ttk - -from physiomotion4d.test_tools import TestTools -from physiomotion4d.register_images_ants import RegisterImagesANTS -from physiomotion4d.register_images_greedy import RegisterImagesGreedy -from physiomotion4d.register_images_icon import RegisterImagesICON - -_HERE = os.path.dirname(os.path.abspath(__file__)) - -# %% -data_dir = os.path.join(_HERE, "..", "..", "data", "Slicer-Heart-CT") -output_dir = os.path.join(_HERE, "results") -os.makedirs(output_dir, exist_ok=True) - -# Fixed = reference time point; moving = time point to align to fixed -fixed_image_path = os.path.join(output_dir, "slice_fixed.mha") -moving_image_path = os.path.join(data_dir, "slice_000.mha") - -if not os.path.exists(fixed_image_path): - raise FileNotFoundError( - f"Fixed image not found: {fixed_image_path}. " - "Run 0-download_and_convert_4d_to_3d.py first." - ) -if not os.path.exists(moving_image_path): - raise FileNotFoundError( - f"Moving image not found: {moving_image_path}. " - "Run 0-download_and_convert_4d_to_3d.py first." - ) - -fixed_image = itk.imread(fixed_image_path) -moving_image = itk.imread(moving_image_path) -print(f"Fixed image: {itk.size(fixed_image)}, spacing {itk.spacing(fixed_image)}") -print(f"Moving image: {itk.size(moving_image)}, spacing {itk.spacing(moving_image)}") - -# %% [markdown] -# ## Optional: downsample for faster comparison -# -# Set `downsample_factor = 1.0` to use full resolution (slower). Use e.g. `0.5` to halve each dimension for a quicker run. - -# %% -downsample_factor = 0.5 # 1.0 = full resolution - -if downsample_factor != 1.0: - resampler_f = ttk.ResampleImage.New(Input=fixed_image) - resampler_f.SetResampleFactor([downsample_factor] * 3) - resampler_f.Update() - fixed_image = resampler_f.GetOutput() - - resampler_m = ttk.ResampleImage.New(Input=moving_image) - resampler_m.SetResampleFactor([downsample_factor] * 3) - resampler_m.Update() - moving_image = resampler_m.GetOutput() - print(f"Downsampled to factor {downsample_factor}") - print(f" Fixed: {itk.size(fixed_image)}") - print(f" Moving: {itk.size(moving_image)}") -else: - print("Using full resolution.") - -# %% [markdown] -# ## Run each method and record time -# -# All three use **deformable** registration (Greedy: affine + deformable; ANTs: SyN; ICON: deep learning). Settings are chosen for a fair comparison with reduced iterations so the notebook runs in a few minutes. - -# %% -results_list = [] - -# --- Greedy (deformable) --- -try: - reg_g = RegisterImagesGreedy() - reg_g.set_modality("ct") - reg_g.set_transform_type("Deformable") - reg_g.set_number_of_iterations([10, 5, 2]) - reg_g.set_fixed_image(fixed_image) - - t0 = time.perf_counter() - out_g = reg_g.register(moving_image) - elapsed_g = time.perf_counter() - t0 - - loss_g = out_g.get("loss") - results_list.append( - { - "method": "Greedy", - "time_sec": round(elapsed_g, 2), - "loss": float(loss_g) if loss_g is not None else None, - } - ) - print(f"Greedy: {elapsed_g:.2f} s") -except Exception as e: - results_list.append({"method": "Greedy", "time_sec": None, "loss": None}) - print(f"Greedy: failed - {e}") - -# --- ANTs (deformable SyN) --- -try: - reg_a = RegisterImagesANTS() - reg_a.set_modality("ct") - reg_a.set_transform_type("Deformable") - reg_a.set_number_of_iterations([10, 5, 2]) # reduced for speed - reg_a.set_fixed_image(fixed_image) - - t0 = time.perf_counter() - out_a = reg_a.register(moving_image) - elapsed_a = time.perf_counter() - t0 - - loss_a = out_a.get("loss") - results_list.append( - { - "method": "ANTs", - "time_sec": round(elapsed_a, 2), - "loss": float(loss_a) if loss_a is not None else None, - } - ) - print(f"ANTs: {elapsed_a:.2f} s") -except Exception as e: - results_list.append({"method": "ANTs", "time_sec": None, "loss": None}) - print(f"ANTs: failed - {e}") - -# --- ICON (deformable, GPU) --- -try: - reg_i = RegisterImagesICON() - reg_i.set_modality("ct") - reg_i.set_number_of_iterations(50) - reg_i.set_fixed_image(fixed_image) - - t0 = time.perf_counter() - out_i = reg_i.register(moving_image) - elapsed_i = time.perf_counter() - t0 - - loss_i = out_i.get("loss") - results_list.append( - { - "method": "ICON", - "time_sec": round(elapsed_i, 2), - "loss": float(loss_i) if loss_i is not None else None, - } - ) - print(f"ICON: {elapsed_i:.2f} s") -except Exception as e: - results_list.append({"method": "ICON", "time_sec": None, "loss": None}) - print(f"ICON: failed - {e}") - -df = pd.DataFrame(results_list) - -# %% -print(df) - -# %% -fig, ax = plt.subplots(figsize=(6, 4)) -valid = df["time_sec"].notna() -if valid.any(): - methods = df.loc[valid, "method"] - times = df.loc[valid, "time_sec"] - ax.bar(methods, times, color=["#2ecc71", "#3498db", "#9b59b6"]) - ax.set_ylabel("Time (seconds)") - ax.set_title("Registration time: two time points (Slicer-Heart-CT)") - plt.tight_layout() - if not TestTools.running_as_test(): - plt.show() -else: - print("No successful runs to plot.") - -# %% [markdown] -# ## Notes -# -# - **Greedy**: CPU-based, often faster than ANTs for comparable quality; see [Greedy](https://greedy.readthedocs.io/) and [picsl-greedy](https://pypi.org/project/picsl-greedy/). -# - **ANTs**: CPU-based, very widely used; typically slower than Greedy for similar settings. -# - **ICON**: GPU-based (UniGradIcon); speed depends on GPU. Loss values are not directly comparable across methods. -# - For a quicker comparison, use `downsample_factor = 0.5` or reduce `number_of_iterations` further. diff --git a/experiments/LongitudinalRegistration/1-finetune_icon.py b/experiments/LongitudinalRegistration/1-finetune_icon.py index 968a078..a11118e 100644 --- a/experiments/LongitudinalRegistration/1-finetune_icon.py +++ b/experiments/LongitudinalRegistration/1-finetune_icon.py @@ -15,10 +15,18 @@ # Each patient directory under ``src_data_dir_base`` is one ``subject_id``; # all of that patient's gated time-point frames form a paired training group. # Frames whose labelmap is missing on disk are dropped from the dataset. +# +# In addition to the original ``gated_nii`` frames, each patient's training +# group is augmented with that patient's ANTS- and Greedy-warped frames +# written by ``1-preregistration.py`` (warped image + labelmap per gated +# frame, under ``output_dir / / ``). Because the warped +# frames are merged into the *same* ``subject_id`` group, uniGradICON pairs the +# original gated frames and both backends' pre-registered frames together. # %% import os from pathlib import Path +from typing import Optional import itk @@ -36,9 +44,18 @@ # Where the workflow writes the dataset JSON, YAML config, derived masks, and # the uniGradICON ``checkpoints/`` tree. experiment_dir resolves to # ``output_dir / fine_tune_name``. -output_dir = Path("./results") +_HERE = Path(__file__).parent +output_dir = _HERE / "results" fine_tune_name = "icon_finetuned" +# Pre-registration augmentation: ``1-preregistration.py`` warps every gated +# moving frame into reference space with these backends and writes the warped +# image + labelmap under ``preregistration_dir / .lower() / +# ``. Those warped frames are merged into each patient's training +# group below (section 4b). +preregistration_dir = output_dir +preregistration_methods = ["ANTS", "greedy"] + # Fixed train/test split: sort patients in ``ref_data_dir`` by filename; # first 80% are train, last 20% are test. ``2-recon_4d_icon_eval.py`` applies # the same rule so the two scripts agree without a cached split record. @@ -47,7 +64,7 @@ # Local clone of uniGradICON (feat-add-finetuning branch) — prepended to # PYTHONPATH so the subprocess picks up the local source instead of the # installed package. Set to ``None`` to use the pip-installed unigradicon. -unigradicon_src_path: Path | None = Path(__file__).parent / "uniGradICON" / "src" +unigradicon_src_path: Optional[Path] = Path(__file__).parent / "uniGradICON" / "src" # %% [markdown] # ## 2. Enumerate patients and apply the fixed 80/20 split @@ -94,7 +111,7 @@ # %% train_image_files: list[list[str]] = [] -train_segmentation_files: list[list[str | None]] = [] +train_segmentation_files: list[list[Optional[str]]] = [] valid_train_subjects: list[str] = [] for patient_id in train_subjects: @@ -113,7 +130,7 @@ continue image_paths = [str(src_dir / f) for f in frame_names] - seg_paths: list[str | None] = [] + seg_paths: list[Optional[str]] = [] for f in frame_names: labelmap = seg_dir / f.replace(".nii.gz", "_labelmap.nii.gz") seg_paths.append(str(labelmap) if labelmap.exists() else None) @@ -137,32 +154,120 @@ # %% mask_dilation_mm = 5.0 -train_mask_files: list[list[str | None]] = [] -for image_paths, seg_paths in zip( - train_image_files, train_segmentation_files, strict=True -): - mask_paths: list[str | None] = [] - for seg_path in seg_paths: - if seg_path is None: - mask_paths.append(None) + + +def derive_mask_for(labelmap_path: Path) -> str: + """Create (or reuse) a loss-function mask next to ``labelmap_path``. + + Thresholds the labelmap at ``>0`` and dilates by ``mask_dilation_mm`` mm + via :meth:`RegisterImagesICON.create_mask`, writing the result as + ``_mask.nii.gz`` in the labelmap's own directory. Handles + both ``.nii.gz`` (original Simpleware labelmaps) and ``.mha`` + (pre-registration warped labelmaps). Returns the mask path as a string; + existing masks on disk are reused unmodified. + """ + name = labelmap_path.name + if name.endswith(".nii.gz"): + stem = name[:-7] + elif name.endswith(".mha"): + stem = name[:-4] + else: + stem = labelmap_path.stem + mask_p = labelmap_path.parent / f"{stem}_mask.nii.gz" + if not mask_p.exists(): + mask = RegisterImagesICON.create_mask( + itk.imread(str(labelmap_path)), dilation_mm=mask_dilation_mm + ) + itk.imwrite(mask, str(mask_p), compression=True) + return str(mask_p) + + +train_mask_files: list[list[Optional[str]]] = [] +for seg_paths in train_segmentation_files: + train_mask_files.append( + [derive_mask_for(Path(s)) if s is not None else None for s in seg_paths] + ) + +# %% [markdown] +# ## 4b. Merge ANTS / Greedy pre-registered frames into each training group +# +# ``1-preregistration.py`` warps every gated moving frame into reference space +# with the ANTS and Greedy backends, writing ``.mha`` (warped image), +# ``_labelmap.mha`` (warped labelmap), and ``_deformation_grid.mha`` +# under ``preregistration_dir / / ``. Here those warped +# frames + labelmaps (with derived loss masks) are appended to the *same* +# patient's training group, so uniGradICON pairs the original gated frames and +# both backends' pre-registered frames together (they share a ``subject_id``). +# Patients/methods with no pre-registration output on disk are skipped. + + +# %% +def gather_warped_frames(method_dir: Path) -> tuple[list[str], list[Optional[str]]]: + """Return ``(warped_image_paths, warped_labelmap_paths)`` for one + ``preregistration_dir / / `` directory. + + Enumerates the warped moving images (``.mha``), excluding the + ``_labelmap.mha``, ``_labelmap_mask.mha``, and ``_deformation_grid.mha`` + companions, and pairs each with its ``_labelmap.mha`` (``None`` when + that labelmap is absent). Returns empty lists when ``method_dir`` does + not exist. + """ + if not method_dir.is_dir(): + return [], [] + companion_suffixes = ( + "_labelmap.mha", + "_labelmap_mask.mha", + "_deformation_grid.mha", + ) + image_paths: list[str] = [] + labelmap_paths: list[Optional[str]] = [] + for mha in sorted(method_dir.glob("*.mha")): + if mha.name.endswith(companion_suffixes): + continue + stem = mha.name[:-4] + labelmap = method_dir / f"{stem}_labelmap.mha" + image_paths.append(str(mha)) + labelmap_paths.append(str(labelmap) if labelmap.exists() else None) + return image_paths, labelmap_paths + + +for subject_index, patient_id in enumerate(valid_train_subjects): + for method_name in preregistration_methods: + method_dir = preregistration_dir / method_name.lower() / patient_id + warped_images, warped_labelmaps = gather_warped_frames(method_dir) + if not warped_images: + print( + f" {patient_id}/{method_name}: no pre-registered frames " + f"in {method_dir}" + ) continue - seg_p = Path(seg_path) - stem = seg_p.name - stem = stem[:-7] if stem.endswith(".nii.gz") else seg_p.stem - mask_p = seg_p.parent / f"{stem}_mask.nii.gz" - if not mask_p.exists(): - mask = RegisterImagesICON.create_mask( - itk.imread(str(seg_p)), dilation_mm=mask_dilation_mm + warped_masks: list[Optional[str]] = [] + for lm in warped_labelmaps: + if lm is None: + warped_masks.append(None) + continue + # 1-preregistration.py writes the warped loss mask next to the + # warped labelmap; prefer it, deriving one only if it is absent. + warped_mask = Path(f"{lm[:-4]}_mask.mha") + warped_masks.append( + str(warped_mask) if warped_mask.exists() else derive_mask_for(Path(lm)) ) - itk.imwrite(mask, str(mask_p), compression=True) - mask_paths.append(str(mask_p)) - train_mask_files.append(mask_paths) + train_image_files[subject_index].extend(warped_images) + train_segmentation_files[subject_index].extend(warped_labelmaps) + train_mask_files[subject_index].extend(warped_masks) + n_seg = sum(1 for lm in warped_labelmaps if lm is not None) + print( + f" {patient_id}/{method_name}: +{len(warped_images)} warped frames, " + f"{n_seg} with labelmap" + ) # %% [markdown] # ## 5. Fine-tune uniGradICON on the train cohort # -# The workflow consumes both the labelmaps (for paired-with-seg training and -# ``use_label``) and the pre-computed masks (for ``loss_function_masking``) +# Each train group now holds the original gated frames plus the merged ANTS +# and Greedy pre-registered frames (section 4b). The workflow consumes both +# the labelmaps (for paired-with-seg training) and the pre-computed masks (for +# ``loss_function_masking``) # and launches ``unigradicon.finetuning.finetune`` as a subprocess. The # final checkpoint lands at # :meth:`WorkflowFineTuneICONRegistration.expected_weights_path`, which is @@ -178,7 +283,7 @@ subject_mask_files=train_mask_files, mask_dilation_mm=mask_dilation_mm, unigradicon_src_path=unigradicon_src_path, - epochs=100, + epochs=500, ) weights_path = workflow.run_fine_tuning() diff --git a/experiments/LongitudinalRegistration/1-preregistration.py b/experiments/LongitudinalRegistration/1-preregistration.py new file mode 100644 index 0000000..dfd8ac4 --- /dev/null +++ b/experiments/LongitudinalRegistration/1-preregistration.py @@ -0,0 +1,605 @@ +# %% [markdown] +# # Pre-registration: compare ANTS vs Greedy on the Duke gated CT cohort +# +# Registers every gated CT time-point of every Duke patient under +# ``ref_data_dir`` (100% of the cohort -- no train/test split) to that +# patient's reference image, using two backends in turn: +# +# * :class:`RegisterImagesANTS` (CPU, SyN deformable) +# * :class:`RegisterImagesGreedy` (CPU, deformable) +# +# For each frame the script records wall-clock registration time, writes +# the warped/resampled moving image to disk, warps the moving labelmap +# into reference space to compute per-label Dice, and warps the moving +# landmarks into reference space to compute squared-error landmark +# statistics (mm^2) against the reference landmarks. +# +# Inputs (same data as ``1-finetune_icon.py``): +# * ``ref_data_dir / pm*_ref.nii.gz`` -- per-patient reference CT +# * ``src_data_dir_base / / *.nii.gz`` -- gated CT frames +# * ``segmentation_dir_base / / _labelmap.nii.gz`` +# -- per-frame multi-label segmentations +# * ``segmentation_dir_base / / _labelmap_mask.nii.gz`` +# -- pre-computed loss-function masks (re-derived on the fly if absent, +# matching the 5 mm dilation used by ``1-finetune_icon.py``) +# * ``segmentation_dir_base / / _landmark.mrk.json`` +# -- per-frame 3D Slicer Markups landmarks in LPS +# +# Outputs under ``results/``: +# * ``ants///.mha`` and +# ``greedy///.mha`` -- warped moving image +# per time point, alongside the forward/inverse transforms (``.hdf``), +# a ``_deformation_grid.mha`` visualization of the registration +# deformation, the warped ``_labelmap.mha`` and its warped +# loss-function mask ``_labelmap_mask.mha``, +# and the warped ``_landmark.mrk.json`` +# * ``preregistration_landmarks.csv`` -- per-landmark squared errors +# * ``preregistration_dice.csv`` -- per-label Dice +# * ``preregistration_summary.csv`` -- per-(subject, method, timepoint) +# time, mean Dice, MSE, RMSE +# +# Run interactively cell-by-cell; all paths are hard-coded. + +# %% +import csv +import re +import time +from pathlib import Path +from typing import Optional + +import itk +import numpy as np + +from physiomotion4d.landmark_tools import LandmarkTools +from physiomotion4d.register_images_ants import RegisterImagesANTS +from physiomotion4d.register_images_greedy import RegisterImagesGreedy +from physiomotion4d.register_images_icon import RegisterImagesICON +from physiomotion4d.transform_tools import TransformTools + +# %% [markdown] +# ## 1. Hard-coded paths and configuration + +# %% +ref_data_dir = Path("d:/PhysioMotion4D/duke_data/ref_images") +src_data_dir_base = Path("d:/PhysioMotion4D/duke_data/gated_nii") +segmentation_dir_base = Path("d:/PhysioMotion4D/duke_data/simple_ascardio") + +_HERE = Path(__file__).parent +output_dir = _HERE / "results" +output_dir.mkdir(parents=True, exist_ok=True) + +# Reference frames in gated_nii are named ``_ref.nii.gz``; every +# other ``.nii.gz`` (excluding ``nop`` non-gated references) is a gated +# time point. Timepoint tag ``g###`` is extracted from each filename. +exclude_tokens = ["nop"] +ref_suffix = "_ref" +timepoint_re = re.compile(r"_g(?P[0-9]{3})") + +# Mask dilation matches 1-finetune_icon.py so any masks we have to +# derive here are identical to the ones written by the fine-tune script. +mask_dilation_mm = 5.0 + +# Iteration schedules. Kept modest for a cohort-wide comparison; raise +# either list for higher accuracy at the cost of runtime. +number_of_iterations_ANTS = [20, 10, 5] +number_of_iterations_greedy = [40, 20, 10] + +methods: list[str] = ["ANTS", "greedy"] + +# Debug knob: when non-empty, only these patient IDs are processed. +# Set to ``[]`` (or ``None``) to run the full cohort. +debug_subjects: list[str] = ["pm0002"] + +detail_landmarks_file = output_dir / "preregistration_landmarks.csv" +detail_dice_file = output_dir / "preregistration_dice.csv" +summary_file = output_dir / "preregistration_summary.csv" +for previous in (detail_landmarks_file, detail_dice_file, summary_file): + if previous.exists(): + previous.unlink() + +# %% [markdown] +# ## 2. Enumerate the full patient cohort +# +# Sort ``ref_data_dir`` by filename so the patient order is stable. +# Every patient is processed -- no train/test split. + +# %% +ref_files = sorted( + p + for p in ref_data_dir.iterdir() + if p.name.startswith("pm00") and p.suffixes[-2:] == [".nii", ".gz"] +) +all_patient_ids = [p.name[:6] for p in ref_files] +print(f"Found {len(all_patient_ids)} patients under {ref_data_dir}") +if debug_subjects: + cohort = [pid for pid in all_patient_ids if pid in debug_subjects] + print( + f"DEBUG: restricting cohort to {debug_subjects} -> " + f"{len(cohort)} matching patients" + ) +else: + cohort = all_patient_ids +print(f"Patient cohort: {cohort}") + +# %% [markdown] +# ## 3. Helpers: labelmap warping, per-label Dice, landmark squared error + +# %% +landmark_tools = LandmarkTools() +transform_tools = TransformTools() + + +def per_label_dice( + fixed_labelmap: itk.Image, warped_labelmap: itk.Image +) -> dict[int, float]: + """Return ``{label_id: Dice}`` for every positive label present in + either the fixed or the warped labelmap. + + Arrays come back from :func:`itk.array_from_image` in shape + ``(Z, Y, X)`` (numpy reverses ITK's index order); we compare element-wise + so the axis convention does not matter as long as both labelmaps live + on the same reference grid (guaranteed because ``warped_labelmap`` was + resampled with ``fixed_labelmap`` as the reference image). + """ + fixed_array = itk.array_from_image(fixed_labelmap) + warped_array = itk.array_from_image(warped_labelmap) + labels = sorted( + {int(v) for v in np.unique(fixed_array)} + | {int(v) for v in np.unique(warped_array)} + ) + labels = [label for label in labels if label > 0] + + dice_by_label: dict[int, float] = {} + for label in labels: + a = fixed_array == label + b = warped_array == label + denom = int(a.sum()) + int(b.sum()) + if denom == 0: + continue + intersection = int(np.logical_and(a, b).sum()) + dice_by_label[label] = 2.0 * intersection / denom + return dice_by_label + + +def warp_landmarks( + inverse_transform: itk.Transform, + moving_landmarks: dict[str, tuple[float, float, float]], +) -> dict[str, tuple[float, float, float]]: + """Warp every moving landmark into reference space. + + Point/landmark warping uses ``inverse_transform`` -- the moving-space -> + fixed-space point map -- which is the opposite of the transform used to + warp the moving image onto the fixed grid (images pull back; points push + forward). Returns a ``{label: (x, y, z)}`` dict in LPS. See + docs/developer/transform_conventions. + """ + return { + name: tuple(float(c) for c in inverse_transform.TransformPoint(point)) + for name, point in moving_landmarks.items() + } + + +def landmark_squared_errors( + warped_landmarks: dict[str, tuple[float, float, float]], + reference_landmarks: dict[str, tuple[float, float, float]], +) -> list[tuple[str, float]]: + """Return per-landmark squared Euclidean error in mm^2 between the + reference-space ``warped_landmarks`` and the matching reference + landmarks, in sorted-name order. + """ + shared = sorted(warped_landmarks.keys() & reference_landmarks.keys()) + errors: list[tuple[str, float]] = [] + for name in shared: + diff = np.asarray(warped_landmarks[name], dtype=np.float64) - np.asarray( + reference_landmarks[name], dtype=np.float64 + ) + errors.append((name, float(np.dot(diff, diff)))) + return errors + + +def load_or_derive_mask(labelmap: itk.Image, mask_path: Path) -> itk.Image: + """Return the cached ``_labelmap_mask.nii.gz`` next to the + labelmap, or derive it via :meth:`RegisterImagesICON.create_mask` + (threshold ``>0`` plus 5 mm physical-radius dilation) and write it + out so subsequent runs and the ICON eval reuse the same mask. + """ + if mask_path.exists(): + return itk.imread(str(mask_path)) + mask = RegisterImagesICON.create_mask(labelmap, dilation_mm=mask_dilation_mm) + itk.imwrite(mask, str(mask_path), compression=True) + return mask + + +# %% [markdown] +# ## 4. Drive the comparison: every patient x every method +# +# For each patient: load the reference image, labelmap, mask, and +# landmarks; load every gated frame (excluding ``nop`` and ``_ref``) with +# its labelmap, mask, and landmarks; then register each frame to the +# reference under both backends. Each frame starts from identity so the +# ANTS-vs-Greedy comparison is independent across frames. + +# %% +summary_rows: list[dict[str, object]] = [] + +# (subject_id, method, timepoint) for frames that produced no usable +# landmark errors -- either no landmark file or no labels shared with the +# reference. Echoed in a highlighted block at the end of the run. +frames_missing_landmarks: list[tuple[str, str, str]] = [] + +for subject_index, subject_id in enumerate(cohort): + print(f"\n=== Subject {subject_index + 1}/{len(cohort)}: {subject_id} ===") + src_dir = src_data_dir_base / subject_id + seg_dir = segmentation_dir_base / subject_id + + if not src_dir.is_dir(): + print(f" Skipping {subject_id}: source dir {src_dir} not found") + continue + if not seg_dir.is_dir(): + print(f" Skipping {subject_id}: segmentation dir {seg_dir} not found") + continue + + # Locate this patient's reference frame in gated_nii (matches the + # `_ref.nii.gz` filename under ref_data_dir). + ref_file = next((p for p in ref_files if p.name.startswith(subject_id)), None) + if ref_file is None: + print(f" Skipping {subject_id}: no reference image found") + continue + ref_stem = ref_file.name[:-7] + ref_labelmap_path = seg_dir / f"{ref_stem}_labelmap.nii.gz" + ref_mask_path = seg_dir / f"{ref_stem}_labelmap_mask.nii.gz" + ref_landmark_path = seg_dir / f"{ref_stem}_landmark.mrk.json" + if not ref_labelmap_path.exists() or not ref_landmark_path.exists(): + print( + f" Skipping {subject_id}: missing reference labelmap or " + f"landmarks under {seg_dir}" + ) + continue + + fixed_image = itk.imread(str(ref_file), pixel_type=itk.F) + fixed_labelmap = itk.imread(str(ref_labelmap_path)) + fixed_mask = load_or_derive_mask(fixed_labelmap, ref_mask_path) + reference_landmarks = landmark_tools.read_landmarks_3dslicer(ref_landmark_path) + + # Gated moving frames (exclude `nop` and the `_ref` frame itself). + gated_files = sorted( + p + for p in src_dir.glob("*.nii.gz") + if not any(token in p.name for token in exclude_tokens) + and not p.name.endswith(f"{ref_suffix}.nii.gz") + ) + moving_records: list[dict[str, object]] = [] + for image_path in gated_files: + stem = image_path.name[:-7] + labelmap_path = seg_dir / f"{stem}_labelmap.nii.gz" + mask_path = seg_dir / f"{stem}_labelmap_mask.nii.gz" + landmark_path = seg_dir / f"{stem}_landmark.mrk.json" + if not labelmap_path.exists(): + print(f" Dropping {stem}: no labelmap at {labelmap_path}") + continue + match = timepoint_re.search(image_path.name) + if match is None: + print(f" Dropping {stem}: no g### timepoint tag in name") + continue + moving_records.append( + { + "stem": stem, + "timepoint": match.group("timepoint"), + "image_path": image_path, + "labelmap_path": labelmap_path, + "mask_path": mask_path, + "landmark_path": landmark_path if landmark_path.exists() else None, + } + ) + if not moving_records: + print(f" Skipping {subject_id}: no usable gated frames") + continue + + print(f" {len(moving_records)} moving frames; reference {ref_file.name}") + + print(f" Loading {len(moving_records)} moving images / labelmaps / masks ...") + moving_images = [] + moving_labelmaps = [] + moving_masks = [] + moving_landmarks_list: list[Optional[dict[str, tuple[float, float, float]]]] = [] + for r_index, r in enumerate(moving_records): + print( + f" [{r_index + 1}/{len(moving_records)}] g{r['timepoint']} {r['stem']}" + ) + moving_image = itk.imread(str(r["image_path"]), pixel_type=itk.F) + labelmap = itk.imread(str(r["labelmap_path"])) + moving_images.append(moving_image) + moving_labelmaps.append(labelmap) + moving_masks.append(load_or_derive_mask(labelmap, r["mask_path"])) + landmark_path = r["landmark_path"] + if landmark_path is None: + moving_landmarks_list.append(None) + else: + moving_landmarks_list.append( + landmark_tools.read_landmarks_3dslicer(landmark_path) + ) + + for method_name in methods: + print(f"\n --- Method: {method_name} ---") + if method_name == "ANTS": + reg = RegisterImagesANTS() + reg.set_number_of_iterations(number_of_iterations_ANTS) + else: + reg = RegisterImagesGreedy() + reg.set_number_of_iterations(number_of_iterations_greedy) + reg.set_transform_type("Deformable") + # NCC ("CC") outperforms MeanSquares for same-modality CT registration. + reg.set_metric("CC") + reg.set_modality("ct") + reg.set_mask_dilation(mask_dilation_mm) + reg.set_fixed_image(fixed_image) + reg.set_fixed_mask(fixed_mask) + + method_dir = output_dir / method_name.lower() / subject_id + method_dir.mkdir(parents=True, exist_ok=True) + + method_t_start = time.perf_counter() + for index, record in enumerate(moving_records): + timepoint = record["timepoint"] + stem = record["stem"] + print( + f" [{method_name} {index + 1}/{len(moving_records)}] " + f"g{timepoint} registering ...", + flush=True, + ) + + frame_t_start = time.perf_counter() + reg_result = reg.register( + moving_image=moving_images[index], + moving_mask=moving_masks[index], + ) + frame_elapsed = time.perf_counter() - frame_t_start + + forward_transform = reg_result["forward_transform"] + inverse_transform = reg_result["inverse_transform"] + frame_loss = float(reg_result["loss"]) + print(f" done in {frame_elapsed:.1f} s, loss={frame_loss:.4f}") + + itk.transformwrite( + forward_transform, + str(method_dir / f"{stem}_fwd.hdf"), + compression=True, + ) + itk.transformwrite( + inverse_transform, + str(method_dir / f"{stem}_inv.hdf"), + compression=True, + ) + + # Visualize the deformation as a warped grid: a regular grid built + # in reference space, resampled through forward_transform -- the same + # transform used below to warp the moving image onto the fixed grid, + # so the grid and the warped image deform consistently. + deformation_grid = transform_tools.convert_field_to_grid_visualization( + forward_transform, fixed_image + ) + itk.imwrite( + deformation_grid, + str(method_dir / f"{stem}_deformation_grid.mha"), + compression=True, + ) + + # Warp the moving image into reference space and save it + # (forward_transform resamples the moving image onto the fixed grid). + warped_image = transform_tools.transform_image( + moving_images[index], + forward_transform, + fixed_image, + interpolation_method="linear", + ) + itk.imwrite( + warped_image, + str(method_dir / f"{stem}.mha"), + compression=True, + ) + + # Warp the moving labelmap onto the fixed grid (forward_transform; + # nearest neighbour preserves label IDs) for per-label Dice. + warped_labelmap = transform_tools.transform_image( + moving_labelmaps[index], + forward_transform, + fixed_labelmap, + interpolation_method="nearest", + ) + itk.imwrite( + warped_labelmap, + str(method_dir / f"{stem}_labelmap.mha"), + compression=True, + ) + + # Warp the moving loss-function mask onto the fixed grid + # (forward_transform; nearest neighbour preserves the binary ROI) + # so downstream fine-tuning reuses it instead of re-deriving a + # mask from the warped labelmap. + warped_mask = transform_tools.transform_image( + moving_masks[index], + forward_transform, + fixed_mask, + interpolation_method="nearest", + ) + itk.imwrite( + warped_mask, + str(method_dir / f"{stem}_labelmap_mask.mha"), + compression=True, + ) + + dice_by_label = per_label_dice(fixed_labelmap, warped_labelmap) + with detail_dice_file.open("a", newline="", encoding="utf-8") as fh: + writer = csv.writer(fh) + if fh.tell() == 0: + writer.writerow( + ["subject_id", "method", "timepoint", "label", "dice"] + ) + for label, dice in dice_by_label.items(): + writer.writerow([subject_id, method_name, timepoint, label, dice]) + mean_dice = ( + float(np.mean(list(dice_by_label.values()))) + if dice_by_label + else float("nan") + ) + + # Warp the moving landmarks into reference space, save them next + # to the transforms, then score squared error vs the reference. + moving_landmarks = moving_landmarks_list[index] + if moving_landmarks is None: + sq_errors: list[tuple[str, float]] = [] + else: + warped_landmarks = warp_landmarks(inverse_transform, moving_landmarks) + landmark_tools.write_landmarks_3dslicer( + warped_landmarks, + str(method_dir / f"{stem}_landmark.mrk.json"), + ) + sq_errors = landmark_squared_errors( + warped_landmarks, reference_landmarks + ) + with detail_landmarks_file.open("a", newline="", encoding="utf-8") as fh: + writer = csv.writer(fh) + if fh.tell() == 0: + writer.writerow( + [ + "subject_id", + "method", + "timepoint", + "name", + "sq_err_mm2", + ] + ) + for name, sq_err in sq_errors: + writer.writerow([subject_id, method_name, timepoint, name, sq_err]) + + sq_values = np.asarray([e for _, e in sq_errors], dtype=np.float64) + if sq_values.size: + mse_mm2 = float(np.mean(sq_values)) + rmse_mm = float(np.sqrt(mse_mm2)) + else: + mse_mm2 = float("nan") + rmse_mm = float("nan") + # Highlight frames with no usable landmarks so they are not + # silently scored as NaN in the CSV / summary table. + reason = ( + "no landmark file" + if moving_landmarks is None + else "no landmarks shared with reference" + ) + frames_missing_landmarks.append((subject_id, method_name, timepoint)) + print( + f" >>> WARNING: {subject_id} {method_name} " + f"g{timepoint} has NO landmarks ({reason})", + flush=True, + ) + + summary_rows.append( + { + "subject_id": subject_id, + "method": method_name, + "timepoint": timepoint, + "time_sec": float(frame_elapsed), + "loss": frame_loss, + "n_labels": int(len(dice_by_label)), + "mean_dice": mean_dice, + "n_landmarks": int(sq_values.size), + "mse_mm2": mse_mm2, + "rmse_mm": rmse_mm, + } + ) + + method_elapsed = time.perf_counter() - method_t_start + print( + f" [{method_name}] subject {subject_id} total " + f"{method_elapsed:.1f} s " + f"({method_elapsed / len(moving_records):.1f} s/frame)" + ) + +# %% [markdown] +# ## 5. Write the per-(subject, method, timepoint) summary CSV + +# %% +if summary_rows: + with summary_file.open("w", newline="", encoding="utf-8") as fh: + writer = csv.DictWriter(fh, fieldnames=list(summary_rows[0].keys())) + writer.writeheader() + writer.writerows(summary_rows) + print(f"\nWrote summary: {summary_file}") + print(f"Wrote landmarks: {detail_landmarks_file}") + print(f"Wrote dice: {detail_dice_file}") +else: + print("\nNo frames processed; nothing to summarize.") + +# %% [markdown] +# ## 5b. Highlight frames that produced no landmark errors + +# %% +if frames_missing_landmarks: + banner = "!" * 70 + print(f"\n{banner}") + print( + f"WARNING: {len(frames_missing_landmarks)} frame(s) missing ALL " + f"landmarks (scored as NaN):" + ) + for subject_id, method_name, timepoint in frames_missing_landmarks: + print(f" - {subject_id} {method_name} g{timepoint}") + print(banner) +else: + print("\nAll processed frames had at least one scored landmark.") + +# %% [markdown] +# ## 6. Per-method aggregate table across the whole cohort +# +# Reports mean per-frame registration time, mean / median / p95 of the +# squared landmark errors (mm^2), the matching RMSE in mm, and the mean +# per-label Dice averaged across (subject, timepoint, label) entries. + +# %% +if summary_rows: + sq_by_method: dict[str, list[float]] = {} + with detail_landmarks_file.open(newline="", encoding="utf-8") as fh: + for row in csv.DictReader(fh): + sq_by_method.setdefault(row["method"], []).append(float(row["sq_err_mm2"])) + + dice_by_method: dict[str, list[float]] = {} + with detail_dice_file.open(newline="", encoding="utf-8") as fh: + for row in csv.DictReader(fh): + dice_by_method.setdefault(row["method"], []).append(float(row["dice"])) + + time_by_method: dict[str, list[float]] = {} + for row in summary_rows: + method_name = str(row["method"]) + time_by_method.setdefault(method_name, []).append(float(row["time_sec"])) + + header = ( + f"{'Method':<10}{'N_lm':>8}{'MSE(mm2)':>12}{'RMSE(mm)':>12}" + f"{'p95(mm2)':>12}{'meanDice':>12}{'sec/frame':>12}" + ) + print() + print("=" * len(header)) + print(f"Pre-registration comparison ({len(all_patient_ids)} patients)") + print("=" * len(header)) + print(header) + print("-" * len(header)) + for method_name in methods: + sq_arr = np.asarray(sq_by_method.get(method_name, []), dtype=np.float64) + dice_arr = np.asarray(dice_by_method.get(method_name, []), dtype=np.float64) + time_arr = np.asarray(time_by_method.get(method_name, []), dtype=np.float64) + if sq_arr.size == 0: + print(f"{method_name:<10}{0:>8}{'':>12}{'':>12}{'':>12}{'':>12}{'':>12}") + continue + mse = float(np.mean(sq_arr)) + rmse = float(np.sqrt(mse)) + p95 = float(np.percentile(sq_arr, 95)) + mean_dice_val = float(np.mean(dice_arr)) if dice_arr.size else float("nan") + mean_time = float(np.mean(time_arr)) if time_arr.size else float("nan") + print( + f"{method_name:<10}" + f"{sq_arr.size:>8}" + f"{mse:>12.3f}" + f"{rmse:>12.3f}" + f"{p95:>12.3f}" + f"{mean_dice_val:>12.3f}" + f"{mean_time:>12.2f}" + ) + print("=" * len(header)) diff --git a/experiments/LongitudinalRegistration/2-recon_4d_icon_eval.py b/experiments/LongitudinalRegistration/2-recon_4d_icon_eval.py index 83ff66c..88cce72 100644 --- a/experiments/LongitudinalRegistration/2-recon_4d_icon_eval.py +++ b/experiments/LongitudinalRegistration/2-recon_4d_icon_eval.py @@ -35,15 +35,21 @@ ref_data_dir = Path("d:/PhysioMotion4D/duke_data/ref_images") timepoint_base_dir = Path("d:/PhysioMotion4D/duke_data/gated_nii") segmentation_base_dir = Path("d:/PhysioMotion4D/duke_data/simple_ascardio") -output_dir = Path("./results") -finetuned_weights_path = Path( - "./results/icon_finetuned/checkpoints/Finetune_multi_final.trch" + +_HERE = Path(__file__).parent +output_dir = _HERE / "results" +finetuned_weights_path = ( + output_dir + / "icon_finetuned" + / "icon_finetuned_model" + / "checkpoints" + / "network_weights_final.trch" ) train_fraction = 0.8 -icon_iterations = 20 +icon_iterations = None reference_percentile = 0.70 -exclude_tokens = ("nop", "dia", "sys", "_ref") +exclude_tokens = ["nop"] timepoint_re = re.compile(r"_g(?P[0-9]{3})") methods: list[tuple[str, Optional[Path]]] = [ @@ -103,15 +109,20 @@ for subject_id in test_subjects: source_dir = timepoint_base_dir / subject_id + print(f"Source directory: {source_dir}") + seg_dir = segmentation_base_dir / subject_id + print(f"Segmentation directory: {seg_dir}") image_files = [ p for p in sorted(source_dir.glob("*.nii.gz")) if not any(t in p.name for t in exclude_tokens) ] + print(f"Found {len(image_files)} image files") stems = [p.name[:-7] for p in image_files] labelmap_files = [seg_dir / f"{s}_labelmap.nii.gz" for s in stems] + mask_files = [seg_dir / f"{s}_labelmap_mask.nii.gz" for s in stems] landmark_files = [seg_dir / f"{s}_landmark.mrk.json" for s in stems] timepoints = [timepoint_re.search(p.name).group("timepoint") for p in image_files] @@ -122,17 +133,32 @@ ) fixed_image = itk.imread(str(image_files[reference_index]), pixel_type=itk.F) - fixed_mask = RegisterImagesICON.create_mask( - itk.imread(str(labelmap_files[reference_index])) - ) + fixed_labelmap = itk.imread(str(labelmap_files[reference_index])) + if mask_files[reference_index].exists(): + fixed_mask = itk.imread(str(mask_files[reference_index])) + else: + fixed_mask = RegisterImagesICON.create_mask(fixed_labelmap, dilation_mm=5.0) + itk.imwrite(fixed_mask, str(mask_files[reference_index]), compression=True) reference_landmarks = landmark_tools.read_landmarks_3dslicer( landmark_files[reference_index] ) moving_images = [itk.imread(str(p), pixel_type=itk.F) for p in image_files] - moving_masks = [ - RegisterImagesICON.create_mask(itk.imread(str(p))) for p in labelmap_files + moving_labelmaps = [itk.imread(str(p)) for p in labelmap_files] + moving_landmarks = [ + landmark_tools.read_landmarks_3dslicer(str(p)) for p in landmark_files ] + moving_masks = [] + for index, p in enumerate(mask_files): + if not p.exists(): + mask = RegisterImagesICON.create_mask( + moving_labelmaps[index], dilation_mm=5.0 + ) + itk.imwrite(mask, str(p), compression=True) + moving_masks.append(mask) + else: + mask = itk.imread(str(p)) + moving_masks.append(mask) for method_name, weights_path in methods: print(f" Method: {method_name}") @@ -147,7 +173,7 @@ result = registrar.register_time_series( moving_images=moving_images, moving_masks=moving_masks, - moving_labelmaps=None, + moving_labelmaps=moving_labelmaps, reference_frame=reference_index, register_reference=False, prior_weight=0.0, @@ -178,9 +204,7 @@ # inverse_transform follows the ITK resampler convention — it maps # moving-grid points back to reference-grid points, which is what # we need to warp time-point landmarks into reference space. - timepoint_landmarks = landmark_tools.read_landmarks_3dslicer( - landmark_files[index] - ) + timepoint_landmarks = moving_landmarks[index] shared = sorted(timepoint_landmarks.keys() & reference_landmarks.keys()) errors: list[tuple[str, float]] = [] for name in shared: diff --git a/experiments/LongitudinalRegistration/3-run_registration_method_comparison.py b/experiments/LongitudinalRegistration/3-run_registration_method_comparison.py index 5467cad..3a81f10 100644 --- a/experiments/LongitudinalRegistration/3-run_registration_method_comparison.py +++ b/experiments/LongitudinalRegistration/3-run_registration_method_comparison.py @@ -494,9 +494,15 @@ def run_method_for_subject( if artifacts.landmark_file is not None: timepoint_landmarks = read_landmarks(artifacts.landmark_file) + # Warp the reference landmarks into the timepoint (moving) space to + # compare against this timepoint's landmarks. Warping reference -> + # time POINTS uses forward_transform (the fixed -> moving point map), + # which is the opposite of the reference_to_time IMAGE above (images + # pull back, points push forward). See + # docs/developer/transform_conventions. direct_landmarks = transform_landmarks( reference_landmarks, - inverse_transform, + forward_transform, ) direct_errors = landmark_errors(direct_landmarks, timepoint_landmarks) write_error_details( diff --git a/experiments/LongitudinalRegistration/registration_test.py b/experiments/LongitudinalRegistration/registration_test.py new file mode 100644 index 0000000..ad2219c --- /dev/null +++ b/experiments/LongitudinalRegistration/registration_test.py @@ -0,0 +1,110 @@ +# %% [markdown] +# # Registration test: pm0003 time point 20 -> time point 60 +# +# Registers pm0003 gated CT time point 20 (moving) to time point 60 +# (fixed) with deformable registration, then warps time point 20 into +# time point 60's space and writes it to disk. +# +# Switch backends by editing the single ``method`` variable below +# ("ANTS", "ICON", or "Greedy"). All paths are hard-coded; run the +# cells top to bottom. + +# %% +import time +from pathlib import Path + +import itk + +from physiomotion4d.register_images_ants import RegisterImagesANTS +from physiomotion4d.register_images_greedy import RegisterImagesGreedy +from physiomotion4d.register_images_icon import RegisterImagesICON +from physiomotion4d.transform_tools import TransformTools + +# %% [markdown] +# ## 1. Configuration and hard-coded paths +# +# Change ``method`` to switch backends. Time point 20 is the moving +# image; time point 60 is the fixed image. + +# %% +method = "Greedy" # one of: "ANTS", "ICON", "Greedy" + +data_dir = Path("d:/PhysioMotion4D/duke_data/gated_nii/pm0003") +moving_path = data_dir / "pm0003_dupr_135-0094_135_4700_g020_s2.000_n0058_11.nii.gz" +fixed_path = data_dir / "pm0003_dupr_135-0094_135_4700_g060_s2.000_n0058_15.nii.gz" + +output_dir = Path(__file__).parent / "results" / "registration_test" +output_dir.mkdir(parents=True, exist_ok=True) +output_path = output_dir / f"pm0003_g020_to_g060_{method.lower()}.mha" + +# %% [markdown] +# ## 2. Load the fixed (time point 60) and moving (time point 20) images + +# %% +fixed_image = itk.imread(str(fixed_path), pixel_type=itk.F) +moving_image = itk.imread(str(moving_path), pixel_type=itk.F) +print(f"Fixed (g060): {fixed_path.name}") +print(f"Moving (g020): {moving_path.name}") + +# %% [markdown] +# ## 3. Build and configure the registration backend +# +# ANTS and Greedy share ``set_transform_type``/``set_metric`` and take a +# per-level iteration list; ICON takes a single iteration count and has +# no transform-type/metric setters. + +# %% +if method == "ANTS": + reg = RegisterImagesANTS() + reg.set_transform_type("Deformable") + reg.set_metric("MeanSquares") + reg.set_number_of_iterations([40, 20, 10]) +elif method == "Greedy": + reg = RegisterImagesGreedy() + reg.set_transform_type("Deformable") + # NCC (CC) beats SSD for same-modality CT; tighter update-field smoothing + # (first sigma) captures more cardiac motion while staying diffeomorphic. + reg.set_metric("CC") + reg.set_number_of_iterations([100, 80, 40]) + reg.deformable_smoothing = "1.0vox 0.5vox" +elif method == "ICON": + reg = RegisterImagesICON() + reg.set_number_of_iterations(50) +else: + raise ValueError(f"Unknown method: {method}") + +reg.set_modality("ct") +reg.set_fixed_image(fixed_image) + +# %% [markdown] +# ## 4. Register time point 20 to time point 60 + +# %% +t_start = time.perf_counter() +reg_result = reg.register(moving_image=moving_image) +elapsed = time.perf_counter() - t_start + +forward_transform = reg_result["forward_transform"] +loss = float(reg_result["loss"]) +print(f"{method} registration done in {elapsed:.1f} s, loss={loss:.4f}") + +# %% [markdown] +# ## 5. Warp time point 20 into time point 60's space and save +# +# ``forward_transform`` is the transform consumed by ``transform_image`` to +# resample the moving image onto the fixed grid (it supplies the fixed->moving +# sampling map the ITK resampler needs). ``inverse_transform`` is the opposite +# direction, used to warp the fixed image onto the moving grid (e.g. in +# ``RegisterTimeSeriesImages.reconstruct_time_series``). This holds for all +# three backends (ANTS, ICON, Greedy). + +# %% +transform_tools = TransformTools() +warped_image = transform_tools.transform_image( + moving_image, + forward_transform, + fixed_image, + interpolation_method="linear", +) +itk.imwrite(warped_image, str(output_path), compression=True) +print(f"Wrote warped time point 20 -> 60: {output_path}") diff --git a/pyproject.toml b/pyproject.toml index cefcec4..7e81117 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -398,12 +398,12 @@ lines-after-imports = 2 max-complexity = 10 [tool.ruff.lint.flake8-quotes] -inline-quotes = "single" +inline-quotes = "double" multiline-quotes = "double" docstring-quotes = "double" [tool.ruff.format] -quote-style = "single" +quote-style = "double" indent-style = "space" skip-magic-trailing-comma = false line-ending = "auto" diff --git a/src/physiomotion4d/register_images_ants.py b/src/physiomotion4d/register_images_ants.py index 182afaf..5231d53 100644 --- a/src/physiomotion4d/register_images_ants.py +++ b/src/physiomotion4d/register_images_ants.py @@ -17,6 +17,7 @@ import itk import numpy as np from numpy.typing import NDArray + from physiomotion4d.register_images_base import RegisterImagesBase from physiomotion4d.transform_tools import TransformTools @@ -526,16 +527,23 @@ def registration_method( region of interest in the moving image moving_image_pre (ants.core.ANTsImage, optional): Pre-processed moving image in ANTs format. If None, preprocessing is performed automatically - initial_forward_transform (itk.Transform, optional): Initial transform from moving - to fixed space. Can be any ITK transform type (Affine, Rigid, - DisplacementField, Composite, etc.). Will be converted to ANTs - format automatically. The returned transforms will include this - initial transform composed with the registration result. + initial_forward_transform (itk.Transform, optional): Initial + forward transform (same convention as the returned + forward_transform: used to warp the moving image onto the fixed + grid). Can be any ITK transform type (Affine, Rigid, + DisplacementField, Composite, etc.). It is applied by pre-warping + the moving image onto the fixed grid before registration; the + returned transforms compose this initial alignment with the + registration refinement. Returns: dict: Dictionary containing: - - "forward_transform": Transformation from moving to fixed - - "inverse_transform": Transformation from fixed to moving + - "forward_transform": Warps the moving image onto the fixed + grid (warping moving points/landmarks into fixed space uses + "inverse_transform" instead -- image and point warps use + opposite transforms; see + docs/developer/transform_conventions) + - "inverse_transform": Warps the fixed image onto the moving grid - "loss": Loss value from the registration Note: @@ -543,11 +551,13 @@ def registration_method( consistent. The forward and inverse transforms are stored separately by ANTs. - IMPORTANT: ANTs registration does NOT include the initial_transform - in its output fwdtransforms/invtransforms. This method automatically - composes the initial transform with the registration result, so the - returned transforms include both the initial alignment and - the registration refinement. + IMPORTANT: the initial transform is applied by pre-warping the + moving image onto the fixed grid (the same approach as + RegisterImagesICON) rather than via ants.registration's + initial_transform argument, which mishandles matrix (affine/ + translation) initials. This method composes the initial transform + with the registration result, so the returned transforms include + both the initial alignment and the registration refinement. Implementation details: - Uses ANTs registration with configurable transform types @@ -584,24 +594,29 @@ def registration_method( if self.fixed_image_pre is None: self.fixed_image_pre = self.preprocess(self.fixed_image, self.modality) - # Convert initial ITK transform to ANTs format if provided - initial_transform: str | list[str] = "identity" + # Apply any initial transform by pre-warping the moving image onto the + # fixed grid (the same approach RegisterImagesICON uses), instead of + # passing it to ants.registration as an initial_transform. ANTS + # mishandles matrix (affine/translation) initial transforms, badly + # corrupting the result; pre-warping keeps the composition below + # self-consistent for any initial transform type. The registration then + # solves the residual and the composition recovers the full transform. if initial_forward_transform is not None: - self.log_info("Converting initial ITK transform to ANTs format...") - initial_transform = self.itk_transform_to_ANTSfile( - itk_tfm=initial_forward_transform, - reference_image=self.fixed_image, - output_filename="initial_transform_temp.mat", + self.log_info("Pre-warping moving image with initial transform...") + transform_tools = TransformTools() + self.moving_image_pre = transform_tools.transform_image( + self.moving_image_pre, + initial_forward_transform, + self.fixed_image, ) - self.log_info("Initial transform converted successfully") transform_type = None if self.transform_type == "Deformable": transform_type = "antsRegistrationSyNQuick[so]" elif self.transform_type == "Affine": - transform_type = "antsRegistrationAffineQuick[so]" + transform_type = "Affine" elif self.transform_type == "Rigid": - transform_type = "antsRegistrationRigidQuick[so]" + transform_type = "Rigid" else: self.log_error("Invalid transform type: %s", self.transform_type) raise ValueError(f"Invalid transform type: {self.transform_type}") @@ -627,13 +642,36 @@ def registration_method( elif self.metric == "MeanSquares": syn_metric = "meansquares" + # antsRegistration --dimensionality 3 --float 0 \ + # --output [$thisfolder/pennTemplate_to_${sub}_,$thisfolder/pennTemplate_to_${sub}_Warped.nii.gz] \ + # --interpolation Linear \ + # --winsorize-image-intensities [0.005,0.995] \ + # --use-histogram-matching 0 \ + # --initial-moving-transform [$t1brain,$template,1] \ + # --transform Rigid[0.1] \ + # --metric MI[$t1brain,$template,1,32,Regular,0.25] \ + # --convergence [1000x500x250x100,1e-6,10] \ + # --shrink-factors 8x4x2x1 \ + # --smoothing-sigmas 3x2x1x0vox \ + # --transform Affine[0.1] \ + # --metric MI[$t1brain,$template,1,32,Regular,0.25] \ + # --convergence [1000x500x250x100,1e-6,10] \ + # --shrink-factors 8x4x2x1 \ + # --smoothing-sigmas 3x2x1x0vox \ + # --transform SyN[0.1,3,0] \ + # --metric CC[$t1brain,$template,1,4] \ + # --convergence [100x70x50x20,1e-6,10] \ + # --shrink-factors 8x4x2x1 \ + # --smoothing-sigmas 3x2x1x0vox \ + # -x $brainlesionmask + if self.fixed_mask is not None and self.moving_mask is not None: registration_result = ants.registration( fixed=self._itk_to_ants_image(self.fixed_image_pre), mask=self._itk_to_ants_image(self.fixed_mask), moving=self._itk_to_ants_image(self.moving_image_pre), moving_mask=self._itk_to_ants_image(self.moving_mask), - initial_transform=[initial_transform], + initial_transform=["identity"], type_of_transform=transform_type, aff_metric=aff_metric, syn_metric=syn_metric, @@ -646,7 +684,7 @@ def registration_method( registration_result = ants.registration( fixed=self._itk_to_ants_image(self.fixed_image_pre), moving=self._itk_to_ants_image(self.moving_image_pre), - initial_transform=[initial_transform], + initial_transform=["identity"], type_of_transform=transform_type, aff_metric=aff_metric, syn_metric=syn_metric, diff --git a/src/physiomotion4d/register_images_base.py b/src/physiomotion4d/register_images_base.py index 463e26f..d833bb3 100644 --- a/src/physiomotion4d/register_images_base.py +++ b/src/physiomotion4d/register_images_base.py @@ -59,8 +59,8 @@ class and implement the register() method. ... def registration_method(self, moving_image, **kwargs): ... # Implement specific registration algorithm ... return { - ... 'forward_transform': tfm_forward, # Moving → Fixed - ... 'inverse_transform': tfm_inverse, # Fixed → Moving + ... 'forward_transform': tfm_forward, # warps moving image -> fixed grid + ... 'inverse_transform': tfm_inverse, # warps fixed image -> moving grid ... 'loss': 0.0, ... } >>> @@ -68,8 +68,8 @@ class and implement the register() method. >>> registrar.set_modality('ct') >>> registrar.set_fixed_image(reference_image) >>> result = registrar.register(moving_image) - >>> forward_tfm = result['forward_transform'] # Moving → Fixed - >>> inverse_tfm = result['inverse_transform'] # Fixed → Moving + >>> forward_tfm = result['forward_transform'] # warps moving image -> fixed grid + >>> inverse_tfm = result['inverse_transform'] # warps fixed image -> moving grid """ def __init__(self, log_level: int | str = logging.INFO) -> None: @@ -251,8 +251,11 @@ def registration_method( Returns: dict: Dictionary containing: - - "forward_transform": Transform that warps moving image into fixed space - - "inverse_transform": Transform that warps fixed image into moving space + - "forward_transform": Warps the moving image onto the fixed + grid. Warping moving points/landmarks into fixed space uses + "inverse_transform" instead (see register() and + docs/developer/transform_conventions). + - "inverse_transform": Warps the fixed image onto the moving grid - "loss": Registration loss/metric value Raises: @@ -283,13 +286,24 @@ def register( Returns: dict: Dictionary containing transformation results: - - "forward_transform": Transforms moving image to fixed space (warps moving → fixed) - - "inverse_transform": Transforms fixed image to moving space (warps fixed → moving) + - "forward_transform": Warps the moving IMAGE onto the fixed + grid, i.e. transform_image(moving, forward_transform, fixed). + - "inverse_transform": Warps the fixed IMAGE onto the moving + grid, i.e. transform_image(fixed, inverse_transform, moving). - "loss": Registration loss/metric value Note: - - forward_transform: Use this to warp the moving image to match the fixed image - - inverse_transform: Use this to warp the fixed image to match the moving image + Image warps and point/landmark warps use OPPOSITE members of the + transform pair, because ITK image resampling pulls back (it maps a + fixed-grid sample to the moving image) while point transforms push + forward (they map a point to its corresponding location): + + - Warp the moving image into fixed space -> forward_transform + - Warp moving points/landmarks into fixed -> inverse_transform + - Warp the fixed image into moving space -> inverse_transform + - Warp fixed points/landmarks into moving -> forward_transform + + See docs/developer/transform_conventions for the full discussion. Raises: NotImplementedError: This method must be implemented by subclasses diff --git a/src/physiomotion4d/register_images_greedy.py b/src/physiomotion4d/register_images_greedy.py index d3caed8..8a7a6cb 100644 --- a/src/physiomotion4d/register_images_greedy.py +++ b/src/physiomotion4d/register_images_greedy.py @@ -12,6 +12,8 @@ from __future__ import annotations import logging +import os +import tempfile from typing import Any, Optional, Union import itk @@ -135,6 +137,30 @@ def _greedy_iterations_str(self) -> str: """Format iterations as Greedy -n string (e.g. 40x20x10).""" return "x".join(str(i) for i in self.number_of_iterations) + def _write_affine_matrix_file(self, mat_4x4: NDArray[np.float64]) -> str: + """Write a 4x4 RAS affine matrix to a temporary Greedy ``.mat`` file. + + Greedy's in-memory interface corrupts the heap when a numpy affine + matrix is supplied as an initial transform (``-ia``/``-it``); passing a + file path instead avoids that native crash. Greedy reads a plain-text + 4x4 RAS matrix, which is what ``numpy.savetxt`` writes here. + + Args: + mat_4x4: 4x4 affine matrix in RAS (Greedy) convention. + + Returns: + Path to the written temporary ``.mat`` file. The caller is + responsible for deleting it. + """ + mat_4x4 = np.asarray(mat_4x4, dtype=np.float64) + if mat_4x4.shape != (4, 4): + raise ValueError(f"Expected 4x4 matrix, got shape {mat_4x4.shape}") + fd, path = tempfile.mkstemp(suffix=".mat", prefix="greedy_aff_") + os.close(fd) + np.savetxt(path, mat_4x4, fmt="%.10f") + self.log_debug("Wrote Greedy affine init matrix to %s", path) + return path + def _matrix_to_itk_affine(self, mat_4x4: NDArray[np.float64]) -> itk.Transform: """Convert 4x4 affine matrix to ITK AffineTransform.""" mat_4x4 = np.asarray(mat_4x4, dtype=np.float64) @@ -195,17 +221,26 @@ def _registration_method_affine_or_rigid( cmd += " -gm fixed_mask -mm moving_mask" kwargs["fixed_mask"] = fixed_mask_sitk kwargs["moving_mask"] = moving_mask_sitk + # Greedy crashes (heap corruption) when an initial affine is passed as an + # in-memory matrix; write it to a temp file and pass the path instead. + initial_affine_file: Optional[str] = None if initial_affine is not None: - cmd += " -ia aff_initial" - kwargs["aff_initial"] = initial_affine + initial_affine_file = self._write_affine_matrix_file(initial_affine) + cmd += f" -ia {initial_affine_file}" - g.execute(cmd, **kwargs) + self.log_debug("Greedy affine/rigid command: %s", cmd) + try: + g.execute(cmd, **kwargs) + finally: + if initial_affine_file is not None: + os.remove(initial_affine_file) mat = np.array(g["aff_out"], dtype=np.float64) try: ml = g.metric_log() loss = float(ml[-1]["TotalPerPixelMetric"][-1]) if ml else 0.0 except Exception: loss = 0.0 + self.log_info("Greedy affine/rigid registration loss: %s", loss) return mat, loss def _registration_method_deformable( @@ -230,17 +265,21 @@ def _registration_method_deformable( cmd_aff += " -gm fixed_mask -mm moving_mask" kwargs_aff["fixed_mask"] = fixed_mask_sitk kwargs_aff["moving_mask"] = moving_mask_sitk + self.log_debug("Greedy deformable affine-init command: %s", cmd_aff) g.execute(cmd_aff, **kwargs_aff) initial_affine = np.array(g["aff_init"], dtype=np.float64) + self.log_info("Greedy deformable affine init complete") + # Greedy crashes (heap corruption) when the affine init is passed as an + # in-memory matrix via -it; write it to a temp file and pass the path. + initial_affine_file = self._write_affine_matrix_file(initial_affine) cmd_def = ( - f"-i fixed moving -it aff_init -n {iterations_str} " + f"-i fixed moving -it {initial_affine_file} -n {iterations_str} " f"-m {metric_str} -s {self.deformable_smoothing} -o warp_out" ) kwargs_def = { "fixed": fixed_sitk, "moving": moving_sitk, - "aff_init": initial_affine, "warp_out": None, } if fixed_mask_sitk is not None and moving_mask_sitk is not None: @@ -248,13 +287,18 @@ def _registration_method_deformable( kwargs_def["fixed_mask"] = fixed_mask_sitk kwargs_def["moving_mask"] = moving_mask_sitk - g.execute(cmd_def, **kwargs_def) + self.log_debug("Greedy deformable command: %s", cmd_def) + try: + g.execute(cmd_def, **kwargs_def) + finally: + os.remove(initial_affine_file) warp_out = g["warp_out"] try: ml = g.metric_log() loss = float(ml[-1]["TotalPerPixelMetric"][-1]) if ml else 0.0 except Exception: loss = 0.0 + self.log_info("Greedy deformable registration loss: %s", loss) return initial_affine, warp_out, loss def registration_method( @@ -270,6 +314,13 @@ def registration_method( Converts ITK images to SimpleITK, runs Greedy (affine and/or deformable), then converts outputs back to ITK transforms. Composes with initial_forward_transform when provided. + + Returns a dict with "forward_transform", "inverse_transform", and + "loss". As with the other image-registration backends, + forward_transform warps the moving image onto the fixed grid and + inverse_transform warps the fixed image onto the moving grid; point and + landmark warps use the opposite transform from image warps (see + docs/developer/transform_conventions). """ if self.fixed_image is None or self.fixed_image_pre is None: raise ValueError("Fixed image must be set before registration.") @@ -371,13 +422,17 @@ def registration_method( ) disp_tfm = itk.DisplacementFieldTransform[itk.D, 3].New() disp_tfm.SetDisplacementField(disp_itk) - # Forward = moving -> fixed: first affine then deformable in Greedy + # forward_transform is consumed by transform_image(moving, ..., + # fixed) to warp the moving image onto the fixed grid, so it holds + # Greedy's raw affine+warp (Greedy applies the affine first, then + # the warp). inverse_transform is the numerically inverted field, + # used to warp the fixed image onto the moving grid. This matches + # RegisterImagesANTS/ICON and RegisterTimeSeriesImages. forward_composite = itk.CompositeTransform[itk.D, 3].New() if aff_tfm is not None: forward_composite.AddTransform(aff_tfm) forward_composite.AddTransform(disp_tfm) forward_transform = forward_composite - # Inverse: inverse warp then inverse affine inv_disp = TransformTools().invert_displacement_field_transform(disp_tfm) inv_aff = itk.AffineTransform[itk.D, 3].New() if aff_tfm is not None: diff --git a/src/physiomotion4d/register_images_icon.py b/src/physiomotion4d/register_images_icon.py index 28bd03d..4ab1d52 100644 --- a/src/physiomotion4d/register_images_icon.py +++ b/src/physiomotion4d/register_images_icon.py @@ -10,6 +10,8 @@ """ import logging +import pathlib +import sys from typing import Optional, Union import icon_registration as icon @@ -88,6 +90,10 @@ def set_weights_path(self, weights_path: str) -> None: pretrained weights. Clears any previously loaded network so the new weights are applied on the next call to register(). + Also, use this to specify the path to store the downloaded weights. The + file must not exist for the weights to be downloaded correctly. Typical + suffix is ".trch". + Args: weights_path: Path to a uniGradICON checkpoint, e.g. "results/duke_4d_finetune/checkpoints/network_weights_100" @@ -185,16 +191,21 @@ def registration_method( Returns: dict: Dictionary containing: - - "forward_transform": transform moving image into fixed space - - "inverse_transform": transform fixed image to moving space + - "forward_transform": Warps the moving image onto the fixed + grid (warping moving points/landmarks into fixed space uses + "inverse_transform" instead -- image and point warps use + opposite transforms; see + docs/developer/transform_conventions) + - "inverse_transform": Warps the fixed image onto the moving grid - "loss": Loss value from the registration Note: The transformations are inverse consistent, meaning - forward_transform ≈ inverse(inverse_transform). - The inverse_transform is used to warp the fixed image - to the moving image space. The forward_transform is used - to warp the moving image to the fixed image space. + forward_transform is approximately inverse(inverse_transform). + Use forward_transform to warp the moving image onto the fixed grid, + and inverse_transform to warp the fixed image onto the moving grid. + Point/landmark warps use the opposite transform from image warps + (see docs/developer/transform_conventions). Implementation details: - Uses UniGradIcon with LNCC loss function @@ -292,13 +303,29 @@ def _ensure_net(self) -> None: """ if self.net is not None: return + main_module = sys.modules.get("__main__") + main_file = getattr(main_module, "__file__", None) + top_dir = pathlib.Path.cwd() + if main_file is not None: + top_dir = pathlib.Path(main_file).resolve().parent if self.use_multi_modality: + if self.weights_path is None: + self.weights_path = str( + top_dir + / "network_weights" + / "multigradicon1.0" + / "Step_2_final.trch" + ) self.net = get_multigradicon( loss_fn=icon.LNCC(sigma=5), apply_intensity_conservation_loss=self.use_mass_preservation, weights_location=self.weights_path, ) else: + if self.weights_path is None: + self.weights_path = str( + top_dir / "network_weights" / "unigradicon1.0" / "Step_2_final.trch" + ) self.net = get_unigradicon( loss_fn=icon.LNCC(sigma=5), apply_intensity_conservation_loss=self.use_mass_preservation, diff --git a/src/physiomotion4d/register_models_distance_maps.py b/src/physiomotion4d/register_models_distance_maps.py index 4b8090e..2c30d4a 100644 --- a/src/physiomotion4d/register_models_distance_maps.py +++ b/src/physiomotion4d/register_models_distance_maps.py @@ -41,7 +41,7 @@ >>> >>> # Access results >>> aligned_model = result['registered_model'] - >>> forward_transform = result['forward_transform'] # Moving to fixed transform + >>> forward_transform = result['forward_transform'] # warps moving image -> fixed grid """ import logging @@ -74,8 +74,15 @@ class RegisterModelsDistanceMaps(PhysioMotion4DBase): - **Optional**: ICON deep learning refinement after any mode **Transform Convention:** - - forward_transform: Moving → fixed space transformation - - inverse_transform: Fixed → moving space transformation + These are the underlying image-registration (ANTs/ICON) transforms, so + they follow the image convention (see + docs/developer/transform_conventions): + + - forward_transform: warps the moving image/mask onto the fixed grid. + Warping the moving MODEL points/landmarks onto the fixed model uses + inverse_transform instead (image and point warps use opposite + transforms). + - inverse_transform: warps the fixed image/mask onto the moving grid. Attributes: moving_model (pv.PolyData): Surface model to be aligned diff --git a/src/physiomotion4d/register_models_icp.py b/src/physiomotion4d/register_models_icp.py index b3b2b61..3236f02 100644 --- a/src/physiomotion4d/register_models_icp.py +++ b/src/physiomotion4d/register_models_icp.py @@ -61,10 +61,16 @@ class RegisterModelsICP(PhysioMotion4DBase): - **Affine transform type**: Centroid alignment → Rigid ICP → Affine ICP **Transform Convention:** - - forward_point_transform: moving → fixed space transformation - (This is the inverse of the transform used to wrap the moving image to the - fixed image) - - inverse_point_transform: moving → fixed space transformation + These are POINT transforms (applied with TransformPoint, e.g. via + TransformTools.transform_pvcontour), so their orientation is opposite to + the image-registration transforms (see + docs/developer/transform_conventions): + + - forward_point_transform: maps moving points -> fixed points; use it to + warp the moving model/landmarks onto the fixed model. This is the + inverse of the transform that would warp the moving IMAGE onto the + fixed grid. + - inverse_point_transform: maps fixed points -> moving points. Attributes: moving_model (pv.PolyData): Surface model to be aligned diff --git a/src/physiomotion4d/register_models_pca.py b/src/physiomotion4d/register_models_pca.py index d230e0a..339b34a 100644 --- a/src/physiomotion4d/register_models_pca.py +++ b/src/physiomotion4d/register_models_pca.py @@ -42,10 +42,15 @@ class RegisterModelsPCA(PhysioMotion4DBase): pca_coefficients (np.ndarray): Optimized PCA coefficients registered_model (pv.DataSet): Final registered and deformed model post_pca_transform (itk.Transform): Transform to apply after PCA registration - forward_point_transform (itk.DisplacementFieldTransform): Forward displacement field transform - (Does not include the post-PCA transform) - inverse_point_transform (itk.DisplacementFieldTransform): Inverse displacement field transform - (Does not include the post-PCA transform) + forward_point_transform (itk.DisplacementFieldTransform): POINT transform + mapping template points -> registered/target points; use it to warp + the template model/landmarks onto the target. Its orientation is + opposite to an image-registration forward_transform (see + docs/developer/transform_conventions). Does not include the post-PCA + transform. + inverse_point_transform (itk.DisplacementFieldTransform): POINT transform + mapping target points -> template points. Does not include the + post-PCA transform. Example: >>> # Load PCA model data @@ -741,8 +746,14 @@ def compute_pca_transforms(self, reference_image: itk.Image) -> dict: Returns: Dictionary containing: - - 'forward_point_transform': Forward displacement field transform - - 'inverse_point_transform': Inverse displacement field transform + - 'forward_point_transform': POINT transform mapping template + points -> target points (warps the template onto the target) + - 'inverse_point_transform': POINT transform mapping target + points -> template points + + Note: + These are point transforms, oriented opposite to image-registration + transforms; see docs/developer/transform_conventions. """ assert self.registered_model_pca_deformation is not None, ( "PCA deformation must be computed" diff --git a/src/physiomotion4d/register_time_series_images.py b/src/physiomotion4d/register_time_series_images.py index 5e14f5a..89dd02d 100644 --- a/src/physiomotion4d/register_time_series_images.py +++ b/src/physiomotion4d/register_time_series_images.py @@ -75,8 +75,8 @@ class RegisterTimeSeriesImages(RegisterImagesBase): ... prior_weight=0.5, ... ) >>> - >>> forward_tfms = result['forward_transforms'] # Moving → Fixed - >>> inverse_tfms = result['inverse_transforms'] # Fixed → Moving + >>> forward_tfms = result['forward_transforms'] # warp moving images -> fixed grid + >>> inverse_tfms = result['inverse_transforms'] # warp fixed image -> moving grids >>> losses = result['losses'] >>> >>> # Reconstruct time series with optional upsampling @@ -205,6 +205,16 @@ def set_fixed_mask(self, fixed_mask: Optional[itk.Image]) -> None: """ self.fixed_mask = fixed_mask + def set_fixed_labelmap(self, fixed_labelmap: Optional[itk.Image]) -> None: + """Set a labelmap for the fixed image region of interest. + + This passes through to the underlying registration method. + + Args: + fixed_labelmap (itk.Image): Labelmap defining ROI + """ + self.fixed_labelmap = fixed_labelmap + def register_time_series( self, moving_images: list[itk.Image], @@ -247,10 +257,14 @@ def register_time_series( Returns: dict: Dictionary containing results: - - "forward_transforms" (list[itk.Transform]): Transforms from moving to fixed - space for each image (warps moving → fixed) - - "inverse_transforms" (list[itk.Transform]): Transforms from fixed to moving - space for each image (warps fixed → moving) + - "forward_transforms" (list[itk.Transform]): one per image; + each warps its moving image onto the fixed grid (warping + moving points/landmarks into fixed space uses the matching + inverse transform instead -- see + docs/developer/transform_conventions) + - "inverse_transforms" (list[itk.Transform]): one per image; + each warps the fixed image onto that moving image's grid + (used by reconstruct_time_series) - "losses" (list[float]): Registration loss value for each image Raises: @@ -277,6 +291,7 @@ def register_time_series( >>> result = registrar.register_time_series( ... moving_images=image_list, ... moving_masks=mask_list, # Optional + ... moving_labelmaps=labelmap_list, # Optional ... reference_frame=5, ... register_reference=True, ... prior_weight=0.5, @@ -630,7 +645,8 @@ def reconstruct_time_series( Args: moving_images (list[itk.Image]): List of moving images to reconstruct inverse_transforms (list[itk.Transform]): List of inverse transforms - (one per moving image) from fixed space to moving space + (one per moving image), each used to warp the fixed image onto + that moving image's grid upsample_to_fixed_resolution (bool, optional): If True, reconstructed images will be upsampled to isotropic resolution (mean of fixed image's X and Y spacing) while maintaining their original origin and direction. diff --git a/src/physiomotion4d/workflow_fine_tune_icon_registration.py b/src/physiomotion4d/workflow_fine_tune_icon_registration.py index 223a5c0..0ab551d 100644 --- a/src/physiomotion4d/workflow_fine_tune_icon_registration.py +++ b/src/physiomotion4d/workflow_fine_tune_icon_registration.py @@ -148,7 +148,7 @@ def __init__( similarity: str = "lncc", lambda_value: float = 1.5, dice_loss_weight: float = 0.5, - lncc_sigma: int = 5, + lncc_sigma: int = 1, ct_window: tuple[float, float] = (-1000.0, 1000.0), is_ct: bool = True, gpus: Optional[list[int]] = None, @@ -530,7 +530,7 @@ def prepare_config(self, dataset_json_path: Optional[Path] = None) -> Path: "dice_loss_weight": self.dice_loss_weight, "lncc_sigma": self.lncc_sigma, "loss_function_masking": self.uses_masks, - "use_label": self.uses_segmentations, + "use_label": False, "roi_masking": False, }, "datasets": [ diff --git a/src/physiomotion4d/workflow_reconstruct_highres_4d_ct.py b/src/physiomotion4d/workflow_reconstruct_highres_4d_ct.py index ebfb732..b3105c5 100644 --- a/src/physiomotion4d/workflow_reconstruct_highres_4d_ct.py +++ b/src/physiomotion4d/workflow_reconstruct_highres_4d_ct.py @@ -61,8 +61,10 @@ class WorkflowReconstructHighres4DCT(PhysioMotion4DBase): registration_method (str): Registration method ('ANTS', 'ICON', or 'ANTS_ICON') number_of_iterations: Iterations for registration registrar (RegisterTimeSeriesImages): Internal registration object - forward_transforms (list[itk.Transform]): Forward transforms (moving → fixed) - inverse_transforms (list[itk.Transform]): Inverse transforms (fixed → moving) + forward_transforms (list[itk.Transform]): one per frame; each warps its + moving image onto the fixed grid + inverse_transforms (list[itk.Transform]): one per frame; each warps the + fixed image onto that frame's moving grid (used for reconstruction) losses (list[float]): Registration loss values reconstructed_images (list[itk.Image]): Reconstructed high-resolution images @@ -260,10 +262,11 @@ def register_time_series(self) -> dict: Returns: dict: Dictionary containing: - - 'forward_transforms' (list[itk.Transform]): Transforms from moving - to fixed space (warps moving → fixed) - - 'inverse_transforms' (list[itk.Transform]): Transforms from fixed - to moving space (warps fixed → moving) + - 'forward_transforms' (list[itk.Transform]): one per frame; + each warps its moving image onto the fixed grid + - 'inverse_transforms' (list[itk.Transform]): one per frame; + each warps the fixed image onto that frame's moving grid + (see docs/developer/transform_conventions) - 'losses' (list[float]): Registration loss value for each image Raises: diff --git a/tests/test_register_images_ants.py b/tests/test_register_images_ants.py index 8102c58..61c778d 100644 --- a/tests/test_register_images_ants.py +++ b/tests/test_register_images_ants.py @@ -20,6 +20,27 @@ from physiomotion4d.transform_tools import TransformTools +def _foreground_ncc( + reference_arr: np.ndarray, warped_arr: np.ndarray, foreground: np.ndarray +) -> float: + """Normalized cross-correlation over a foreground mask (higher = better). + + Args: + reference_arr: Reference image array (e.g. the fixed image), axes (Z, Y, X). + warped_arr: Warped image array on the same grid/axes as ``reference_arr``. + foreground: Boolean mask (same shape) selecting the voxels to score. + + Returns: + NCC in [-1, 1] over the foreground voxels (nan if degenerate). + """ + a = reference_arr[foreground].astype(np.float64) + b = warped_arr[foreground].astype(np.float64) + a0 = a - a.mean() + b0 = b - b.mean() + denom = float(np.sqrt((a0 * a0).sum() * (b0 * b0).sum())) + return float((a0 * b0).sum() / denom) if denom > 0 else float("nan") + + @pytest.mark.slow class TestRegisterImagesANTS: """Test suite for ANTs-based image registration.""" @@ -309,6 +330,218 @@ def test_registration_with_initial_transform( print("Registration with initial transform complete") + def test_initial_transform_composition_metrics( + self, + registrar_ANTS: RegisterImagesANTS, + test_images: list[Any], + test_directories: dict[str, Path], + ) -> None: + """Verify the initial_forward_transform composition path with metrics. + + Exercises the two initial-transform inputs the platform actually uses + (identity and a prior deformable forward_transform, as in prior-based + time-series registration) and confirms the composed forward_transform + warps the moving image onto the fixed grid. Scored with foreground NCC + over the brightest 30% of the fixed image (tissue/blood pool). See + docs/developer/transform_conventions. + + Asserted facts: + * a plain registration improves on the unregistered pair, + * an identity initial reproduces the baseline exactly (the + composition machinery is a structurally correct no-op; note an + identity AffineTransform is itself a matrix initial), + * a prior-deformable initial reaches the no-initial baseline quality + (the composition recovers the full transform). + + The initial transform is applied by pre-warping the moving image (as in + RegisterImagesICON), which keeps the composition self-consistent for any + initial transform type. + """ + output_dir = test_directories["output"] + reg_output_dir = output_dir / "registration_ANTS" + reg_output_dir.mkdir(exist_ok=True) + + # Pick two phases that are far apart in the cycle so there is real motion. + fixed_image = test_images[0] + moving_image = test_images[min(10, len(test_images) - 1)] + + fixed_arr = itk.array_from_image(fixed_image) + # Moving and fixed share the acquisition grid (split from one 4D image), + # so the moving array is directly comparable for the unregistered score. + moving_arr = itk.array_from_image(moving_image) + threshold = float(np.percentile(fixed_arr, 70.0)) + foreground = fixed_arr > threshold + + transform_tools = TransformTools() + + def warp_score(forward_transform: Any) -> float: + warped = transform_tools.transform_image( + moving_image, + forward_transform, + fixed_image, + interpolation_method="linear", + ) + return _foreground_ncc(fixed_arr, itk.array_from_image(warped), foreground) + + ncc_unregistered = _foreground_ncc(fixed_arr, moving_arr, foreground) + + # Baseline: no initial transform. + registrar_ANTS.set_modality("ct") + registrar_ANTS.set_fixed_image(fixed_image) + baseline = registrar_ANTS.register(moving_image=moving_image) + ncc_baseline = warp_score(baseline["forward_transform"]) + + # Identity initial: the composition machinery must be a no-op. + identity = itk.AffineTransform[itk.D, 3].New() + identity.SetIdentity() + registrar_identity = RegisterImagesANTS() + registrar_identity.set_modality("ct") + registrar_identity.set_fixed_image(fixed_image) + identity_result = registrar_identity.register( + moving_image=moving_image, initial_forward_transform=identity + ) + ncc_identity = warp_score(identity_result["forward_transform"]) + + # Prior deformable initial: the realistic time-series prior use case. + registrar_prior = RegisterImagesANTS() + registrar_prior.set_modality("ct") + registrar_prior.set_fixed_image(fixed_image) + prior_result = registrar_prior.register( + moving_image=moving_image, + initial_forward_transform=baseline["forward_transform"], + ) + ncc_prior = warp_score(prior_result["forward_transform"]) + + print("\nANTS initial-transform composition metrics (foreground NCC):") + print(f" unregistered: {ncc_unregistered:.4f}") + print(f" baseline (no initial): {ncc_baseline:.4f}") + print(f" identity initial: {ncc_identity:.4f}") + print(f" prior-deformable init: {ncc_prior:.4f}") + + warped_prior = transform_tools.transform_image( + moving_image, + prior_result["forward_transform"], + fixed_image, + interpolation_method="linear", + ) + itk.imwrite( + warped_prior, + str(reg_output_dir / "ants_warped_prior_initial.mha"), + compression=True, + ) + + # Registration must improve alignment over the unregistered pair. + assert ncc_baseline > ncc_unregistered, ( + f"Baseline registration did not improve alignment: " + f"{ncc_baseline:.4f} <= {ncc_unregistered:.4f}" + ) + # Identity initial must reproduce the baseline (composition is a no-op). + assert abs(ncc_identity - ncc_baseline) < 0.03, ( + f"Identity initial transform changed the result: " + f"identity={ncc_identity:.4f} vs baseline={ncc_baseline:.4f}" + ) + # A prior-deformable initial must reach the no-initial baseline quality + # (the composition recovers the full transform). + assert ncc_prior >= ncc_baseline - 0.03, ( + f"Prior-initial composition did not reach baseline quality: " + f"{ncc_prior:.4f} < {ncc_baseline:.4f} - 0.03" + ) + + def test_initial_transform_matrix_composition( + self, + registrar_ANTS: RegisterImagesANTS, + test_images: list[Any], + ) -> None: + """A matrix (translation/affine) initial composes without corruption. + + Regression guard for the previously-broken matrix initial_transform + path: feeding a translation initial used to corrupt the composition + (foreground NCC far below the unregistered pair). With the moving image + pre-warped by the initial, the composed forward_transform must align the + moving image onto the fixed grid at least as well as the unregistered + pair. + """ + fixed_image = test_images[0] + moving_image = test_images[min(10, len(test_images) - 1)] + + fixed_arr = itk.array_from_image(fixed_image) + threshold = float(np.percentile(fixed_arr, 70.0)) + foreground = fixed_arr > threshold + ncc_unregistered = _foreground_ncc( + fixed_arr, itk.array_from_image(moving_image), foreground + ) + + translation = itk.TranslationTransform[itk.D, 3].New() + translation.SetOffset([-5.0, -5.0, -5.0]) + + registrar_ANTS.set_modality("ct") + registrar_ANTS.set_fixed_image(fixed_image) + result = registrar_ANTS.register( + moving_image=moving_image, initial_forward_transform=translation + ) + + transform_tools = TransformTools() + warped = transform_tools.transform_image( + moving_image, + result["forward_transform"], + fixed_image, + interpolation_method="linear", + ) + ncc = _foreground_ncc(fixed_arr, itk.array_from_image(warped), foreground) + print( + f"\nMatrix-initial composed NCC={ncc:.4f} (unregistered={ncc_unregistered:.4f})" + ) + assert ncc > ncc_unregistered, ( + f"Matrix-initial composition worse than unregistered: " + f"{ncc:.4f} <= {ncc_unregistered:.4f}" + ) + + def test_affine_and_rigid_transform_types( + self, + registrar_ANTS: RegisterImagesANTS, + test_images: list[Any], + ) -> None: + """Affine and Rigid transform types run and improve alignment. + + Regression guard for the ANTS preset names: ``set_transform_type`` + previously mapped Affine/Rigid to ``antsRegistration{Affine,Rigid}Quick`` + preset strings that do not exist in antspy, raising ValueError. Each + type must now run and warp the moving image onto the fixed grid at least + as well as the unregistered pair. + """ + fixed_image = test_images[0] + moving_image = test_images[min(10, len(test_images) - 1)] + + fixed_arr = itk.array_from_image(fixed_image) + threshold = float(np.percentile(fixed_arr, 70.0)) + foreground = fixed_arr > threshold + ncc_unregistered = _foreground_ncc( + fixed_arr, itk.array_from_image(moving_image), foreground + ) + + transform_tools = TransformTools() + for transform_type in ("Rigid", "Affine"): + registrar = RegisterImagesANTS() + registrar.set_modality("ct") + registrar.set_transform_type(transform_type) + registrar.set_fixed_image(fixed_image) + result = registrar.register(moving_image=moving_image) + warped = transform_tools.transform_image( + moving_image, + result["forward_transform"], + fixed_image, + interpolation_method="linear", + ) + ncc = _foreground_ncc(fixed_arr, itk.array_from_image(warped), foreground) + print( + f"\n{transform_type} transform NCC={ncc:.4f} " + f"(unregistered={ncc_unregistered:.4f})" + ) + assert ncc > ncc_unregistered, ( + f"{transform_type} registration did not improve alignment: " + f"{ncc:.4f} <= {ncc_unregistered:.4f}" + ) + def test_multiple_registrations( self, registrar_ANTS: RegisterImagesANTS, test_images: list[Any] ) -> None: diff --git a/tutorials/tutorial_08_dirlab_pca_time_series.py b/tutorials/tutorial_08_dirlab_pca_time_series.py index 232e8ed..75cd063 100644 --- a/tutorials/tutorial_08_dirlab_pca_time_series.py +++ b/tutorials/tutorial_08_dirlab_pca_time_series.py @@ -170,9 +170,16 @@ def run_tutorial() -> dict[str, Any]: compression=True, ) + # Warp the reference-space fitted mesh into this phase's space. + # Warping reference -> phase POINTS uses the forward transform + # (the fixed -> moving point map), which is the opposite of the + # transform used to warp an image into phase space (images pull + # back, points push forward). The forward transform is named + # "phase_to_reference" after its image-warp role. See + # docs/developer/transform_conventions. phase_mesh = transform_tools.transform_pvcontour( fitted_reference_mesh, - reference_to_phase, + phase_to_reference, with_deformation_magnitude=True, ) phase_mesh_file = meshes_dir / f"{phase_name}_pca_fit.vtp" From cd1d3e99875618e6c2a5ae535ea5b3c4b94b04fc Mon Sep 17 00:00:00 2001 From: Stephen Aylward Date: Sat, 30 May 2026 12:28:57 -0400 Subject: [PATCH 02/10] COMP: Remove warnings and fix failing tests --- docs/api/utilities/index.rst | 2 + docs/api/utilities/labelmap_tools.rst | 19 +++ .../1-finetune_icon.py | 11 +- .../1-preregistration.py | 22 ++-- .../2-recon_4d_icon_eval.py | 15 ++- pyproject.toml | 12 +- src/physiomotion4d/__init__.py | 2 + src/physiomotion4d/contour_tools.py | 2 +- src/physiomotion4d/image_tools.py | 9 +- src/physiomotion4d/labelmap_tools.py | 100 ++++++++++++++++ src/physiomotion4d/register_images_ants.py | 4 +- src/physiomotion4d/register_images_base.py | 29 ++--- src/physiomotion4d/register_images_icon.py | 36 ------ .../register_models_distance_maps.py | 16 ++- src/physiomotion4d/segment_anatomy_base.py | 23 ---- .../segment_heart_simpleware.py | 4 +- .../workflow_fine_tune_icon_registration.py | 50 +++++--- ...rkflow_fit_statistical_model_to_patient.py | 18 ++- tests/test_image_tools.py | 4 +- tests/test_labelmap_tools.py | 109 ++++++++++++++++++ ...st_workflow_fine_tune_icon_registration.py | 40 ++----- 21 files changed, 347 insertions(+), 180 deletions(-) create mode 100644 docs/api/utilities/labelmap_tools.rst create mode 100644 src/physiomotion4d/labelmap_tools.py create mode 100644 tests/test_labelmap_tools.py diff --git a/docs/api/utilities/index.rst b/docs/api/utilities/index.rst index 0e4d2c0..62277bb 100644 --- a/docs/api/utilities/index.rst +++ b/docs/api/utilities/index.rst @@ -21,6 +21,7 @@ Quick Links **Utility Modules**: * :doc:`image_tools` - Image processing utilities + * :doc:`labelmap_tools` - Labelmap to registration-mask conversion * :doc:`transform_tools` - Transform operations * :doc:`contour_tools` - Contour processing * :doc:`image_conversion` - 4D image to 3D time-series utilities @@ -34,6 +35,7 @@ Module Documentation :maxdepth: 2 image_tools + labelmap_tools transform_tools contour_tools image_conversion diff --git a/docs/api/utilities/labelmap_tools.rst b/docs/api/utilities/labelmap_tools.rst new file mode 100644 index 0000000..7b453b5 --- /dev/null +++ b/docs/api/utilities/labelmap_tools.rst @@ -0,0 +1,19 @@ +==================================== +Labelmap Tools +==================================== + +.. currentmodule:: physiomotion4d + +Convert segmentation labelmaps into binary registration masks, with optional +label exclusion and physically isotropic dilation. + +Module Reference +================ + +.. automodule:: physiomotion4d.labelmap_tools + :members: + :undoc-members: + +.. rubric:: Navigation + +:doc:`index` | :doc:`image_tools` | :doc:`transform_tools` diff --git a/experiments/LongitudinalRegistration/1-finetune_icon.py b/experiments/LongitudinalRegistration/1-finetune_icon.py index a11118e..8ed8f6d 100644 --- a/experiments/LongitudinalRegistration/1-finetune_icon.py +++ b/experiments/LongitudinalRegistration/1-finetune_icon.py @@ -31,7 +31,7 @@ import itk from physiomotion4d import WorkflowFineTuneICONRegistration -from physiomotion4d.register_images_icon import RegisterImagesICON +from physiomotion4d.labelmap_tools import LabelmapTools # %% [markdown] # ## 1. Configure data, output locations, and the train/test split @@ -145,7 +145,7 @@ # %% [markdown] # ## 4. Pre-compute loss-function masks next to each labelmap # -# Use :meth:`RegisterImagesICON.create_mask` (``>0`` threshold + 5 mm +# Use :meth:`LabelmapTools.convert_labelmap_to_mask` (``>0`` threshold + 5 mm # physical-radius dilation) to derive each frame's binary heart-ROI mask and # write it as ``_mask.nii.gz`` in the labelmap's own directory. # Pre-computing here means the workflow does not have to re-derive masks @@ -154,13 +154,14 @@ # %% mask_dilation_mm = 5.0 +labelmap_tools = LabelmapTools() def derive_mask_for(labelmap_path: Path) -> str: """Create (or reuse) a loss-function mask next to ``labelmap_path``. Thresholds the labelmap at ``>0`` and dilates by ``mask_dilation_mm`` mm - via :meth:`RegisterImagesICON.create_mask`, writing the result as + via :meth:`LabelmapTools.convert_labelmap_to_mask`, writing the result as ``_mask.nii.gz`` in the labelmap's own directory. Handles both ``.nii.gz`` (original Simpleware labelmaps) and ``.mha`` (pre-registration warped labelmaps). Returns the mask path as a string; @@ -175,8 +176,8 @@ def derive_mask_for(labelmap_path: Path) -> str: stem = labelmap_path.stem mask_p = labelmap_path.parent / f"{stem}_mask.nii.gz" if not mask_p.exists(): - mask = RegisterImagesICON.create_mask( - itk.imread(str(labelmap_path)), dilation_mm=mask_dilation_mm + mask = labelmap_tools.convert_labelmap_to_mask( + itk.imread(str(labelmap_path)), dilation_in_mm=mask_dilation_mm ) itk.imwrite(mask, str(mask_p), compression=True) return str(mask_p) diff --git a/experiments/LongitudinalRegistration/1-preregistration.py b/experiments/LongitudinalRegistration/1-preregistration.py index dfd8ac4..fab754f 100644 --- a/experiments/LongitudinalRegistration/1-preregistration.py +++ b/experiments/LongitudinalRegistration/1-preregistration.py @@ -50,10 +50,10 @@ import itk import numpy as np +from physiomotion4d.labelmap_tools import LabelmapTools from physiomotion4d.landmark_tools import LandmarkTools from physiomotion4d.register_images_ants import RegisterImagesANTS from physiomotion4d.register_images_greedy import RegisterImagesGreedy -from physiomotion4d.register_images_icon import RegisterImagesICON from physiomotion4d.transform_tools import TransformTools # %% [markdown] @@ -78,6 +78,7 @@ # Mask dilation matches 1-finetune_icon.py so any masks we have to # derive here are identical to the ones written by the fine-tune script. mask_dilation_mm = 5.0 +labelmap_tools = LabelmapTools() # Iteration schedules. Kept modest for a cohort-wide comparison; raise # either list for higher accuracy at the cost of runtime. @@ -199,13 +200,20 @@ def landmark_squared_errors( def load_or_derive_mask(labelmap: itk.Image, mask_path: Path) -> itk.Image: """Return the cached ``_labelmap_mask.nii.gz`` next to the - labelmap, or derive it via :meth:`RegisterImagesICON.create_mask` - (threshold ``>0`` plus 5 mm physical-radius dilation) and write it - out so subsequent runs and the ICON eval reuse the same mask. + labelmap, or derive it via + :meth:`LabelmapTools.convert_labelmap_to_mask` (threshold ``>0`` plus + 5 mm physical-radius dilation) and write it out so subsequent runs and + the ICON eval reuse the same mask. """ - if mask_path.exists(): - return itk.imread(str(mask_path)) - mask = RegisterImagesICON.create_mask(labelmap, dilation_mm=mask_dilation_mm) + # Force mask update + # if mask_path.exists(): + # return itk.imread(str(mask_path)) + mask = labelmap_tools.convert_labelmap_to_mask( + labelmap, + dilation_in_mm=mask_dilation_mm, + labels_to_exclude=[1, 2, 3, 4], + # These are labels for the interior of the heart chambers (the LV, RV, LA, RA) + ) itk.imwrite(mask, str(mask_path), compression=True) return mask diff --git a/experiments/LongitudinalRegistration/2-recon_4d_icon_eval.py b/experiments/LongitudinalRegistration/2-recon_4d_icon_eval.py index 88cce72..bd462b4 100644 --- a/experiments/LongitudinalRegistration/2-recon_4d_icon_eval.py +++ b/experiments/LongitudinalRegistration/2-recon_4d_icon_eval.py @@ -25,8 +25,8 @@ import numpy as np from physiomotion4d import RegisterTimeSeriesImages +from physiomotion4d.labelmap_tools import LabelmapTools from physiomotion4d.landmark_tools import LandmarkTools -from physiomotion4d.register_images_icon import RegisterImagesICON # %% [markdown] # ## 1. Hard-coded paths and configuration @@ -93,12 +93,13 @@ # Landmarks are read with :meth:`LandmarkTools.read_landmarks_3dslicer` — # they were written as ``_landmark.mrk.json`` (3D Slicer Markups JSON, # LPS) by ``0-cardiacGatedCT_segment_and_landmark.py``. Binary registration -# masks come from :meth:`RegisterImagesICON.create_mask` (``>0`` threshold -# plus 5 mm dilation by default), matching the loss-function masks used +# masks come from :meth:`LabelmapTools.convert_labelmap_to_mask` (``>0`` +# threshold plus 5 mm dilation), matching the loss-function masks used # during fine-tuning in ``1-finetune_icon.py``. # %% landmark_tools = LandmarkTools() +labelmap_tools = LabelmapTools() # %% [markdown] @@ -137,7 +138,9 @@ if mask_files[reference_index].exists(): fixed_mask = itk.imread(str(mask_files[reference_index])) else: - fixed_mask = RegisterImagesICON.create_mask(fixed_labelmap, dilation_mm=5.0) + fixed_mask = labelmap_tools.convert_labelmap_to_mask( + fixed_labelmap, dilation_in_mm=5.0 + ) itk.imwrite(fixed_mask, str(mask_files[reference_index]), compression=True) reference_landmarks = landmark_tools.read_landmarks_3dslicer( landmark_files[reference_index] @@ -151,8 +154,8 @@ moving_masks = [] for index, p in enumerate(mask_files): if not p.exists(): - mask = RegisterImagesICON.create_mask( - moving_labelmaps[index], dilation_mm=5.0 + mask = labelmap_tools.convert_labelmap_to_mask( + moving_labelmaps[index], dilation_in_mm=5.0 ) itk.imwrite(mask, str(p), compression=True) moving_masks.append(mask) diff --git a/pyproject.toml b/pyproject.toml index 7e81117..70d0aa2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -280,7 +280,6 @@ minversion = "7.0" addopts = [ "--strict-markers", "--strict-config", - "-W", "always", "--cov=physiomotion4d", "--cov-report=term-missing", "--cov-report=html", @@ -288,6 +287,17 @@ addopts = [ ] testpaths = ["tests"] pythonpath = ["."] +# Surface every warning by default ("always"), then silence the third-party +# ITK/SWIG binding noise emitted while wrapped C++ types are defined at import +# on CPython >=3.12 ("builtin type Swig... has no __module__ attribute"). +# Those come from the bindings, not our code, and would otherwise drown the +# warnings summary. Order matters: the trailing "ignore" entries are applied +# last and so take precedence over "always" for these specific messages. +filterwarnings = [ + "always", + "ignore:builtin type Swig", + "ignore:builtin type swigvarlink", +] markers = [ "unit: marks tests as unit tests (fast, isolated)", "integration: marks tests as integration tests (slower, multiple components)", diff --git a/src/physiomotion4d/__init__.py b/src/physiomotion4d/__init__.py index 5d00e32..63fc3a7 100644 --- a/src/physiomotion4d/__init__.py +++ b/src/physiomotion4d/__init__.py @@ -43,6 +43,7 @@ # Utility classes from .image_tools import ImageTools +from .labelmap_tools import LabelmapTools from .landmark_tools import LandmarkTools # Base classes @@ -106,6 +107,7 @@ "PhysioMotion4DBase", # Utility classes "ImageTools", + "LabelmapTools", "LandmarkTools", "TestTools", "TransformTools", diff --git a/src/physiomotion4d/contour_tools.py b/src/physiomotion4d/contour_tools.py index 9f3028e..6e72efe 100644 --- a/src/physiomotion4d/contour_tools.py +++ b/src/physiomotion4d/contour_tools.py @@ -230,7 +230,7 @@ def create_mask_from_mesh( # Direction: use identity for now (axis-aligned), will be handled by resampling # Flip Z axis to match ITK convention - ref_dir = np.array(binary_image.GetDirection()) + ref_dir = itk.array_from_matrix(binary_image.GetDirection()) ref_dir[2, 2] = -ref_dir[2, 2] binary_image.SetDirection(ref_dir) diff --git a/src/physiomotion4d/image_tools.py b/src/physiomotion4d/image_tools.py index c6e4023..c031603 100644 --- a/src/physiomotion4d/image_tools.py +++ b/src/physiomotion4d/image_tools.py @@ -279,9 +279,12 @@ def flip_image( flip1 = False flip2 = False if flip_and_make_identity: - flip0 = np.array(in_image.GetDirection())[0, 0] < 0 - flip1 = np.array(in_image.GetDirection())[1, 1] < 0 - flip2 = np.array(in_image.GetDirection())[2, 2] < 0 + # itk.array_from_matrix avoids itk.Matrix.__array__, whose missing + # copy keyword triggers a numpy>=2.0 DeprecationWarning. + direction = itk.array_from_matrix(in_image.GetDirection()) + flip0 = direction[0, 0] < 0 + flip1 = direction[1, 1] < 0 + flip2 = direction[2, 2] < 0 if flip_x: flip0 = True if flip_y: diff --git a/src/physiomotion4d/labelmap_tools.py b/src/physiomotion4d/labelmap_tools.py new file mode 100644 index 0000000..8e2d067 --- /dev/null +++ b/src/physiomotion4d/labelmap_tools.py @@ -0,0 +1,100 @@ +""" +Labelmap Tools for PhysioMotion4D + +This module provides the :class:`LabelmapTools` class with the definitive +utility for turning a multi-label (or binary) segmentation labelmap into a +binary registration mask, optionally excluding specific labels and dilating +the result by a physical radius in millimeters. +""" + +import logging +from typing import Optional + +import itk +import numpy as np + +from physiomotion4d.physiomotion4d_base import PhysioMotion4DBase + + +class LabelmapTools(PhysioMotion4DBase): + """ + Utilities for converting segmentation labelmaps into registration masks. + + A labelmap is an ``itk.Image`` of integer labels where ``0`` is background + and each positive value identifies an anatomical structure. A registration + mask is a binary ``itk.Image`` where every foreground voxel is ``1``. This + class centralizes the labelmap-to-mask conversion so that thresholding, + label exclusion, and physically isotropic dilation are performed + identically everywhere in the platform. + + Example: + >>> tools = LabelmapTools() + >>> # Binary mask of every labeled voxel, dilated 5 mm + >>> mask = tools.convert_labelmap_to_mask(labelmap, dilation_in_mm=5.0) + >>> # Exclude the table/background labels 8 and 9 before masking + >>> mask = tools.convert_labelmap_to_mask( + ... labelmap, dilation_in_mm=5.0, labels_to_exclude=[8, 9] + ... ) + """ + + def __init__(self, log_level: int | str = logging.INFO) -> None: + """Initialize LabelmapTools. + + Args: + log_level: Logging level (default: logging.INFO) + """ + super().__init__(class_name=self.__class__.__name__, log_level=log_level) + + def convert_labelmap_to_mask( + self, + labelmap: itk.Image, + dilation_in_mm: float = 0.0, + labels_to_exclude: Optional[list[int]] = None, + ) -> itk.Image: + """Convert a labelmap into a binary registration mask. + + Any voxel whose label is in ``labels_to_exclude`` is set to background + first; every remaining non-zero voxel becomes foreground (``1``). The + binary mask is then dilated by ``dilation_in_mm`` millimeters of + physical radius. The radius is converted into per-axis voxel counts + from the labelmap's spacing so the dilation is physically isotropic + even on anisotropic grids; each per-axis count is clamped to at least + 1 voxel when ``dilation_in_mm > 0``. + + Axis ordering: the labelmap is a scalar 3D ``itk.Image`` in ITK + world-axis order (X, Y, Z). All thresholding is performed on the numpy + view (Z, Y, X) and written back through ``CopyInformation``, so origin, + spacing, and direction are preserved. + + Args: + labelmap: Multi-label or binary ``itk.Image``. Any non-zero voxel + that is not excluded is treated as foreground. + dilation_in_mm: Physical radius of the binary dilation in + millimeters. Pass ``0`` (or negative) to skip dilation and + return the raw thresholded mask. Default 0.0. + labels_to_exclude: Optional list of integer label values to force + to background before thresholding. When ``None`` (the default) + no labels are excluded. + + Returns: + ``itk.Image[itk.UC, 3]`` binary mask in the same physical space as + ``labelmap`` (origin, spacing, direction copied from the input). + """ + arr = itk.array_from_image(labelmap) + if labels_to_exclude: + arr = np.where(np.isin(arr, labels_to_exclude), 0, arr) + mask_arr = (arr > 0).astype(np.uint8) + mask = itk.image_from_array(mask_arr) + mask.CopyInformation(labelmap) + + if dilation_in_mm <= 0: + return mask + + spacing = labelmap.GetSpacing() + radius = itk.Size[3]() + for i in range(3): + radius[i] = max(1, int(round(dilation_in_mm / float(spacing[i])))) + structuring_element = itk.FlatStructuringElement[3].Ball(radius) + return itk.binary_dilate_image_filter( + mask, kernel=structuring_element, foreground_value=1 + ) diff --git a/src/physiomotion4d/register_images_ants.py b/src/physiomotion4d/register_images_ants.py index 5231d53..7412763 100644 --- a/src/physiomotion4d/register_images_ants.py +++ b/src/physiomotion4d/register_images_ants.py @@ -200,9 +200,7 @@ def _itk_to_ants_image( image_dimension = len(spatial_shape) - direction = np.asarray(itk_image.GetDirection()).reshape( - (image_dimension, image_dimension) - ) + direction = itk.array_from_matrix(itk_image.GetDirection()) spacing = list(itk_image.GetSpacing()) origin = list(itk_image.GetOrigin()) diff --git a/src/physiomotion4d/register_images_base.py b/src/physiomotion4d/register_images_base.py index d833bb3..1c397fb 100644 --- a/src/physiomotion4d/register_images_base.py +++ b/src/physiomotion4d/register_images_base.py @@ -19,9 +19,8 @@ from typing import Any, Optional, Union import itk -import numpy as np -from itk import TubeTK as ttk +from physiomotion4d.labelmap_tools import LabelmapTools from physiomotion4d.physiomotion4d_base import PhysioMotion4DBase from physiomotion4d.transform_tools import TransformTools @@ -84,6 +83,8 @@ def __init__(self, log_level: int | str = logging.INFO) -> None: """ super().__init__(class_name=self.__class__.__name__, log_level=log_level) + self.labelmap_tools = LabelmapTools(log_level=log_level) + self.net: Any = None self.modality: str = "ct" @@ -180,16 +181,10 @@ def set_fixed_mask(self, fixed_mask: Optional[itk.Image]) -> None: if self.fixed_image is None: raise ValueError("Fixed image must be set before setting a fixed mask.") - mask_arr = itk.GetArrayFromImage(fixed_mask) - mask_arr = np.where(mask_arr > 0, 1, 0) - self.fixed_mask = itk.GetImageFromArray(mask_arr.astype(np.uint8)) + self.fixed_mask = self.labelmap_tools.convert_labelmap_to_mask( + fixed_mask, dilation_in_mm=self.mask_dilation_mm + ) self.fixed_mask.CopyInformation(self.fixed_image) - if self.mask_dilation_mm > 0: - imMath = ttk.ImageMath.New(self.fixed_mask) - imMath.Dilate( - int(self.mask_dilation_mm / self.fixed_image.GetSpacing()[0]), 1, 0 - ) - self.fixed_mask = imMath.GetOutputUChar() def set_fixed_labelmap(self, fixed_labelmap: Optional[itk.Image]) -> None: """Set the fixed image labelmap (multi-label segmentation). @@ -327,16 +322,10 @@ def register( new_moving_mask = moving_mask if moving_mask is not None: - mask_arr = itk.GetArrayFromImage(moving_mask) - mask_arr = np.where(mask_arr > 0, 1, 0) - new_moving_mask = itk.GetImageFromArray(mask_arr.astype(np.uint8)) + new_moving_mask = self.labelmap_tools.convert_labelmap_to_mask( + moving_mask, dilation_in_mm=self.mask_dilation_mm + ) new_moving_mask.CopyInformation(moving_image) - if self.mask_dilation_mm > 0: - imMath = ttk.ImageMath.New(new_moving_mask) - imMath.Dilate( - int(self.mask_dilation_mm / moving_image.GetSpacing()[0]), 1, 0 - ) - new_moving_mask = imMath.GetOutputUChar() self.moving_image = moving_image self.moving_image_pre = moving_image_pre diff --git a/src/physiomotion4d/register_images_icon.py b/src/physiomotion4d/register_images_icon.py index 4ab1d52..3d28dad 100644 --- a/src/physiomotion4d/register_images_icon.py +++ b/src/physiomotion4d/register_images_icon.py @@ -373,42 +373,6 @@ def _image_to_resized_tensor( tensor, size=shape[2:], mode="trilinear", align_corners=False ) - @staticmethod - def create_mask(labelmap: itk.Image, dilation_mm: float = 5.0) -> itk.Image: - """Create a binary registration mask from a labelmap. - - Thresholds the labelmap at ``>0`` (so every non-zero label becomes - foreground) and dilates the result by ``dilation_mm`` millimeters of - physical radius. The radius is converted into per-axis voxel counts - from the labelmap's spacing so the dilation is physically isotropic - even on anisotropic grids; each per-axis count is clamped to at least - 1 voxel when ``dilation_mm > 0``. - - Args: - labelmap: Multi-label or binary ``itk.Image``. Any non-zero voxel - is treated as foreground. - dilation_mm: Physical radius of the binary dilation in - millimeters. Pass ``0`` (or negative) to skip dilation and - return the raw ``>0`` mask. Default 5.0 mm. - - Returns: - ``itk.Image[itk.UC, 3]`` binary mask in the same physical space as - ``labelmap`` (origin, spacing, direction copied from the input). - """ - arr = (itk.array_from_image(labelmap) > 0).astype(np.uint8) - mask = itk.image_from_array(arr) - mask.CopyInformation(labelmap) - if dilation_mm <= 0: - return mask - spacing = labelmap.GetSpacing() - radius = itk.Size[3]() - for i in range(3): - radius[i] = max(1, int(round(dilation_mm / float(spacing[i])))) - structuring_element = itk.FlatStructuringElement[3].Ball(radius) - return itk.binary_dilate_image_filter( - mask, kernel=structuring_element, foreground_value=1 - ) - def _mask_to_resized_tensor( self, mask: itk.Image, shape: torch.Size ) -> torch.Tensor: diff --git a/src/physiomotion4d/register_models_distance_maps.py b/src/physiomotion4d/register_models_distance_maps.py index 2c30d4a..c03751f 100644 --- a/src/physiomotion4d/register_models_distance_maps.py +++ b/src/physiomotion4d/register_models_distance_maps.py @@ -49,9 +49,9 @@ import itk import pyvista as pv -from itk import TubeTK as ttk from physiomotion4d.contour_tools import ContourTools +from physiomotion4d.labelmap_tools import LabelmapTools from physiomotion4d.physiomotion4d_base import PhysioMotion4DBase from physiomotion4d.register_images_ants import RegisterImagesANTS from physiomotion4d.register_images_icon import RegisterImagesICON @@ -155,6 +155,7 @@ def __init__( # Utilities self.transform_tools = TransformTools() self.contour_tools = ContourTools() + self.labelmap_tools = LabelmapTools(log_level=log_level) # Registration instances self.registrar_ANTS = RegisterImagesANTS(log_level=log_level) @@ -201,12 +202,9 @@ def _create_masks_from_models(self) -> None: mask = self.contour_tools.create_mask_from_mesh( self.fixed_model, self.reference_image ) - imMath = ttk.ImageMath.New(mask) - dilation_voxels = int( - self.roi_dilation_mm / self.reference_image.GetSpacing()[0] + self.fixed_mask_roi_image = self.labelmap_tools.convert_labelmap_to_mask( + mask, dilation_in_mm=self.roi_dilation_mm ) - imMath.Dilate(dilation_voxels, 1, 0) - self.fixed_mask_roi_image = imMath.GetOutput() # Create moving mask self.moving_mask_image = self.contour_tools.create_distance_map( @@ -223,9 +221,9 @@ def _create_masks_from_models(self) -> None: mask = self.contour_tools.create_mask_from_mesh( self.moving_model, self.reference_image ) - imMath = ttk.ImageMath.New(self.moving_mask_image) - imMath.Dilate(dilation_voxels, 1, 0) - self.moving_mask_roi_image = imMath.GetOutputUChar() + self.moving_mask_roi_image = self.labelmap_tools.convert_labelmap_to_mask( + mask, dilation_in_mm=self.roi_dilation_mm + ) self.log_info("Mask generation complete") diff --git a/src/physiomotion4d/segment_anatomy_base.py b/src/physiomotion4d/segment_anatomy_base.py index ca2a788..8368d6b 100644 --- a/src/physiomotion4d/segment_anatomy_base.py +++ b/src/physiomotion4d/segment_anatomy_base.py @@ -557,29 +557,6 @@ def segmentation_method(self, preprocessed_image: itk.image) -> itk.image: """ raise NotImplementedError("This method should be implemented by the subclass.") - def dilate_mask(self, mask: itk.image, dilation: int) -> itk.image: - """ - Dilate a binary mask using morphological operations. - - Expands the mask regions by the specified number of pixels to create - larger regions of interest. Useful for creating candidate regions or - ensuring complete coverage of anatomical structures. - - Args: - mask (itk.image): The binary mask to dilate - dilation (int): Number of pixels to dilate in each direction - - Returns: - itk.image: The dilated binary mask - - Example: - >>> dilated_heart = segmenter.dilate_mask(heart_mask, 5) - """ - imMath = tube.ImageMath.New(mask) - imMath.Dilate(dilation, 1, 0) - dilated_mask = imMath.GetOutputUChar() - return dilated_mask - def segment( self, input_image: itk.image, diff --git a/src/physiomotion4d/segment_heart_simpleware.py b/src/physiomotion4d/segment_heart_simpleware.py index 15a1031..1d59e4a 100644 --- a/src/physiomotion4d/segment_heart_simpleware.py +++ b/src/physiomotion4d/segment_heart_simpleware.py @@ -338,8 +338,8 @@ def segmentation_method(self, preprocessed_image: itk.image) -> itk.image: ) if mask_image is not None: - in_direction = np.array(preprocessed_image.GetDirection()) - out_direction = np.array(mask_image.GetDirection()) + in_direction = itk.array_from_matrix(preprocessed_image.GetDirection()) + out_direction = itk.array_from_matrix(mask_image.GetDirection()) flip = [False, False, False] for i in range(3): if np.sign(out_direction[i, i]) != np.sign(in_direction[i, i]): diff --git a/src/physiomotion4d/workflow_fine_tune_icon_registration.py b/src/physiomotion4d/workflow_fine_tune_icon_registration.py index 0ab551d..8f0c6b0 100644 --- a/src/physiomotion4d/workflow_fine_tune_icon_registration.py +++ b/src/physiomotion4d/workflow_fine_tune_icon_registration.py @@ -42,8 +42,8 @@ import numpy as np import yaml +from physiomotion4d.labelmap_tools import LabelmapTools from physiomotion4d.physiomotion4d_base import PhysioMotion4DBase -from physiomotion4d.register_images_icon import RegisterImagesICON from physiomotion4d.register_time_series_images import RegisterTimeSeriesImages from physiomotion4d.transform_tools import TransformTools @@ -95,7 +95,7 @@ class WorkflowFineTuneICONRegistration(PhysioMotion4DBase): traceability; not consumed by uniGradICON fine-tuning itself. mask_dilation_mm (float): Millimeters of physical-radius binary dilation applied to the >0 labelmap when deriving the loss-masking - binary mask via :meth:`RegisterImagesICON.create_mask`. + binary mask via :meth:`LabelmapTools.convert_labelmap_to_mask`. mask_dir (Optional[Path]): Directory where derived binary masks are written and looked up. ``None`` (default) writes each derived mask next to its source labelmap on disk. @@ -205,8 +205,8 @@ def __init__( mask_dilation_mm: Physical radius (millimeters) of binary dilation applied to the >0 labelmap when deriving the loss-masking binary mask via - :meth:`RegisterImagesICON.create_mask`. Ignored when no - segmentations are supplied. Default 5.0 mm. + :meth:`LabelmapTools.convert_labelmap_to_mask`. Ignored when + no segmentations are supplied. Default 5.0 mm. mask_dir: Directory where derived binary masks are written and looked up. ``None`` (default) writes each derived mask next to its source labelmap on disk @@ -283,6 +283,7 @@ def __init__( ) self.transform_tools = TransformTools() + self.labelmap_tools = LabelmapTools(log_level=log_level) self.registrar: Optional[RegisterTimeSeriesImages] = None self._dataset_json_path: Optional[Path] = None @@ -310,7 +311,7 @@ def _validate_companion_shape( ) @property - def uses_segmentations(self) -> bool: + def use_segmentations(self) -> bool: """Whether at least one segmentation file is supplied for training. Drives the uniGradICON ``training.use_label`` flag. @@ -318,14 +319,24 @@ def uses_segmentations(self) -> bool: return self._any_non_none(self.subject_segmentation_files) @property - def uses_masks(self) -> bool: + def use_masks(self) -> bool: """Whether the dataset will have a ``mask`` field on every kept entry. True when explicit masks are supplied OR when segmentations are supplied (since masks are then derived). Drives the uniGradICON ``training.loss_function_masking`` flag. """ - return self._any_non_none(self.subject_mask_files) or self.uses_segmentations + return self._any_non_none(self.subject_mask_files) or self.use_segmentations + + @property + def use_label(self) -> bool: + """Whether uniGradICON trains with label supervision. + + Drives the uniGradICON ``training.use_label`` flag. True exactly when + segmentations are supplied, since the dataset then carries a ``label`` + field on every kept entry. + """ + return self.use_segmentations @staticmethod def _any_non_none( @@ -349,8 +360,9 @@ def _derive_mask(self, labelmap_path: Union[str, Path]) -> Path: """Create (or reuse) a dilated binary mask from a multi-label labelmap. Threshold the labelmap at ``>0`` and dilate by ``mask_dilation_mm`` mm - of physical radius via :meth:`RegisterImagesICON.create_mask` to widen - the ROI for loss-function masking. + of physical radius via + :meth:`LabelmapTools.convert_labelmap_to_mask` to widen the ROI for + loss-function masking. When :attr:`mask_dir` is ``None`` (the default) the mask is written next to the source labelmap as @@ -378,8 +390,8 @@ def _derive_mask(self, labelmap_path: Union[str, Path]) -> Path: return mask_path labelmap = itk.imread(str(labelmap_path)) - mask = RegisterImagesICON.create_mask( - labelmap, dilation_mm=self.mask_dilation_mm + mask = self.labelmap_tools.convert_labelmap_to_mask( + labelmap, dilation_in_mm=self.mask_dilation_mm ) itk.imwrite(mask, str(mask_path), compression=True) return mask_path @@ -406,8 +418,8 @@ def prepare_dataset(self) -> Path: does not exist on disk. """ self.experiment_dir.mkdir(parents=True, exist_ok=True) - use_seg = self.uses_segmentations - use_mask = self.uses_masks + use_seg = self.use_segmentations + use_mask = self.use_masks dataset_entries: list[dict[str, str]] = [] for subject_index, image_files in enumerate(self.subject_image_files): @@ -529,8 +541,8 @@ def prepare_config(self, dataset_json_path: Optional[Path] = None) -> Path: "lambda": self.lambda_value, "dice_loss_weight": self.dice_loss_weight, "lncc_sigma": self.lncc_sigma, - "loss_function_masking": self.uses_masks, - "use_label": False, + "loss_function_masking": self.use_masks, + "use_label": self.use_label, "roi_masking": False, }, "datasets": [ @@ -735,8 +747,8 @@ def apply_registration( self.log_info("ICON weights: %s", weights_path) fixed_mask = ( - RegisterImagesICON.create_mask( - reference_segmentation, dilation_mm=self.mask_dilation_mm + self.labelmap_tools.convert_labelmap_to_mask( + reference_segmentation, dilation_in_mm=self.mask_dilation_mm ) if reference_segmentation is not None else None @@ -745,8 +757,8 @@ def apply_registration( if moving_segmentations is not None: moving_masks = [ ( - RegisterImagesICON.create_mask( - seg, dilation_mm=self.mask_dilation_mm + self.labelmap_tools.convert_labelmap_to_mask( + seg, dilation_in_mm=self.mask_dilation_mm ) if seg is not None else None diff --git a/src/physiomotion4d/workflow_fit_statistical_model_to_patient.py b/src/physiomotion4d/workflow_fit_statistical_model_to_patient.py index a533f4f..0e2c384 100644 --- a/src/physiomotion4d/workflow_fit_statistical_model_to_patient.py +++ b/src/physiomotion4d/workflow_fit_statistical_model_to_patient.py @@ -30,6 +30,7 @@ import pyvista as pv from physiomotion4d.contour_tools import ContourTools +from physiomotion4d.labelmap_tools import LabelmapTools from physiomotion4d.physiomotion4d_base import PhysioMotion4DBase from physiomotion4d.register_images_ants import RegisterImagesANTS from physiomotion4d.register_images_icon import RegisterImagesICON @@ -199,6 +200,7 @@ def __init__( # Utilities (needed for create_reference_image when patient_image is None) self.transform_tools = TransformTools() self.contour_tools = ContourTools() + self.labelmap_tools = LabelmapTools() if patient_image is not None: self.patient_image = patient_image @@ -319,11 +321,9 @@ def _auto_generate_mask( if dilate_mm is None: dilate_mm = self.mask_dilation_mm if dilate_mm > 0: - ttk = _load_tubetk() - imMath = ttk.ImageMath.New(mask) - dilation_voxels = int(dilate_mm / self.patient_image.GetSpacing()[0]) - imMath.Dilate(dilation_voxels, 1, 0) - mask = imMath.GetOutputUChar() + mask = self.labelmap_tools.convert_labelmap_to_mask( + mask, dilation_in_mm=dilate_mm + ) self.log_info("Masks auto-generated successfully.") @@ -349,11 +349,9 @@ def _auto_generate_roi_mask( # Generate model ROI mask roi = None if dilate_mm > 0: - ttk = _load_tubetk() - imMath = ttk.ImageMath.New(mask) - dilation_voxels = int(dilate_mm / mask.GetSpacing()[0]) - imMath.Dilate(dilation_voxels, 1, 0) - roi = imMath.GetOutputUChar() + roi = self.labelmap_tools.convert_labelmap_to_mask( + mask, dilation_in_mm=dilate_mm + ) else: roi = mask diff --git a/tests/test_image_tools.py b/tests/test_image_tools.py index 7e48b7c..100a35c 100644 --- a/tests/test_image_tools.py +++ b/tests/test_image_tools.py @@ -428,7 +428,7 @@ def test_flip_and_make_identity_sets_direction_to_identity( direction = np.diag([-1.0, 1.0, 1.0]) itk_image = _make_synthetic_itk_image(shape_xyz, arr=arr, direction=direction) out = image_tools.flip_image(itk_image, flip_and_make_identity=True) - out_direction = np.array(out.GetDirection()) + out_direction = itk.array_from_matrix(out.GetDirection()) identity = np.eye(3) assert np.allclose(out_direction, identity), ( "flip_and_make_identity should set direction to identity" @@ -450,7 +450,7 @@ def test_flip_and_make_identity_with_mask_sets_both_directions_to_identity( itk_image, in_mask=itk_mask, flip_and_make_identity=True ) for im, name in [(out_image, "image"), (out_mask, "mask")]: - dir_mat = np.array(im.GetDirection()) + dir_mat = itk.array_from_matrix(im.GetDirection()) assert np.allclose(dir_mat, np.eye(3)), ( f"flip_and_make_identity should set {name} direction to identity" ) diff --git a/tests/test_labelmap_tools.py b/tests/test_labelmap_tools.py new file mode 100644 index 0000000..627c9e3 --- /dev/null +++ b/tests/test_labelmap_tools.py @@ -0,0 +1,109 @@ +#!/usr/bin/env python +""" +Tests for LabelmapTools functionality. + +Covers thresholding a multi-label labelmap into a binary registration mask, +physically isotropic dilation that respects per-axis spacing, and forcing +selected labels to background via ``labels_to_exclude``. +""" + +from __future__ import annotations + +import itk +import numpy as np +import pytest + +from physiomotion4d.labelmap_tools import LabelmapTools + + +class TestLabelmapTools: + """Test suite for LabelmapTools.convert_labelmap_to_mask.""" + + @pytest.fixture + def labelmap_tools(self) -> LabelmapTools: + """Create LabelmapTools instance.""" + return LabelmapTools() + + def test_threshold_without_dilation(self, labelmap_tools: LabelmapTools) -> None: + """Every non-zero label becomes foreground; no dilation grows it.""" + arr = np.zeros((5, 5, 5), dtype=np.uint8) + arr[2, 2, 2] = 3 # non-zero label id + labelmap = itk.image_from_array(arr) + labelmap.SetSpacing([1.0, 1.0, 1.0]) + + mask = labelmap_tools.convert_labelmap_to_mask(labelmap, dilation_in_mm=0.0) + mask_arr = itk.array_from_image(mask) + + assert set(np.unique(mask_arr).tolist()) == {0, 1} + assert int(mask_arr.sum()) == 1 + assert mask_arr[2, 2, 2] == 1 + + def test_dilation_grows_mask(self, labelmap_tools: LabelmapTools) -> None: + """Positive dilation_in_mm grows the mask but keeps the seed voxel.""" + arr = np.zeros((5, 5, 5), dtype=np.uint8) + arr[2, 2, 2] = 3 + labelmap = itk.image_from_array(arr) + # Unit isotropic spacing so dilation_in_mm == voxel radius. + labelmap.SetSpacing([1.0, 1.0, 1.0]) + + dilated = labelmap_tools.convert_labelmap_to_mask(labelmap, dilation_in_mm=1.0) + dilated_arr = itk.array_from_image(dilated) + + assert int(dilated_arr.sum()) > 1 + assert dilated_arr[2, 2, 2] == 1 + + def test_dilation_respects_anisotropic_spacing( + self, labelmap_tools: LabelmapTools + ) -> None: + """A 5 mm radius covers more voxels along the finely spaced axis.""" + arr = np.zeros((11, 11, 11), dtype=np.uint8) + arr[5, 5, 5] = 1 + labelmap = itk.image_from_array(arr) + # numpy axes are (Z, Y, X); ITK spacing is (X, Y, Z). Make X coarse + # (5 mm/voxel -> 1-voxel radius) and Z fine (1 mm/voxel -> 5-voxel + # radius) so the per-axis radius differs. + labelmap.SetSpacing([5.0, 1.0, 1.0]) + + dilated = itk.array_from_image( + labelmap_tools.convert_labelmap_to_mask(labelmap, dilation_in_mm=5.0) + ) + + # Z axis (numpy axis 0) reaches 5 voxels out; X axis (numpy axis 2) + # only 1 voxel out. + assert dilated[0, 5, 5] == 1 + assert dilated[10, 5, 5] == 1 + assert dilated[5, 5, 0] == 0 + assert dilated[5, 5, 10] == 0 + + def test_labels_to_exclude_removes_voxels( + self, labelmap_tools: LabelmapTools + ) -> None: + """Excluded labels become background before thresholding.""" + arr = np.zeros((5, 5, 5), dtype=np.uint8) + arr[1, 1, 1] = 2 # kept + arr[3, 3, 3] = 7 # excluded + labelmap = itk.image_from_array(arr) + labelmap.SetSpacing([1.0, 1.0, 1.0]) + + mask_arr = itk.array_from_image( + labelmap_tools.convert_labelmap_to_mask( + labelmap, dilation_in_mm=0.0, labels_to_exclude=[7] + ) + ) + + assert mask_arr[1, 1, 1] == 1 + assert mask_arr[3, 3, 3] == 0 + assert int(mask_arr.sum()) == 1 + + def test_preserves_image_information(self, labelmap_tools: LabelmapTools) -> None: + """Origin, spacing, and direction are copied from the labelmap.""" + arr = np.zeros((4, 4, 4), dtype=np.uint8) + arr[2, 2, 2] = 1 + labelmap = itk.image_from_array(arr) + labelmap.SetSpacing([0.5, 1.0, 2.0]) + labelmap.SetOrigin([10.0, -5.0, 3.0]) + + mask = labelmap_tools.convert_labelmap_to_mask(labelmap, dilation_in_mm=0.0) + + assert list(mask.GetSpacing()) == [0.5, 1.0, 2.0] + assert list(mask.GetOrigin()) == [10.0, -5.0, 3.0] diff --git a/tests/test_workflow_fine_tune_icon_registration.py b/tests/test_workflow_fine_tune_icon_registration.py index 11cde75..a90d9ea 100644 --- a/tests/test_workflow_fine_tune_icon_registration.py +++ b/tests/test_workflow_fine_tune_icon_registration.py @@ -20,7 +20,6 @@ import pytest import yaml -from physiomotion4d.register_images_icon import RegisterImagesICON from physiomotion4d.workflow_fine_tune_icon_registration import ( WorkflowFineTuneICONRegistration, ) @@ -128,7 +127,7 @@ def test_init_rejects_mismatched_subject_ids_length(tmp_path: Path) -> None: ) -def test_uses_segmentations_and_uses_masks_flags(tmp_path: Path) -> None: +def test_use_segmentations_and_use_masks_flags(tmp_path: Path) -> None: """The two helper flags reflect supplied companions independently.""" base: dict[str, Any] = { "subject_image_files": [["a"]], @@ -136,45 +135,20 @@ def test_uses_segmentations_and_uses_masks_flags(tmp_path: Path) -> None: "fine_tune_name": "x", } none_wf = WorkflowFineTuneICONRegistration(**base) - assert not none_wf.uses_segmentations - assert not none_wf.uses_masks + assert not none_wf.use_segmentations + assert not none_wf.use_masks seg_only = WorkflowFineTuneICONRegistration( **base, subject_segmentation_files=[["seg.nii.gz"]] ) - assert seg_only.uses_segmentations - assert seg_only.uses_masks # derived from segs + assert seg_only.use_segmentations + assert seg_only.use_masks # derived from segs mask_only = WorkflowFineTuneICONRegistration( **base, subject_mask_files=[["mask.nii.gz"]] ) - assert not mask_only.uses_segmentations - assert mask_only.uses_masks - - -# --------------------------------------------------------------------------- -# RegisterImagesICON.create_mask (in-memory dilation, used by the workflow) -# --------------------------------------------------------------------------- - - -def test_create_mask_thresholds_and_dilates() -> None: - """Single-voxel labelmap becomes a binary mask whose dilation grows it.""" - arr = np.zeros((5, 5, 5), dtype=np.uint8) - arr[2, 2, 2] = 3 # non-zero label id - labelmap = itk.image_from_array(arr) - # Unit isotropic spacing so dilation_mm == voxel radius. - labelmap.SetSpacing([1.0, 1.0, 1.0]) - - no_dilate = RegisterImagesICON.create_mask(labelmap, dilation_mm=0.0) - no_dilate_arr = itk.array_from_image(no_dilate) - assert set(np.unique(no_dilate_arr).tolist()) == {0, 1} - assert int(no_dilate_arr.sum()) == 1 - - dilated = RegisterImagesICON.create_mask(labelmap, dilation_mm=1.0) - dilated_arr = itk.array_from_image(dilated) - assert int(dilated_arr.sum()) > 1 - # Original foreground voxel stays foreground. - assert dilated_arr[2, 2, 2] == 1 + assert not mask_only.use_segmentations + assert mask_only.use_masks # --------------------------------------------------------------------------- From e6a789fe72de2497d4bec1e8be70ca455b500b39 Mon Sep 17 00:00:00 2001 From: Stephen Aylward Date: Mon, 1 Jun 2026 15:41:13 -0400 Subject: [PATCH 03/10] ENH: Registration timing estimation --- ...istration.py => 1-initial_registration.py} | 273 +++++++++++++++--- ...{1-finetune_icon.py => 2-finetune_icon.py} | 0 ...d_icon_eval.py => 3-recon_4d_icon_eval.py} | 0 ...d_comparison.py => 4-recon_4d_all_eval.py} | 0 .../registration_test.py | 26 +- src/physiomotion4d/labelmap_tools.py | 12 +- .../workflow_fine_tune_icon_registration.py | 104 +++---- tests/test_labelmap_tools.py | 8 +- 8 files changed, 310 insertions(+), 113 deletions(-) rename experiments/LongitudinalRegistration/{1-preregistration.py => 1-initial_registration.py} (70%) rename experiments/LongitudinalRegistration/{1-finetune_icon.py => 2-finetune_icon.py} (100%) rename experiments/LongitudinalRegistration/{2-recon_4d_icon_eval.py => 3-recon_4d_icon_eval.py} (100%) rename experiments/LongitudinalRegistration/{3-run_registration_method_comparison.py => 4-recon_4d_all_eval.py} (100%) diff --git a/experiments/LongitudinalRegistration/1-preregistration.py b/experiments/LongitudinalRegistration/1-initial_registration.py similarity index 70% rename from experiments/LongitudinalRegistration/1-preregistration.py rename to experiments/LongitudinalRegistration/1-initial_registration.py index fab754f..e9074b7 100644 --- a/experiments/LongitudinalRegistration/1-preregistration.py +++ b/experiments/LongitudinalRegistration/1-initial_registration.py @@ -1,12 +1,13 @@ # %% [markdown] -# # Pre-registration: compare ANTS vs Greedy on the Duke gated CT cohort +# # Pre-registration: compare ANTS vs Greedy vs ICON on the Duke gated CT cohort # # Registers every gated CT time-point of every Duke patient under # ``ref_data_dir`` (100% of the cohort -- no train/test split) to that -# patient's reference image, using two backends in turn: +# patient's reference image, using three backends in turn: # # * :class:`RegisterImagesANTS` (CPU, SyN deformable) # * :class:`RegisterImagesGreedy` (CPU, deformable) +# * :class:`RegisterImagesICON` (GPU, uniGradICON deformable) # # For each frame the script records wall-clock registration time, writes # the warped/resampled moving image to disk, warps the moving labelmap @@ -21,22 +22,27 @@ # -- per-frame multi-label segmentations # * ``segmentation_dir_base / / _labelmap_mask.nii.gz`` # -- pre-computed loss-function masks (re-derived on the fly if absent, -# matching the 5 mm dilation used by ``1-finetune_icon.py``) +# matching the 3 mm dilation used by ``1-finetune_icon.py``) # * ``segmentation_dir_base / / _landmark.mrk.json`` # -- per-frame 3D Slicer Markups landmarks in LPS # # Outputs under ``results/``: -# * ``ants///.mha`` and -# ``greedy///.mha`` -- warped moving image +# * ``ants///.mha``, +# ``greedy///.mha`` and +# ``icon///.mha`` -- warped moving image # per time point, alongside the forward/inverse transforms (``.hdf``), -# a ``_deformation_grid.mha`` visualization of the registration -# deformation, the warped ``_labelmap.mha`` and its warped +# the warped ``_labelmap.mha`` and its warped # loss-function mask ``_labelmap_mask.mha``, # and the warped ``_landmark.mrk.json`` -# * ``preregistration_landmarks.csv`` -- per-landmark squared errors -# * ``preregistration_dice.csv`` -- per-label Dice -# * ``preregistration_summary.csv`` -- per-(subject, method, timepoint) -# time, mean Dice, MSE, RMSE +# * ``registration_landmarks_.csv`` -- per-landmark squared errors +# * ``registration_dice_.csv`` -- per-label Dice +# * ``registration_summary_.csv`` -- per-(subject, method, timepoint) +# registration time, per-frame total time, mean Dice, MSE, RMSE +# * ``registration_timing_.csv`` -- per-step wall-clock seconds, +# appended live as each frame's steps complete (register, write_transforms, +# warp_image, warp_labelmap, warp_mask, dice, landmarks, frame_total) +# * ``registration_timing_summary_.csv`` -- per-(method, step) count, +# mean, and total seconds, written once at the end of the run # # Run interactively cell-by-cell; all paths are hard-coded. @@ -54,6 +60,7 @@ from physiomotion4d.landmark_tools import LandmarkTools from physiomotion4d.register_images_ants import RegisterImagesANTS from physiomotion4d.register_images_greedy import RegisterImagesGreedy +from physiomotion4d.register_images_icon import RegisterImagesICON from physiomotion4d.transform_tools import TransformTools # %% [markdown] @@ -77,24 +84,44 @@ # Mask dilation matches 1-finetune_icon.py so any masks we have to # derive here are identical to the ones written by the fine-tune script. -mask_dilation_mm = 5.0 +mask_dilation_mm = 3.0 labelmap_tools = LabelmapTools() # Iteration schedules. Kept modest for a cohort-wide comparison; raise -# either list for higher accuracy at the cost of runtime. -number_of_iterations_ANTS = [20, 10, 5] +# either list for higher accuracy at the cost of runtime. ANTS and Greedy +# take a multi-resolution list; ICON takes a single per-pair iterative +# optimization step count (0 disables it, using the pretrained forward pass +# alone). +number_of_iterations_ANTS = [40, 20, 10] number_of_iterations_greedy = [40, 20, 10] +number_of_iterations_ICON = 50 -methods: list[str] = ["ANTS", "greedy"] +# Optional uniGradICON checkpoint (".trch") to load instead of the default +# pretrained weights under ``network_weights/unigradicon1.0/``. When None, +# the default pretrained weights are used. +icon_weights_path: Optional[Path] = None + +methods: list[str] = ["ANTS", "Greedy", "ICON"] # Debug knob: when non-empty, only these patient IDs are processed. # Set to ``[]`` (or ``None``) to run the full cohort. -debug_subjects: list[str] = ["pm0002"] - -detail_landmarks_file = output_dir / "preregistration_landmarks.csv" -detail_dice_file = output_dir / "preregistration_dice.csv" -summary_file = output_dir / "preregistration_summary.csv" -for previous in (detail_landmarks_file, detail_dice_file, summary_file): +debug_subjects: list[str] = [] # ["pm0002"] + +run_stamp = time.time() +detail_landmarks_file = output_dir / f"registration_landmarks_{run_stamp}.csv" +detail_dice_file = output_dir / f"registration_dice_{run_stamp}.csv" +summary_file = output_dir / f"registration_summary_{run_stamp}.csv" +# Per-step wall-clock times, appended live as each frame's steps complete. +timing_detail_file = output_dir / f"registration_timing_{run_stamp}.csv" +# Per-(method, step) timing aggregates, written once at the end of the run. +timing_summary_file = output_dir / f"registration_timing_summary_{run_stamp}.csv" +for previous in ( + detail_landmarks_file, + detail_dice_file, + summary_file, + timing_detail_file, + timing_summary_file, +): if previous.exists(): previous.unlink() @@ -129,6 +156,41 @@ landmark_tools = LandmarkTools() transform_tools = TransformTools() +# Per-step timing records (subject, method, timepoint, step, seconds), +# accumulated in memory for the end-of-run timing summary and mirrored live +# into timing_detail_file as each step finishes. +timing_rows: list[dict[str, object]] = [] + + +def record_step_time( + subject_id: str, + method_name: str, + timepoint: str, + step: str, + seconds: float, +) -> None: + """Report a single processing step's wall-clock time. + + Prints the time immediately, appends a row to ``timing_detail_file`` so + progress is visible while the run is still going, and stores the same + row in ``timing_rows`` for the end-of-run timing summary. + """ + print(f" [time] {step:<18}{seconds:8.2f} s", flush=True) + timing_rows.append( + { + "subject_id": subject_id, + "method": method_name, + "timepoint": timepoint, + "step": step, + "seconds": float(seconds), + } + ) + with timing_detail_file.open("a", newline="", encoding="utf-8") as fh: + writer = csv.writer(fh) + if fh.tell() == 0: + writer.writerow(["subject_id", "method", "timepoint", "step", "seconds"]) + writer.writerow([subject_id, method_name, timepoint, step, f"{seconds:.6f}"]) + def per_label_dice( fixed_labelmap: itk.Image, warped_labelmap: itk.Image @@ -202,7 +264,7 @@ def load_or_derive_mask(labelmap: itk.Image, mask_path: Path) -> itk.Image: """Return the cached ``_labelmap_mask.nii.gz`` next to the labelmap, or derive it via :meth:`LabelmapTools.convert_labelmap_to_mask` (threshold ``>0`` plus - 5 mm physical-radius dilation) and write it out so subsequent runs and + 3 mm physical-radius dilation) and write it out so subsequent runs and the ICON eval reuse the same mask. """ # Force mask update @@ -211,7 +273,7 @@ def load_or_derive_mask(labelmap: itk.Image, mask_path: Path) -> itk.Image: mask = labelmap_tools.convert_labelmap_to_mask( labelmap, dilation_in_mm=mask_dilation_mm, - labels_to_exclude=[1, 2, 3, 4], + exclude_labels=[1, 2, 3, 4], # These are labels for the interior of the heart chambers (the LV, RV, LA, RA) ) itk.imwrite(mask, str(mask_path), compression=True) @@ -332,12 +394,20 @@ def load_or_derive_mask(labelmap: itk.Image, mask_path: Path) -> itk.Image: if method_name == "ANTS": reg = RegisterImagesANTS() reg.set_number_of_iterations(number_of_iterations_ANTS) - else: + reg.set_transform_type("Deformable") + # NCC ("CC") beats MeanSquares for same-modality CT registration. + reg.set_metric("CC") + elif method_name == "Greedy": reg = RegisterImagesGreedy() reg.set_number_of_iterations(number_of_iterations_greedy) - reg.set_transform_type("Deformable") - # NCC ("CC") outperforms MeanSquares for same-modality CT registration. - reg.set_metric("CC") + reg.set_transform_type("Deformable") + # NCC ("CC") beats MeanSquares for same-modality CT registration. + reg.set_metric("CC") + else: # ICON: GPU deep-learning deformable registration. + reg = RegisterImagesICON() + reg.set_number_of_iterations(number_of_iterations_ICON) + if icon_weights_path is not None: + reg.set_weights_path(str(icon_weights_path)) reg.set_modality("ct") reg.set_mask_dilation(mask_dilation_mm) reg.set_fixed_image(fixed_image) @@ -356,7 +426,8 @@ def load_or_derive_mask(labelmap: itk.Image, mask_path: Path) -> itk.Image: flush=True, ) - frame_t_start = time.perf_counter() + frame_total_start = time.perf_counter() + frame_t_start = frame_total_start reg_result = reg.register( moving_image=moving_images[index], moving_mask=moving_masks[index], @@ -367,7 +438,11 @@ def load_or_derive_mask(labelmap: itk.Image, mask_path: Path) -> itk.Image: inverse_transform = reg_result["inverse_transform"] frame_loss = float(reg_result["loss"]) print(f" done in {frame_elapsed:.1f} s, loss={frame_loss:.4f}") + record_step_time( + subject_id, method_name, timepoint, "register", frame_elapsed + ) + step_t_start = time.perf_counter() itk.transformwrite( forward_transform, str(method_dir / f"{stem}_fwd.hdf"), @@ -378,22 +453,17 @@ def load_or_derive_mask(labelmap: itk.Image, mask_path: Path) -> itk.Image: str(method_dir / f"{stem}_inv.hdf"), compression=True, ) - - # Visualize the deformation as a warped grid: a regular grid built - # in reference space, resampled through forward_transform -- the same - # transform used below to warp the moving image onto the fixed grid, - # so the grid and the warped image deform consistently. - deformation_grid = transform_tools.convert_field_to_grid_visualization( - forward_transform, fixed_image - ) - itk.imwrite( - deformation_grid, - str(method_dir / f"{stem}_deformation_grid.mha"), - compression=True, + record_step_time( + subject_id, + method_name, + timepoint, + "write_transforms", + time.perf_counter() - step_t_start, ) # Warp the moving image into reference space and save it # (forward_transform resamples the moving image onto the fixed grid). + step_t_start = time.perf_counter() warped_image = transform_tools.transform_image( moving_images[index], forward_transform, @@ -405,9 +475,17 @@ def load_or_derive_mask(labelmap: itk.Image, mask_path: Path) -> itk.Image: str(method_dir / f"{stem}.mha"), compression=True, ) + record_step_time( + subject_id, + method_name, + timepoint, + "warp_image", + time.perf_counter() - step_t_start, + ) # Warp the moving labelmap onto the fixed grid (forward_transform; # nearest neighbour preserves label IDs) for per-label Dice. + step_t_start = time.perf_counter() warped_labelmap = transform_tools.transform_image( moving_labelmaps[index], forward_transform, @@ -419,11 +497,19 @@ def load_or_derive_mask(labelmap: itk.Image, mask_path: Path) -> itk.Image: str(method_dir / f"{stem}_labelmap.mha"), compression=True, ) + record_step_time( + subject_id, + method_name, + timepoint, + "warp_labelmap", + time.perf_counter() - step_t_start, + ) # Warp the moving loss-function mask onto the fixed grid # (forward_transform; nearest neighbour preserves the binary ROI) # so downstream fine-tuning reuses it instead of re-deriving a # mask from the warped labelmap. + step_t_start = time.perf_counter() warped_mask = transform_tools.transform_image( moving_masks[index], forward_transform, @@ -435,7 +521,15 @@ def load_or_derive_mask(labelmap: itk.Image, mask_path: Path) -> itk.Image: str(method_dir / f"{stem}_labelmap_mask.mha"), compression=True, ) + record_step_time( + subject_id, + method_name, + timepoint, + "warp_mask", + time.perf_counter() - step_t_start, + ) + step_t_start = time.perf_counter() dice_by_label = per_label_dice(fixed_labelmap, warped_labelmap) with detail_dice_file.open("a", newline="", encoding="utf-8") as fh: writer = csv.writer(fh) @@ -450,9 +544,17 @@ def load_or_derive_mask(labelmap: itk.Image, mask_path: Path) -> itk.Image: if dice_by_label else float("nan") ) + record_step_time( + subject_id, + method_name, + timepoint, + "dice", + time.perf_counter() - step_t_start, + ) # Warp the moving landmarks into reference space, save them next # to the transforms, then score squared error vs the reference. + step_t_start = time.perf_counter() moving_landmarks = moving_landmarks_list[index] if moving_landmarks is None: sq_errors: list[tuple[str, float]] = [] @@ -479,6 +581,13 @@ def load_or_derive_mask(labelmap: itk.Image, mask_path: Path) -> itk.Image: ) for name, sq_err in sq_errors: writer.writerow([subject_id, method_name, timepoint, name, sq_err]) + record_step_time( + subject_id, + method_name, + timepoint, + "landmarks", + time.perf_counter() - step_t_start, + ) sq_values = np.asarray([e for _, e in sq_errors], dtype=np.float64) if sq_values.size: @@ -501,12 +610,18 @@ def load_or_derive_mask(labelmap: itk.Image, mask_path: Path) -> itk.Image: flush=True, ) + frame_total = time.perf_counter() - frame_total_start + record_step_time( + subject_id, method_name, timepoint, "frame_total", frame_total + ) + summary_rows.append( { "subject_id": subject_id, "method": method_name, "timepoint": timepoint, "time_sec": float(frame_elapsed), + "frame_total_sec": float(frame_total), "loss": frame_loss, "n_labels": int(len(dice_by_label)), "mean_dice": mean_dice, @@ -535,6 +650,7 @@ def load_or_derive_mask(labelmap: itk.Image, mask_path: Path) -> itk.Image: print(f"\nWrote summary: {summary_file}") print(f"Wrote landmarks: {detail_landmarks_file}") print(f"Wrote dice: {detail_dice_file}") + print(f"Wrote timing: {timing_detail_file}") else: print("\nNo frames processed; nothing to summarize.") @@ -611,3 +727,80 @@ def load_or_derive_mask(labelmap: itk.Image, mask_path: Path) -> itk.Image: f"{mean_time:>12.2f}" ) print("=" * len(header)) + +# %% [markdown] +# ## 7. Per-(method, step) timing summary +# +# Aggregates the live per-step timings into mean and total wall-clock +# seconds per (method, step), printed as a table and written to +# ``timing_summary_file``. ``frame_total`` is the end-to-end per-frame +# time (register + all warps/writes + scoring); the other rows are its +# components. + +# %% +if timing_rows: + # Preserve the pipeline order in which steps are timed; any unexpected + # step name is appended in first-seen order so nothing is dropped. + step_order = [ + "register", + "write_transforms", + "warp_image", + "warp_labelmap", + "warp_mask", + "dice", + "landmarks", + "frame_total", + ] + seconds_by_method_step: dict[str, dict[str, list[float]]] = {} + for row in timing_rows: + method_name = str(row["method"]) + step = str(row["step"]) + seconds = float(row["seconds"]) + seconds_by_method_step.setdefault(method_name, {}).setdefault(step, []).append( + seconds + ) + if step not in step_order: + step_order.append(step) + + timing_summary_rows: list[dict[str, object]] = [] + timing_header = ( + f"{'Method':<10}{'Step':<18}{'N':>6}{'mean_sec':>12}{'total_sec':>12}" + ) + print() + print("=" * len(timing_header)) + print("Timing summary (wall-clock seconds)") + print("=" * len(timing_header)) + print(timing_header) + print("-" * len(timing_header)) + for method_name in methods: + step_times = seconds_by_method_step.get(method_name, {}) + if not step_times: + continue + for step in step_order: + values = step_times.get(step) + if not values: + continue + arr = np.asarray(values, dtype=np.float64) + mean_sec = float(np.mean(arr)) + total_sec = float(np.sum(arr)) + timing_summary_rows.append( + { + "method": method_name, + "step": step, + "n": int(arr.size), + "mean_sec": mean_sec, + "total_sec": total_sec, + } + ) + print( + f"{method_name:<10}{step:<18}{arr.size:>6}" + f"{mean_sec:>12.2f}{total_sec:>12.2f}" + ) + print("-" * len(timing_header)) + print("=" * len(timing_header)) + + with timing_summary_file.open("w", newline="", encoding="utf-8") as fh: + writer = csv.DictWriter(fh, fieldnames=list(timing_summary_rows[0].keys())) + writer.writeheader() + writer.writerows(timing_summary_rows) + print(f"Wrote timing summary: {timing_summary_file}") diff --git a/experiments/LongitudinalRegistration/1-finetune_icon.py b/experiments/LongitudinalRegistration/2-finetune_icon.py similarity index 100% rename from experiments/LongitudinalRegistration/1-finetune_icon.py rename to experiments/LongitudinalRegistration/2-finetune_icon.py diff --git a/experiments/LongitudinalRegistration/2-recon_4d_icon_eval.py b/experiments/LongitudinalRegistration/3-recon_4d_icon_eval.py similarity index 100% rename from experiments/LongitudinalRegistration/2-recon_4d_icon_eval.py rename to experiments/LongitudinalRegistration/3-recon_4d_icon_eval.py diff --git a/experiments/LongitudinalRegistration/3-run_registration_method_comparison.py b/experiments/LongitudinalRegistration/4-recon_4d_all_eval.py similarity index 100% rename from experiments/LongitudinalRegistration/3-run_registration_method_comparison.py rename to experiments/LongitudinalRegistration/4-recon_4d_all_eval.py diff --git a/experiments/LongitudinalRegistration/registration_test.py b/experiments/LongitudinalRegistration/registration_test.py index ad2219c..0c66ef3 100644 --- a/experiments/LongitudinalRegistration/registration_test.py +++ b/experiments/LongitudinalRegistration/registration_test.py @@ -65,7 +65,7 @@ # NCC (CC) beats SSD for same-modality CT; tighter update-field smoothing # (first sigma) captures more cardiac motion while staying diffeomorphic. reg.set_metric("CC") - reg.set_number_of_iterations([100, 80, 40]) + reg.set_number_of_iterations([40, 20, 10]) reg.deformable_smoothing = "1.0vox 0.5vox" elif method == "ICON": reg = RegisterImagesICON() @@ -100,6 +100,7 @@ # %% transform_tools = TransformTools() +warp_t_start = time.perf_counter() warped_image = transform_tools.transform_image( moving_image, forward_transform, @@ -107,4 +108,27 @@ interpolation_method="linear", ) itk.imwrite(warped_image, str(output_path), compression=True) +warp_elapsed = time.perf_counter() - warp_t_start print(f"Wrote warped time point 20 -> 60: {output_path}") + +# %% [markdown] +# ## 6. Timing report +# +# Wall-clock seconds for the registration and the warp/write step. The +# registration time dominates and is the figure to compare across backends; +# for ICON it includes the one-time network load on this first (and only) +# pair. + +# %% +total_elapsed = elapsed + warp_elapsed +print() +print("=" * 44) +print(f"Timing report ({method})") +print("=" * 44) +print(f"{'Step':<22}{'seconds':>12}") +print("-" * 44) +print(f"{'register':<22}{elapsed:>12.2f}") +print(f"{'warp + write':<22}{warp_elapsed:>12.2f}") +print("-" * 44) +print(f"{'total':<22}{total_elapsed:>12.2f}") +print("=" * 44) diff --git a/src/physiomotion4d/labelmap_tools.py b/src/physiomotion4d/labelmap_tools.py index 8e2d067..ff5ad78 100644 --- a/src/physiomotion4d/labelmap_tools.py +++ b/src/physiomotion4d/labelmap_tools.py @@ -33,7 +33,7 @@ class centralizes the labelmap-to-mask conversion so that thresholding, >>> mask = tools.convert_labelmap_to_mask(labelmap, dilation_in_mm=5.0) >>> # Exclude the table/background labels 8 and 9 before masking >>> mask = tools.convert_labelmap_to_mask( - ... labelmap, dilation_in_mm=5.0, labels_to_exclude=[8, 9] + ... labelmap, dilation_in_mm=5.0, exclude_labels=[8, 9] ... ) """ @@ -49,11 +49,11 @@ def convert_labelmap_to_mask( self, labelmap: itk.Image, dilation_in_mm: float = 0.0, - labels_to_exclude: Optional[list[int]] = None, + exclude_labels: Optional[list[int]] = None, ) -> itk.Image: """Convert a labelmap into a binary registration mask. - Any voxel whose label is in ``labels_to_exclude`` is set to background + Any voxel whose label is in ``exclude_labels`` is set to background first; every remaining non-zero voxel becomes foreground (``1``). The binary mask is then dilated by ``dilation_in_mm`` millimeters of physical radius. The radius is converted into per-axis voxel counts @@ -72,7 +72,7 @@ def convert_labelmap_to_mask( dilation_in_mm: Physical radius of the binary dilation in millimeters. Pass ``0`` (or negative) to skip dilation and return the raw thresholded mask. Default 0.0. - labels_to_exclude: Optional list of integer label values to force + exclude_labels: Optional list of integer label values to force to background before thresholding. When ``None`` (the default) no labels are excluded. @@ -81,8 +81,8 @@ def convert_labelmap_to_mask( ``labelmap`` (origin, spacing, direction copied from the input). """ arr = itk.array_from_image(labelmap) - if labels_to_exclude: - arr = np.where(np.isin(arr, labels_to_exclude), 0, arr) + if exclude_labels: + arr = np.where(np.isin(arr, exclude_labels), 0, arr) mask_arr = (arr > 0).astype(np.uint8) mask = itk.image_from_array(mask_arr) mask.CopyInformation(labelmap) diff --git a/src/physiomotion4d/workflow_fine_tune_icon_registration.py b/src/physiomotion4d/workflow_fine_tune_icon_registration.py index 8f0c6b0..3b91e0a 100644 --- a/src/physiomotion4d/workflow_fine_tune_icon_registration.py +++ b/src/physiomotion4d/workflow_fine_tune_icon_registration.py @@ -96,6 +96,8 @@ class WorkflowFineTuneICONRegistration(PhysioMotion4DBase): mask_dilation_mm (float): Millimeters of physical-radius binary dilation applied to the >0 labelmap when deriving the loss-masking binary mask via :meth:`LabelmapTools.convert_labelmap_to_mask`. + mask_exclude_labels (Optional[list[int]]): Labels to exclude from the mask. + Default is None. mask_dir (Optional[Path]): Directory where derived binary masks are written and looked up. ``None`` (default) writes each derived mask next to its source labelmap on disk. @@ -155,6 +157,7 @@ def __init__( eval_period: int = 10, save_period: int = 50, mask_dilation_mm: float = 5.0, + mask_exclude_labels: Optional[list[int]] = None, mask_dir: Optional[Path] = None, unigradicon_src_path: Optional[Path] = None, log_level: Union[int, str] = logging.INFO, @@ -176,7 +179,7 @@ def __init__( form ``subject_0000``, ``subject_0001``, ... Must be unique. subject_segmentation_files: Per-subject multi-label segmentation (labelmap) paths matching ``subject_image_files``. ``None`` - disables paired-with-seg training (no ``use_label``). + disables paired-with-seg training. Individual ``None`` entries inside the inner lists skip just those frames when paired-with-seg training is enabled. subject_mask_files: Per-subject binary mask paths matching @@ -277,6 +280,7 @@ def __init__( self.gpus = list(gpus) if gpus is not None else [0] self.eval_period = eval_period self.save_period = save_period + self.mask_exclude_labels = mask_exclude_labels self.mask_dilation_mm = float(mask_dilation_mm) self.unigradicon_src_path = ( Path(unigradicon_src_path) if unigradicon_src_path is not None else None @@ -286,6 +290,9 @@ def __init__( self.labelmap_tools = LabelmapTools(log_level=log_level) self.registrar: Optional[RegisterTimeSeriesImages] = None + self._use_segmentations: Optional[bool] = None + self._use_masks: Optional[bool] = None + self._dataset_json_path: Optional[Path] = None self._config_yaml_path: Optional[Path] = None @@ -310,53 +317,15 @@ def _validate_companion_shape( f"subject_image_files[{i}] length ({len(images)})" ) - @property - def use_segmentations(self) -> bool: - """Whether at least one segmentation file is supplied for training. - - Drives the uniGradICON ``training.use_label`` flag. - """ - return self._any_non_none(self.subject_segmentation_files) - - @property - def use_masks(self) -> bool: - """Whether the dataset will have a ``mask`` field on every kept entry. - - True when explicit masks are supplied OR when segmentations are supplied - (since masks are then derived). Drives the uniGradICON - ``training.loss_function_masking`` flag. - """ - return self._any_non_none(self.subject_mask_files) or self.use_segmentations - - @property - def use_label(self) -> bool: - """Whether uniGradICON trains with label supervision. - - Drives the uniGradICON ``training.use_label`` flag. True exactly when - segmentations are supplied, since the dataset then carries a ``label`` - field on every kept entry. - """ - return self.use_segmentations - - @staticmethod - def _any_non_none( - companion: Optional[list[list[Optional[str]]]], - ) -> bool: - """Return True when ``companion`` contains at least one non-``None`` entry.""" - if companion is None: - return False - for inner in companion: - for item in inner: - if item is not None: - return True - return False - @staticmethod def _posix(path: Union[str, Path]) -> str: """Return a forward-slashed string path (uniGradICON expects POSIX paths).""" return str(path).replace("\\", "/") - def _derive_mask(self, labelmap_path: Union[str, Path]) -> Path: + def _derive_mask( + self, + labelmap_path: Union[str, Path], + ) -> Path: """Create (or reuse) a dilated binary mask from a multi-label labelmap. Threshold the labelmap at ``>0`` and dilate by ``mask_dilation_mm`` mm @@ -391,12 +360,16 @@ def _derive_mask(self, labelmap_path: Union[str, Path]) -> Path: labelmap = itk.imread(str(labelmap_path)) mask = self.labelmap_tools.convert_labelmap_to_mask( - labelmap, dilation_in_mm=self.mask_dilation_mm + labelmap, + dilation_in_mm=self.mask_dilation_mm, + exclude_labels=self.mask_exclude_labels, ) itk.imwrite(mask, str(mask_path), compression=True) return mask_path - def prepare_dataset(self) -> Path: + def prepare_dataset( + self, use_segmentations: bool = True, use_masks: bool = True + ) -> Path: """Write the uniGradICON dataset JSON from the configured file lists. Builds one entry per image with ``image``, optional ``segmentation``, @@ -418,8 +391,9 @@ def prepare_dataset(self) -> Path: does not exist on disk. """ self.experiment_dir.mkdir(parents=True, exist_ok=True) - use_seg = self.use_segmentations - use_mask = self.use_masks + + self._use_segmentations = use_segmentations + self._use_masks = use_masks dataset_entries: list[dict[str, str]] = [] for subject_index, image_files in enumerate(self.subject_image_files): @@ -428,16 +402,24 @@ def prepare_dataset(self) -> Path: if self.subject_ids is not None else f"subject_{subject_index:04d}" ) - seg_list = ( - self.subject_segmentation_files[subject_index] - if self.subject_segmentation_files is not None - else [None] * len(image_files) - ) - mask_list = ( - self.subject_mask_files[subject_index] - if self.subject_mask_files is not None - else [None] * len(image_files) - ) + seg_list: list[Optional[str]] + if not use_segmentations: + seg_list = [None] * len(image_files) + else: + seg_list = ( + self.subject_segmentation_files[subject_index] + if self.subject_segmentation_files is not None + else [None] * len(image_files) + ) + mask_list: list[Optional[str]] + if not use_masks: + mask_list = [None] * len(image_files) + else: + mask_list = ( + self.subject_mask_files[subject_index] + if self.subject_mask_files is not None + else [None] * len(image_files) + ) landmark_list = ( self.subject_landmark_files[subject_index] if self.subject_landmark_files is not None @@ -456,7 +438,7 @@ def prepare_dataset(self) -> Path: "subject_id": subject_id, } - if use_seg: + if use_segmentations: if seg_file is None or not Path(seg_file).exists(): self.log_warning( "Skipping %s: segmentation missing for paired-with-seg " @@ -467,7 +449,7 @@ def prepare_dataset(self) -> Path: continue entry["segmentation"] = self._posix(seg_file) - if use_mask: + if use_masks: if mask_file is not None and Path(mask_file).exists(): resolved_mask: Path = Path(mask_file) elif seg_file is not None and Path(seg_file).exists(): @@ -541,8 +523,8 @@ def prepare_config(self, dataset_json_path: Optional[Path] = None) -> Path: "lambda": self.lambda_value, "dice_loss_weight": self.dice_loss_weight, "lncc_sigma": self.lncc_sigma, - "loss_function_masking": self.use_masks, - "use_label": self.use_label, + "loss_function_masking": self._use_masks, + "use_label": False, "roi_masking": False, }, "datasets": [ diff --git a/tests/test_labelmap_tools.py b/tests/test_labelmap_tools.py index 627c9e3..e6f15cf 100644 --- a/tests/test_labelmap_tools.py +++ b/tests/test_labelmap_tools.py @@ -4,7 +4,7 @@ Covers thresholding a multi-label labelmap into a binary registration mask, physically isotropic dilation that respects per-axis spacing, and forcing -selected labels to background via ``labels_to_exclude``. +selected labels to background via ``exclude_labels``. """ from __future__ import annotations @@ -75,9 +75,7 @@ def test_dilation_respects_anisotropic_spacing( assert dilated[5, 5, 0] == 0 assert dilated[5, 5, 10] == 0 - def test_labels_to_exclude_removes_voxels( - self, labelmap_tools: LabelmapTools - ) -> None: + def test_exclude_labels_removes_voxels(self, labelmap_tools: LabelmapTools) -> None: """Excluded labels become background before thresholding.""" arr = np.zeros((5, 5, 5), dtype=np.uint8) arr[1, 1, 1] = 2 # kept @@ -87,7 +85,7 @@ def test_labels_to_exclude_removes_voxels( mask_arr = itk.array_from_image( labelmap_tools.convert_labelmap_to_mask( - labelmap, dilation_in_mm=0.0, labels_to_exclude=[7] + labelmap, dilation_in_mm=0.0, exclude_labels=[7] ) ) From fc3cb5491040995db730f0c2454189ddc6013079 Mon Sep 17 00:00:00 2001 From: Stephen Aylward Date: Thu, 4 Jun 2026 11:07:57 -0400 Subject: [PATCH 04/10] ENH: Updated results reporting for registration method selection --- .../LongitudinalRegistration/.gitignore | 1 + .../1-initial_registration.py | 1105 +++++++---------- .../2-finetune_icon.py | 188 ++- .../4-recon_4d_all_eval.py | 706 ----------- .../experiment_recon_4d.py | 191 --- .../registration_results_analysis.py | 244 ++++ src/physiomotion4d/labelmap_tools.py | 87 ++ src/physiomotion4d/register_images_greedy.py | 227 +++- 8 files changed, 1074 insertions(+), 1675 deletions(-) delete mode 100644 experiments/LongitudinalRegistration/4-recon_4d_all_eval.py delete mode 100644 experiments/LongitudinalRegistration/experiment_recon_4d.py create mode 100644 experiments/LongitudinalRegistration/registration_results_analysis.py diff --git a/experiments/LongitudinalRegistration/.gitignore b/experiments/LongitudinalRegistration/.gitignore index 19960a9..f850328 100644 --- a/experiments/LongitudinalRegistration/.gitignore +++ b/experiments/LongitudinalRegistration/.gitignore @@ -1 +1,2 @@ uniGradICON +fixed diff --git a/experiments/LongitudinalRegistration/1-initial_registration.py b/experiments/LongitudinalRegistration/1-initial_registration.py index e9074b7..197c0c2 100644 --- a/experiments/LongitudinalRegistration/1-initial_registration.py +++ b/experiments/LongitudinalRegistration/1-initial_registration.py @@ -1,54 +1,13 @@ # %% [markdown] -# # Pre-registration: compare ANTS vs Greedy vs ICON on the Duke gated CT cohort -# -# Registers every gated CT time-point of every Duke patient under -# ``ref_data_dir`` (100% of the cohort -- no train/test split) to that -# patient's reference image, using three backends in turn: +# Initial registration: compare ANTS vs Greedy vs ICON on the Duke gated CT cohort # # * :class:`RegisterImagesANTS` (CPU, SyN deformable) # * :class:`RegisterImagesGreedy` (CPU, deformable) # * :class:`RegisterImagesICON` (GPU, uniGradICON deformable) # -# For each frame the script records wall-clock registration time, writes -# the warped/resampled moving image to disk, warps the moving labelmap -# into reference space to compute per-label Dice, and warps the moving -# landmarks into reference space to compute squared-error landmark -# statistics (mm^2) against the reference landmarks. -# -# Inputs (same data as ``1-finetune_icon.py``): -# * ``ref_data_dir / pm*_ref.nii.gz`` -- per-patient reference CT -# * ``src_data_dir_base / / *.nii.gz`` -- gated CT frames -# * ``segmentation_dir_base / / _labelmap.nii.gz`` -# -- per-frame multi-label segmentations -# * ``segmentation_dir_base / / _labelmap_mask.nii.gz`` -# -- pre-computed loss-function masks (re-derived on the fly if absent, -# matching the 3 mm dilation used by ``1-finetune_icon.py``) -# * ``segmentation_dir_base / / _landmark.mrk.json`` -# -- per-frame 3D Slicer Markups landmarks in LPS -# -# Outputs under ``results/``: -# * ``ants///.mha``, -# ``greedy///.mha`` and -# ``icon///.mha`` -- warped moving image -# per time point, alongside the forward/inverse transforms (``.hdf``), -# the warped ``_labelmap.mha`` and its warped -# loss-function mask ``_labelmap_mask.mha``, -# and the warped ``_landmark.mrk.json`` -# * ``registration_landmarks_.csv`` -- per-landmark squared errors -# * ``registration_dice_.csv`` -- per-label Dice -# * ``registration_summary_.csv`` -- per-(subject, method, timepoint) -# registration time, per-frame total time, mean Dice, MSE, RMSE -# * ``registration_timing_.csv`` -- per-step wall-clock seconds, -# appended live as each frame's steps complete (register, write_transforms, -# warp_image, warp_labelmap, warp_mask, dice, landmarks, frame_total) -# * ``registration_timing_summary_.csv`` -- per-(method, step) count, -# mean, and total seconds, written once at the end of the run -# -# Run interactively cell-by-cell; all paths are hard-coded. - # %% import csv -import re +import shutil import time from pathlib import Path from typing import Optional @@ -63,73 +22,61 @@ from physiomotion4d.register_images_icon import RegisterImagesICON from physiomotion4d.transform_tools import TransformTools -# %% [markdown] -# ## 1. Hard-coded paths and configuration - # %% ref_data_dir = Path("d:/PhysioMotion4D/duke_data/ref_images") src_data_dir_base = Path("d:/PhysioMotion4D/duke_data/gated_nii") segmentation_dir_base = Path("d:/PhysioMotion4D/duke_data/simple_ascardio") -_HERE = Path(__file__).parent -output_dir = _HERE / "results" -output_dir.mkdir(parents=True, exist_ok=True) +use_mask_list = [False, False, False, False, False, False, False, False] +use_labelmap_list = [True, False, True, False, True, False, True, False] + +# ICON only +use_mass_list = [False, False, False, False, False, False, False, False] + +methods_list = [ + ["Greedy"], + ["Greedy"], + ["Greedy"], + ["Greedy"], + ["Greedy"], + ["Greedy"], + ["Greedy"], + ["Greedy"], +] +number_of_iterations_ANTS_list = [ + [40, 20, 10], + [40, 20, 10], + [40, 20, 10], + [40, 20, 10], + [40, 20, 10], + [40, 20, 10], + [40, 20, 10], + [40, 20, 10], +] +number_of_iterations_greedy_list = [ + [40, 20, 10], + [40, 20, 10], + [80, 20, 10], + [80, 20, 10], + [40, 40, 10], + [40, 40, 10], + [80, 40, 5], + [80, 40, 5], +] +number_of_iterations_ICON_list = [100, 100, 100, 100, 100, 100, 100, 100] -# Reference frames in gated_nii are named ``_ref.nii.gz``; every -# other ``.nii.gz`` (excluding ``nop`` non-gated references) is a gated -# time point. Timepoint tag ``g###`` is extracted from each filename. exclude_tokens = ["nop"] ref_suffix = "_ref" -timepoint_re = re.compile(r"_g(?P[0-9]{3})") - -# Mask dilation matches 1-finetune_icon.py so any masks we have to -# derive here are identical to the ones written by the fine-tune script. -mask_dilation_mm = 3.0 -labelmap_tools = LabelmapTools() - -# Iteration schedules. Kept modest for a cohort-wide comparison; raise -# either list for higher accuracy at the cost of runtime. ANTS and Greedy -# take a multi-resolution list; ICON takes a single per-pair iterative -# optimization step count (0 disables it, using the pretrained forward pass -# alone). -number_of_iterations_ANTS = [40, 20, 10] -number_of_iterations_greedy = [40, 20, 10] -number_of_iterations_ICON = 50 - -# Optional uniGradICON checkpoint (".trch") to load instead of the default -# pretrained weights under ``network_weights/unigradicon1.0/``. When None, -# the default pretrained weights are used. icon_weights_path: Optional[Path] = None +mask_dilation_mm = 3.0 +use_crop = False +fixed_image_resolution_mm = 0.0 -methods: list[str] = ["ANTS", "Greedy", "ICON"] - -# Debug knob: when non-empty, only these patient IDs are processed. -# Set to ``[]`` (or ``None``) to run the full cohort. -debug_subjects: list[str] = [] # ["pm0002"] - -run_stamp = time.time() -detail_landmarks_file = output_dir / f"registration_landmarks_{run_stamp}.csv" -detail_dice_file = output_dir / f"registration_dice_{run_stamp}.csv" -summary_file = output_dir / f"registration_summary_{run_stamp}.csv" -# Per-step wall-clock times, appended live as each frame's steps complete. -timing_detail_file = output_dir / f"registration_timing_{run_stamp}.csv" -# Per-(method, step) timing aggregates, written once at the end of the run. -timing_summary_file = output_dir / f"registration_timing_summary_{run_stamp}.csv" -for previous in ( - detail_landmarks_file, - detail_dice_file, - summary_file, - timing_detail_file, - timing_summary_file, -): - if previous.exists(): - previous.unlink() +debug_subjects = [] # ["pm0002", "pm0003", "pm0004"] -# %% [markdown] -# ## 2. Enumerate the full patient cohort -# -# Sort ``ref_data_dir`` by filename so the patient order is stable. -# Every patient is processed -- no train/test split. +labelmap_tools = LabelmapTools() +landmark_tools = LandmarkTools() +transform_tools = TransformTools() # %% ref_files = sorted( @@ -139,6 +86,7 @@ ) all_patient_ids = [p.name[:6] for p in ref_files] print(f"Found {len(all_patient_ids)} patients under {ref_data_dir}") + if debug_subjects: cohort = [pid for pid in all_patient_ids if pid in debug_subjects] print( @@ -147,49 +95,6 @@ ) else: cohort = all_patient_ids -print(f"Patient cohort: {cohort}") - -# %% [markdown] -# ## 3. Helpers: labelmap warping, per-label Dice, landmark squared error - -# %% -landmark_tools = LandmarkTools() -transform_tools = TransformTools() - -# Per-step timing records (subject, method, timepoint, step, seconds), -# accumulated in memory for the end-of-run timing summary and mirrored live -# into timing_detail_file as each step finishes. -timing_rows: list[dict[str, object]] = [] - - -def record_step_time( - subject_id: str, - method_name: str, - timepoint: str, - step: str, - seconds: float, -) -> None: - """Report a single processing step's wall-clock time. - - Prints the time immediately, appends a row to ``timing_detail_file`` so - progress is visible while the run is still going, and stores the same - row in ``timing_rows`` for the end-of-run timing summary. - """ - print(f" [time] {step:<18}{seconds:8.2f} s", flush=True) - timing_rows.append( - { - "subject_id": subject_id, - "method": method_name, - "timepoint": timepoint, - "step": step, - "seconds": float(seconds), - } - ) - with timing_detail_file.open("a", newline="", encoding="utf-8") as fh: - writer = csv.writer(fh) - if fh.tell() == 0: - writer.writerow(["subject_id", "method", "timepoint", "step", "seconds"]) - writer.writerow([subject_id, method_name, timepoint, step, f"{seconds:.6f}"]) def per_label_dice( @@ -236,27 +141,31 @@ def warp_landmarks( forward). Returns a ``{label: (x, y, z)}`` dict in LPS. See docs/developer/transform_conventions. """ - return { - name: tuple(float(c) for c in inverse_transform.TransformPoint(point)) - for name, point in moving_landmarks.items() - } + new_landmarks = {} + for name, point in moving_landmarks.items(): + new_point = inverse_transform.TransformPoint(np.array(point)) + new_landmarks[name] = tuple(np.array(new_point).tolist()) + return new_landmarks -def landmark_squared_errors( +def landmark_rms_errors( warped_landmarks: dict[str, tuple[float, float, float]], - reference_landmarks: dict[str, tuple[float, float, float]], + fixed_landmarks: dict[str, tuple[float, float, float]], ) -> list[tuple[str, float]]: - """Return per-landmark squared Euclidean error in mm^2 between the + """Return per-landmark RMS Euclidean error in mm between the reference-space ``warped_landmarks`` and the matching reference landmarks, in sorted-name order. """ - shared = sorted(warped_landmarks.keys() & reference_landmarks.keys()) errors: list[tuple[str, float]] = [] - for name in shared: - diff = np.asarray(warped_landmarks[name], dtype=np.float64) - np.asarray( - reference_landmarks[name], dtype=np.float64 - ) - errors.append((name, float(np.dot(diff, diff)))) + for name in fixed_landmarks.keys(): + if name not in warped_landmarks: + errors.append((name, float("nan"))) + continue + diff = 0 + for i in range(3): + diff += (warped_landmarks[name][i] - fixed_landmarks[name][i]) ** 2 + errors.append((name, float(np.sqrt(diff)))) + print(f"Landmark {name} RMS error: {errors[-1][1]:.4f} mm") return errors @@ -267,540 +176,460 @@ def load_or_derive_mask(labelmap: itk.Image, mask_path: Path) -> itk.Image: 3 mm physical-radius dilation) and write it out so subsequent runs and the ICON eval reuse the same mask. """ - # Force mask update # if mask_path.exists(): - # return itk.imread(str(mask_path)) + # return itk.imread(str(mask_path)) mask = labelmap_tools.convert_labelmap_to_mask( labelmap, dilation_in_mm=mask_dilation_mm, exclude_labels=[1, 2, 3, 4], - # These are labels for the interior of the heart chambers (the LV, RV, LA, RA) + # Interior chambers of the heart: LV, RV, LA, RA ) itk.imwrite(mask, str(mask_path), compression=True) return mask -# %% [markdown] -# ## 4. Drive the comparison: every patient x every method -# -# For each patient: load the reference image, labelmap, mask, and -# landmarks; load every gated frame (excluding ``nop`` and ``_ref``) with -# its labelmap, mask, and landmarks; then register each frame to the -# reference under both backends. Each frame starts from identity so the -# ANTS-vs-Greedy comparison is independent across frames. +def crop_image_to_mask( + image: itk.Image, + mask: Optional[itk.Image] = None, + labelmap: Optional[itk.Image] = None, + margin_fraction: float = 0.1, +) -> dict[str, itk.Image]: + if mask is None: + mask_arr = itk.array_from_image(image) + print("No mask provided, using image as mask") + else: + mask_arr = itk.array_from_image(mask) + bounding_box = np.where(mask_arr > 0) + min_x = np.min(bounding_box[2]) + max_x = np.max(bounding_box[2]) + min_y = np.min(bounding_box[1]) + max_y = np.max(bounding_box[1]) + min_z = np.min(bounding_box[0]) + max_z = np.max(bounding_box[0]) + margin_x = int((max_x - min_x) * margin_fraction) + margin_y = int((max_y - min_y) * margin_fraction) + margin_z = int((max_z - min_z) * margin_fraction) + min_x -= margin_x + max_x += margin_x + min_y -= margin_y + max_y += margin_y + min_z -= margin_z + max_z += margin_z + if min_x < 0: + min_x = 0 + if min_y < 0: + min_y = 0 + if min_z < 0: + min_z = 0 + max_size = image.GetLargestPossibleRegion().GetSize() + if max_x >= max_size[0]: + max_x = max_size[0] - 1 + if max_y >= max_size[1]: + max_y = max_size[1] - 1 + if max_z >= max_size[2]: + max_z = max_size[2] - 1 + print(f"array shape: {mask_arr.shape}") + print( + f"min_x: {min_x}, max_x: {max_x}, min_y: {min_y}, max_y: {max_y}, min_z: {min_z}, max_z: {max_z}" + ) + new_image_arr = itk.array_from_image(image) + new_image_arr = new_image_arr[min_z:max_z, min_y:max_y, min_x:max_x] + new_origin = image.TransformIndexToPhysicalPoint( + [int(min_x), int(min_y), int(min_z)] + ) + new_image = itk.image_from_array(new_image_arr) + new_image.SetSpacing(image.GetSpacing()) + new_image.SetDirection(image.GetDirection()) + new_image.SetOrigin(new_origin) + + if labelmap is not None: + new_labelmap_arr = itk.array_from_image(labelmap) + new_labelmap_arr = new_labelmap_arr[min_z:max_z, min_y:max_y, min_x:max_x] + new_labelmap = itk.image_from_array(new_labelmap_arr) + new_labelmap.CopyInformation(new_image) + else: + new_labelmap = None + + if mask is not None: + new_mask_arr = itk.array_from_image(mask) + new_mask_arr = new_mask_arr[min_z:max_z, min_y:max_y, min_x:max_x] + new_mask = itk.image_from_array(new_mask_arr) + new_mask.CopyInformation(new_image) + else: + new_mask = None -# %% -summary_rows: list[dict[str, object]] = [] + return { + "image": new_image, + "labelmap": new_labelmap, + "mask": new_mask, + } -# (subject_id, method, timepoint) for frames that produced no usable -# landmark errors -- either no landmark file or no labels shared with the -# reference. Echoed in a highlighted block at the end of the run. -frames_missing_landmarks: list[tuple[str, str, str]] = [] +# %% +_HERE = Path(__file__).parent for subject_index, subject_id in enumerate(cohort): print(f"\n=== Subject {subject_index + 1}/{len(cohort)}: {subject_id} ===") src_dir = src_data_dir_base / subject_id seg_dir = segmentation_dir_base / subject_id - if not src_dir.is_dir(): - print(f" Skipping {subject_id}: source dir {src_dir} not found") - continue - if not seg_dir.is_dir(): - print(f" Skipping {subject_id}: segmentation dir {seg_dir} not found") - continue - - # Locate this patient's reference frame in gated_nii (matches the - # `_ref.nii.gz` filename under ref_data_dir). ref_file = next((p for p in ref_files if p.name.startswith(subject_id)), None) - if ref_file is None: - print(f" Skipping {subject_id}: no reference image found") - continue ref_stem = ref_file.name[:-7] ref_labelmap_path = seg_dir / f"{ref_stem}_labelmap.nii.gz" ref_mask_path = seg_dir / f"{ref_stem}_labelmap_mask.nii.gz" ref_landmark_path = seg_dir / f"{ref_stem}_landmark.mrk.json" - if not ref_labelmap_path.exists() or not ref_landmark_path.exists(): - print( - f" Skipping {subject_id}: missing reference labelmap or " - f"landmarks under {seg_dir}" - ) - continue + + fixed_output_dir = _HERE / "fixed" + fixed_output_dir.mkdir(parents=True, exist_ok=True) fixed_image = itk.imread(str(ref_file), pixel_type=itk.F) - fixed_labelmap = itk.imread(str(ref_labelmap_path)) - fixed_mask = load_or_derive_mask(fixed_labelmap, ref_mask_path) - reference_landmarks = landmark_tools.read_landmarks_3dslicer(ref_landmark_path) - # Gated moving frames (exclude `nop` and the `_ref` frame itself). + fixed_labelmap = None + if ref_labelmap_path.exists(): + fixed_labelmap = itk.imread(str(ref_labelmap_path)) + + fixed_mask = None + if ref_mask_path.exists(): + fixed_mask = load_or_derive_mask(fixed_labelmap, ref_mask_path) + + fixed_landmarks = None + if ref_landmark_path.exists(): + fixed_landmarks = landmark_tools.read_landmarks_3dslicer(ref_landmark_path) + + if use_crop: + cropped = crop_image_to_mask( + fixed_image, mask=fixed_mask, labelmap=fixed_labelmap + ) + fixed_image = cropped["image"] + fixed_labelmap = cropped["labelmap"] + fixed_mask = cropped["mask"] + + if fixed_image_resolution_mm > 0.0: + fixed_image_size = fixed_image.GetLargestPossibleRegion().GetSize() + fixed_image_size[0] = int( + fixed_image_size[0] + * fixed_image.GetSpacing()[0] + / fixed_image_resolution_mm + ) + fixed_image_size[1] = int( + fixed_image_size[1] + * fixed_image.GetSpacing()[1] + / fixed_image_resolution_mm + ) + fixed_image_size[2] = int( + fixed_image_size[2] + * fixed_image.GetSpacing()[2] + / fixed_image_resolution_mm + ) + fixed_image = itk.resample_image_filter( + fixed_image, + output_direction=fixed_image.GetDirection(), + output_origin=fixed_image.GetOrigin(), + size=fixed_image_size, + output_spacing=[ + fixed_image_resolution_mm, + fixed_image_resolution_mm, + fixed_image_resolution_mm, + ], + default_pixel_value=-1000, + ) + if fixed_labelmap is not None: + fixed_labelmap = itk.resample_image_filter( + fixed_labelmap, + output_parameters_from_image=fixed_image, + default_pixel_value=0, + interpolator=itk.NearestNeighborInterpolateImageFunction.New( + fixed_labelmap + ), + ) + if fixed_mask is not None: + fixed_mask = itk.resample_image_filter( + fixed_mask, + output_parameters_from_image=fixed_image, + default_pixel_value=0, + interpolator=itk.NearestNeighborInterpolateImageFunction.New( + fixed_mask + ), + ) + + print(f"Writing reference image to {f'{subject_id}_ref.nii.gz'}") + itk.imwrite( + fixed_image, + str(fixed_output_dir / f"{subject_id}_ref.nii.gz"), + compression=True, + ) + if fixed_labelmap is not None: + print(f"Writing reference labelmap to {f'{subject_id}_ref_labelmap.nii.gz'}") + itk.imwrite( + fixed_labelmap, + str(fixed_output_dir / f"{subject_id}_ref_labelmap.nii.gz"), + compression=True, + ) + if fixed_mask is not None: + print(f"Writing reference mask to {f'{subject_id}_ref_mask.nii.gz'}") + itk.imwrite( + fixed_mask, + str(fixed_output_dir / f"{subject_id}_ref_mask.nii.gz"), + compression=True, + ) + if fixed_landmarks is not None: + print( + f"Writing reference landmarks to {f'{subject_id}_ref_landmarks.mrk.json'}" + ) + landmark_tools.write_landmarks_3dslicer( + fixed_landmarks, + str(fixed_output_dir / f"{subject_id}_ref_landmarks.mrk.json"), + ) + gated_files = sorted( p for p in src_dir.glob("*.nii.gz") if not any(token in p.name for token in exclude_tokens) and not p.name.endswith(f"{ref_suffix}.nii.gz") ) - moving_records: list[dict[str, object]] = [] - for image_path in gated_files: + + print(f"Found {len(gated_files)} gated images under {src_dir}") + + for image_index, image_path in enumerate(gated_files): stem = image_path.name[:-7] + + print(f"\n\n *** Processing {stem} ***\n\n") + labelmap_path = seg_dir / f"{stem}_labelmap.nii.gz" mask_path = seg_dir / f"{stem}_labelmap_mask.nii.gz" landmark_path = seg_dir / f"{stem}_landmark.mrk.json" - if not labelmap_path.exists(): - print(f" Dropping {stem}: no labelmap at {labelmap_path}") - continue - match = timepoint_re.search(image_path.name) - if match is None: - print(f" Dropping {stem}: no g### timepoint tag in name") - continue - moving_records.append( - { - "stem": stem, - "timepoint": match.group("timepoint"), - "image_path": image_path, - "labelmap_path": labelmap_path, - "mask_path": mask_path, - "landmark_path": landmark_path if landmark_path.exists() else None, - } - ) - if not moving_records: - print(f" Skipping {subject_id}: no usable gated frames") - continue - - print(f" {len(moving_records)} moving frames; reference {ref_file.name}") - - print(f" Loading {len(moving_records)} moving images / labelmaps / masks ...") - moving_images = [] - moving_labelmaps = [] - moving_masks = [] - moving_landmarks_list: list[Optional[dict[str, tuple[float, float, float]]]] = [] - for r_index, r in enumerate(moving_records): - print( - f" [{r_index + 1}/{len(moving_records)}] g{r['timepoint']} {r['stem']}" - ) - moving_image = itk.imread(str(r["image_path"]), pixel_type=itk.F) - labelmap = itk.imread(str(r["labelmap_path"])) - moving_images.append(moving_image) - moving_labelmaps.append(labelmap) - moving_masks.append(load_or_derive_mask(labelmap, r["mask_path"])) - landmark_path = r["landmark_path"] - if landmark_path is None: - moving_landmarks_list.append(None) - else: - moving_landmarks_list.append( - landmark_tools.read_landmarks_3dslicer(landmark_path) - ) - - for method_name in methods: - print(f"\n --- Method: {method_name} ---") - if method_name == "ANTS": - reg = RegisterImagesANTS() - reg.set_number_of_iterations(number_of_iterations_ANTS) - reg.set_transform_type("Deformable") - # NCC ("CC") beats MeanSquares for same-modality CT registration. - reg.set_metric("CC") - elif method_name == "Greedy": - reg = RegisterImagesGreedy() - reg.set_number_of_iterations(number_of_iterations_greedy) - reg.set_transform_type("Deformable") - # NCC ("CC") beats MeanSquares for same-modality CT registration. - reg.set_metric("CC") - else: # ICON: GPU deep-learning deformable registration. - reg = RegisterImagesICON() - reg.set_number_of_iterations(number_of_iterations_ICON) - if icon_weights_path is not None: - reg.set_weights_path(str(icon_weights_path)) - reg.set_modality("ct") - reg.set_mask_dilation(mask_dilation_mm) - reg.set_fixed_image(fixed_image) - reg.set_fixed_mask(fixed_mask) - - method_dir = output_dir / method_name.lower() / subject_id - method_dir.mkdir(parents=True, exist_ok=True) - - method_t_start = time.perf_counter() - for index, record in enumerate(moving_records): - timepoint = record["timepoint"] - stem = record["stem"] - print( - f" [{method_name} {index + 1}/{len(moving_records)}] " - f"g{timepoint} registering ...", - flush=True, - ) - frame_total_start = time.perf_counter() - frame_t_start = frame_total_start - reg_result = reg.register( - moving_image=moving_images[index], - moving_mask=moving_masks[index], - ) - frame_elapsed = time.perf_counter() - frame_t_start - - forward_transform = reg_result["forward_transform"] - inverse_transform = reg_result["inverse_transform"] - frame_loss = float(reg_result["loss"]) - print(f" done in {frame_elapsed:.1f} s, loss={frame_loss:.4f}") - record_step_time( - subject_id, method_name, timepoint, "register", frame_elapsed + moving_image_name = str(image_path) + moving_image = itk.imread(moving_image_name, pixel_type=itk.F) + moving_labelmap = None + if fixed_labelmap is not None and labelmap_path.exists(): + moving_labelmap = itk.imread(str(labelmap_path)) + moving_mask = None + if fixed_mask is not None and mask_path.exists(): + moving_mask = load_or_derive_mask(moving_labelmap, mask_path) + moving_landmarks = None + if fixed_landmarks is not None and landmark_path.exists(): + moving_landmarks = landmark_tools.read_landmarks_3dslicer(landmark_path) + + if use_crop: + cropped = crop_image_to_mask( + moving_image, mask=moving_mask, labelmap=moving_labelmap ) + moving_image = cropped["image"] + moving_labelmap = cropped["labelmap"] + moving_mask = cropped["mask"] + + for cond_index in range(len(use_mask_list)): + methods = methods_list[cond_index] + + use_mask = use_mask_list[cond_index] + use_labelmap = use_labelmap_list[cond_index] + use_mass = use_mass_list[cond_index] + + number_of_iterations_ANTS = number_of_iterations_ANTS_list[cond_index] + number_of_iterations_greedy = number_of_iterations_greedy_list[cond_index] + number_of_iterations_ICON = number_of_iterations_ICON_list[cond_index] + + cond = "_" + if use_mask: + cond += "m" + if use_labelmap: + cond += "l" + if use_mass: + cond += "p" + if use_crop: + cond += "c" + if cond == "_": + cond += "raw" - step_t_start = time.perf_counter() - itk.transformwrite( - forward_transform, - str(method_dir / f"{stem}_fwd.hdf"), - compression=True, - ) - itk.transformwrite( - inverse_transform, - str(method_dir / f"{stem}_inv.hdf"), - compression=True, - ) - record_step_time( - subject_id, - method_name, - timepoint, - "write_transforms", - time.perf_counter() - step_t_start, - ) - - # Warp the moving image into reference space and save it - # (forward_transform resamples the moving image onto the fixed grid). - step_t_start = time.perf_counter() - warped_image = transform_tools.transform_image( - moving_images[index], - forward_transform, - fixed_image, - interpolation_method="linear", - ) - itk.imwrite( - warped_image, - str(method_dir / f"{stem}.mha"), - compression=True, - ) - record_step_time( - subject_id, - method_name, - timepoint, - "warp_image", - time.perf_counter() - step_t_start, + print( + f"\n\n ***** {cond_index + 1}/{len(use_mask_list)}: results{cond} *****\n\n" ) - # Warp the moving labelmap onto the fixed grid (forward_transform; - # nearest neighbour preserves label IDs) for per-label Dice. - step_t_start = time.perf_counter() - warped_labelmap = transform_tools.transform_image( - moving_labelmaps[index], - forward_transform, - fixed_labelmap, - interpolation_method="nearest", - ) - itk.imwrite( - warped_labelmap, - str(method_dir / f"{stem}_labelmap.mha"), - compression=True, - ) - record_step_time( - subject_id, - method_name, - timepoint, - "warp_labelmap", - time.perf_counter() - step_t_start, - ) + output_dir = _HERE / f"results{cond}" + detail_landmarks_file = output_dir / "registration_landmarks_init.csv" + detail_dice_file = output_dir / "registration_dice_init.csv" + if subject_index == 0 and image_index == 0 and cond_index == 0: + if output_dir.exists(): + shutil.rmtree(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + for method_index, method_name in enumerate(methods): + print(f"\n\n --- {method_name} --- \n\n") + if method_name == "ANTS": + reg = RegisterImagesANTS() + reg.set_number_of_iterations(number_of_iterations_ANTS) + num_iters_str = ".".join(str(n) for n in number_of_iterations_ANTS) + reg.set_transform_type("Deformable") + # NCC ("CC") beats MeanSquares for same-modality CT registration. + if use_labelmap: + reg.set_metric("MeanSquares") + else: + reg.set_metric("CC") + elif method_name == "Greedy": + reg = RegisterImagesGreedy() + reg.set_number_of_iterations(number_of_iterations_greedy) + print(f"Number of iterations: {number_of_iterations_greedy}") + num_iters_str = ".".join( + str(n) for n in number_of_iterations_greedy + ) + reg.set_transform_type("Deformable") + # NCC ("CC") beats MeanSquares for same-modality CT registration. + if use_labelmap: + reg.set_metric("MeanSquares") + else: + reg.set_metric("CC") + else: # ICON: GPU deep-learning deformable registration. + reg = RegisterImagesICON() + reg.set_number_of_iterations(number_of_iterations_ICON) + num_iters_str = ".".join(str(n) for n in number_of_iterations_ICON) + reg.set_multi_modality(False) + reg.set_mass_preservation(use_mass) + if icon_weights_path is not None: + reg.set_weights_path(str(icon_weights_path)) + reg.set_modality("ct") + reg.set_mask_dilation(0) # Already dilated + reg.set_fixed_image(fixed_image) + if use_mask: + reg.set_fixed_mask(fixed_mask) + if use_labelmap: + reg.set_fixed_labelmap(fixed_labelmap) + + method_dir_name = str(method_name.lower()) + "_" + num_iters_str + method_dir = output_dir / method_dir_name / subject_id + print(f"Method directory: {method_dir}") + method_dir.mkdir(parents=True, exist_ok=True) + print(f"Registering {stem} with {method_name}...") + time_start = time.perf_counter() + reg_result = reg.register( + moving_image=moving_image, + moving_mask=moving_mask if use_mask else None, + moving_labelmap=moving_labelmap if use_labelmap else None, + ) + time_elapsed = time.perf_counter() - time_start + print(f" ...finished registration in {time_elapsed:.1f}s") - # Warp the moving loss-function mask onto the fixed grid - # (forward_transform; nearest neighbour preserves the binary ROI) - # so downstream fine-tuning reuses it instead of re-deriving a - # mask from the warped labelmap. - step_t_start = time.perf_counter() - warped_mask = transform_tools.transform_image( - moving_masks[index], - forward_transform, - fixed_mask, - interpolation_method="nearest", - ) - itk.imwrite( - warped_mask, - str(method_dir / f"{stem}_labelmap_mask.mha"), - compression=True, - ) - record_step_time( - subject_id, - method_name, - timepoint, - "warp_mask", - time.perf_counter() - step_t_start, - ) + forward_transform = reg_result["forward_transform"] + inverse_transform = reg_result["inverse_transform"] + loss = float(reg_result["loss"]) - step_t_start = time.perf_counter() - dice_by_label = per_label_dice(fixed_labelmap, warped_labelmap) - with detail_dice_file.open("a", newline="", encoding="utf-8") as fh: - writer = csv.writer(fh) - if fh.tell() == 0: - writer.writerow( - ["subject_id", "method", "timepoint", "label", "dice"] - ) - for label, dice in dice_by_label.items(): - writer.writerow([subject_id, method_name, timepoint, label, dice]) - mean_dice = ( - float(np.mean(list(dice_by_label.values()))) - if dice_by_label - else float("nan") - ) - record_step_time( - subject_id, - method_name, - timepoint, - "dice", - time.perf_counter() - step_t_start, - ) + print(f"Writing results to {method_dir / f'{stem}_init_*.*'}") - # Warp the moving landmarks into reference space, save them next - # to the transforms, then score squared error vs the reference. - step_t_start = time.perf_counter() - moving_landmarks = moving_landmarks_list[index] - if moving_landmarks is None: - sq_errors: list[tuple[str, float]] = [] - else: - warped_landmarks = warp_landmarks(inverse_transform, moving_landmarks) - landmark_tools.write_landmarks_3dslicer( - warped_landmarks, - str(method_dir / f"{stem}_landmark.mrk.json"), + itk.transformwrite( + forward_transform, + str(method_dir / f"{stem}_init_fwd.hdf"), + compression=True, ) - sq_errors = landmark_squared_errors( - warped_landmarks, reference_landmarks + itk.transformwrite( + inverse_transform, + str(method_dir / f"{stem}_init_inv.hdf"), + compression=True, ) - with detail_landmarks_file.open("a", newline="", encoding="utf-8") as fh: - writer = csv.writer(fh) - if fh.tell() == 0: - writer.writerow( - [ - "subject_id", - "method", - "timepoint", - "name", - "sq_err_mm2", - ] - ) - for name, sq_err in sq_errors: - writer.writerow([subject_id, method_name, timepoint, name, sq_err]) - record_step_time( - subject_id, - method_name, - timepoint, - "landmarks", - time.perf_counter() - step_t_start, - ) - sq_values = np.asarray([e for _, e in sq_errors], dtype=np.float64) - if sq_values.size: - mse_mm2 = float(np.mean(sq_values)) - rmse_mm = float(np.sqrt(mse_mm2)) - else: - mse_mm2 = float("nan") - rmse_mm = float("nan") - # Highlight frames with no usable landmarks so they are not - # silently scored as NaN in the CSV / summary table. - reason = ( - "no landmark file" - if moving_landmarks is None - else "no landmarks shared with reference" + warped_image = transform_tools.transform_image( + moving_image, + forward_transform, + fixed_image, + interpolation_method="linear", ) - frames_missing_landmarks.append((subject_id, method_name, timepoint)) - print( - f" >>> WARNING: {subject_id} {method_name} " - f"g{timepoint} has NO landmarks ({reason})", - flush=True, + itk.imwrite( + warped_image, + str(method_dir / f"{stem}_init.mha"), + compression=True, ) - frame_total = time.perf_counter() - frame_total_start - record_step_time( - subject_id, method_name, timepoint, "frame_total", frame_total - ) - - summary_rows.append( - { - "subject_id": subject_id, - "method": method_name, - "timepoint": timepoint, - "time_sec": float(frame_elapsed), - "frame_total_sec": float(frame_total), - "loss": frame_loss, - "n_labels": int(len(dice_by_label)), - "mean_dice": mean_dice, - "n_landmarks": int(sq_values.size), - "mse_mm2": mse_mm2, - "rmse_mm": rmse_mm, - } - ) - - method_elapsed = time.perf_counter() - method_t_start - print( - f" [{method_name}] subject {subject_id} total " - f"{method_elapsed:.1f} s " - f"({method_elapsed / len(moving_records):.1f} s/frame)" - ) - -# %% [markdown] -# ## 5. Write the per-(subject, method, timepoint) summary CSV - -# %% -if summary_rows: - with summary_file.open("w", newline="", encoding="utf-8") as fh: - writer = csv.DictWriter(fh, fieldnames=list(summary_rows[0].keys())) - writer.writeheader() - writer.writerows(summary_rows) - print(f"\nWrote summary: {summary_file}") - print(f"Wrote landmarks: {detail_landmarks_file}") - print(f"Wrote dice: {detail_dice_file}") - print(f"Wrote timing: {timing_detail_file}") -else: - print("\nNo frames processed; nothing to summarize.") - -# %% [markdown] -# ## 5b. Highlight frames that produced no landmark errors - -# %% -if frames_missing_landmarks: - banner = "!" * 70 - print(f"\n{banner}") - print( - f"WARNING: {len(frames_missing_landmarks)} frame(s) missing ALL " - f"landmarks (scored as NaN):" - ) - for subject_id, method_name, timepoint in frames_missing_landmarks: - print(f" - {subject_id} {method_name} g{timepoint}") - print(banner) -else: - print("\nAll processed frames had at least one scored landmark.") - -# %% [markdown] -# ## 6. Per-method aggregate table across the whole cohort -# -# Reports mean per-frame registration time, mean / median / p95 of the -# squared landmark errors (mm^2), the matching RMSE in mm, and the mean -# per-label Dice averaged across (subject, timepoint, label) entries. + average_dice = float("nan") + if fixed_labelmap is not None and moving_labelmap is not None: + warped_labelmap = transform_tools.transform_image( + moving_labelmap, + forward_transform, + fixed_labelmap, + interpolation_method="nearest", + ) + itk.imwrite( + warped_labelmap, + str(method_dir / f"{stem}_init_labelmap.mha"), + compression=True, + ) + dice_by_label = per_label_dice(fixed_labelmap, warped_labelmap) + with detail_dice_file.open("a", newline="", encoding="utf-8") as fh: + writer = csv.writer(fh) + if fh.tell() == 0: + writer.writerow( + ["subject_id", "method", "stem", "label", "dice"] + ) + for label, dice in dice_by_label.items(): + writer.writerow( + [ + subject_id, + method_name + "_" + num_iters_str, + stem, + label, + dice, + ] + ) + if dice_by_label: + average_dice = float(np.mean(list(dice_by_label.values()))) + + if fixed_mask is not None and moving_mask is not None: + warped_mask = transform_tools.transform_image( + moving_mask, + forward_transform, + fixed_mask, + interpolation_method="nearest", + ) + itk.imwrite( + warped_mask, + str(method_dir / f"{stem}_init_mask.mha"), + compression=True, + ) -# %% -if summary_rows: - sq_by_method: dict[str, list[float]] = {} - with detail_landmarks_file.open(newline="", encoding="utf-8") as fh: - for row in csv.DictReader(fh): - sq_by_method.setdefault(row["method"], []).append(float(row["sq_err_mm2"])) - - dice_by_method: dict[str, list[float]] = {} - with detail_dice_file.open(newline="", encoding="utf-8") as fh: - for row in csv.DictReader(fh): - dice_by_method.setdefault(row["method"], []).append(float(row["dice"])) - - time_by_method: dict[str, list[float]] = {} - for row in summary_rows: - method_name = str(row["method"]) - time_by_method.setdefault(method_name, []).append(float(row["time_sec"])) - - header = ( - f"{'Method':<10}{'N_lm':>8}{'MSE(mm2)':>12}{'RMSE(mm)':>12}" - f"{'p95(mm2)':>12}{'meanDice':>12}{'sec/frame':>12}" - ) - print() - print("=" * len(header)) - print(f"Pre-registration comparison ({len(all_patient_ids)} patients)") - print("=" * len(header)) - print(header) - print("-" * len(header)) - for method_name in methods: - sq_arr = np.asarray(sq_by_method.get(method_name, []), dtype=np.float64) - dice_arr = np.asarray(dice_by_method.get(method_name, []), dtype=np.float64) - time_arr = np.asarray(time_by_method.get(method_name, []), dtype=np.float64) - if sq_arr.size == 0: - print(f"{method_name:<10}{0:>8}{'':>12}{'':>12}{'':>12}{'':>12}{'':>12}") - continue - mse = float(np.mean(sq_arr)) - rmse = float(np.sqrt(mse)) - p95 = float(np.percentile(sq_arr, 95)) - mean_dice_val = float(np.mean(dice_arr)) if dice_arr.size else float("nan") - mean_time = float(np.mean(time_arr)) if time_arr.size else float("nan") - print( - f"{method_name:<10}" - f"{sq_arr.size:>8}" - f"{mse:>12.3f}" - f"{rmse:>12.3f}" - f"{p95:>12.3f}" - f"{mean_dice_val:>12.3f}" - f"{mean_time:>12.2f}" - ) - print("=" * len(header)) + average_rms_errors = float("nan") + if fixed_landmarks is not None and moving_landmarks is not None: + # Landmarks live in LPS world space, unaffected by cropping, so + # the uncropped moving_landmarks are warped here. + warped_landmarks = warp_landmarks( + inverse_transform, moving_landmarks + ) + landmark_tools.write_landmarks_3dslicer( + warped_landmarks, + str(method_dir / f"{stem}_init_landmarks.mrk.json"), + ) + rms_errors = landmark_rms_errors(warped_landmarks, fixed_landmarks) + with detail_landmarks_file.open( + "a", newline="", encoding="utf-8" + ) as fh: + writer = csv.writer(fh) + if fh.tell() == 0: + writer.writerow( + [ + "subject_id", + "method", + "stem", + "name", + "rms_err_mm", + ] + ) + for name, error in rms_errors: + writer.writerow( + [ + subject_id, + method_name + "_" + num_iters_str, + stem, + name, + error, + ] + ) + rms_values = np.asarray( + [e for _, e in rms_errors], dtype=np.float64 + ) + if rms_values.size and not np.all(np.isnan(rms_values)): + average_rms_errors = float(np.nanmean(rms_values)) -# %% [markdown] -# ## 7. Per-(method, step) timing summary -# -# Aggregates the live per-step timings into mean and total wall-clock -# seconds per (method, step), printed as a table and written to -# ``timing_summary_file``. ``frame_total`` is the end-to-end per-frame -# time (register + all warps/writes + scoring); the other rows are its -# components. + print( + f"Method: {method_name}_{num_iters_str}, Subject: {subject_id}, " + f"Timepoint: {stem}, time: {time_elapsed:.1f}s, " + f"loss: {loss:.4f}, Dice: {average_dice:.4f}, " + f"RMS(mm): {average_rms_errors:.4f}" + ) # %% -if timing_rows: - # Preserve the pipeline order in which steps are timed; any unexpected - # step name is appended in first-seen order so nothing is dropped. - step_order = [ - "register", - "write_transforms", - "warp_image", - "warp_labelmap", - "warp_mask", - "dice", - "landmarks", - "frame_total", - ] - seconds_by_method_step: dict[str, dict[str, list[float]]] = {} - for row in timing_rows: - method_name = str(row["method"]) - step = str(row["step"]) - seconds = float(row["seconds"]) - seconds_by_method_step.setdefault(method_name, {}).setdefault(step, []).append( - seconds - ) - if step not in step_order: - step_order.append(step) - - timing_summary_rows: list[dict[str, object]] = [] - timing_header = ( - f"{'Method':<10}{'Step':<18}{'N':>6}{'mean_sec':>12}{'total_sec':>12}" - ) - print() - print("=" * len(timing_header)) - print("Timing summary (wall-clock seconds)") - print("=" * len(timing_header)) - print(timing_header) - print("-" * len(timing_header)) - for method_name in methods: - step_times = seconds_by_method_step.get(method_name, {}) - if not step_times: - continue - for step in step_order: - values = step_times.get(step) - if not values: - continue - arr = np.asarray(values, dtype=np.float64) - mean_sec = float(np.mean(arr)) - total_sec = float(np.sum(arr)) - timing_summary_rows.append( - { - "method": method_name, - "step": step, - "n": int(arr.size), - "mean_sec": mean_sec, - "total_sec": total_sec, - } - ) - print( - f"{method_name:<10}{step:<18}{arr.size:>6}" - f"{mean_sec:>12.2f}{total_sec:>12.2f}" - ) - print("-" * len(timing_header)) - print("=" * len(timing_header)) - - with timing_summary_file.open("w", newline="", encoding="utf-8") as fh: - writer = csv.DictWriter(fh, fieldnames=list(timing_summary_rows[0].keys())) - writer.writeheader() - writer.writerows(timing_summary_rows) - print(f"Wrote timing summary: {timing_summary_file}") diff --git a/experiments/LongitudinalRegistration/2-finetune_icon.py b/experiments/LongitudinalRegistration/2-finetune_icon.py index 8ed8f6d..eede97d 100644 --- a/experiments/LongitudinalRegistration/2-finetune_icon.py +++ b/experiments/LongitudinalRegistration/2-finetune_icon.py @@ -18,7 +18,7 @@ # # In addition to the original ``gated_nii`` frames, each patient's training # group is augmented with that patient's ANTS- and Greedy-warped frames -# written by ``1-preregistration.py`` (warped image + labelmap per gated +# written by ``1-initial_registration.py`` (warped image + labelmap per gated # frame, under ``output_dir / / ``). Because the warped # frames are merged into the *same* ``subject_id`` group, uniGradICON pairs the # original gated frames and both backends' pre-registered frames together. @@ -45,16 +45,16 @@ # the uniGradICON ``checkpoints/`` tree. experiment_dir resolves to # ``output_dir / fine_tune_name``. _HERE = Path(__file__).parent -output_dir = _HERE / "results" -fine_tune_name = "icon_finetuned" +output_dir = _HERE / "results_finetuning" +fine_tune_name = "icon_finetuning" -# Pre-registration augmentation: ``1-preregistration.py`` warps every gated +# Pre-registration augmentation: ``1-initial_registration.py`` warps every gated # moving frame into reference space with these backends and writes the warped -# image + labelmap under ``preregistration_dir / .lower() / +# image + labelmap under ``initial_registration_dir / .lower() / # ``. Those warped frames are merged into each patient's training # group below (section 4b). -preregistration_dir = output_dir -preregistration_methods = ["ANTS", "greedy"] +initial_registration_dir = output_dir +initial_registration_methods = ["Greedy"] # Fixed train/test split: sort patients in ``ref_data_dir`` by filename; # first 80% are train, last 20% are test. ``2-recon_4d_icon_eval.py`` applies @@ -98,12 +98,12 @@ print(f" Test (last {len(test_subjects)}): {test_subjects}") # %% [markdown] -# ## 3. Gather the train cohort's gated frames and labelmaps +# ## 3. Gather the train cohort's gated frames and labelmaps and masks # # For each train-cohort patient, list gated frames in # ``src_data_dir_base / `` (excluding ``"nop"`` non-gated # references) and pair each frame with its -# ``_labelmap.nii.gz`` under ``segmentation_dir_base / ``. +# ``_labelmap.nii.gz`` and ``_mask.nii.gz`` under ``segmentation_dir_base / ``. # Patients with no source directory or no valid frames are skipped here only # — they remain part of the canonical train list above, but contribute no # training data. Missing labelmaps are recorded as ``None`` so the workflow @@ -111,53 +111,16 @@ # %% train_image_files: list[list[str]] = [] -train_segmentation_files: list[list[Optional[str]]] = [] +train_labelmap_files: list[list[Optional[str]]] = [] +train_mask_files: list[list[Optional[str]]] = [] valid_train_subjects: list[str] = [] -for patient_id in train_subjects: - src_dir = src_data_dir_base / patient_id - seg_dir = segmentation_dir_base / patient_id - - if not src_dir.is_dir(): - print(f" Skipping {patient_id}: source dir {src_dir} not found") - continue - - frame_names = sorted( - f for f in os.listdir(src_dir) if "nop" not in f and f.endswith(".nii.gz") - ) - if not frame_names: - print(f" Skipping {patient_id}: no valid frames in {src_dir}") - continue - - image_paths = [str(src_dir / f) for f in frame_names] - seg_paths: list[Optional[str]] = [] - for f in frame_names: - labelmap = seg_dir / f.replace(".nii.gz", "_labelmap.nii.gz") - seg_paths.append(str(labelmap) if labelmap.exists() else None) - - train_image_files.append(image_paths) - train_segmentation_files.append(seg_paths) - valid_train_subjects.append(patient_id) - - n_seg = sum(1 for s in seg_paths if s is not None) - print(f" {patient_id}: {len(image_paths)} frames, {n_seg} with labelmap") - -# %% [markdown] -# ## 4. Pre-compute loss-function masks next to each labelmap -# -# Use :meth:`LabelmapTools.convert_labelmap_to_mask` (``>0`` threshold + 5 mm -# physical-radius dilation) to derive each frame's binary heart-ROI mask and -# write it as ``_mask.nii.gz`` in the labelmap's own directory. -# Pre-computing here means the workflow does not have to re-derive masks -# during ``run_fine_tuning`` and the same masks are reused by downstream -# evaluation scripts. - -# %% -mask_dilation_mm = 5.0 +mask_dilation_mm = 3.0 labelmap_tools = LabelmapTools() -def derive_mask_for(labelmap_path: Path) -> str: +# %% +def load_or_derive_mask(labelmap_path: Path) -> str: """Create (or reuse) a loss-function mask next to ``labelmap_path``. Thresholds the labelmap at ``>0`` and dilates by ``mask_dilation_mm`` mm @@ -167,6 +130,9 @@ def derive_mask_for(labelmap_path: Path) -> str: (pre-registration warped labelmaps). Returns the mask path as a string; existing masks on disk are reused unmodified. """ + if not labelmap_path.exists(): + return None + name = labelmap_path.name if name.endswith(".nii.gz"): stem = name[:-7] @@ -183,32 +149,13 @@ def derive_mask_for(labelmap_path: Path) -> str: return str(mask_p) -train_mask_files: list[list[Optional[str]]] = [] -for seg_paths in train_segmentation_files: - train_mask_files.append( - [derive_mask_for(Path(s)) if s is not None else None for s in seg_paths] - ) - -# %% [markdown] -# ## 4b. Merge ANTS / Greedy pre-registered frames into each training group -# -# ``1-preregistration.py`` warps every gated moving frame into reference space -# with the ANTS and Greedy backends, writing ``.mha`` (warped image), -# ``_labelmap.mha`` (warped labelmap), and ``_deformation_grid.mha`` -# under ``preregistration_dir / / ``. Here those warped -# frames + labelmaps (with derived loss masks) are appended to the *same* -# patient's training group, so uniGradICON pairs the original gated frames and -# both backends' pre-registered frames together (they share a ``subject_id``). -# Patients/methods with no pre-registration output on disk are skipped. - - # %% def gather_warped_frames(method_dir: Path) -> tuple[list[str], list[Optional[str]]]: """Return ``(warped_image_paths, warped_labelmap_paths)`` for one - ``preregistration_dir / / `` directory. + ``initial_registration_dir / / `` directory. Enumerates the warped moving images (``.mha``), excluding the - ``_labelmap.mha``, ``_labelmap_mask.mha``, and ``_deformation_grid.mha`` + ``_labelmap.mha`` and ``_mask.mha`` companions, and pairs each with its ``_labelmap.mha`` (``None`` when that labelmap is absent). Returns empty lists when ``method_dir`` does not exist. @@ -217,70 +164,91 @@ def gather_warped_frames(method_dir: Path) -> tuple[list[str], list[Optional[str return [], [] companion_suffixes = ( "_labelmap.mha", - "_labelmap_mask.mha", - "_deformation_grid.mha", + "_mask.mha", ) image_paths: list[str] = [] labelmap_paths: list[Optional[str]] = [] - for mha in sorted(method_dir.glob("*.mha")): - if mha.name.endswith(companion_suffixes): + mask_paths: list[Optional[str]] = [] + for image in sorted(method_dir.glob("*.mha")): + if image.name.endswith(companion_suffixes): continue - stem = mha.name[:-4] + stem = image.name[:-4] labelmap = method_dir / f"{stem}_labelmap.mha" - image_paths.append(str(mha)) + mask = method_dir / f"{stem}_mask.mha" + image_paths.append(str(image)) labelmap_paths.append(str(labelmap) if labelmap.exists() else None) - return image_paths, labelmap_paths + mask_paths.append(str(mask) if mask.exists() else None) + return image_paths, labelmap_paths, mask_paths + + +# %% +train_mask_files: list[list[Optional[str]]] = [] +for labelmap_paths in train_labelmap_files: + train_mask_files.append( + [ + load_or_derive_mask(Path(s)) if s is not None else None + for s in labelmap_paths + ] + ) + +for patient_id in train_subjects: + src_dir = src_data_dir_base / patient_id + seg_dir = segmentation_dir_base / patient_id + + if not src_dir.is_dir(): + print(f" Skipping {patient_id}: source dir {src_dir} not found") + continue + + frame_names = sorted( + f for f in os.listdir(src_dir) if "nop" not in f and f.endswith(".nii.gz") + ) + if not frame_names: + print(f" Skipping {patient_id}: no valid frames in {src_dir}") + continue + + image_paths = [str(src_dir / f) for f in frame_names] + labelmap_paths: list[Optional[str]] = [] + mask_paths: list[Optional[str]] = [] + for f in frame_names: + labelmap = seg_dir / f.replace(".nii.gz", "_labelmap.nii.gz") + labelmap_paths.append(str(labelmap) if labelmap.exists() else None) + mask = load_or_derive_mask(labelmap) + mask_paths.append(str(mask) if mask.exists() else None) + + train_image_files.append(image_paths) + train_labelmap_files.append(labelmap_paths) + valid_train_subjects.append(patient_id) + + n_seg = sum(1 for s in labelmap_paths if s is not None) + print(f" {patient_id}: {len(image_paths)} frames, {n_seg} with labelmap") for subject_index, patient_id in enumerate(valid_train_subjects): - for method_name in preregistration_methods: - method_dir = preregistration_dir / method_name.lower() / patient_id - warped_images, warped_labelmaps = gather_warped_frames(method_dir) + for method_name in initial_registration_methods: + method_dir = initial_registration_dir / method_name.lower() / patient_id + warped_images, warped_labelmaps, warped_masks = gather_warped_frames(method_dir) if not warped_images: print( - f" {patient_id}/{method_name}: no pre-registered frames " + f" {patient_id}/{method_name}: no initial-registered frames " f"in {method_dir}" ) continue - warped_masks: list[Optional[str]] = [] - for lm in warped_labelmaps: - if lm is None: - warped_masks.append(None) - continue - # 1-preregistration.py writes the warped loss mask next to the - # warped labelmap; prefer it, deriving one only if it is absent. - warped_mask = Path(f"{lm[:-4]}_mask.mha") - warped_masks.append( - str(warped_mask) if warped_mask.exists() else derive_mask_for(Path(lm)) - ) train_image_files[subject_index].extend(warped_images) - train_segmentation_files[subject_index].extend(warped_labelmaps) + train_labelmap_files[subject_index].extend(warped_labelmaps) train_mask_files[subject_index].extend(warped_masks) - n_seg = sum(1 for lm in warped_labelmaps if lm is not None) + n_warped = sum(1 for labelmap in warped_labelmaps if labelmap is not None) print( f" {patient_id}/{method_name}: +{len(warped_images)} warped frames, " - f"{n_seg} with labelmap" + f"{n_warped} with labelmap" ) -# %% [markdown] -# ## 5. Fine-tune uniGradICON on the train cohort -# -# Each train group now holds the original gated frames plus the merged ANTS -# and Greedy pre-registered frames (section 4b). The workflow consumes both -# the labelmaps (for paired-with-seg training) and the pre-computed masks (for -# ``loss_function_masking``) -# and launches ``unigradicon.finetuning.finetune`` as a subprocess. The -# final checkpoint lands at -# :meth:`WorkflowFineTuneICONRegistration.expected_weights_path`, which is -# the default ``--finetuned-weights-path`` read by ``2-recon_4d_icon_eval.py``. - # %% workflow = WorkflowFineTuneICONRegistration( subject_image_files=train_image_files, output_dir=output_dir, fine_tune_name=fine_tune_name, subject_ids=valid_train_subjects, - subject_segmentation_files=train_segmentation_files, + subject_labelmap_files=train_labelmap_files, subject_mask_files=train_mask_files, mask_dilation_mm=mask_dilation_mm, unigradicon_src_path=unigradicon_src_path, diff --git a/experiments/LongitudinalRegistration/4-recon_4d_all_eval.py b/experiments/LongitudinalRegistration/4-recon_4d_all_eval.py deleted file mode 100644 index 3a81f10..0000000 --- a/experiments/LongitudinalRegistration/4-recon_4d_all_eval.py +++ /dev/null @@ -1,706 +0,0 @@ -"""Compare longitudinal cardiac CT registration methods with landmarks. - -The experiment registers each gated time-point image to the high-resolution -reference image for the same subject. Input images are 3D CT volumes in LPS -world space. Landmarks are CSV rows with physical LPS coordinates -``Name,X,Y,Z`` in millimeters. - -Two accuracy modes are written: -1. Direct landmarks: reference landmarks are transformed into each time-point - image space with the inverse registration transform and compared to the - precomputed time-point landmarks. -2. Re-segmented landmarks: the reference image is warped into each time-point - image space, re-segmented with Simpleware, and the newly extracted landmarks - are compared to the precomputed time-point landmarks. -""" - -from __future__ import annotations - -import argparse -import csv -import re -from dataclasses import dataclass -from pathlib import Path -from typing import Optional - -import itk -import numpy as np - -from physiomotion4d import ( - RegisterTimeSeriesImages, - SegmentHeartSimpleware, - TransformTools, -) - -DEFAULT_REF_DIR = Path("d:/PhysioMotion4D/duke_data/ref_images") -DEFAULT_TIMEPOINT_BASE_DIR = Path("d:/PhysioMotion4D/duke_data/gated_nii") -DEFAULT_SEGMENTATION_BASE_DIR = Path("d:/PhysioMotion4D/duke_data/simple_ascardio") -DEFAULT_OUTPUT_DIR = Path("d:/PhysioMotion4D/duke_data/longitudinal_registration") -DEFAULT_EXCLUDE_TOKENS = ("nop", "dia", "sys", "_ref") -DEFAULT_SEGMENTATION_DIR = "results-labelmaps_and_landmarks" -DEFAULT_METHODS = ("ANTS", "greedy", "icon_default", "ants_icon_default") -TIMEPOINT_RE = re.compile(r"_g(?P[0-9]{3})") - - -Landmarks = dict[str, tuple[float, float, float]] - - -@dataclass(frozen=True) -class MethodSpec: - """Registration method plus optional ICON checkpoint.""" - - output_name: str - registration_method: str - icon_weights_path: Optional[Path] = None - - -@dataclass(frozen=True) -class ImageArtifacts: - """Input files associated with one image volume.""" - - image_file: Path - landmark_file: Optional[Path] - labelmap_file: Optional[Path] - timepoint: str - - -def nii_stem(path: Path) -> str: - """Return a stable stem for ``.nii.gz`` or single-suffix files.""" - if path.name.endswith(".nii.gz"): - return path.name[:-7] - return path.stem - - -def timepoint_from_name(path: Path) -> str: - """Extract the gated time-point tag from a filename.""" - match = TIMEPOINT_RE.search(path.name) - if match: - return match.group("timepoint") - return nii_stem(path) - - -def first_existing(paths: list[Path]) -> Optional[Path]: - """Return the first existing path from a candidate list.""" - for path in paths: - if path.exists(): - return path - return None - - -def landmark_candidates( - image_file: Path, - segmentation_dir: str, - artifact_dir: Optional[Path], -) -> list[Path]: - """Return likely landmark CSV paths for an image.""" - stem = nii_stem(image_file) - parent = image_file.parent - seg_parent = parent / segmentation_dir - candidates = [ - parent / f"{stem}_landmark.csv", - parent / f"{stem}_landmarks.csv", - seg_parent / f"{stem}_landmark.csv", - seg_parent / f"{stem}_landmarks.csv", - ] - if artifact_dir is not None: - candidates = [ - artifact_dir / f"{stem}_landmark.csv", - artifact_dir / f"{stem}_landmarks.csv", - *candidates, - ] - return candidates - - -def labelmap_candidates( - image_file: Path, - segmentation_dir: str, - artifact_dir: Optional[Path], -) -> list[Path]: - """Return likely labelmap paths for an image.""" - stem = nii_stem(image_file) - parent = image_file.parent - seg_parent = parent / segmentation_dir - candidates = [ - parent / f"{stem}_labelmap.nii.gz", - seg_parent / f"{stem}_labelmap.nii.gz", - ] - if artifact_dir is not None: - candidates = [artifact_dir / f"{stem}_labelmap.nii.gz", *candidates] - return candidates - - -def image_artifacts( - image_file: Path, - segmentation_dir: str, - artifact_dir: Optional[Path] = None, -) -> ImageArtifacts: - """Find landmarks and labelmaps associated with one image.""" - return ImageArtifacts( - image_file=image_file, - landmark_file=first_existing( - landmark_candidates(image_file, segmentation_dir, artifact_dir) - ), - labelmap_file=first_existing( - labelmap_candidates(image_file, segmentation_dir, artifact_dir) - ), - timepoint=timepoint_from_name(image_file), - ) - - -def read_landmarks(path: Path) -> Landmarks: - """Read physical LPS landmarks from ``Name,X,Y,Z`` CSV.""" - landmarks: Landmarks = {} - with path.open(newline="", encoding="utf-8-sig") as fh: - for row in csv.DictReader(fh): - landmarks[row["Name"]] = ( - float(row["X"]), - float(row["Y"]), - float(row["Z"]), - ) - return landmarks - - -def write_landmarks(path: Path, landmarks: Landmarks) -> None: - """Write physical LPS landmarks to ``Name,X,Y,Z`` CSV.""" - path.parent.mkdir(parents=True, exist_ok=True) - with path.open("w", newline="", encoding="utf-8") as fh: - writer = csv.writer(fh) - writer.writerow(["Name", "X", "Y", "Z"]) - for name, coords in sorted(landmarks.items()): - writer.writerow([name, coords[0], coords[1], coords[2]]) - - -def transform_landmarks(landmarks: Landmarks, transform: itk.Transform) -> Landmarks: - """Apply an ITK physical-space transform to landmark coordinates.""" - transformed: Landmarks = {} - for name, point in landmarks.items(): - transformed_point = transform.TransformPoint(point) - transformed[name] = ( - float(transformed_point[0]), - float(transformed_point[1]), - float(transformed_point[2]), - ) - return transformed - - -def landmark_errors(source: Landmarks, target: Landmarks) -> dict[str, float]: - """Return per-landmark Euclidean errors in millimeters.""" - errors: dict[str, float] = {} - for name in sorted(source.keys() & target.keys()): - source_point = np.asarray(source[name], dtype=np.float64) - target_point = np.asarray(target[name], dtype=np.float64) - errors[name] = float(np.linalg.norm(source_point - target_point)) - return errors - - -def summarize_errors(errors: dict[str, float], prefix: str) -> dict[str, object]: - """Summarize landmark errors for one comparison mode.""" - if not errors: - return { - f"{prefix}_landmarks": 0, - f"{prefix}_mean_mm": "", - f"{prefix}_median_mm": "", - f"{prefix}_max_mm": "", - } - values = np.asarray(list(errors.values()), dtype=np.float64) - return { - f"{prefix}_landmarks": len(errors), - f"{prefix}_mean_mm": float(np.mean(values)), - f"{prefix}_median_mm": float(np.median(values)), - f"{prefix}_max_mm": float(np.max(values)), - } - - -def write_error_details( - path: Path, - subject_id: str, - method_name: str, - timepoint: str, - mode: str, - errors: dict[str, float], -) -> None: - """Append per-landmark errors to the detail CSV.""" - exists = path.exists() - path.parent.mkdir(parents=True, exist_ok=True) - with path.open("a", newline="", encoding="utf-8") as fh: - fieldnames = ["subject_id", "method", "timepoint", "mode", "name", "error_mm"] - writer = csv.DictWriter(fh, fieldnames=fieldnames) - if not exists: - writer.writeheader() - for name, error in sorted(errors.items()): - writer.writerow( - { - "subject_id": subject_id, - "method": method_name, - "timepoint": timepoint, - "mode": mode, - "name": name, - "error_mm": error, - } - ) - - -def dice_by_label( - labelmap_a: itk.Image, - labelmap_b: itk.Image, -) -> dict[int, float]: - """Compute Dice scores for labels present in either 3D labelmap.""" - arr_a = itk.array_from_image(labelmap_a) - arr_b = itk.array_from_image(labelmap_b) - if arr_a.shape != arr_b.shape: - return {} - labels = sorted(set(np.unique(arr_a)).union(set(np.unique(arr_b))) - {0}) - scores: dict[int, float] = {} - for label in labels: - mask_a = arr_a == label - mask_b = arr_b == label - denom = int(mask_a.sum() + mask_b.sum()) - if denom > 0: - scores[int(label)] = float( - 2.0 * np.logical_and(mask_a, mask_b).sum() / denom - ) - return scores - - -def summarize_dice(scores: dict[int, float]) -> dict[str, object]: - """Summarize per-label Dice scores.""" - if not scores: - return {"dice_labels": 0, "dice_mean": "", "dice_min": ""} - values = np.asarray(list(scores.values()), dtype=np.float64) - return { - "dice_labels": len(scores), - "dice_mean": float(np.mean(values)), - "dice_min": float(np.min(values)), - } - - -def discover_subjects( - reference_dir: Path, - timepoint_base_dir: Path, - reference_pattern: str, - timepoint_pattern: str, - exclude_tokens: tuple[str, ...], - segmentation_dir: str, - segmentation_base_dir: Optional[Path], -) -> list[tuple[str, ImageArtifacts, list[ImageArtifacts]]]: - """Discover reference and time-point files for each subject.""" - if not reference_dir.exists(): - raise FileNotFoundError(f"Reference image directory not found: {reference_dir}") - if not timepoint_base_dir.exists(): - raise FileNotFoundError( - f"Time-point image base directory not found: {timepoint_base_dir}" - ) - - subjects: list[tuple[str, ImageArtifacts, list[ImageArtifacts]]] = [] - for reference_file in sorted(reference_dir.glob(reference_pattern)): - subject_id = reference_file.name[:6] - source_dir = timepoint_base_dir / subject_id - if not source_dir.exists(): - raise FileNotFoundError( - f"No time-point directory for {subject_id}: {source_dir}" - ) - artifact_dir = None - if segmentation_base_dir is not None: - candidate_dir = segmentation_base_dir / subject_id - if candidate_dir.exists(): - artifact_dir = candidate_dir - - reference_in_source = source_dir / reference_file.name - reference_artifacts = image_artifacts( - reference_in_source if reference_in_source.exists() else reference_file, - segmentation_dir, - artifact_dir, - ) - - timepoint_files = [ - path - for path in sorted(source_dir.glob(timepoint_pattern)) - if not any(token in path.name for token in exclude_tokens) - ] - timepoints = [ - image_artifacts(path, segmentation_dir, artifact_dir) - for path in timepoint_files - if path.is_file() - ] - subjects.append((subject_id, reference_artifacts, timepoints)) - return subjects - - -def build_method_specs( - method_names: list[str], - finetuned_weights_path: Optional[Path], -) -> list[MethodSpec]: - """Map output method labels to registrar methods and optional weights.""" - specs: list[MethodSpec] = [] - for method_name in method_names: - if method_name == "ANTS": - specs.append(MethodSpec(method_name, "ANTS")) - elif method_name == "greedy": - specs.append(MethodSpec(method_name, "greedy")) - elif method_name == "icon_default": - specs.append(MethodSpec(method_name, "ICON")) - elif method_name == "ants_icon_default": - specs.append(MethodSpec(method_name, "ANTS_ICON")) - elif method_name == "greedy_icon_default": - specs.append(MethodSpec(method_name, "greedy_ICON")) - elif method_name == "icon_finetuned": - specs.append(MethodSpec(method_name, "ICON", finetuned_weights_path)) - elif method_name == "ants_icon_finetuned": - specs.append(MethodSpec(method_name, "ANTS_ICON", finetuned_weights_path)) - elif method_name == "greedy_icon_finetuned": - specs.append(MethodSpec(method_name, "greedy_ICON", finetuned_weights_path)) - else: - raise ValueError(f"Unknown method: {method_name}") - - for spec in specs: - if "finetuned" in spec.output_name and spec.icon_weights_path is None: - raise ValueError(f"{spec.output_name} requires --finetuned-weights-path") - return specs - - -def configure_registrar( - method_spec: MethodSpec, - fixed_image: itk.Image, - fixed_labelmap: Optional[itk.Image], - ants_iterations: list[int], - greedy_iterations: list[int], - icon_iterations: int, -) -> RegisterTimeSeriesImages: - """Create and configure the time-series registrar.""" - registrar = RegisterTimeSeriesImages( - registration_method=method_spec.registration_method - ) - registrar.set_modality("ct") - registrar.set_fixed_image(fixed_image) - registrar.set_fixed_labelmap(fixed_labelmap) - registrar.set_number_of_iterations_ANTS(ants_iterations) - registrar.set_number_of_iterations_greedy(greedy_iterations) - registrar.set_number_of_iterations_ICON(icon_iterations) - if method_spec.icon_weights_path is not None: - registrar.registrar_ICON.set_weights_path(str(method_spec.icon_weights_path)) - return registrar - - -def write_summary(path: Path, rows: list[dict[str, object]]) -> None: - """Write experiment summary rows.""" - if not rows: - return - path.parent.mkdir(parents=True, exist_ok=True) - fieldnames = list(rows[0].keys()) - with path.open("w", newline="", encoding="utf-8") as fh: - writer = csv.DictWriter(fh, fieldnames=fieldnames) - writer.writeheader() - writer.writerows(rows) - - -def run_method_for_subject( - subject_id: str, - reference_artifacts: ImageArtifacts, - timepoint_artifacts: list[ImageArtifacts], - method_spec: MethodSpec, - output_dir: Path, - run_resegmentation: bool, - ants_iterations: list[int], - greedy_iterations: list[int], - icon_iterations: int, - error_detail_file: Path, -) -> list[dict[str, object]]: - """Run one registration method for one subject and return summary rows.""" - if reference_artifacts.landmark_file is None: - raise FileNotFoundError( - f"Missing reference landmarks for {reference_artifacts.image_file}" - ) - - fixed_image = itk.imread(str(reference_artifacts.image_file), pixel_type=itk.F) - fixed_labelmap = None - if reference_artifacts.labelmap_file is not None: - fixed_labelmap = itk.imread(str(reference_artifacts.labelmap_file)) - - moving_images = [ - itk.imread(str(artifacts.image_file), pixel_type=itk.F) - for artifacts in timepoint_artifacts - ] - moving_labelmaps = None - if all(artifacts.labelmap_file is not None for artifacts in timepoint_artifacts): - moving_labelmaps = [ - itk.imread(str(artifacts.labelmap_file)) - for artifacts in timepoint_artifacts - ] - - registrar = configure_registrar( - method_spec, - fixed_image, - fixed_labelmap, - ants_iterations, - greedy_iterations, - icon_iterations, - ) - - result = registrar.register_time_series( - moving_images=moving_images, - moving_labelmaps=moving_labelmaps, - reference_frame=0, - register_reference=True, - prior_weight=0.0, - ) - - reference_landmarks = read_landmarks(reference_artifacts.landmark_file) - transform_tools = TransformTools() - segmenter = SegmentHeartSimpleware() if run_resegmentation else None - subject_method_dir = output_dir / method_spec.output_name / subject_id - subject_method_dir.mkdir(parents=True, exist_ok=True) - - rows: list[dict[str, object]] = [] - for index, artifacts in enumerate(timepoint_artifacts): - timepoint_dir = subject_method_dir / artifacts.timepoint - timepoint_dir.mkdir(parents=True, exist_ok=True) - - forward_transform = result["forward_transforms"][index] - inverse_transform = result["inverse_transforms"][index] - loss = result["losses"][index] - - forward_file = timepoint_dir / "time_to_reference.hdf" - inverse_file = timepoint_dir / "reference_to_time.hdf" - itk.transformwrite(forward_transform, str(forward_file), compression=True) - itk.transformwrite(inverse_transform, str(inverse_file), compression=True) - - moving_to_reference = transform_tools.transform_image( - moving_images[index], - forward_transform, - fixed_image, - ) - moving_to_reference_file = timepoint_dir / "time_to_reference.mha" - itk.imwrite( - moving_to_reference, str(moving_to_reference_file), compression=True - ) - - reference_to_time = transform_tools.transform_image( - fixed_image, - inverse_transform, - moving_images[index], - ) - reference_to_time_file = timepoint_dir / "reference_to_time.mha" - itk.imwrite(reference_to_time, str(reference_to_time_file), compression=True) - - row: dict[str, object] = { - "subject_id": subject_id, - "method": method_spec.output_name, - "timepoint": artifacts.timepoint, - "moving_image": str(artifacts.image_file), - "forward_transform": str(forward_file), - "inverse_transform": str(inverse_file), - "loss": float(loss), - } - - if artifacts.landmark_file is not None: - timepoint_landmarks = read_landmarks(artifacts.landmark_file) - # Warp the reference landmarks into the timepoint (moving) space to - # compare against this timepoint's landmarks. Warping reference -> - # time POINTS uses forward_transform (the fixed -> moving point map), - # which is the opposite of the reference_to_time IMAGE above (images - # pull back, points push forward). See - # docs/developer/transform_conventions. - direct_landmarks = transform_landmarks( - reference_landmarks, - forward_transform, - ) - direct_errors = landmark_errors(direct_landmarks, timepoint_landmarks) - write_error_details( - error_detail_file, - subject_id, - method_spec.output_name, - artifacts.timepoint, - "direct", - direct_errors, - ) - row.update(summarize_errors(direct_errors, "direct")) - else: - row.update(summarize_errors({}, "direct")) - - if run_resegmentation and segmenter is not None: - segmentation = segmenter.segment( - reference_to_time, - contrast_enhanced_study=False, - ) - warped_labelmap = segmentation["labelmap"] - warped_labelmap_file = timepoint_dir / "reference_to_time_labelmap.nii.gz" - itk.imwrite(warped_labelmap, str(warped_labelmap_file), compression=True) - reseg_landmarks = segmenter.get_landmarks() - reseg_landmark_file = timepoint_dir / "reference_to_time_landmark.csv" - write_landmarks(reseg_landmark_file, reseg_landmarks) - row["resegmented_labelmap"] = str(warped_labelmap_file) - row["resegmented_landmarks"] = str(reseg_landmark_file) - - if artifacts.landmark_file is not None: - timepoint_landmarks = read_landmarks(artifacts.landmark_file) - reseg_errors = landmark_errors(reseg_landmarks, timepoint_landmarks) - write_error_details( - error_detail_file, - subject_id, - method_spec.output_name, - artifacts.timepoint, - "resegmented", - reseg_errors, - ) - row.update(summarize_errors(reseg_errors, "resegmented")) - else: - row.update(summarize_errors({}, "resegmented")) - - if artifacts.labelmap_file is not None: - timepoint_labelmap = itk.imread(str(artifacts.labelmap_file)) - row.update( - summarize_dice(dice_by_label(warped_labelmap, timepoint_labelmap)) - ) - else: - row.update(summarize_dice({})) - else: - row["resegmented_labelmap"] = "" - row["resegmented_landmarks"] = "" - row.update(summarize_errors({}, "resegmented")) - row.update(summarize_dice({})) - - rows.append(row) - - return rows - - -def parse_iterations(value: str) -> list[int]: - """Parse comma-separated multi-resolution iteration counts.""" - return [int(item.strip()) for item in value.split(",") if item.strip()] - - -def main() -> int: - """Run the longitudinal registration comparison experiment.""" - parser = argparse.ArgumentParser( - description="Compare ANTS, Greedy, and ICON longitudinal registration." - ) - parser.add_argument("--reference-dir", type=Path, default=DEFAULT_REF_DIR) - parser.add_argument( - "--timepoint-base-dir", - type=Path, - default=DEFAULT_TIMEPOINT_BASE_DIR, - ) - parser.add_argument("--output-dir", type=Path, default=DEFAULT_OUTPUT_DIR) - parser.add_argument( - "--segmentation-base-dir", - type=Path, - default=DEFAULT_SEGMENTATION_BASE_DIR, - help="Directory with per-subject precomputed *_labelmap and *_landmark files.", - ) - parser.add_argument("--reference-pattern", default="pm00*.nii.gz") - parser.add_argument("--timepoint-pattern", default="*.nii.gz") - parser.add_argument("--segmentation-dir", default=DEFAULT_SEGMENTATION_DIR) - parser.add_argument( - "--exclude-token", - action="append", - default=list(DEFAULT_EXCLUDE_TOKENS), - help="Filename token to exclude from time-point inputs.", - ) - parser.add_argument( - "--methods", - nargs="+", - default=None, - help="Methods to run. Defaults include finetuned methods when weights are set.", - ) - parser.add_argument("--finetuned-weights-path", type=Path, default=None) - parser.add_argument("--max-subjects", type=int, default=None) - parser.add_argument("--max-timepoints", type=int, default=None) - parser.add_argument("--ANTS-iterations", default="30,15,7,3") - parser.add_argument("--greedy-iterations", default="30,15,7,3") - parser.add_argument("--ICON-iterations", type=int, default=20) - parser.add_argument( - "--skip-resegmentation", - action="store_true", - help="Skip Simpleware re-segmentation mode.", - ) - parser.add_argument( - "--dry-run", - action="store_true", - help="Validate discovered files and planned methods without registration.", - ) - args = parser.parse_args() - - method_names = args.methods - if method_names is None: - method_names = list(DEFAULT_METHODS) - method_names.append("greedy_icon_default") - if args.finetuned_weights_path is not None: - method_names.extend( - [ - "icon_finetuned", - "ants_icon_finetuned", - "greedy_icon_finetuned", - ] - ) - - method_specs = build_method_specs(method_names, args.finetuned_weights_path) - subjects = discover_subjects( - args.reference_dir, - args.timepoint_base_dir, - args.reference_pattern, - args.timepoint_pattern, - tuple(args.exclude_token), - args.segmentation_dir, - args.segmentation_base_dir, - ) - if args.max_subjects is not None: - subjects = subjects[: args.max_subjects] - - if args.dry_run: - for subject_id, reference_artifacts, timepoint_artifacts in subjects: - if args.max_timepoints is not None: - timepoint_artifacts = timepoint_artifacts[: args.max_timepoints] - missing_landmarks = sum( - artifacts.landmark_file is None for artifacts in timepoint_artifacts - ) - missing_labelmaps = sum( - artifacts.labelmap_file is None for artifacts in timepoint_artifacts - ) - print( - f"{subject_id}: {len(timepoint_artifacts)} time points, " - f"reference_landmarks={reference_artifacts.landmark_file is not None}, " - f"reference_labelmap={reference_artifacts.labelmap_file is not None}, " - f"missing_time_landmarks={missing_landmarks}, " - f"missing_time_labelmaps={missing_labelmaps}" - ) - print("Methods: " + ", ".join(spec.output_name for spec in method_specs)) - return 0 - - summary_rows: list[dict[str, object]] = [] - detail_file = args.output_dir / "landmark_errors_by_point.csv" - if detail_file.exists(): - detail_file.unlink() - - for subject_id, reference_artifacts, timepoint_artifacts in subjects: - if args.max_timepoints is not None: - timepoint_artifacts = timepoint_artifacts[: args.max_timepoints] - if not timepoint_artifacts: - raise ValueError(f"No time-point images found for {subject_id}") - print( - f"Running {subject_id}: {len(timepoint_artifacts)} time points, " - f"{len(method_specs)} methods" - ) - for method_spec in method_specs: - print(f" Method: {method_spec.output_name}") - rows = run_method_for_subject( - subject_id, - reference_artifacts, - timepoint_artifacts, - method_spec, - args.output_dir, - not args.skip_resegmentation, - parse_iterations(args.ants_iterations), - parse_iterations(args.greedy_iterations), - args.icon_iterations, - detail_file, - ) - summary_rows.extend(rows) - write_summary(args.output_dir / "registration_summary.csv", summary_rows) - - print(f"Wrote summary: {args.output_dir / 'registration_summary.csv'}") - print(f"Wrote landmark details: {detail_file}") - return 0 - - -if __name__ == "__main__": - raise SystemExit(main()) diff --git a/experiments/LongitudinalRegistration/experiment_recon_4d.py b/experiments/LongitudinalRegistration/experiment_recon_4d.py deleted file mode 100644 index 7d9ce26..0000000 --- a/experiments/LongitudinalRegistration/experiment_recon_4d.py +++ /dev/null @@ -1,191 +0,0 @@ -# %% [markdown] -# # 4D CT Reconstruction Using RegisterTimeSeriesImages Class -# -# This script demonstrates the use of the `RegisterTimeSeriesImages` class -# to register a time series of CT images to a common reference frame. -# -# This is a refactored version of `reconstruct_4d_ct.ipynb` that uses -# the new class-based approach,including: -# - Registration of time series images using ANTs, Greedy, ICON, or combined -# ANTs/Greedy + ICON methods -# - Reconstruction of time series using the `reconstruct_time_series()` method -# - Optional upsampling to fixed image resolution while preserving spatial positioning -# - -# %% -# Import necessary libraries -######################################################## - -import os - -import itk -import numpy as np -from physiomotion4d import RegisterTimeSeriesImages - -# %% -# Identify reference images -######################################################## - -ref_data_dir = "d:/PhysioMotion4D/duke_data/ref_images" -src_data_dir_base = "d:/PhysioMotion4D/duke_data/gated_nii" -dest_data_dir_base = "d:/PhysioMotion4D/duke_data/recon4d" - -ref_files = [ - os.path.join(ref_data_dir, f) - for f in sorted(os.listdir(ref_data_dir)) - if f.startswith("pm00") and f.endswith(".nii.gz") -] - -print(f"Found {len(ref_files)} reference images") - -# %% -# Identify source data directories and files using reference image names -######################################################## - - -print(os.path.basename(ref_files[0])[:6]) -src_data_dirs = [] -src_data_files = [] -for ref_file in ref_files: - src_dir = os.path.join(src_data_dir_base, os.path.basename(ref_file)[:6]) - src_data_dirs.append(src_dir) - - file_list = sorted(os.listdir(src_dir)) - valid_file_list = [ - f - for f in file_list - if "dia" not in f - and "nop" not in f - and "sys" not in f - and f.endswith(".nii.gz") - ] - src_data_files.append(valid_file_list) - -print(f"Found {len(src_data_dirs)} source data directories") -for d, fs in zip(src_data_dirs, src_data_files): - print(f"{d}: {len(fs)} files") - for f in fs: - print(f" {f}") - -# %% -# Define registration function -######################################################## - - -def register_time_series( - reference_image_file: str, - source_image_dir: str, - source_image_files: list[str], - registration_method: str, -) -> None: - # ANTs registration - if registration_method in ["ANTS", "greedy"]: - number_of_iterations = [30, 15, 7, 3] - elif registration_method == "ICON": - number_of_iterations = 20 - elif registration_method in ["ANTS_ICON", "greedy_ICON"]: - number_of_iterations = [[30, 15, 7, 3], 20] - else: - raise ValueError(f"Invalid registration method: {registration_method}") - - # Create output dir - output_dir = os.path.join( - dest_data_dir_base, registration_method, os.path.basename(source_image_dir) - ) - os.makedirs(output_dir, exist_ok=True) - - # Read the reference image as the fixed image - fixed_image = itk.imread(reference_image_file, pixel_type=itk.F) - - images = [] - for file in source_image_files: - img = itk.imread(os.path.join(source_image_dir, file), pixel_type=itk.F) - images.append(img) - - reference_image_num = 7 - register_start_to_reference = True - if reference_image_file in source_image_files: - reference_image_num = source_image_files.index(reference_image_file) - register_start_to_reference = False - - portion_of_prior_transform_to_init_next_transform = 0.0 - - # Register the time series - registrar = RegisterTimeSeriesImages(registration_method=registration_method) - registrar.set_modality("ct") - registrar.set_fixed_image(fixed_image) - if registration_method == "ANTS": - registrar.set_number_of_iterations_ANTS(number_of_iterations) - elif registration_method == "greedy": - registrar.set_number_of_iterations_greedy(number_of_iterations) - elif registration_method == "ICON": - registrar.set_number_of_iterations_ICON(number_of_iterations) - elif registration_method == "ANTS_ICON": - registrar.set_number_of_iterations_ANTS(number_of_iterations[0]) - registrar.set_number_of_iterations_ICON(number_of_iterations[1]) - elif registration_method == "greedy_ICON": - registrar.set_number_of_iterations_greedy(number_of_iterations[0]) - registrar.set_number_of_iterations_ICON(number_of_iterations[1]) - else: - raise ValueError(f"Invalid registration method: {registration_method}") - - result = registrar.register_time_series( - moving_images=images, - reference_frame=reference_image_num, - register_reference=register_start_to_reference, - prior_weight=portion_of_prior_transform_to_init_next_transform, - ) - - upsampled_images = registrar.reconstruct_time_series( - moving_images=images, - inverse_transforms=result["inverse_transforms"], - upsample_to_fixed_resolution=True, - ) - - losses = result["losses"] - print("Registration complete!") - print(f" Average loss: {np.mean(losses):.6f}") - print(f" Min loss: {np.min(losses):.6f}") - print(f" Max loss: {np.max(losses):.6f}") - print("") - print("Saving results...") - output_file_basename = os.path.basename(reference_image_file)[:6] - for i, fwd_transform in enumerate(result["forward_transforms"]): - time_point_index = source_image_files[i].index("_g") + 2 - time_point = source_image_files[i][time_point_index : time_point_index + 3] - - output_file = f"{output_file_basename}_{time_point}_fwd.hdf" - itk.transformwrite( - fwd_transform, - os.path.join(output_dir, output_file), - compression=True, - ) - - inv_transform = result["inverse_transforms"][i] - output_file = f"{output_file_basename}_{time_point}_inv.hdf" - itk.transformwrite( - inv_transform, - os.path.join(output_dir, output_file), - compression=True, - ) - - output_file = f"{output_file_basename}_{time_point}_hrr.mha" - itk.imwrite( - upsampled_images[i], - os.path.join(output_dir, output_file), - compression=True, - ) - - -# %% -# Register time series -######################################################## - -for ref_file, src_dir, src_files in zip(ref_files, src_data_dirs, src_data_files): - register_time_series(ref_file, src_dir, src_files, "ANTS") - register_time_series(ref_file, src_dir, src_files, "greedy") - register_time_series(ref_file, src_dir, src_files, "ICON") - register_time_series(ref_file, src_dir, src_files, "ANTS_ICON") - register_time_series(ref_file, src_dir, src_files, "greedy_ICON") - -# %% diff --git a/experiments/LongitudinalRegistration/registration_results_analysis.py b/experiments/LongitudinalRegistration/registration_results_analysis.py new file mode 100644 index 0000000..50b8896 --- /dev/null +++ b/experiments/LongitudinalRegistration/registration_results_analysis.py @@ -0,0 +1,244 @@ +"""Summarize registration Dice and landmark RMSE results across experiments. + +For every ``results_*`` directory under a base directory this script reads: + +- ``registration_dice_init.csv`` with columns + ``subject_id, method, stem, label, dice`` (one row per subject / time point + / anatomy label). +- ``registration_landmarks_init.csv`` with columns + ``subject_id, method, stem, name, rms_err_mm`` (one row per subject / time + point / landmark). + +It pools all subjects and all time points (rows) within each +``(results_dir, method)`` group and reports the mean, standard deviation, and +95th percentile of the Dice score per label (1..10) and across all labels, and +the same statistics of the landmark RMSE per landmark and across all landmarks. + +**Duplicate handling.** If the same ``(subject_id, method, stem, label)`` +combination appears more than once in a single Dice CSV, the *n*-th occurrence +is treated as if it came from a separate directory named +``{results_dir}_{n}`` (e.g. ``results_ml_2``, ``results_ml_3``). The same +applies for ``(subject_id, method, stem, name)`` duplicates in the landmark +CSV. + +Two summary tables are produced, each indexed by ``(results_dir, method)`` +(with any ``_n`` suffixes introduced by duplicate splitting): one for label +Dice scores and one for landmark RMSE. +""" + +from __future__ import annotations + +import argparse +from pathlib import Path +from typing import Optional + +import numpy as np +import pandas as pd + +DICE_CSV = "registration_dice_init.csv" +LANDMARKS_CSV = "registration_landmarks_init.csv" + +ALL_KEY = "all" +STATS = ("mean", "std", "p95") + + +def _split_by_occurrence( + frame: pd.DataFrame, + key_cols: list[str], + dir_name: str, +) -> dict[str, pd.DataFrame]: + """Split ``frame`` into per-occurrence sub-frames keyed by a group name. + + Within ``frame``, each unique combination of ``key_cols`` may appear + multiple times. The *n*-th occurrence (1-indexed) of a given key is + assigned to group ``dir_name`` (for n=1) or ``{dir_name}_{n}`` (for n>1). + Sub-frames for different occurrence numbers are returned in a dict whose + keys follow that naming convention. + + Args: + frame: Input data frame, one row per observation. + key_cols: Columns that together identify a unique observation. + dir_name: Base name used to construct group keys. + + Returns: + Ordered dict mapping group name → sub-data-frame. + """ + occurrence = frame.groupby(key_cols, sort=False).cumcount() + 1 # 1-indexed + max_occ = int(occurrence.max()) + result: dict[str, pd.DataFrame] = {} + for n in range(1, max_occ + 1): + sub = frame[occurrence == n] + if sub.empty: + continue + group_name = dir_name if n == 1 else f"{dir_name}_{n}" + result[group_name] = sub.reset_index(drop=True) + return result + + +def _aggregate( + frames: dict[str, pd.DataFrame], + group_col: str, + value_col: str, +) -> pd.DataFrame: + """Aggregate one value column grouped by ``group_col`` for each (directory, method). + + Args: + frames: Mapping from ``results_*`` directory name to its data frame. + group_col: Column whose distinct values become summary categories + (``"label"`` for Dice, ``"name"`` for landmarks). + value_col: Column to summarize (``"dice"`` or ``"rms_err_mm"``). + + Returns: + A data frame indexed by ``(results_dir, method)`` with a two-level + column ``(category, stat)`` where ``category`` is each distinct group + value plus ``"all"`` and ``stat`` is one of ``mean``, ``std``, ``p95``. + Categories are pooled across all subjects and time points within each + ``(directory, method)`` group. + """ + + def _stats(series: pd.Series) -> dict[str, float]: + # Drop missing measurements so a single blank row does not poison the + # mean / std / percentile for an entire landmark or the pooled "all". + values = pd.to_numeric(series, errors="coerce").dropna().to_numpy(dtype=float) + if values.size == 0: + return {stat: float("nan") for stat in STATS} + return { + "mean": float(np.mean(values)), + "std": float(np.std(values, ddof=1)) if values.size > 1 else 0.0, + "p95": float(np.percentile(values, 95)), + } + + rows: dict[tuple[str, str], dict[tuple[object, str], float]] = {} + for dir_name, frame in frames.items(): + method_groups: list[tuple[str, pd.DataFrame]] + if "method" in frame.columns: + method_groups = [ + (str(m), sub) for m, sub in frame.groupby("method", sort=True) + ] + else: + method_groups = [("", frame)] + for method, mframe in method_groups: + row: dict[tuple[object, str], float] = {} + for category, group in mframe.groupby(group_col): + for stat, value in _stats(group[value_col]).items(): + row[(category, stat)] = value + for stat, value in _stats(mframe[value_col]).items(): + row[(ALL_KEY, stat)] = value + rows[(dir_name, method)] = row + + table = pd.DataFrame.from_dict(rows, orient="index") + table.index = pd.MultiIndex.from_tuples( + table.index, names=["results_dir", "method"] + ) + # from_dict yields a flat Index of tuples; promote to a MultiIndex so the + # reindex below aligns on (category, stat) rather than producing NaNs. + table.columns = pd.MultiIndex.from_tuples(table.columns) + + # Order columns: numeric/string categories sorted, "all" last. + categories = sorted( + {category for category, _ in table.columns if category != ALL_KEY}, + key=lambda c: (0, int(c)) if str(c).isdigit() else (1, str(c)), + ) + categories.append(ALL_KEY) + ordered = [(category, stat) for category in categories for stat in STATS] + table = table.reindex(columns=pd.MultiIndex.from_tuples(ordered)) + return table + + +def summarize(base_dir: Path) -> tuple[pd.DataFrame, pd.DataFrame]: + """Build the Dice and landmark RMSE summary tables for ``base_dir``. + + Args: + base_dir: Directory containing one or more ``results_*`` subdirectories. + + Returns: + ``(dice_table, landmark_table)`` summary frames, each indexed by + ``(results_dir, method)``. + """ + result_dirs = sorted(p for p in base_dir.glob("results_*") if p.is_dir()) + if not result_dirs: + raise FileNotFoundError(f"No results_* directories found in {base_dir}") + + dice_frames: dict[str, pd.DataFrame] = {} + landmark_frames: dict[str, pd.DataFrame] = {} + for result_dir in result_dirs: + dice_path = result_dir / DICE_CSV + landmark_path = result_dir / LANDMARKS_CSV + if dice_path.is_file(): + dice_frames.update( + _split_by_occurrence( + pd.read_csv(dice_path), + ["subject_id", "method", "stem", "label"], + result_dir.name, + ) + ) + if landmark_path.is_file(): + landmark_frames.update( + _split_by_occurrence( + pd.read_csv(landmark_path), + ["subject_id", "method", "stem", "name"], + result_dir.name, + ) + ) + + dice_table = _aggregate(dice_frames, "label", "dice") + landmark_table = _aggregate(landmark_frames, "name", "rms_err_mm") + return dice_table, landmark_table + + +def _print_table(title: str, table: pd.DataFrame) -> None: + """Print a summary table with a heading at full width.""" + print(f"\n{'=' * 70}\n{title}\n{'=' * 70}") + with pd.option_context( + "display.max_columns", + None, + "display.width", + None, + "display.float_format", + lambda v: f"{v:.4f}", + ): + print(table) + + +def main(argv: Optional[list[str]] = None) -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "base_dir", + nargs="?", + type=Path, + default=Path(__file__).resolve().parent, + help="Directory containing results_* subdirectories " + "(default: this script's directory).", + ) + parser.add_argument( + "--out-dir", + type=Path, + default=None, + help="Optional directory to write summary CSV files into.", + ) + args = parser.parse_args(argv) + + dice_table, landmark_table = summarize(args.base_dir) + + _print_table( + "Dice score by label (mean / std / 95th percentile), grouped by " + "results_dir + method, pooled over subjects and time points", + dice_table, + ) + _print_table( + "Landmark RMSE [mm] (mean / std / 95th percentile), grouped by " + "results_dir + method, pooled over subjects and time points", + landmark_table, + ) + + if args.out_dir is not None: + args.out_dir.mkdir(parents=True, exist_ok=True) + dice_out = args.out_dir / "summary_dice.csv" + landmark_out = args.out_dir / "summary_landmarks.csv" + dice_table.to_csv(dice_out) + landmark_table.to_csv(landmark_out) + print(f"\nWrote {dice_out}\nWrote {landmark_out}") + + +if __name__ == "__main__": + main() diff --git a/src/physiomotion4d/labelmap_tools.py b/src/physiomotion4d/labelmap_tools.py index ff5ad78..2ab4d6c 100644 --- a/src/physiomotion4d/labelmap_tools.py +++ b/src/physiomotion4d/labelmap_tools.py @@ -98,3 +98,90 @@ def convert_labelmap_to_mask( return itk.binary_dilate_image_filter( mask, kernel=structuring_element, foreground_value=1 ) + + def create_distance_map( + self, + labelmap: itk.Image, + max_distance_mm: float = 20.0, + distance_scale: float = 5.0, + ) -> itk.Image: + """Encode a labelmap as a continuous label-plus-boundary-distance image. + + Each output voxel holds its original integer label plus a small + fractional offset that encodes how far the voxel lies from the nearest + boundary between two differently-labeled regions: + + value = label + min(distance_to_nearest_boundary_mm, + max_distance_mm) / distance_scale + + The boundary set is every voxel that 6-neighbors a voxel with a + different label (background label ``0`` participates, so the outer + surface of each structure is a boundary). The unsigned physical + distance from each voxel to that set is computed with + ``SignedMaurerDistanceMapImageFilter`` (taking the magnitude), clipped + to ``max_distance_mm``, divided by ``distance_scale``, and added to the + voxel's original label. + + With the defaults (``50`` mm clip, ``100`` scale) the fractional offset + stays in ``[0.0, 0.5]``, so it never reaches the next integer label and + label identity is recoverable as ``floor(value)``. + + The motivation is registration metrics such as Greedy's NCC: a raw + integer labelmap is piecewise-constant, so the local variance inside + each region is zero and NCC produces NaN gradients. Replacing it with + this continuous encoding gives every region a smoothly varying signal + while preserving label identity. + + Axis ordering: ``labelmap`` is a scalar 3D ``itk.Image`` in ITK + world-axis order (X, Y, Z). All work is done on the numpy view + (Z, Y, X) and written back through ``CopyInformation``, so origin, + spacing, and direction are preserved. + + Args: + labelmap: Multi-label (or binary) ``itk.Image`` of integer labels. + max_distance_mm: Distance clip, in millimeters. Default 50.0. + distance_scale: Divisor applied to the clipped distance before it + is added to the label. Default 100.0. With the default clip + this bounds the fractional offset to ``[0, 0.5]``. + + Returns: + ``itk.Image[itk.F, 3]`` in the same physical space as ``labelmap`` + (origin, spacing, direction copied from the input). + """ + labels = itk.array_from_image(labelmap) + + # A voxel is on a label boundary when it differs from a 6-connected + # neighbor along any axis. Mark both voxels straddling each change. + boundary = np.zeros(labels.shape, dtype=bool) + for axis in range(labels.ndim): + changed = np.diff(labels, axis=axis) != 0 + lower = [slice(None)] * labels.ndim + upper = [slice(None)] * labels.ndim + lower[axis] = slice(0, -1) + upper[axis] = slice(1, None) + boundary[tuple(lower)] |= changed + boundary[tuple(upper)] |= changed + + if boundary.any(): + boundary_image = itk.image_from_array(boundary.astype(np.uint8)) + boundary_image.CopyInformation(labelmap) + distance_filter = itk.SignedMaurerDistanceMapImageFilter.New( + Input=boundary_image + ) + distance_filter.SetSquaredDistance(False) + distance_filter.SetUseImageSpacing(True) + distance_filter.Update() + distance = np.abs( + itk.array_from_image(distance_filter.GetOutput()).astype(np.float32) + ) + else: + # No inter-label boundary exists (single uniform label); every + # voxel gets a zero offset. + distance = np.zeros(labels.shape, dtype=np.float32) + + offset = np.clip(distance, 0.0, max_distance_mm) / distance_scale + encoded = labels.astype(np.float32) + offset + + encoded_image = itk.image_from_array(encoded) + encoded_image.CopyInformation(labelmap) + return encoded_image diff --git a/src/physiomotion4d/register_images_greedy.py b/src/physiomotion4d/register_images_greedy.py index 8a7a6cb..9e3016f 100644 --- a/src/physiomotion4d/register_images_greedy.py +++ b/src/physiomotion4d/register_images_greedy.py @@ -66,6 +66,15 @@ class RegisterImagesGreedy(RegisterImagesBase): deformable_smoothing: Smoothing sigmas for deformable (e.g. "2.0vox 0.5vox") """ + # picsl_greedy 0.0.12 segfaults when its multi-component metric (image + + # labelmap channels) allocates a working buffer for a fixed grid larger + # than roughly 100M voxels (empirically: 95M voxels succeeds, 104M crashes; + # single-channel metrics are unaffected at any size). When the labelmap + # channel is active, the metric inputs are isotropically downsampled to + # stay under this conservative cap. Greedy emits physical-space transforms, + # so a coarser metric grid only coarsens warp sampling, not the frame. + _MAX_METRIC_VOXELS = 90_000_000 + def __init__(self, log_level: int | str = logging.INFO) -> None: """Initialize the Greedy image registration class. @@ -137,6 +146,74 @@ def _greedy_iterations_str(self) -> str: """Format iterations as Greedy -n string (e.g. 40x20x10).""" return "x".join(str(i) for i in self.number_of_iterations) + def _metric_downsample_scale(self, reference_image: itk.Image) -> float: + """Per-axis scale that keeps ``reference_image`` under the voxel cap. + + Returns ``1.0`` when the grid already fits within + ``_MAX_METRIC_VOXELS``; otherwise returns the isotropic per-axis factor + ``(_MAX_METRIC_VOXELS / voxels) ** (1/3)`` (always < 1.0) so the + downsampled grid lands at or just below the cap. + + Args: + reference_image: The fixed metric image (X, Y, Z) whose voxel count + drives the Greedy multi-component buffer size. + + Returns: + Per-axis resampling scale in ``(0, 1]``. + """ + size = reference_image.GetLargestPossibleRegion().GetSize() + voxels = int(size[0]) * int(size[1]) * int(size[2]) + if voxels <= self._MAX_METRIC_VOXELS: + return 1.0 + scale = float((self._MAX_METRIC_VOXELS / voxels) ** (1.0 / 3.0)) + self.log_info( + "Greedy labelmap metric: downsampling %d-voxel fixed grid by " + "%.3f/axis to stay under the %d-voxel picsl_greedy crash threshold.", + voxels, + scale, + self._MAX_METRIC_VOXELS, + ) + return scale + + def _downsample_image( + self, image: itk.Image, scale: float, nearest: bool = False + ) -> itk.Image: + """Isotropically resample ``image`` by ``scale`` (no-op when >= 1.0). + + The physical extent is preserved exactly: the new per-axis spacing is + chosen so ``new_size * new_spacing == old_size * old_spacing``, so the + coarser grid covers the same world-space region with the same origin + and direction. Axis order is ITK world order (X, Y, Z). + + Args: + image: Scalar 3D ``itk.Image`` to resample. + scale: Per-axis factor in ``(0, 1]``; ``>= 1.0`` returns ``image`` + unchanged so the full-resolution path is untouched. + nearest: Use nearest-neighbor interpolation (for labelmaps and + masks) instead of linear. + + Returns: + The resampled ``itk.Image``, or ``image`` itself when ``scale`` is + ``>= 1.0``. + """ + if scale >= 1.0: + return image + size = image.GetLargestPossibleRegion().GetSize() + spacing = image.GetSpacing() + new_size = [max(1, int(round(int(size[i]) * scale))) for i in range(3)] + new_spacing = [float(spacing[i]) * int(size[i]) / new_size[i] for i in range(3)] + kwargs: dict[str, Any] = { + "output_origin": image.GetOrigin(), + "output_direction": image.GetDirection(), + "size": new_size, + "output_spacing": new_spacing, + } + if nearest: + kwargs["interpolator"] = itk.NearestNeighborInterpolateImageFunction.New( + image + ) + return itk.resample_image_filter(image, **kwargs) + def _write_affine_matrix_file(self, mat_4x4: NDArray[np.float64]) -> str: """Write a 4x4 RAS affine matrix to a temporary Greedy ``.mat`` file. @@ -200,23 +277,33 @@ def _registration_method_affine_or_rigid( self, fixed_sitk: Any, moving_sitk: Any, - fixed_mask_sitk: Optional[Any], - moving_mask_sitk: Optional[Any], iterations_str: str, metric_str: str, dof: int, + fixed_mask_sitk: Optional[Any] = None, + moving_mask_sitk: Optional[Any] = None, + fixed_labelmap_sitk: Optional[Any] = None, + moving_labelmap_sitk: Optional[Any] = None, initial_affine: Optional[NDArray[np.float64]] = None, ) -> tuple[NDArray[np.float64], float]: """Run Greedy affine or rigid registration. Returns (4x4 matrix, loss).""" Greedy3D = _try_import_greedy() g = Greedy3D() - cmd = f"-i fixed moving -a -dof {dof} -n {iterations_str} -m {metric_str} -o aff_out" + cmd = "-d 3" + if fixed_labelmap_sitk is not None and moving_labelmap_sitk is not None: + cmd += " -w 0.60" + cmd += " -i fixed moving" kwargs: dict[str, Any] = { "fixed": fixed_sitk, "moving": moving_sitk, - "aff_out": None, } + if fixed_labelmap_sitk is not None and moving_labelmap_sitk is not None: + cmd += " -w 0.40 -i fixed_labelmap moving_labelmap" + kwargs["fixed_labelmap"] = fixed_labelmap_sitk + kwargs["moving_labelmap"] = moving_labelmap_sitk + cmd += f" -a -dof {dof} -n {iterations_str} -m {metric_str} -o aff_out" + kwargs["aff_out"] = None if fixed_mask_sitk is not None and moving_mask_sitk is not None: cmd += " -gm fixed_mask -mm moving_mask" kwargs["fixed_mask"] = fixed_mask_sitk @@ -247,10 +334,12 @@ def _registration_method_deformable( self, fixed_sitk: Any, moving_sitk: Any, - fixed_mask_sitk: Optional[Any], - moving_mask_sitk: Optional[Any], iterations_str: str, metric_str: str, + fixed_mask_sitk: Optional[Any] = None, + moving_mask_sitk: Optional[Any] = None, + fixed_labelmap_sitk: Optional[Any] = None, + moving_labelmap_sitk: Optional[Any] = None, initial_affine: Optional[NDArray[np.float64]] = None, ) -> tuple[Optional[NDArray[np.float64]], Any, float]: """Run Greedy deformable registration. Returns (affine 4x4 or None, warp_sitk, loss).""" @@ -259,8 +348,20 @@ def _registration_method_deformable( # Optional affine init (uses configured metric) if initial_affine is None: - cmd_aff = f"-i fixed moving -a -dof 6 -n {iterations_str} -m {metric_str} -o aff_init" - kwargs_aff = {"fixed": fixed_sitk, "moving": moving_sitk, "aff_init": None} + cmd_aff = "-d 3" + if fixed_labelmap_sitk is not None and moving_labelmap_sitk is not None: + cmd_aff += " -w 0.60" + cmd_aff += " -i fixed moving" + kwargs_aff = { + "fixed": fixed_sitk, + "moving": moving_sitk, + } + if fixed_labelmap_sitk is not None and moving_labelmap_sitk is not None: + cmd_aff += " -w 0.40 -i fixed_labelmap moving_labelmap" + kwargs_aff["fixed_labelmap"] = fixed_labelmap_sitk + kwargs_aff["moving_labelmap"] = moving_labelmap_sitk + cmd_aff += f" -a -dof 12 -n {iterations_str} -m {metric_str} -o aff_init" + kwargs_aff["aff_init"] = None if fixed_mask_sitk is not None and moving_mask_sitk is not None: cmd_aff += " -gm fixed_mask -mm moving_mask" kwargs_aff["fixed_mask"] = fixed_mask_sitk @@ -273,15 +374,23 @@ def _registration_method_deformable( # Greedy crashes (heap corruption) when the affine init is passed as an # in-memory matrix via -it; write it to a temp file and pass the path. initial_affine_file = self._write_affine_matrix_file(initial_affine) - cmd_def = ( - f"-i fixed moving -it {initial_affine_file} -n {iterations_str} " - f"-m {metric_str} -s {self.deformable_smoothing} -o warp_out" - ) + cmd_def = "-d 3" + if fixed_labelmap_sitk is not None and moving_labelmap_sitk is not None: + cmd_def += " -w 0.60" + cmd_def += " -i fixed moving" kwargs_def = { "fixed": fixed_sitk, "moving": moving_sitk, - "warp_out": None, } + if fixed_labelmap_sitk is not None and moving_labelmap_sitk is not None: + cmd_def += " -w 0.40 -i fixed_labelmap moving_labelmap" + kwargs_def["fixed_labelmap"] = fixed_labelmap_sitk + kwargs_def["moving_labelmap"] = moving_labelmap_sitk + cmd_def += ( + f" -it {initial_affine_file} -n {iterations_str}" + f" -m {metric_str} -s {self.deformable_smoothing} -o warp_out" + ) + kwargs_def["warp_out"] = None if fixed_mask_sitk is not None and moving_mask_sitk is not None: cmd_def += " -gm fixed_mask -mm moving_mask" kwargs_def["fixed_mask"] = fixed_mask_sitk @@ -326,15 +435,67 @@ def registration_method( raise ValueError("Fixed image must be set before registration.") moving_pre = moving_image_pre if moving_image_pre is not None else moving_image - fixed_sitk = self._itk_to_sitk(self.fixed_image_pre) + + # The labelmap is added as a second Greedy metric channel only when both + # the fixed and moving labelmaps are present. That multi-component + # metric crashes picsl_greedy on large grids, so downsample every metric + # input by a single isotropic scale when (and only when) the channel is + # active; the single-channel path stays full resolution. + use_labelmap_channel = ( + self.fixed_labelmap is not None and moving_labelmap is not None + ) + metric_scale = ( + self._metric_downsample_scale(self.fixed_image_pre) + if use_labelmap_channel + else 1.0 + ) + + fixed_pre = self._downsample_image(self.fixed_image_pre, metric_scale) + moving_pre = self._downsample_image(moving_pre, metric_scale) + # warp_out lands on the (possibly downsampled) fixed grid; use the same + # grid as the displacement-field reference so shapes match. + displacement_reference = fixed_pre + fixed_sitk = self._itk_to_sitk(fixed_pre) moving_sitk = self._itk_to_sitk(moving_pre) + # Greedy applies one global metric to every input channel. A raw + # integer labelmap is piecewise-constant, so NCC sees zero local + # variance and emits NaN gradients (a native crash). Encode each + # labelmap as a continuous label-plus-boundary-distance field instead. + from physiomotion4d.labelmap_tools import LabelmapTools + + labelmap_tools = LabelmapTools() + fixed_labelmap_sitk = None + moving_labelmap_sitk = None + if self.fixed_labelmap is not None: + fixed_labelmap_ds = self._downsample_image( + self.fixed_labelmap, metric_scale, nearest=True + ) + fixed_labelmap_dist_map = labelmap_tools.create_distance_map( + fixed_labelmap_ds + ) + fixed_labelmap_sitk = self._itk_to_sitk(fixed_labelmap_dist_map) + if moving_labelmap is not None: + moving_labelmap_ds = self._downsample_image( + moving_labelmap, metric_scale, nearest=True + ) + moving_labelmap_dist_map = labelmap_tools.create_distance_map( + moving_labelmap_ds + ) + moving_labelmap_sitk = self._itk_to_sitk(moving_labelmap_dist_map) + fixed_mask_sitk = None moving_mask_sitk = None if self.fixed_mask is not None: - fixed_mask_sitk = self._itk_to_sitk(self.fixed_mask) + fixed_mask_ds = self._downsample_image( + self.fixed_mask, metric_scale, nearest=True + ) + fixed_mask_sitk = self._itk_to_sitk(fixed_mask_ds) if moving_mask is not None: - moving_mask_sitk = self._itk_to_sitk(moving_mask) + moving_mask_ds = self._downsample_image( + moving_mask, metric_scale, nearest=True + ) + moving_mask_sitk = self._itk_to_sitk(moving_mask_ds) iterations_str = self._greedy_iterations_str() metric_str = self._greedy_metric() @@ -365,10 +526,12 @@ def registration_method( mat, loss_val = self._registration_method_affine_or_rigid( fixed_sitk, moving_sitk, - fixed_mask_sitk, - moving_mask_sitk, - iterations_str, - metric_str, + fixed_mask_sitk=fixed_mask_sitk, + moving_mask_sitk=moving_mask_sitk, + fixed_labelmap_sitk=fixed_labelmap_sitk, + moving_labelmap_sitk=moving_labelmap_sitk, + iterations_str=iterations_str, + metric_str=metric_str, dof=6, initial_affine=initial_affine, ) @@ -380,10 +543,12 @@ def registration_method( mat, loss_val = self._registration_method_affine_or_rigid( fixed_sitk, moving_sitk, - fixed_mask_sitk, - moving_mask_sitk, - iterations_str, - metric_str, + fixed_mask_sitk=fixed_mask_sitk, + moving_mask_sitk=moving_mask_sitk, + fixed_labelmap_sitk=fixed_labelmap_sitk, + moving_labelmap_sitk=moving_labelmap_sitk, + iterations_str=iterations_str, + metric_str=metric_str, dof=12, initial_affine=initial_affine, ) @@ -396,10 +561,12 @@ def registration_method( aff_mat, warp_sitk, loss_val = self._registration_method_deformable( fixed_sitk, moving_sitk, - fixed_mask_sitk, - moving_mask_sitk, - iterations_str, - metric_str, + fixed_mask_sitk=fixed_mask_sitk, + moving_mask_sitk=moving_mask_sitk, + fixed_labelmap_sitk=fixed_labelmap_sitk, + moving_labelmap_sitk=moving_labelmap_sitk, + iterations_str=iterations_str, + metric_str=metric_str, initial_affine=initial_affine, ) aff_tfm = ( @@ -408,7 +575,7 @@ def registration_method( # warp_sitk can be displacement field (SimpleITK image) or numpy if hasattr(warp_sitk, "GetSize"): disp_tfm = self._sitk_warp_to_itk_displacement_transform( - warp_sitk, self.fixed_image + warp_sitk, displacement_reference ) else: # Assume numpy displacement field (z,y,x,3) @@ -416,7 +583,7 @@ def registration_method( image_tools = ImageTools() warp_arr = np.asarray(warp_sitk, dtype=np.float64) - ref = self.fixed_image + ref = displacement_reference disp_itk = image_tools.convert_array_to_image_of_vectors( warp_arr, ref, itk.D ) From 3e590f0499530070e4cee9985cc18c1cbe676ea9 Mon Sep 17 00:00:00 2001 From: Stephen Aylward Date: Thu, 4 Jun 2026 11:35:26 -0400 Subject: [PATCH 05/10] ENH: Generating summary results per method/iterations. --- .../registration_results_analysis.py | 39 ++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/experiments/LongitudinalRegistration/registration_results_analysis.py b/experiments/LongitudinalRegistration/registration_results_analysis.py index 50b8896..562f61b 100644 --- a/experiments/LongitudinalRegistration/registration_results_analysis.py +++ b/experiments/LongitudinalRegistration/registration_results_analysis.py @@ -145,6 +145,35 @@ def _stats(series: pd.Series) -> dict[str, float]: return table +def _method_summary( + dice_table: pd.DataFrame, + landmark_table: pd.DataFrame, +) -> pd.DataFrame: + """Collapse both summary tables to one row per ``(results_dir, method)``. + + Extracts the ``"all"`` slice (pooled over every label / landmark) from + each table and concatenates the columns side-by-side. + + Args: + dice_table: Output of :func:`_aggregate` for Dice scores. + landmark_table: Output of :func:`_aggregate` for landmark RMSE. + + Returns: + Data frame indexed by ``(results_dir, method)`` with flat columns + ``dice_mean``, ``dice_std``, ``dice_p95``, ``landmark_mean``, + ``landmark_std``, ``landmark_p95``. + """ + parts: list[pd.DataFrame] = [] + for table, prefix in ((dice_table, "dice"), (landmark_table, "landmark")): + if ALL_KEY in table.columns.get_level_values(0): + sub = table[ALL_KEY].copy() + sub.columns = [f"{prefix}_{c}" for c in sub.columns] + parts.append(sub) + if not parts: + return pd.DataFrame() + return pd.concat(parts, axis=1) + + def summarize(base_dir: Path) -> tuple[pd.DataFrame, pd.DataFrame]: """Build the Dice and landmark RMSE summary tables for ``base_dir``. @@ -219,7 +248,13 @@ def main(argv: Optional[list[str]] = None) -> None: args = parser.parse_args(argv) dice_table, landmark_table = summarize(args.base_dir) + methods_table = _method_summary(dice_table, landmark_table) + _print_table( + "Per-method summary (mean / std / 95th percentile across ALL labels " + "and ALL landmarks), grouped by results_dir + method", + methods_table, + ) _print_table( "Dice score by label (mean / std / 95th percentile), grouped by " "results_dir + method, pooled over subjects and time points", @@ -233,11 +268,13 @@ def main(argv: Optional[list[str]] = None) -> None: if args.out_dir is not None: args.out_dir.mkdir(parents=True, exist_ok=True) + methods_out = args.out_dir / "summary_methods.csv" dice_out = args.out_dir / "summary_dice.csv" landmark_out = args.out_dir / "summary_landmarks.csv" + methods_table.to_csv(methods_out) dice_table.to_csv(dice_out) landmark_table.to_csv(landmark_out) - print(f"\nWrote {dice_out}\nWrote {landmark_out}") + print(f"\nWrote {methods_out}\nWrote {dice_out}\nWrote {landmark_out}") if __name__ == "__main__": From b7365a554d0279f0857135e4a3fd5e0e8b1091f9 Mon Sep 17 00:00:00 2001 From: Stephen Aylward Date: Thu, 4 Jun 2026 12:21:19 -0400 Subject: [PATCH 06/10] BUG: Fixed API to fine-tune-icon --- .../1-initial_registration.py | 23 ++++++------------- .../workflow_fine_tune_icon_registration.py | 20 ++++++++++++---- 2 files changed, 23 insertions(+), 20 deletions(-) diff --git a/experiments/LongitudinalRegistration/1-initial_registration.py b/experiments/LongitudinalRegistration/1-initial_registration.py index 197c0c2..a5ab56f 100644 --- a/experiments/LongitudinalRegistration/1-initial_registration.py +++ b/experiments/LongitudinalRegistration/1-initial_registration.py @@ -27,11 +27,11 @@ src_data_dir_base = Path("d:/PhysioMotion4D/duke_data/gated_nii") segmentation_dir_base = Path("d:/PhysioMotion4D/duke_data/simple_ascardio") -use_mask_list = [False, False, False, False, False, False, False, False] -use_labelmap_list = [True, False, True, False, True, False, True, False] +use_mask_list = [False, False, False, False, False] +use_labelmap_list = [False, False, False, False, False] # ICON only -use_mass_list = [False, False, False, False, False, False, False, False] +use_mass_list = [False, False, False, False, False] methods_list = [ ["Greedy"], @@ -39,9 +39,6 @@ ["Greedy"], ["Greedy"], ["Greedy"], - ["Greedy"], - ["Greedy"], - ["Greedy"], ] number_of_iterations_ANTS_list = [ [40, 20, 10], @@ -49,21 +46,15 @@ [40, 20, 10], [40, 20, 10], [40, 20, 10], - [40, 20, 10], - [40, 20, 10], - [40, 20, 10], ] number_of_iterations_greedy_list = [ + [20, 20, 10], [40, 20, 10], - [40, 20, 10], - [80, 20, 10], + [60, 20, 10], [80, 20, 10], - [40, 40, 10], - [40, 40, 10], - [80, 40, 5], - [80, 40, 5], + [100, 20, 10], ] -number_of_iterations_ICON_list = [100, 100, 100, 100, 100, 100, 100, 100] +number_of_iterations_ICON_list = [100, 100, 100, 100, 100] exclude_tokens = ["nop"] ref_suffix = "_ref" diff --git a/src/physiomotion4d/workflow_fine_tune_icon_registration.py b/src/physiomotion4d/workflow_fine_tune_icon_registration.py index 3b91e0a..539a4e4 100644 --- a/src/physiomotion4d/workflow_fine_tune_icon_registration.py +++ b/src/physiomotion4d/workflow_fine_tune_icon_registration.py @@ -262,6 +262,11 @@ def __init__( self.subject_mask_files = subject_mask_files self.subject_landmark_files = subject_landmark_files + self.use_segmentations: bool = subject_segmentation_files is not None + self.use_masks: bool = ( + subject_mask_files is not None or subject_segmentation_files is not None + ) + self.output_dir = Path(output_dir).resolve() self.fine_tune_name = fine_tune_name self.experiment_dir = self.output_dir / fine_tune_name @@ -290,8 +295,8 @@ def __init__( self.labelmap_tools = LabelmapTools(log_level=log_level) self.registrar: Optional[RegisterTimeSeriesImages] = None - self._use_segmentations: Optional[bool] = None - self._use_masks: Optional[bool] = None + self._use_segmentations: bool = self.use_segmentations + self._use_masks: bool = self.use_masks self._dataset_json_path: Optional[Path] = None self._config_yaml_path: Optional[Path] = None @@ -368,7 +373,9 @@ def _derive_mask( return mask_path def prepare_dataset( - self, use_segmentations: bool = True, use_masks: bool = True + self, + use_segmentations: Optional[bool] = None, + use_masks: Optional[bool] = None, ) -> Path: """Write the uniGradICON dataset JSON from the configured file lists. @@ -392,6 +399,11 @@ def prepare_dataset( """ self.experiment_dir.mkdir(parents=True, exist_ok=True) + if use_segmentations is None: + use_segmentations = self.use_segmentations + if use_masks is None: + use_masks = self.use_masks + self._use_segmentations = use_segmentations self._use_masks = use_masks @@ -524,7 +536,7 @@ def prepare_config(self, dataset_json_path: Optional[Path] = None) -> Path: "dice_loss_weight": self.dice_loss_weight, "lncc_sigma": self.lncc_sigma, "loss_function_masking": self._use_masks, - "use_label": False, + "use_label": self._use_segmentations, "roi_masking": False, }, "datasets": [ From fffc95603b0e7b9dbf725ca482627b95239e0b66 Mon Sep 17 00:00:00 2001 From: Stephen Aylward Date: Fri, 5 Jun 2026 10:55:20 -0400 Subject: [PATCH 07/10] DOC: Coderabbit suggested changes to match code/doc and asserts --- .../1-initial_registration.py | 26 +++++-------------- .../2-finetune_icon.py | 8 +++--- .../3-recon_4d_icon_eval.py | 17 +++++++----- src/physiomotion4d/labelmap_tools.py | 14 +++++----- src/physiomotion4d/register_images_ants.py | 6 +++++ src/physiomotion4d/register_images_icon.py | 18 ------------- .../register_time_series_images.py | 2 +- .../segment_heart_simpleware.py | 2 +- src/physiomotion4d/transform_tools.py | 2 +- tests/test_labelmap_tools.py | 4 +++ 10 files changed, 41 insertions(+), 58 deletions(-) diff --git a/experiments/LongitudinalRegistration/1-initial_registration.py b/experiments/LongitudinalRegistration/1-initial_registration.py index a5ab56f..8775814 100644 --- a/experiments/LongitudinalRegistration/1-initial_registration.py +++ b/experiments/LongitudinalRegistration/1-initial_registration.py @@ -27,34 +27,22 @@ src_data_dir_base = Path("d:/PhysioMotion4D/duke_data/gated_nii") segmentation_dir_base = Path("d:/PhysioMotion4D/duke_data/simple_ascardio") -use_mask_list = [False, False, False, False, False] -use_labelmap_list = [False, False, False, False, False] +use_mask_list = [False] +use_labelmap_list = [False] # ICON only -use_mass_list = [False, False, False, False, False] +use_mass_list = [False] methods_list = [ ["Greedy"], - ["Greedy"], - ["Greedy"], - ["Greedy"], - ["Greedy"], ] number_of_iterations_ANTS_list = [ [40, 20, 10], - [40, 20, 10], - [40, 20, 10], - [40, 20, 10], - [40, 20, 10], ] number_of_iterations_greedy_list = [ - [20, 20, 10], - [40, 20, 10], [60, 20, 10], - [80, 20, 10], - [100, 20, 10], ] -number_of_iterations_ICON_list = [100, 100, 100, 100, 100] +number_of_iterations_ICON_list = [100] exclude_tokens = ["nop"] ref_suffix = "_ref" @@ -167,8 +155,8 @@ def load_or_derive_mask(labelmap: itk.Image, mask_path: Path) -> itk.Image: 3 mm physical-radius dilation) and write it out so subsequent runs and the ICON eval reuse the same mask. """ - # if mask_path.exists(): - # return itk.imread(str(mask_path)) + if mask_path.exists(): + return itk.imread(str(mask_path)) mask = labelmap_tools.convert_labelmap_to_mask( labelmap, dilation_in_mm=mask_dilation_mm, @@ -472,7 +460,7 @@ def crop_image_to_mask( else: # ICON: GPU deep-learning deformable registration. reg = RegisterImagesICON() reg.set_number_of_iterations(number_of_iterations_ICON) - num_iters_str = ".".join(str(n) for n in number_of_iterations_ICON) + num_iters_str = str(number_of_iterations_ICON) reg.set_multi_modality(False) reg.set_mass_preservation(use_mass) if icon_weights_path is not None: diff --git a/experiments/LongitudinalRegistration/2-finetune_icon.py b/experiments/LongitudinalRegistration/2-finetune_icon.py index eede97d..85c8a06 100644 --- a/experiments/LongitudinalRegistration/2-finetune_icon.py +++ b/experiments/LongitudinalRegistration/2-finetune_icon.py @@ -39,7 +39,7 @@ # %% ref_data_dir = Path("d:/PhysioMotion4D/duke_data/ref_images") src_data_dir_base = Path("d:/PhysioMotion4D/duke_data/gated_nii") -segmentation_dir_base = Path("d:/PhysioMotion4D/duke_data/simple_ascardio") +labelmap_dir_base = Path("d:/PhysioMotion4D/duke_data/simple_ascardio") # Where the workflow writes the dataset JSON, YAML config, derived masks, and # the uniGradICON ``checkpoints/`` tree. experiment_dir resolves to @@ -103,7 +103,7 @@ # For each train-cohort patient, list gated frames in # ``src_data_dir_base / `` (excluding ``"nop"`` non-gated # references) and pair each frame with its -# ``_labelmap.nii.gz`` and ``_mask.nii.gz`` under ``segmentation_dir_base / ``. +# ``_labelmap.nii.gz`` and ``_mask.nii.gz`` under ``labelmap_dir_base / ``. # Patients with no source directory or no valid frames are skipped here only # — they remain part of the canonical train list above, but contribute no # training data. Missing labelmaps are recorded as ``None`` so the workflow @@ -120,7 +120,7 @@ # %% -def load_or_derive_mask(labelmap_path: Path) -> str: +def load_or_derive_mask(labelmap_path: Path) -> Optional[str]: """Create (or reuse) a loss-function mask next to ``labelmap_path``. Thresholds the labelmap at ``>0`` and dilates by ``mask_dilation_mm`` mm @@ -193,7 +193,7 @@ def gather_warped_frames(method_dir: Path) -> tuple[list[str], list[Optional[str for patient_id in train_subjects: src_dir = src_data_dir_base / patient_id - seg_dir = segmentation_dir_base / patient_id + seg_dir = labelmap_dir_base / patient_id if not src_dir.is_dir(): print(f" Skipping {patient_id}: source dir {src_dir} not found") diff --git a/experiments/LongitudinalRegistration/3-recon_4d_icon_eval.py b/experiments/LongitudinalRegistration/3-recon_4d_icon_eval.py index bd462b4..6b61734 100644 --- a/experiments/LongitudinalRegistration/3-recon_4d_icon_eval.py +++ b/experiments/LongitudinalRegistration/3-recon_4d_icon_eval.py @@ -3,11 +3,11 @@ # # Enumerates the Duke patient cohort by sorting ``ref_images/`` and uses the # *last 20%* of patients as the held-out test set — the same fixed split -# applied by ``1-finetune_icon.py`` (first 80% train, last 20% test). For +# applied by ``2-finetune_icon.py`` (first 80% train, last 20% test). For # each test subject the 70th-percentile gated frame is selected as the # reference and every other frame is registered to it twice with # ``RegisterTimeSeriesImages``: once with the default uniGradICON weights and -# once with the finetuned checkpoint from ``1-finetune_icon.py``. The +# once with the finetuned checkpoint from ``2-finetune_icon.py``. The # resampler-convention inverse transform (which maps moving-grid points back # to reference-grid points) is applied to each time-point's precomputed # landmarks to land them in reference space, and the Euclidean error against @@ -248,11 +248,14 @@ # ## 5. Write the wide-form per-timepoint summary CSV # %% -with summary_file.open("w", newline="", encoding="utf-8") as fh: - writer = csv.DictWriter(fh, fieldnames=list(summary_rows[0].keys())) - writer.writeheader() - writer.writerows(summary_rows) -print(f"Wrote summary: {summary_file}") +if not summary_rows: + print("No summary rows to write") +else: + with summary_file.open("w", newline="", encoding="utf-8") as fh: + writer = csv.DictWriter(fh, fieldnames=list(summary_rows[0].keys())) + writer.writeheader() + writer.writerows(summary_rows) + print(f"Wrote summary: {summary_file}") print(f"Wrote landmark details: {detail_file}") # %% [markdown] diff --git a/src/physiomotion4d/labelmap_tools.py b/src/physiomotion4d/labelmap_tools.py index 2ab4d6c..08fe4e6 100644 --- a/src/physiomotion4d/labelmap_tools.py +++ b/src/physiomotion4d/labelmap_tools.py @@ -112,7 +112,7 @@ def create_distance_map( boundary between two differently-labeled regions: value = label + min(distance_to_nearest_boundary_mm, - max_distance_mm) / distance_scale + max_distance_mm) / distance_scale The boundary set is every voxel that 6-neighbors a voxel with a different label (background label ``0`` participates, so the outer @@ -122,9 +122,9 @@ def create_distance_map( to ``max_distance_mm``, divided by ``distance_scale``, and added to the voxel's original label. - With the defaults (``50`` mm clip, ``100`` scale) the fractional offset - stays in ``[0.0, 0.5]``, so it never reaches the next integer label and - label identity is recoverable as ``floor(value)``. + With the defaults (``20`` mm clip, ``5`` scale) the fractional offset + stays in ``[0.0, 4.0]``, potentially passing adjacent integer labels but + emphasizing in medial alignment as well as boundary. The motivation is registration metrics such as Greedy's NCC: a raw integer labelmap is piecewise-constant, so the local variance inside @@ -139,10 +139,10 @@ def create_distance_map( Args: labelmap: Multi-label (or binary) ``itk.Image`` of integer labels. - max_distance_mm: Distance clip, in millimeters. Default 50.0. + max_distance_mm: Distance clip, in millimeters. Default 20.0. distance_scale: Divisor applied to the clipped distance before it - is added to the label. Default 100.0. With the default clip - this bounds the fractional offset to ``[0, 0.5]``. + is added to the label. Default 5.0. With the default clip + this bounds the fractional offset to ``[0, 4.0]``. Returns: ``itk.Image[itk.F, 3]`` in the same physical space as ``labelmap`` diff --git a/src/physiomotion4d/register_images_ants.py b/src/physiomotion4d/register_images_ants.py index 7412763..43cf59d 100644 --- a/src/physiomotion4d/register_images_ants.py +++ b/src/physiomotion4d/register_images_ants.py @@ -607,6 +607,12 @@ def registration_method( initial_forward_transform, self.fixed_image, ) + if self.moving_mask is not None: + self.moving_mask_pre = transform_tools.transform_image( + self.moving_mask, + initial_forward_transform, + self.fixed_image, + ) transform_type = None if self.transform_type == "Deformable": diff --git a/src/physiomotion4d/register_images_icon.py b/src/physiomotion4d/register_images_icon.py index 3d28dad..65c527e 100644 --- a/src/physiomotion4d/register_images_icon.py +++ b/src/physiomotion4d/register_images_icon.py @@ -10,8 +10,6 @@ """ import logging -import pathlib -import sys from typing import Optional, Union import icon_registration as icon @@ -303,29 +301,13 @@ def _ensure_net(self) -> None: """ if self.net is not None: return - main_module = sys.modules.get("__main__") - main_file = getattr(main_module, "__file__", None) - top_dir = pathlib.Path.cwd() - if main_file is not None: - top_dir = pathlib.Path(main_file).resolve().parent if self.use_multi_modality: - if self.weights_path is None: - self.weights_path = str( - top_dir - / "network_weights" - / "multigradicon1.0" - / "Step_2_final.trch" - ) self.net = get_multigradicon( loss_fn=icon.LNCC(sigma=5), apply_intensity_conservation_loss=self.use_mass_preservation, weights_location=self.weights_path, ) else: - if self.weights_path is None: - self.weights_path = str( - top_dir / "network_weights" / "unigradicon1.0" / "Step_2_final.trch" - ) self.net = get_unigradicon( loss_fn=icon.LNCC(sigma=5), apply_intensity_conservation_loss=self.use_mass_preservation, diff --git a/src/physiomotion4d/register_time_series_images.py b/src/physiomotion4d/register_time_series_images.py index 89dd02d..6ced80c 100644 --- a/src/physiomotion4d/register_time_series_images.py +++ b/src/physiomotion4d/register_time_series_images.py @@ -211,7 +211,7 @@ def set_fixed_labelmap(self, fixed_labelmap: Optional[itk.Image]) -> None: This passes through to the underlying registration method. Args: - fixed_labelmap (itk.Image): Labelmap defining ROI + fixed_labelmap (Optional[itk.Image]): Labelmap defining ROI """ self.fixed_labelmap = fixed_labelmap diff --git a/src/physiomotion4d/segment_heart_simpleware.py b/src/physiomotion4d/segment_heart_simpleware.py index 1d59e4a..63ba60d 100644 --- a/src/physiomotion4d/segment_heart_simpleware.py +++ b/src/physiomotion4d/segment_heart_simpleware.py @@ -106,7 +106,7 @@ def __init__(self, log_level: int | str = logging.INFO): self._finalize_other_group() # Path to Simpleware Medical console executable - self.simpleware_exe_path = "C:/Program Files/Synopsys/Simpleware Medical/X-2025.06/ConsoleSimplewareMedical.exe" + self.simpleware_exe_path = "C:/Program Files/Synopsys/Simpleware Medical/Y-2026.03/ConsoleSimplewareMedical.exe" # Path to the Simpleware Python script for heart segmentation self.simpleware_script_path = os.path.join( diff --git a/src/physiomotion4d/transform_tools.py b/src/physiomotion4d/transform_tools.py index 2a22aa2..f0ed856 100644 --- a/src/physiomotion4d/transform_tools.py +++ b/src/physiomotion4d/transform_tools.py @@ -365,7 +365,7 @@ def transform_dataset( np.array(tfm.TransformPoint((float(p[0]), float(p[1]), float(p[2])))) for p in pnts ] - new_mesh.points = np.asarray(new_pnts, dtype=float) + new_mesh.points = np.asarray(new_pnts, dtype=float).reshape(-1, 3) if with_deformation_magnitude: if cp is not None: diff --git a/tests/test_labelmap_tools.py b/tests/test_labelmap_tools.py index e6f15cf..6e1279c 100644 --- a/tests/test_labelmap_tools.py +++ b/tests/test_labelmap_tools.py @@ -105,3 +105,7 @@ def test_preserves_image_information(self, labelmap_tools: LabelmapTools) -> Non assert list(mask.GetSpacing()) == [0.5, 1.0, 2.0] assert list(mask.GetOrigin()) == [10.0, -5.0, 3.0] + assert np.allclose( + itk.array_from_matrix(mask.GetDirection()), + itk.array_from_matrix(labelmap.GetDirection()), + ) From 19b59b05628c1a380b0e022db7517057031c024a Mon Sep 17 00:00:00 2001 From: Stephen Aylward Date: Fri, 5 Jun 2026 11:54:46 -0400 Subject: [PATCH 08/10] ENH: Additional coderabbit comments --- src/physiomotion4d/register_images_ants.py | 15 +++------- src/physiomotion4d/transform_tools.py | 34 +++++++++++++--------- tests/test_transform_tools.py | 18 ++++++++++++ 3 files changed, 42 insertions(+), 25 deletions(-) diff --git a/src/physiomotion4d/register_images_ants.py b/src/physiomotion4d/register_images_ants.py index 43cf59d..869c17d 100644 --- a/src/physiomotion4d/register_images_ants.py +++ b/src/physiomotion4d/register_images_ants.py @@ -510,7 +510,7 @@ def registration_method( moving_image: itk.Image, moving_mask: Optional[itk.Image] = None, moving_labelmap: Optional[itk.Image] = None, - moving_image_pre: Optional[ants.ANTsImage] = None, + moving_image_pre: Optional[itk.Image] = None, initial_forward_transform: Optional[itk.Transform] = None, ) -> dict[str, Union[itk.Transform, float]]: """Register moving image to fixed image using ANTs registration algorithm. @@ -523,8 +523,8 @@ def registration_method( moving_image (itk.image): The 3D image to be registered/aligned. moving_mask (itk.image, optional): Binary mask defining the region of interest in the moving image - moving_image_pre (ants.core.ANTsImage, optional): Pre-processed moving image - in ANTs format. If None, preprocessing is performed automatically + moving_image_pre (itk.Image, optional): Pre-processed moving image. + If None, preprocessing is performed automatically initial_forward_transform (itk.Transform, optional): Initial forward transform (same convention as the returned forward_transform: used to warp the moving image onto the fixed @@ -592,13 +592,6 @@ def registration_method( if self.fixed_image_pre is None: self.fixed_image_pre = self.preprocess(self.fixed_image, self.modality) - # Apply any initial transform by pre-warping the moving image onto the - # fixed grid (the same approach RegisterImagesICON uses), instead of - # passing it to ants.registration as an initial_transform. ANTS - # mishandles matrix (affine/translation) initial transforms, badly - # corrupting the result; pre-warping keeps the composition below - # self-consistent for any initial transform type. The registration then - # solves the residual and the composition recovers the full transform. if initial_forward_transform is not None: self.log_info("Pre-warping moving image with initial transform...") transform_tools = TransformTools() @@ -608,7 +601,7 @@ def registration_method( self.fixed_image, ) if self.moving_mask is not None: - self.moving_mask_pre = transform_tools.transform_image( + self.moving_mask = transform_tools.transform_image( self.moving_mask, initial_forward_transform, self.fixed_image, diff --git a/src/physiomotion4d/transform_tools.py b/src/physiomotion4d/transform_tools.py index f0ed856..c7dee4d 100644 --- a/src/physiomotion4d/transform_tools.py +++ b/src/physiomotion4d/transform_tools.py @@ -89,7 +89,7 @@ def combine_displacement_field_transforms( mode: str = "compose", tfm1_blur_sigma: float = 0.0, tfm2_blur_sigma: float = 0.0, - ) -> itk.DisplacementFieldTransform: + ) -> itk.CompositeTransform: """ Compose two displacement field transforms. @@ -225,7 +225,7 @@ def convert_transform_to_displacement_field( field = tfm.GetDisplacementField() field_arr = itk.array_view_from_image(tfm.GetDisplacementField()) reference_image_arr = itk.array_view_from_image(reference_image) - if field_arr.shape[:2] != reference_image_arr.shape: + if field_arr.shape[:3] != reference_image_arr.shape: field_filter = itk.TransformToDisplacementFieldFilter[ itk.Image[itk.Vector[itk.F, 3], 3], TfmPrecision ].New() @@ -419,7 +419,7 @@ def transform_image( ... ) >>> # Transform label map preserving discrete values >>> warped_labels = transform_tools.transform_image( - ... labelmap, transform, reference, tfm_type='nearest' + ... labelmap, transform, reference, interpolation_method='nearest' ... ) """ # Handle case where tfm is a list (e.g., from itk.transformread) @@ -618,6 +618,7 @@ def combine_transforms_with_masks( sum_fields_arr = mask1_arr * field1_arr + mask2_arr * field2_arr denom = mask1_arr + mask2_arr + denom[denom == 0] = 1.0 combined_field_arr = sum_fields_arr / denom @@ -757,17 +758,22 @@ def generate_grid_image( width_max = width_min + line_width for i in range(grid_size): for j in range(grid_size): - min_idx0 = int(i * grid_spacing[0]) - width_min - max_idx0 = int(i * grid_spacing[0]) + width_max - min_idx1 = int(j * grid_spacing[1]) - width_min - max_idx1 = int(j * grid_spacing[1]) + width_max - img_arr[min_idx0:max_idx0, min_idx1:max_idx1, :] = img_arr_max - min_idx2 = int(j * grid_spacing[2]) - width_min - max_idx2 = int(j * grid_spacing[2]) + width_max - img_arr[min_idx0:max_idx0, :, min_idx2:max_idx2] = img_arr_max - min_idx1 = int(i * grid_spacing[1]) - width_min - max_idx1 = int(i * grid_spacing[1]) + width_max - img_arr[:, min_idx1:max_idx1, min_idx2:max_idx2] = img_arr_max + min_idx0 = max(0, int(i * grid_spacing[0]) - width_min) + max_idx0 = min(img_arr.shape[0], int(i * grid_spacing[0]) + width_max) + min_idx1 = max(0, int(j * grid_spacing[1]) - width_min) + max_idx1 = min(img_arr.shape[1], int(j * grid_spacing[1]) + width_max) + if min_idx0 < max_idx0 and min_idx1 < max_idx1: + img_arr[min_idx0:max_idx0, min_idx1:max_idx1, :] = img_arr_max + + min_idx2 = max(0, int(j * grid_spacing[2]) - width_min) + max_idx2 = min(img_arr.shape[2], int(j * grid_spacing[2]) + width_max) + if min_idx0 < max_idx0 and min_idx2 < max_idx2: + img_arr[min_idx0:max_idx0, :, min_idx2:max_idx2] = img_arr_max + + min_idx1 = max(0, int(i * grid_spacing[1]) - width_min) + max_idx1 = min(img_arr.shape[1], int(i * grid_spacing[1]) + width_max) + if min_idx1 < max_idx1 and min_idx2 < max_idx2: + img_arr[:, min_idx1:max_idx1, min_idx2:max_idx2] = img_arr_max grid_image = itk.image_from_array(img_arr) grid_image.CopyInformation(reference_image) diff --git a/tests/test_transform_tools.py b/tests/test_transform_tools.py index 3b12635..75e16de 100644 --- a/tests/test_transform_tools.py +++ b/tests/test_transform_tools.py @@ -19,6 +19,24 @@ from physiomotion4d.transform_tools import TransformTools +def test_generate_grid_image_clamps_boundary_lines() -> None: + """ + Grid image clamps boundary slices for an ITK image with axes (X, Y, Z). + + The synthetic ITK image has axes (X, Y, Z) = (7, 6, 5), created from + NumPy array shape (Z, Y, X) = (5, 6, 7). + """ + image_arr = np.zeros((5, 6, 7), dtype=np.float32) + image_arr[-1, -1, -1] = 5.0 + image = itk.image_from_array(image_arr) + + grid_image = TransformTools().generate_grid_image(image, grid_size=2, line_width=3) + grid_arr = itk.array_from_image(grid_image) + + assert grid_arr.shape == image_arr.shape + assert grid_arr[0, 0, 0] == 5.0 + + @pytest.mark.slow class TestTransformTools: """Test suite for TransformTools functionality.""" From 6494c98bdd8b8b2f4b3ec7dafa397b08421979a8 Mon Sep 17 00:00:00 2001 From: Stephen Aylward Date: Fri, 5 Jun 2026 11:59:47 -0400 Subject: [PATCH 09/10] BUG: Update workflow parameter names for fine-tuning ICON --- .../workflow_fine_tune_icon_registration.py | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/src/physiomotion4d/workflow_fine_tune_icon_registration.py b/src/physiomotion4d/workflow_fine_tune_icon_registration.py index 539a4e4..c16e9f9 100644 --- a/src/physiomotion4d/workflow_fine_tune_icon_registration.py +++ b/src/physiomotion4d/workflow_fine_tune_icon_registration.py @@ -79,7 +79,7 @@ class WorkflowFineTuneICONRegistration(PhysioMotion4DBase): subject_ids (Optional[list[str]]): One ID per subject (e.g. patient identifiers). Written into the dataset JSON's ``subject_id`` field; falls back to synthetic ``subject_NNNN`` when ``None``. - subject_segmentation_files (Optional[list[list[Optional[str]]]]): + subject_labelmap_files (Optional[list[list[Optional[str]]]]): Per-subject multi-label segmentation/labelmap paths aligned with ``subject_image_files``. ``None`` (or per-image ``None``) means no segmentation for that image. If supplied for at least one image, @@ -88,7 +88,7 @@ class WorkflowFineTuneICONRegistration(PhysioMotion4DBase): Per-subject binary mask paths aligned with ``subject_image_files``. When supplied for a frame these masks are used directly for loss-function masking; otherwise masks are derived from - ``subject_segmentation_files``. + ``subject_labelmap_files``. subject_landmark_files (Optional[list[list[Optional[str]]]]): Per-subject landmark CSV paths (``Name,X,Y,Z`` format) aligned with ``subject_image_files``. Recorded in the dataset JSON for @@ -115,7 +115,7 @@ class WorkflowFineTuneICONRegistration(PhysioMotion4DBase): ... ], ... output_dir=Path('d:/PhysioMotion4D/icon_finetuned'), ... fine_tune_name='duke_4d_gated_icon_ft', - ... subject_segmentation_files=[ + ... subject_labelmap_files=[ ... ['pm0001/g000_labelmap.nii.gz', 'pm0001/g050_labelmap.nii.gz'], ... ['pm0002/g000_labelmap.nii.gz', 'pm0002/g050_labelmap.nii.gz'], ... ], @@ -140,7 +140,7 @@ def __init__( output_dir: Path, fine_tune_name: str, subject_ids: Optional[list[str]] = None, - subject_segmentation_files: Optional[list[list[Optional[str]]]] = None, + subject_labelmap_files: Optional[list[list[Optional[str]]]] = None, subject_mask_files: Optional[list[list[Optional[str]]]] = None, subject_landmark_files: Optional[list[list[Optional[str]]]] = None, epochs: int = 2000, @@ -177,7 +177,7 @@ def __init__( JSON's ``subject_id`` field so paired training groups frames that share an ID. ``None`` falls back to synthetic IDs of the form ``subject_0000``, ``subject_0001``, ... Must be unique. - subject_segmentation_files: Per-subject multi-label segmentation + subject_labelmap_files: Per-subject multi-label segmentation (labelmap) paths matching ``subject_image_files``. ``None`` disables paired-with-seg training. Individual ``None`` entries inside the inner lists skip just @@ -185,7 +185,7 @@ def __init__( subject_mask_files: Per-subject binary mask paths matching ``subject_image_files``. When supplied these are used directly for ICON loss-function masking; otherwise masks are derived - from ``subject_segmentation_files`` via a >0 threshold and + from ``subject_labelmap_files`` via a >0 threshold and dilation by ``mask_dilation_mm``. Per-image ``None`` entries fall back to the derived mask for that frame (or skip it if no segmentation is available either). @@ -223,7 +223,7 @@ def __init__( Raises: ValueError: If ``subject_image_files`` is empty. - ValueError: If ``subject_segmentation_files``, + ValueError: If ``subject_labelmap_files``, ``subject_mask_files``, or ``subject_landmark_files`` is provided with a shape that does not match ``subject_image_files``. @@ -246,8 +246,8 @@ def __init__( self._validate_companion_shape( subject_image_files, - subject_segmentation_files, - "subject_segmentation_files", + subject_labelmap_files, + "subject_labelmap_files", ) self._validate_companion_shape( subject_image_files, subject_mask_files, "subject_mask_files" @@ -258,13 +258,13 @@ def __init__( self.subject_image_files = subject_image_files self.subject_ids = subject_ids - self.subject_segmentation_files = subject_segmentation_files + self.subject_labelmap_files = subject_labelmap_files self.subject_mask_files = subject_mask_files self.subject_landmark_files = subject_landmark_files - self.use_segmentations: bool = subject_segmentation_files is not None + self.use_segmentations: bool = subject_labelmap_files is not None self.use_masks: bool = ( - subject_mask_files is not None or subject_segmentation_files is not None + subject_mask_files is not None or subject_labelmap_files is not None ) self.output_dir = Path(output_dir).resolve() @@ -384,7 +384,7 @@ def prepare_dataset( ``subject_id`` derived from the inner-list index. Masks are taken from ``subject_mask_files`` when supplied for a frame; - otherwise they are derived from ``subject_segmentation_files`` via a + otherwise they are derived from ``subject_labelmap_files`` via a >0 threshold and ``mask_dilation_mm`` mm dilation. Frames are skipped (with a log warning) when a required companion (segmentation for paired-with-seg training, or mask for loss-function masking) is @@ -419,8 +419,8 @@ def prepare_dataset( seg_list = [None] * len(image_files) else: seg_list = ( - self.subject_segmentation_files[subject_index] - if self.subject_segmentation_files is not None + self.subject_labelmap_files[subject_index] + if self.subject_labelmap_files is not None else [None] * len(image_files) ) mask_list: list[Optional[str]] From 49212f51140c393ed032d9ccf4508a3603cd6ecc Mon Sep 17 00:00:00 2001 From: Stephen Aylward Date: Fri, 5 Jun 2026 15:52:56 -0400 Subject: [PATCH 10/10] BUG: Fix API naming --- .../2-finetune_icon.py | 72 ++++++++++--------- src/physiomotion4d/register_images_ants.py | 1 + src/physiomotion4d/transform_tools.py | 7 +- ...st_workflow_fine_tune_icon_registration.py | 16 ++--- 4 files changed, 52 insertions(+), 44 deletions(-) diff --git a/experiments/LongitudinalRegistration/2-finetune_icon.py b/experiments/LongitudinalRegistration/2-finetune_icon.py index 85c8a06..05e55db 100644 --- a/experiments/LongitudinalRegistration/2-finetune_icon.py +++ b/experiments/LongitudinalRegistration/2-finetune_icon.py @@ -110,9 +110,9 @@ # skips just that frame. # %% -train_image_files: list[list[str]] = [] -train_labelmap_files: list[list[Optional[str]]] = [] -train_mask_files: list[list[Optional[str]]] = [] +train_image_files: list[list[Path]] = [] +train_labelmap_files: list[list[Optional[Path]]] = [] +train_mask_files: list[list[Optional[Path]]] = [] valid_train_subjects: list[str] = [] mask_dilation_mm = 3.0 @@ -120,15 +120,15 @@ # %% -def load_or_derive_mask(labelmap_path: Path) -> Optional[str]: +def load_or_derive_mask(labelmap_path: Path) -> Optional[Path]: """Create (or reuse) a loss-function mask next to ``labelmap_path``. Thresholds the labelmap at ``>0`` and dilates by ``mask_dilation_mm`` mm via :meth:`LabelmapTools.convert_labelmap_to_mask`, writing the result as ``_mask.nii.gz`` in the labelmap's own directory. Handles both ``.nii.gz`` (original Simpleware labelmaps) and ``.mha`` - (pre-registration warped labelmaps). Returns the mask path as a string; - existing masks on disk are reused unmodified. + (pre-registration warped labelmaps). Returns the mask path; existing + masks on disk are reused unmodified. """ if not labelmap_path.exists(): return None @@ -146,12 +146,14 @@ def load_or_derive_mask(labelmap_path: Path) -> Optional[str]: itk.imread(str(labelmap_path)), dilation_in_mm=mask_dilation_mm ) itk.imwrite(mask, str(mask_p), compression=True) - return str(mask_p) + return mask_p # %% -def gather_warped_frames(method_dir: Path) -> tuple[list[str], list[Optional[str]]]: - """Return ``(warped_image_paths, warped_labelmap_paths)`` for one +def gather_warped_frames( + method_dir: Path, +) -> tuple[list[Path], list[Optional[Path]], list[Optional[Path]]]: + """Return ``(warped_image_paths, warped_labelmap_paths, warped_mask_paths)`` for one ``initial_registration_dir / / `` directory. Enumerates the warped moving images (``.mha``), excluding the @@ -161,36 +163,27 @@ def gather_warped_frames(method_dir: Path) -> tuple[list[str], list[Optional[str not exist. """ if not method_dir.is_dir(): - return [], [] + return [], [], [] companion_suffixes = ( "_labelmap.mha", "_mask.mha", ) - image_paths: list[str] = [] - labelmap_paths: list[Optional[str]] = [] - mask_paths: list[Optional[str]] = [] + image_paths: list[Path] = [] + labelmap_paths: list[Optional[Path]] = [] + mask_paths: list[Optional[Path]] = [] for image in sorted(method_dir.glob("*.mha")): if image.name.endswith(companion_suffixes): continue stem = image.name[:-4] labelmap = method_dir / f"{stem}_labelmap.mha" mask = method_dir / f"{stem}_mask.mha" - image_paths.append(str(image)) - labelmap_paths.append(str(labelmap) if labelmap.exists() else None) - mask_paths.append(str(mask) if mask.exists() else None) + image_paths.append(image) + labelmap_paths.append(labelmap if labelmap.exists() else None) + mask_paths.append(mask if mask.exists() else None) return image_paths, labelmap_paths, mask_paths # %% -train_mask_files: list[list[Optional[str]]] = [] -for labelmap_paths in train_labelmap_files: - train_mask_files.append( - [ - load_or_derive_mask(Path(s)) if s is not None else None - for s in labelmap_paths - ] - ) - for patient_id in train_subjects: src_dir = src_data_dir_base / patient_id seg_dir = labelmap_dir_base / patient_id @@ -206,17 +199,18 @@ def gather_warped_frames(method_dir: Path) -> tuple[list[str], list[Optional[str print(f" Skipping {patient_id}: no valid frames in {src_dir}") continue - image_paths = [str(src_dir / f) for f in frame_names] - labelmap_paths: list[Optional[str]] = [] - mask_paths: list[Optional[str]] = [] + image_paths = [src_dir / f for f in frame_names] + labelmap_paths: list[Optional[Path]] = [] + mask_paths: list[Optional[Path]] = [] for f in frame_names: labelmap = seg_dir / f.replace(".nii.gz", "_labelmap.nii.gz") - labelmap_paths.append(str(labelmap) if labelmap.exists() else None) + labelmap_paths.append(labelmap if labelmap.exists() else None) mask = load_or_derive_mask(labelmap) - mask_paths.append(str(mask) if mask.exists() else None) + mask_paths.append(mask) train_image_files.append(image_paths) train_labelmap_files.append(labelmap_paths) + train_mask_files.append(mask_paths) valid_train_subjects.append(patient_id) n_seg = sum(1 for s in labelmap_paths if s is not None) @@ -244,12 +238,24 @@ def gather_warped_frames(method_dir: Path) -> tuple[list[str], list[Optional[str # %% workflow = WorkflowFineTuneICONRegistration( - subject_image_files=train_image_files, + subject_image_files=[ + [str(image_path) for image_path in image_paths] + for image_paths in train_image_files + ], output_dir=output_dir, fine_tune_name=fine_tune_name, subject_ids=valid_train_subjects, - subject_labelmap_files=train_labelmap_files, - subject_mask_files=train_mask_files, + subject_labelmap_files=[ + [ + str(labelmap_path) if labelmap_path is not None else None + for labelmap_path in labelmap_paths + ] + for labelmap_paths in train_labelmap_files + ], + subject_mask_files=[ + [str(mask_path) if mask_path is not None else None for mask_path in mask_paths] + for mask_paths in train_mask_files + ], mask_dilation_mm=mask_dilation_mm, unigradicon_src_path=unigradicon_src_path, epochs=500, diff --git a/src/physiomotion4d/register_images_ants.py b/src/physiomotion4d/register_images_ants.py index 869c17d..d31552c 100644 --- a/src/physiomotion4d/register_images_ants.py +++ b/src/physiomotion4d/register_images_ants.py @@ -605,6 +605,7 @@ def registration_method( self.moving_mask, initial_forward_transform, self.fixed_image, + interpolation_method="nearest", ) transform_type = None diff --git a/src/physiomotion4d/transform_tools.py b/src/physiomotion4d/transform_tools.py index c7dee4d..a1f4197 100644 --- a/src/physiomotion4d/transform_tools.py +++ b/src/physiomotion4d/transform_tools.py @@ -89,12 +89,13 @@ def combine_displacement_field_transforms( mode: str = "compose", tfm1_blur_sigma: float = 0.0, tfm2_blur_sigma: float = 0.0, - ) -> itk.CompositeTransform: + ) -> itk.Transform: """ Compose two displacement field transforms. - Composes two displacement field transforms into a single displacement field - transform. + In ``add`` mode, returns a single displacement field transform with + weighted summed vectors. In ``compose`` mode, returns a composite + transform containing both weighted displacement field transforms. """ assert mode in ["add", "compose"], "Invalid mode" diff --git a/tests/test_workflow_fine_tune_icon_registration.py b/tests/test_workflow_fine_tune_icon_registration.py index a90d9ea..311fd96 100644 --- a/tests/test_workflow_fine_tune_icon_registration.py +++ b/tests/test_workflow_fine_tune_icon_registration.py @@ -46,7 +46,7 @@ def two_subject_dataset(tmp_path: Path) -> dict[str, Any]: output_dir = tmp_path / "ft_out" subject_image_files: list[list[str]] = [] - subject_segmentation_files: list[list[Optional[str]]] = [] + subject_labelmap_files: list[list[Optional[str]]] = [] for patient_id in ("pm0001", "pm0002"): pdir = data_dir / patient_id pdir.mkdir() @@ -60,14 +60,14 @@ def two_subject_dataset(tmp_path: Path) -> dict[str, Any]: images.append(str(image_path)) segs.append(str(label_path)) subject_image_files.append(images) - subject_segmentation_files.append(segs) + subject_labelmap_files.append(segs) return { "output_dir": output_dir, "fine_tune_name": "test_exp", "subject_ids": ["pm0001", "pm0002"], "subject_image_files": subject_image_files, - "subject_segmentation_files": subject_segmentation_files, + "subject_labelmap_files": subject_labelmap_files, } @@ -139,7 +139,7 @@ def test_use_segmentations_and_use_masks_flags(tmp_path: Path) -> None: assert not none_wf.use_masks seg_only = WorkflowFineTuneICONRegistration( - **base, subject_segmentation_files=[["seg.nii.gz"]] + **base, subject_labelmap_files=[["seg.nii.gz"]] ) assert seg_only.use_segmentations assert seg_only.use_masks # derived from segs @@ -195,7 +195,7 @@ def test_prepare_dataset_skips_frames_with_missing_segmentation( subject_image_files=[[str(img_a), str(img_b)]], output_dir=tmp_path / "out", fine_tune_name="exp", - subject_segmentation_files=[[str(seg_a), None]], + subject_labelmap_files=[[str(seg_a), None]], log_level=logging.CRITICAL, ) dataset_json_path = workflow.prepare_dataset() @@ -220,7 +220,7 @@ def test_prepare_dataset_uses_explicit_mask_over_derived(tmp_path: Path) -> None subject_image_files=[[str(image)]], output_dir=tmp_path / "out", fine_tune_name="exp", - subject_segmentation_files=[[str(seg)]], + subject_labelmap_files=[[str(seg)]], subject_mask_files=[[str(explicit_mask)]], log_level=logging.CRITICAL, ) @@ -269,7 +269,7 @@ def test_prepare_dataset_derives_mask_next_to_labelmap_by_default( seg_files = [ Path(s) - for inner in workflow.subject_segmentation_files or [] + for inner in workflow.subject_labelmap_files or [] for s in inner if s is not None ] @@ -299,7 +299,7 @@ def test_prepare_dataset_derives_mask_under_explicit_mask_dir( # None of the labelmap-adjacent locations should have been written to. seg_files = [ Path(s) - for inner in workflow.subject_segmentation_files or [] + for inner in workflow.subject_labelmap_files or [] for s in inner if s is not None ]