diff --git a/README.md b/README.md
index 6c83125..e7324ef 100644
--- a/README.md
+++ b/README.md
@@ -88,13 +88,14 @@ print(f"PhysioMotion4D version: {physiomotion4d.__version__}")
- **Workflow Classes**: Complete end-to-end pipeline processors
- `WorkflowConvertHeartGatedCTToUSD`: Heart-gated CT to USD processing workflow
- - `WorkflowRegisterHeartModelToPatient`: Model-to-patient registration workflow
+ - `WorkflowCreateStatisticalModel`: Create PCA statistical shape model from sample meshes
+ - `WorkflowFitStatisticalModelToPatient`: Model-to-patient registration workflow
- **Segmentation Classes**: Multiple AI-based chest segmentation implementations
- `SegmentChestTotalSegmentator`: TotalSegmentator-based segmentation
- `SegmentChestVista3D`: VISTA-3D model-based segmentation
- `SegmentChestVista3DNIM`: NVIDIA NIM version of VISTA-3D
- `SegmentChestEnsemble`: Ensemble segmentation combining multiple methods
- - `SegmentChestBase`: Base class for custom segmentation methods
+ - `SegmentAnatomyBase`: Base class for custom segmentation methods
- **Registration Classes**: Multiple registration methods for different use cases
- Image-to-Image Registration:
- `RegisterImagesICON`: Deep learning-based registration using Icon algorithm
@@ -156,13 +157,34 @@ physiomotion4d-heart-gated-ct cardiac.nrrd \
For Python API usage and advanced customization, see the examples below or refer to the CLI implementation in `src/physiomotion4d/cli/`.
+#### Create Statistical Model
+
+Build a PCA statistical shape model from sample meshes aligned to a reference:
+
+```bash
+# From a directory of sample meshes
+physiomotion4d-create-statistical-model \
+ --sample-meshes-dir ./input_meshes \
+ --reference-mesh average_mesh.vtk \
+ --output-dir ./pca_output
+
+# With custom PCA components
+physiomotion4d-create-statistical-model \
+ --sample-meshes-dir ./meshes \
+ --reference-mesh average_mesh.vtk \
+ --output-dir ./pca_output \
+ --pca-components 20
+```
+
+Outputs: `pca_mean_surface.vtp`, `pca_mean.vtu` (if reference is volumetric), and `pca_model.json`.
+
#### Heart Model to Patient Registration
Register a generic heart model to patient-specific data:
```bash
# Basic registration
-physiomotion4d-register-heart-model \
+physiomotion4d-fit-statistical-model-to-patient \
--template-model heart_model.vtu \
--template-labelmap heart_labelmap.nii.gz \
--patient-models lv.vtp rv.vtp myo.vtp \
@@ -170,7 +192,7 @@ physiomotion4d-register-heart-model \
--output-dir ./results
# With PCA shape fitting
-physiomotion4d-register-heart-model \
+physiomotion4d-fit-statistical-model-to-patient \
--template-model heart_model.vtu \
--template-labelmap heart_labelmap.nii.gz \
--patient-models lv.vtp rv.vtp myo.vtp \
@@ -203,7 +225,7 @@ final_usd = processor.process()
### Python API - Model to Patient Registration
```python
-from physiomotion4d import WorkflowRegisterHeartModelToPatient
+from physiomotion4d import WorkflowFitStatisticalModelToPatient
import pyvista as pv
import itk
@@ -213,7 +235,7 @@ patient_surfaces = [pv.read("lv.stl"), pv.read("rv.stl")]
reference_image = itk.imread("patient_ct.nii.gz")
# Initialize and run workflow
-workflow = WorkflowRegisterHeartModelToPatient(
+workflow = WorkflowFitStatisticalModelToPatient(
moving_mesh=model_mesh,
fixed_meshes=patient_surfaces,
fixed_image=reference_image
@@ -352,13 +374,13 @@ PhysioMotion4D provides standardized logging through the `PhysioMotion4DBase` cl
```python
import logging
-from physiomotion4d import WorkflowRegisterHeartModelToPatient, PhysioMotion4DBase
+from physiomotion4d import WorkflowFitStatisticalModelToPatient, PhysioMotion4DBase
# Control logging level globally for all classes
PhysioMotion4DBase.set_log_level(logging.DEBUG)
# Or filter to show logs from specific classes only
-PhysioMotion4DBase.set_log_classes(["WorkflowRegisterHeartModelToPatient", "RegisterModelsPCA"])
+PhysioMotion4DBase.set_log_classes(["WorkflowFitStatisticalModelToPatient", "RegisterModelsPCA"])
# Show all classes again
PhysioMotion4DBase.set_log_all_classes()
@@ -436,7 +458,7 @@ Advanced registration between generic anatomical models and patient-specific dat
- **`heart_model_to_model_registration_pca.ipynb`**: PCA-based statistical shape model registration
- **`heart_model_to_patient.ipynb`**: Complete model-to-patient registration workflow
-Uses the `WorkflowRegisterHeartModelToPatient` class for three-stage registration:
+Uses the `WorkflowFitStatisticalModelToPatient` class for three-stage registration:
1. ICP-based rough alignment
2. Mask-to-mask deformable registration
3. Optional PCA-constrained shape fitting
diff --git a/docs/api/index.rst b/docs/api/index.rst
index 611a778..4ff8f53 100644
--- a/docs/api/index.rst
+++ b/docs/api/index.rst
@@ -76,10 +76,11 @@ By Category
**Workflows**
* :class:`~physiomotion4d.WorkflowConvertHeartGatedCTToUSD` - Heart CT to USD
- * :class:`~physiomotion4d.WorkflowRegisterHeartModelToPatient` - Heart model registration
+ * :class:`~physiomotion4d.WorkflowCreateStatisticalModel` - Create PCA statistical shape model
+ * :class:`~physiomotion4d.WorkflowFitStatisticalModelToPatient` - Heart model registration
**Segmentation**
- * :class:`~physiomotion4d.SegmentChestBase` - Base segmentation class
+ * :class:`~physiomotion4d.SegmentAnatomyBase` - Base segmentation class
* :class:`~physiomotion4d.SegmentChestTotalSegmentator` - TotalSegmentator
* :class:`~physiomotion4d.SegmentChestVista3D` - VISTA-3D model
* :class:`~physiomotion4d.SegmentChestVista3DNIM` - VISTA-3D NIM
diff --git a/docs/api/segmentation/base.rst b/docs/api/segmentation/base.rst
index 4446f22..895dd90 100644
--- a/docs/api/segmentation/base.rst
+++ b/docs/api/segmentation/base.rst
@@ -9,7 +9,7 @@ Abstract base class for all segmentation methods.
Class Reference
===============
-.. autoclass:: SegmentChestBase
+.. autoclass:: SegmentAnatomyBase
:members:
:undoc-members:
:show-inheritance:
@@ -18,7 +18,7 @@ Class Reference
Overview
========
-:class:`SegmentChestBase` provides the foundation for all segmentation implementations in PhysioMotion4D. It defines the common interface and shared functionality that all segmentation methods must implement.
+:class:`SegmentAnatomyBase` provides the foundation for all segmentation implementations in PhysioMotion4D. It defines the common interface and shared functionality that all segmentation methods must implement.
**Key Responsibilities**:
* Define standard segmentation interface
@@ -87,17 +87,17 @@ These methods are provided by the base class:
Creating Custom Segmentation Classes
=====================================
-To create a new segmentation method, inherit from :class:`SegmentChestBase`:
+To create a new segmentation method, inherit from :class:`SegmentAnatomyBase`:
Basic Implementation
--------------------
.. code-block:: python
- from physiomotion4d import SegmentChestBase
+ from physiomotion4d import SegmentAnatomyBase
import numpy as np
- class CustomSegmentator(SegmentChestBase):
+ class CustomSegmentator(SegmentAnatomyBase):
"""Custom segmentation implementation."""
def __init__(self, param1=None, verbose=False):
@@ -150,7 +150,7 @@ With Custom Post-Processing
.. code-block:: python
- class CustomSegmentator(SegmentChestBase):
+ class CustomSegmentator(SegmentAnatomyBase):
"""Segmentator with custom post-processing."""
def post_process(self, labelmap):
@@ -243,7 +243,7 @@ Validate that required structures are present:
.. code-block:: python
- class ValidatingSegmentator(SegmentChestBase):
+ class ValidatingSegmentator(SegmentAnatomyBase):
"""Segmentator with validation."""
def segment(self, image_path):
@@ -268,7 +268,7 @@ Track segmentation progress for long operations:
.. code-block:: python
- class ProgressSegmentator(SegmentChestBase):
+ class ProgressSegmentator(SegmentAnatomyBase):
"""Segmentator with progress tracking."""
def segment(self, image_path):
diff --git a/docs/api/segmentation/index.rst b/docs/api/segmentation/index.rst
index 0964ff0..d326ef1 100644
--- a/docs/api/segmentation/index.rst
+++ b/docs/api/segmentation/index.rst
@@ -16,7 +16,7 @@ PhysioMotion4D supports multiple segmentation approaches:
* **VISTA-3D NIM**: NVIDIA Inference Microservice version
* **Ensemble**: Combine multiple methods for improved accuracy
-All segmentation classes inherit from :class:`SegmentChestBase` and provide consistent interfaces.
+All segmentation classes inherit from :class:`SegmentAnatomyBase` and provide consistent interfaces.
Quick Links
===========
diff --git a/docs/api/workflows.rst b/docs/api/workflows.rst
index 811e3a4..5239bfd 100644
--- a/docs/api/workflows.rst
+++ b/docs/api/workflows.rst
@@ -49,7 +49,7 @@ Heart Gated CT to USD
Heart Model to Patient Registration
------------------------------------
-.. autoclass:: WorkflowRegisterHeartModelToPatient
+.. autoclass:: WorkflowFitStatisticalModelToPatient
:members:
:undoc-members:
:show-inheritance:
@@ -67,9 +67,9 @@ Heart Model to Patient Registration
.. code-block:: python
- from physiomotion4d import WorkflowRegisterHeartModelToPatient
+ from physiomotion4d import WorkflowFitStatisticalModelToPatient
- workflow = WorkflowRegisterHeartModelToPatient(
+ workflow = WorkflowFitStatisticalModelToPatient(
model_file="heart_template.vtk",
patient_image="patient_ct.nrrd",
output_directory="./registration_results",
@@ -221,7 +221,7 @@ Combine multiple workflows:
heart_result = heart_workflow.process()
# Second workflow: Register model
- registration_workflow = WorkflowRegisterHeartModelToPatient(
+ registration_workflow = WorkflowFitStatisticalModelToPatient(
model_file=heart_result['model'],
patient_image=patient_ct,
output_directory="./registration_output"
diff --git a/docs/architecture.rst b/docs/architecture.rst
index 8b5d442..bdb21f5 100644
--- a/docs/architecture.rst
+++ b/docs/architecture.rst
@@ -31,7 +31,7 @@ Core Components
2. **Segmentation Module**
- * Base class: :class:`SegmentChestBase`
+ * Base class: :class:`SegmentAnatomyBase`
* Implementations: TotalSegmentator, VISTA-3D, Ensemble
3. **Registration Module**
diff --git a/docs/cli_scripts/create_statistical_model.rst b/docs/cli_scripts/create_statistical_model.rst
new file mode 100644
index 0000000..ccd6af6
--- /dev/null
+++ b/docs/cli_scripts/create_statistical_model.rst
@@ -0,0 +1,102 @@
+====================================
+Create Statistical Model
+====================================
+
+Overview
+========
+
+The ``physiomotion4d-create-statistical-model`` command-line tool builds a PCA
+(Principal Component Analysis) statistical shape model from a sample of meshes
+aligned to a reference mesh. This mirrors the pipeline in the
+Heart-Create_Statistical_Model experiment notebooks.
+
+The workflow:
+
+1. **Extract surfaces** from sample and reference meshes
+2. **ICP alignment**: Affine align each sample surface to the reference surface
+3. **Deformable registration**: ANTs SyN to establish dense correspondence
+4. **Correspondence**: Build aligned shapes with reference topology
+5. **PCA**: Compute mean shape and principal components
+
+Outputs written to the output directory:
+
+* ``pca_mean_surface.vtp`` — Mean shape as a surface (PolyData)
+* ``pca_mean.vtu`` — Reference volume mesh in mean space (only if reference is volumetric)
+* ``pca_model.json`` — PCA model (eigenvalues, components) for use with
+ :class:`physiomotion4d.WorkflowFitStatisticalModelToPatient` or
+ :class:`physiomotion4d.RegisterModelsPCA`
+
+Installation
+============
+
+The script is installed with PhysioMotion4D:
+
+.. code-block:: bash
+
+ pip install physiomotion4d
+
+Quick Start
+===========
+
+Basic Usage
+-----------
+
+Create a PCA model from a directory of sample meshes and a reference mesh:
+
+.. code-block:: bash
+
+ physiomotion4d-create-statistical-model \
+ --sample-meshes-dir ./input_meshes \
+ --reference-mesh average_mesh.vtk \
+ --output-dir ./pca_output
+
+Explicit Sample List
+--------------------
+
+Provide sample mesh paths explicitly instead of a directory:
+
+.. code-block:: bash
+
+ physiomotion4d-create-statistical-model \
+ --sample-meshes 01.vtk 02.vtk 03.vtu 04.vtp \
+ --reference-mesh average_mesh.vtk \
+ --output-dir ./pca_output
+
+With Custom Parameters
+----------------------
+
+.. code-block:: bash
+
+ physiomotion4d-create-statistical-model \
+ --sample-meshes-dir ./meshes \
+ --reference-mesh average_mesh.vtk \
+ --output-dir ./pca_output \
+ --pca-components 20
+
+Command-Line Arguments
+======================
+
+Required Arguments
+------------------
+
+``--sample-meshes-dir DIR`` or ``--sample-meshes PATH [PATH ...]``
+ Either a directory containing sample mesh files (``.vtk``, ``.vtu``, ``.vtp``)
+ or a list of paths to sample meshes. One of these is required.
+
+``--reference-mesh PATH``
+ Path to the reference mesh. Its surface is used as the alignment target for
+ all samples.
+
+``--output-dir DIR``
+ Output directory. Writes ``pca_mean_surface.vtp``, ``pca_mean.vtu`` (if
+ reference is volumetric), and ``pca_model.json``.
+
+Optional Arguments
+------------------
+
+``--pca-components N``
+ Number of PCA components to retain (default: 15).
+
+See :class:`physiomotion4d.WorkflowCreateStatisticalModel` for the full API and
+additional parameters (e.g. ``reference_spatial_resolution``,
+``reference_buffer_factor``) that can be exposed in future CLI versions.
diff --git a/docs/cli_scripts/heart_model_to_patient.rst b/docs/cli_scripts/fit_statistical_model_to_patient.rst
similarity index 90%
rename from docs/cli_scripts/heart_model_to_patient.rst
rename to docs/cli_scripts/fit_statistical_model_to_patient.rst
index d2f85d0..669fe15 100644
--- a/docs/cli_scripts/heart_model_to_patient.rst
+++ b/docs/cli_scripts/fit_statistical_model_to_patient.rst
@@ -5,7 +5,7 @@ Heart Model to Patient Registration
Overview
========
-The ``physiomotion4d-register-heart-model`` command-line tool registers generic anatomical heart models to patient-specific imaging data and surface models. This workflow enables:
+The ``physiomotion4d-fit-statistical-model-to-patient`` command-line tool registers generic anatomical heart models to patient-specific imaging data and surface models. This workflow enables:
* Patient-specific anatomical modeling from generic templates
* Multi-stage registration combining ICP, PCA, and deformable methods
@@ -38,7 +38,7 @@ Register a generic heart model to patient data:
.. code-block:: bash
- physiomotion4d-register-heart-model \
+ physiomotion4d-fit-statistical-model-to-patient \
--template-model heart_model.vtu \
--template-labelmap heart_labelmap.nii.gz \
--patient-models lv.vtp rv.vtp myo.vtp \
@@ -52,7 +52,7 @@ Include statistical shape model fitting:
.. code-block:: bash
- physiomotion4d-register-heart-model \
+ physiomotion4d-fit-statistical-model-to-patient \
--template-model heart_model.vtu \
--template-labelmap heart_labelmap.nii.gz \
--patient-models lv.vtp rv.vtp myo.vtp \
@@ -82,7 +82,7 @@ Required Arguments
``--output-dir DIR``
Output directory for results
-See :class:`physiomotion4d.WorkflowRegisterHeartModelToPatient` for API documentation.
+See :class:`physiomotion4d.WorkflowFitStatisticalModelToPatient` for API documentation.
Template Labelmap Configuration
--------------------------------
@@ -141,7 +141,7 @@ Example 1: Basic Registration
.. code-block:: bash
- physiomotion4d-register-heart-model \
+ physiomotion4d-fit-statistical-model-to-patient \
--template-model heart_model.vtu \
--template-labelmap heart_labelmap.nii.gz \
--patient-models lv.vtp rv.vtp myo.vtp \
@@ -153,7 +153,7 @@ Example 2: PCA-Based Registration
.. code-block:: bash
- physiomotion4d-register-heart-model \
+ physiomotion4d-fit-statistical-model-to-patient \
--template-model heart_model.vtu \
--template-labelmap heart_labelmap.nii.gz \
--patient-models lv.vtp rv.vtp \
diff --git a/docs/cli_scripts/overview.rst b/docs/cli_scripts/overview.rst
index 2fa843c..180e6fc 100644
--- a/docs/cli_scripts/overview.rst
+++ b/docs/cli_scripts/overview.rst
@@ -9,7 +9,8 @@ This section provides comprehensive guides for using PhysioMotion4D's command-li
**CLI Commands: Your Definitive Resource** ⭐
The examples and workflows documented here are based on production-ready CLI commands
- (``physiomotion4d-heart-gated-ct``, ``physiomotion4d-register-heart-model``) and their
+ (``physiomotion4d-heart-gated-ct``, ``physiomotion4d-create-statistical-model``,
+ ``physiomotion4d-fit-statistical-model-to-patient``) and their
implementations in ``src/physiomotion4d/cli/``. These are your **primary resource** for:
* Production-ready workflow implementations
@@ -48,7 +49,9 @@ Current Scripts
- Description
* - :doc:`heart_gated_ct`
- Process cardiac gated CT to animated heart models with physiological motion
- * - :doc:`heart_model_to_patient`
+ * - :doc:`create_statistical_model`
+ - Build a PCA statistical shape model from sample meshes aligned to a reference
+ * - :doc:`fit_statistical_model_to_patient`
- Register generic heart models to patient-specific imaging data and surface models
Upcoming Scripts
diff --git a/docs/developer/architecture.rst b/docs/developer/architecture.rst
index eacbd21..bf67f13 100644
--- a/docs/developer/architecture.rst
+++ b/docs/developer/architecture.rst
@@ -46,10 +46,10 @@ The package is organized into functional modules:
│
├── Workflow Classes
│ ├── workflow_convert_heart_gated_ct_to_usd.py Cardiac CT → USD
- │ └── workflow_register_heart_model_to_patient.py Model → Patient
+ │ └── workflow_fit_statistical_model_to_patient.py Model → Patient
│
├── Segmentation
- │ ├── segment_chest_base.py Base segmentation
+ │ ├── segment_anatomy_base.py Base segmentation
│ ├── segment_chest_total_segmentator.py TotalSegmentator
│ ├── segment_chest_vista_3d.py VISTA-3D
│ ├── segment_chest_vista_3d_nim.py VISTA-3D NIM
@@ -95,9 +95,9 @@ Most PhysioMotion4D classes inherit from :class:`PhysioMotion4DBase`:
PhysioMotion4DBase
├── Workflow Classes
│ ├── WorkflowConvertHeartGatedCTToUSD
- │ └── WorkflowRegisterHeartModelToPatient
+ │ └── WorkflowFitStatisticalModelToPatient
├── Segmentation Classes
- │ ├── SegmentChestBase
+ │ ├── SegmentAnatomyBase
│ │ ├── SegmentChestTotalSegmentator
│ │ ├── SegmentChestVista3D
│ │ └── SegmentChestEnsemble
@@ -221,7 +221,7 @@ Extension Points
PhysioMotion4D is designed for extension:
**Add New Segmentation Methods**
- Inherit from :class:`SegmentChestBase`
+ Inherit from :class:`SegmentAnatomyBase`
**Add New Registration Methods**
Inherit from :class:`RegisterImagesBase`
diff --git a/docs/developer/extending.rst b/docs/developer/extending.rst
index 1ebd916..d2a10f7 100644
--- a/docs/developer/extending.rst
+++ b/docs/developer/extending.rst
@@ -249,14 +249,14 @@ Custom Segmentation Methods
Adding New Segmentation Algorithms
-----------------------------------
-Extend :class:`SegmentChestBase`:
+Extend :class:`SegmentAnatomyBase`:
.. code-block:: python
- from physiomotion4d import SegmentChestBase
+ from physiomotion4d import SegmentAnatomyBase
import torch
- class MyCustomSegmentator(SegmentChestBase):
+ class MyCustomSegmentator(SegmentAnatomyBase):
"""Custom deep learning segmentation."""
def __init__(
diff --git a/docs/developer/segmentation.rst b/docs/developer/segmentation.rst
index d5ab30f..656cd00 100644
--- a/docs/developer/segmentation.rst
+++ b/docs/developer/segmentation.rst
@@ -21,17 +21,17 @@ PhysioMotion4D supports multiple segmentation approaches:
* **VISTA-3D NIM**: NVIDIA Inference Microservice version
* **Ensemble**: Combine multiple methods for improved accuracy
-All segmentation classes inherit from :class:`SegmentChestBase` and provide consistent interfaces.
+All segmentation classes inherit from :class:`SegmentAnatomyBase` and provide consistent interfaces.
Base Segmentation Class
=======================
-SegmentChestBase
+SegmentAnatomyBase
----------------
Abstract base class for all segmentation methods.
-.. autoclass:: physiomotion4d.SegmentChestBase
+.. autoclass:: physiomotion4d.SegmentAnatomyBase
:members:
:undoc-members:
:show-inheritance:
@@ -274,10 +274,10 @@ Add custom post-processing steps:
.. code-block:: python
- from physiomotion4d import SegmentChestBase
+ from physiomotion4d import SegmentAnatomyBase
import numpy as np
- class CustomSegmentator(SegmentChestBase):
+ class CustomSegmentator(SegmentAnatomyBase):
"""Custom segmentator with post-processing."""
def post_process(self, labelmap):
diff --git a/docs/developer/workflows.rst b/docs/developer/workflows.rst
index c44aba1..cadd1da 100644
--- a/docs/developer/workflows.rst
+++ b/docs/developer/workflows.rst
@@ -37,7 +37,7 @@ Each CLI script wraps a corresponding workflow class:
* - ``physiomotion4d-heart-gated-ct``
- :class:`WorkflowConvertHeartGatedCTToUSD`
* - ``physiomotion4d-heart-model-to-patient``
- - :class:`WorkflowRegisterHeartModelToPatient` *(planned)*
+ - :class:`WorkflowFitStatisticalModelToPatient` *(planned)*
* - ``physiomotion4d-lung-gated-ct``
- :class:`LungGatedCTToUSDWorkflow` *(planned)*
* - ``physiomotion4d-4dct-reconstruction``
@@ -109,7 +109,7 @@ Process 4D cardiac CT to animated USD models.
See :doc:`../cli_scripts/heart_gated_ct` for CLI usage.
-WorkflowRegisterHeartModelToPatient
+WorkflowFitStatisticalModelToPatient
------------------------------------
.. note::
@@ -121,9 +121,9 @@ Register population heart models to patient images.
.. code-block:: python
- from physiomotion4d import WorkflowRegisterHeartModelToPatient
+ from physiomotion4d import WorkflowFitStatisticalModelToPatient
- workflow = WorkflowRegisterHeartModelToPatient(
+ workflow = WorkflowFitStatisticalModelToPatient(
model_file="population_heart_model.vtk",
patient_image="patient_ct.nrrd",
output_directory="./results"
@@ -131,7 +131,7 @@ Register population heart models to patient images.
registered_model = workflow.process()
-See :doc:`../cli_scripts/heart_model_to_patient` for planned CLI usage.
+See :doc:`../cli_scripts/fit_statistical_model_to_patient` for planned CLI usage.
Common Workflow Patterns
========================
diff --git a/docs/examples.rst b/docs/examples.rst
index 4aee427..2fb4b6d 100644
--- a/docs/examples.rst
+++ b/docs/examples.rst
@@ -8,7 +8,8 @@ see the :doc:`cli_scripts/overview` section.
.. note::
**For Production Workflows:** The CLI commands (``physiomotion4d-heart-gated-ct``,
- ``physiomotion4d-register-heart-model``) and their implementations in ``src/physiomotion4d/cli/``
+ ``physiomotion4d-create-statistical-model``, ``physiomotion4d-fit-statistical-model-to-patient``)
+ and their implementations in ``src/physiomotion4d/cli/``
are the definitive source for proper library usage, class instantiation, and best practices.
The ``experiments/`` directory contains research prototypes that informed development but should
diff --git a/docs/index.rst b/docs/index.rst
index 70ca847..4e8bd44 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -56,7 +56,8 @@ PhysioMotion4D is a comprehensive medical imaging package that converts 3D and 4
**Getting Started with Code Examples:**
This documentation uses examples from the CLI commands (``physiomotion4d-heart-gated-ct``,
- ``physiomotion4d-register-heart-model``) and their implementations in ``src/physiomotion4d/cli/``,
+ ``physiomotion4d-create-statistical-model``, ``physiomotion4d-fit-statistical-model-to-patient``)
+ and their implementations in ``src/physiomotion4d/cli/``,
which contain production-ready workflows and proper library usage patterns. The repository also includes
an ``experiments/`` directory with research prototypes that can inspire adaptations to
new digital twin models and anatomical regions—see the experiments README for details on
@@ -76,7 +77,8 @@ PhysioMotion4D is a comprehensive medical imaging package that converts 3D and 4
cli_scripts/overview
cli_scripts/heart_gated_ct
- cli_scripts/heart_model_to_patient
+ cli_scripts/create_statistical_model
+ cli_scripts/fit_statistical_model_to_patient
cli_scripts/lung_gated_ct
cli_scripts/4dct_reconstruction
cli_scripts/vtk_to_usd
diff --git a/docs/quickstart.rst b/docs/quickstart.rst
index 22e4355..b37dfd5 100644
--- a/docs/quickstart.rst
+++ b/docs/quickstart.rst
@@ -260,7 +260,8 @@ Now that you've completed your first workflow:
**About CLI Commands and Experiments:**
* **CLI Commands** ⭐ **PRIMARY RESOURCE** - Production-ready workflows with proper class usage
- (``physiomotion4d-heart-gated-ct``, ``physiomotion4d-register-heart-model``).
+ (``physiomotion4d-heart-gated-ct``, ``physiomotion4d-create-statistical-model``,
+ ``physiomotion4d-fit-statistical-model-to-patient``).
See ``src/physiomotion4d/cli/`` for implementation details.
* **experiments/** - Research prototypes and design explorations. These demonstrate conceptual
diff --git a/experiments/Heart-Create_Statistical_Model/5-compute_pca_model.ipynb b/experiments/Heart-Create_Statistical_Model/5-compute_pca_model.ipynb
index 05ecc54..29aa5f6 100644
--- a/experiments/Heart-Create_Statistical_Model/5-compute_pca_model.ipynb
+++ b/experiments/Heart-Create_Statistical_Model/5-compute_pca_model.ipynb
@@ -132,7 +132,7 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": null,
"id": "df57b010",
"metadata": {
"execution": {
@@ -155,7 +155,7 @@
}
],
"source": [
- "# Load average surface (this will be used as the mean for PCA)\n",
+ "# Load average surface (this will be replaced by the mean from PCA)\n",
"average_mesh = pv.read(average_surface_path)\n",
"n_points = average_mesh.n_points\n",
"n_features = n_points * 3 # x, y, z coordinates for each point\n",
diff --git a/experiments/Heart-Create_Statistical_Model/README.md b/experiments/Heart-Create_Statistical_Model/README.md
index ab39cf7..00b5286 100644
--- a/experiments/Heart-Create_Statistical_Model/README.md
+++ b/experiments/Heart-Create_Statistical_Model/README.md
@@ -145,10 +145,10 @@ After completing this experiment, you will have generated files in `kcl-heart-mo
The outputs from this experiment are used in the `Heart-Statistical_Model_To_Patient` experiment:
```python
-from physiomotion4d import WorkflowRegisterHeartModelToPatient
+from physiomotion4d import WorkflowFitStatisticalModelToPatient
# Use PCA model from this experiment
-workflow = WorkflowRegisterHeartModelToPatient(
+workflow = WorkflowFitStatisticalModelToPatient(
moving_mesh=mean_shape,
fixed_meshes=patient_surfaces,
fixed_image=patient_ct,
diff --git a/experiments/Heart-Simpleware_Segmentation/README.md b/experiments/Heart-Simpleware_Segmentation/README.md
index b561c40..2eb59a2 100644
--- a/experiments/Heart-Simpleware_Segmentation/README.md
+++ b/experiments/Heart-Simpleware_Segmentation/README.md
@@ -225,9 +225,9 @@ workflow.set_static_labelmap(result["labelmap"])
### Statistical Model Registration
Register segmentation with heart model using `Heart-Statistical_Model_To_Patient`:
```python
-from physiomotion4d.workflow_register_heart_model_to_patient import WorkflowRegisterHeartModelToPatient
+from physiomotion4d.workflow_fit_statistical_model_to_patient import WorkflowFitStatisticalModelToPatient
-workflow = WorkflowRegisterHeartModelToPatient()
+workflow = WorkflowFitStatisticalModelToPatient()
workflow.set_patient_segmentation(result["labelmap"])
# Perform model-to-patient registration
```
diff --git a/experiments/Heart-Statistical_Model_To_Patient/heart_model_to_model_registration_pca.ipynb b/experiments/Heart-Statistical_Model_To_Patient/heart_model_to_model_registration_pca.ipynb
index 091e82b..7674b49 100644
--- a/experiments/Heart-Statistical_Model_To_Patient/heart_model_to_model_registration_pca.ipynb
+++ b/experiments/Heart-Statistical_Model_To_Patient/heart_model_to_model_registration_pca.ipynb
@@ -1,846 +1,849 @@
{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# PCA-based Heart Model to Image Registration Experiment\n",
- "\n",
- "This notebook demonstrates using the `RegisterModelToImagePCA` class to register\n",
- "a statistical shape model to patient CT images using PCA-based shape variation.\n",
- "\n",
- "## Overview\n",
- "- Uses the KCL Heart Model PCA statistical shape model\n",
- "- Registers to the same Duke Heart CT data as the original notebook\n",
- "- Two-stage optimization: rigid alignment + PCA shape fitting\n",
- "- Converts segmentation mask to intensity image for registration"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Setup and Imports"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "execution": {
- "iopub.execute_input": "2026-02-09T06:34:36.139622Z",
- "iopub.status.busy": "2026-02-09T06:34:36.139622Z",
- "iopub.status.idle": "2026-02-09T06:34:51.022942Z",
- "shell.execute_reply": "2026-02-09T06:34:51.022012Z"
- }
- },
- "outputs": [],
- "source": [
- "# PCA-based Heart Model to Image Registration Experiment\n",
- "\n",
- "import os\n",
- "from pathlib import Path\n",
- "\n",
- "import itk\n",
- "import numpy as np\n",
- "import pyvista as pv\n",
- "from itk import TubeTK as ttk\n",
- "\n",
- "# Import from PhysioMotion4D package\n",
- "from physiomotion4d import (\n",
- " ContourTools,\n",
- " RegisterModelsICP,\n",
- " RegisterModelsPCA,\n",
- " TransformTools,\n",
- ")\n",
- "from physiomotion4d.notebook_utils import running_as_test"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Define File Paths"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "execution": {
- "iopub.execute_input": "2026-02-09T06:34:51.023944Z",
- "iopub.status.busy": "2026-02-09T06:34:51.023944Z",
- "iopub.status.idle": "2026-02-09T06:34:51.037942Z",
- "shell.execute_reply": "2026-02-09T06:34:51.036949Z"
- }
- },
- "outputs": [],
- "source": [
- "# Patient CT image (defines coordinate frame)\n",
- "patient_data_dir = Path.cwd().parent.parent / \"data\" / \"Slicer-Heart-CT\"\n",
- "patient_ct_path = patient_data_dir / \"patient_img.mha\"\n",
- "patient_ct_heart_mask_path = patient_data_dir / \"patient_heart_wall_mask.nii.gz\"\n",
- "\n",
- "# PCA heart model data\n",
- "heart_model_data_dir = Path.cwd().parent.parent / \"data\" / \"KCL-Heart-Model\"\n",
- "heart_model_path = heart_model_data_dir / \"average_mesh.vtk\"\n",
- "\n",
- "# PCA statistical model (from Heart-Create_Statistical_Model workflow)\n",
- "template_model_data_dir = (\n",
- " Path.cwd().parent / \"Heart-Create_Statistical_Model\" / \"kcl-heart-model\"\n",
- ")\n",
- "template_model_surface_path = template_model_data_dir / \"pca_mean.vtp\"\n",
- "pca_json_path = template_model_data_dir / \"pca_model.json\"\n",
- "\n",
- "# Output directory\n",
- "output_dir = Path.cwd() / \"results_pca\"\n",
- "os.makedirs(output_dir, exist_ok=True)\n",
- "\n",
- "print(f\"Patient data: {patient_data_dir}\")\n",
- "print(f\"PCA Model data: {template_model_data_dir}\")\n",
- "print(f\"Output directory: {output_dir}\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Load and Preprocess Patient Image"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "execution": {
- "iopub.execute_input": "2026-02-09T06:34:51.050954Z",
- "iopub.status.busy": "2026-02-09T06:34:51.050954Z",
- "iopub.status.idle": "2026-02-09T06:34:56.086219Z",
- "shell.execute_reply": "2026-02-09T06:34:56.085110Z"
- }
- },
- "outputs": [],
- "source": [
- "# Load patient CT image\n",
- "print(\"Loading patient CT image...\")\n",
- "patient_image = itk.imread(str(patient_ct_path))\n",
- "print(f\" Original size: {itk.size(patient_image)}\")\n",
- "print(f\" Original spacing: {itk.spacing(patient_image)}\")\n",
- "\n",
- "# Resample to 1mm isotropic spacing\n",
- "print(\"Resampling to sotropic...\")\n",
- "resampler = ttk.ResampleImage.New(Input=patient_image)\n",
- "resampler.SetMakeHighResIso(True)\n",
- "resampler.Update()\n",
- "patient_image = resampler.GetOutput()\n",
- "\n",
- "print(f\" Resampled size: {itk.size(patient_image)}\")\n",
- "print(f\" Resampled spacing: {itk.spacing(patient_image)}\")\n",
- "\n",
- "# Save preprocessed image\n",
- "itk.imwrite(patient_image, str(output_dir / \"patient_image.mha\"), compression=True)\n",
- "print(\"✓ Saved preprocessed image\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Load and Process Heart Segmentation Mask"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "execution": {
- "iopub.execute_input": "2026-02-09T06:34:56.087775Z",
- "iopub.status.busy": "2026-02-09T06:34:56.087775Z",
- "iopub.status.idle": "2026-02-09T06:34:56.115692Z",
- "shell.execute_reply": "2026-02-09T06:34:56.115196Z"
- }
- },
- "outputs": [],
- "source": [
- "# Load heart segmentation mask\n",
- "print(\"Loading heart segmentation mask...\")\n",
- "patient_heart_mask_image = itk.imread(str(patient_ct_heart_mask_path))\n",
- "\n",
- "print(f\" Mask size: {itk.size(patient_heart_mask_image)}\")\n",
- "print(f\" Mask spacing: {itk.spacing(patient_heart_mask_image)}\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "execution": {
- "iopub.execute_input": "2026-02-09T06:34:56.117556Z",
- "iopub.status.busy": "2026-02-09T06:34:56.117556Z",
- "iopub.status.idle": "2026-02-09T06:34:56.462795Z",
- "shell.execute_reply": "2026-02-09T06:34:56.462795Z"
- }
- },
- "outputs": [],
- "source": [
- "# Handle image orientation (flip if needed)\n",
- "flip0 = np.array(patient_heart_mask_image.GetDirection())[0, 0] < 0\n",
- "flip1 = np.array(patient_heart_mask_image.GetDirection())[1, 1] < 0\n",
- "flip2 = np.array(patient_heart_mask_image.GetDirection())[2, 2] < 0\n",
- "\n",
- "if flip0 or flip1 or flip2:\n",
- " print(f\"Flipping image axes: {flip0}, {flip1}, {flip2}\")\n",
- "\n",
- " # Flip CT image\n",
- " flip_filter = itk.FlipImageFilter.New(Input=patient_image)\n",
- " flip_filter.SetFlipAxes([int(flip0), int(flip1), int(flip2)])\n",
- " flip_filter.SetFlipAboutOrigin(True)\n",
- " flip_filter.Update()\n",
- " patient_image = flip_filter.GetOutput()\n",
- " id_mat = itk.Matrix[itk.D, 3, 3]()\n",
- " id_mat.SetIdentity()\n",
- " patient_image.SetDirection(id_mat)\n",
- "\n",
- " # Flip mask image\n",
- " flip_filter = itk.FlipImageFilter.New(Input=patient_heart_mask_image)\n",
- " flip_filter.SetFlipAxes([int(flip0), int(flip1), int(flip2)])\n",
- " flip_filter.SetFlipAboutOrigin(True)\n",
- " flip_filter.Update()\n",
- " patient_heart_mask_image = flip_filter.GetOutput()\n",
- " patient_heart_mask_image.SetDirection(id_mat)\n",
- "\n",
- " print(\"✓ Images flipped to standard orientation\")\n",
- "\n",
- "# Save oriented images\n",
- "itk.imwrite(\n",
- " patient_image, str(output_dir / \"patient_image_oriented.mha\"), compression=True\n",
- ")\n",
- "itk.imwrite(\n",
- " patient_heart_mask_image,\n",
- " str(output_dir / \"patient_heart_mask_oriented.mha\"),\n",
- " compression=True,\n",
- ")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Convert Segmentation Mask to a Surface"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "execution": {
- "iopub.execute_input": "2026-02-09T06:34:56.464797Z",
- "iopub.status.busy": "2026-02-09T06:34:56.464797Z",
- "iopub.status.idle": "2026-02-09T06:34:56.917954Z",
- "shell.execute_reply": "2026-02-09T06:34:56.917429Z"
- }
- },
- "outputs": [],
- "source": [
- "contour_tools = ContourTools()\n",
- "patient_surface = contour_tools.extract_contours(patient_heart_mask_image)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Perform Initial ICP Affine Registration\n",
- "\n",
- "Use ICP (Iterative Closest Point) with affine mode to align the model surface to the patient surface extracted from the segmentation mask. This provides a good initial alignment for the PCA-based registration.\n",
- "\n",
- "The ICP registration pipeline:\n",
- "1. Centroid alignment (automatic)\n",
- "2. Rigid ICP alignment\n",
- "\n",
- "The PCA registration will then refine this initial alignment with shape model constraints."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "execution": {
- "iopub.execute_input": "2026-02-09T06:34:56.919468Z",
- "iopub.status.busy": "2026-02-09T06:34:56.919468Z",
- "iopub.status.idle": "2026-02-09T06:35:10.733991Z",
- "shell.execute_reply": "2026-02-09T06:35:10.733143Z"
- }
- },
- "outputs": [],
- "source": [
- "# Load the pca model\n",
- "print(\"Loading PCA heart model...\")\n",
- "template_model = pv.read(str(heart_model_path))\n",
- "\n",
- "template_model_surface = pv.read(template_model_surface_path)\n",
- "print(f\" Template surface: {template_model_surface.n_points} points\")\n",
- "\n",
- "icp_registrar = RegisterModelsICP(fixed_model=patient_surface)\n",
- "\n",
- "# Use fewer iterations when run as test (pytest) for faster execution\n",
- "max_iterations_icp = 100 if running_as_test() else 2000\n",
- "icp_result = icp_registrar.register(\n",
- " transform_type=\"Affine\",\n",
- " moving_model=template_model_surface,\n",
- " max_iterations=max_iterations_icp,\n",
- ")\n",
- "\n",
- "# Get the aligned mesh and transform\n",
- "icp_registered_model_surface = icp_result[\"registered_model\"]\n",
- "icp_forward_point_transform = icp_result[\"forward_point_transform\"]\n",
- "\n",
- "print(\"\\n✓ ICP affine registration complete\")\n",
- "print(\" Transform =\", icp_result[\"forward_point_transform\"])\n",
- "\n",
- "# Save aligned model\n",
- "icp_registered_model_surface.save(str(output_dir / \"icp_registered_model_surface.vtp\"))\n",
- "print(\" Saved ICP-aligned model surface\")\n",
- "\n",
- "itk.transformwrite(\n",
- " [icp_result[\"forward_point_transform\"]],\n",
- " str(output_dir / \"icp_transform.hdf\"),\n",
- " compression=True,\n",
- ")\n",
- "print(\" Saved ICP transform\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "execution": {
- "iopub.execute_input": "2026-02-09T06:35:10.735086Z",
- "iopub.status.busy": "2026-02-09T06:35:10.735086Z",
- "iopub.status.idle": "2026-02-09T06:35:14.798820Z",
- "shell.execute_reply": "2026-02-09T06:35:14.797942Z"
- }
- },
- "outputs": [],
- "source": [
- "# Apply ICP transform to the full average mesh (not just surface)\n",
- "# This gets the volumetric mesh into patient space for PCA registration\n",
- "transform_tools = TransformTools()\n",
- "icp_registered_model = transform_tools.transform_pvcontour(\n",
- " template_model, icp_forward_point_transform\n",
- ")\n",
- "icp_registered_model.save(str(output_dir / \"icp_registered_model.vtk\"))\n",
- "print(\"\\n✓ Applied ICP transform to full model mesh\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Initialize PCA Registration"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "execution": {
- "iopub.execute_input": "2026-02-09T06:35:14.800938Z",
- "iopub.status.busy": "2026-02-09T06:35:14.800938Z",
- "iopub.status.idle": "2026-02-09T06:35:18.694406Z",
- "shell.execute_reply": "2026-02-09T06:35:18.693110Z"
- }
- },
- "outputs": [],
- "source": [
- "## Initialize PCA Registration\n",
- "print(\"=\" * 70)\n",
- "\n",
- "# Use the ICP-aligned mesh as the starting point for PCA registration\n",
- "pca_registrar = RegisterModelsPCA.from_json(\n",
- " pca_template_model=template_model_surface,\n",
- " pca_json_filename=pca_json_path,\n",
- " pca_number_of_modes=10,\n",
- " pre_pca_transform=icp_forward_point_transform,\n",
- " fixed_model=patient_surface,\n",
- " reference_image=patient_image,\n",
- ")\n",
- "\n",
- "itk.imwrite(pca_registrar.fixed_distance_map, str(output_dir / \"distance_map.mha\"))\n",
- "\n",
- "print(\"✓ PCA registrar initialized\")\n",
- "print(\" Using ICP-aligned mesh as starting point\")\n",
- "print(f\" Number of points: {len(pca_registrar.pca_template_model.points)}\")\n",
- "print(f\" Number of PCA modes: {pca_registrar.pca_number_of_modes}\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Run PCA-Based Shape Optimization\n",
- "\n",
- "Now that we have a good initial alignment from ICP affine registration, we run the PCA-based registration to optimize the shape parameters."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "execution": {
- "iopub.execute_input": "2026-02-09T06:35:18.696501Z",
- "iopub.status.busy": "2026-02-09T06:35:18.695503Z",
- "iopub.status.idle": "2026-02-09T06:36:49.984747Z",
- "shell.execute_reply": "2026-02-09T06:36:49.983820Z"
- }
- },
- "outputs": [],
- "source": [
- "print(\"\\n\" + \"=\" * 70)\n",
- "print(\"PCA-BASED SHAPE OPTIMIZATION\")\n",
- "print(\"=\" * 70)\n",
- "print(\"\\nRunning complete PCA registration pipeline...\")\n",
- "print(\" (Starting from ICP-aligned mesh)\")\n",
- "\n",
- "result = pca_registrar.register(\n",
- " pca_number_of_modes=10, # Use first 10 PCA modes\n",
- ")\n",
- "\n",
- "dm = pca_registrar.fixed_distance_map\n",
- "itk.imwrite(dm, str(output_dir / \"target_distance_map.mha\"))\n",
- "\n",
- "pca_registered_model_surface = result[\"registered_model\"]\n",
- "\n",
- "print(\"\\n✓ PCA registration complete\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Display Registration Results\n",
- "\n",
- "Review the optimization results from the PCA registration pipeline.\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "execution": {
- "iopub.execute_input": "2026-02-09T06:36:49.986008Z",
- "iopub.status.busy": "2026-02-09T06:36:49.986008Z",
- "iopub.status.idle": "2026-02-09T06:36:49.999829Z",
- "shell.execute_reply": "2026-02-09T06:36:49.998876Z"
- }
- },
- "outputs": [],
- "source": [
- "print(\"\\n\" + \"=\" * 70)\n",
- "print(\"REGISTRATION RESULTS\")\n",
- "print(\"=\" * 70)\n",
- "\n",
- "# Display results\n",
- "print(\"\\nFinal Registration Metrics:\")\n",
- "print(f\" Final mean intensity: {result['mean_distance']:.4f}\")\n",
- "\n",
- "print(\"\\nOptimized PCA Coefficients (in units of std deviations):\")\n",
- "for i, coef in enumerate(result[\"pca_coefficients\"]):\n",
- " print(f\" Mode {i + 1:2d}: {coef:7.4f}\")\n",
- "\n",
- "print(\"\\n✓ Registration pipeline complete!\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Save Registration Results\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "execution": {
- "iopub.execute_input": "2026-02-09T06:36:50.001252Z",
- "iopub.status.busy": "2026-02-09T06:36:50.001252Z",
- "iopub.status.idle": "2026-02-09T06:36:51.739511Z",
- "shell.execute_reply": "2026-02-09T06:36:51.738563Z"
- }
- },
- "outputs": [],
- "source": [
- "print(\"\\nSaving results...\")\n",
- "\n",
- "# Save final PCA-registered mesh\n",
- "pca_registered_model_surface.save(str(output_dir / \"pca_registered_model_surface.vtk\"))\n",
- "print(\" Saved final PCA-registered mesh\")\n",
- "\n",
- "ref_image = contour_tools.create_reference_image(pca_registered_model_surface)\n",
- "\n",
- "distance_map = contour_tools.create_distance_map(\n",
- " pca_registered_model_surface,\n",
- " ref_image,\n",
- " squared_distance=True,\n",
- " negative_inside=False,\n",
- " zero_inside=True,\n",
- " norm_to_max_distance=200.0,\n",
- ")\n",
- "\n",
- "itk.imwrite(distance_map, str(output_dir / \"pca_distance_map.mha\"))\n",
- "\n",
- "# Save PCA coefficients\n",
- "np.savetxt(\n",
- " str(output_dir / \"pca_coefficients.txt\"),\n",
- " result[\"pca_coefficients\"],\n",
- " header=f\"PCA coefficients for {len(result['pca_coefficients'])} modes\",\n",
- ")\n",
- "print(\" Saved PCA coefficients\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Visualize Results\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "execution": {
- "iopub.execute_input": "2026-02-09T06:36:51.741507Z",
- "iopub.status.busy": "2026-02-09T06:36:51.741507Z",
- "iopub.status.idle": "2026-02-09T06:36:54.084672Z",
- "shell.execute_reply": "2026-02-09T06:36:54.082694Z"
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# PCA-based Heart Model to Image Registration Experiment\n",
+ "\n",
+ "This notebook demonstrates using the `RegisterModelToImagePCA` class to register\n",
+ "a statistical shape model to patient CT images using PCA-based shape variation.\n",
+ "\n",
+ "## Overview\n",
+ "- Uses the KCL Heart Model PCA statistical shape model\n",
+ "- Registers to the same Duke Heart CT data as the original notebook\n",
+ "- Two-stage optimization: rigid alignment + PCA shape fitting\n",
+ "- Converts segmentation mask to intensity image for registration"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Setup and Imports"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2026-02-09T06:34:36.139622Z",
+ "iopub.status.busy": "2026-02-09T06:34:36.139622Z",
+ "iopub.status.idle": "2026-02-09T06:34:51.022942Z",
+ "shell.execute_reply": "2026-02-09T06:34:51.022012Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "# PCA-based Heart Model to Image Registration Experiment\n",
+ "\n",
+ "import json\n",
+ "import os\n",
+ "from pathlib import Path\n",
+ "\n",
+ "import itk\n",
+ "import numpy as np\n",
+ "import pyvista as pv\n",
+ "from itk import TubeTK as ttk\n",
+ "\n",
+ "# Import from PhysioMotion4D package\n",
+ "from physiomotion4d import (\n",
+ " ContourTools,\n",
+ " RegisterModelsICP,\n",
+ " RegisterModelsPCA,\n",
+ " TransformTools,\n",
+ ")\n",
+ "from physiomotion4d.notebook_utils import running_as_test"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Define File Paths"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2026-02-09T06:34:51.023944Z",
+ "iopub.status.busy": "2026-02-09T06:34:51.023944Z",
+ "iopub.status.idle": "2026-02-09T06:34:51.037942Z",
+ "shell.execute_reply": "2026-02-09T06:34:51.036949Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "# Patient CT image (defines coordinate frame)\n",
+ "patient_data_dir = Path.cwd().parent.parent / \"data\" / \"Slicer-Heart-CT\"\n",
+ "patient_ct_path = patient_data_dir / \"patient_img.mha\"\n",
+ "patient_ct_heart_mask_path = patient_data_dir / \"patient_heart_wall_mask.nii.gz\"\n",
+ "\n",
+ "# PCA heart model data\n",
+ "heart_model_data_dir = Path.cwd().parent.parent / \"data\" / \"KCL-Heart-Model\"\n",
+ "heart_model_path = heart_model_data_dir / \"average_mesh.vtk\"\n",
+ "\n",
+ "# PCA statistical model (from Heart-Create_Statistical_Model workflow)\n",
+ "template_model_data_dir = (\n",
+ " Path.cwd().parent / \"Heart-Create_Statistical_Model\" / \"kcl-heart-model\"\n",
+ ")\n",
+ "template_model_surface_path = template_model_data_dir / \"pca_mean.vtp\"\n",
+ "pca_json_path = template_model_data_dir / \"pca_model.json\"\n",
+ "\n",
+ "# Output directory\n",
+ "output_dir = Path.cwd() / \"results_pca\"\n",
+ "os.makedirs(output_dir, exist_ok=True)\n",
+ "\n",
+ "print(f\"Patient data: {patient_data_dir}\")\n",
+ "print(f\"PCA Model data: {template_model_data_dir}\")\n",
+ "print(f\"Output directory: {output_dir}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Load and Preprocess Patient Image"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2026-02-09T06:34:51.050954Z",
+ "iopub.status.busy": "2026-02-09T06:34:51.050954Z",
+ "iopub.status.idle": "2026-02-09T06:34:56.086219Z",
+ "shell.execute_reply": "2026-02-09T06:34:56.085110Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "# Load patient CT image\n",
+ "print(\"Loading patient CT image...\")\n",
+ "patient_image = itk.imread(str(patient_ct_path))\n",
+ "print(f\" Original size: {itk.size(patient_image)}\")\n",
+ "print(f\" Original spacing: {itk.spacing(patient_image)}\")\n",
+ "\n",
+ "# Resample to 1mm isotropic spacing\n",
+ "print(\"Resampling to sotropic...\")\n",
+ "resampler = ttk.ResampleImage.New(Input=patient_image)\n",
+ "resampler.SetMakeHighResIso(True)\n",
+ "resampler.Update()\n",
+ "patient_image = resampler.GetOutput()\n",
+ "\n",
+ "print(f\" Resampled size: {itk.size(patient_image)}\")\n",
+ "print(f\" Resampled spacing: {itk.spacing(patient_image)}\")\n",
+ "\n",
+ "# Save preprocessed image\n",
+ "itk.imwrite(patient_image, str(output_dir / \"patient_image.mha\"), compression=True)\n",
+ "print(\"✓ Saved preprocessed image\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Load and Process Heart Segmentation Mask"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2026-02-09T06:34:56.087775Z",
+ "iopub.status.busy": "2026-02-09T06:34:56.087775Z",
+ "iopub.status.idle": "2026-02-09T06:34:56.115692Z",
+ "shell.execute_reply": "2026-02-09T06:34:56.115196Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "# Load heart segmentation mask\n",
+ "print(\"Loading heart segmentation mask...\")\n",
+ "patient_heart_mask_image = itk.imread(str(patient_ct_heart_mask_path))\n",
+ "\n",
+ "print(f\" Mask size: {itk.size(patient_heart_mask_image)}\")\n",
+ "print(f\" Mask spacing: {itk.spacing(patient_heart_mask_image)}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2026-02-09T06:34:56.117556Z",
+ "iopub.status.busy": "2026-02-09T06:34:56.117556Z",
+ "iopub.status.idle": "2026-02-09T06:34:56.462795Z",
+ "shell.execute_reply": "2026-02-09T06:34:56.462795Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "# Handle image orientation (flip if needed)\n",
+ "flip0 = np.array(patient_heart_mask_image.GetDirection())[0, 0] < 0\n",
+ "flip1 = np.array(patient_heart_mask_image.GetDirection())[1, 1] < 0\n",
+ "flip2 = np.array(patient_heart_mask_image.GetDirection())[2, 2] < 0\n",
+ "\n",
+ "if flip0 or flip1 or flip2:\n",
+ " print(f\"Flipping image axes: {flip0}, {flip1}, {flip2}\")\n",
+ "\n",
+ " # Flip CT image\n",
+ " flip_filter = itk.FlipImageFilter.New(Input=patient_image)\n",
+ " flip_filter.SetFlipAxes([int(flip0), int(flip1), int(flip2)])\n",
+ " flip_filter.SetFlipAboutOrigin(True)\n",
+ " flip_filter.Update()\n",
+ " patient_image = flip_filter.GetOutput()\n",
+ " id_mat = itk.Matrix[itk.D, 3, 3]()\n",
+ " id_mat.SetIdentity()\n",
+ " patient_image.SetDirection(id_mat)\n",
+ "\n",
+ " # Flip mask image\n",
+ " flip_filter = itk.FlipImageFilter.New(Input=patient_heart_mask_image)\n",
+ " flip_filter.SetFlipAxes([int(flip0), int(flip1), int(flip2)])\n",
+ " flip_filter.SetFlipAboutOrigin(True)\n",
+ " flip_filter.Update()\n",
+ " patient_heart_mask_image = flip_filter.GetOutput()\n",
+ " patient_heart_mask_image.SetDirection(id_mat)\n",
+ "\n",
+ " print(\"✓ Images flipped to standard orientation\")\n",
+ "\n",
+ "# Save oriented images\n",
+ "itk.imwrite(\n",
+ " patient_image, str(output_dir / \"patient_image_oriented.mha\"), compression=True\n",
+ ")\n",
+ "itk.imwrite(\n",
+ " patient_heart_mask_image,\n",
+ " str(output_dir / \"patient_heart_mask_oriented.mha\"),\n",
+ " compression=True,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Convert Segmentation Mask to a Surface"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2026-02-09T06:34:56.464797Z",
+ "iopub.status.busy": "2026-02-09T06:34:56.464797Z",
+ "iopub.status.idle": "2026-02-09T06:34:56.917954Z",
+ "shell.execute_reply": "2026-02-09T06:34:56.917429Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "contour_tools = ContourTools()\n",
+ "patient_surface = contour_tools.extract_contours(patient_heart_mask_image)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Perform Initial ICP Affine Registration\n",
+ "\n",
+ "Use ICP (Iterative Closest Point) with affine mode to align the model surface to the patient surface extracted from the segmentation mask. This provides a good initial alignment for the PCA-based registration.\n",
+ "\n",
+ "The ICP registration pipeline:\n",
+ "1. Centroid alignment (automatic)\n",
+ "2. Rigid ICP alignment\n",
+ "\n",
+ "The PCA registration will then refine this initial alignment with shape model constraints."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2026-02-09T06:34:56.919468Z",
+ "iopub.status.busy": "2026-02-09T06:34:56.919468Z",
+ "iopub.status.idle": "2026-02-09T06:35:10.733991Z",
+ "shell.execute_reply": "2026-02-09T06:35:10.733143Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "# Load the pca model\n",
+ "print(\"Loading PCA heart model...\")\n",
+ "template_model = pv.read(str(heart_model_path))\n",
+ "\n",
+ "template_model_surface = pv.read(template_model_surface_path)\n",
+ "print(f\" Template surface: {template_model_surface.n_points} points\")\n",
+ "\n",
+ "icp_registrar = RegisterModelsICP(fixed_model=patient_surface)\n",
+ "\n",
+ "# Use fewer iterations when run as test (pytest) for faster execution\n",
+ "max_iterations_icp = 100 if running_as_test() else 2000\n",
+ "icp_result = icp_registrar.register(\n",
+ " transform_type=\"Affine\",\n",
+ " moving_model=template_model_surface,\n",
+ " max_iterations=max_iterations_icp,\n",
+ ")\n",
+ "\n",
+ "# Get the aligned mesh and transform\n",
+ "icp_registered_model_surface = icp_result[\"registered_model\"]\n",
+ "icp_forward_point_transform = icp_result[\"forward_point_transform\"]\n",
+ "\n",
+ "print(\"\\n✓ ICP affine registration complete\")\n",
+ "print(\" Transform =\", icp_result[\"forward_point_transform\"])\n",
+ "\n",
+ "# Save aligned model\n",
+ "icp_registered_model_surface.save(str(output_dir / \"icp_registered_model_surface.vtp\"))\n",
+ "print(\" Saved ICP-aligned model surface\")\n",
+ "\n",
+ "itk.transformwrite(\n",
+ " [icp_result[\"forward_point_transform\"]],\n",
+ " str(output_dir / \"icp_transform.hdf\"),\n",
+ " compression=True,\n",
+ ")\n",
+ "print(\" Saved ICP transform\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2026-02-09T06:35:10.735086Z",
+ "iopub.status.busy": "2026-02-09T06:35:10.735086Z",
+ "iopub.status.idle": "2026-02-09T06:35:14.798820Z",
+ "shell.execute_reply": "2026-02-09T06:35:14.797942Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "# Apply ICP transform to the full average mesh (not just surface)\n",
+ "# This gets the volumetric mesh into patient space for PCA registration\n",
+ "transform_tools = TransformTools()\n",
+ "icp_registered_model = transform_tools.transform_pvcontour(\n",
+ " template_model, icp_forward_point_transform\n",
+ ")\n",
+ "icp_registered_model.save(str(output_dir / \"icp_registered_model.vtk\"))\n",
+ "print(\"\\n✓ Applied ICP transform to full model mesh\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Initialize PCA Registration"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2026-02-09T06:35:14.800938Z",
+ "iopub.status.busy": "2026-02-09T06:35:14.800938Z",
+ "iopub.status.idle": "2026-02-09T06:35:18.694406Z",
+ "shell.execute_reply": "2026-02-09T06:35:18.693110Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "## Initialize PCA Registration\n",
+ "print(\"=\" * 70)\n",
+ "\n",
+ "# Use the ICP-aligned mesh as the starting point for PCA registration\n",
+ "with open(pca_json_path, encoding=\"utf-8\") as f:\n",
+ " pca_model = json.load(f)\n",
+ "pca_registrar = RegisterModelsPCA.from_pca_model(\n",
+ " pca_template_model=template_model_surface,\n",
+ " pca_model=pca_model,\n",
+ " pca_number_of_modes=10,\n",
+ " pre_pca_transform=icp_forward_point_transform,\n",
+ " fixed_model=patient_surface,\n",
+ " reference_image=patient_image,\n",
+ ")\n",
+ "\n",
+ "itk.imwrite(pca_registrar.fixed_distance_map, str(output_dir / \"distance_map.mha\"))\n",
+ "\n",
+ "print(\"✓ PCA registrar initialized\")\n",
+ "print(\" Using ICP-aligned mesh as starting point\")\n",
+ "print(f\" Number of points: {len(pca_registrar.pca_template_model.points)}\")\n",
+ "print(f\" Number of PCA modes: {pca_registrar.pca_number_of_modes}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Run PCA-Based Shape Optimization\n",
+ "\n",
+ "Now that we have a good initial alignment from ICP affine registration, we run the PCA-based registration to optimize the shape parameters."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2026-02-09T06:35:18.696501Z",
+ "iopub.status.busy": "2026-02-09T06:35:18.695503Z",
+ "iopub.status.idle": "2026-02-09T06:36:49.984747Z",
+ "shell.execute_reply": "2026-02-09T06:36:49.983820Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "print(\"\\n\" + \"=\" * 70)\n",
+ "print(\"PCA-BASED SHAPE OPTIMIZATION\")\n",
+ "print(\"=\" * 70)\n",
+ "print(\"\\nRunning complete PCA registration pipeline...\")\n",
+ "print(\" (Starting from ICP-aligned mesh)\")\n",
+ "\n",
+ "result = pca_registrar.register(\n",
+ " pca_number_of_modes=10, # Use first 10 PCA modes\n",
+ ")\n",
+ "\n",
+ "dm = pca_registrar.fixed_distance_map\n",
+ "itk.imwrite(dm, str(output_dir / \"target_distance_map.mha\"))\n",
+ "\n",
+ "pca_registered_model_surface = result[\"registered_model\"]\n",
+ "\n",
+ "print(\"\\n✓ PCA registration complete\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Display Registration Results\n",
+ "\n",
+ "Review the optimization results from the PCA registration pipeline.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2026-02-09T06:36:49.986008Z",
+ "iopub.status.busy": "2026-02-09T06:36:49.986008Z",
+ "iopub.status.idle": "2026-02-09T06:36:49.999829Z",
+ "shell.execute_reply": "2026-02-09T06:36:49.998876Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "print(\"\\n\" + \"=\" * 70)\n",
+ "print(\"REGISTRATION RESULTS\")\n",
+ "print(\"=\" * 70)\n",
+ "\n",
+ "# Display results\n",
+ "print(\"\\nFinal Registration Metrics:\")\n",
+ "print(f\" Final mean intensity: {result['mean_distance']:.4f}\")\n",
+ "\n",
+ "print(\"\\nOptimized PCA Coefficients (in units of std deviations):\")\n",
+ "for i, coef in enumerate(result[\"pca_coefficients\"]):\n",
+ " print(f\" Mode {i + 1:2d}: {coef:7.4f}\")\n",
+ "\n",
+ "print(\"\\n✓ Registration pipeline complete!\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Save Registration Results\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2026-02-09T06:36:50.001252Z",
+ "iopub.status.busy": "2026-02-09T06:36:50.001252Z",
+ "iopub.status.idle": "2026-02-09T06:36:51.739511Z",
+ "shell.execute_reply": "2026-02-09T06:36:51.738563Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "print(\"\\nSaving results...\")\n",
+ "\n",
+ "# Save final PCA-registered mesh\n",
+ "pca_registered_model_surface.save(str(output_dir / \"pca_registered_model_surface.vtk\"))\n",
+ "print(\" Saved final PCA-registered mesh\")\n",
+ "\n",
+ "ref_image = contour_tools.create_reference_image(pca_registered_model_surface)\n",
+ "\n",
+ "distance_map = contour_tools.create_distance_map(\n",
+ " pca_registered_model_surface,\n",
+ " ref_image,\n",
+ " squared_distance=True,\n",
+ " negative_inside=False,\n",
+ " zero_inside=True,\n",
+ " norm_to_max_distance=200.0,\n",
+ ")\n",
+ "\n",
+ "itk.imwrite(distance_map, str(output_dir / \"pca_distance_map.mha\"))\n",
+ "\n",
+ "# Save PCA coefficients\n",
+ "np.savetxt(\n",
+ " str(output_dir / \"pca_coefficients.txt\"),\n",
+ " result[\"pca_coefficients\"],\n",
+ " header=f\"PCA coefficients for {len(result['pca_coefficients'])} modes\",\n",
+ ")\n",
+ "print(\" Saved PCA coefficients\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Visualize Results\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2026-02-09T06:36:51.741507Z",
+ "iopub.status.busy": "2026-02-09T06:36:51.741507Z",
+ "iopub.status.idle": "2026-02-09T06:36:54.084672Z",
+ "shell.execute_reply": "2026-02-09T06:36:54.082694Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "# Create side-by-side comparison\n",
+ "plotter = pv.Plotter(shape=(1, 2), window_size=[1000, 600])\n",
+ "\n",
+ "plotter.subplot(0, 0)\n",
+ "plotter.add_mesh(patient_surface, color=\"red\", opacity=1.0, label=\"Patient\")\n",
+ "plotter.add_mesh(\n",
+ " icp_registered_model_surface, color=\"green\", opacity=1.0, label=\"ICP Registered\"\n",
+ ")\n",
+ "plotter.add_title(\"ICP Shape Fitting\")\n",
+ "plotter.add_axes()\n",
+ "\n",
+ "# After PCA shape fitting\n",
+ "plotter.subplot(0, 1)\n",
+ "plotter.add_mesh(patient_surface, color=\"red\", opacity=1.0, label=\"Patient\")\n",
+ "plotter.add_mesh(\n",
+ " pca_registered_model_surface, color=\"green\", opacity=1.0, label=\"PCA Registered\"\n",
+ ")\n",
+ "plotter.add_title(\"PCA Shape Fitting\")\n",
+ "plotter.add_axes()\n",
+ "\n",
+ "plotter.link_views()\n",
+ "plotter.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Visualize PCA Displacement Magnitude\n",
+ "\n",
+ "Compute and display the displacement magnitude caused by PCA optimization. This shows how much each point moved from the ICP-aligned mean shape to the final PCA-registered shape."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2026-02-09T06:36:54.086869Z",
+ "iopub.status.busy": "2026-02-09T06:36:54.086230Z",
+ "iopub.status.idle": "2026-02-09T06:36:55.292551Z",
+ "shell.execute_reply": "2026-02-09T06:36:55.291527Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "# Compute displacement from ICP-aligned (mean shape) to PCA-registered shape\n",
+ "icp_points = icp_registered_model_surface.points\n",
+ "pca_points = pca_registered_model_surface.points\n",
+ "\n",
+ "# Calculate displacement vectors\n",
+ "displacement_vectors = pca_points - icp_points\n",
+ "\n",
+ "# Compute surface normals for the ICP-aligned mesh\n",
+ "icp_registered_model_with_normals = icp_registered_model_surface.compute_normals(\n",
+ " point_normals=True, cell_normals=False\n",
+ ")\n",
+ "normals = icp_registered_model_with_normals.point_data[\"Normals\"]\n",
+ "\n",
+ "# Calculate signed displacement along the normal direction\n",
+ "# Positive = outward displacement, Negative = inward displacement\n",
+ "signed_displacement = np.sum(displacement_vectors * normals, axis=1)\n",
+ "\n",
+ "# Add displacement as scalar data to the mesh\n",
+ "pca_registered_model_with_displacement = pca_registered_model_surface.copy()\n",
+ "pca_registered_model_with_displacement[\"PCA Signed Displacement (mm)\"] = (\n",
+ " signed_displacement\n",
+ ")\n",
+ "\n",
+ "# Print statistics\n",
+ "print(\"PCA Signed Displacement Statistics:\")\n",
+ "print(f\" Mean displacement: {np.mean(signed_displacement):.2f} mm\")\n",
+ "print(f\" Max displacement (outward): {np.max(signed_displacement):.2f} mm\")\n",
+ "print(f\" Min displacement (inward): {np.min(signed_displacement):.2f} mm\")\n",
+ "print(f\" Std displacement: {np.std(signed_displacement):.2f} mm\")\n",
+ "\n",
+ "# Visualize the signed displacement with diverging colormap\n",
+ "# Blue = inward displacement, Red = outward displacement\n",
+ "plotter = pv.Plotter(window_size=[800, 600])\n",
+ "plotter.add_mesh(\n",
+ " pca_registered_model_with_displacement,\n",
+ " scalars=\"PCA Signed Displacement (mm)\",\n",
+ " cmap=\"RdBu_r\", # Red for positive (outward), Blue for negative (inward)\n",
+ " clim=[\n",
+ " -np.max(np.abs(signed_displacement)),\n",
+ " np.max(np.abs(signed_displacement)),\n",
+ " ], # Symmetric color scale\n",
+ " show_scalar_bar=True,\n",
+ " scalar_bar_args={\n",
+ " \"title\": \"PCA Signed Displacement (mm)\\n(Red=Outward, Blue=Inward)\",\n",
+ " \"vertical\": True,\n",
+ " \"position_x\": 0.82,\n",
+ " \"position_y\": 0.1,\n",
+ " },\n",
+ ")\n",
+ "plotter.add_title(\"PCA Signed Displacement on Registered Model\")\n",
+ "plotter.add_axes()\n",
+ "plotter.show()\n",
+ "\n",
+ "# Save the mesh with displacement data\n",
+ "pca_registered_model_with_displacement.save(\n",
+ " str(output_dir / \"pca_registered_model_with_signed_displacement.vtp\")\n",
+ ")\n",
+ "print(\"\\n✓ Saved model with signed displacement data\")"
+ ]
}
- },
- "outputs": [],
- "source": [
- "# Create side-by-side comparison\n",
- "plotter = pv.Plotter(shape=(1, 2), window_size=[1000, 600])\n",
- "\n",
- "plotter.subplot(0, 0)\n",
- "plotter.add_mesh(patient_surface, color=\"red\", opacity=1.0, label=\"Patient\")\n",
- "plotter.add_mesh(\n",
- " icp_registered_model_surface, color=\"green\", opacity=1.0, label=\"ICP Registered\"\n",
- ")\n",
- "plotter.add_title(\"ICP Shape Fitting\")\n",
- "plotter.add_axes()\n",
- "\n",
- "# After PCA shape fitting\n",
- "plotter.subplot(0, 1)\n",
- "plotter.add_mesh(patient_surface, color=\"red\", opacity=1.0, label=\"Patient\")\n",
- "plotter.add_mesh(\n",
- " pca_registered_model_surface, color=\"green\", opacity=1.0, label=\"PCA Registered\"\n",
- ")\n",
- "plotter.add_title(\"PCA Shape Fitting\")\n",
- "plotter.add_axes()\n",
- "\n",
- "plotter.link_views()\n",
- "plotter.show()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Visualize PCA Displacement Magnitude\n",
- "\n",
- "Compute and display the displacement magnitude caused by PCA optimization. This shows how much each point moved from the ICP-aligned mean shape to the final PCA-registered shape."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "execution": {
- "iopub.execute_input": "2026-02-09T06:36:54.086869Z",
- "iopub.status.busy": "2026-02-09T06:36:54.086230Z",
- "iopub.status.idle": "2026-02-09T06:36:55.292551Z",
- "shell.execute_reply": "2026-02-09T06:36:55.291527Z"
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "venv",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.11"
+ },
+ "widgets": {
+ "application/vnd.jupyter.widget-state+json": {
+ "state": {
+ "12ebef1bb4134652baada92b4bf41e65": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "2.0.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "2.0.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "2.0.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_allow_html": false,
+ "layout": "IPY_MODEL_d13b70b19456415a92798dc847c57042",
+ "placeholder": "",
+ "style": "IPY_MODEL_8a5ce0cea571455f908701b9b74e615b",
+ "tabbable": null,
+ "tooltip": null,
+ "value": ""
+ }
+ },
+ "8a5ce0cea571455f908701b9b74e615b": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "2.0.0",
+ "model_name": "HTMLStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "2.0.0",
+ "_model_name": "HTMLStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "2.0.0",
+ "_view_name": "StyleView",
+ "background": null,
+ "description_width": "",
+ "font_size": null,
+ "text_color": null
+ }
+ },
+ "aa7837c5f1364c6c9086d42707ea1ac5": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "2.0.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "2.0.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "2.0.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border_bottom": null,
+ "border_left": null,
+ "border_right": null,
+ "border_top": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "b493ad90f44d494ebad4e99da0e2eb5f": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "2.0.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "2.0.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "2.0.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_allow_html": false,
+ "layout": "IPY_MODEL_aa7837c5f1364c6c9086d42707ea1ac5",
+ "placeholder": "",
+ "style": "IPY_MODEL_c21c90fea7bb4fa490ee1b0ca2073d14",
+ "tabbable": null,
+ "tooltip": null,
+ "value": ""
+ }
+ },
+ "c21c90fea7bb4fa490ee1b0ca2073d14": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "2.0.0",
+ "model_name": "HTMLStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "2.0.0",
+ "_model_name": "HTMLStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "2.0.0",
+ "_view_name": "StyleView",
+ "background": null,
+ "description_width": "",
+ "font_size": null,
+ "text_color": null
+ }
+ },
+ "d13b70b19456415a92798dc847c57042": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "2.0.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "2.0.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "2.0.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border_bottom": null,
+ "border_left": null,
+ "border_right": null,
+ "border_top": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ }
+ },
+ "version_major": 2,
+ "version_minor": 0
+ }
}
- },
- "outputs": [],
- "source": [
- "# Compute displacement from ICP-aligned (mean shape) to PCA-registered shape\n",
- "icp_points = icp_registered_model_surface.points\n",
- "pca_points = pca_registered_model_surface.points\n",
- "\n",
- "# Calculate displacement vectors\n",
- "displacement_vectors = pca_points - icp_points\n",
- "\n",
- "# Compute surface normals for the ICP-aligned mesh\n",
- "icp_registered_model_with_normals = icp_registered_model_surface.compute_normals(\n",
- " point_normals=True, cell_normals=False\n",
- ")\n",
- "normals = icp_registered_model_with_normals.point_data[\"Normals\"]\n",
- "\n",
- "# Calculate signed displacement along the normal direction\n",
- "# Positive = outward displacement, Negative = inward displacement\n",
- "signed_displacement = np.sum(displacement_vectors * normals, axis=1)\n",
- "\n",
- "# Add displacement as scalar data to the mesh\n",
- "pca_registered_model_with_displacement = pca_registered_model_surface.copy()\n",
- "pca_registered_model_with_displacement[\"PCA Signed Displacement (mm)\"] = (\n",
- " signed_displacement\n",
- ")\n",
- "\n",
- "# Print statistics\n",
- "print(\"PCA Signed Displacement Statistics:\")\n",
- "print(f\" Mean displacement: {np.mean(signed_displacement):.2f} mm\")\n",
- "print(f\" Max displacement (outward): {np.max(signed_displacement):.2f} mm\")\n",
- "print(f\" Min displacement (inward): {np.min(signed_displacement):.2f} mm\")\n",
- "print(f\" Std displacement: {np.std(signed_displacement):.2f} mm\")\n",
- "\n",
- "# Visualize the signed displacement with diverging colormap\n",
- "# Blue = inward displacement, Red = outward displacement\n",
- "plotter = pv.Plotter(window_size=[800, 600])\n",
- "plotter.add_mesh(\n",
- " pca_registered_model_with_displacement,\n",
- " scalars=\"PCA Signed Displacement (mm)\",\n",
- " cmap=\"RdBu_r\", # Red for positive (outward), Blue for negative (inward)\n",
- " clim=[\n",
- " -np.max(np.abs(signed_displacement)),\n",
- " np.max(np.abs(signed_displacement)),\n",
- " ], # Symmetric color scale\n",
- " show_scalar_bar=True,\n",
- " scalar_bar_args={\n",
- " \"title\": \"PCA Signed Displacement (mm)\\n(Red=Outward, Blue=Inward)\",\n",
- " \"vertical\": True,\n",
- " \"position_x\": 0.82,\n",
- " \"position_y\": 0.1,\n",
- " },\n",
- ")\n",
- "plotter.add_title(\"PCA Signed Displacement on Registered Model\")\n",
- "plotter.add_axes()\n",
- "plotter.show()\n",
- "\n",
- "# Save the mesh with displacement data\n",
- "pca_registered_model_with_displacement.save(\n",
- " str(output_dir / \"pca_registered_model_with_signed_displacement.vtp\")\n",
- ")\n",
- "print(\"\\n✓ Saved model with signed displacement data\")"
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "venv",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.10.11"
},
- "widgets": {
- "application/vnd.jupyter.widget-state+json": {
- "state": {
- "12ebef1bb4134652baada92b4bf41e65": {
- "model_module": "@jupyter-widgets/controls",
- "model_module_version": "2.0.0",
- "model_name": "HTMLModel",
- "state": {
- "_dom_classes": [],
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "2.0.0",
- "_model_name": "HTMLModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/controls",
- "_view_module_version": "2.0.0",
- "_view_name": "HTMLView",
- "description": "",
- "description_allow_html": false,
- "layout": "IPY_MODEL_d13b70b19456415a92798dc847c57042",
- "placeholder": "",
- "style": "IPY_MODEL_8a5ce0cea571455f908701b9b74e615b",
- "tabbable": null,
- "tooltip": null,
- "value": ""
- }
- },
- "8a5ce0cea571455f908701b9b74e615b": {
- "model_module": "@jupyter-widgets/controls",
- "model_module_version": "2.0.0",
- "model_name": "HTMLStyleModel",
- "state": {
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "2.0.0",
- "_model_name": "HTMLStyleModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "2.0.0",
- "_view_name": "StyleView",
- "background": null,
- "description_width": "",
- "font_size": null,
- "text_color": null
- }
- },
- "aa7837c5f1364c6c9086d42707ea1ac5": {
- "model_module": "@jupyter-widgets/base",
- "model_module_version": "2.0.0",
- "model_name": "LayoutModel",
- "state": {
- "_model_module": "@jupyter-widgets/base",
- "_model_module_version": "2.0.0",
- "_model_name": "LayoutModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "2.0.0",
- "_view_name": "LayoutView",
- "align_content": null,
- "align_items": null,
- "align_self": null,
- "border_bottom": null,
- "border_left": null,
- "border_right": null,
- "border_top": null,
- "bottom": null,
- "display": null,
- "flex": null,
- "flex_flow": null,
- "grid_area": null,
- "grid_auto_columns": null,
- "grid_auto_flow": null,
- "grid_auto_rows": null,
- "grid_column": null,
- "grid_gap": null,
- "grid_row": null,
- "grid_template_areas": null,
- "grid_template_columns": null,
- "grid_template_rows": null,
- "height": null,
- "justify_content": null,
- "justify_items": null,
- "left": null,
- "margin": null,
- "max_height": null,
- "max_width": null,
- "min_height": null,
- "min_width": null,
- "object_fit": null,
- "object_position": null,
- "order": null,
- "overflow": null,
- "padding": null,
- "right": null,
- "top": null,
- "visibility": null,
- "width": null
- }
- },
- "b493ad90f44d494ebad4e99da0e2eb5f": {
- "model_module": "@jupyter-widgets/controls",
- "model_module_version": "2.0.0",
- "model_name": "HTMLModel",
- "state": {
- "_dom_classes": [],
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "2.0.0",
- "_model_name": "HTMLModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/controls",
- "_view_module_version": "2.0.0",
- "_view_name": "HTMLView",
- "description": "",
- "description_allow_html": false,
- "layout": "IPY_MODEL_aa7837c5f1364c6c9086d42707ea1ac5",
- "placeholder": "",
- "style": "IPY_MODEL_c21c90fea7bb4fa490ee1b0ca2073d14",
- "tabbable": null,
- "tooltip": null,
- "value": ""
- }
- },
- "c21c90fea7bb4fa490ee1b0ca2073d14": {
- "model_module": "@jupyter-widgets/controls",
- "model_module_version": "2.0.0",
- "model_name": "HTMLStyleModel",
- "state": {
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "2.0.0",
- "_model_name": "HTMLStyleModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "2.0.0",
- "_view_name": "StyleView",
- "background": null,
- "description_width": "",
- "font_size": null,
- "text_color": null
- }
- },
- "d13b70b19456415a92798dc847c57042": {
- "model_module": "@jupyter-widgets/base",
- "model_module_version": "2.0.0",
- "model_name": "LayoutModel",
- "state": {
- "_model_module": "@jupyter-widgets/base",
- "_model_module_version": "2.0.0",
- "_model_name": "LayoutModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "2.0.0",
- "_view_name": "LayoutView",
- "align_content": null,
- "align_items": null,
- "align_self": null,
- "border_bottom": null,
- "border_left": null,
- "border_right": null,
- "border_top": null,
- "bottom": null,
- "display": null,
- "flex": null,
- "flex_flow": null,
- "grid_area": null,
- "grid_auto_columns": null,
- "grid_auto_flow": null,
- "grid_auto_rows": null,
- "grid_column": null,
- "grid_gap": null,
- "grid_row": null,
- "grid_template_areas": null,
- "grid_template_columns": null,
- "grid_template_rows": null,
- "height": null,
- "justify_content": null,
- "justify_items": null,
- "left": null,
- "margin": null,
- "max_height": null,
- "max_width": null,
- "min_height": null,
- "min_width": null,
- "object_fit": null,
- "object_position": null,
- "order": null,
- "overflow": null,
- "padding": null,
- "right": null,
- "top": null,
- "visibility": null,
- "width": null
- }
- }
- },
- "version_major": 2,
- "version_minor": 0
- }
- }
- },
- "nbformat": 4,
- "nbformat_minor": 4
+ "nbformat": 4,
+ "nbformat_minor": 4
}
diff --git a/experiments/Heart-Statistical_Model_To_Patient/heart_model_to_patient.ipynb b/experiments/Heart-Statistical_Model_To_Patient/heart_model_to_patient.ipynb
index dcd7a59..6c08ef6 100644
--- a/experiments/Heart-Statistical_Model_To_Patient/heart_model_to_patient.ipynb
+++ b/experiments/Heart-Statistical_Model_To_Patient/heart_model_to_patient.ipynb
@@ -20,6 +20,7 @@
},
"outputs": [],
"source": [
+ "import json\n",
"import os\n",
"from pathlib import Path\n",
"\n",
@@ -31,7 +32,7 @@
"from physiomotion4d import (\n",
" ContourTools,\n",
" SegmentChestTotalSegmentator,\n",
- " WorkflowRegisterHeartModelToPatient,\n",
+ " WorkflowFitStatisticalModelToPatient,\n",
")"
]
},
@@ -214,16 +215,22 @@
},
"outputs": [],
"source": [
- "registrar = WorkflowRegisterHeartModelToPatient(\n",
+ "with open(pca_json_path, encoding=\"utf-8\") as f:\n",
+ " pca_model = json.load(f)\n",
+ "registrar = WorkflowFitStatisticalModelToPatient(\n",
" template_model=template_model,\n",
+ " patient_models=[patient_model],\n",
+ " patient_image=patient_image,\n",
+ ")\n",
+ "registrar.set_use_pca_registration(\n",
+ " True, pca_model=pca_model, pca_number_of_modes=pca_n_modes\n",
+ ")\n",
+ "registrar.set_use_mask_to_image_registration(\n",
+ " True,\n",
" template_labelmap=template_labelmap,\n",
- " template_labelmap_heart_muscle_ids=[1],\n",
- " template_labelmap_chamber_ids=[2, 3, 4, 5],\n",
+ " template_labelmap_organ_mesh_ids=[1],\n",
+ " template_labelmap_organ_extra_ids=[2, 3, 4, 5],\n",
" template_labelmap_background_ids=[6],\n",
- " patient_image=patient_image,\n",
- " patient_models=[patient_model],\n",
- " pca_json_filename=pca_json_path,\n",
- " pca_number_of_modes=pca_n_modes,\n",
")\n",
"\n",
"registrar.set_mask_dilation_mm(0)\n",
diff --git a/experiments/README.md b/experiments/README.md
index aac93f2..fb07670 100644
--- a/experiments/README.md
+++ b/experiments/README.md
@@ -14,7 +14,7 @@ of the PhysioMotion4D library.
- Command-line tools and parameter specifications
See:
-- **CLI Commands**: Run `physiomotion4d-heart-gated-ct --help` and `physiomotion4d-register-heart-model --help`
+- **CLI Commands**: Run `physiomotion4d-heart-gated-ct --help`, `physiomotion4d-create-statistical-model --help`, and `physiomotion4d-fit-statistical-model-to-patient --help`
- **CLI Implementation**: `src/physiomotion4d/cli/` for Python API examples
- **Library Classes**: `src/physiomotion4d/` for all workflow and utility classes
@@ -36,7 +36,7 @@ These experiments demonstrate key digital twin workflows that can be adapted to
anatomical regions, imaging modalities, and physiological motion tasks.
> **Note:** For production implementations of these workflows, use the CLI commands
-> (`physiomotion4d-heart-gated-ct`, `physiomotion4d-register-heart-model`) or consult
+> (`physiomotion4d-heart-gated-ct`, `physiomotion4d-create-statistical-model`, `physiomotion4d-fit-statistical-model-to-patient`) or consult
> the CLI implementation in `src/physiomotion4d/cli/` for proper class usage and parameter specifications.
### `Reconstruct4DCT` - High-Resolution 4D Reconstruction
@@ -275,7 +275,8 @@ Each subdirectory represents a different experimental domain:
1. **CLI Commands** ⭐ **PRIMARY RESOURCE**
- `physiomotion4d-heart-gated-ct` - Complete heart-gated CT workflow
- - `physiomotion4d-register-heart-model` - Model-to-patient registration
+ - `physiomotion4d-create-statistical-model` - Create PCA statistical shape model from sample meshes
+ - `physiomotion4d-fit-statistical-model-to-patient` - Model-to-patient registration
- Run with `--help` for all options and parameter specifications
- Tested on diverse datasets
@@ -336,7 +337,7 @@ The typical evolution path was:
When exploring new digital twin applications, you can follow a similar path:
- Start by understanding relevant experiments here as conceptual references
- Examine CLI implementations in `src/physiomotion4d/cli/` for proper library usage
-- Use CLI commands (`physiomotion4d-heart-gated-ct`, `physiomotion4d-register-heart-model`) as starting points
+- Use CLI commands (`physiomotion4d-heart-gated-ct`, `physiomotion4d-create-statistical-model`, `physiomotion4d-fit-statistical-model-to-patient`) as starting points
- Extend and adapt production code with your domain-specific requirements
- Contribute back improvements and new capabilities to the community
diff --git a/pyproject.toml b/pyproject.toml
index afac15a..1ea2f3d 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -163,7 +163,9 @@ Changelog = "https://github.com/aylward/PhysioMotion4d/blob/main/CHANGELOG.md"
# CLI commands installed via pip
# Entry points reference the main() functions in the cli submodule
physiomotion4d-heart-gated-ct = "physiomotion4d.cli.convert_heart_gated_ct_to_usd:main"
-physiomotion4d-register-heart-model = "physiomotion4d.cli.register_heart_model_to_patient:main"
+physiomotion4d-create-statistical-model = "physiomotion4d.cli.create_statistical_model:main"
+physiomotion4d-fit-statistical-model-to-patient = "physiomotion4d.cli.fit_statistical_model_to_patient:main"
+physiomotion4d-visualize-pca-modes = "physiomotion4d.cli.visualize_pca_modes:main"
[tool.setuptools.packages.find]
where = ["src"]
@@ -173,6 +175,7 @@ current_version = "2025.05.0"
version_pattern = "YYYY.0M.PATCH[PYTAGNUM]"
commit_message = "Bump version {old_version} -> {new_version}"
commit = true
+tag_message = "{new_version}"
tag = true
push = false
@@ -201,42 +204,20 @@ warn_unreachable = true
strict_equality = true
show_error_codes = true
exclude = '(?x)(^src[/\\\\]physiomotion4d[/\\\\]network_weights[/\\\\]vista3d[/\\\\])'
+# Third-party libs (itk, pyvista, pxr, vtk, simpleware, etc.) have no stubs
+disable_error_code = ["import-not-found", "import-untyped"]
+# When type-checking our package, ignore missing imports for third-party libs.
+# Override applies to the module being checked (physiomotion4d.*).
[[tool.mypy.overrides]]
-module = [
- "ants",
- "cupy",
- "huggingface_hub",
- "hugging_face_pipeline",
- "icon_registration.*",
- "ignite.*",
- "itk.*",
- "matplotlib",
- "matplotlib.*",
- "monai.*",
- "nibabel.*",
- "nrrd",
- "numpy",
- "numpy.*",
- "pxr",
- "pynrrd.*",
- "pyvista.*",
- "scripts.*",
- "SimpleITK",
- "simpleware.*",
- "torch",
- "torch.*",
- "totalsegmentator.*",
- "transformers",
- "transformers.*",
- "trimesh",
- "unigradicon.*",
- "vista3d",
- "vista3d.*",
- "vtk.*"
-]
+module = ["physiomotion4d", "physiomotion4d.*"]
ignore_missing_imports = true
+[tool.pyright]
+# Third-party packages (e.g. pyvista) are in dependencies but may have no stubs;
+# do not report import-not-found so analysis matches mypy overrides above.
+reportMissingImports = false
+
[tool.pytest.ini_options]
minversion = "7.0"
addopts = [
diff --git a/src/physiomotion4d/__init__.py b/src/physiomotion4d/__init__.py
index df48af7..16e27cc 100644
--- a/src/physiomotion4d/__init__.py
+++ b/src/physiomotion4d/__init__.py
@@ -45,11 +45,12 @@
from .register_time_series_images import RegisterTimeSeriesImages
# Segmentation classes
-from .segment_chest_base import SegmentChestBase
+from .segment_anatomy_base import SegmentAnatomyBase
from .segment_chest_ensemble import SegmentChestEnsemble
from .segment_chest_total_segmentator import SegmentChestTotalSegmentator
from .segment_chest_vista_3d import SegmentChestVista3D
from .segment_chest_vista_3d_nim import SegmentChestVista3DNIM
+from .segment_heart_simpleware import SegmentHeartSimpleware
from .transform_tools import TransformTools
from .usd_anatomy_tools import USDAnatomyTools
from .usd_tools import USDTools
@@ -57,21 +58,24 @@
# Core workflow processor
from .workflow_convert_heart_gated_ct_to_usd import WorkflowConvertHeartGatedCTToUSD
from .workflow_reconstruct_highres_4d_ct import WorkflowReconstructHighres4DCT
-from .workflow_register_heart_model_to_patient import (
- WorkflowRegisterHeartModelToPatient,
+from .workflow_create_statistical_model import WorkflowCreateStatisticalModel
+from .workflow_fit_statistical_model_to_patient import (
+ WorkflowFitStatisticalModelToPatient,
)
__all__ = [
# Workflow classes
"WorkflowConvertHeartGatedCTToUSD",
+ "WorkflowCreateStatisticalModel",
"WorkflowReconstructHighres4DCT",
- "WorkflowRegisterHeartModelToPatient",
+ "WorkflowFitStatisticalModelToPatient",
# Segmentation classes
- "SegmentChestBase",
+ "SegmentAnatomyBase",
"SegmentChestEnsemble",
"SegmentChestTotalSegmentator",
"SegmentChestVista3D",
"SegmentChestVista3DNIM",
+ "SegmentHeartSimpleware",
# Registration classes
"RegisterImagesBase",
"RegisterImagesICON",
diff --git a/src/physiomotion4d/cli/__init__.py b/src/physiomotion4d/cli/__init__.py
index c40248b..632cf1c 100644
--- a/src/physiomotion4d/cli/__init__.py
+++ b/src/physiomotion4d/cli/__init__.py
@@ -1,3 +1,8 @@
"""Command-line interface modules for PhysioMotion4D."""
-__all__ = ["convert_heart_gated_ct_to_usd", "register_heart_model_to_patient"]
+__all__ = [
+ "convert_heart_gated_ct_to_usd",
+ "create_statistical_model",
+ "fit_statistical_model_to_patient",
+ "visualize_pca_modes",
+]
diff --git a/src/physiomotion4d/cli/create_statistical_model.py b/src/physiomotion4d/cli/create_statistical_model.py
new file mode 100644
index 0000000..635ab5d
--- /dev/null
+++ b/src/physiomotion4d/cli/create_statistical_model.py
@@ -0,0 +1,175 @@
+#!/usr/bin/env python
+"""
+Command-line interface for Create Statistical Model workflow.
+
+This script provides a CLI to build a PCA statistical shape model from a sample
+of meshes aligned to a reference mesh, as in the Heart-Create_Statistical_Model
+experiment notebooks. Outputs include pca_mean_surface.vtp, pca_mean.vtu (if
+reference is volumetric), and pca_model.json.
+"""
+
+import argparse
+import json
+import os
+import sys
+import traceback
+from pathlib import Path
+
+import pyvista as pv
+
+from physiomotion4d import WorkflowCreateStatisticalModel
+
+
+def main() -> int:
+ """Command-line interface for create statistical model workflow."""
+ parser = argparse.ArgumentParser(
+ description="Create a PCA statistical shape model from sample meshes aligned to a reference",
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ epilog="""
+Examples:
+ # Create model from a directory of sample meshes and a reference mesh
+ %(prog)s \\
+ --sample-meshes-dir ./meshes \\
+ --reference-mesh average_mesh.vtk \\
+ --output-dir ./pca_model
+
+ # Specify sample meshes explicitly
+ %(prog)s \\
+ --sample-meshes 01.vtk 02.vtk 03.vtu \\
+ --reference-mesh average_mesh.vtk \\
+ --output-dir ./pca_model
+
+ # Custom PCA components
+ %(prog)s \\
+ --sample-meshes-dir ./meshes \\
+ --reference-mesh average_mesh.vtk \\
+ --output-dir ./pca_model \\
+ --pca-components 20
+ """,
+ )
+
+ # Sample meshes: either a directory or a list of files
+ group = parser.add_mutually_exclusive_group(required=True)
+ group.add_argument(
+ "--sample-meshes-dir",
+ type=Path,
+ metavar="DIR",
+ help="Directory containing sample mesh files (.vtk, .vtu, .vtp)",
+ )
+ group.add_argument(
+ "--sample-meshes",
+ nargs="+",
+ type=Path,
+ metavar="PATH",
+ help="Paths to sample mesh files",
+ )
+
+ parser.add_argument(
+ "--reference-mesh",
+ type=Path,
+ required=True,
+ metavar="PATH",
+ help="Path to reference mesh; its surface is used to align all samples",
+ )
+ parser.add_argument(
+ "--output-dir",
+ type=Path,
+ required=True,
+ metavar="DIR",
+ help="Output directory for pca_mean_surface.vtp, pca_mean.vtu, pca_model.json",
+ )
+
+ parser.add_argument(
+ "--pca-components",
+ type=int,
+ default=15,
+ help="Number of PCA components to retain (default: 15)",
+ )
+
+ args = parser.parse_args()
+
+ # Resolve sample mesh paths
+ if args.sample_meshes_dir is not None:
+ smd = Path(args.sample_meshes_dir)
+ sample_paths: list[Path] = []
+ for ext in [".vtk", ".vtp", ".vtu"]:
+ sample_paths.extend(smd.glob(f"*{ext}"))
+ sample_paths = sorted(set(sample_paths))
+ if not sample_paths:
+ print(
+ f"Error: No .vtk, .vtu, or .vtp files found in {args.sample_meshes_dir}"
+ )
+ return 1
+ else:
+ sample_paths = args.sample_meshes
+
+ # Validate paths
+ print("Validating input files...")
+ if not args.reference_mesh.exists():
+ print(f"Error: Reference mesh not found: {args.reference_mesh}")
+ return 1
+ for p in sample_paths:
+ if not p.exists():
+ print(f"Error: Sample mesh not found: {p}")
+ return 1
+
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ # Load meshes
+ print("\nLoading meshes...")
+ try:
+ print(f" Reference mesh: {args.reference_mesh}")
+ reference_mesh = pv.read(args.reference_mesh)
+ print(f" Sample meshes: {len(sample_paths)} files")
+ sample_meshes = [pv.read(p) for p in sample_paths]
+ except (FileNotFoundError, OSError, RuntimeError) as e:
+ print(f"Error loading meshes: {e}")
+ traceback.print_exc()
+ return 1
+
+ # Run workflow
+ print("\nInitializing create statistical model workflow...")
+ try:
+ workflow = WorkflowCreateStatisticalModel(
+ sample_meshes=sample_meshes,
+ reference_mesh=reference_mesh,
+ pca_number_of_components=args.pca_components,
+ )
+ except (ValueError, RuntimeError) as e:
+ print(f"Error initializing workflow: {e}")
+ traceback.print_exc()
+ return 1
+
+ try:
+ print("\nRunning pipeline...")
+ print("=" * 70)
+ result = workflow.run_workflow()
+ print("=" * 70)
+ print("\nSaving outputs...")
+
+ out_surface = args.output_dir / "pca_mean_surface.vtp"
+ result["pca_mean_surface"].save(str(out_surface))
+ print(f" pca_mean_surface: {out_surface}")
+
+ if result.get("pca_mean_mesh") is not None:
+ out_mesh = args.output_dir / "pca_mean.vtu"
+ result["pca_mean_mesh"].save(str(out_mesh))
+ print(f" pca_mean_mesh: {out_mesh}")
+
+ out_json = args.output_dir / "pca_model.json"
+ with open(out_json, "w", encoding="utf-8") as f:
+ json.dump(result["pca_model"], f, indent=4)
+ print(f" pca_model: {out_json}")
+
+ print("\nCreate statistical model completed successfully.")
+ print(f"Outputs written to: {args.output_dir}")
+ return 0
+
+ except (RuntimeError, ValueError, OSError) as e:
+ print(f"\nError during workflow: {e}")
+ traceback.print_exc()
+ return 1
+
+
+if __name__ == "__main__":
+ sys.exit(main())
diff --git a/src/physiomotion4d/cli/register_heart_model_to_patient.py b/src/physiomotion4d/cli/fit_statistical_model_to_patient.py
similarity index 78%
rename from src/physiomotion4d/cli/register_heart_model_to_patient.py
rename to src/physiomotion4d/cli/fit_statistical_model_to_patient.py
index 1955174..56c86d5 100644
--- a/src/physiomotion4d/cli/register_heart_model_to_patient.py
+++ b/src/physiomotion4d/cli/fit_statistical_model_to_patient.py
@@ -8,6 +8,7 @@
"""
import argparse
+import json
import os
import sys
import traceback
@@ -15,7 +16,7 @@
import itk
import pyvista as pv
-from physiomotion4d import WorkflowRegisterHeartModelToPatient
+from physiomotion4d import WorkflowFitStatisticalModelToPatient
def main() -> int:
@@ -25,39 +26,36 @@ def main() -> int:
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
- # Basic registration with required inputs
+ # Basic registration (no patient image: reference image created from patient models)
%(prog)s \\
--template-model heart_model.vtu \\
- --template-labelmap heart_labelmap.nii.gz \\
--patient-models lv.vtp rv.vtp myo.vtp \\
- --patient-image patient_ct.nii.gz \\
--output-dir ./results
- # Registration with PCA shape fitting
+ # With patient image and PCA shape fitting
%(prog)s \\
--template-model heart_model.vtu \\
- --template-labelmap heart_labelmap.nii.gz \\
--patient-models lv.vtp rv.vtp myo.vtp \\
--patient-image patient_ct.nii.gz \\
--pca-json pca_model.json \\
--pca-number-of-modes 10 \\
--output-dir ./results
- # Registration with custom label IDs
+ # Enable mask-to-image refinement (requires template labelmap and label IDs)
%(prog)s \\
--template-model heart_model.vtu \\
- --template-labelmap heart_labelmap.nii.gz \\
- --patient-models lv.vtp rv.vtp \\
+ --patient-models lv.vtp rv.vtp myo.vtp \\
--patient-image patient_ct.nii.gz \\
+ --mask-to-image \\
+ --template-labelmap heart_labelmap.nii.gz \\
--template-labelmap-muscle-ids 1 2 3 \\
--template-labelmap-chamber-ids 4 5 6 \\
--template-labelmap-background-ids 0 \\
--output-dir ./results
- # Registration with ICON refinement
+ # With ICON refinement
%(prog)s \\
--template-model heart_model.vtu \\
- --template-labelmap heart_labelmap.nii.gz \\
--patient-models lv.vtp rv.vtp \\
--patient-image patient_ct.nii.gz \\
--use-icon-refinement \\
@@ -71,11 +69,6 @@ def main() -> int:
required=True,
help="Path to template/generic heart model (.vtu, .vtk, .stl)",
)
- parser.add_argument(
- "--template-labelmap",
- required=True,
- help="Path to template labelmap image (.nii.gz, .nrrd, .mha)",
- )
parser.add_argument(
"--patient-models",
nargs="+",
@@ -84,8 +77,11 @@ def main() -> int:
)
parser.add_argument(
"--patient-image",
- required=True,
- help="Path to patient CT/MRI image (.nii.gz, .nrrd, .mha)",
+ help="Path to patient CT/MRI image (.nii.gz, .nrrd, .mha). If omitted, a reference image is created from the patient models.",
+ )
+ parser.add_argument(
+ "--template-labelmap",
+ help="Path to template labelmap image (.nii.gz, .nrrd, .mha). Required when --mask-to-image is set.",
)
parser.add_argument(
"--output-dir", required=True, help="Output directory for results"
@@ -135,11 +131,11 @@ def main() -> int:
help="Disable mask-to-mask deformable registration",
)
parser.add_argument(
- "--no-mask-to-image",
+ "--mask-to-image",
dest="use_mask_to_image",
- action="store_false",
- default=True,
- help="Disable mask-to-image refinement registration",
+ action="store_true",
+ default=False,
+ help="Enable mask-to-image refinement (requires --template-labelmap and label IDs)",
)
parser.add_argument(
"--use-icon-refinement",
@@ -163,19 +159,23 @@ def main() -> int:
print(f"Error: Template model not found: {args.template_model}")
return 1
- if not os.path.exists(args.template_labelmap):
- print(f"Error: Template labelmap not found: {args.template_labelmap}")
- return 1
-
for patient_model in args.patient_models:
if not os.path.exists(patient_model):
print(f"Error: Patient model not found: {patient_model}")
return 1
- if not os.path.exists(args.patient_image):
+ if args.patient_image is not None and not os.path.exists(args.patient_image):
print(f"Error: Patient image not found: {args.patient_image}")
return 1
+ if args.use_mask_to_image:
+ if args.template_labelmap is None:
+ print("Error: --template-labelmap is required when --mask-to-image is set.")
+ return 1
+ if not os.path.exists(args.template_labelmap):
+ print(f"Error: Template labelmap not found: {args.template_labelmap}")
+ return 1
+
if args.pca_json and not os.path.exists(args.pca_json):
print(f"Error: PCA JSON file not found: {args.pca_json}")
return 1
@@ -193,17 +193,25 @@ def main() -> int:
)
template_model: pv.UnstructuredGrid = template_model_raw
- print(f" Loading template labelmap: {args.template_labelmap}")
- template_labelmap = itk.imread(args.template_labelmap)
-
print(" Loading patient models:")
patient_models = []
for patient_model_file in args.patient_models:
print(f" - {patient_model_file}")
patient_models.append(pv.read(patient_model_file))
- print(f" Loading patient image: {args.patient_image}")
- patient_image = itk.imread(args.patient_image)
+ if args.patient_image is not None:
+ print(f" Loading patient image: {args.patient_image}")
+ patient_image = itk.imread(args.patient_image)
+ else:
+ patient_image = None
+ print(
+ " No patient image: reference image will be created from patient models"
+ )
+
+ template_labelmap = None
+ if args.template_labelmap is not None:
+ print(f" Loading template labelmap: {args.template_labelmap}")
+ template_labelmap = itk.imread(args.template_labelmap)
except (FileNotFoundError, OSError, RuntimeError) as e:
print(f"Error loading input data: {e}")
@@ -213,17 +221,27 @@ def main() -> int:
# Initialize workflow
print("\nInitializing heart model to patient registration workflow...")
try:
- workflow = WorkflowRegisterHeartModelToPatient(
+ workflow = WorkflowFitStatisticalModelToPatient(
template_model=template_model,
- template_labelmap=template_labelmap,
- template_labelmap_heart_muscle_ids=args.template_labelmap_muscle_ids,
- template_labelmap_chamber_ids=args.template_labelmap_chamber_ids,
- template_labelmap_background_ids=args.template_labelmap_background_ids,
patient_models=patient_models,
patient_image=patient_image,
- pca_json_filename=args.pca_json,
- pca_number_of_modes=args.pca_number_of_modes,
)
+ if args.pca_json is not None:
+ with open(args.pca_json, encoding="utf-8") as f:
+ pca_model = json.load(f)
+ workflow.set_use_pca_registration(
+ True,
+ pca_model=pca_model,
+ pca_number_of_modes=args.pca_number_of_modes,
+ )
+ if args.use_mask_to_image:
+ workflow.set_use_mask_to_image_registration(
+ True,
+ template_labelmap=template_labelmap,
+ template_labelmap_organ_mesh_ids=args.template_labelmap_muscle_ids,
+ template_labelmap_organ_extra_ids=args.template_labelmap_chamber_ids,
+ template_labelmap_background_ids=args.template_labelmap_background_ids,
+ )
except (ValueError, RuntimeError, OSError) as e:
print(f"Error initializing workflow: {e}")
traceback.print_exc()
diff --git a/src/physiomotion4d/cli/visualize_pca_modes.py b/src/physiomotion4d/cli/visualize_pca_modes.py
new file mode 100644
index 0000000..5db4eb6
--- /dev/null
+++ b/src/physiomotion4d/cli/visualize_pca_modes.py
@@ -0,0 +1,261 @@
+#!/usr/bin/env python
+"""
+Command-line interface to visualize PCA modes of variation.
+
+Displays the first three principal components in a 1x3 PyVista plotter. A
+slider (0 to 4) controls the standard-deviation magnitude; each subplot
+shows the mean shape (gray), +sigma shape (coral), and -sigma shape (blue)
+for that PC (matching experiments/Heart-Create_Statistical_Model/
+5-compute_pca_model.ipynb cell 11).
+
+Inputs: pca_mean_surface.vtp (mean mesh topology and mean shape) and
+pca_model.json (components and eigenvalues).
+"""
+
+import argparse
+import json
+import sys
+import traceback
+from pathlib import Path
+
+import numpy as np
+import pyvista as pv
+
+
+def _shape_at_sigma(
+ mean_shape: np.ndarray,
+ components: list,
+ eigenvalues: list,
+ pc_index: int,
+ sigma: float,
+) -> np.ndarray:
+ """Return (n_points, 3) mesh points for mean + sigma * std_dev * component."""
+ pc = np.asarray(components[pc_index], dtype=np.float64)
+ std_dev = np.sqrt(eigenvalues[pc_index])
+ variation = mean_shape + (sigma * std_dev * pc)
+ n_points = mean_shape.size // 3
+ return variation.reshape(n_points, 3)
+
+
+def _generate_pc_variation(
+ mean_mesh: pv.PolyData,
+ mean_shape: np.ndarray,
+ components: list,
+ eigenvalues: list,
+ pc_index: int,
+ std_dev_multiplier: float = 3.0,
+) -> tuple[pv.PolyData, pv.PolyData, pv.PolyData]:
+ """Generate shape variations along a principal component.
+
+ Parameters
+ ----------
+ mean_mesh : pv.PolyData
+ Template mesh (topology only; points are replaced).
+ mean_shape : np.ndarray
+ Flattened mean shape (n_points * 3,).
+ components : list
+ List of component vectors (each length n_points * 3).
+ eigenvalues : list
+ Variance (eigenvalue) for each component.
+ pc_index : int
+ Index of the principal component (0-based).
+ std_dev_multiplier : float
+ How many standard deviations to vary (default: +/-3 sigma).
+
+ Returns
+ -------
+ tuple of (negative_mesh, mean_mesh, positive_mesh)
+ """
+ pc = np.asarray(components[pc_index], dtype=np.float64)
+ std_dev = np.sqrt(eigenvalues[pc_index])
+
+ negative_variation = mean_shape - (std_dev_multiplier * std_dev * pc)
+ positive_variation = mean_shape + (std_dev_multiplier * std_dev * pc)
+
+ n_points = mean_shape.size // 3
+ negative_points = negative_variation.reshape(n_points, 3)
+ positive_points = positive_variation.reshape(n_points, 3)
+ mean_points = mean_shape.reshape(n_points, 3)
+
+ negative_mesh = mean_mesh.copy()
+ negative_mesh.points = negative_points
+
+ positive_mesh = mean_mesh.copy()
+ positive_mesh.points = positive_points
+
+ mean_mesh_out = mean_mesh.copy()
+ mean_mesh_out.points = mean_points
+
+ return negative_mesh, mean_mesh_out, positive_mesh
+
+
+def main() -> int:
+ """Command-line interface for visualizing PCA modes."""
+ parser = argparse.ArgumentParser(
+ description="Visualize the first three PCA modes with a 0-4 sigma slider; each subplot shows mean, +sigma, -sigma (PyVista 1x3 plot).",
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ epilog="""
+Examples:
+ # Slider 0-4 sigma; each subplot shows mean (gray), +sigma (coral), -sigma (blue)
+ %(prog)s pca_mean_surface.vtp pca_model.json
+
+ # Start with slider at 2 sigma
+ %(prog)s --std-dev 2.0 mean.vtp pca_model.json
+ """,
+ )
+
+ parser.add_argument(
+ "pca_mean_surface",
+ type=Path,
+ metavar="PCA_MEAN_SURFACE.vtp",
+ help="Path to PCA mean surface (.vtp); provides topology and mean shape.",
+ )
+ parser.add_argument(
+ "pca_model_json",
+ type=Path,
+ metavar="pca_model.json",
+ help="Path to PCA model JSON (components and eigenvalues).",
+ )
+ parser.add_argument(
+ "--std-dev",
+ type=float,
+ default=0.0,
+ metavar="N",
+ help="Initial slider position 0-4 std dev (default: 0).",
+ )
+
+ args = parser.parse_args()
+
+ if not args.pca_mean_surface.exists():
+ print(f"Error: PCA mean surface not found: {args.pca_mean_surface}")
+ return 1
+ if not args.pca_model_json.exists():
+ print(f"Error: PCA model JSON not found: {args.pca_model_json}")
+ return 1
+
+ try:
+ mean_mesh = pv.read(str(args.pca_mean_surface))
+ except (OSError, RuntimeError) as e:
+ print(f"Error loading PCA mean surface: {e}")
+ traceback.print_exc()
+ return 1
+
+ if not isinstance(mean_mesh, pv.PolyData):
+ print("Error: PCA mean surface must be a PolyData (.vtp).")
+ return 1
+
+ try:
+ with open(args.pca_model_json, encoding="utf-8") as f:
+ pca_model = json.load(f)
+ except (OSError, json.JSONDecodeError) as e:
+ print(f"Error loading PCA model JSON: {e}")
+ traceback.print_exc()
+ return 1
+
+ for key in ("components", "eigenvalues"):
+ if key not in pca_model:
+ print(f"Error: PCA model JSON must contain '{key}'.")
+ return 1
+
+ components = pca_model["components"]
+ eigenvalues = pca_model["eigenvalues"]
+
+ if len(components) < 3:
+ print(
+ f"Error: PCA model has only {len(components)} component(s); need at least 3 for visualization."
+ )
+ return 1
+
+ n_points = mean_mesh.n_points
+ n_features = n_points * 3
+ if len(components[0]) != n_features:
+ print(
+ f"Error: Mean surface has {n_points} points ({n_features} features), "
+ f"but PCA components have {len(components[0])} entries. Shapes must match."
+ )
+ return 1
+
+ mean_shape = mean_mesh.points.astype(np.float64).flatten()
+ n_points = mean_shape.size // 3
+
+ # Precompute component arrays for slider updates
+ pc_arrays = [np.asarray(components[i], dtype=np.float64) for i in range(3)]
+ std_devs = [np.sqrt(eigenvalues[i]) for i in range(3)]
+
+ plotter = pv.Plotter(shape=(1, 3))
+
+ # Slider 0 to 4; clamp initial to that range
+ initial_sigma = max(0.0, min(4.0, args.std_dev))
+
+ # Mean shape for reference (same in all subplots)
+ mean_ref = mean_mesh.copy()
+ mean_ref.points = mean_shape.reshape(n_points, 3)
+
+ # Per subplot: mean (static), +sigma mesh, -sigma mesh (updated by slider)
+ plus_meshes: list[pv.PolyData] = []
+ minus_meshes: list[pv.PolyData] = []
+ for col, pc_index in enumerate(range(3)):
+ points_plus = _shape_at_sigma(
+ mean_shape, components, eigenvalues, pc_index, initial_sigma
+ )
+ points_minus = _shape_at_sigma(
+ mean_shape, components, eigenvalues, pc_index, -initial_sigma
+ )
+ mesh_plus = mean_mesh.copy()
+ mesh_plus.points = points_plus
+ mesh_minus = mean_mesh.copy()
+ mesh_minus.points = points_minus
+ plus_meshes.append(mesh_plus)
+ minus_meshes.append(mesh_minus)
+ plotter.subplot(0, col)
+ plotter.add_mesh(
+ mean_ref.copy(),
+ color="lightgray",
+ opacity=0.4,
+ show_edges=False,
+ )
+ plotter.add_mesh(
+ mesh_minus,
+ color="lightblue",
+ opacity=1.0,
+ show_edges=False,
+ )
+ plotter.add_mesh(
+ mesh_plus,
+ color="lightcoral",
+ opacity=1.0,
+ show_edges=False,
+ )
+ plotter.add_text(
+ f"PC{pc_index + 1}",
+ font_size=12,
+ )
+ plotter.camera_position = "iso"
+
+ def _on_slider(sigma: float) -> None:
+ for i in range(3):
+ plus_pts = mean_shape + (sigma * std_devs[i] * pc_arrays[i])
+ minus_pts = mean_shape - (sigma * std_devs[i] * pc_arrays[i])
+ plus_meshes[i].points = plus_pts.reshape(n_points, 3)
+ minus_meshes[i].points = minus_pts.reshape(n_points, 3)
+ plotter.render()
+
+ # Slider: 0 to 4 std dev (shows mean, +sigma, -sigma)
+ plotter.subplot(0, 0)
+ plotter.add_slider_widget(
+ _on_slider,
+ rng=(0.0, 4.0),
+ value=initial_sigma,
+ title="Std dev",
+ pointa=(0.02, 0.1),
+ pointb=(0.98, 0.1),
+ )
+
+ plotter.link_views()
+ plotter.show()
+
+ return 0
+
+
+if __name__ == "__main__":
+ sys.exit(main())
diff --git a/src/physiomotion4d/contour_tools.py b/src/physiomotion4d/contour_tools.py
index 231bbc2..daf21db 100644
--- a/src/physiomotion4d/contour_tools.py
+++ b/src/physiomotion4d/contour_tools.py
@@ -265,14 +265,30 @@ def create_distance_map(
tmp_arr = np.zeros(size, dtype=np.int32)
itk_point = itk.Point[itk.D, 3]()
+ point_count = 0
for point in points:
itk_point[0] = float(point[0])
itk_point[1] = float(point[1])
itk_point[2] = float(point[2])
indx = reference_image.TransformPhysicalPointToIndex(itk_point)
+ if (
+ indx[0] < 0
+ or indx[1] < 0
+ or indx[2] < 0
+ or indx[0] >= size[0]
+ or indx[1] >= size[1]
+ or indx[2] >= size[2]
+ ):
+ continue
tmp_arr[indx[2], indx[1], indx[0]] = 1
+ point_count += 1
+
tmp_binary_image = itk.GetImageFromArray(tmp_arr.astype(np.uint8))
tmp_binary_image.CopyInformation(reference_image)
+ assert (
+ tmp_binary_image.GetLargestPossibleRegion().GetSize()
+ == reference_image.GetLargestPossibleRegion().GetSize()
+ )
distance_filter = itk.SignedMaurerDistanceMapImageFilter.New(
Input=tmp_binary_image
@@ -319,6 +335,15 @@ def create_deformation_field(
itk_point[1] = float(point[1])
itk_point[2] = float(point[2])
indx = reference_image.TransformPhysicalPointToIndex(itk_point)
+ if (
+ indx[0] < 0
+ or indx[1] < 0
+ or indx[2] < 0
+ or indx[0] >= size[0]
+ or indx[1] >= size[1]
+ or indx[2] >= size[2]
+ ):
+ continue
displacement_map_x[int(indx[2]), int(indx[1]), int(indx[0])] = (
point_displacements[i, 0]
)
@@ -332,6 +357,10 @@ def create_deformation_field(
norm_img = itk.GetImageFromArray(norm_map)
norm_img.CopyInformation(reference_image)
+ assert (
+ norm_img.GetLargestPossibleRegion().GetSize()
+ == reference_image.GetLargestPossibleRegion().GetSize()
+ )
blurred_norm = itk.SmoothingRecursiveGaussianImageFilter(
Input=norm_img, Sigma=blur_sigma
diff --git a/src/physiomotion4d/image_tools.py b/src/physiomotion4d/image_tools.py
index 3cc9d1d..ccb974d 100644
--- a/src/physiomotion4d/image_tools.py
+++ b/src/physiomotion4d/image_tools.py
@@ -6,7 +6,7 @@
"""
import logging
-from typing import Any
+from typing import Any, Optional
import itk
import numpy as np
@@ -245,3 +245,40 @@ def convert_array_to_image_of_vectors(
itk.array_view_from_image(itk_image)[:] = arr_data
return itk_image
+
+ def flip_image_to_identity_direction(
+ self, in_image: itk.Image, in_mask: Optional[itk.Image] = None
+ ) -> Any | tuple[Any, Any]:
+ """
+ Flip the image to the identity direction.
+ """
+ flip0 = np.array(in_image.GetDirection())[0, 0] < 0
+ flip1 = np.array(in_image.GetDirection())[1, 1] < 0
+ flip2 = np.array(in_image.GetDirection())[2, 2] < 0
+ if flip0 or flip1 or flip2:
+ self.log_info(
+ f"Flipping image to identity direction: {flip0}, {flip1}, {flip2}"
+ )
+ flip_filter = itk.FlipImageFilter.New(Input=in_image)
+ flip_filter.SetFlipAxes([int(flip0), int(flip1), int(flip2)])
+ flip_filter.SetFlipAboutOrigin(True)
+ flip_filter.Update()
+ out_image = flip_filter.GetOutput()
+ id_mat = itk.Matrix[itk.D, 3, 3]()
+ id_mat.SetIdentity()
+ out_image.SetDirection(id_mat)
+ if in_mask is not None:
+ flip_filter = itk.FlipImageFilter.New(Input=in_mask)
+ flip_filter.SetFlipAxes([int(flip0), int(flip1), int(flip2)])
+ flip_filter.SetFlipAboutOrigin(True)
+ flip_filter.Update()
+ out_mask = flip_filter.GetOutput()
+ out_mask.SetDirection(id_mat)
+ return out_image, out_mask
+ else:
+ return out_image
+ else:
+ if in_mask is not None:
+ return in_image, in_mask
+ else:
+ return in_image
diff --git a/src/physiomotion4d/physiomotion4d_base.py b/src/physiomotion4d/physiomotion4d_base.py
index c4a7451..336a111 100644
--- a/src/physiomotion4d/physiomotion4d_base.py
+++ b/src/physiomotion4d/physiomotion4d_base.py
@@ -228,7 +228,7 @@ def set_log_classes(cls, class_names: list[str]) -> None:
Args:
class_names: List of class names to show logs from.
- Example: ["RegisterModelsPCA", "WorkflowRegisterHeartModelToPatient"]
+ Example: ["RegisterModelsPCA", "WorkflowFitStatisticalModelToPatient"]
Example:
>>> PhysioMotion4DBase.set_log_classes(['RegisterModelsPCA'])
@@ -263,7 +263,7 @@ def get_log_classes(cls) -> list[str]:
Example:
>>> classes = PhysioMotion4DBase.get_log_classes()
>>> print(classes)
- ['RegisterModelsPCA', 'WorkflowRegisterHeartModelToPatient']
+ ['RegisterModelsPCA', 'WorkflowFitStatisticalModelToPatient']
"""
if cls._class_filter is not None and cls._class_filter.enabled:
return sorted(cls._class_filter.allowed_classes)
diff --git a/src/physiomotion4d/register_models_distance_maps.py b/src/physiomotion4d/register_models_distance_maps.py
index b50c7da..5ee8ead 100644
--- a/src/physiomotion4d/register_models_distance_maps.py
+++ b/src/physiomotion4d/register_models_distance_maps.py
@@ -120,7 +120,7 @@ def __init__(
moving_model: pv.PolyData,
fixed_model: pv.PolyData,
reference_image: itk.Image,
- roi_dilation_mm: float = 10,
+ roi_dilation_mm: float = 20,
log_level: int | str = logging.INFO,
):
"""Initialize mask-based model registration.
diff --git a/src/physiomotion4d/register_models_pca.py b/src/physiomotion4d/register_models_pca.py
index 895990e..952e70e 100644
--- a/src/physiomotion4d/register_models_pca.py
+++ b/src/physiomotion4d/register_models_pca.py
@@ -96,7 +96,7 @@ def __init__(
These are the square roots of pca_eigenvalues
pca_number_of_modes: Number of PCA modes to use. Default: -1 (use all)
pca_template_model_point_subsample: Step size for subsampling model points. Default: 4
- pre_pca_transform: Optional ITK transform to apply after PCA registration.
+ pre_pca_transform: Optional ITK transform to apply before PCA registration.
Default: None
fixed_distance_map: ITK image providing the distance map.
Default: None
@@ -161,21 +161,10 @@ def __init__(
self.registered_model_pca_coefficients: Optional[np.ndarray] = None
self.registered_model: Optional[pv.UnstructuredGrid] = None
self.registered_model_mean_distance: float = 0.0
- self.register_model_pca_deformation: Optional[np.ndarray] = None
+ self.registered_model_pca_deformation: Optional[np.ndarray] = None
self.forward_point_transform: Optional[itk.DisplacementFieldTransform] = None
self.inverse_point_transform: Optional[itk.DisplacementFieldTransform] = None
- self._template_model_pca_deformation_field_image: Optional[itk.Image] = None
- self._deformation_field_interpolator_x: Optional[
- itk.LinearInterpolateImageFunction
- ] = None
- self._deformation_field_interpolator_y: Optional[
- itk.LinearInterpolateImageFunction
- ] = None
- self._deformation_field_interpolator_z: Optional[
- itk.LinearInterpolateImageFunction
- ] = None
-
# Image interpolator (created when needed)
self._fixed_distance_map_interpolator: Optional[
itk.LinearInterpolateImageFunction
@@ -215,7 +204,7 @@ def from_json(
pca_json_filename: Path to the PCA model JSON file
pca_number_of_modes: Number of PCA modes to use. Default: 0 (use all)
pca_template_model_point_subsample: Step size for subsampling model points. Default: 4
- pre_pca_transform: Optional ITK transform to apply after PCA registration.
+ pre_pca_transform: Optional ITK transform to apply before PCA registration.
Default: None
fixed_distance_map: ITK image providing the distance values
for registration. If None, must be set later before registration.
@@ -283,6 +272,67 @@ def from_json(
logger.info(" ✓ Data validation successful!")
logger.info("PCA model data loaded successfully!")
+ return cls.from_pca_model(
+ pca_template_model=pca_template_model,
+ pca_model=pca_data,
+ pca_number_of_modes=pca_number_of_modes,
+ pca_template_model_point_subsample=pca_template_model_point_subsample,
+ pre_pca_transform=pre_pca_transform,
+ fixed_distance_map=fixed_distance_map,
+ fixed_model=fixed_model,
+ reference_image=reference_image,
+ log_level=log_level,
+ )
+
+ @classmethod
+ def from_pca_model(
+ cls,
+ pca_template_model: pv.UnstructuredGrid,
+ pca_model: dict,
+ pca_number_of_modes: int = 0,
+ pca_template_model_point_subsample: int = 4,
+ pre_pca_transform: Optional[itk.Transform] = None,
+ fixed_distance_map: Optional[itk.Image] = None,
+ fixed_model: Optional[pv.UnstructuredGrid] = None,
+ reference_image: Optional[itk.Image] = None,
+ log_level: int | str = logging.INFO,
+ ) -> Self:
+ """Create RegisterModelsPCA from a PCA model dictionary.
+
+ The dict must match the structure produced by
+ :class:`WorkflowCreateStatisticalModel` (key ``pca_model``):
+ ``explained_variance_ratio``, ``eigenvalues``, ``components``.
+
+ Args:
+ pca_template_model: Mean surface mesh to use as template
+ pca_model: PCA model dict with 'eigenvalues' and 'components' (and optionally
+ 'explained_variance_ratio')
+ pca_number_of_modes: Number of PCA modes to use. Default: 0 (use all)
+ pca_template_model_point_subsample: Step size for subsampling model points. Default: 4
+ pre_pca_transform: Optional ITK transform to apply before PCA registration.
+ fixed_distance_map: ITK image providing the distance values for registration.
+ fixed_model: Target surface mesh to register to.
+ reference_image: Reference image defining coordinate space.
+ log_level: Logging level.
+
+ Returns:
+ RegisterModelsPCA instance
+
+ Raises:
+ ValueError: If required keys are missing or dimensions invalid
+ """
+ if "eigenvalues" not in pca_model:
+ raise ValueError("'eigenvalues' field not found in pca_model")
+ pca_std_deviations = np.sqrt(np.array(pca_model["eigenvalues"]))
+ if "components" not in pca_model:
+ raise ValueError("'components' field not found in pca_model")
+ pca_eigenvectors = np.array(pca_model["components"], dtype=np.float64)
+ expected_size = pca_template_model.n_points * 3
+ if pca_eigenvectors.shape[1] != expected_size:
+ raise ValueError(
+ f"Component dimension mismatch: expected {expected_size} "
+ f"(3 × {pca_template_model.n_points} points), got {pca_eigenvectors.shape[1]}"
+ )
return cls(
pca_template_model=pca_template_model,
pca_eigenvectors=pca_eigenvectors,
@@ -597,8 +647,8 @@ def transform_template_model(self) -> pv.UnstructuredGrid:
self.log_info("Creating final registered model...")
# Compute PCA deformation
- if self.register_model_pca_deformation is None:
- self.register_model_pca_deformation = self._compute_pca_deformation(
+ if self.registered_model_pca_deformation is None:
+ self.registered_model_pca_deformation = self._compute_pca_deformation(
self.registered_model_pca_coefficients,
)
@@ -623,9 +673,9 @@ def transform_template_model(self) -> pv.UnstructuredGrid:
point = self.pre_pca_transform.TransformPoint(point)
# Add PCA deformation
- point[0] += self.register_model_pca_deformation[i, 0]
- point[1] += self.register_model_pca_deformation[i, 1]
- point[2] += self.register_model_pca_deformation[i, 2]
+ point[0] += self.registered_model_pca_deformation[i, 0]
+ point[1] += self.registered_model_pca_deformation[i, 1]
+ point[2] += self.registered_model_pca_deformation[i, 2]
# Store result
final_points[i, 0] = point[0]
@@ -655,8 +705,10 @@ def transform_point(
Returns:
Transformed ITK point
- Raises:
- ValueError: If registration has not been completed yet
+ Notes:
+ 1) if the point is outside the image bounds, the point is not transformed.
+ 2) if the pre_pca_transform is set and enabled, it is applied.
+ 3) if the forward point transform is not set, no errors are raised.
Example:
>>> p = itk.Point[itk.D, 3]()
@@ -664,81 +716,13 @@ def transform_point(
>>> transformed_p = registrar.transform_point(p)
"""
- if self._deformation_field_interpolator_x is None:
- field_array = itk.GetArrayFromImage(
- self._template_model_pca_deformation_field_image
- )
- field_x_image = itk.GetImageFromArray(field_array[:, :, :, 0])
- field_x_image.CopyInformation(
- self._template_model_pca_deformation_field_image
- )
- self._deformation_field_interpolator_x = itk.LinearInterpolateImageFunction[
- itk.Image[itk.D, 3], itk.D
- ].New()
- self._deformation_field_interpolator_x.SetInputImage(field_x_image)
-
- field_y_image = itk.GetImageFromArray(field_array[:, :, :, 1])
- field_y_image.CopyInformation(
- self._template_model_pca_deformation_field_image
- )
- self._deformation_field_interpolator_y = itk.LinearInterpolateImageFunction[
- itk.Image[itk.D, 3], itk.D
- ].New()
- self._deformation_field_interpolator_y.SetInputImage(field_y_image)
-
- field_z_image = itk.GetImageFromArray(field_array[:, :, :, 2])
- field_z_image.CopyInformation(
- self._template_model_pca_deformation_field_image
- )
- self._deformation_field_interpolator_z = itk.LinearInterpolateImageFunction[
- itk.Image[itk.D, 3], itk.D
- ].New()
- self._deformation_field_interpolator_z.SetInputImage(field_z_image)
-
- assert self._template_model_pca_deformation_field_image is not None, (
- "Deformation field image must be set"
- )
- assert self._deformation_field_interpolator_x is not None, (
- "Interpolator x must be initialized"
- )
- assert self._deformation_field_interpolator_y is not None, (
- "Interpolator y must be initialized"
- )
- assert self._deformation_field_interpolator_z is not None, (
- "Interpolator z must be initialized"
- )
if include_pre_pca_transform and self.pre_pca_transform is not None:
point = self.pre_pca_transform.TransformPoint(point)
- cindx = self._template_model_pca_deformation_field_image.TransformPhysicalPointToContinuousIndex(
- point
- )
- size = self._template_model_pca_deformation_field_image.GetLargestPossibleRegion().GetSize()
- if (
- cindx[0] < 0
- or cindx[0] >= size[0]
- or cindx[1] < 0
- or cindx[1] >= size[1]
- or cindx[2] < 0
- or cindx[2] >= size[2]
- ):
- self.log_error("Point is outside deformation field bounds")
- return point
-
- deformation_x = (
- self._deformation_field_interpolator_x.EvaluateAtContinuousIndex(cindx)
- )
- deformation_y = (
- self._deformation_field_interpolator_y.EvaluateAtContinuousIndex(cindx)
- )
- deformation_z = (
- self._deformation_field_interpolator_z.EvaluateAtContinuousIndex(cindx)
- )
-
- transformed_point = itk.Point[itk.D, 3]()
- transformed_point[0] = float(point[0] + deformation_x)
- transformed_point[1] = float(point[1] + deformation_y)
- transformed_point[2] = float(point[2] + deformation_z)
+ if self.forward_point_transform is not None:
+ transformed_point = self.forward_point_transform.TransformPoint(point)
+ else:
+ transformed_point = point
return transformed_point
@@ -750,13 +734,13 @@ def compute_pca_transforms(self, reference_image: itk.Image) -> dict:
- 'forward_point_transform': Forward displacement field transform
- 'inverse_point_transform': Inverse displacement field transform
"""
- assert self.register_model_pca_deformation is not None, (
+ assert self.registered_model_pca_deformation is not None, (
"PCA deformation must be computed"
)
- self._template_model_pca_deformation_field_image = (
+ template_model_pca_deformation_field_image = (
self._contour_tools.create_deformation_field(
np.array(self.pca_template_model.points),
- self.register_model_pca_deformation,
+ self.registered_model_pca_deformation,
reference_image=reference_image,
blur_sigma=2.5,
ptype=itk.D,
@@ -765,7 +749,7 @@ def compute_pca_transforms(self, reference_image: itk.Image) -> dict:
self.forward_point_transform = itk.DisplacementFieldTransform[itk.D, 3].New()
self.forward_point_transform.SetDisplacementField(
- self._template_model_pca_deformation_field_image
+ template_model_pca_deformation_field_image
)
transform_tools = TransformTools()
@@ -830,7 +814,7 @@ def register(
)
# Create final registered model
- self.register_model_pca_deformation = None
+ self.registered_model_pca_deformation = None
self.registered_model = self.transform_template_model()
# Return results as dictionary
diff --git a/src/physiomotion4d/segment_chest_base.py b/src/physiomotion4d/segment_anatomy_base.py
similarity index 97%
rename from src/physiomotion4d/segment_chest_base.py
rename to src/physiomotion4d/segment_anatomy_base.py
index b6f1f47..8a7e25d 100644
--- a/src/physiomotion4d/segment_chest_base.py
+++ b/src/physiomotion4d/segment_anatomy_base.py
@@ -1,7 +1,7 @@
-"""Base class for segmenting chest CT images.
+"""Base class for segmenting anatomy in CT images.
-This module provides the SegmentChestBase class that serves as a foundation
-for implementing different chest CT segmentation algorithms. It handles common
+This module provides the SegmentAnatomyBase class that serves as a foundation
+for implementing different anatomy CT segmentation algorithms. It handles common
preprocessing, postprocessing, and anatomical structure organization tasks.
"""
@@ -15,12 +15,12 @@
from physiomotion4d.physiomotion4d_base import PhysioMotion4DBase
-class SegmentChestBase(PhysioMotion4DBase):
- """Base class for chest segmentation that provides common functionality for
- segmenting chest CT images.
+class SegmentAnatomyBase(PhysioMotion4DBase):
+ """Base class for anatomy segmentation that provides common functionality for
+ segmenting anatomy in CT images.
This class implements preprocessing, postprocessing, and mask creation
- methods that are shared across different chest segmentation
+ methods that are shared across different anatomy segmentation
implementations. It defines anatomical structure mappings and provides
utilities for image preprocessing, intensity rescaling, and mask generation.
@@ -43,7 +43,7 @@ class SegmentChestBase(PhysioMotion4DBase):
"""
def __init__(self, log_level: int | str = logging.INFO):
- """Initialize the SegmentChestBase class.
+ """Initialize the SegmentAnatomyBase class.
Sets up default parameters for image preprocessing and anatomical
structure ID mappings. Subclasses should call this constructor and
@@ -385,6 +385,9 @@ def segment_connected_component(
InsideValue=1,
OutsideValue=0,
)
+ thresh_arr = itk.GetArrayFromImage(thresh_image).astype(np.int16)
+ thresh_image = itk.GetImageFromArray(thresh_arr)
+ thresh_image.CopyInformation(preprocessed_image)
label_arr = itk.GetArrayFromImage(labelmap_image)
if labelmap_ids is None:
diff --git a/src/physiomotion4d/segment_chest_ensemble.py b/src/physiomotion4d/segment_chest_ensemble.py
index 8e394a7..b51bb3c 100644
--- a/src/physiomotion4d/segment_chest_ensemble.py
+++ b/src/physiomotion4d/segment_chest_ensemble.py
@@ -12,12 +12,12 @@
import itk
import numpy as np
-from physiomotion4d.segment_chest_base import SegmentChestBase
+from physiomotion4d.segment_anatomy_base import SegmentAnatomyBase
from physiomotion4d.segment_chest_total_segmentator import SegmentChestTotalSegmentator
from physiomotion4d.segment_chest_vista_3d import SegmentChestVista3D
-class SegmentChestEnsemble(SegmentChestBase):
+class SegmentChestEnsemble(SegmentAnatomyBase):
"""
A class that inherits from physioSegmentChest and implements the
segmentation method using VISTA3D.
diff --git a/src/physiomotion4d/segment_chest_total_segmentator.py b/src/physiomotion4d/segment_chest_total_segmentator.py
index c2d8b18..d8bb53b 100644
--- a/src/physiomotion4d/segment_chest_total_segmentator.py
+++ b/src/physiomotion4d/segment_chest_total_segmentator.py
@@ -2,7 +2,7 @@
This module provides the SegmentChestTotalSegmentator class that implements
chest CT segmentation using the TotalSegmentator deep learning model. It inherits
-from SegmentChestBase and defines anatomical structure mappings specific to
+from SegmentAnatomyBase and defines anatomical structure mappings specific to
TotalSegmentator's output labels.
"""
@@ -15,10 +15,10 @@
import numpy as np
from totalsegmentator.python_api import totalsegmentator
-from physiomotion4d.segment_chest_base import SegmentChestBase
+from physiomotion4d.segment_anatomy_base import SegmentAnatomyBase
-class SegmentChestTotalSegmentator(SegmentChestBase):
+class SegmentChestTotalSegmentator(SegmentAnatomyBase):
"""
Chest CT segmentation using TotalSegmentator deep learning model.
diff --git a/src/physiomotion4d/segment_chest_vista_3d.py b/src/physiomotion4d/segment_chest_vista_3d.py
index 5337960..2373178 100644
--- a/src/physiomotion4d/segment_chest_vista_3d.py
+++ b/src/physiomotion4d/segment_chest_vista_3d.py
@@ -26,10 +26,10 @@
import torch
from huggingface_hub import snapshot_download
-from physiomotion4d.segment_chest_base import SegmentChestBase
+from physiomotion4d.segment_anatomy_base import SegmentAnatomyBase
-class SegmentChestVista3D(SegmentChestBase):
+class SegmentChestVista3D(SegmentAnatomyBase):
"""
Chest CT segmentation using NVIDIA VISTA-3D foundational model.
diff --git a/src/physiomotion4d/segment_heart_simpleware.py b/src/physiomotion4d/segment_heart_simpleware.py
index fec54b2..be3eba0 100644
--- a/src/physiomotion4d/segment_heart_simpleware.py
+++ b/src/physiomotion4d/segment_heart_simpleware.py
@@ -2,23 +2,24 @@
This module provides the SegmentHeartSimpleware class that implements
heart segmentation using Synopsys Simpleware Medical's ASCardio module.
-It inherits from SegmentChestBase and provides heart-specific anatomical
+It inherits from SegmentAnatomyBase and provides heart-specific anatomical
structure mappings.
"""
import logging
import os
import subprocess
+import sys
import tempfile
import itk
import numpy as np
from itk import TubeTK as tube
-from physiomotion4d.segment_chest_base import SegmentChestBase
+from physiomotion4d.segment_anatomy_base import SegmentAnatomyBase
-class SegmentHeartSimpleware(SegmentChestBase):
+class SegmentHeartSimpleware(SegmentAnatomyBase):
"""
Heart CT segmentation using Simpleware Medical's ASCardio module.
@@ -97,6 +98,8 @@ def __init__(self, log_level: int | str = logging.INFO):
# From Base Class
# self.contrast_mask_ids = {135: "contrast"}
+ self.trim_mesh_to_essentials = False
+
self.set_other_and_all_mask_ids()
# Path to Simpleware Medical console executable
@@ -109,6 +112,14 @@ def __init__(self, log_level: int | str = logging.INFO):
"SimplewareScript_heart_segmentation.py",
)
+ def set_trim_mesh_to_essentials(self, trim_mesh_to_essentials: bool) -> None:
+ """Set whether to trim mesh to common and critical structures.
+
+ Args:
+ trim_mesh_to_essentials (bool): Whether to reduce to essential.
+ """
+ self.trim_mesh_to_essentials = trim_mesh_to_essentials
+
def set_simpleware_executable_path(self, path: str) -> None:
"""Set the path to the Simpleware Medical console executable.
@@ -195,15 +206,43 @@ def segmentation_method(self, preprocessed_image: itk.image) -> itk.image:
self.log_info("Command: %s", " ".join(cmd))
try:
- # Run Simpleware Medical as a subprocess
- result = subprocess.run(
+ # Run Simpleware Medical as a subprocess. When the process exits,
+ # the OS frees all of its resources (GPU, memory); no extra
+ # cleanup is required. Using Popen so we can kill the process
+ # tree on timeout and ensure no child processes keep holding GPU.
+ proc = subprocess.Popen(
cmd,
- input=user_input,
- capture_output=True,
+ stdin=subprocess.PIPE,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
text=True,
- check=True,
- timeout=600, # 10 minute timeout
+ start_new_session=(
+ sys.platform != "win32"
+ ), # process group on Unix
)
+ try:
+ stdout, stderr = proc.communicate(
+ input=user_input,
+ timeout=600, # 10 minute timeout
+ )
+ except subprocess.TimeoutExpired:
+ # Kill process tree so GPU/memory are released (child may have spawned others)
+ if sys.platform == "win32":
+ subprocess.run(
+ ["taskkill", "/F", "/T", "/PID", str(proc.pid)],
+ capture_output=True,
+ )
+ else:
+ os.killpg(os.getpgid(proc.pid), 9)
+ proc.wait()
+ raise RuntimeError(
+ "Simpleware Medical segmentation timed out after 600 seconds"
+ )
+ if proc.returncode != 0:
+ raise subprocess.CalledProcessError(
+ proc.returncode, cmd, stdout, stderr
+ )
+ result = type("Result", (), {"stdout": stdout, "stderr": stderr})()
# Log output from Simpleware
if result.stdout:
@@ -211,10 +250,6 @@ def segmentation_method(self, preprocessed_image: itk.image) -> itk.image:
if result.stderr:
self.log_warning("Simpleware stderr:\n%s", result.stderr)
- except subprocess.TimeoutExpired as e:
- raise RuntimeError(
- f"Simpleware Medical segmentation timed out after 600 seconds: {e}"
- )
except subprocess.CalledProcessError as e:
raise RuntimeError(
f"Simpleware Medical segmentation failed with return code {e.returncode}:\n"
@@ -223,7 +258,7 @@ def segmentation_method(self, preprocessed_image: itk.image) -> itk.image:
# Simpleware's right ventricle, left atrium, right atrium correspond to
# the interior of those regions.
- mask_ids_of_interior_regions = [2, 3, 4]
+ mask_ids_of_interior_regions = [1, 2, 3, 4]
# Check if output file was created
sz = [s for s in preprocessed_image.GetLargestPossibleRegion().GetSize()]
@@ -265,6 +300,16 @@ def segmentation_method(self, preprocessed_image: itk.image) -> itk.image:
"ensure the ASCardio module ran successfully."
)
+ if self.trim_mesh_to_essentials:
+ z = labelmap_array.shape[2] - 1
+ z_classes = np.unique(labelmap_array[z, :, :])
+ heart_count = np.sum((c in [1, 2, 3, 4, 5]) for c in z_classes)
+ while heart_count < 3 and z > 0:
+ z -= 1
+ z_classes = np.unique(labelmap_array[z, :, :])
+ heart_count = np.sum((c in [1, 2, 3, 4, 5]) for c in z_classes)
+ if z < labelmap_array.shape[2] - 3:
+ labelmap_array[(z + 3) :, :, :] = 0
labelmap_image = itk.GetImageFromArray(labelmap_array.astype(np.uint8))
labelmap_image.CopyInformation(preprocessed_image)
diff --git a/src/physiomotion4d/simpleware_medical/README.md b/src/physiomotion4d/simpleware_medical/README.md
index d44039f..6c20b41 100644
--- a/src/physiomotion4d/simpleware_medical/README.md
+++ b/src/physiomotion4d/simpleware_medical/README.md
@@ -6,7 +6,7 @@ This directory contains integration code for using Synopsys Simpleware Medical w
The integration enables PhysioMotion4D to leverage Simpleware Medical's ASCardio module for automated cardiac segmentation. The implementation uses a two-component architecture:
-1. **segment_heart_simpleware.py** (in parent directory): A Python class that inherits from `SegmentChestBase` and manages the external Simpleware Medical process
+1. **segment_heart_simpleware.py** (in parent directory): A Python class that inherits from `SegmentAnatomyBase` and manages the external Simpleware Medical process
2. **SimplewareScript_heart_segmentation.py** (this directory): A Python script that runs within the Simpleware Medical environment and performs the actual segmentation using ASCardio
## Requirements
diff --git a/src/physiomotion4d/workflow_create_statistical_model.py b/src/physiomotion4d/workflow_create_statistical_model.py
new file mode 100644
index 0000000..6aece41
--- /dev/null
+++ b/src/physiomotion4d/workflow_create_statistical_model.py
@@ -0,0 +1,302 @@
+"""Create a PCA statistical shape model from a sample of meshes.
+
+This module provides the WorkflowCreateStatisticalModel class that implements
+the pipeline from the Heart-Create_Statistical_Model experiment notebooks:
+
+1. Extract surfaces from sample and reference meshes
+2. ICP alignment: align each sample surface to the reference (template) surface
+3. Deformable registration: establish dense correspondence via mask-based SyN
+4. Correspondence: warp reference surface by each transform to get aligned shapes
+5. PCA: compute mean and modes from corresponded shapes
+
+Returns a dictionary of surfaces, meshes, and PCA model structure (no file I/O).
+"""
+
+import logging
+from typing import Any, Optional
+
+import itk
+import numpy as np
+import pyvista as pv
+from sklearn.decomposition import PCA
+
+from physiomotion4d.contour_tools import ContourTools
+from physiomotion4d.physiomotion4d_base import PhysioMotion4DBase
+from physiomotion4d.register_models_distance_maps import RegisterModelsDistanceMaps
+from physiomotion4d.register_models_icp import RegisterModelsICP
+from physiomotion4d.transform_tools import TransformTools
+
+
+def _extract_surface(mesh: pv.DataSet) -> pv.PolyData:
+ """Extract surface from a mesh (PolyData or UnstructuredGrid)."""
+ if isinstance(mesh, pv.UnstructuredGrid):
+ return mesh.extract_surface()
+ if isinstance(mesh, pv.PolyData):
+ return mesh
+ return mesh.extract_surface()
+
+
+class WorkflowCreateStatisticalModel(PhysioMotion4DBase):
+ """Create a PCA statistical shape model from a sample of meshes aligned to a reference.
+
+ Pipeline (mirrors experiments/Heart-Create_Statistical_Model notebooks 1–5):
+ 1. Extract surfaces from sample meshes and reference mesh (reference surface = alignment target)
+ 2. ICP (affine) align each sample surface to the reference surface
+ 3. Deformable (ANTs SyN) registration of each aligned sample to reference
+ 4. Build corresponded shapes (reference topology) in reference space
+ 5. Compute PCA and return mean surface, reference mesh, and PCA model dict
+
+ Attributes:
+ sample_meshes (list): List of sample mesh DataSets (.vtk/.vtu/.vtp geometry)
+ reference_mesh (pv.DataSet): Reference mesh; its surface is used for alignment
+ pca_number_of_components (int): Number of PCA components to retain
+ reference_spatial_resolution (float): Resolution for reference image from mesh
+ reference_buffer_factor (float): Buffer around mesh for reference image
+ """
+
+ def __init__(
+ self,
+ sample_meshes: list[pv.DataSet],
+ reference_mesh: pv.DataSet,
+ pca_number_of_components: int = 15,
+ reference_spatial_resolution: float = 1.0,
+ reference_buffer_factor: float = 0.25,
+ log_level: int | str = logging.INFO,
+ ):
+ """Initialize the create-statistical-model workflow.
+
+ Args:
+ sample_meshes: List of sample mesh DataSets (PyVista PolyData or UnstructuredGrid).
+ reference_mesh: Reference mesh; its surface is used to align all samples.
+ pca_number_of_components: Number of PCA components. Default 15.
+ reference_spatial_resolution: Isotropic resolution (mm) for reference image. Default 1.0.
+ reference_buffer_factor: Buffer factor around mesh for reference image. Default 0.25.
+ log_level: Logging level.
+ """
+ super().__init__(
+ class_name="WorkflowCreateStatisticalModel", log_level=log_level
+ )
+ self.sample_meshes = list(sample_meshes)
+ self.reference_mesh = reference_mesh
+ self.pca_number_of_components = pca_number_of_components
+ self.reference_spatial_resolution = reference_spatial_resolution
+ self.reference_buffer_factor = reference_buffer_factor
+
+ self.contour_tools = ContourTools()
+ self.transform_tools = TransformTools()
+
+ # Set by pipeline
+ self.reference_surface: Optional[pv.PolyData] = None
+ self.sample_surfaces: list[pv.PolyData] = []
+ self.sample_ids: list[str] = []
+ self.aligned_surfaces: list[pv.PolyData] = []
+ self.forward_transforms: list = []
+ self.inverse_transforms: list = []
+ self.pca_input_surfaces: list[pv.PolyData] = []
+ self.pca_fitted: Optional[PCA] = None
+ self.pca_mean_surface: Optional[pv.PolyData] = None
+ self.pca_mean_mesh: Optional[pv.UnstructuredGrid] = None
+
+ def set_pca_number_of_components(self, n: int) -> None:
+ """Set number of PCA components to retain."""
+ self.pca_number_of_components = n
+
+ def _step1_extract_surfaces(self) -> None:
+ """Extract reference surface and all sample surfaces (notebook 1)."""
+ self.log_section("Step 1: Extract reference and sample surfaces", width=70)
+ if not self.sample_meshes:
+ raise ValueError("sample_meshes must not be empty")
+ self.reference_surface = _extract_surface(self.reference_mesh)
+ self.log_info(
+ "Reference surface: %d points",
+ self.reference_surface.n_points,
+ )
+ self.sample_surfaces = []
+ self.sample_ids = []
+ for i, mesh in enumerate(self.sample_meshes):
+ surface = _extract_surface(mesh)
+ self.sample_surfaces.append(surface)
+ self.sample_ids.append(str(i))
+ self.log_info("Extracted %d sample surfaces", len(self.sample_surfaces))
+
+ def _step2_icp_align(self) -> None:
+ """ICP (affine) align each sample surface to reference (notebook 2)."""
+ self.log_section("Step 2: ICP alignment to reference surface", width=70)
+ assert self.reference_surface is not None and self.sample_surfaces
+ self.aligned_surfaces = []
+ self.forward_transforms = []
+ self.inverse_transforms = []
+
+ for i, (sid, moving) in enumerate(zip(self.sample_ids, self.sample_surfaces)):
+ self.log_info(
+ "ICP aligning %s (%d/%d)", sid, i + 1, len(self.sample_surfaces)
+ )
+ if isinstance(moving, pv.UnstructuredGrid):
+ moving = moving.extract_surface()
+ registrar = RegisterModelsICP(fixed_model=self.reference_surface)
+ result = registrar.register(
+ moving_model=moving,
+ transform_type="Affine",
+ max_iterations=2000,
+ )
+ self.aligned_surfaces.append(result["registered_model"])
+ self.forward_transforms.append(result["forward_point_transform"])
+ self.inverse_transforms.append(result["inverse_point_transform"])
+
+ self.log_info(
+ "ICP alignment complete for %d samples", len(self.aligned_surfaces)
+ )
+
+ def _step3_deformable_correspondence(self) -> None:
+ """Deformable registration of each aligned sample to reference (notebook 3)."""
+ self.log_section("Step 3: Deformable registration (correspondence)", width=70)
+ assert self.reference_surface is not None and self.aligned_surfaces
+ reference_image = self.contour_tools.create_reference_image(
+ mesh=self.reference_surface,
+ spatial_resolution=self.reference_spatial_resolution,
+ buffer_factor=self.reference_buffer_factor,
+ ptype=itk.UC,
+ )
+ self.forward_transforms = []
+ self.inverse_transforms = []
+
+ for i, (sid, moving) in enumerate(zip(self.sample_ids, self.aligned_surfaces)):
+ self.log_info(
+ "Deformable registration %s (%d/%d)",
+ sid,
+ i + 1,
+ len(self.aligned_surfaces),
+ )
+ registrar = RegisterModelsDistanceMaps(
+ moving_model=moving,
+ fixed_model=self.reference_surface,
+ reference_image=reference_image,
+ )
+ result = registrar.register(
+ transform_type="Deformable",
+ use_icon=False,
+ )
+ self.forward_transforms.append(result["forward_transform"])
+ self.inverse_transforms.append(result["inverse_transform"])
+
+ self.log_info(
+ "Deformable registration complete for %d samples",
+ len(self.forward_transforms),
+ )
+
+ def _step4_build_pca_inputs(self) -> None:
+ """Build corresponded shapes in reference space (notebook 4).
+
+ For each case, reference_surface is warped by forward (image) deformation
+ (= inverse point) transform from step 3, so that we get reference topology
+ in ICP-aligned space with residual deformation per subject to be used as PCA
+ input.
+ """
+ self.log_section("Step 4: Build PCA inputs (corresponded shapes)", width=70)
+ assert self.reference_surface is not None and self.forward_transforms
+ self.pca_input_surfaces = []
+ for fwd_tfm in self.forward_transforms:
+ pca_input_surface = self.contour_tools.transform_contours(
+ self.reference_surface, tfm=fwd_tfm, with_deformation_magnitude=False
+ )
+ self.pca_input_surfaces.append(pca_input_surface)
+ self.log_info(
+ "Built %d corresponded surfaces for PCA", len(self.pca_input_surfaces)
+ )
+
+ def _step5_compute_pca(self) -> None:
+ """Compute PCA and mean surface (notebook 5)."""
+ self.log_section("Step 5: Compute PCA model", width=70)
+ assert self.reference_surface is not None and self.pca_input_surfaces
+ template = self.reference_surface
+ n_points = template.n_points
+
+ data_matrix = []
+ for i, mesh in enumerate(self.pca_input_surfaces):
+ if mesh.n_points != n_points:
+ raise ValueError(
+ f"Sample {self.sample_ids[i]} has {mesh.n_points} points, "
+ f"expected {n_points}. Topology must match."
+ )
+ data_matrix.append(mesh.points.flatten())
+ data_matrix = np.array(data_matrix)
+
+ if data_matrix.shape[0] - 1 < 2:
+ raise ValueError(
+ f"At least 2 samples are required for PCA. Got {data_matrix.shape[0]} samples."
+ )
+ n_comp = min(self.pca_number_of_components, data_matrix.shape[0] - 1)
+ if n_comp < self.pca_number_of_components:
+ self.log_warning(
+ "Reducing PCA components from %d to %d (n_samples=%d)",
+ self.pca_number_of_components,
+ n_comp,
+ data_matrix.shape[0],
+ )
+ self.pca_fitted = PCA(n_components=n_comp)
+ self.pca_fitted.fit(data_matrix)
+
+ self.pca_mean_surface = template.copy()
+ self.pca_mean_surface.points = self.pca_fitted.mean_.reshape(-1, 3)
+ self.log_info(
+ "PCA complete: %d components, variance explained %.4f",
+ len(self.pca_fitted.explained_variance_ratio_),
+ self.pca_fitted.explained_variance_ratio_.sum(),
+ )
+
+ reference_image = self.contour_tools.create_reference_image(
+ mesh=self.pca_mean_surface,
+ spatial_resolution=self.reference_spatial_resolution,
+ buffer_factor=self.reference_buffer_factor,
+ ptype=itk.UC,
+ )
+ mean_deformation_array = self.pca_mean_surface.points - template.points
+ mean_deformation_field = self.contour_tools.create_deformation_field(
+ points=template.points,
+ point_displacements=mean_deformation_array,
+ reference_image=reference_image,
+ blur_sigma=2.5,
+ ptype=itk.D,
+ )
+ mean_deformation_transform = itk.DisplacementFieldTransform[itk.D, 3].New()
+ mean_deformation_transform.SetDisplacementField(mean_deformation_field)
+ self.pca_mean_mesh = self.contour_tools.transform_contours(
+ self.reference_mesh,
+ tfm=mean_deformation_transform,
+ with_deformation_magnitude=False,
+ )
+
+ def _build_result(self) -> dict[str, Any]:
+ """Build result dictionary: surfaces, meshes, and PCA model structure."""
+ assert self.pca_mean_surface is not None and self.pca_fitted is not None
+ result: dict[str, Any] = {
+ "pca_mean_surface": self.pca_mean_surface,
+ "pca_mean_mesh": self.pca_mean_mesh,
+ "pca_model": {
+ "explained_variance_ratio": self.pca_fitted.explained_variance_ratio_.tolist(),
+ "eigenvalues": self.pca_fitted.explained_variance_.tolist(),
+ "components": [c.tolist() for c in self.pca_fitted.components_],
+ },
+ "pca_fitted": self.pca_fitted,
+ }
+ return result
+
+ def run_workflow(self) -> dict[str, Any]:
+ """Run the full pipeline and return a dictionary of results (no file I/O).
+
+ Returns:
+ dict with keys:
+ - pca_mean_surface: pv.PolyData mean shape surface
+ - pca_mean_mesh: pv.UnstructuredGrid reference volume mesh, or None if reference was surface-only
+ - pca_model: dict with "explained_variance_ratio", "eigenvalues", "components" (same structure as pca_model.json)
+ - pca_fitted: fitted sklearn PCA object
+ """
+ self.log_section("STARTING CREATE STATISTICAL MODEL WORKFLOW", width=70)
+ self._step1_extract_surfaces()
+ self._step2_icp_align()
+ self._step3_deformable_correspondence()
+ self._step4_build_pca_inputs()
+ self._step5_compute_pca()
+ result = self._build_result()
+ self.log_section("CREATE STATISTICAL MODEL WORKFLOW COMPLETE", width=70)
+ return result
diff --git a/src/physiomotion4d/workflow_register_heart_model_to_patient.py b/src/physiomotion4d/workflow_fit_statistical_model_to_patient.py
similarity index 74%
rename from src/physiomotion4d/workflow_register_heart_model_to_patient.py
rename to src/physiomotion4d/workflow_fit_statistical_model_to_patient.py
index c63e643..70b0f9f 100644
--- a/src/physiomotion4d/workflow_register_heart_model_to_patient.py
+++ b/src/physiomotion4d/workflow_fit_statistical_model_to_patient.py
@@ -1,6 +1,6 @@
"""Model-to-image and model-to-model registration for anatomical models.
-This module provides the WorkflowRegisterHeartModelToPatient class for registering generic
+This module provides the WorkflowFitStatisticalModelToPatient class for registering generic
anatomical models to patient-specific imaging data and surface models.
The workflow includes:
1. Rough alignment using ICP (RegisterModelsICP)
@@ -23,7 +23,7 @@
"""
import logging
-from typing import Optional
+from typing import Any, Optional
import itk
import numpy as np
@@ -40,7 +40,7 @@
from physiomotion4d.transform_tools import TransformTools
-class WorkflowRegisterHeartModelToPatient(PhysioMotion4DBase):
+class WorkflowFitStatisticalModelToPatient(PhysioMotion4DBase):
"""Register anatomical models using multi-stage ICP, mask-based, and image-based
registration.
@@ -79,8 +79,9 @@ class WorkflowRegisterHeartModelToPatient(PhysioMotion4DBase):
transform_tools (TransformTools): Transform utilities
registrar_icon (RegisterImagesICON): ICON registration instance
registrar_ants (RegisterImagesANTs): ANTs registration instance
- pca_json_filename (str): PCA JSON filename (optional)
- pca_number_of_modes (int): Number of PCA modes to use
+ use_pca_registration (bool): Whether PCA registration is enabled (set via set_use_pca_registration)
+ pca_model (dict): PCA model dict when PCA enabled; same structure as WorkflowCreateStatisticalModel output
+ pca_number_of_modes (int): Number of PCA modes when PCA enabled
icp_forward_point_transform : ICP transforms
icp_inverse_point_transform : ICP inverse transforms
icp_template_model_surface: template model surface after ICP alignment
@@ -100,33 +101,24 @@ class WorkflowRegisterHeartModelToPatient(PhysioMotion4DBase):
registered_template_model_surface: Final registered model surface
Example:
- >>> # Initialize with minimal parameters
- >>> registrar = WorkflowRegisterHeartModelToPatient(
+ >>> # Initialize with minimal parameters (no labelmap; no patient image -> reference created from patient models)
+ >>> registrar = WorkflowFitStatisticalModelToPatient(
... template_model=heart_model,
... patient_models=[lv_model, mc_model, rv_model],
- ... patient_image=patient_ct,
- ... pca_json_filename='path/to/pca_model.json',
- ... pca_number_of_modes=10,
... )
- >>>
- >>> # Optional: Configure parameters (masks auto-generated if not set)
>>> registrar.set_roi_dilation_mm(20)
- >>>
- >>> # Run registration
- >>> patient_model = registrar.run_workflow()
+ >>> # To enable PCA registration, call before run_workflow():
+ >>> # registrar.set_use_pca_registration(True, pca_model=pca_model_dict, pca_number_of_modes=10)
+ >>> # To enable mask-to-image refinement:
+ >>> # registrar.set_use_mask_to_image_registration(True, template_labelmap, organ_mesh_ids, organ_extra_ids, background_ids)
+ >>> result = registrar.run_workflow()
"""
def __init__(
self,
template_model: pv.UnstructuredGrid,
- template_labelmap: itk.Image,
- template_labelmap_heart_muscle_ids: list[int],
- template_labelmap_chamber_ids: list[int],
- template_labelmap_background_ids: list[int],
patient_models: list,
- patient_image: itk.Image,
- pca_json_filename: Optional[str] = None,
- pca_number_of_modes: int = 0,
+ patient_image: Optional[itk.Image] = None,
log_level: int | str = logging.INFO,
):
"""Initialize the model-to-image-and-model registration pipeline.
@@ -135,39 +127,47 @@ def __init__(
template_model: Generic anatomical model to be registered
patient_models: List of patient-specific models extracted from imaging
data. Typically 3 models for cardiac applications: LV, myocardium, RV.
- patient_image: Patient image data providing the target coordinate frame
- (origin, spacing, direction). Used as reference for registration.
+ patient_image: Optional patient image providing the target coordinate frame.
+ If None, a reference image is created from the patient model surface
+ via create_reference_image (contour_tools).
log_level: Logging level (logging.DEBUG, logging.INFO, logging.WARNING).
Default: logging.INFO
"""
# Initialize base class with logging
super().__init__(
- class_name="WorkflowRegisterHeartModelToPatient", log_level=log_level
+ class_name="WorkflowFitStatisticalModelToPatient", log_level=log_level
)
self.template_model = template_model
self.template_model_surface = template_model.extract_surface()
- self.template_labelmap = template_labelmap
- self.template_labelmap_heart_muscle_ids = template_labelmap_heart_muscle_ids
- self.template_labelmap_chamber_ids = template_labelmap_chamber_ids
- self.template_labelmap_background_ids = template_labelmap_background_ids
+ self.template_labelmap: Optional[itk.Image] = None
+ self.template_labelmap_organ_mesh_ids: Optional[list[int]] = None
+ self.template_labelmap_organ_extra_ids: Optional[list[int]] = None
+ self.template_labelmap_background_ids: Optional[list[int]] = None
self.patient_models = patient_models
patient_models_surfaces = [model.extract_surface() for model in patient_models]
combined_patient_model = pv.merge(patient_models_surfaces)
self.patient_model_surface = combined_patient_model.extract_surface()
- self.patient_image = patient_image
-
- resampler = ttk.ResampleImage.New(Input=self.patient_image)
- resampler.SetMakeHighResIso(True)
- resampler.Update()
- self.patient_image = resampler.GetOutput()
-
- # Utilities
+ # Utilities (needed for create_reference_image when patient_image is None)
self.transform_tools = TransformTools()
self.contour_tools = ContourTools()
+ if patient_image is not None:
+ self.patient_image = patient_image
+ resampler = ttk.ResampleImage.New(Input=self.patient_image)
+ resampler.SetMakeHighResIso(True)
+ resampler.Update()
+ self.patient_image = resampler.GetOutput()
+ else:
+ self.patient_image = self.contour_tools.create_reference_image(
+ mesh=self.patient_model_surface,
+ spatial_resolution=1.0,
+ buffer_factor=0.25,
+ ptype=itk.F,
+ )
+
self.registrar_ants = RegisterImagesANTs()
self.registrar_ants.set_number_of_iterations([5, 2, 5])
# Icon registration for final mask-to-image step
@@ -194,12 +194,13 @@ def __init__(
self.icp_template_model_surface: Optional[pv.PolyData] = None
self.icp_template_labelmap: Optional[itk.Image] = None
- # Stage 1.5: PCA registration results (optional)
+ # Stage 1.5: PCA registration results (optional; enable via set_use_pca_registration(True, pca_model, pca_number_of_modes))
+ self.use_pca_registration = False
self.pca_registrar: Optional[RegisterModelsPCA] = None
self.pca_forward_point_transform: Optional[itk.Transform] = None
self.pca_inverse_point_transform: Optional[itk.Transform] = None
- self.pca_json_filename = pca_json_filename
- self.pca_number_of_modes = pca_number_of_modes
+ self.pca_model: Optional[dict[str, Any]] = None
+ self.pca_number_of_modes: int = 0
self.pca_coefficients: Optional[np.ndarray] = None
self.pca_template_model_surface: Optional[pv.PolyData] = None
self.pca_template_labelmap: Optional[itk.Image] = None
@@ -211,8 +212,8 @@ def __init__(
self.m2m_template_model_surface: Optional[pv.PolyData] = None
self.m2m_template_labelmap: Optional[itk.Image] = None
- # Stage 3: Mask-to-image registration results
- self.use_m2i_registration = True
+ # Stage 3: Mask-to-image registration results (disabled by default; enable via set_use_mask_to_image_registration(True, template_labelmap, ...))
+ self.use_m2i_registration = False
self.m2i_inverse_transform: Optional[itk.Transform] = None
self.m2i_forward_transform: Optional[itk.Transform] = None
self.m2i_template_model_surface: Optional[pv.PolyData] = None
@@ -320,6 +321,39 @@ def set_roi_dilation_mm(self, roi_dilation_mm: float) -> None:
"""
self.roi_dilation_mm = roi_dilation_mm
+ def set_use_pca_registration(
+ self,
+ use_pca_registration: bool,
+ pca_model: Optional[dict[str, Any]] = None,
+ pca_number_of_modes: int = 0,
+ ) -> None:
+ """Set whether to use PCA-based registration and provide the PCA model.
+
+ When enabling (True), pca_model and pca_number_of_modes must be provided.
+
+ Args:
+ use_pca_registration: Whether to use PCA registration after ICP.
+ pca_model: Required when use is True. PCA model dict (e.g. from
+ WorkflowCreateStatisticalModel result["pca_model"]) with keys
+ "eigenvalues" and "components".
+ pca_number_of_modes: Required when use is True. Number of PCA modes to use.
+ Default 0 means use all modes.
+
+ Raises:
+ ValueError: If use is True and pca_model is None.
+ """
+ if use_pca_registration:
+ if pca_model is None:
+ raise ValueError(
+ "When enabling PCA registration, pca_model must be provided."
+ )
+ self.pca_model = pca_model
+ self.pca_number_of_modes = pca_number_of_modes
+ else:
+ self.pca_model = None
+ self.pca_number_of_modes = 0
+ self.use_pca_registration = use_pca_registration
+
def set_use_mask_to_mask_registration(
self, use_mask_to_mask_registration: bool
) -> None:
@@ -332,13 +366,57 @@ def set_use_mask_to_mask_registration(
self.use_m2m_registration = use_mask_to_mask_registration
def set_use_mask_to_image_registration(
- self, use_mask_to_image_registration: bool
+ self,
+ use_mask_to_image_registration: bool,
+ template_labelmap: Optional[itk.Image] = None,
+ template_labelmap_organ_mesh_ids: Optional[list[int]] = None,
+ template_labelmap_organ_extra_ids: Optional[list[int]] = None,
+ template_labelmap_background_ids: Optional[list[int]] = None,
) -> None:
"""Set whether to use mask-to-image registration.
+ When enabling (True), a template labelmap and label IDs must be provided
+ so the workflow can propagate and refine the labelmap to the patient image.
+
Args:
- use_m2i: Whether to use mask-to-image registration. Default: True
+ use_mask_to_image_registration: Whether to use mask-to-image registration.
+ template_labelmap: Required when use is True. Template labelmap in template
+ model space (same geometry as template_model).
+ template_labelmap_organ_mesh_ids: Required when use is True. Label IDs for
+ organ mesh in the template labelmap.
+ template_labelmap_organ_extra_ids: Required when use is True. Label IDs for
+ organ-extra structures in the template labelmap.
+ template_labelmap_background_ids: Required when use is True. Label IDs for
+ background in the template labelmap.
+
+ Raises:
+ ValueError: If use is True and any of template_labelmap or the id lists
+ is None or missing.
"""
+ if use_mask_to_image_registration:
+ if template_labelmap is None:
+ raise ValueError(
+ "When enabling mask-to-image registration, template_labelmap must be provided."
+ )
+ if template_labelmap_organ_mesh_ids is None:
+ raise ValueError(
+ "When enabling mask-to-image registration, "
+ "template_labelmap_organ_mesh_ids must be provided."
+ )
+ if template_labelmap_organ_extra_ids is None:
+ raise ValueError(
+ "When enabling mask-to-image registration, "
+ "template_labelmap_organ_extra_ids must be provided."
+ )
+ if template_labelmap_background_ids is None:
+ raise ValueError(
+ "When enabling mask-to-image registration, "
+ "template_labelmap_background_ids must be provided."
+ )
+ self.template_labelmap = template_labelmap
+ self.template_labelmap_organ_mesh_ids = template_labelmap_organ_mesh_ids
+ self.template_labelmap_organ_extra_ids = template_labelmap_organ_extra_ids
+ self.template_labelmap_background_ids = template_labelmap_background_ids
self.use_m2i_registration = use_mask_to_image_registration
def register_model_to_model_icp(self) -> dict:
@@ -370,12 +448,15 @@ def register_model_to_model_icp(self) -> dict:
self.icp_inverse_point_transform = icp_result["inverse_point_transform"]
self.icp_template_model_surface = icp_result["registered_model"]
- self.icp_template_labelmap = self.transform_tools.transform_image(
- self.template_labelmap,
- self.icp_inverse_point_transform,
- self.patient_image,
- interpolation_method="nearest",
- )
+ if self.template_labelmap is not None:
+ self.icp_template_labelmap = self.transform_tools.transform_image(
+ self.template_labelmap,
+ self.icp_inverse_point_transform,
+ self.patient_image,
+ interpolation_method="nearest",
+ )
+ else:
+ self.icp_template_labelmap = None
self.log_info("Stage 1 complete: ICP alignment finished.")
@@ -408,16 +489,17 @@ def register_model_to_model_pca(self) -> dict:
width=70,
)
- if self.pca_json_filename is None:
+ if not self.use_pca_registration or self.pca_model is None:
self.pca_template_model_surface = self.icp_template_model_surface
+ self.pca_template_labelmap = self.icp_template_labelmap
return {
"pca_coefficients": None,
"registered_model_surface": self.pca_template_model_surface,
}
- self.pca_registrar = RegisterModelsPCA.from_json(
+ self.pca_registrar = RegisterModelsPCA.from_pca_model(
pca_template_model=self.icp_template_model_surface,
- pca_json_filename=self.pca_json_filename,
+ pca_model=self.pca_model,
pca_number_of_modes=self.pca_number_of_modes,
fixed_model=self.patient_model_surface,
reference_image=self.patient_image,
@@ -439,12 +521,15 @@ def register_model_to_model_pca(self) -> dict:
self.registered_template_model_surface = self.pca_template_model_surface
- self.pca_template_labelmap = self.transform_tools.transform_image(
- self.icp_template_labelmap,
- self.pca_inverse_point_transform,
- self.patient_image,
- interpolation_method="nearest",
- )
+ if self.icp_template_labelmap is not None:
+ self.pca_template_labelmap = self.transform_tools.transform_image(
+ self.icp_template_labelmap,
+ self.pca_inverse_point_transform,
+ self.patient_image,
+ interpolation_method="nearest",
+ )
+ else:
+ self.pca_template_labelmap = None
self.log_info("Stage 2 complete: PCA registration finished.")
@@ -503,12 +588,15 @@ def register_mask_to_mask(
self.registered_template_model_surface = self.m2m_template_model_surface
- self.m2m_template_labelmap = self.transform_tools.transform_image(
- self.pca_template_labelmap,
- self.m2m_forward_transform,
- self.patient_image,
- interpolation_method="nearest",
- )
+ if self.pca_template_labelmap is not None:
+ self.m2m_template_labelmap = self.transform_tools.transform_image(
+ self.pca_template_labelmap,
+ self.m2m_forward_transform,
+ self.patient_image,
+ interpolation_method="nearest",
+ )
+ else:
+ self.m2m_template_labelmap = None
self.log_info("Stage 3 complete: Mask-to-mask registration finished.")
@@ -537,6 +625,24 @@ def register_labelmap_to_image(
"Stage 4: Labelmap-to-Image Refinement (Icon Registration)", width=70
)
+ if (
+ self.template_labelmap is None
+ or self.template_labelmap_organ_mesh_ids is None
+ or self.template_labelmap_organ_extra_ids is None
+ or self.template_labelmap_background_ids is None
+ ):
+ raise ValueError(
+ "Mask-to-image registration requires template labelmap and label IDs. "
+ "Call set_use_mask_to_image_registration(True, template_labelmap, "
+ "organ_mesh_ids, organ_extra_ids, background_ids) before run_workflow()."
+ )
+ if self.m2m_template_labelmap is None:
+ raise ValueError(
+ "Mask-to-image registration requires a labelmap to have been set "
+ "(via set_use_mask_to_image_registration(True, ...)) before running "
+ "earlier stages so the labelmap is propagated through ICP/PCA/M2M."
+ )
+
labelmap_arr = itk.GetArrayFromImage(self.m2m_template_labelmap).astype(
np.uint16
)
@@ -546,12 +652,14 @@ def register_labelmap_to_image(
labelmap_arr,
)
labelmap_arr = np.where(
- np.isin(labelmap_arr, self.template_labelmap_heart_muscle_ids),
+ np.isin(labelmap_arr, self.template_labelmap_organ_mesh_ids),
0,
labelmap_arr,
)
labelmap_arr = np.where(
- np.isin(labelmap_arr, self.template_labelmap_chamber_ids), 1, labelmap_arr
+ np.isin(labelmap_arr, self.template_labelmap_organ_extra_ids),
+ 1,
+ labelmap_arr,
)
labelmap = itk.GetImageFromArray(labelmap_arr)
labelmap.CopyInformation(self.m2m_template_labelmap)
@@ -691,7 +799,7 @@ def transform_model(
def run_workflow(
self,
- use_mask_to_image_registration: bool = True,
+ use_mask_to_image_registration: bool = False,
use_mask_to_mask_registration: bool = True,
use_icon_registration_refinement: bool = False,
) -> dict:
@@ -701,21 +809,20 @@ def run_workflow(
1. ICP alignment (RegisterModelsICP)
2. PCA registration (PCA data was provided)
3. Mask-to-mask deformable registration (RegisterModelsDistanceMaps)
- 4. Optional mask-to-image refinement (Icon)
+ 4. Optional mask-to-image refinement (Icon); requires template labelmap and IDs
+ set via set_use_mask_to_image_registration(True, ...).
Args:
use_mask_to_image_registration: Whether to include mask-to-image
- registration stage.
- Default: True
+ registration stage. Default: False. When True, template labelmap and
+ label IDs must have been set via set_use_mask_to_image_registration(True, ...).
use_mask_to_mask_registration: Whether to include mask-to-mask registration
- stage.
- Default: True
+ stage. Default: True
use_icon_registration_refinement: Whether to include icon registration
- refinement stage.
- Default: False
+ refinement stage. Default: False
Returns:
- pv.UnstructuredGrid: Final registered model
+ dict with registered_template_model and registered_template_model_surface
"""
self.log_section(
"STARTING COMPLETE MODEL-TO-IMAGE-AND-MODEL REGISTRATION WORKFLOW", width=70
diff --git a/statistics.md b/statistics.md
index 4cfa1fe..fe67ee9 100644
--- a/statistics.md
+++ b/statistics.md
@@ -45,9 +45,9 @@ PhysioMotion4D is a sophisticated medical imaging package for generating anatomi
| --------------------------------------------- | ----- | ---------------------------------------------- |
| `transform_tools.py` | 1,142 | Transform manipulation utilities |
| `register_models_pca.py` | 818 | PCA-based statistical shape model registration |
-| `workflow_register_heart_model_to_patient.py` | 745 | Model-to-patient registration workflow |
+| `workflow_fit_statistical_model_to_patient.py` | 745 | Model-to-patient registration workflow |
| `register_images_ants.py` | 725 | ANTs-based image registration |
-| `segment_chest_base.py` | 672 | Base class for chest segmentation |
+| `segment_anatomy_base.py` | 672 | Base class for anatomy segmentation |
| `convert_vtk_to_usd_polymesh.py` | 622 | Polymesh USD conversion |
| `convert_vtk_to_usd_base.py` | 585 | Base USD conversion functionality |
| `workflow_convert_heart_gated_ct_to_usd.py` | 539 | Heart CT to USD workflow |