diff --git a/README.md b/README.md index 6c83125..e7324ef 100644 --- a/README.md +++ b/README.md @@ -88,13 +88,14 @@ print(f"PhysioMotion4D version: {physiomotion4d.__version__}") - **Workflow Classes**: Complete end-to-end pipeline processors - `WorkflowConvertHeartGatedCTToUSD`: Heart-gated CT to USD processing workflow - - `WorkflowRegisterHeartModelToPatient`: Model-to-patient registration workflow + - `WorkflowCreateStatisticalModel`: Create PCA statistical shape model from sample meshes + - `WorkflowFitStatisticalModelToPatient`: Model-to-patient registration workflow - **Segmentation Classes**: Multiple AI-based chest segmentation implementations - `SegmentChestTotalSegmentator`: TotalSegmentator-based segmentation - `SegmentChestVista3D`: VISTA-3D model-based segmentation - `SegmentChestVista3DNIM`: NVIDIA NIM version of VISTA-3D - `SegmentChestEnsemble`: Ensemble segmentation combining multiple methods - - `SegmentChestBase`: Base class for custom segmentation methods + - `SegmentAnatomyBase`: Base class for custom segmentation methods - **Registration Classes**: Multiple registration methods for different use cases - Image-to-Image Registration: - `RegisterImagesICON`: Deep learning-based registration using Icon algorithm @@ -156,13 +157,34 @@ physiomotion4d-heart-gated-ct cardiac.nrrd \ For Python API usage and advanced customization, see the examples below or refer to the CLI implementation in `src/physiomotion4d/cli/`. +#### Create Statistical Model + +Build a PCA statistical shape model from sample meshes aligned to a reference: + +```bash +# From a directory of sample meshes +physiomotion4d-create-statistical-model \ + --sample-meshes-dir ./input_meshes \ + --reference-mesh average_mesh.vtk \ + --output-dir ./pca_output + +# With custom PCA components +physiomotion4d-create-statistical-model \ + --sample-meshes-dir ./meshes \ + --reference-mesh average_mesh.vtk \ + --output-dir ./pca_output \ + --pca-components 20 +``` + +Outputs: `pca_mean_surface.vtp`, `pca_mean.vtu` (if reference is volumetric), and `pca_model.json`. + #### Heart Model to Patient Registration Register a generic heart model to patient-specific data: ```bash # Basic registration -physiomotion4d-register-heart-model \ +physiomotion4d-fit-statistical-model-to-patient \ --template-model heart_model.vtu \ --template-labelmap heart_labelmap.nii.gz \ --patient-models lv.vtp rv.vtp myo.vtp \ @@ -170,7 +192,7 @@ physiomotion4d-register-heart-model \ --output-dir ./results # With PCA shape fitting -physiomotion4d-register-heart-model \ +physiomotion4d-fit-statistical-model-to-patient \ --template-model heart_model.vtu \ --template-labelmap heart_labelmap.nii.gz \ --patient-models lv.vtp rv.vtp myo.vtp \ @@ -203,7 +225,7 @@ final_usd = processor.process() ### Python API - Model to Patient Registration ```python -from physiomotion4d import WorkflowRegisterHeartModelToPatient +from physiomotion4d import WorkflowFitStatisticalModelToPatient import pyvista as pv import itk @@ -213,7 +235,7 @@ patient_surfaces = [pv.read("lv.stl"), pv.read("rv.stl")] reference_image = itk.imread("patient_ct.nii.gz") # Initialize and run workflow -workflow = WorkflowRegisterHeartModelToPatient( +workflow = WorkflowFitStatisticalModelToPatient( moving_mesh=model_mesh, fixed_meshes=patient_surfaces, fixed_image=reference_image @@ -352,13 +374,13 @@ PhysioMotion4D provides standardized logging through the `PhysioMotion4DBase` cl ```python import logging -from physiomotion4d import WorkflowRegisterHeartModelToPatient, PhysioMotion4DBase +from physiomotion4d import WorkflowFitStatisticalModelToPatient, PhysioMotion4DBase # Control logging level globally for all classes PhysioMotion4DBase.set_log_level(logging.DEBUG) # Or filter to show logs from specific classes only -PhysioMotion4DBase.set_log_classes(["WorkflowRegisterHeartModelToPatient", "RegisterModelsPCA"]) +PhysioMotion4DBase.set_log_classes(["WorkflowFitStatisticalModelToPatient", "RegisterModelsPCA"]) # Show all classes again PhysioMotion4DBase.set_log_all_classes() @@ -436,7 +458,7 @@ Advanced registration between generic anatomical models and patient-specific dat - **`heart_model_to_model_registration_pca.ipynb`**: PCA-based statistical shape model registration - **`heart_model_to_patient.ipynb`**: Complete model-to-patient registration workflow -Uses the `WorkflowRegisterHeartModelToPatient` class for three-stage registration: +Uses the `WorkflowFitStatisticalModelToPatient` class for three-stage registration: 1. ICP-based rough alignment 2. Mask-to-mask deformable registration 3. Optional PCA-constrained shape fitting diff --git a/docs/api/index.rst b/docs/api/index.rst index 611a778..4ff8f53 100644 --- a/docs/api/index.rst +++ b/docs/api/index.rst @@ -76,10 +76,11 @@ By Category **Workflows** * :class:`~physiomotion4d.WorkflowConvertHeartGatedCTToUSD` - Heart CT to USD - * :class:`~physiomotion4d.WorkflowRegisterHeartModelToPatient` - Heart model registration + * :class:`~physiomotion4d.WorkflowCreateStatisticalModel` - Create PCA statistical shape model + * :class:`~physiomotion4d.WorkflowFitStatisticalModelToPatient` - Heart model registration **Segmentation** - * :class:`~physiomotion4d.SegmentChestBase` - Base segmentation class + * :class:`~physiomotion4d.SegmentAnatomyBase` - Base segmentation class * :class:`~physiomotion4d.SegmentChestTotalSegmentator` - TotalSegmentator * :class:`~physiomotion4d.SegmentChestVista3D` - VISTA-3D model * :class:`~physiomotion4d.SegmentChestVista3DNIM` - VISTA-3D NIM diff --git a/docs/api/segmentation/base.rst b/docs/api/segmentation/base.rst index 4446f22..895dd90 100644 --- a/docs/api/segmentation/base.rst +++ b/docs/api/segmentation/base.rst @@ -9,7 +9,7 @@ Abstract base class for all segmentation methods. Class Reference =============== -.. autoclass:: SegmentChestBase +.. autoclass:: SegmentAnatomyBase :members: :undoc-members: :show-inheritance: @@ -18,7 +18,7 @@ Class Reference Overview ======== -:class:`SegmentChestBase` provides the foundation for all segmentation implementations in PhysioMotion4D. It defines the common interface and shared functionality that all segmentation methods must implement. +:class:`SegmentAnatomyBase` provides the foundation for all segmentation implementations in PhysioMotion4D. It defines the common interface and shared functionality that all segmentation methods must implement. **Key Responsibilities**: * Define standard segmentation interface @@ -87,17 +87,17 @@ These methods are provided by the base class: Creating Custom Segmentation Classes ===================================== -To create a new segmentation method, inherit from :class:`SegmentChestBase`: +To create a new segmentation method, inherit from :class:`SegmentAnatomyBase`: Basic Implementation -------------------- .. code-block:: python - from physiomotion4d import SegmentChestBase + from physiomotion4d import SegmentAnatomyBase import numpy as np - class CustomSegmentator(SegmentChestBase): + class CustomSegmentator(SegmentAnatomyBase): """Custom segmentation implementation.""" def __init__(self, param1=None, verbose=False): @@ -150,7 +150,7 @@ With Custom Post-Processing .. code-block:: python - class CustomSegmentator(SegmentChestBase): + class CustomSegmentator(SegmentAnatomyBase): """Segmentator with custom post-processing.""" def post_process(self, labelmap): @@ -243,7 +243,7 @@ Validate that required structures are present: .. code-block:: python - class ValidatingSegmentator(SegmentChestBase): + class ValidatingSegmentator(SegmentAnatomyBase): """Segmentator with validation.""" def segment(self, image_path): @@ -268,7 +268,7 @@ Track segmentation progress for long operations: .. code-block:: python - class ProgressSegmentator(SegmentChestBase): + class ProgressSegmentator(SegmentAnatomyBase): """Segmentator with progress tracking.""" def segment(self, image_path): diff --git a/docs/api/segmentation/index.rst b/docs/api/segmentation/index.rst index 0964ff0..d326ef1 100644 --- a/docs/api/segmentation/index.rst +++ b/docs/api/segmentation/index.rst @@ -16,7 +16,7 @@ PhysioMotion4D supports multiple segmentation approaches: * **VISTA-3D NIM**: NVIDIA Inference Microservice version * **Ensemble**: Combine multiple methods for improved accuracy -All segmentation classes inherit from :class:`SegmentChestBase` and provide consistent interfaces. +All segmentation classes inherit from :class:`SegmentAnatomyBase` and provide consistent interfaces. Quick Links =========== diff --git a/docs/api/workflows.rst b/docs/api/workflows.rst index 811e3a4..5239bfd 100644 --- a/docs/api/workflows.rst +++ b/docs/api/workflows.rst @@ -49,7 +49,7 @@ Heart Gated CT to USD Heart Model to Patient Registration ------------------------------------ -.. autoclass:: WorkflowRegisterHeartModelToPatient +.. autoclass:: WorkflowFitStatisticalModelToPatient :members: :undoc-members: :show-inheritance: @@ -67,9 +67,9 @@ Heart Model to Patient Registration .. code-block:: python - from physiomotion4d import WorkflowRegisterHeartModelToPatient + from physiomotion4d import WorkflowFitStatisticalModelToPatient - workflow = WorkflowRegisterHeartModelToPatient( + workflow = WorkflowFitStatisticalModelToPatient( model_file="heart_template.vtk", patient_image="patient_ct.nrrd", output_directory="./registration_results", @@ -221,7 +221,7 @@ Combine multiple workflows: heart_result = heart_workflow.process() # Second workflow: Register model - registration_workflow = WorkflowRegisterHeartModelToPatient( + registration_workflow = WorkflowFitStatisticalModelToPatient( model_file=heart_result['model'], patient_image=patient_ct, output_directory="./registration_output" diff --git a/docs/architecture.rst b/docs/architecture.rst index 8b5d442..bdb21f5 100644 --- a/docs/architecture.rst +++ b/docs/architecture.rst @@ -31,7 +31,7 @@ Core Components 2. **Segmentation Module** - * Base class: :class:`SegmentChestBase` + * Base class: :class:`SegmentAnatomyBase` * Implementations: TotalSegmentator, VISTA-3D, Ensemble 3. **Registration Module** diff --git a/docs/cli_scripts/create_statistical_model.rst b/docs/cli_scripts/create_statistical_model.rst new file mode 100644 index 0000000..ccd6af6 --- /dev/null +++ b/docs/cli_scripts/create_statistical_model.rst @@ -0,0 +1,102 @@ +==================================== +Create Statistical Model +==================================== + +Overview +======== + +The ``physiomotion4d-create-statistical-model`` command-line tool builds a PCA +(Principal Component Analysis) statistical shape model from a sample of meshes +aligned to a reference mesh. This mirrors the pipeline in the +Heart-Create_Statistical_Model experiment notebooks. + +The workflow: + +1. **Extract surfaces** from sample and reference meshes +2. **ICP alignment**: Affine align each sample surface to the reference surface +3. **Deformable registration**: ANTs SyN to establish dense correspondence +4. **Correspondence**: Build aligned shapes with reference topology +5. **PCA**: Compute mean shape and principal components + +Outputs written to the output directory: + +* ``pca_mean_surface.vtp`` — Mean shape as a surface (PolyData) +* ``pca_mean.vtu`` — Reference volume mesh in mean space (only if reference is volumetric) +* ``pca_model.json`` — PCA model (eigenvalues, components) for use with + :class:`physiomotion4d.WorkflowFitStatisticalModelToPatient` or + :class:`physiomotion4d.RegisterModelsPCA` + +Installation +============ + +The script is installed with PhysioMotion4D: + +.. code-block:: bash + + pip install physiomotion4d + +Quick Start +=========== + +Basic Usage +----------- + +Create a PCA model from a directory of sample meshes and a reference mesh: + +.. code-block:: bash + + physiomotion4d-create-statistical-model \ + --sample-meshes-dir ./input_meshes \ + --reference-mesh average_mesh.vtk \ + --output-dir ./pca_output + +Explicit Sample List +-------------------- + +Provide sample mesh paths explicitly instead of a directory: + +.. code-block:: bash + + physiomotion4d-create-statistical-model \ + --sample-meshes 01.vtk 02.vtk 03.vtu 04.vtp \ + --reference-mesh average_mesh.vtk \ + --output-dir ./pca_output + +With Custom Parameters +---------------------- + +.. code-block:: bash + + physiomotion4d-create-statistical-model \ + --sample-meshes-dir ./meshes \ + --reference-mesh average_mesh.vtk \ + --output-dir ./pca_output \ + --pca-components 20 + +Command-Line Arguments +====================== + +Required Arguments +------------------ + +``--sample-meshes-dir DIR`` or ``--sample-meshes PATH [PATH ...]`` + Either a directory containing sample mesh files (``.vtk``, ``.vtu``, ``.vtp``) + or a list of paths to sample meshes. One of these is required. + +``--reference-mesh PATH`` + Path to the reference mesh. Its surface is used as the alignment target for + all samples. + +``--output-dir DIR`` + Output directory. Writes ``pca_mean_surface.vtp``, ``pca_mean.vtu`` (if + reference is volumetric), and ``pca_model.json``. + +Optional Arguments +------------------ + +``--pca-components N`` + Number of PCA components to retain (default: 15). + +See :class:`physiomotion4d.WorkflowCreateStatisticalModel` for the full API and +additional parameters (e.g. ``reference_spatial_resolution``, +``reference_buffer_factor``) that can be exposed in future CLI versions. diff --git a/docs/cli_scripts/heart_model_to_patient.rst b/docs/cli_scripts/fit_statistical_model_to_patient.rst similarity index 90% rename from docs/cli_scripts/heart_model_to_patient.rst rename to docs/cli_scripts/fit_statistical_model_to_patient.rst index d2f85d0..669fe15 100644 --- a/docs/cli_scripts/heart_model_to_patient.rst +++ b/docs/cli_scripts/fit_statistical_model_to_patient.rst @@ -5,7 +5,7 @@ Heart Model to Patient Registration Overview ======== -The ``physiomotion4d-register-heart-model`` command-line tool registers generic anatomical heart models to patient-specific imaging data and surface models. This workflow enables: +The ``physiomotion4d-fit-statistical-model-to-patient`` command-line tool registers generic anatomical heart models to patient-specific imaging data and surface models. This workflow enables: * Patient-specific anatomical modeling from generic templates * Multi-stage registration combining ICP, PCA, and deformable methods @@ -38,7 +38,7 @@ Register a generic heart model to patient data: .. code-block:: bash - physiomotion4d-register-heart-model \ + physiomotion4d-fit-statistical-model-to-patient \ --template-model heart_model.vtu \ --template-labelmap heart_labelmap.nii.gz \ --patient-models lv.vtp rv.vtp myo.vtp \ @@ -52,7 +52,7 @@ Include statistical shape model fitting: .. code-block:: bash - physiomotion4d-register-heart-model \ + physiomotion4d-fit-statistical-model-to-patient \ --template-model heart_model.vtu \ --template-labelmap heart_labelmap.nii.gz \ --patient-models lv.vtp rv.vtp myo.vtp \ @@ -82,7 +82,7 @@ Required Arguments ``--output-dir DIR`` Output directory for results -See :class:`physiomotion4d.WorkflowRegisterHeartModelToPatient` for API documentation. +See :class:`physiomotion4d.WorkflowFitStatisticalModelToPatient` for API documentation. Template Labelmap Configuration -------------------------------- @@ -141,7 +141,7 @@ Example 1: Basic Registration .. code-block:: bash - physiomotion4d-register-heart-model \ + physiomotion4d-fit-statistical-model-to-patient \ --template-model heart_model.vtu \ --template-labelmap heart_labelmap.nii.gz \ --patient-models lv.vtp rv.vtp myo.vtp \ @@ -153,7 +153,7 @@ Example 2: PCA-Based Registration .. code-block:: bash - physiomotion4d-register-heart-model \ + physiomotion4d-fit-statistical-model-to-patient \ --template-model heart_model.vtu \ --template-labelmap heart_labelmap.nii.gz \ --patient-models lv.vtp rv.vtp \ diff --git a/docs/cli_scripts/overview.rst b/docs/cli_scripts/overview.rst index 2fa843c..180e6fc 100644 --- a/docs/cli_scripts/overview.rst +++ b/docs/cli_scripts/overview.rst @@ -9,7 +9,8 @@ This section provides comprehensive guides for using PhysioMotion4D's command-li **CLI Commands: Your Definitive Resource** ⭐ The examples and workflows documented here are based on production-ready CLI commands - (``physiomotion4d-heart-gated-ct``, ``physiomotion4d-register-heart-model``) and their + (``physiomotion4d-heart-gated-ct``, ``physiomotion4d-create-statistical-model``, + ``physiomotion4d-fit-statistical-model-to-patient``) and their implementations in ``src/physiomotion4d/cli/``. These are your **primary resource** for: * Production-ready workflow implementations @@ -48,7 +49,9 @@ Current Scripts - Description * - :doc:`heart_gated_ct` - Process cardiac gated CT to animated heart models with physiological motion - * - :doc:`heart_model_to_patient` + * - :doc:`create_statistical_model` + - Build a PCA statistical shape model from sample meshes aligned to a reference + * - :doc:`fit_statistical_model_to_patient` - Register generic heart models to patient-specific imaging data and surface models Upcoming Scripts diff --git a/docs/developer/architecture.rst b/docs/developer/architecture.rst index eacbd21..bf67f13 100644 --- a/docs/developer/architecture.rst +++ b/docs/developer/architecture.rst @@ -46,10 +46,10 @@ The package is organized into functional modules: │ ├── Workflow Classes │ ├── workflow_convert_heart_gated_ct_to_usd.py Cardiac CT → USD - │ └── workflow_register_heart_model_to_patient.py Model → Patient + │ └── workflow_fit_statistical_model_to_patient.py Model → Patient │ ├── Segmentation - │ ├── segment_chest_base.py Base segmentation + │ ├── segment_anatomy_base.py Base segmentation │ ├── segment_chest_total_segmentator.py TotalSegmentator │ ├── segment_chest_vista_3d.py VISTA-3D │ ├── segment_chest_vista_3d_nim.py VISTA-3D NIM @@ -95,9 +95,9 @@ Most PhysioMotion4D classes inherit from :class:`PhysioMotion4DBase`: PhysioMotion4DBase ├── Workflow Classes │ ├── WorkflowConvertHeartGatedCTToUSD - │ └── WorkflowRegisterHeartModelToPatient + │ └── WorkflowFitStatisticalModelToPatient ├── Segmentation Classes - │ ├── SegmentChestBase + │ ├── SegmentAnatomyBase │ │ ├── SegmentChestTotalSegmentator │ │ ├── SegmentChestVista3D │ │ └── SegmentChestEnsemble @@ -221,7 +221,7 @@ Extension Points PhysioMotion4D is designed for extension: **Add New Segmentation Methods** - Inherit from :class:`SegmentChestBase` + Inherit from :class:`SegmentAnatomyBase` **Add New Registration Methods** Inherit from :class:`RegisterImagesBase` diff --git a/docs/developer/extending.rst b/docs/developer/extending.rst index 1ebd916..d2a10f7 100644 --- a/docs/developer/extending.rst +++ b/docs/developer/extending.rst @@ -249,14 +249,14 @@ Custom Segmentation Methods Adding New Segmentation Algorithms ----------------------------------- -Extend :class:`SegmentChestBase`: +Extend :class:`SegmentAnatomyBase`: .. code-block:: python - from physiomotion4d import SegmentChestBase + from physiomotion4d import SegmentAnatomyBase import torch - class MyCustomSegmentator(SegmentChestBase): + class MyCustomSegmentator(SegmentAnatomyBase): """Custom deep learning segmentation.""" def __init__( diff --git a/docs/developer/segmentation.rst b/docs/developer/segmentation.rst index d5ab30f..656cd00 100644 --- a/docs/developer/segmentation.rst +++ b/docs/developer/segmentation.rst @@ -21,17 +21,17 @@ PhysioMotion4D supports multiple segmentation approaches: * **VISTA-3D NIM**: NVIDIA Inference Microservice version * **Ensemble**: Combine multiple methods for improved accuracy -All segmentation classes inherit from :class:`SegmentChestBase` and provide consistent interfaces. +All segmentation classes inherit from :class:`SegmentAnatomyBase` and provide consistent interfaces. Base Segmentation Class ======================= -SegmentChestBase +SegmentAnatomyBase ---------------- Abstract base class for all segmentation methods. -.. autoclass:: physiomotion4d.SegmentChestBase +.. autoclass:: physiomotion4d.SegmentAnatomyBase :members: :undoc-members: :show-inheritance: @@ -274,10 +274,10 @@ Add custom post-processing steps: .. code-block:: python - from physiomotion4d import SegmentChestBase + from physiomotion4d import SegmentAnatomyBase import numpy as np - class CustomSegmentator(SegmentChestBase): + class CustomSegmentator(SegmentAnatomyBase): """Custom segmentator with post-processing.""" def post_process(self, labelmap): diff --git a/docs/developer/workflows.rst b/docs/developer/workflows.rst index c44aba1..cadd1da 100644 --- a/docs/developer/workflows.rst +++ b/docs/developer/workflows.rst @@ -37,7 +37,7 @@ Each CLI script wraps a corresponding workflow class: * - ``physiomotion4d-heart-gated-ct`` - :class:`WorkflowConvertHeartGatedCTToUSD` * - ``physiomotion4d-heart-model-to-patient`` - - :class:`WorkflowRegisterHeartModelToPatient` *(planned)* + - :class:`WorkflowFitStatisticalModelToPatient` *(planned)* * - ``physiomotion4d-lung-gated-ct`` - :class:`LungGatedCTToUSDWorkflow` *(planned)* * - ``physiomotion4d-4dct-reconstruction`` @@ -109,7 +109,7 @@ Process 4D cardiac CT to animated USD models. See :doc:`../cli_scripts/heart_gated_ct` for CLI usage. -WorkflowRegisterHeartModelToPatient +WorkflowFitStatisticalModelToPatient ------------------------------------ .. note:: @@ -121,9 +121,9 @@ Register population heart models to patient images. .. code-block:: python - from physiomotion4d import WorkflowRegisterHeartModelToPatient + from physiomotion4d import WorkflowFitStatisticalModelToPatient - workflow = WorkflowRegisterHeartModelToPatient( + workflow = WorkflowFitStatisticalModelToPatient( model_file="population_heart_model.vtk", patient_image="patient_ct.nrrd", output_directory="./results" @@ -131,7 +131,7 @@ Register population heart models to patient images. registered_model = workflow.process() -See :doc:`../cli_scripts/heart_model_to_patient` for planned CLI usage. +See :doc:`../cli_scripts/fit_statistical_model_to_patient` for planned CLI usage. Common Workflow Patterns ======================== diff --git a/docs/examples.rst b/docs/examples.rst index 4aee427..2fb4b6d 100644 --- a/docs/examples.rst +++ b/docs/examples.rst @@ -8,7 +8,8 @@ see the :doc:`cli_scripts/overview` section. .. note:: **For Production Workflows:** The CLI commands (``physiomotion4d-heart-gated-ct``, - ``physiomotion4d-register-heart-model``) and their implementations in ``src/physiomotion4d/cli/`` + ``physiomotion4d-create-statistical-model``, ``physiomotion4d-fit-statistical-model-to-patient``) + and their implementations in ``src/physiomotion4d/cli/`` are the definitive source for proper library usage, class instantiation, and best practices. The ``experiments/`` directory contains research prototypes that informed development but should diff --git a/docs/index.rst b/docs/index.rst index 70ca847..4e8bd44 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -56,7 +56,8 @@ PhysioMotion4D is a comprehensive medical imaging package that converts 3D and 4 **Getting Started with Code Examples:** This documentation uses examples from the CLI commands (``physiomotion4d-heart-gated-ct``, - ``physiomotion4d-register-heart-model``) and their implementations in ``src/physiomotion4d/cli/``, + ``physiomotion4d-create-statistical-model``, ``physiomotion4d-fit-statistical-model-to-patient``) + and their implementations in ``src/physiomotion4d/cli/``, which contain production-ready workflows and proper library usage patterns. The repository also includes an ``experiments/`` directory with research prototypes that can inspire adaptations to new digital twin models and anatomical regions—see the experiments README for details on @@ -76,7 +77,8 @@ PhysioMotion4D is a comprehensive medical imaging package that converts 3D and 4 cli_scripts/overview cli_scripts/heart_gated_ct - cli_scripts/heart_model_to_patient + cli_scripts/create_statistical_model + cli_scripts/fit_statistical_model_to_patient cli_scripts/lung_gated_ct cli_scripts/4dct_reconstruction cli_scripts/vtk_to_usd diff --git a/docs/quickstart.rst b/docs/quickstart.rst index 22e4355..b37dfd5 100644 --- a/docs/quickstart.rst +++ b/docs/quickstart.rst @@ -260,7 +260,8 @@ Now that you've completed your first workflow: **About CLI Commands and Experiments:** * **CLI Commands** ⭐ **PRIMARY RESOURCE** - Production-ready workflows with proper class usage - (``physiomotion4d-heart-gated-ct``, ``physiomotion4d-register-heart-model``). + (``physiomotion4d-heart-gated-ct``, ``physiomotion4d-create-statistical-model``, + ``physiomotion4d-fit-statistical-model-to-patient``). See ``src/physiomotion4d/cli/`` for implementation details. * **experiments/** - Research prototypes and design explorations. These demonstrate conceptual diff --git a/experiments/Heart-Create_Statistical_Model/5-compute_pca_model.ipynb b/experiments/Heart-Create_Statistical_Model/5-compute_pca_model.ipynb index 05ecc54..29aa5f6 100644 --- a/experiments/Heart-Create_Statistical_Model/5-compute_pca_model.ipynb +++ b/experiments/Heart-Create_Statistical_Model/5-compute_pca_model.ipynb @@ -132,7 +132,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "id": "df57b010", "metadata": { "execution": { @@ -155,7 +155,7 @@ } ], "source": [ - "# Load average surface (this will be used as the mean for PCA)\n", + "# Load average surface (this will be replaced by the mean from PCA)\n", "average_mesh = pv.read(average_surface_path)\n", "n_points = average_mesh.n_points\n", "n_features = n_points * 3 # x, y, z coordinates for each point\n", diff --git a/experiments/Heart-Create_Statistical_Model/README.md b/experiments/Heart-Create_Statistical_Model/README.md index ab39cf7..00b5286 100644 --- a/experiments/Heart-Create_Statistical_Model/README.md +++ b/experiments/Heart-Create_Statistical_Model/README.md @@ -145,10 +145,10 @@ After completing this experiment, you will have generated files in `kcl-heart-mo The outputs from this experiment are used in the `Heart-Statistical_Model_To_Patient` experiment: ```python -from physiomotion4d import WorkflowRegisterHeartModelToPatient +from physiomotion4d import WorkflowFitStatisticalModelToPatient # Use PCA model from this experiment -workflow = WorkflowRegisterHeartModelToPatient( +workflow = WorkflowFitStatisticalModelToPatient( moving_mesh=mean_shape, fixed_meshes=patient_surfaces, fixed_image=patient_ct, diff --git a/experiments/Heart-Simpleware_Segmentation/README.md b/experiments/Heart-Simpleware_Segmentation/README.md index b561c40..2eb59a2 100644 --- a/experiments/Heart-Simpleware_Segmentation/README.md +++ b/experiments/Heart-Simpleware_Segmentation/README.md @@ -225,9 +225,9 @@ workflow.set_static_labelmap(result["labelmap"]) ### Statistical Model Registration Register segmentation with heart model using `Heart-Statistical_Model_To_Patient`: ```python -from physiomotion4d.workflow_register_heart_model_to_patient import WorkflowRegisterHeartModelToPatient +from physiomotion4d.workflow_fit_statistical_model_to_patient import WorkflowFitStatisticalModelToPatient -workflow = WorkflowRegisterHeartModelToPatient() +workflow = WorkflowFitStatisticalModelToPatient() workflow.set_patient_segmentation(result["labelmap"]) # Perform model-to-patient registration ``` diff --git a/experiments/Heart-Statistical_Model_To_Patient/heart_model_to_model_registration_pca.ipynb b/experiments/Heart-Statistical_Model_To_Patient/heart_model_to_model_registration_pca.ipynb index 091e82b..7674b49 100644 --- a/experiments/Heart-Statistical_Model_To_Patient/heart_model_to_model_registration_pca.ipynb +++ b/experiments/Heart-Statistical_Model_To_Patient/heart_model_to_model_registration_pca.ipynb @@ -1,846 +1,849 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# PCA-based Heart Model to Image Registration Experiment\n", - "\n", - "This notebook demonstrates using the `RegisterModelToImagePCA` class to register\n", - "a statistical shape model to patient CT images using PCA-based shape variation.\n", - "\n", - "## Overview\n", - "- Uses the KCL Heart Model PCA statistical shape model\n", - "- Registers to the same Duke Heart CT data as the original notebook\n", - "- Two-stage optimization: rigid alignment + PCA shape fitting\n", - "- Converts segmentation mask to intensity image for registration" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Setup and Imports" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2026-02-09T06:34:36.139622Z", - "iopub.status.busy": "2026-02-09T06:34:36.139622Z", - "iopub.status.idle": "2026-02-09T06:34:51.022942Z", - "shell.execute_reply": "2026-02-09T06:34:51.022012Z" - } - }, - "outputs": [], - "source": [ - "# PCA-based Heart Model to Image Registration Experiment\n", - "\n", - "import os\n", - "from pathlib import Path\n", - "\n", - "import itk\n", - "import numpy as np\n", - "import pyvista as pv\n", - "from itk import TubeTK as ttk\n", - "\n", - "# Import from PhysioMotion4D package\n", - "from physiomotion4d import (\n", - " ContourTools,\n", - " RegisterModelsICP,\n", - " RegisterModelsPCA,\n", - " TransformTools,\n", - ")\n", - "from physiomotion4d.notebook_utils import running_as_test" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Define File Paths" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2026-02-09T06:34:51.023944Z", - "iopub.status.busy": "2026-02-09T06:34:51.023944Z", - "iopub.status.idle": "2026-02-09T06:34:51.037942Z", - "shell.execute_reply": "2026-02-09T06:34:51.036949Z" - } - }, - "outputs": [], - "source": [ - "# Patient CT image (defines coordinate frame)\n", - "patient_data_dir = Path.cwd().parent.parent / \"data\" / \"Slicer-Heart-CT\"\n", - "patient_ct_path = patient_data_dir / \"patient_img.mha\"\n", - "patient_ct_heart_mask_path = patient_data_dir / \"patient_heart_wall_mask.nii.gz\"\n", - "\n", - "# PCA heart model data\n", - "heart_model_data_dir = Path.cwd().parent.parent / \"data\" / \"KCL-Heart-Model\"\n", - "heart_model_path = heart_model_data_dir / \"average_mesh.vtk\"\n", - "\n", - "# PCA statistical model (from Heart-Create_Statistical_Model workflow)\n", - "template_model_data_dir = (\n", - " Path.cwd().parent / \"Heart-Create_Statistical_Model\" / \"kcl-heart-model\"\n", - ")\n", - "template_model_surface_path = template_model_data_dir / \"pca_mean.vtp\"\n", - "pca_json_path = template_model_data_dir / \"pca_model.json\"\n", - "\n", - "# Output directory\n", - "output_dir = Path.cwd() / \"results_pca\"\n", - "os.makedirs(output_dir, exist_ok=True)\n", - "\n", - "print(f\"Patient data: {patient_data_dir}\")\n", - "print(f\"PCA Model data: {template_model_data_dir}\")\n", - "print(f\"Output directory: {output_dir}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Load and Preprocess Patient Image" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2026-02-09T06:34:51.050954Z", - "iopub.status.busy": "2026-02-09T06:34:51.050954Z", - "iopub.status.idle": "2026-02-09T06:34:56.086219Z", - "shell.execute_reply": "2026-02-09T06:34:56.085110Z" - } - }, - "outputs": [], - "source": [ - "# Load patient CT image\n", - "print(\"Loading patient CT image...\")\n", - "patient_image = itk.imread(str(patient_ct_path))\n", - "print(f\" Original size: {itk.size(patient_image)}\")\n", - "print(f\" Original spacing: {itk.spacing(patient_image)}\")\n", - "\n", - "# Resample to 1mm isotropic spacing\n", - "print(\"Resampling to sotropic...\")\n", - "resampler = ttk.ResampleImage.New(Input=patient_image)\n", - "resampler.SetMakeHighResIso(True)\n", - "resampler.Update()\n", - "patient_image = resampler.GetOutput()\n", - "\n", - "print(f\" Resampled size: {itk.size(patient_image)}\")\n", - "print(f\" Resampled spacing: {itk.spacing(patient_image)}\")\n", - "\n", - "# Save preprocessed image\n", - "itk.imwrite(patient_image, str(output_dir / \"patient_image.mha\"), compression=True)\n", - "print(\"✓ Saved preprocessed image\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Load and Process Heart Segmentation Mask" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2026-02-09T06:34:56.087775Z", - "iopub.status.busy": "2026-02-09T06:34:56.087775Z", - "iopub.status.idle": "2026-02-09T06:34:56.115692Z", - "shell.execute_reply": "2026-02-09T06:34:56.115196Z" - } - }, - "outputs": [], - "source": [ - "# Load heart segmentation mask\n", - "print(\"Loading heart segmentation mask...\")\n", - "patient_heart_mask_image = itk.imread(str(patient_ct_heart_mask_path))\n", - "\n", - "print(f\" Mask size: {itk.size(patient_heart_mask_image)}\")\n", - "print(f\" Mask spacing: {itk.spacing(patient_heart_mask_image)}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2026-02-09T06:34:56.117556Z", - "iopub.status.busy": "2026-02-09T06:34:56.117556Z", - "iopub.status.idle": "2026-02-09T06:34:56.462795Z", - "shell.execute_reply": "2026-02-09T06:34:56.462795Z" - } - }, - "outputs": [], - "source": [ - "# Handle image orientation (flip if needed)\n", - "flip0 = np.array(patient_heart_mask_image.GetDirection())[0, 0] < 0\n", - "flip1 = np.array(patient_heart_mask_image.GetDirection())[1, 1] < 0\n", - "flip2 = np.array(patient_heart_mask_image.GetDirection())[2, 2] < 0\n", - "\n", - "if flip0 or flip1 or flip2:\n", - " print(f\"Flipping image axes: {flip0}, {flip1}, {flip2}\")\n", - "\n", - " # Flip CT image\n", - " flip_filter = itk.FlipImageFilter.New(Input=patient_image)\n", - " flip_filter.SetFlipAxes([int(flip0), int(flip1), int(flip2)])\n", - " flip_filter.SetFlipAboutOrigin(True)\n", - " flip_filter.Update()\n", - " patient_image = flip_filter.GetOutput()\n", - " id_mat = itk.Matrix[itk.D, 3, 3]()\n", - " id_mat.SetIdentity()\n", - " patient_image.SetDirection(id_mat)\n", - "\n", - " # Flip mask image\n", - " flip_filter = itk.FlipImageFilter.New(Input=patient_heart_mask_image)\n", - " flip_filter.SetFlipAxes([int(flip0), int(flip1), int(flip2)])\n", - " flip_filter.SetFlipAboutOrigin(True)\n", - " flip_filter.Update()\n", - " patient_heart_mask_image = flip_filter.GetOutput()\n", - " patient_heart_mask_image.SetDirection(id_mat)\n", - "\n", - " print(\"✓ Images flipped to standard orientation\")\n", - "\n", - "# Save oriented images\n", - "itk.imwrite(\n", - " patient_image, str(output_dir / \"patient_image_oriented.mha\"), compression=True\n", - ")\n", - "itk.imwrite(\n", - " patient_heart_mask_image,\n", - " str(output_dir / \"patient_heart_mask_oriented.mha\"),\n", - " compression=True,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Convert Segmentation Mask to a Surface" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2026-02-09T06:34:56.464797Z", - "iopub.status.busy": "2026-02-09T06:34:56.464797Z", - "iopub.status.idle": "2026-02-09T06:34:56.917954Z", - "shell.execute_reply": "2026-02-09T06:34:56.917429Z" - } - }, - "outputs": [], - "source": [ - "contour_tools = ContourTools()\n", - "patient_surface = contour_tools.extract_contours(patient_heart_mask_image)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Perform Initial ICP Affine Registration\n", - "\n", - "Use ICP (Iterative Closest Point) with affine mode to align the model surface to the patient surface extracted from the segmentation mask. This provides a good initial alignment for the PCA-based registration.\n", - "\n", - "The ICP registration pipeline:\n", - "1. Centroid alignment (automatic)\n", - "2. Rigid ICP alignment\n", - "\n", - "The PCA registration will then refine this initial alignment with shape model constraints." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2026-02-09T06:34:56.919468Z", - "iopub.status.busy": "2026-02-09T06:34:56.919468Z", - "iopub.status.idle": "2026-02-09T06:35:10.733991Z", - "shell.execute_reply": "2026-02-09T06:35:10.733143Z" - } - }, - "outputs": [], - "source": [ - "# Load the pca model\n", - "print(\"Loading PCA heart model...\")\n", - "template_model = pv.read(str(heart_model_path))\n", - "\n", - "template_model_surface = pv.read(template_model_surface_path)\n", - "print(f\" Template surface: {template_model_surface.n_points} points\")\n", - "\n", - "icp_registrar = RegisterModelsICP(fixed_model=patient_surface)\n", - "\n", - "# Use fewer iterations when run as test (pytest) for faster execution\n", - "max_iterations_icp = 100 if running_as_test() else 2000\n", - "icp_result = icp_registrar.register(\n", - " transform_type=\"Affine\",\n", - " moving_model=template_model_surface,\n", - " max_iterations=max_iterations_icp,\n", - ")\n", - "\n", - "# Get the aligned mesh and transform\n", - "icp_registered_model_surface = icp_result[\"registered_model\"]\n", - "icp_forward_point_transform = icp_result[\"forward_point_transform\"]\n", - "\n", - "print(\"\\n✓ ICP affine registration complete\")\n", - "print(\" Transform =\", icp_result[\"forward_point_transform\"])\n", - "\n", - "# Save aligned model\n", - "icp_registered_model_surface.save(str(output_dir / \"icp_registered_model_surface.vtp\"))\n", - "print(\" Saved ICP-aligned model surface\")\n", - "\n", - "itk.transformwrite(\n", - " [icp_result[\"forward_point_transform\"]],\n", - " str(output_dir / \"icp_transform.hdf\"),\n", - " compression=True,\n", - ")\n", - "print(\" Saved ICP transform\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2026-02-09T06:35:10.735086Z", - "iopub.status.busy": "2026-02-09T06:35:10.735086Z", - "iopub.status.idle": "2026-02-09T06:35:14.798820Z", - "shell.execute_reply": "2026-02-09T06:35:14.797942Z" - } - }, - "outputs": [], - "source": [ - "# Apply ICP transform to the full average mesh (not just surface)\n", - "# This gets the volumetric mesh into patient space for PCA registration\n", - "transform_tools = TransformTools()\n", - "icp_registered_model = transform_tools.transform_pvcontour(\n", - " template_model, icp_forward_point_transform\n", - ")\n", - "icp_registered_model.save(str(output_dir / \"icp_registered_model.vtk\"))\n", - "print(\"\\n✓ Applied ICP transform to full model mesh\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Initialize PCA Registration" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2026-02-09T06:35:14.800938Z", - "iopub.status.busy": "2026-02-09T06:35:14.800938Z", - "iopub.status.idle": "2026-02-09T06:35:18.694406Z", - "shell.execute_reply": "2026-02-09T06:35:18.693110Z" - } - }, - "outputs": [], - "source": [ - "## Initialize PCA Registration\n", - "print(\"=\" * 70)\n", - "\n", - "# Use the ICP-aligned mesh as the starting point for PCA registration\n", - "pca_registrar = RegisterModelsPCA.from_json(\n", - " pca_template_model=template_model_surface,\n", - " pca_json_filename=pca_json_path,\n", - " pca_number_of_modes=10,\n", - " pre_pca_transform=icp_forward_point_transform,\n", - " fixed_model=patient_surface,\n", - " reference_image=patient_image,\n", - ")\n", - "\n", - "itk.imwrite(pca_registrar.fixed_distance_map, str(output_dir / \"distance_map.mha\"))\n", - "\n", - "print(\"✓ PCA registrar initialized\")\n", - "print(\" Using ICP-aligned mesh as starting point\")\n", - "print(f\" Number of points: {len(pca_registrar.pca_template_model.points)}\")\n", - "print(f\" Number of PCA modes: {pca_registrar.pca_number_of_modes}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Run PCA-Based Shape Optimization\n", - "\n", - "Now that we have a good initial alignment from ICP affine registration, we run the PCA-based registration to optimize the shape parameters." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2026-02-09T06:35:18.696501Z", - "iopub.status.busy": "2026-02-09T06:35:18.695503Z", - "iopub.status.idle": "2026-02-09T06:36:49.984747Z", - "shell.execute_reply": "2026-02-09T06:36:49.983820Z" - } - }, - "outputs": [], - "source": [ - "print(\"\\n\" + \"=\" * 70)\n", - "print(\"PCA-BASED SHAPE OPTIMIZATION\")\n", - "print(\"=\" * 70)\n", - "print(\"\\nRunning complete PCA registration pipeline...\")\n", - "print(\" (Starting from ICP-aligned mesh)\")\n", - "\n", - "result = pca_registrar.register(\n", - " pca_number_of_modes=10, # Use first 10 PCA modes\n", - ")\n", - "\n", - "dm = pca_registrar.fixed_distance_map\n", - "itk.imwrite(dm, str(output_dir / \"target_distance_map.mha\"))\n", - "\n", - "pca_registered_model_surface = result[\"registered_model\"]\n", - "\n", - "print(\"\\n✓ PCA registration complete\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Display Registration Results\n", - "\n", - "Review the optimization results from the PCA registration pipeline.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2026-02-09T06:36:49.986008Z", - "iopub.status.busy": "2026-02-09T06:36:49.986008Z", - "iopub.status.idle": "2026-02-09T06:36:49.999829Z", - "shell.execute_reply": "2026-02-09T06:36:49.998876Z" - } - }, - "outputs": [], - "source": [ - "print(\"\\n\" + \"=\" * 70)\n", - "print(\"REGISTRATION RESULTS\")\n", - "print(\"=\" * 70)\n", - "\n", - "# Display results\n", - "print(\"\\nFinal Registration Metrics:\")\n", - "print(f\" Final mean intensity: {result['mean_distance']:.4f}\")\n", - "\n", - "print(\"\\nOptimized PCA Coefficients (in units of std deviations):\")\n", - "for i, coef in enumerate(result[\"pca_coefficients\"]):\n", - " print(f\" Mode {i + 1:2d}: {coef:7.4f}\")\n", - "\n", - "print(\"\\n✓ Registration pipeline complete!\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Save Registration Results\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2026-02-09T06:36:50.001252Z", - "iopub.status.busy": "2026-02-09T06:36:50.001252Z", - "iopub.status.idle": "2026-02-09T06:36:51.739511Z", - "shell.execute_reply": "2026-02-09T06:36:51.738563Z" - } - }, - "outputs": [], - "source": [ - "print(\"\\nSaving results...\")\n", - "\n", - "# Save final PCA-registered mesh\n", - "pca_registered_model_surface.save(str(output_dir / \"pca_registered_model_surface.vtk\"))\n", - "print(\" Saved final PCA-registered mesh\")\n", - "\n", - "ref_image = contour_tools.create_reference_image(pca_registered_model_surface)\n", - "\n", - "distance_map = contour_tools.create_distance_map(\n", - " pca_registered_model_surface,\n", - " ref_image,\n", - " squared_distance=True,\n", - " negative_inside=False,\n", - " zero_inside=True,\n", - " norm_to_max_distance=200.0,\n", - ")\n", - "\n", - "itk.imwrite(distance_map, str(output_dir / \"pca_distance_map.mha\"))\n", - "\n", - "# Save PCA coefficients\n", - "np.savetxt(\n", - " str(output_dir / \"pca_coefficients.txt\"),\n", - " result[\"pca_coefficients\"],\n", - " header=f\"PCA coefficients for {len(result['pca_coefficients'])} modes\",\n", - ")\n", - "print(\" Saved PCA coefficients\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Visualize Results\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2026-02-09T06:36:51.741507Z", - "iopub.status.busy": "2026-02-09T06:36:51.741507Z", - "iopub.status.idle": "2026-02-09T06:36:54.084672Z", - "shell.execute_reply": "2026-02-09T06:36:54.082694Z" + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# PCA-based Heart Model to Image Registration Experiment\n", + "\n", + "This notebook demonstrates using the `RegisterModelToImagePCA` class to register\n", + "a statistical shape model to patient CT images using PCA-based shape variation.\n", + "\n", + "## Overview\n", + "- Uses the KCL Heart Model PCA statistical shape model\n", + "- Registers to the same Duke Heart CT data as the original notebook\n", + "- Two-stage optimization: rigid alignment + PCA shape fitting\n", + "- Converts segmentation mask to intensity image for registration" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup and Imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2026-02-09T06:34:36.139622Z", + "iopub.status.busy": "2026-02-09T06:34:36.139622Z", + "iopub.status.idle": "2026-02-09T06:34:51.022942Z", + "shell.execute_reply": "2026-02-09T06:34:51.022012Z" + } + }, + "outputs": [], + "source": [ + "# PCA-based Heart Model to Image Registration Experiment\n", + "\n", + "import json\n", + "import os\n", + "from pathlib import Path\n", + "\n", + "import itk\n", + "import numpy as np\n", + "import pyvista as pv\n", + "from itk import TubeTK as ttk\n", + "\n", + "# Import from PhysioMotion4D package\n", + "from physiomotion4d import (\n", + " ContourTools,\n", + " RegisterModelsICP,\n", + " RegisterModelsPCA,\n", + " TransformTools,\n", + ")\n", + "from physiomotion4d.notebook_utils import running_as_test" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define File Paths" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2026-02-09T06:34:51.023944Z", + "iopub.status.busy": "2026-02-09T06:34:51.023944Z", + "iopub.status.idle": "2026-02-09T06:34:51.037942Z", + "shell.execute_reply": "2026-02-09T06:34:51.036949Z" + } + }, + "outputs": [], + "source": [ + "# Patient CT image (defines coordinate frame)\n", + "patient_data_dir = Path.cwd().parent.parent / \"data\" / \"Slicer-Heart-CT\"\n", + "patient_ct_path = patient_data_dir / \"patient_img.mha\"\n", + "patient_ct_heart_mask_path = patient_data_dir / \"patient_heart_wall_mask.nii.gz\"\n", + "\n", + "# PCA heart model data\n", + "heart_model_data_dir = Path.cwd().parent.parent / \"data\" / \"KCL-Heart-Model\"\n", + "heart_model_path = heart_model_data_dir / \"average_mesh.vtk\"\n", + "\n", + "# PCA statistical model (from Heart-Create_Statistical_Model workflow)\n", + "template_model_data_dir = (\n", + " Path.cwd().parent / \"Heart-Create_Statistical_Model\" / \"kcl-heart-model\"\n", + ")\n", + "template_model_surface_path = template_model_data_dir / \"pca_mean.vtp\"\n", + "pca_json_path = template_model_data_dir / \"pca_model.json\"\n", + "\n", + "# Output directory\n", + "output_dir = Path.cwd() / \"results_pca\"\n", + "os.makedirs(output_dir, exist_ok=True)\n", + "\n", + "print(f\"Patient data: {patient_data_dir}\")\n", + "print(f\"PCA Model data: {template_model_data_dir}\")\n", + "print(f\"Output directory: {output_dir}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load and Preprocess Patient Image" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2026-02-09T06:34:51.050954Z", + "iopub.status.busy": "2026-02-09T06:34:51.050954Z", + "iopub.status.idle": "2026-02-09T06:34:56.086219Z", + "shell.execute_reply": "2026-02-09T06:34:56.085110Z" + } + }, + "outputs": [], + "source": [ + "# Load patient CT image\n", + "print(\"Loading patient CT image...\")\n", + "patient_image = itk.imread(str(patient_ct_path))\n", + "print(f\" Original size: {itk.size(patient_image)}\")\n", + "print(f\" Original spacing: {itk.spacing(patient_image)}\")\n", + "\n", + "# Resample to 1mm isotropic spacing\n", + "print(\"Resampling to sotropic...\")\n", + "resampler = ttk.ResampleImage.New(Input=patient_image)\n", + "resampler.SetMakeHighResIso(True)\n", + "resampler.Update()\n", + "patient_image = resampler.GetOutput()\n", + "\n", + "print(f\" Resampled size: {itk.size(patient_image)}\")\n", + "print(f\" Resampled spacing: {itk.spacing(patient_image)}\")\n", + "\n", + "# Save preprocessed image\n", + "itk.imwrite(patient_image, str(output_dir / \"patient_image.mha\"), compression=True)\n", + "print(\"✓ Saved preprocessed image\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load and Process Heart Segmentation Mask" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2026-02-09T06:34:56.087775Z", + "iopub.status.busy": "2026-02-09T06:34:56.087775Z", + "iopub.status.idle": "2026-02-09T06:34:56.115692Z", + "shell.execute_reply": "2026-02-09T06:34:56.115196Z" + } + }, + "outputs": [], + "source": [ + "# Load heart segmentation mask\n", + "print(\"Loading heart segmentation mask...\")\n", + "patient_heart_mask_image = itk.imread(str(patient_ct_heart_mask_path))\n", + "\n", + "print(f\" Mask size: {itk.size(patient_heart_mask_image)}\")\n", + "print(f\" Mask spacing: {itk.spacing(patient_heart_mask_image)}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2026-02-09T06:34:56.117556Z", + "iopub.status.busy": "2026-02-09T06:34:56.117556Z", + "iopub.status.idle": "2026-02-09T06:34:56.462795Z", + "shell.execute_reply": "2026-02-09T06:34:56.462795Z" + } + }, + "outputs": [], + "source": [ + "# Handle image orientation (flip if needed)\n", + "flip0 = np.array(patient_heart_mask_image.GetDirection())[0, 0] < 0\n", + "flip1 = np.array(patient_heart_mask_image.GetDirection())[1, 1] < 0\n", + "flip2 = np.array(patient_heart_mask_image.GetDirection())[2, 2] < 0\n", + "\n", + "if flip0 or flip1 or flip2:\n", + " print(f\"Flipping image axes: {flip0}, {flip1}, {flip2}\")\n", + "\n", + " # Flip CT image\n", + " flip_filter = itk.FlipImageFilter.New(Input=patient_image)\n", + " flip_filter.SetFlipAxes([int(flip0), int(flip1), int(flip2)])\n", + " flip_filter.SetFlipAboutOrigin(True)\n", + " flip_filter.Update()\n", + " patient_image = flip_filter.GetOutput()\n", + " id_mat = itk.Matrix[itk.D, 3, 3]()\n", + " id_mat.SetIdentity()\n", + " patient_image.SetDirection(id_mat)\n", + "\n", + " # Flip mask image\n", + " flip_filter = itk.FlipImageFilter.New(Input=patient_heart_mask_image)\n", + " flip_filter.SetFlipAxes([int(flip0), int(flip1), int(flip2)])\n", + " flip_filter.SetFlipAboutOrigin(True)\n", + " flip_filter.Update()\n", + " patient_heart_mask_image = flip_filter.GetOutput()\n", + " patient_heart_mask_image.SetDirection(id_mat)\n", + "\n", + " print(\"✓ Images flipped to standard orientation\")\n", + "\n", + "# Save oriented images\n", + "itk.imwrite(\n", + " patient_image, str(output_dir / \"patient_image_oriented.mha\"), compression=True\n", + ")\n", + "itk.imwrite(\n", + " patient_heart_mask_image,\n", + " str(output_dir / \"patient_heart_mask_oriented.mha\"),\n", + " compression=True,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Convert Segmentation Mask to a Surface" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2026-02-09T06:34:56.464797Z", + "iopub.status.busy": "2026-02-09T06:34:56.464797Z", + "iopub.status.idle": "2026-02-09T06:34:56.917954Z", + "shell.execute_reply": "2026-02-09T06:34:56.917429Z" + } + }, + "outputs": [], + "source": [ + "contour_tools = ContourTools()\n", + "patient_surface = contour_tools.extract_contours(patient_heart_mask_image)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Perform Initial ICP Affine Registration\n", + "\n", + "Use ICP (Iterative Closest Point) with affine mode to align the model surface to the patient surface extracted from the segmentation mask. This provides a good initial alignment for the PCA-based registration.\n", + "\n", + "The ICP registration pipeline:\n", + "1. Centroid alignment (automatic)\n", + "2. Rigid ICP alignment\n", + "\n", + "The PCA registration will then refine this initial alignment with shape model constraints." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2026-02-09T06:34:56.919468Z", + "iopub.status.busy": "2026-02-09T06:34:56.919468Z", + "iopub.status.idle": "2026-02-09T06:35:10.733991Z", + "shell.execute_reply": "2026-02-09T06:35:10.733143Z" + } + }, + "outputs": [], + "source": [ + "# Load the pca model\n", + "print(\"Loading PCA heart model...\")\n", + "template_model = pv.read(str(heart_model_path))\n", + "\n", + "template_model_surface = pv.read(template_model_surface_path)\n", + "print(f\" Template surface: {template_model_surface.n_points} points\")\n", + "\n", + "icp_registrar = RegisterModelsICP(fixed_model=patient_surface)\n", + "\n", + "# Use fewer iterations when run as test (pytest) for faster execution\n", + "max_iterations_icp = 100 if running_as_test() else 2000\n", + "icp_result = icp_registrar.register(\n", + " transform_type=\"Affine\",\n", + " moving_model=template_model_surface,\n", + " max_iterations=max_iterations_icp,\n", + ")\n", + "\n", + "# Get the aligned mesh and transform\n", + "icp_registered_model_surface = icp_result[\"registered_model\"]\n", + "icp_forward_point_transform = icp_result[\"forward_point_transform\"]\n", + "\n", + "print(\"\\n✓ ICP affine registration complete\")\n", + "print(\" Transform =\", icp_result[\"forward_point_transform\"])\n", + "\n", + "# Save aligned model\n", + "icp_registered_model_surface.save(str(output_dir / \"icp_registered_model_surface.vtp\"))\n", + "print(\" Saved ICP-aligned model surface\")\n", + "\n", + "itk.transformwrite(\n", + " [icp_result[\"forward_point_transform\"]],\n", + " str(output_dir / \"icp_transform.hdf\"),\n", + " compression=True,\n", + ")\n", + "print(\" Saved ICP transform\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2026-02-09T06:35:10.735086Z", + "iopub.status.busy": "2026-02-09T06:35:10.735086Z", + "iopub.status.idle": "2026-02-09T06:35:14.798820Z", + "shell.execute_reply": "2026-02-09T06:35:14.797942Z" + } + }, + "outputs": [], + "source": [ + "# Apply ICP transform to the full average mesh (not just surface)\n", + "# This gets the volumetric mesh into patient space for PCA registration\n", + "transform_tools = TransformTools()\n", + "icp_registered_model = transform_tools.transform_pvcontour(\n", + " template_model, icp_forward_point_transform\n", + ")\n", + "icp_registered_model.save(str(output_dir / \"icp_registered_model.vtk\"))\n", + "print(\"\\n✓ Applied ICP transform to full model mesh\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Initialize PCA Registration" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2026-02-09T06:35:14.800938Z", + "iopub.status.busy": "2026-02-09T06:35:14.800938Z", + "iopub.status.idle": "2026-02-09T06:35:18.694406Z", + "shell.execute_reply": "2026-02-09T06:35:18.693110Z" + } + }, + "outputs": [], + "source": [ + "## Initialize PCA Registration\n", + "print(\"=\" * 70)\n", + "\n", + "# Use the ICP-aligned mesh as the starting point for PCA registration\n", + "with open(pca_json_path, encoding=\"utf-8\") as f:\n", + " pca_model = json.load(f)\n", + "pca_registrar = RegisterModelsPCA.from_pca_model(\n", + " pca_template_model=template_model_surface,\n", + " pca_model=pca_model,\n", + " pca_number_of_modes=10,\n", + " pre_pca_transform=icp_forward_point_transform,\n", + " fixed_model=patient_surface,\n", + " reference_image=patient_image,\n", + ")\n", + "\n", + "itk.imwrite(pca_registrar.fixed_distance_map, str(output_dir / \"distance_map.mha\"))\n", + "\n", + "print(\"✓ PCA registrar initialized\")\n", + "print(\" Using ICP-aligned mesh as starting point\")\n", + "print(f\" Number of points: {len(pca_registrar.pca_template_model.points)}\")\n", + "print(f\" Number of PCA modes: {pca_registrar.pca_number_of_modes}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Run PCA-Based Shape Optimization\n", + "\n", + "Now that we have a good initial alignment from ICP affine registration, we run the PCA-based registration to optimize the shape parameters." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2026-02-09T06:35:18.696501Z", + "iopub.status.busy": "2026-02-09T06:35:18.695503Z", + "iopub.status.idle": "2026-02-09T06:36:49.984747Z", + "shell.execute_reply": "2026-02-09T06:36:49.983820Z" + } + }, + "outputs": [], + "source": [ + "print(\"\\n\" + \"=\" * 70)\n", + "print(\"PCA-BASED SHAPE OPTIMIZATION\")\n", + "print(\"=\" * 70)\n", + "print(\"\\nRunning complete PCA registration pipeline...\")\n", + "print(\" (Starting from ICP-aligned mesh)\")\n", + "\n", + "result = pca_registrar.register(\n", + " pca_number_of_modes=10, # Use first 10 PCA modes\n", + ")\n", + "\n", + "dm = pca_registrar.fixed_distance_map\n", + "itk.imwrite(dm, str(output_dir / \"target_distance_map.mha\"))\n", + "\n", + "pca_registered_model_surface = result[\"registered_model\"]\n", + "\n", + "print(\"\\n✓ PCA registration complete\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Display Registration Results\n", + "\n", + "Review the optimization results from the PCA registration pipeline.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2026-02-09T06:36:49.986008Z", + "iopub.status.busy": "2026-02-09T06:36:49.986008Z", + "iopub.status.idle": "2026-02-09T06:36:49.999829Z", + "shell.execute_reply": "2026-02-09T06:36:49.998876Z" + } + }, + "outputs": [], + "source": [ + "print(\"\\n\" + \"=\" * 70)\n", + "print(\"REGISTRATION RESULTS\")\n", + "print(\"=\" * 70)\n", + "\n", + "# Display results\n", + "print(\"\\nFinal Registration Metrics:\")\n", + "print(f\" Final mean intensity: {result['mean_distance']:.4f}\")\n", + "\n", + "print(\"\\nOptimized PCA Coefficients (in units of std deviations):\")\n", + "for i, coef in enumerate(result[\"pca_coefficients\"]):\n", + " print(f\" Mode {i + 1:2d}: {coef:7.4f}\")\n", + "\n", + "print(\"\\n✓ Registration pipeline complete!\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Save Registration Results\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2026-02-09T06:36:50.001252Z", + "iopub.status.busy": "2026-02-09T06:36:50.001252Z", + "iopub.status.idle": "2026-02-09T06:36:51.739511Z", + "shell.execute_reply": "2026-02-09T06:36:51.738563Z" + } + }, + "outputs": [], + "source": [ + "print(\"\\nSaving results...\")\n", + "\n", + "# Save final PCA-registered mesh\n", + "pca_registered_model_surface.save(str(output_dir / \"pca_registered_model_surface.vtk\"))\n", + "print(\" Saved final PCA-registered mesh\")\n", + "\n", + "ref_image = contour_tools.create_reference_image(pca_registered_model_surface)\n", + "\n", + "distance_map = contour_tools.create_distance_map(\n", + " pca_registered_model_surface,\n", + " ref_image,\n", + " squared_distance=True,\n", + " negative_inside=False,\n", + " zero_inside=True,\n", + " norm_to_max_distance=200.0,\n", + ")\n", + "\n", + "itk.imwrite(distance_map, str(output_dir / \"pca_distance_map.mha\"))\n", + "\n", + "# Save PCA coefficients\n", + "np.savetxt(\n", + " str(output_dir / \"pca_coefficients.txt\"),\n", + " result[\"pca_coefficients\"],\n", + " header=f\"PCA coefficients for {len(result['pca_coefficients'])} modes\",\n", + ")\n", + "print(\" Saved PCA coefficients\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualize Results\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2026-02-09T06:36:51.741507Z", + "iopub.status.busy": "2026-02-09T06:36:51.741507Z", + "iopub.status.idle": "2026-02-09T06:36:54.084672Z", + "shell.execute_reply": "2026-02-09T06:36:54.082694Z" + } + }, + "outputs": [], + "source": [ + "# Create side-by-side comparison\n", + "plotter = pv.Plotter(shape=(1, 2), window_size=[1000, 600])\n", + "\n", + "plotter.subplot(0, 0)\n", + "plotter.add_mesh(patient_surface, color=\"red\", opacity=1.0, label=\"Patient\")\n", + "plotter.add_mesh(\n", + " icp_registered_model_surface, color=\"green\", opacity=1.0, label=\"ICP Registered\"\n", + ")\n", + "plotter.add_title(\"ICP Shape Fitting\")\n", + "plotter.add_axes()\n", + "\n", + "# After PCA shape fitting\n", + "plotter.subplot(0, 1)\n", + "plotter.add_mesh(patient_surface, color=\"red\", opacity=1.0, label=\"Patient\")\n", + "plotter.add_mesh(\n", + " pca_registered_model_surface, color=\"green\", opacity=1.0, label=\"PCA Registered\"\n", + ")\n", + "plotter.add_title(\"PCA Shape Fitting\")\n", + "plotter.add_axes()\n", + "\n", + "plotter.link_views()\n", + "plotter.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualize PCA Displacement Magnitude\n", + "\n", + "Compute and display the displacement magnitude caused by PCA optimization. This shows how much each point moved from the ICP-aligned mean shape to the final PCA-registered shape." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2026-02-09T06:36:54.086869Z", + "iopub.status.busy": "2026-02-09T06:36:54.086230Z", + "iopub.status.idle": "2026-02-09T06:36:55.292551Z", + "shell.execute_reply": "2026-02-09T06:36:55.291527Z" + } + }, + "outputs": [], + "source": [ + "# Compute displacement from ICP-aligned (mean shape) to PCA-registered shape\n", + "icp_points = icp_registered_model_surface.points\n", + "pca_points = pca_registered_model_surface.points\n", + "\n", + "# Calculate displacement vectors\n", + "displacement_vectors = pca_points - icp_points\n", + "\n", + "# Compute surface normals for the ICP-aligned mesh\n", + "icp_registered_model_with_normals = icp_registered_model_surface.compute_normals(\n", + " point_normals=True, cell_normals=False\n", + ")\n", + "normals = icp_registered_model_with_normals.point_data[\"Normals\"]\n", + "\n", + "# Calculate signed displacement along the normal direction\n", + "# Positive = outward displacement, Negative = inward displacement\n", + "signed_displacement = np.sum(displacement_vectors * normals, axis=1)\n", + "\n", + "# Add displacement as scalar data to the mesh\n", + "pca_registered_model_with_displacement = pca_registered_model_surface.copy()\n", + "pca_registered_model_with_displacement[\"PCA Signed Displacement (mm)\"] = (\n", + " signed_displacement\n", + ")\n", + "\n", + "# Print statistics\n", + "print(\"PCA Signed Displacement Statistics:\")\n", + "print(f\" Mean displacement: {np.mean(signed_displacement):.2f} mm\")\n", + "print(f\" Max displacement (outward): {np.max(signed_displacement):.2f} mm\")\n", + "print(f\" Min displacement (inward): {np.min(signed_displacement):.2f} mm\")\n", + "print(f\" Std displacement: {np.std(signed_displacement):.2f} mm\")\n", + "\n", + "# Visualize the signed displacement with diverging colormap\n", + "# Blue = inward displacement, Red = outward displacement\n", + "plotter = pv.Plotter(window_size=[800, 600])\n", + "plotter.add_mesh(\n", + " pca_registered_model_with_displacement,\n", + " scalars=\"PCA Signed Displacement (mm)\",\n", + " cmap=\"RdBu_r\", # Red for positive (outward), Blue for negative (inward)\n", + " clim=[\n", + " -np.max(np.abs(signed_displacement)),\n", + " np.max(np.abs(signed_displacement)),\n", + " ], # Symmetric color scale\n", + " show_scalar_bar=True,\n", + " scalar_bar_args={\n", + " \"title\": \"PCA Signed Displacement (mm)\\n(Red=Outward, Blue=Inward)\",\n", + " \"vertical\": True,\n", + " \"position_x\": 0.82,\n", + " \"position_y\": 0.1,\n", + " },\n", + ")\n", + "plotter.add_title(\"PCA Signed Displacement on Registered Model\")\n", + "plotter.add_axes()\n", + "plotter.show()\n", + "\n", + "# Save the mesh with displacement data\n", + "pca_registered_model_with_displacement.save(\n", + " str(output_dir / \"pca_registered_model_with_signed_displacement.vtp\")\n", + ")\n", + "print(\"\\n✓ Saved model with signed displacement data\")" + ] } - }, - "outputs": [], - "source": [ - "# Create side-by-side comparison\n", - "plotter = pv.Plotter(shape=(1, 2), window_size=[1000, 600])\n", - "\n", - "plotter.subplot(0, 0)\n", - "plotter.add_mesh(patient_surface, color=\"red\", opacity=1.0, label=\"Patient\")\n", - "plotter.add_mesh(\n", - " icp_registered_model_surface, color=\"green\", opacity=1.0, label=\"ICP Registered\"\n", - ")\n", - "plotter.add_title(\"ICP Shape Fitting\")\n", - "plotter.add_axes()\n", - "\n", - "# After PCA shape fitting\n", - "plotter.subplot(0, 1)\n", - "plotter.add_mesh(patient_surface, color=\"red\", opacity=1.0, label=\"Patient\")\n", - "plotter.add_mesh(\n", - " pca_registered_model_surface, color=\"green\", opacity=1.0, label=\"PCA Registered\"\n", - ")\n", - "plotter.add_title(\"PCA Shape Fitting\")\n", - "plotter.add_axes()\n", - "\n", - "plotter.link_views()\n", - "plotter.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Visualize PCA Displacement Magnitude\n", - "\n", - "Compute and display the displacement magnitude caused by PCA optimization. This shows how much each point moved from the ICP-aligned mean shape to the final PCA-registered shape." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "execution": { - "iopub.execute_input": "2026-02-09T06:36:54.086869Z", - "iopub.status.busy": "2026-02-09T06:36:54.086230Z", - "iopub.status.idle": "2026-02-09T06:36:55.292551Z", - "shell.execute_reply": "2026-02-09T06:36:55.291527Z" + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "state": { + "12ebef1bb4134652baada92b4bf41e65": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "2.0.0", + "_view_name": "HTMLView", + "description": "", + "description_allow_html": false, + "layout": "IPY_MODEL_d13b70b19456415a92798dc847c57042", + "placeholder": "​", + "style": "IPY_MODEL_8a5ce0cea571455f908701b9b74e615b", + "tabbable": null, + "tooltip": null, + "value": "" + } + }, + "8a5ce0cea571455f908701b9b74e615b": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "StyleView", + "background": null, + "description_width": "", + "font_size": null, + "text_color": null + } + }, + "aa7837c5f1364c6c9086d42707ea1ac5": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "2.0.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border_bottom": null, + "border_left": null, + "border_right": null, + "border_top": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "b493ad90f44d494ebad4e99da0e2eb5f": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "2.0.0", + "_view_name": "HTMLView", + "description": "", + "description_allow_html": false, + "layout": "IPY_MODEL_aa7837c5f1364c6c9086d42707ea1ac5", + "placeholder": "​", + "style": "IPY_MODEL_c21c90fea7bb4fa490ee1b0ca2073d14", + "tabbable": null, + "tooltip": null, + "value": "" + } + }, + "c21c90fea7bb4fa490ee1b0ca2073d14": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "StyleView", + "background": null, + "description_width": "", + "font_size": null, + "text_color": null + } + }, + "d13b70b19456415a92798dc847c57042": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "2.0.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border_bottom": null, + "border_left": null, + "border_right": null, + "border_top": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + } + }, + "version_major": 2, + "version_minor": 0 + } } - }, - "outputs": [], - "source": [ - "# Compute displacement from ICP-aligned (mean shape) to PCA-registered shape\n", - "icp_points = icp_registered_model_surface.points\n", - "pca_points = pca_registered_model_surface.points\n", - "\n", - "# Calculate displacement vectors\n", - "displacement_vectors = pca_points - icp_points\n", - "\n", - "# Compute surface normals for the ICP-aligned mesh\n", - "icp_registered_model_with_normals = icp_registered_model_surface.compute_normals(\n", - " point_normals=True, cell_normals=False\n", - ")\n", - "normals = icp_registered_model_with_normals.point_data[\"Normals\"]\n", - "\n", - "# Calculate signed displacement along the normal direction\n", - "# Positive = outward displacement, Negative = inward displacement\n", - "signed_displacement = np.sum(displacement_vectors * normals, axis=1)\n", - "\n", - "# Add displacement as scalar data to the mesh\n", - "pca_registered_model_with_displacement = pca_registered_model_surface.copy()\n", - "pca_registered_model_with_displacement[\"PCA Signed Displacement (mm)\"] = (\n", - " signed_displacement\n", - ")\n", - "\n", - "# Print statistics\n", - "print(\"PCA Signed Displacement Statistics:\")\n", - "print(f\" Mean displacement: {np.mean(signed_displacement):.2f} mm\")\n", - "print(f\" Max displacement (outward): {np.max(signed_displacement):.2f} mm\")\n", - "print(f\" Min displacement (inward): {np.min(signed_displacement):.2f} mm\")\n", - "print(f\" Std displacement: {np.std(signed_displacement):.2f} mm\")\n", - "\n", - "# Visualize the signed displacement with diverging colormap\n", - "# Blue = inward displacement, Red = outward displacement\n", - "plotter = pv.Plotter(window_size=[800, 600])\n", - "plotter.add_mesh(\n", - " pca_registered_model_with_displacement,\n", - " scalars=\"PCA Signed Displacement (mm)\",\n", - " cmap=\"RdBu_r\", # Red for positive (outward), Blue for negative (inward)\n", - " clim=[\n", - " -np.max(np.abs(signed_displacement)),\n", - " np.max(np.abs(signed_displacement)),\n", - " ], # Symmetric color scale\n", - " show_scalar_bar=True,\n", - " scalar_bar_args={\n", - " \"title\": \"PCA Signed Displacement (mm)\\n(Red=Outward, Blue=Inward)\",\n", - " \"vertical\": True,\n", - " \"position_x\": 0.82,\n", - " \"position_y\": 0.1,\n", - " },\n", - ")\n", - "plotter.add_title(\"PCA Signed Displacement on Registered Model\")\n", - "plotter.add_axes()\n", - "plotter.show()\n", - "\n", - "# Save the mesh with displacement data\n", - "pca_registered_model_with_displacement.save(\n", - " str(output_dir / \"pca_registered_model_with_signed_displacement.vtp\")\n", - ")\n", - "print(\"\\n✓ Saved model with signed displacement data\")" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.11" }, - "widgets": { - "application/vnd.jupyter.widget-state+json": { - "state": { - "12ebef1bb4134652baada92b4bf41e65": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "2.0.0", - "_view_name": "HTMLView", - "description": "", - "description_allow_html": false, - "layout": "IPY_MODEL_d13b70b19456415a92798dc847c57042", - "placeholder": "​", - "style": "IPY_MODEL_8a5ce0cea571455f908701b9b74e615b", - "tabbable": null, - "tooltip": null, - "value": "" - } - }, - "8a5ce0cea571455f908701b9b74e615b": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "HTMLStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "HTMLStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "StyleView", - "background": null, - "description_width": "", - "font_size": null, - "text_color": null - } - }, - "aa7837c5f1364c6c9086d42707ea1ac5": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "2.0.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "2.0.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border_bottom": null, - "border_left": null, - "border_right": null, - "border_top": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "b493ad90f44d494ebad4e99da0e2eb5f": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "2.0.0", - "_view_name": "HTMLView", - "description": "", - "description_allow_html": false, - "layout": "IPY_MODEL_aa7837c5f1364c6c9086d42707ea1ac5", - "placeholder": "​", - "style": "IPY_MODEL_c21c90fea7bb4fa490ee1b0ca2073d14", - "tabbable": null, - "tooltip": null, - "value": "" - } - }, - "c21c90fea7bb4fa490ee1b0ca2073d14": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "HTMLStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "HTMLStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "StyleView", - "background": null, - "description_width": "", - "font_size": null, - "text_color": null - } - }, - "d13b70b19456415a92798dc847c57042": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "2.0.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "2.0.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border_bottom": null, - "border_left": null, - "border_right": null, - "border_top": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - } - }, - "version_major": 2, - "version_minor": 0 - } - } - }, - "nbformat": 4, - "nbformat_minor": 4 + "nbformat": 4, + "nbformat_minor": 4 } diff --git a/experiments/Heart-Statistical_Model_To_Patient/heart_model_to_patient.ipynb b/experiments/Heart-Statistical_Model_To_Patient/heart_model_to_patient.ipynb index dcd7a59..6c08ef6 100644 --- a/experiments/Heart-Statistical_Model_To_Patient/heart_model_to_patient.ipynb +++ b/experiments/Heart-Statistical_Model_To_Patient/heart_model_to_patient.ipynb @@ -20,6 +20,7 @@ }, "outputs": [], "source": [ + "import json\n", "import os\n", "from pathlib import Path\n", "\n", @@ -31,7 +32,7 @@ "from physiomotion4d import (\n", " ContourTools,\n", " SegmentChestTotalSegmentator,\n", - " WorkflowRegisterHeartModelToPatient,\n", + " WorkflowFitStatisticalModelToPatient,\n", ")" ] }, @@ -214,16 +215,22 @@ }, "outputs": [], "source": [ - "registrar = WorkflowRegisterHeartModelToPatient(\n", + "with open(pca_json_path, encoding=\"utf-8\") as f:\n", + " pca_model = json.load(f)\n", + "registrar = WorkflowFitStatisticalModelToPatient(\n", " template_model=template_model,\n", + " patient_models=[patient_model],\n", + " patient_image=patient_image,\n", + ")\n", + "registrar.set_use_pca_registration(\n", + " True, pca_model=pca_model, pca_number_of_modes=pca_n_modes\n", + ")\n", + "registrar.set_use_mask_to_image_registration(\n", + " True,\n", " template_labelmap=template_labelmap,\n", - " template_labelmap_heart_muscle_ids=[1],\n", - " template_labelmap_chamber_ids=[2, 3, 4, 5],\n", + " template_labelmap_organ_mesh_ids=[1],\n", + " template_labelmap_organ_extra_ids=[2, 3, 4, 5],\n", " template_labelmap_background_ids=[6],\n", - " patient_image=patient_image,\n", - " patient_models=[patient_model],\n", - " pca_json_filename=pca_json_path,\n", - " pca_number_of_modes=pca_n_modes,\n", ")\n", "\n", "registrar.set_mask_dilation_mm(0)\n", diff --git a/experiments/README.md b/experiments/README.md index aac93f2..fb07670 100644 --- a/experiments/README.md +++ b/experiments/README.md @@ -14,7 +14,7 @@ of the PhysioMotion4D library. - Command-line tools and parameter specifications See: -- **CLI Commands**: Run `physiomotion4d-heart-gated-ct --help` and `physiomotion4d-register-heart-model --help` +- **CLI Commands**: Run `physiomotion4d-heart-gated-ct --help`, `physiomotion4d-create-statistical-model --help`, and `physiomotion4d-fit-statistical-model-to-patient --help` - **CLI Implementation**: `src/physiomotion4d/cli/` for Python API examples - **Library Classes**: `src/physiomotion4d/` for all workflow and utility classes @@ -36,7 +36,7 @@ These experiments demonstrate key digital twin workflows that can be adapted to anatomical regions, imaging modalities, and physiological motion tasks. > **Note:** For production implementations of these workflows, use the CLI commands -> (`physiomotion4d-heart-gated-ct`, `physiomotion4d-register-heart-model`) or consult +> (`physiomotion4d-heart-gated-ct`, `physiomotion4d-create-statistical-model`, `physiomotion4d-fit-statistical-model-to-patient`) or consult > the CLI implementation in `src/physiomotion4d/cli/` for proper class usage and parameter specifications. ### `Reconstruct4DCT` - High-Resolution 4D Reconstruction @@ -275,7 +275,8 @@ Each subdirectory represents a different experimental domain: 1. **CLI Commands** ⭐ **PRIMARY RESOURCE** - `physiomotion4d-heart-gated-ct` - Complete heart-gated CT workflow - - `physiomotion4d-register-heart-model` - Model-to-patient registration + - `physiomotion4d-create-statistical-model` - Create PCA statistical shape model from sample meshes + - `physiomotion4d-fit-statistical-model-to-patient` - Model-to-patient registration - Run with `--help` for all options and parameter specifications - Tested on diverse datasets @@ -336,7 +337,7 @@ The typical evolution path was: When exploring new digital twin applications, you can follow a similar path: - Start by understanding relevant experiments here as conceptual references - Examine CLI implementations in `src/physiomotion4d/cli/` for proper library usage -- Use CLI commands (`physiomotion4d-heart-gated-ct`, `physiomotion4d-register-heart-model`) as starting points +- Use CLI commands (`physiomotion4d-heart-gated-ct`, `physiomotion4d-create-statistical-model`, `physiomotion4d-fit-statistical-model-to-patient`) as starting points - Extend and adapt production code with your domain-specific requirements - Contribute back improvements and new capabilities to the community diff --git a/pyproject.toml b/pyproject.toml index afac15a..1ea2f3d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -163,7 +163,9 @@ Changelog = "https://github.com/aylward/PhysioMotion4d/blob/main/CHANGELOG.md" # CLI commands installed via pip # Entry points reference the main() functions in the cli submodule physiomotion4d-heart-gated-ct = "physiomotion4d.cli.convert_heart_gated_ct_to_usd:main" -physiomotion4d-register-heart-model = "physiomotion4d.cli.register_heart_model_to_patient:main" +physiomotion4d-create-statistical-model = "physiomotion4d.cli.create_statistical_model:main" +physiomotion4d-fit-statistical-model-to-patient = "physiomotion4d.cli.fit_statistical_model_to_patient:main" +physiomotion4d-visualize-pca-modes = "physiomotion4d.cli.visualize_pca_modes:main" [tool.setuptools.packages.find] where = ["src"] @@ -173,6 +175,7 @@ current_version = "2025.05.0" version_pattern = "YYYY.0M.PATCH[PYTAGNUM]" commit_message = "Bump version {old_version} -> {new_version}" commit = true +tag_message = "{new_version}" tag = true push = false @@ -201,42 +204,20 @@ warn_unreachable = true strict_equality = true show_error_codes = true exclude = '(?x)(^src[/\\\\]physiomotion4d[/\\\\]network_weights[/\\\\]vista3d[/\\\\])' +# Third-party libs (itk, pyvista, pxr, vtk, simpleware, etc.) have no stubs +disable_error_code = ["import-not-found", "import-untyped"] +# When type-checking our package, ignore missing imports for third-party libs. +# Override applies to the module being checked (physiomotion4d.*). [[tool.mypy.overrides]] -module = [ - "ants", - "cupy", - "huggingface_hub", - "hugging_face_pipeline", - "icon_registration.*", - "ignite.*", - "itk.*", - "matplotlib", - "matplotlib.*", - "monai.*", - "nibabel.*", - "nrrd", - "numpy", - "numpy.*", - "pxr", - "pynrrd.*", - "pyvista.*", - "scripts.*", - "SimpleITK", - "simpleware.*", - "torch", - "torch.*", - "totalsegmentator.*", - "transformers", - "transformers.*", - "trimesh", - "unigradicon.*", - "vista3d", - "vista3d.*", - "vtk.*" -] +module = ["physiomotion4d", "physiomotion4d.*"] ignore_missing_imports = true +[tool.pyright] +# Third-party packages (e.g. pyvista) are in dependencies but may have no stubs; +# do not report import-not-found so analysis matches mypy overrides above. +reportMissingImports = false + [tool.pytest.ini_options] minversion = "7.0" addopts = [ diff --git a/src/physiomotion4d/__init__.py b/src/physiomotion4d/__init__.py index df48af7..16e27cc 100644 --- a/src/physiomotion4d/__init__.py +++ b/src/physiomotion4d/__init__.py @@ -45,11 +45,12 @@ from .register_time_series_images import RegisterTimeSeriesImages # Segmentation classes -from .segment_chest_base import SegmentChestBase +from .segment_anatomy_base import SegmentAnatomyBase from .segment_chest_ensemble import SegmentChestEnsemble from .segment_chest_total_segmentator import SegmentChestTotalSegmentator from .segment_chest_vista_3d import SegmentChestVista3D from .segment_chest_vista_3d_nim import SegmentChestVista3DNIM +from .segment_heart_simpleware import SegmentHeartSimpleware from .transform_tools import TransformTools from .usd_anatomy_tools import USDAnatomyTools from .usd_tools import USDTools @@ -57,21 +58,24 @@ # Core workflow processor from .workflow_convert_heart_gated_ct_to_usd import WorkflowConvertHeartGatedCTToUSD from .workflow_reconstruct_highres_4d_ct import WorkflowReconstructHighres4DCT -from .workflow_register_heart_model_to_patient import ( - WorkflowRegisterHeartModelToPatient, +from .workflow_create_statistical_model import WorkflowCreateStatisticalModel +from .workflow_fit_statistical_model_to_patient import ( + WorkflowFitStatisticalModelToPatient, ) __all__ = [ # Workflow classes "WorkflowConvertHeartGatedCTToUSD", + "WorkflowCreateStatisticalModel", "WorkflowReconstructHighres4DCT", - "WorkflowRegisterHeartModelToPatient", + "WorkflowFitStatisticalModelToPatient", # Segmentation classes - "SegmentChestBase", + "SegmentAnatomyBase", "SegmentChestEnsemble", "SegmentChestTotalSegmentator", "SegmentChestVista3D", "SegmentChestVista3DNIM", + "SegmentHeartSimpleware", # Registration classes "RegisterImagesBase", "RegisterImagesICON", diff --git a/src/physiomotion4d/cli/__init__.py b/src/physiomotion4d/cli/__init__.py index c40248b..632cf1c 100644 --- a/src/physiomotion4d/cli/__init__.py +++ b/src/physiomotion4d/cli/__init__.py @@ -1,3 +1,8 @@ """Command-line interface modules for PhysioMotion4D.""" -__all__ = ["convert_heart_gated_ct_to_usd", "register_heart_model_to_patient"] +__all__ = [ + "convert_heart_gated_ct_to_usd", + "create_statistical_model", + "fit_statistical_model_to_patient", + "visualize_pca_modes", +] diff --git a/src/physiomotion4d/cli/create_statistical_model.py b/src/physiomotion4d/cli/create_statistical_model.py new file mode 100644 index 0000000..635ab5d --- /dev/null +++ b/src/physiomotion4d/cli/create_statistical_model.py @@ -0,0 +1,175 @@ +#!/usr/bin/env python +""" +Command-line interface for Create Statistical Model workflow. + +This script provides a CLI to build a PCA statistical shape model from a sample +of meshes aligned to a reference mesh, as in the Heart-Create_Statistical_Model +experiment notebooks. Outputs include pca_mean_surface.vtp, pca_mean.vtu (if +reference is volumetric), and pca_model.json. +""" + +import argparse +import json +import os +import sys +import traceback +from pathlib import Path + +import pyvista as pv + +from physiomotion4d import WorkflowCreateStatisticalModel + + +def main() -> int: + """Command-line interface for create statistical model workflow.""" + parser = argparse.ArgumentParser( + description="Create a PCA statistical shape model from sample meshes aligned to a reference", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Create model from a directory of sample meshes and a reference mesh + %(prog)s \\ + --sample-meshes-dir ./meshes \\ + --reference-mesh average_mesh.vtk \\ + --output-dir ./pca_model + + # Specify sample meshes explicitly + %(prog)s \\ + --sample-meshes 01.vtk 02.vtk 03.vtu \\ + --reference-mesh average_mesh.vtk \\ + --output-dir ./pca_model + + # Custom PCA components + %(prog)s \\ + --sample-meshes-dir ./meshes \\ + --reference-mesh average_mesh.vtk \\ + --output-dir ./pca_model \\ + --pca-components 20 + """, + ) + + # Sample meshes: either a directory or a list of files + group = parser.add_mutually_exclusive_group(required=True) + group.add_argument( + "--sample-meshes-dir", + type=Path, + metavar="DIR", + help="Directory containing sample mesh files (.vtk, .vtu, .vtp)", + ) + group.add_argument( + "--sample-meshes", + nargs="+", + type=Path, + metavar="PATH", + help="Paths to sample mesh files", + ) + + parser.add_argument( + "--reference-mesh", + type=Path, + required=True, + metavar="PATH", + help="Path to reference mesh; its surface is used to align all samples", + ) + parser.add_argument( + "--output-dir", + type=Path, + required=True, + metavar="DIR", + help="Output directory for pca_mean_surface.vtp, pca_mean.vtu, pca_model.json", + ) + + parser.add_argument( + "--pca-components", + type=int, + default=15, + help="Number of PCA components to retain (default: 15)", + ) + + args = parser.parse_args() + + # Resolve sample mesh paths + if args.sample_meshes_dir is not None: + smd = Path(args.sample_meshes_dir) + sample_paths: list[Path] = [] + for ext in [".vtk", ".vtp", ".vtu"]: + sample_paths.extend(smd.glob(f"*{ext}")) + sample_paths = sorted(set(sample_paths)) + if not sample_paths: + print( + f"Error: No .vtk, .vtu, or .vtp files found in {args.sample_meshes_dir}" + ) + return 1 + else: + sample_paths = args.sample_meshes + + # Validate paths + print("Validating input files...") + if not args.reference_mesh.exists(): + print(f"Error: Reference mesh not found: {args.reference_mesh}") + return 1 + for p in sample_paths: + if not p.exists(): + print(f"Error: Sample mesh not found: {p}") + return 1 + + os.makedirs(args.output_dir, exist_ok=True) + + # Load meshes + print("\nLoading meshes...") + try: + print(f" Reference mesh: {args.reference_mesh}") + reference_mesh = pv.read(args.reference_mesh) + print(f" Sample meshes: {len(sample_paths)} files") + sample_meshes = [pv.read(p) for p in sample_paths] + except (FileNotFoundError, OSError, RuntimeError) as e: + print(f"Error loading meshes: {e}") + traceback.print_exc() + return 1 + + # Run workflow + print("\nInitializing create statistical model workflow...") + try: + workflow = WorkflowCreateStatisticalModel( + sample_meshes=sample_meshes, + reference_mesh=reference_mesh, + pca_number_of_components=args.pca_components, + ) + except (ValueError, RuntimeError) as e: + print(f"Error initializing workflow: {e}") + traceback.print_exc() + return 1 + + try: + print("\nRunning pipeline...") + print("=" * 70) + result = workflow.run_workflow() + print("=" * 70) + print("\nSaving outputs...") + + out_surface = args.output_dir / "pca_mean_surface.vtp" + result["pca_mean_surface"].save(str(out_surface)) + print(f" pca_mean_surface: {out_surface}") + + if result.get("pca_mean_mesh") is not None: + out_mesh = args.output_dir / "pca_mean.vtu" + result["pca_mean_mesh"].save(str(out_mesh)) + print(f" pca_mean_mesh: {out_mesh}") + + out_json = args.output_dir / "pca_model.json" + with open(out_json, "w", encoding="utf-8") as f: + json.dump(result["pca_model"], f, indent=4) + print(f" pca_model: {out_json}") + + print("\nCreate statistical model completed successfully.") + print(f"Outputs written to: {args.output_dir}") + return 0 + + except (RuntimeError, ValueError, OSError) as e: + print(f"\nError during workflow: {e}") + traceback.print_exc() + return 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/src/physiomotion4d/cli/register_heart_model_to_patient.py b/src/physiomotion4d/cli/fit_statistical_model_to_patient.py similarity index 78% rename from src/physiomotion4d/cli/register_heart_model_to_patient.py rename to src/physiomotion4d/cli/fit_statistical_model_to_patient.py index 1955174..56c86d5 100644 --- a/src/physiomotion4d/cli/register_heart_model_to_patient.py +++ b/src/physiomotion4d/cli/fit_statistical_model_to_patient.py @@ -8,6 +8,7 @@ """ import argparse +import json import os import sys import traceback @@ -15,7 +16,7 @@ import itk import pyvista as pv -from physiomotion4d import WorkflowRegisterHeartModelToPatient +from physiomotion4d import WorkflowFitStatisticalModelToPatient def main() -> int: @@ -25,39 +26,36 @@ def main() -> int: formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: - # Basic registration with required inputs + # Basic registration (no patient image: reference image created from patient models) %(prog)s \\ --template-model heart_model.vtu \\ - --template-labelmap heart_labelmap.nii.gz \\ --patient-models lv.vtp rv.vtp myo.vtp \\ - --patient-image patient_ct.nii.gz \\ --output-dir ./results - # Registration with PCA shape fitting + # With patient image and PCA shape fitting %(prog)s \\ --template-model heart_model.vtu \\ - --template-labelmap heart_labelmap.nii.gz \\ --patient-models lv.vtp rv.vtp myo.vtp \\ --patient-image patient_ct.nii.gz \\ --pca-json pca_model.json \\ --pca-number-of-modes 10 \\ --output-dir ./results - # Registration with custom label IDs + # Enable mask-to-image refinement (requires template labelmap and label IDs) %(prog)s \\ --template-model heart_model.vtu \\ - --template-labelmap heart_labelmap.nii.gz \\ - --patient-models lv.vtp rv.vtp \\ + --patient-models lv.vtp rv.vtp myo.vtp \\ --patient-image patient_ct.nii.gz \\ + --mask-to-image \\ + --template-labelmap heart_labelmap.nii.gz \\ --template-labelmap-muscle-ids 1 2 3 \\ --template-labelmap-chamber-ids 4 5 6 \\ --template-labelmap-background-ids 0 \\ --output-dir ./results - # Registration with ICON refinement + # With ICON refinement %(prog)s \\ --template-model heart_model.vtu \\ - --template-labelmap heart_labelmap.nii.gz \\ --patient-models lv.vtp rv.vtp \\ --patient-image patient_ct.nii.gz \\ --use-icon-refinement \\ @@ -71,11 +69,6 @@ def main() -> int: required=True, help="Path to template/generic heart model (.vtu, .vtk, .stl)", ) - parser.add_argument( - "--template-labelmap", - required=True, - help="Path to template labelmap image (.nii.gz, .nrrd, .mha)", - ) parser.add_argument( "--patient-models", nargs="+", @@ -84,8 +77,11 @@ def main() -> int: ) parser.add_argument( "--patient-image", - required=True, - help="Path to patient CT/MRI image (.nii.gz, .nrrd, .mha)", + help="Path to patient CT/MRI image (.nii.gz, .nrrd, .mha). If omitted, a reference image is created from the patient models.", + ) + parser.add_argument( + "--template-labelmap", + help="Path to template labelmap image (.nii.gz, .nrrd, .mha). Required when --mask-to-image is set.", ) parser.add_argument( "--output-dir", required=True, help="Output directory for results" @@ -135,11 +131,11 @@ def main() -> int: help="Disable mask-to-mask deformable registration", ) parser.add_argument( - "--no-mask-to-image", + "--mask-to-image", dest="use_mask_to_image", - action="store_false", - default=True, - help="Disable mask-to-image refinement registration", + action="store_true", + default=False, + help="Enable mask-to-image refinement (requires --template-labelmap and label IDs)", ) parser.add_argument( "--use-icon-refinement", @@ -163,19 +159,23 @@ def main() -> int: print(f"Error: Template model not found: {args.template_model}") return 1 - if not os.path.exists(args.template_labelmap): - print(f"Error: Template labelmap not found: {args.template_labelmap}") - return 1 - for patient_model in args.patient_models: if not os.path.exists(patient_model): print(f"Error: Patient model not found: {patient_model}") return 1 - if not os.path.exists(args.patient_image): + if args.patient_image is not None and not os.path.exists(args.patient_image): print(f"Error: Patient image not found: {args.patient_image}") return 1 + if args.use_mask_to_image: + if args.template_labelmap is None: + print("Error: --template-labelmap is required when --mask-to-image is set.") + return 1 + if not os.path.exists(args.template_labelmap): + print(f"Error: Template labelmap not found: {args.template_labelmap}") + return 1 + if args.pca_json and not os.path.exists(args.pca_json): print(f"Error: PCA JSON file not found: {args.pca_json}") return 1 @@ -193,17 +193,25 @@ def main() -> int: ) template_model: pv.UnstructuredGrid = template_model_raw - print(f" Loading template labelmap: {args.template_labelmap}") - template_labelmap = itk.imread(args.template_labelmap) - print(" Loading patient models:") patient_models = [] for patient_model_file in args.patient_models: print(f" - {patient_model_file}") patient_models.append(pv.read(patient_model_file)) - print(f" Loading patient image: {args.patient_image}") - patient_image = itk.imread(args.patient_image) + if args.patient_image is not None: + print(f" Loading patient image: {args.patient_image}") + patient_image = itk.imread(args.patient_image) + else: + patient_image = None + print( + " No patient image: reference image will be created from patient models" + ) + + template_labelmap = None + if args.template_labelmap is not None: + print(f" Loading template labelmap: {args.template_labelmap}") + template_labelmap = itk.imread(args.template_labelmap) except (FileNotFoundError, OSError, RuntimeError) as e: print(f"Error loading input data: {e}") @@ -213,17 +221,27 @@ def main() -> int: # Initialize workflow print("\nInitializing heart model to patient registration workflow...") try: - workflow = WorkflowRegisterHeartModelToPatient( + workflow = WorkflowFitStatisticalModelToPatient( template_model=template_model, - template_labelmap=template_labelmap, - template_labelmap_heart_muscle_ids=args.template_labelmap_muscle_ids, - template_labelmap_chamber_ids=args.template_labelmap_chamber_ids, - template_labelmap_background_ids=args.template_labelmap_background_ids, patient_models=patient_models, patient_image=patient_image, - pca_json_filename=args.pca_json, - pca_number_of_modes=args.pca_number_of_modes, ) + if args.pca_json is not None: + with open(args.pca_json, encoding="utf-8") as f: + pca_model = json.load(f) + workflow.set_use_pca_registration( + True, + pca_model=pca_model, + pca_number_of_modes=args.pca_number_of_modes, + ) + if args.use_mask_to_image: + workflow.set_use_mask_to_image_registration( + True, + template_labelmap=template_labelmap, + template_labelmap_organ_mesh_ids=args.template_labelmap_muscle_ids, + template_labelmap_organ_extra_ids=args.template_labelmap_chamber_ids, + template_labelmap_background_ids=args.template_labelmap_background_ids, + ) except (ValueError, RuntimeError, OSError) as e: print(f"Error initializing workflow: {e}") traceback.print_exc() diff --git a/src/physiomotion4d/cli/visualize_pca_modes.py b/src/physiomotion4d/cli/visualize_pca_modes.py new file mode 100644 index 0000000..5db4eb6 --- /dev/null +++ b/src/physiomotion4d/cli/visualize_pca_modes.py @@ -0,0 +1,261 @@ +#!/usr/bin/env python +""" +Command-line interface to visualize PCA modes of variation. + +Displays the first three principal components in a 1x3 PyVista plotter. A +slider (0 to 4) controls the standard-deviation magnitude; each subplot +shows the mean shape (gray), +sigma shape (coral), and -sigma shape (blue) +for that PC (matching experiments/Heart-Create_Statistical_Model/ +5-compute_pca_model.ipynb cell 11). + +Inputs: pca_mean_surface.vtp (mean mesh topology and mean shape) and +pca_model.json (components and eigenvalues). +""" + +import argparse +import json +import sys +import traceback +from pathlib import Path + +import numpy as np +import pyvista as pv + + +def _shape_at_sigma( + mean_shape: np.ndarray, + components: list, + eigenvalues: list, + pc_index: int, + sigma: float, +) -> np.ndarray: + """Return (n_points, 3) mesh points for mean + sigma * std_dev * component.""" + pc = np.asarray(components[pc_index], dtype=np.float64) + std_dev = np.sqrt(eigenvalues[pc_index]) + variation = mean_shape + (sigma * std_dev * pc) + n_points = mean_shape.size // 3 + return variation.reshape(n_points, 3) + + +def _generate_pc_variation( + mean_mesh: pv.PolyData, + mean_shape: np.ndarray, + components: list, + eigenvalues: list, + pc_index: int, + std_dev_multiplier: float = 3.0, +) -> tuple[pv.PolyData, pv.PolyData, pv.PolyData]: + """Generate shape variations along a principal component. + + Parameters + ---------- + mean_mesh : pv.PolyData + Template mesh (topology only; points are replaced). + mean_shape : np.ndarray + Flattened mean shape (n_points * 3,). + components : list + List of component vectors (each length n_points * 3). + eigenvalues : list + Variance (eigenvalue) for each component. + pc_index : int + Index of the principal component (0-based). + std_dev_multiplier : float + How many standard deviations to vary (default: +/-3 sigma). + + Returns + ------- + tuple of (negative_mesh, mean_mesh, positive_mesh) + """ + pc = np.asarray(components[pc_index], dtype=np.float64) + std_dev = np.sqrt(eigenvalues[pc_index]) + + negative_variation = mean_shape - (std_dev_multiplier * std_dev * pc) + positive_variation = mean_shape + (std_dev_multiplier * std_dev * pc) + + n_points = mean_shape.size // 3 + negative_points = negative_variation.reshape(n_points, 3) + positive_points = positive_variation.reshape(n_points, 3) + mean_points = mean_shape.reshape(n_points, 3) + + negative_mesh = mean_mesh.copy() + negative_mesh.points = negative_points + + positive_mesh = mean_mesh.copy() + positive_mesh.points = positive_points + + mean_mesh_out = mean_mesh.copy() + mean_mesh_out.points = mean_points + + return negative_mesh, mean_mesh_out, positive_mesh + + +def main() -> int: + """Command-line interface for visualizing PCA modes.""" + parser = argparse.ArgumentParser( + description="Visualize the first three PCA modes with a 0-4 sigma slider; each subplot shows mean, +sigma, -sigma (PyVista 1x3 plot).", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Slider 0-4 sigma; each subplot shows mean (gray), +sigma (coral), -sigma (blue) + %(prog)s pca_mean_surface.vtp pca_model.json + + # Start with slider at 2 sigma + %(prog)s --std-dev 2.0 mean.vtp pca_model.json + """, + ) + + parser.add_argument( + "pca_mean_surface", + type=Path, + metavar="PCA_MEAN_SURFACE.vtp", + help="Path to PCA mean surface (.vtp); provides topology and mean shape.", + ) + parser.add_argument( + "pca_model_json", + type=Path, + metavar="pca_model.json", + help="Path to PCA model JSON (components and eigenvalues).", + ) + parser.add_argument( + "--std-dev", + type=float, + default=0.0, + metavar="N", + help="Initial slider position 0-4 std dev (default: 0).", + ) + + args = parser.parse_args() + + if not args.pca_mean_surface.exists(): + print(f"Error: PCA mean surface not found: {args.pca_mean_surface}") + return 1 + if not args.pca_model_json.exists(): + print(f"Error: PCA model JSON not found: {args.pca_model_json}") + return 1 + + try: + mean_mesh = pv.read(str(args.pca_mean_surface)) + except (OSError, RuntimeError) as e: + print(f"Error loading PCA mean surface: {e}") + traceback.print_exc() + return 1 + + if not isinstance(mean_mesh, pv.PolyData): + print("Error: PCA mean surface must be a PolyData (.vtp).") + return 1 + + try: + with open(args.pca_model_json, encoding="utf-8") as f: + pca_model = json.load(f) + except (OSError, json.JSONDecodeError) as e: + print(f"Error loading PCA model JSON: {e}") + traceback.print_exc() + return 1 + + for key in ("components", "eigenvalues"): + if key not in pca_model: + print(f"Error: PCA model JSON must contain '{key}'.") + return 1 + + components = pca_model["components"] + eigenvalues = pca_model["eigenvalues"] + + if len(components) < 3: + print( + f"Error: PCA model has only {len(components)} component(s); need at least 3 for visualization." + ) + return 1 + + n_points = mean_mesh.n_points + n_features = n_points * 3 + if len(components[0]) != n_features: + print( + f"Error: Mean surface has {n_points} points ({n_features} features), " + f"but PCA components have {len(components[0])} entries. Shapes must match." + ) + return 1 + + mean_shape = mean_mesh.points.astype(np.float64).flatten() + n_points = mean_shape.size // 3 + + # Precompute component arrays for slider updates + pc_arrays = [np.asarray(components[i], dtype=np.float64) for i in range(3)] + std_devs = [np.sqrt(eigenvalues[i]) for i in range(3)] + + plotter = pv.Plotter(shape=(1, 3)) + + # Slider 0 to 4; clamp initial to that range + initial_sigma = max(0.0, min(4.0, args.std_dev)) + + # Mean shape for reference (same in all subplots) + mean_ref = mean_mesh.copy() + mean_ref.points = mean_shape.reshape(n_points, 3) + + # Per subplot: mean (static), +sigma mesh, -sigma mesh (updated by slider) + plus_meshes: list[pv.PolyData] = [] + minus_meshes: list[pv.PolyData] = [] + for col, pc_index in enumerate(range(3)): + points_plus = _shape_at_sigma( + mean_shape, components, eigenvalues, pc_index, initial_sigma + ) + points_minus = _shape_at_sigma( + mean_shape, components, eigenvalues, pc_index, -initial_sigma + ) + mesh_plus = mean_mesh.copy() + mesh_plus.points = points_plus + mesh_minus = mean_mesh.copy() + mesh_minus.points = points_minus + plus_meshes.append(mesh_plus) + minus_meshes.append(mesh_minus) + plotter.subplot(0, col) + plotter.add_mesh( + mean_ref.copy(), + color="lightgray", + opacity=0.4, + show_edges=False, + ) + plotter.add_mesh( + mesh_minus, + color="lightblue", + opacity=1.0, + show_edges=False, + ) + plotter.add_mesh( + mesh_plus, + color="lightcoral", + opacity=1.0, + show_edges=False, + ) + plotter.add_text( + f"PC{pc_index + 1}", + font_size=12, + ) + plotter.camera_position = "iso" + + def _on_slider(sigma: float) -> None: + for i in range(3): + plus_pts = mean_shape + (sigma * std_devs[i] * pc_arrays[i]) + minus_pts = mean_shape - (sigma * std_devs[i] * pc_arrays[i]) + plus_meshes[i].points = plus_pts.reshape(n_points, 3) + minus_meshes[i].points = minus_pts.reshape(n_points, 3) + plotter.render() + + # Slider: 0 to 4 std dev (shows mean, +sigma, -sigma) + plotter.subplot(0, 0) + plotter.add_slider_widget( + _on_slider, + rng=(0.0, 4.0), + value=initial_sigma, + title="Std dev", + pointa=(0.02, 0.1), + pointb=(0.98, 0.1), + ) + + plotter.link_views() + plotter.show() + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/src/physiomotion4d/contour_tools.py b/src/physiomotion4d/contour_tools.py index 231bbc2..daf21db 100644 --- a/src/physiomotion4d/contour_tools.py +++ b/src/physiomotion4d/contour_tools.py @@ -265,14 +265,30 @@ def create_distance_map( tmp_arr = np.zeros(size, dtype=np.int32) itk_point = itk.Point[itk.D, 3]() + point_count = 0 for point in points: itk_point[0] = float(point[0]) itk_point[1] = float(point[1]) itk_point[2] = float(point[2]) indx = reference_image.TransformPhysicalPointToIndex(itk_point) + if ( + indx[0] < 0 + or indx[1] < 0 + or indx[2] < 0 + or indx[0] >= size[0] + or indx[1] >= size[1] + or indx[2] >= size[2] + ): + continue tmp_arr[indx[2], indx[1], indx[0]] = 1 + point_count += 1 + tmp_binary_image = itk.GetImageFromArray(tmp_arr.astype(np.uint8)) tmp_binary_image.CopyInformation(reference_image) + assert ( + tmp_binary_image.GetLargestPossibleRegion().GetSize() + == reference_image.GetLargestPossibleRegion().GetSize() + ) distance_filter = itk.SignedMaurerDistanceMapImageFilter.New( Input=tmp_binary_image @@ -319,6 +335,15 @@ def create_deformation_field( itk_point[1] = float(point[1]) itk_point[2] = float(point[2]) indx = reference_image.TransformPhysicalPointToIndex(itk_point) + if ( + indx[0] < 0 + or indx[1] < 0 + or indx[2] < 0 + or indx[0] >= size[0] + or indx[1] >= size[1] + or indx[2] >= size[2] + ): + continue displacement_map_x[int(indx[2]), int(indx[1]), int(indx[0])] = ( point_displacements[i, 0] ) @@ -332,6 +357,10 @@ def create_deformation_field( norm_img = itk.GetImageFromArray(norm_map) norm_img.CopyInformation(reference_image) + assert ( + norm_img.GetLargestPossibleRegion().GetSize() + == reference_image.GetLargestPossibleRegion().GetSize() + ) blurred_norm = itk.SmoothingRecursiveGaussianImageFilter( Input=norm_img, Sigma=blur_sigma diff --git a/src/physiomotion4d/image_tools.py b/src/physiomotion4d/image_tools.py index 3cc9d1d..ccb974d 100644 --- a/src/physiomotion4d/image_tools.py +++ b/src/physiomotion4d/image_tools.py @@ -6,7 +6,7 @@ """ import logging -from typing import Any +from typing import Any, Optional import itk import numpy as np @@ -245,3 +245,40 @@ def convert_array_to_image_of_vectors( itk.array_view_from_image(itk_image)[:] = arr_data return itk_image + + def flip_image_to_identity_direction( + self, in_image: itk.Image, in_mask: Optional[itk.Image] = None + ) -> Any | tuple[Any, Any]: + """ + Flip the image to the identity direction. + """ + flip0 = np.array(in_image.GetDirection())[0, 0] < 0 + flip1 = np.array(in_image.GetDirection())[1, 1] < 0 + flip2 = np.array(in_image.GetDirection())[2, 2] < 0 + if flip0 or flip1 or flip2: + self.log_info( + f"Flipping image to identity direction: {flip0}, {flip1}, {flip2}" + ) + flip_filter = itk.FlipImageFilter.New(Input=in_image) + flip_filter.SetFlipAxes([int(flip0), int(flip1), int(flip2)]) + flip_filter.SetFlipAboutOrigin(True) + flip_filter.Update() + out_image = flip_filter.GetOutput() + id_mat = itk.Matrix[itk.D, 3, 3]() + id_mat.SetIdentity() + out_image.SetDirection(id_mat) + if in_mask is not None: + flip_filter = itk.FlipImageFilter.New(Input=in_mask) + flip_filter.SetFlipAxes([int(flip0), int(flip1), int(flip2)]) + flip_filter.SetFlipAboutOrigin(True) + flip_filter.Update() + out_mask = flip_filter.GetOutput() + out_mask.SetDirection(id_mat) + return out_image, out_mask + else: + return out_image + else: + if in_mask is not None: + return in_image, in_mask + else: + return in_image diff --git a/src/physiomotion4d/physiomotion4d_base.py b/src/physiomotion4d/physiomotion4d_base.py index c4a7451..336a111 100644 --- a/src/physiomotion4d/physiomotion4d_base.py +++ b/src/physiomotion4d/physiomotion4d_base.py @@ -228,7 +228,7 @@ def set_log_classes(cls, class_names: list[str]) -> None: Args: class_names: List of class names to show logs from. - Example: ["RegisterModelsPCA", "WorkflowRegisterHeartModelToPatient"] + Example: ["RegisterModelsPCA", "WorkflowFitStatisticalModelToPatient"] Example: >>> PhysioMotion4DBase.set_log_classes(['RegisterModelsPCA']) @@ -263,7 +263,7 @@ def get_log_classes(cls) -> list[str]: Example: >>> classes = PhysioMotion4DBase.get_log_classes() >>> print(classes) - ['RegisterModelsPCA', 'WorkflowRegisterHeartModelToPatient'] + ['RegisterModelsPCA', 'WorkflowFitStatisticalModelToPatient'] """ if cls._class_filter is not None and cls._class_filter.enabled: return sorted(cls._class_filter.allowed_classes) diff --git a/src/physiomotion4d/register_models_distance_maps.py b/src/physiomotion4d/register_models_distance_maps.py index b50c7da..5ee8ead 100644 --- a/src/physiomotion4d/register_models_distance_maps.py +++ b/src/physiomotion4d/register_models_distance_maps.py @@ -120,7 +120,7 @@ def __init__( moving_model: pv.PolyData, fixed_model: pv.PolyData, reference_image: itk.Image, - roi_dilation_mm: float = 10, + roi_dilation_mm: float = 20, log_level: int | str = logging.INFO, ): """Initialize mask-based model registration. diff --git a/src/physiomotion4d/register_models_pca.py b/src/physiomotion4d/register_models_pca.py index 895990e..952e70e 100644 --- a/src/physiomotion4d/register_models_pca.py +++ b/src/physiomotion4d/register_models_pca.py @@ -96,7 +96,7 @@ def __init__( These are the square roots of pca_eigenvalues pca_number_of_modes: Number of PCA modes to use. Default: -1 (use all) pca_template_model_point_subsample: Step size for subsampling model points. Default: 4 - pre_pca_transform: Optional ITK transform to apply after PCA registration. + pre_pca_transform: Optional ITK transform to apply before PCA registration. Default: None fixed_distance_map: ITK image providing the distance map. Default: None @@ -161,21 +161,10 @@ def __init__( self.registered_model_pca_coefficients: Optional[np.ndarray] = None self.registered_model: Optional[pv.UnstructuredGrid] = None self.registered_model_mean_distance: float = 0.0 - self.register_model_pca_deformation: Optional[np.ndarray] = None + self.registered_model_pca_deformation: Optional[np.ndarray] = None self.forward_point_transform: Optional[itk.DisplacementFieldTransform] = None self.inverse_point_transform: Optional[itk.DisplacementFieldTransform] = None - self._template_model_pca_deformation_field_image: Optional[itk.Image] = None - self._deformation_field_interpolator_x: Optional[ - itk.LinearInterpolateImageFunction - ] = None - self._deformation_field_interpolator_y: Optional[ - itk.LinearInterpolateImageFunction - ] = None - self._deformation_field_interpolator_z: Optional[ - itk.LinearInterpolateImageFunction - ] = None - # Image interpolator (created when needed) self._fixed_distance_map_interpolator: Optional[ itk.LinearInterpolateImageFunction @@ -215,7 +204,7 @@ def from_json( pca_json_filename: Path to the PCA model JSON file pca_number_of_modes: Number of PCA modes to use. Default: 0 (use all) pca_template_model_point_subsample: Step size for subsampling model points. Default: 4 - pre_pca_transform: Optional ITK transform to apply after PCA registration. + pre_pca_transform: Optional ITK transform to apply before PCA registration. Default: None fixed_distance_map: ITK image providing the distance values for registration. If None, must be set later before registration. @@ -283,6 +272,67 @@ def from_json( logger.info(" ✓ Data validation successful!") logger.info("PCA model data loaded successfully!") + return cls.from_pca_model( + pca_template_model=pca_template_model, + pca_model=pca_data, + pca_number_of_modes=pca_number_of_modes, + pca_template_model_point_subsample=pca_template_model_point_subsample, + pre_pca_transform=pre_pca_transform, + fixed_distance_map=fixed_distance_map, + fixed_model=fixed_model, + reference_image=reference_image, + log_level=log_level, + ) + + @classmethod + def from_pca_model( + cls, + pca_template_model: pv.UnstructuredGrid, + pca_model: dict, + pca_number_of_modes: int = 0, + pca_template_model_point_subsample: int = 4, + pre_pca_transform: Optional[itk.Transform] = None, + fixed_distance_map: Optional[itk.Image] = None, + fixed_model: Optional[pv.UnstructuredGrid] = None, + reference_image: Optional[itk.Image] = None, + log_level: int | str = logging.INFO, + ) -> Self: + """Create RegisterModelsPCA from a PCA model dictionary. + + The dict must match the structure produced by + :class:`WorkflowCreateStatisticalModel` (key ``pca_model``): + ``explained_variance_ratio``, ``eigenvalues``, ``components``. + + Args: + pca_template_model: Mean surface mesh to use as template + pca_model: PCA model dict with 'eigenvalues' and 'components' (and optionally + 'explained_variance_ratio') + pca_number_of_modes: Number of PCA modes to use. Default: 0 (use all) + pca_template_model_point_subsample: Step size for subsampling model points. Default: 4 + pre_pca_transform: Optional ITK transform to apply before PCA registration. + fixed_distance_map: ITK image providing the distance values for registration. + fixed_model: Target surface mesh to register to. + reference_image: Reference image defining coordinate space. + log_level: Logging level. + + Returns: + RegisterModelsPCA instance + + Raises: + ValueError: If required keys are missing or dimensions invalid + """ + if "eigenvalues" not in pca_model: + raise ValueError("'eigenvalues' field not found in pca_model") + pca_std_deviations = np.sqrt(np.array(pca_model["eigenvalues"])) + if "components" not in pca_model: + raise ValueError("'components' field not found in pca_model") + pca_eigenvectors = np.array(pca_model["components"], dtype=np.float64) + expected_size = pca_template_model.n_points * 3 + if pca_eigenvectors.shape[1] != expected_size: + raise ValueError( + f"Component dimension mismatch: expected {expected_size} " + f"(3 × {pca_template_model.n_points} points), got {pca_eigenvectors.shape[1]}" + ) return cls( pca_template_model=pca_template_model, pca_eigenvectors=pca_eigenvectors, @@ -597,8 +647,8 @@ def transform_template_model(self) -> pv.UnstructuredGrid: self.log_info("Creating final registered model...") # Compute PCA deformation - if self.register_model_pca_deformation is None: - self.register_model_pca_deformation = self._compute_pca_deformation( + if self.registered_model_pca_deformation is None: + self.registered_model_pca_deformation = self._compute_pca_deformation( self.registered_model_pca_coefficients, ) @@ -623,9 +673,9 @@ def transform_template_model(self) -> pv.UnstructuredGrid: point = self.pre_pca_transform.TransformPoint(point) # Add PCA deformation - point[0] += self.register_model_pca_deformation[i, 0] - point[1] += self.register_model_pca_deformation[i, 1] - point[2] += self.register_model_pca_deformation[i, 2] + point[0] += self.registered_model_pca_deformation[i, 0] + point[1] += self.registered_model_pca_deformation[i, 1] + point[2] += self.registered_model_pca_deformation[i, 2] # Store result final_points[i, 0] = point[0] @@ -655,8 +705,10 @@ def transform_point( Returns: Transformed ITK point - Raises: - ValueError: If registration has not been completed yet + Notes: + 1) if the point is outside the image bounds, the point is not transformed. + 2) if the pre_pca_transform is set and enabled, it is applied. + 3) if the forward point transform is not set, no errors are raised. Example: >>> p = itk.Point[itk.D, 3]() @@ -664,81 +716,13 @@ def transform_point( >>> transformed_p = registrar.transform_point(p) """ - if self._deformation_field_interpolator_x is None: - field_array = itk.GetArrayFromImage( - self._template_model_pca_deformation_field_image - ) - field_x_image = itk.GetImageFromArray(field_array[:, :, :, 0]) - field_x_image.CopyInformation( - self._template_model_pca_deformation_field_image - ) - self._deformation_field_interpolator_x = itk.LinearInterpolateImageFunction[ - itk.Image[itk.D, 3], itk.D - ].New() - self._deformation_field_interpolator_x.SetInputImage(field_x_image) - - field_y_image = itk.GetImageFromArray(field_array[:, :, :, 1]) - field_y_image.CopyInformation( - self._template_model_pca_deformation_field_image - ) - self._deformation_field_interpolator_y = itk.LinearInterpolateImageFunction[ - itk.Image[itk.D, 3], itk.D - ].New() - self._deformation_field_interpolator_y.SetInputImage(field_y_image) - - field_z_image = itk.GetImageFromArray(field_array[:, :, :, 2]) - field_z_image.CopyInformation( - self._template_model_pca_deformation_field_image - ) - self._deformation_field_interpolator_z = itk.LinearInterpolateImageFunction[ - itk.Image[itk.D, 3], itk.D - ].New() - self._deformation_field_interpolator_z.SetInputImage(field_z_image) - - assert self._template_model_pca_deformation_field_image is not None, ( - "Deformation field image must be set" - ) - assert self._deformation_field_interpolator_x is not None, ( - "Interpolator x must be initialized" - ) - assert self._deformation_field_interpolator_y is not None, ( - "Interpolator y must be initialized" - ) - assert self._deformation_field_interpolator_z is not None, ( - "Interpolator z must be initialized" - ) if include_pre_pca_transform and self.pre_pca_transform is not None: point = self.pre_pca_transform.TransformPoint(point) - cindx = self._template_model_pca_deformation_field_image.TransformPhysicalPointToContinuousIndex( - point - ) - size = self._template_model_pca_deformation_field_image.GetLargestPossibleRegion().GetSize() - if ( - cindx[0] < 0 - or cindx[0] >= size[0] - or cindx[1] < 0 - or cindx[1] >= size[1] - or cindx[2] < 0 - or cindx[2] >= size[2] - ): - self.log_error("Point is outside deformation field bounds") - return point - - deformation_x = ( - self._deformation_field_interpolator_x.EvaluateAtContinuousIndex(cindx) - ) - deformation_y = ( - self._deformation_field_interpolator_y.EvaluateAtContinuousIndex(cindx) - ) - deformation_z = ( - self._deformation_field_interpolator_z.EvaluateAtContinuousIndex(cindx) - ) - - transformed_point = itk.Point[itk.D, 3]() - transformed_point[0] = float(point[0] + deformation_x) - transformed_point[1] = float(point[1] + deformation_y) - transformed_point[2] = float(point[2] + deformation_z) + if self.forward_point_transform is not None: + transformed_point = self.forward_point_transform.TransformPoint(point) + else: + transformed_point = point return transformed_point @@ -750,13 +734,13 @@ def compute_pca_transforms(self, reference_image: itk.Image) -> dict: - 'forward_point_transform': Forward displacement field transform - 'inverse_point_transform': Inverse displacement field transform """ - assert self.register_model_pca_deformation is not None, ( + assert self.registered_model_pca_deformation is not None, ( "PCA deformation must be computed" ) - self._template_model_pca_deformation_field_image = ( + template_model_pca_deformation_field_image = ( self._contour_tools.create_deformation_field( np.array(self.pca_template_model.points), - self.register_model_pca_deformation, + self.registered_model_pca_deformation, reference_image=reference_image, blur_sigma=2.5, ptype=itk.D, @@ -765,7 +749,7 @@ def compute_pca_transforms(self, reference_image: itk.Image) -> dict: self.forward_point_transform = itk.DisplacementFieldTransform[itk.D, 3].New() self.forward_point_transform.SetDisplacementField( - self._template_model_pca_deformation_field_image + template_model_pca_deformation_field_image ) transform_tools = TransformTools() @@ -830,7 +814,7 @@ def register( ) # Create final registered model - self.register_model_pca_deformation = None + self.registered_model_pca_deformation = None self.registered_model = self.transform_template_model() # Return results as dictionary diff --git a/src/physiomotion4d/segment_chest_base.py b/src/physiomotion4d/segment_anatomy_base.py similarity index 97% rename from src/physiomotion4d/segment_chest_base.py rename to src/physiomotion4d/segment_anatomy_base.py index b6f1f47..8a7e25d 100644 --- a/src/physiomotion4d/segment_chest_base.py +++ b/src/physiomotion4d/segment_anatomy_base.py @@ -1,7 +1,7 @@ -"""Base class for segmenting chest CT images. +"""Base class for segmenting anatomy in CT images. -This module provides the SegmentChestBase class that serves as a foundation -for implementing different chest CT segmentation algorithms. It handles common +This module provides the SegmentAnatomyBase class that serves as a foundation +for implementing different anatomy CT segmentation algorithms. It handles common preprocessing, postprocessing, and anatomical structure organization tasks. """ @@ -15,12 +15,12 @@ from physiomotion4d.physiomotion4d_base import PhysioMotion4DBase -class SegmentChestBase(PhysioMotion4DBase): - """Base class for chest segmentation that provides common functionality for - segmenting chest CT images. +class SegmentAnatomyBase(PhysioMotion4DBase): + """Base class for anatomy segmentation that provides common functionality for + segmenting anatomy in CT images. This class implements preprocessing, postprocessing, and mask creation - methods that are shared across different chest segmentation + methods that are shared across different anatomy segmentation implementations. It defines anatomical structure mappings and provides utilities for image preprocessing, intensity rescaling, and mask generation. @@ -43,7 +43,7 @@ class SegmentChestBase(PhysioMotion4DBase): """ def __init__(self, log_level: int | str = logging.INFO): - """Initialize the SegmentChestBase class. + """Initialize the SegmentAnatomyBase class. Sets up default parameters for image preprocessing and anatomical structure ID mappings. Subclasses should call this constructor and @@ -385,6 +385,9 @@ def segment_connected_component( InsideValue=1, OutsideValue=0, ) + thresh_arr = itk.GetArrayFromImage(thresh_image).astype(np.int16) + thresh_image = itk.GetImageFromArray(thresh_arr) + thresh_image.CopyInformation(preprocessed_image) label_arr = itk.GetArrayFromImage(labelmap_image) if labelmap_ids is None: diff --git a/src/physiomotion4d/segment_chest_ensemble.py b/src/physiomotion4d/segment_chest_ensemble.py index 8e394a7..b51bb3c 100644 --- a/src/physiomotion4d/segment_chest_ensemble.py +++ b/src/physiomotion4d/segment_chest_ensemble.py @@ -12,12 +12,12 @@ import itk import numpy as np -from physiomotion4d.segment_chest_base import SegmentChestBase +from physiomotion4d.segment_anatomy_base import SegmentAnatomyBase from physiomotion4d.segment_chest_total_segmentator import SegmentChestTotalSegmentator from physiomotion4d.segment_chest_vista_3d import SegmentChestVista3D -class SegmentChestEnsemble(SegmentChestBase): +class SegmentChestEnsemble(SegmentAnatomyBase): """ A class that inherits from physioSegmentChest and implements the segmentation method using VISTA3D. diff --git a/src/physiomotion4d/segment_chest_total_segmentator.py b/src/physiomotion4d/segment_chest_total_segmentator.py index c2d8b18..d8bb53b 100644 --- a/src/physiomotion4d/segment_chest_total_segmentator.py +++ b/src/physiomotion4d/segment_chest_total_segmentator.py @@ -2,7 +2,7 @@ This module provides the SegmentChestTotalSegmentator class that implements chest CT segmentation using the TotalSegmentator deep learning model. It inherits -from SegmentChestBase and defines anatomical structure mappings specific to +from SegmentAnatomyBase and defines anatomical structure mappings specific to TotalSegmentator's output labels. """ @@ -15,10 +15,10 @@ import numpy as np from totalsegmentator.python_api import totalsegmentator -from physiomotion4d.segment_chest_base import SegmentChestBase +from physiomotion4d.segment_anatomy_base import SegmentAnatomyBase -class SegmentChestTotalSegmentator(SegmentChestBase): +class SegmentChestTotalSegmentator(SegmentAnatomyBase): """ Chest CT segmentation using TotalSegmentator deep learning model. diff --git a/src/physiomotion4d/segment_chest_vista_3d.py b/src/physiomotion4d/segment_chest_vista_3d.py index 5337960..2373178 100644 --- a/src/physiomotion4d/segment_chest_vista_3d.py +++ b/src/physiomotion4d/segment_chest_vista_3d.py @@ -26,10 +26,10 @@ import torch from huggingface_hub import snapshot_download -from physiomotion4d.segment_chest_base import SegmentChestBase +from physiomotion4d.segment_anatomy_base import SegmentAnatomyBase -class SegmentChestVista3D(SegmentChestBase): +class SegmentChestVista3D(SegmentAnatomyBase): """ Chest CT segmentation using NVIDIA VISTA-3D foundational model. diff --git a/src/physiomotion4d/segment_heart_simpleware.py b/src/physiomotion4d/segment_heart_simpleware.py index fec54b2..be3eba0 100644 --- a/src/physiomotion4d/segment_heart_simpleware.py +++ b/src/physiomotion4d/segment_heart_simpleware.py @@ -2,23 +2,24 @@ This module provides the SegmentHeartSimpleware class that implements heart segmentation using Synopsys Simpleware Medical's ASCardio module. -It inherits from SegmentChestBase and provides heart-specific anatomical +It inherits from SegmentAnatomyBase and provides heart-specific anatomical structure mappings. """ import logging import os import subprocess +import sys import tempfile import itk import numpy as np from itk import TubeTK as tube -from physiomotion4d.segment_chest_base import SegmentChestBase +from physiomotion4d.segment_anatomy_base import SegmentAnatomyBase -class SegmentHeartSimpleware(SegmentChestBase): +class SegmentHeartSimpleware(SegmentAnatomyBase): """ Heart CT segmentation using Simpleware Medical's ASCardio module. @@ -97,6 +98,8 @@ def __init__(self, log_level: int | str = logging.INFO): # From Base Class # self.contrast_mask_ids = {135: "contrast"} + self.trim_mesh_to_essentials = False + self.set_other_and_all_mask_ids() # Path to Simpleware Medical console executable @@ -109,6 +112,14 @@ def __init__(self, log_level: int | str = logging.INFO): "SimplewareScript_heart_segmentation.py", ) + def set_trim_mesh_to_essentials(self, trim_mesh_to_essentials: bool) -> None: + """Set whether to trim mesh to common and critical structures. + + Args: + trim_mesh_to_essentials (bool): Whether to reduce to essential. + """ + self.trim_mesh_to_essentials = trim_mesh_to_essentials + def set_simpleware_executable_path(self, path: str) -> None: """Set the path to the Simpleware Medical console executable. @@ -195,15 +206,43 @@ def segmentation_method(self, preprocessed_image: itk.image) -> itk.image: self.log_info("Command: %s", " ".join(cmd)) try: - # Run Simpleware Medical as a subprocess - result = subprocess.run( + # Run Simpleware Medical as a subprocess. When the process exits, + # the OS frees all of its resources (GPU, memory); no extra + # cleanup is required. Using Popen so we can kill the process + # tree on timeout and ensure no child processes keep holding GPU. + proc = subprocess.Popen( cmd, - input=user_input, - capture_output=True, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, text=True, - check=True, - timeout=600, # 10 minute timeout + start_new_session=( + sys.platform != "win32" + ), # process group on Unix ) + try: + stdout, stderr = proc.communicate( + input=user_input, + timeout=600, # 10 minute timeout + ) + except subprocess.TimeoutExpired: + # Kill process tree so GPU/memory are released (child may have spawned others) + if sys.platform == "win32": + subprocess.run( + ["taskkill", "/F", "/T", "/PID", str(proc.pid)], + capture_output=True, + ) + else: + os.killpg(os.getpgid(proc.pid), 9) + proc.wait() + raise RuntimeError( + "Simpleware Medical segmentation timed out after 600 seconds" + ) + if proc.returncode != 0: + raise subprocess.CalledProcessError( + proc.returncode, cmd, stdout, stderr + ) + result = type("Result", (), {"stdout": stdout, "stderr": stderr})() # Log output from Simpleware if result.stdout: @@ -211,10 +250,6 @@ def segmentation_method(self, preprocessed_image: itk.image) -> itk.image: if result.stderr: self.log_warning("Simpleware stderr:\n%s", result.stderr) - except subprocess.TimeoutExpired as e: - raise RuntimeError( - f"Simpleware Medical segmentation timed out after 600 seconds: {e}" - ) except subprocess.CalledProcessError as e: raise RuntimeError( f"Simpleware Medical segmentation failed with return code {e.returncode}:\n" @@ -223,7 +258,7 @@ def segmentation_method(self, preprocessed_image: itk.image) -> itk.image: # Simpleware's right ventricle, left atrium, right atrium correspond to # the interior of those regions. - mask_ids_of_interior_regions = [2, 3, 4] + mask_ids_of_interior_regions = [1, 2, 3, 4] # Check if output file was created sz = [s for s in preprocessed_image.GetLargestPossibleRegion().GetSize()] @@ -265,6 +300,16 @@ def segmentation_method(self, preprocessed_image: itk.image) -> itk.image: "ensure the ASCardio module ran successfully." ) + if self.trim_mesh_to_essentials: + z = labelmap_array.shape[2] - 1 + z_classes = np.unique(labelmap_array[z, :, :]) + heart_count = np.sum((c in [1, 2, 3, 4, 5]) for c in z_classes) + while heart_count < 3 and z > 0: + z -= 1 + z_classes = np.unique(labelmap_array[z, :, :]) + heart_count = np.sum((c in [1, 2, 3, 4, 5]) for c in z_classes) + if z < labelmap_array.shape[2] - 3: + labelmap_array[(z + 3) :, :, :] = 0 labelmap_image = itk.GetImageFromArray(labelmap_array.astype(np.uint8)) labelmap_image.CopyInformation(preprocessed_image) diff --git a/src/physiomotion4d/simpleware_medical/README.md b/src/physiomotion4d/simpleware_medical/README.md index d44039f..6c20b41 100644 --- a/src/physiomotion4d/simpleware_medical/README.md +++ b/src/physiomotion4d/simpleware_medical/README.md @@ -6,7 +6,7 @@ This directory contains integration code for using Synopsys Simpleware Medical w The integration enables PhysioMotion4D to leverage Simpleware Medical's ASCardio module for automated cardiac segmentation. The implementation uses a two-component architecture: -1. **segment_heart_simpleware.py** (in parent directory): A Python class that inherits from `SegmentChestBase` and manages the external Simpleware Medical process +1. **segment_heart_simpleware.py** (in parent directory): A Python class that inherits from `SegmentAnatomyBase` and manages the external Simpleware Medical process 2. **SimplewareScript_heart_segmentation.py** (this directory): A Python script that runs within the Simpleware Medical environment and performs the actual segmentation using ASCardio ## Requirements diff --git a/src/physiomotion4d/workflow_create_statistical_model.py b/src/physiomotion4d/workflow_create_statistical_model.py new file mode 100644 index 0000000..6aece41 --- /dev/null +++ b/src/physiomotion4d/workflow_create_statistical_model.py @@ -0,0 +1,302 @@ +"""Create a PCA statistical shape model from a sample of meshes. + +This module provides the WorkflowCreateStatisticalModel class that implements +the pipeline from the Heart-Create_Statistical_Model experiment notebooks: + +1. Extract surfaces from sample and reference meshes +2. ICP alignment: align each sample surface to the reference (template) surface +3. Deformable registration: establish dense correspondence via mask-based SyN +4. Correspondence: warp reference surface by each transform to get aligned shapes +5. PCA: compute mean and modes from corresponded shapes + +Returns a dictionary of surfaces, meshes, and PCA model structure (no file I/O). +""" + +import logging +from typing import Any, Optional + +import itk +import numpy as np +import pyvista as pv +from sklearn.decomposition import PCA + +from physiomotion4d.contour_tools import ContourTools +from physiomotion4d.physiomotion4d_base import PhysioMotion4DBase +from physiomotion4d.register_models_distance_maps import RegisterModelsDistanceMaps +from physiomotion4d.register_models_icp import RegisterModelsICP +from physiomotion4d.transform_tools import TransformTools + + +def _extract_surface(mesh: pv.DataSet) -> pv.PolyData: + """Extract surface from a mesh (PolyData or UnstructuredGrid).""" + if isinstance(mesh, pv.UnstructuredGrid): + return mesh.extract_surface() + if isinstance(mesh, pv.PolyData): + return mesh + return mesh.extract_surface() + + +class WorkflowCreateStatisticalModel(PhysioMotion4DBase): + """Create a PCA statistical shape model from a sample of meshes aligned to a reference. + + Pipeline (mirrors experiments/Heart-Create_Statistical_Model notebooks 1–5): + 1. Extract surfaces from sample meshes and reference mesh (reference surface = alignment target) + 2. ICP (affine) align each sample surface to the reference surface + 3. Deformable (ANTs SyN) registration of each aligned sample to reference + 4. Build corresponded shapes (reference topology) in reference space + 5. Compute PCA and return mean surface, reference mesh, and PCA model dict + + Attributes: + sample_meshes (list): List of sample mesh DataSets (.vtk/.vtu/.vtp geometry) + reference_mesh (pv.DataSet): Reference mesh; its surface is used for alignment + pca_number_of_components (int): Number of PCA components to retain + reference_spatial_resolution (float): Resolution for reference image from mesh + reference_buffer_factor (float): Buffer around mesh for reference image + """ + + def __init__( + self, + sample_meshes: list[pv.DataSet], + reference_mesh: pv.DataSet, + pca_number_of_components: int = 15, + reference_spatial_resolution: float = 1.0, + reference_buffer_factor: float = 0.25, + log_level: int | str = logging.INFO, + ): + """Initialize the create-statistical-model workflow. + + Args: + sample_meshes: List of sample mesh DataSets (PyVista PolyData or UnstructuredGrid). + reference_mesh: Reference mesh; its surface is used to align all samples. + pca_number_of_components: Number of PCA components. Default 15. + reference_spatial_resolution: Isotropic resolution (mm) for reference image. Default 1.0. + reference_buffer_factor: Buffer factor around mesh for reference image. Default 0.25. + log_level: Logging level. + """ + super().__init__( + class_name="WorkflowCreateStatisticalModel", log_level=log_level + ) + self.sample_meshes = list(sample_meshes) + self.reference_mesh = reference_mesh + self.pca_number_of_components = pca_number_of_components + self.reference_spatial_resolution = reference_spatial_resolution + self.reference_buffer_factor = reference_buffer_factor + + self.contour_tools = ContourTools() + self.transform_tools = TransformTools() + + # Set by pipeline + self.reference_surface: Optional[pv.PolyData] = None + self.sample_surfaces: list[pv.PolyData] = [] + self.sample_ids: list[str] = [] + self.aligned_surfaces: list[pv.PolyData] = [] + self.forward_transforms: list = [] + self.inverse_transforms: list = [] + self.pca_input_surfaces: list[pv.PolyData] = [] + self.pca_fitted: Optional[PCA] = None + self.pca_mean_surface: Optional[pv.PolyData] = None + self.pca_mean_mesh: Optional[pv.UnstructuredGrid] = None + + def set_pca_number_of_components(self, n: int) -> None: + """Set number of PCA components to retain.""" + self.pca_number_of_components = n + + def _step1_extract_surfaces(self) -> None: + """Extract reference surface and all sample surfaces (notebook 1).""" + self.log_section("Step 1: Extract reference and sample surfaces", width=70) + if not self.sample_meshes: + raise ValueError("sample_meshes must not be empty") + self.reference_surface = _extract_surface(self.reference_mesh) + self.log_info( + "Reference surface: %d points", + self.reference_surface.n_points, + ) + self.sample_surfaces = [] + self.sample_ids = [] + for i, mesh in enumerate(self.sample_meshes): + surface = _extract_surface(mesh) + self.sample_surfaces.append(surface) + self.sample_ids.append(str(i)) + self.log_info("Extracted %d sample surfaces", len(self.sample_surfaces)) + + def _step2_icp_align(self) -> None: + """ICP (affine) align each sample surface to reference (notebook 2).""" + self.log_section("Step 2: ICP alignment to reference surface", width=70) + assert self.reference_surface is not None and self.sample_surfaces + self.aligned_surfaces = [] + self.forward_transforms = [] + self.inverse_transforms = [] + + for i, (sid, moving) in enumerate(zip(self.sample_ids, self.sample_surfaces)): + self.log_info( + "ICP aligning %s (%d/%d)", sid, i + 1, len(self.sample_surfaces) + ) + if isinstance(moving, pv.UnstructuredGrid): + moving = moving.extract_surface() + registrar = RegisterModelsICP(fixed_model=self.reference_surface) + result = registrar.register( + moving_model=moving, + transform_type="Affine", + max_iterations=2000, + ) + self.aligned_surfaces.append(result["registered_model"]) + self.forward_transforms.append(result["forward_point_transform"]) + self.inverse_transforms.append(result["inverse_point_transform"]) + + self.log_info( + "ICP alignment complete for %d samples", len(self.aligned_surfaces) + ) + + def _step3_deformable_correspondence(self) -> None: + """Deformable registration of each aligned sample to reference (notebook 3).""" + self.log_section("Step 3: Deformable registration (correspondence)", width=70) + assert self.reference_surface is not None and self.aligned_surfaces + reference_image = self.contour_tools.create_reference_image( + mesh=self.reference_surface, + spatial_resolution=self.reference_spatial_resolution, + buffer_factor=self.reference_buffer_factor, + ptype=itk.UC, + ) + self.forward_transforms = [] + self.inverse_transforms = [] + + for i, (sid, moving) in enumerate(zip(self.sample_ids, self.aligned_surfaces)): + self.log_info( + "Deformable registration %s (%d/%d)", + sid, + i + 1, + len(self.aligned_surfaces), + ) + registrar = RegisterModelsDistanceMaps( + moving_model=moving, + fixed_model=self.reference_surface, + reference_image=reference_image, + ) + result = registrar.register( + transform_type="Deformable", + use_icon=False, + ) + self.forward_transforms.append(result["forward_transform"]) + self.inverse_transforms.append(result["inverse_transform"]) + + self.log_info( + "Deformable registration complete for %d samples", + len(self.forward_transforms), + ) + + def _step4_build_pca_inputs(self) -> None: + """Build corresponded shapes in reference space (notebook 4). + + For each case, reference_surface is warped by forward (image) deformation + (= inverse point) transform from step 3, so that we get reference topology + in ICP-aligned space with residual deformation per subject to be used as PCA + input. + """ + self.log_section("Step 4: Build PCA inputs (corresponded shapes)", width=70) + assert self.reference_surface is not None and self.forward_transforms + self.pca_input_surfaces = [] + for fwd_tfm in self.forward_transforms: + pca_input_surface = self.contour_tools.transform_contours( + self.reference_surface, tfm=fwd_tfm, with_deformation_magnitude=False + ) + self.pca_input_surfaces.append(pca_input_surface) + self.log_info( + "Built %d corresponded surfaces for PCA", len(self.pca_input_surfaces) + ) + + def _step5_compute_pca(self) -> None: + """Compute PCA and mean surface (notebook 5).""" + self.log_section("Step 5: Compute PCA model", width=70) + assert self.reference_surface is not None and self.pca_input_surfaces + template = self.reference_surface + n_points = template.n_points + + data_matrix = [] + for i, mesh in enumerate(self.pca_input_surfaces): + if mesh.n_points != n_points: + raise ValueError( + f"Sample {self.sample_ids[i]} has {mesh.n_points} points, " + f"expected {n_points}. Topology must match." + ) + data_matrix.append(mesh.points.flatten()) + data_matrix = np.array(data_matrix) + + if data_matrix.shape[0] - 1 < 2: + raise ValueError( + f"At least 2 samples are required for PCA. Got {data_matrix.shape[0]} samples." + ) + n_comp = min(self.pca_number_of_components, data_matrix.shape[0] - 1) + if n_comp < self.pca_number_of_components: + self.log_warning( + "Reducing PCA components from %d to %d (n_samples=%d)", + self.pca_number_of_components, + n_comp, + data_matrix.shape[0], + ) + self.pca_fitted = PCA(n_components=n_comp) + self.pca_fitted.fit(data_matrix) + + self.pca_mean_surface = template.copy() + self.pca_mean_surface.points = self.pca_fitted.mean_.reshape(-1, 3) + self.log_info( + "PCA complete: %d components, variance explained %.4f", + len(self.pca_fitted.explained_variance_ratio_), + self.pca_fitted.explained_variance_ratio_.sum(), + ) + + reference_image = self.contour_tools.create_reference_image( + mesh=self.pca_mean_surface, + spatial_resolution=self.reference_spatial_resolution, + buffer_factor=self.reference_buffer_factor, + ptype=itk.UC, + ) + mean_deformation_array = self.pca_mean_surface.points - template.points + mean_deformation_field = self.contour_tools.create_deformation_field( + points=template.points, + point_displacements=mean_deformation_array, + reference_image=reference_image, + blur_sigma=2.5, + ptype=itk.D, + ) + mean_deformation_transform = itk.DisplacementFieldTransform[itk.D, 3].New() + mean_deformation_transform.SetDisplacementField(mean_deformation_field) + self.pca_mean_mesh = self.contour_tools.transform_contours( + self.reference_mesh, + tfm=mean_deformation_transform, + with_deformation_magnitude=False, + ) + + def _build_result(self) -> dict[str, Any]: + """Build result dictionary: surfaces, meshes, and PCA model structure.""" + assert self.pca_mean_surface is not None and self.pca_fitted is not None + result: dict[str, Any] = { + "pca_mean_surface": self.pca_mean_surface, + "pca_mean_mesh": self.pca_mean_mesh, + "pca_model": { + "explained_variance_ratio": self.pca_fitted.explained_variance_ratio_.tolist(), + "eigenvalues": self.pca_fitted.explained_variance_.tolist(), + "components": [c.tolist() for c in self.pca_fitted.components_], + }, + "pca_fitted": self.pca_fitted, + } + return result + + def run_workflow(self) -> dict[str, Any]: + """Run the full pipeline and return a dictionary of results (no file I/O). + + Returns: + dict with keys: + - pca_mean_surface: pv.PolyData mean shape surface + - pca_mean_mesh: pv.UnstructuredGrid reference volume mesh, or None if reference was surface-only + - pca_model: dict with "explained_variance_ratio", "eigenvalues", "components" (same structure as pca_model.json) + - pca_fitted: fitted sklearn PCA object + """ + self.log_section("STARTING CREATE STATISTICAL MODEL WORKFLOW", width=70) + self._step1_extract_surfaces() + self._step2_icp_align() + self._step3_deformable_correspondence() + self._step4_build_pca_inputs() + self._step5_compute_pca() + result = self._build_result() + self.log_section("CREATE STATISTICAL MODEL WORKFLOW COMPLETE", width=70) + return result diff --git a/src/physiomotion4d/workflow_register_heart_model_to_patient.py b/src/physiomotion4d/workflow_fit_statistical_model_to_patient.py similarity index 74% rename from src/physiomotion4d/workflow_register_heart_model_to_patient.py rename to src/physiomotion4d/workflow_fit_statistical_model_to_patient.py index c63e643..70b0f9f 100644 --- a/src/physiomotion4d/workflow_register_heart_model_to_patient.py +++ b/src/physiomotion4d/workflow_fit_statistical_model_to_patient.py @@ -1,6 +1,6 @@ """Model-to-image and model-to-model registration for anatomical models. -This module provides the WorkflowRegisterHeartModelToPatient class for registering generic +This module provides the WorkflowFitStatisticalModelToPatient class for registering generic anatomical models to patient-specific imaging data and surface models. The workflow includes: 1. Rough alignment using ICP (RegisterModelsICP) @@ -23,7 +23,7 @@ """ import logging -from typing import Optional +from typing import Any, Optional import itk import numpy as np @@ -40,7 +40,7 @@ from physiomotion4d.transform_tools import TransformTools -class WorkflowRegisterHeartModelToPatient(PhysioMotion4DBase): +class WorkflowFitStatisticalModelToPatient(PhysioMotion4DBase): """Register anatomical models using multi-stage ICP, mask-based, and image-based registration. @@ -79,8 +79,9 @@ class WorkflowRegisterHeartModelToPatient(PhysioMotion4DBase): transform_tools (TransformTools): Transform utilities registrar_icon (RegisterImagesICON): ICON registration instance registrar_ants (RegisterImagesANTs): ANTs registration instance - pca_json_filename (str): PCA JSON filename (optional) - pca_number_of_modes (int): Number of PCA modes to use + use_pca_registration (bool): Whether PCA registration is enabled (set via set_use_pca_registration) + pca_model (dict): PCA model dict when PCA enabled; same structure as WorkflowCreateStatisticalModel output + pca_number_of_modes (int): Number of PCA modes when PCA enabled icp_forward_point_transform : ICP transforms icp_inverse_point_transform : ICP inverse transforms icp_template_model_surface: template model surface after ICP alignment @@ -100,33 +101,24 @@ class WorkflowRegisterHeartModelToPatient(PhysioMotion4DBase): registered_template_model_surface: Final registered model surface Example: - >>> # Initialize with minimal parameters - >>> registrar = WorkflowRegisterHeartModelToPatient( + >>> # Initialize with minimal parameters (no labelmap; no patient image -> reference created from patient models) + >>> registrar = WorkflowFitStatisticalModelToPatient( ... template_model=heart_model, ... patient_models=[lv_model, mc_model, rv_model], - ... patient_image=patient_ct, - ... pca_json_filename='path/to/pca_model.json', - ... pca_number_of_modes=10, ... ) - >>> - >>> # Optional: Configure parameters (masks auto-generated if not set) >>> registrar.set_roi_dilation_mm(20) - >>> - >>> # Run registration - >>> patient_model = registrar.run_workflow() + >>> # To enable PCA registration, call before run_workflow(): + >>> # registrar.set_use_pca_registration(True, pca_model=pca_model_dict, pca_number_of_modes=10) + >>> # To enable mask-to-image refinement: + >>> # registrar.set_use_mask_to_image_registration(True, template_labelmap, organ_mesh_ids, organ_extra_ids, background_ids) + >>> result = registrar.run_workflow() """ def __init__( self, template_model: pv.UnstructuredGrid, - template_labelmap: itk.Image, - template_labelmap_heart_muscle_ids: list[int], - template_labelmap_chamber_ids: list[int], - template_labelmap_background_ids: list[int], patient_models: list, - patient_image: itk.Image, - pca_json_filename: Optional[str] = None, - pca_number_of_modes: int = 0, + patient_image: Optional[itk.Image] = None, log_level: int | str = logging.INFO, ): """Initialize the model-to-image-and-model registration pipeline. @@ -135,39 +127,47 @@ def __init__( template_model: Generic anatomical model to be registered patient_models: List of patient-specific models extracted from imaging data. Typically 3 models for cardiac applications: LV, myocardium, RV. - patient_image: Patient image data providing the target coordinate frame - (origin, spacing, direction). Used as reference for registration. + patient_image: Optional patient image providing the target coordinate frame. + If None, a reference image is created from the patient model surface + via create_reference_image (contour_tools). log_level: Logging level (logging.DEBUG, logging.INFO, logging.WARNING). Default: logging.INFO """ # Initialize base class with logging super().__init__( - class_name="WorkflowRegisterHeartModelToPatient", log_level=log_level + class_name="WorkflowFitStatisticalModelToPatient", log_level=log_level ) self.template_model = template_model self.template_model_surface = template_model.extract_surface() - self.template_labelmap = template_labelmap - self.template_labelmap_heart_muscle_ids = template_labelmap_heart_muscle_ids - self.template_labelmap_chamber_ids = template_labelmap_chamber_ids - self.template_labelmap_background_ids = template_labelmap_background_ids + self.template_labelmap: Optional[itk.Image] = None + self.template_labelmap_organ_mesh_ids: Optional[list[int]] = None + self.template_labelmap_organ_extra_ids: Optional[list[int]] = None + self.template_labelmap_background_ids: Optional[list[int]] = None self.patient_models = patient_models patient_models_surfaces = [model.extract_surface() for model in patient_models] combined_patient_model = pv.merge(patient_models_surfaces) self.patient_model_surface = combined_patient_model.extract_surface() - self.patient_image = patient_image - - resampler = ttk.ResampleImage.New(Input=self.patient_image) - resampler.SetMakeHighResIso(True) - resampler.Update() - self.patient_image = resampler.GetOutput() - - # Utilities + # Utilities (needed for create_reference_image when patient_image is None) self.transform_tools = TransformTools() self.contour_tools = ContourTools() + if patient_image is not None: + self.patient_image = patient_image + resampler = ttk.ResampleImage.New(Input=self.patient_image) + resampler.SetMakeHighResIso(True) + resampler.Update() + self.patient_image = resampler.GetOutput() + else: + self.patient_image = self.contour_tools.create_reference_image( + mesh=self.patient_model_surface, + spatial_resolution=1.0, + buffer_factor=0.25, + ptype=itk.F, + ) + self.registrar_ants = RegisterImagesANTs() self.registrar_ants.set_number_of_iterations([5, 2, 5]) # Icon registration for final mask-to-image step @@ -194,12 +194,13 @@ def __init__( self.icp_template_model_surface: Optional[pv.PolyData] = None self.icp_template_labelmap: Optional[itk.Image] = None - # Stage 1.5: PCA registration results (optional) + # Stage 1.5: PCA registration results (optional; enable via set_use_pca_registration(True, pca_model, pca_number_of_modes)) + self.use_pca_registration = False self.pca_registrar: Optional[RegisterModelsPCA] = None self.pca_forward_point_transform: Optional[itk.Transform] = None self.pca_inverse_point_transform: Optional[itk.Transform] = None - self.pca_json_filename = pca_json_filename - self.pca_number_of_modes = pca_number_of_modes + self.pca_model: Optional[dict[str, Any]] = None + self.pca_number_of_modes: int = 0 self.pca_coefficients: Optional[np.ndarray] = None self.pca_template_model_surface: Optional[pv.PolyData] = None self.pca_template_labelmap: Optional[itk.Image] = None @@ -211,8 +212,8 @@ def __init__( self.m2m_template_model_surface: Optional[pv.PolyData] = None self.m2m_template_labelmap: Optional[itk.Image] = None - # Stage 3: Mask-to-image registration results - self.use_m2i_registration = True + # Stage 3: Mask-to-image registration results (disabled by default; enable via set_use_mask_to_image_registration(True, template_labelmap, ...)) + self.use_m2i_registration = False self.m2i_inverse_transform: Optional[itk.Transform] = None self.m2i_forward_transform: Optional[itk.Transform] = None self.m2i_template_model_surface: Optional[pv.PolyData] = None @@ -320,6 +321,39 @@ def set_roi_dilation_mm(self, roi_dilation_mm: float) -> None: """ self.roi_dilation_mm = roi_dilation_mm + def set_use_pca_registration( + self, + use_pca_registration: bool, + pca_model: Optional[dict[str, Any]] = None, + pca_number_of_modes: int = 0, + ) -> None: + """Set whether to use PCA-based registration and provide the PCA model. + + When enabling (True), pca_model and pca_number_of_modes must be provided. + + Args: + use_pca_registration: Whether to use PCA registration after ICP. + pca_model: Required when use is True. PCA model dict (e.g. from + WorkflowCreateStatisticalModel result["pca_model"]) with keys + "eigenvalues" and "components". + pca_number_of_modes: Required when use is True. Number of PCA modes to use. + Default 0 means use all modes. + + Raises: + ValueError: If use is True and pca_model is None. + """ + if use_pca_registration: + if pca_model is None: + raise ValueError( + "When enabling PCA registration, pca_model must be provided." + ) + self.pca_model = pca_model + self.pca_number_of_modes = pca_number_of_modes + else: + self.pca_model = None + self.pca_number_of_modes = 0 + self.use_pca_registration = use_pca_registration + def set_use_mask_to_mask_registration( self, use_mask_to_mask_registration: bool ) -> None: @@ -332,13 +366,57 @@ def set_use_mask_to_mask_registration( self.use_m2m_registration = use_mask_to_mask_registration def set_use_mask_to_image_registration( - self, use_mask_to_image_registration: bool + self, + use_mask_to_image_registration: bool, + template_labelmap: Optional[itk.Image] = None, + template_labelmap_organ_mesh_ids: Optional[list[int]] = None, + template_labelmap_organ_extra_ids: Optional[list[int]] = None, + template_labelmap_background_ids: Optional[list[int]] = None, ) -> None: """Set whether to use mask-to-image registration. + When enabling (True), a template labelmap and label IDs must be provided + so the workflow can propagate and refine the labelmap to the patient image. + Args: - use_m2i: Whether to use mask-to-image registration. Default: True + use_mask_to_image_registration: Whether to use mask-to-image registration. + template_labelmap: Required when use is True. Template labelmap in template + model space (same geometry as template_model). + template_labelmap_organ_mesh_ids: Required when use is True. Label IDs for + organ mesh in the template labelmap. + template_labelmap_organ_extra_ids: Required when use is True. Label IDs for + organ-extra structures in the template labelmap. + template_labelmap_background_ids: Required when use is True. Label IDs for + background in the template labelmap. + + Raises: + ValueError: If use is True and any of template_labelmap or the id lists + is None or missing. """ + if use_mask_to_image_registration: + if template_labelmap is None: + raise ValueError( + "When enabling mask-to-image registration, template_labelmap must be provided." + ) + if template_labelmap_organ_mesh_ids is None: + raise ValueError( + "When enabling mask-to-image registration, " + "template_labelmap_organ_mesh_ids must be provided." + ) + if template_labelmap_organ_extra_ids is None: + raise ValueError( + "When enabling mask-to-image registration, " + "template_labelmap_organ_extra_ids must be provided." + ) + if template_labelmap_background_ids is None: + raise ValueError( + "When enabling mask-to-image registration, " + "template_labelmap_background_ids must be provided." + ) + self.template_labelmap = template_labelmap + self.template_labelmap_organ_mesh_ids = template_labelmap_organ_mesh_ids + self.template_labelmap_organ_extra_ids = template_labelmap_organ_extra_ids + self.template_labelmap_background_ids = template_labelmap_background_ids self.use_m2i_registration = use_mask_to_image_registration def register_model_to_model_icp(self) -> dict: @@ -370,12 +448,15 @@ def register_model_to_model_icp(self) -> dict: self.icp_inverse_point_transform = icp_result["inverse_point_transform"] self.icp_template_model_surface = icp_result["registered_model"] - self.icp_template_labelmap = self.transform_tools.transform_image( - self.template_labelmap, - self.icp_inverse_point_transform, - self.patient_image, - interpolation_method="nearest", - ) + if self.template_labelmap is not None: + self.icp_template_labelmap = self.transform_tools.transform_image( + self.template_labelmap, + self.icp_inverse_point_transform, + self.patient_image, + interpolation_method="nearest", + ) + else: + self.icp_template_labelmap = None self.log_info("Stage 1 complete: ICP alignment finished.") @@ -408,16 +489,17 @@ def register_model_to_model_pca(self) -> dict: width=70, ) - if self.pca_json_filename is None: + if not self.use_pca_registration or self.pca_model is None: self.pca_template_model_surface = self.icp_template_model_surface + self.pca_template_labelmap = self.icp_template_labelmap return { "pca_coefficients": None, "registered_model_surface": self.pca_template_model_surface, } - self.pca_registrar = RegisterModelsPCA.from_json( + self.pca_registrar = RegisterModelsPCA.from_pca_model( pca_template_model=self.icp_template_model_surface, - pca_json_filename=self.pca_json_filename, + pca_model=self.pca_model, pca_number_of_modes=self.pca_number_of_modes, fixed_model=self.patient_model_surface, reference_image=self.patient_image, @@ -439,12 +521,15 @@ def register_model_to_model_pca(self) -> dict: self.registered_template_model_surface = self.pca_template_model_surface - self.pca_template_labelmap = self.transform_tools.transform_image( - self.icp_template_labelmap, - self.pca_inverse_point_transform, - self.patient_image, - interpolation_method="nearest", - ) + if self.icp_template_labelmap is not None: + self.pca_template_labelmap = self.transform_tools.transform_image( + self.icp_template_labelmap, + self.pca_inverse_point_transform, + self.patient_image, + interpolation_method="nearest", + ) + else: + self.pca_template_labelmap = None self.log_info("Stage 2 complete: PCA registration finished.") @@ -503,12 +588,15 @@ def register_mask_to_mask( self.registered_template_model_surface = self.m2m_template_model_surface - self.m2m_template_labelmap = self.transform_tools.transform_image( - self.pca_template_labelmap, - self.m2m_forward_transform, - self.patient_image, - interpolation_method="nearest", - ) + if self.pca_template_labelmap is not None: + self.m2m_template_labelmap = self.transform_tools.transform_image( + self.pca_template_labelmap, + self.m2m_forward_transform, + self.patient_image, + interpolation_method="nearest", + ) + else: + self.m2m_template_labelmap = None self.log_info("Stage 3 complete: Mask-to-mask registration finished.") @@ -537,6 +625,24 @@ def register_labelmap_to_image( "Stage 4: Labelmap-to-Image Refinement (Icon Registration)", width=70 ) + if ( + self.template_labelmap is None + or self.template_labelmap_organ_mesh_ids is None + or self.template_labelmap_organ_extra_ids is None + or self.template_labelmap_background_ids is None + ): + raise ValueError( + "Mask-to-image registration requires template labelmap and label IDs. " + "Call set_use_mask_to_image_registration(True, template_labelmap, " + "organ_mesh_ids, organ_extra_ids, background_ids) before run_workflow()." + ) + if self.m2m_template_labelmap is None: + raise ValueError( + "Mask-to-image registration requires a labelmap to have been set " + "(via set_use_mask_to_image_registration(True, ...)) before running " + "earlier stages so the labelmap is propagated through ICP/PCA/M2M." + ) + labelmap_arr = itk.GetArrayFromImage(self.m2m_template_labelmap).astype( np.uint16 ) @@ -546,12 +652,14 @@ def register_labelmap_to_image( labelmap_arr, ) labelmap_arr = np.where( - np.isin(labelmap_arr, self.template_labelmap_heart_muscle_ids), + np.isin(labelmap_arr, self.template_labelmap_organ_mesh_ids), 0, labelmap_arr, ) labelmap_arr = np.where( - np.isin(labelmap_arr, self.template_labelmap_chamber_ids), 1, labelmap_arr + np.isin(labelmap_arr, self.template_labelmap_organ_extra_ids), + 1, + labelmap_arr, ) labelmap = itk.GetImageFromArray(labelmap_arr) labelmap.CopyInformation(self.m2m_template_labelmap) @@ -691,7 +799,7 @@ def transform_model( def run_workflow( self, - use_mask_to_image_registration: bool = True, + use_mask_to_image_registration: bool = False, use_mask_to_mask_registration: bool = True, use_icon_registration_refinement: bool = False, ) -> dict: @@ -701,21 +809,20 @@ def run_workflow( 1. ICP alignment (RegisterModelsICP) 2. PCA registration (PCA data was provided) 3. Mask-to-mask deformable registration (RegisterModelsDistanceMaps) - 4. Optional mask-to-image refinement (Icon) + 4. Optional mask-to-image refinement (Icon); requires template labelmap and IDs + set via set_use_mask_to_image_registration(True, ...). Args: use_mask_to_image_registration: Whether to include mask-to-image - registration stage. - Default: True + registration stage. Default: False. When True, template labelmap and + label IDs must have been set via set_use_mask_to_image_registration(True, ...). use_mask_to_mask_registration: Whether to include mask-to-mask registration - stage. - Default: True + stage. Default: True use_icon_registration_refinement: Whether to include icon registration - refinement stage. - Default: False + refinement stage. Default: False Returns: - pv.UnstructuredGrid: Final registered model + dict with registered_template_model and registered_template_model_surface """ self.log_section( "STARTING COMPLETE MODEL-TO-IMAGE-AND-MODEL REGISTRATION WORKFLOW", width=70 diff --git a/statistics.md b/statistics.md index 4cfa1fe..fe67ee9 100644 --- a/statistics.md +++ b/statistics.md @@ -45,9 +45,9 @@ PhysioMotion4D is a sophisticated medical imaging package for generating anatomi | --------------------------------------------- | ----- | ---------------------------------------------- | | `transform_tools.py` | 1,142 | Transform manipulation utilities | | `register_models_pca.py` | 818 | PCA-based statistical shape model registration | -| `workflow_register_heart_model_to_patient.py` | 745 | Model-to-patient registration workflow | +| `workflow_fit_statistical_model_to_patient.py` | 745 | Model-to-patient registration workflow | | `register_images_ants.py` | 725 | ANTs-based image registration | -| `segment_chest_base.py` | 672 | Base class for chest segmentation | +| `segment_anatomy_base.py` | 672 | Base class for anatomy segmentation | | `convert_vtk_to_usd_polymesh.py` | 622 | Polymesh USD conversion | | `convert_vtk_to_usd_base.py` | 585 | Base USD conversion functionality | | `workflow_convert_heart_gated_ct_to_usd.py` | 539 | Heart CT to USD workflow |