diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..9afc9f1 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,135 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Commands + +**Python launcher:** Use `py` on this Windows system (not `python`). + +```bash +# Install in editable mode (preferred) +uv pip install -e . + +# Lint and format +ruff check . --fix +ruff format . + +# Type checking +mypy src/ + +# All pre-commit hooks +pre-commit run --all-files + +# Run fast tests (recommended for development) +py -m pytest tests/ -m "not slow and not requires_data" -v + +# Run a single test file +py -m pytest tests/test_contour_tools.py -v + +# Run a single test by name +py -m pytest tests/test_contour_tools.py::test_extract_surface -v + +# Run tests without GPU-dependent tests +py -m pytest tests/ --ignore=tests/test_segment_chest_total_segmentator.py \ + --ignore=tests/test_segment_chest_vista_3d.py \ + --ignore=tests/test_register_images_icon.py + +# Run with coverage +py -m pytest tests/ --cov=src/physiomotion4d --cov-report=html + +# Run experiment notebook tests (opt-in, very slow) +py -m pytest tests/ --run-experiments + +# Create baseline files when missing +py -m pytest tests/ --create-baselines +``` + +**Version bumping:** `bumpver update --patch` (or `--minor`, `--major`) + +## Architecture + +### Pipeline Overview + +PhysioMotion4D converts 4D CT scans (cardiac or pulmonary) into animated USD models for NVIDIA Omniverse. The pipeline flows: + +``` +4D CT → Segmentation → Registration → Contour Extraction → USD Export +``` + +### Class Hierarchy + +All major classes inherit from `PhysioMotion4DBase` (`src/physiomotion4d/physiomotion4d_base.py`), which provides a shared logger named `"PhysioMotion4D"`. Use `self.log_info()`, `self.log_debug()`, etc. — never `print()`. Use `PhysioMotion4DBase.set_log_classes([...])` to filter output to specific classes. + +### Workflow Classes (entry points) + +- **`WorkflowConvertHeartGatedCTToUSD`**: Full 4D cardiac CT → USD pipeline. Orchestrates: 4D→3D conversion → segmentation (TotalSegmentator) → registration (ICON or ANTs) → contour extraction → USD generation. +- **`WorkflowCreateStatisticalModel`**: Builds a PCA statistical shape model (sklearn) from a population of aligned meshes. Outputs `pca_model.json`, `pca_mean_surface.vtp`. +- **`WorkflowFitStatisticalModelToPatient`**: Multi-stage model-to-patient registration: (1) ICP rough alignment → (2) optional PCA shape fitting → (3) mask-to-mask deformable registration → (4) optional Icon final refinement. +- **`WorkflowReconstructHighres4DCT`**: Reconstructs high-resolution 4D CT from sparse time samples via deformable registration. + +### Segmentation Classes + +All segment methods return anatomy group masks (heart, lung, major_vessels, bone, soft_tissue, contrast, other, dynamic). The `SegmentAnatomyBase` abstract class defines the interface. + +- `SegmentChestTotalSegmentator` — default, CPU-capable +- `SegmentChestVista3D` — GPU-accelerated MONAI VISTA-3D model +- `SegmentChestVista3DNIM` — NIM cloud API version (requires `pip install physiomotion4d[nim]`) +- `SegmentChestEnsemble` — combines multiple methods +- `SegmentHeartSimpleware` — wraps Simpleware ScanIP SDK (requires Simpleware installation) + +### Registration Classes + +**Image-to-image:** +- `RegisterImagesICON` — deep learning, GPU, preferred for 4D CT +- `RegisterImagesANTs` — classical deformable, CPU-capable +- `RegisterTimeSeriesImages` — wraps ICON or ANTs for 4D time series; handles reference frame selection + +All image registerers follow the interface: `set_fixed_image()` → `register(moving_image)` → returns `{"forward_transform": ..., "inverse_transform": ...}` (ITK composite transforms). + +**Model-to-model/image:** +- `RegisterModelsICP` — centroid + affine ICP using VTK/PyVista +- `RegisterModelsICPITK` — ICP using ITK +- `RegisterModelsPCA` — PCA shape space fitting; requires `pca_model.json` +- `RegisterModelsDistanceMaps` — deformable registration via distance map matching (uses ANTs or ICON internally) + +### USD Pipeline + +Two APIs exist for VTK→USD conversion: + +1. **`ConvertVTKToUSD`** (`convert_vtk_to_usd.py`) — high-level, operates on PyVista objects in memory. Supports colormap overlays, multi-label anatomy, and animated time series. +2. **`vtk_to_usd/`** subpackage — file-based, modular. Core: `VTKToUSDConverter`, `ConversionSettings`, `MaterialData`. Use `convert_vtk_file()` for simple cases. + +`USDTools` and `USDAnatomyTools` handle USD stage merging, time-varying data preservation, and applying surgical materials from a materials library. + +### Key Data Conventions + +- Medical images use ITK (`itk.Image`); surfaces use PyVista (`pv.PolyData`, `pv.UnstructuredGrid`) +- Coordinate system: RAS (medical) internally; converted to Y-up for USD/Omniverse export +- Masks are ITK images with integer labels; anatomy groups use consistent label IDs across segmenters +- Transforms stored as ITK composite transforms in `.hdf` files + +### Testing + +- Test baselines are stored in `tests/baselines/` via **Git LFS** — run `git lfs pull` after cloning +- `tests/conftest.py` provides session-scoped fixtures that chain (download → convert → segment → register); most tests depend on upstream fixtures +- Test markers: `slow`, `requires_gpu`, `requires_data`, `experiment` (skipped by default; use `--run-experiments`) +- `test_tools.py` (`src/physiomotion4d/test_tools.py`) provides baseline comparison utilities + +### Reference Code + +API documentation and examples for advanced third-party libraries (ITK, VTK, PyVista, Omniverse, PhysicsNeMo, Simpleware, MONAI, OpenUSD) are in the `reference_code/` directory. + +## File Operations + +Use `git mv` / `git rm` for moving or deleting tracked files — not `mv` / `rm` — to preserve git history. + +## Documentation Policy + +Do **not** create new `.md` files unless explicitly requested. Document via docstrings and inline comments. A `README.md` may be created for new submodules that lack one. + +## Code Style + +- Single quotes for strings (`'...'`), double quotes for docstrings (`"""..."""`) +- Full type hints required (`mypy` is strict; `disallow_untyped_defs = true`) +- `Optional[X]` not `X | None` for ITK compatibility (ruff `UP007` is suppressed) +- Backward compatibility is **not** a priority — breaking changes are acceptable diff --git a/data/DirLab-4DCT/.gitignore b/data/DirLab-4DCT/.gitignore index 381f31a..9c77de8 100644 --- a/data/DirLab-4DCT/.gitignore +++ b/data/DirLab-4DCT/.gitignore @@ -1,3 +1,4 @@ Case*Pack Case*Deploy *.txt +*.mha diff --git a/experiments/Convert_VTK_To_USD/convert_chop_alterra_valve_to_usd.ipynb b/experiments/Convert_VTK_To_USD/convert_chop_alterra_valve_to_usd.ipynb index 45d83c2..b31e91c 100644 --- a/experiments/Convert_VTK_To_USD/convert_chop_alterra_valve_to_usd.ipynb +++ b/experiments/Convert_VTK_To_USD/convert_chop_alterra_valve_to_usd.ipynb @@ -25,19 +25,6 @@ "5. Create multiple variations (full resolution, subsampled, etc.)" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from pathlib import Path\n", - "import re\n", - "import time as time_module\n", - "\n", - "import shutil" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -53,31 +40,13 @@ "metadata": {}, "outputs": [], "source": [ - "# Configuration: Control which conversions to run\n", - "# Set to True to compute full time series (all frames) - takes longer\n", - "# Set to False to only compute subsampled time series (faster, for preview)\n", - "COMPUTE_FULL_TIME_SERIES = True # Default: only subsampled\n", - "\n", - "print(\"Time Series Configuration:\")\n", - "print(f\" - Compute Full Time Series: {COMPUTE_FULL_TIME_SERIES}\")\n", - "print(\" - Compute Subsampled Time Series: Always enabled\")\n", - "print()\n", - "if not COMPUTE_FULL_TIME_SERIES:\n", - " print(\"⚠️ Full time series conversion is DISABLED for faster execution.\")\n", - " print(\" Set COMPUTE_FULL_TIME_SERIES = True to enable full conversion.\")\n", - "else:\n", - " print(\"✓ Full time series conversion is ENABLED (this will take longer).\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import logging\n", + "from pathlib import Path\n", + "import re\n", + "import time as time_module\n", + "\n", "import numpy as np\n", "\n", + "\n", "# Import the vtk_to_usd library\n", "from physiomotion4d.vtk_to_usd import (\n", " VTKToUSDConverter,\n", @@ -92,8 +61,8 @@ "from physiomotion4d.usd_tools import USDTools\n", "from physiomotion4d.usd_anatomy_tools import USDAnatomyTools\n", "\n", - "# Configure logging\n", - "logging.basicConfig(level=logging.INFO, format=\"%(levelname)s: %(message)s\")" + "# Use as a test\n", + "from physiomotion4d.notebook_utils import running_as_test" ] }, { @@ -109,16 +78,45 @@ "metadata": {}, "outputs": [], "source": [ + "# Set to True to use as a test. Automatically done by\n", + "# running_as_test() helper function.\n", + "quick_run = running_as_test()\n", + "quick_run_step = 4\n", + "\n", "# Define data directories (Alterra only)\n", "data_dir = Path.cwd().parent.parent / \"data\" / \"CHOP-Valve4D\"\n", - "Alterra_dir = data_dir / \"Alterra\"\n", - "output_dir = Path.cwd() / \"output\" / \"valve4d-alterra\"\n", - "output_dir.mkdir(parents=True, exist_ok=True)\n", + "alterra_dir = data_dir / \"Alterra\"\n", + "\n", + "output_dir = Path.cwd() / \"results\" / \"valve4d-alterra\"\n", + "if quick_run:\n", + " output_usd = output_dir / \"alterra_quick.usd\"\n", + "else:\n", + " output_usd = output_dir / \"alterra_full.usd\"\n", "\n", - "print(f\"Data directory: {data_dir}\")\n", - "print(f\"Output directory: {output_dir}\")\n", - "print(\"\\nDirectory status:\")\n", - "print(f\" Alterra: {'✓' if Alterra_dir.exists() else '✗'} {Alterra_dir}\")" + "colormap_primvar_substrs = [\"stress\", \"strain\"]\n", + "colormap_name = \"jet\" # matplotlib colormap name\n", + "colormap_range_min = 25\n", + "colormap_range_max = 200\n", + "\n", + "conversion_settings = ConversionSettings(\n", + " triangulate_meshes=True,\n", + " compute_normals=False, # Use existing normals if available\n", + " preserve_point_arrays=True,\n", + " preserve_cell_arrays=True,\n", + " separate_objects_by_cell_type=False,\n", + " separate_objects_by_connectivity=True, # Essential for alterra vtk file\n", + " up_axis=\"Y\",\n", + " times_per_second=60.0, # 60 FPS for smooth animation\n", + " use_time_samples=True,\n", + ")\n", + "\n", + "stent_material = MaterialData(\n", + " name=\"Alterra_valve\",\n", + " diffuse_color=(0.5, 0.5, 0.5),\n", + " roughness=0.4,\n", + " metallic=0.9,\n", + " use_vertex_colors=False,\n", + ")" ] }, { @@ -127,44 +125,21 @@ "metadata": {}, "outputs": [], "source": [ - "def discover_time_series(directory, pattern=r\"\\.t(\\d+)\\.vtk$\"):\n", - " \"\"\"Discover and sort time-series VTK files.\n", - "\n", - " Args:\n", - " directory: Directory containing VTK files\n", - " pattern: Regex pattern to extract time step number\n", - "\n", - " Returns:\n", - " list: Sorted list of (time_step, file_path) tuples\n", - " \"\"\"\n", - " vtk_files = list(Path(directory).glob(\"*.vtk\"))\n", - "\n", - " # Extract time step numbers and pair with files\n", - " time_series = []\n", - " for vtk_file in vtk_files:\n", - " match = re.search(pattern, vtk_file.name)\n", - " if match:\n", - " time_step = int(match.group(1))\n", - " time_series.append((time_step, vtk_file))\n", - "\n", - " # Sort by time step\n", - " time_series.sort(key=lambda x: x[0])\n", - "\n", - " return time_series\n", - "\n", - "\n", - "# Discover Alterra time series\n", - "Alterra_series = discover_time_series(Alterra_dir)\n", - "\n", - "print(\"=\" * 60)\n", - "print(\"Time-Series Discovery (Alterra)\")\n", - "print(\"=\" * 60)\n", - "print(\"\\nAlterra:\")\n", - "print(f\" Files found: {len(Alterra_series)}\")\n", - "if Alterra_series:\n", - " print(f\" Time range: t{Alterra_series[0][0]} to t{Alterra_series[-1][0]}\")\n", - " print(f\" First file: {Alterra_series[0][1].name}\")\n", - " print(f\" Last file: {Alterra_series[-1][1].name}\")" + "output_dir.mkdir(parents=True, exist_ok=True)\n", + "\n", + "vtk_files = list(Path(alterra_dir).glob(\"*.vtk\"))\n", + "pattern = r\"\\.t(\\d+)\\.vtk$\"\n", + "\n", + "# Extract time step numbers and pair with files\n", + "alterra_series = []\n", + "for vtk_file in vtk_files:\n", + " match = re.search(pattern, vtk_file.name)\n", + " if match:\n", + " time_step = int(match.group(1))\n", + " alterra_series.append((time_step, vtk_file))\n", + "\n", + "# Sort by time step\n", + "alterra_series.sort(key=lambda x: x[0])" ] }, { @@ -182,106 +157,49 @@ "metadata": {}, "outputs": [], "source": [ - "# Read first frame of Alterra\n", - "if Alterra_series:\n", - " print(\"=\" * 60)\n", - " print(\"Alterra - First Frame Analysis\")\n", - " print(\"=\" * 60)\n", - "\n", - " first_file = Alterra_series[0][1]\n", - " mesh_data = read_vtk_file(first_file, extract_surface=True)\n", - "\n", - " print(f\"\\nFile: {first_file.name}\")\n", - " print(\"\\nGeometry:\")\n", - " print(f\" Points: {len(mesh_data.points):,}\")\n", - " print(f\" Faces: {len(mesh_data.face_vertex_counts):,}\")\n", - " print(f\" Normals: {'Yes' if mesh_data.normals is not None else 'No'}\")\n", - " print(f\" Colors: {'Yes' if mesh_data.colors is not None else 'No'}\")\n", - "\n", - " # Bounding box\n", - " bbox_min = np.min(mesh_data.points, axis=0)\n", - " bbox_max = np.max(mesh_data.points, axis=0)\n", - " bbox_size = bbox_max - bbox_min\n", - " print(\"\\nBounding Box:\")\n", - " print(f\" Min: [{bbox_min[0]:.3f}, {bbox_min[1]:.3f}, {bbox_min[2]:.3f}]\")\n", - " print(f\" Max: [{bbox_max[0]:.3f}, {bbox_max[1]:.3f}, {bbox_max[2]:.3f}]\")\n", - " print(f\" Size: [{bbox_size[0]:.3f}, {bbox_size[1]:.3f}, {bbox_size[2]:.3f}]\")\n", - "\n", - " print(f\"\\nData Arrays ({len(mesh_data.generic_arrays)}):\")\n", - " for i, array in enumerate(mesh_data.generic_arrays, 1):\n", - " print(f\" {i}. {array.name}:\")\n", - " print(f\" - Type: {array.data_type.value}\")\n", - " print(f\" - Components: {array.num_components}\")\n", - " print(f\" - Interpolation: {array.interpolation}\")\n", - " print(f\" - Elements: {len(array.data):,}\")\n", - " if array.data.size > 0:\n", - " print(f\" - Range: [{np.min(array.data):.6f}, {np.max(array.data):.6f}]\")\n", - "\n", - " # Cell types (face vertex count) - TPV data has multiple cell types (triangle, quad, etc.)\n", - " unique_counts, num_each = np.unique(\n", - " mesh_data.face_vertex_counts, return_counts=True\n", - " )\n", - " print(\"\\nCell types (faces by vertex count):\")\n", - " for u, n in zip(unique_counts, num_each):\n", - " name = cell_type_name_for_vertex_count(int(u))\n", - " print(f\" {name} ({u} vertices): {n:,} faces\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Note: Helper functions removed - now using USDTools for primvar inspection and colorization\n", - "# The workflow has changed to: convert to USD first, then apply colormap post-processing\n", - "\n", - "# Configuration: choose colormap for visualization\n", - "DEFAULT_COLORMAP = \"viridis\" # matplotlib colormap name\n", - "\n", - "# Enable automatic colorization (will pick strain/stress primvars if available)\n", - "ENABLE_AUTO_COLORIZATION = True\n", - "\n", - "print(\"Colorization will be applied after USD conversion using USDTools methods\")\n", - "print(\" - USDTools.list_mesh_primvars() for inspection\")\n", - "print(\" - USDTools.pick_color_primvar() for selection\")\n", - "print(\" - USDTools.apply_colormap_from_primvar() for coloring\")\n", - "print(f\" - Colormap: {DEFAULT_COLORMAP}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "## 2. Configure Conversion Settings\n", - "\n", - "# Create converter settings\n", - "settings = ConversionSettings(\n", - " triangulate_meshes=True,\n", - " compute_normals=False, # Use existing normals if available\n", - " preserve_point_arrays=True,\n", - " preserve_cell_arrays=True,\n", - " separate_objects_by_cell_type=False,\n", - " separate_objects_by_connectivity=True,\n", - " up_axis=\"Y\",\n", - " times_per_second=60.0, # 60 FPS for smooth animation\n", - " use_time_samples=True,\n", - ")\n", - "\n", - "print(\"Conversion settings configured\")\n", - "print(f\" - Triangulate: {settings.triangulate_meshes}\")\n", - "print(f\" - Separate objects by cell type: {settings.separate_objects_by_cell_type}\")\n", - "print(f\" - FPS: {settings.times_per_second}\")\n", - "print(f\" - Up axis: {settings.up_axis}\")" + "# Debuggin\n", + "first_file = alterra_series[0][1]\n", + "mesh_data = read_vtk_file(first_file, extract_surface=True)\n", + "\n", + "print(f\"\\nFile: {first_file.name}\")\n", + "print(\"\\nGeometry:\")\n", + "print(f\" Points: {len(mesh_data.points):,}\")\n", + "print(f\" Faces: {len(mesh_data.face_vertex_counts):,}\")\n", + "print(f\" Normals: {'Yes' if mesh_data.normals is not None else 'No'}\")\n", + "print(f\" Colors: {'Yes' if mesh_data.colors is not None else 'No'}\")\n", + "\n", + "# Bounding box\n", + "bbox_min = np.min(mesh_data.points, axis=0)\n", + "bbox_max = np.max(mesh_data.points, axis=0)\n", + "bbox_size = bbox_max - bbox_min\n", + "print(\"\\nBounding Box:\")\n", + "print(f\" Min: [{bbox_min[0]:.3f}, {bbox_min[1]:.3f}, {bbox_min[2]:.3f}]\")\n", + "print(f\" Max: [{bbox_max[0]:.3f}, {bbox_max[1]:.3f}, {bbox_max[2]:.3f}]\")\n", + "print(f\" Size: [{bbox_size[0]:.3f}, {bbox_size[1]:.3f}, {bbox_size[2]:.3f}]\")\n", + "\n", + "print(f\"\\nData Arrays ({len(mesh_data.generic_arrays)}):\")\n", + "for i, array in enumerate(mesh_data.generic_arrays, 1):\n", + " print(f\" {i}. {array.name}:\")\n", + " print(f\" - Type: {array.data_type.value}\")\n", + " print(f\" - Components: {array.num_components}\")\n", + " print(f\" - Interpolation: {array.interpolation}\")\n", + " print(f\" - Elements: {len(array.data):,}\")\n", + " if array.data.size > 0:\n", + " print(f\" - Range: [{np.min(array.data):.6f}, {np.max(array.data):.6f}]\")\n", + "\n", + "# Cell types (face vertex count = triangle, quad, etc.)\n", + "unique_counts, num_each = np.unique(mesh_data.face_vertex_counts, return_counts=True)\n", + "print(\"\\nCell types (faces by vertex count):\")\n", + "for u, n in zip(unique_counts, num_each):\n", + " name = cell_type_name_for_vertex_count(int(u))\n", + " print(f\" {name} ({u} vertices): {n:,} faces\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## 3. Convert Full Time Series - TPV25" + "## 3. Convert TPV25" ] }, { @@ -290,62 +208,45 @@ "metadata": {}, "outputs": [], "source": [ - "# Create material for Alterra\n", - "# Note: Vertex colors will be applied post-conversion by USDTools\n", - "Alterra_material = MaterialData(\n", - " name=\"Alterra_valve\",\n", - " diffuse_color=(0.85, 0.4, 0.4),\n", - " roughness=0.4,\n", - " metallic=0.0,\n", - " use_vertex_colors=False, # USDTools will bind vertex color material during colorization\n", - ")\n", - "\n", - "print(\"=\" * 60)\n", - "print(\"Converting Alterra Time Series\")\n", - "print(\"=\" * 60)\n", - "print(f\"Dataset: {len(Alterra_series)} frames\")\n", - "\n", - "# Convert Alterra (full resolution)\n", - "if COMPUTE_FULL_TIME_SERIES and Alterra_series:\n", - " converter = VTKToUSDConverter(settings)\n", + "converter = VTKToUSDConverter(conversion_settings)\n", "\n", - " Alterra_files = [file_path for _, file_path in Alterra_series]\n", - " Alterra_times = [float(time_step) for time_step, _ in Alterra_series]\n", + "alterra_files = [file_path for _, file_path in alterra_series]\n", + "alterra_times = [float(time_step) for time_step, _ in alterra_series]\n", "\n", - " output_usd = output_dir / \"Alterra_full.usd\"\n", + "if quick_run:\n", + " alterra_files = alterra_files[::quick_run_step]\n", + " alterra_times = alterra_times[::quick_run_step]\n", "\n", - " print(f\"\\nConverting to: {output_usd}\")\n", - " print(f\"Time codes: {Alterra_times[0]:.1f} to {Alterra_times[-1]:.1f}\")\n", - " print(\"\\nThis may take several minutes...\\n\")\n", + "print(f\"\\nConverting to: {output_usd}\")\n", + "print(f\"Number of time steps: {len(alterra_times)}\")\n", + "print(\"\\nThis may take several minutes...\\n\")\n", "\n", - " start_time = time_module.time()\n", + "start_time = time_module.time()\n", "\n", - " # Read MeshData\n", - " mesh_data_sequence = [read_vtk_file(f, extract_surface=True) for f in Alterra_files]\n", + "# Read MeshData\n", + "mesh_data_sequence = [read_vtk_file(f, extract_surface=True) for f in alterra_files]\n", "\n", - " # Validate topology consistency across time series\n", - " validation_report = validate_time_series_topology(\n", - " mesh_data_sequence, filenames=Alterra_files\n", + "# Validate topology consistency across time series\n", + "validation_report = validate_time_series_topology(\n", + " mesh_data_sequence, filenames=alterra_files\n", + ")\n", + "if not validation_report[\"is_consistent\"]:\n", + " print(\n", + " f\"Warning: Found {len(validation_report['warnings'])} topology/primvar issues\"\n", " )\n", - " if not validation_report[\"is_consistent\"]:\n", + " if validation_report[\"topology_changes\"]:\n", " print(\n", - " f\"Warning: Found {len(validation_report['warnings'])} topology/primvar issues\"\n", + " f\" Topology changes in {len(validation_report['topology_changes'])} frames\"\n", " )\n", - " if validation_report[\"topology_changes\"]:\n", - " print(\n", - " f\" Topology changes in {len(validation_report['topology_changes'])} frames\"\n", - " )\n", - "\n", - " # Convert to USD (preserves all primvars from VTK)\n", - " stage = converter.convert_mesh_data_sequence(\n", - " mesh_data_sequence=mesh_data_sequence,\n", - " output_usd=output_usd,\n", - " mesh_name=\"AlterraValve\",\n", - " time_codes=Alterra_times,\n", - " material=Alterra_material,\n", - " )\n", "\n", - " shutil.copy(output_usd, output_usd.with_suffix(\".save.usd\"))" + "# Convert to USD (preserves all primvars from VTK)\n", + "stage = converter.convert_mesh_data_sequence(\n", + " mesh_data_sequence=mesh_data_sequence,\n", + " output_usd=output_usd,\n", + " mesh_name=\"AlterraValve\",\n", + " time_codes=alterra_times,\n", + " material=stent_material,\n", + ")" ] }, { @@ -354,59 +255,33 @@ "metadata": {}, "outputs": [], "source": [ - "if COMPUTE_FULL_TIME_SERIES and Alterra_series:\n", - " # Post-process: apply colormap visualization using USDTools\n", - " if ENABLE_AUTO_COLORIZATION:\n", - " usd_tools = USDTools()\n", - " usd_anatomy_tools = USDAnatomyTools(stage)\n", - " if settings.separate_objects_by_connectivity is True:\n", - " mesh_path1 = \"/World/Meshes/AlterraValve_object3\"\n", - " mesh_path2 = \"/World/Meshes/AlterraValve_object4\"\n", - " elif settings.separate_objects_by_cell_type is True:\n", - " mesh_path1 = \"/World/Meshes/AlterraValve_triangle1\"\n", - " mesh_path2 = \"/World/Meshes/AlterraValve_triangle1\"\n", - " else:\n", - " mesh_path1 = \"/World/Meshes/AlterraValve\"\n", - " mesh_path2 = None\n", - "\n", - " # Inspect and select primvar for coloring\n", - " primvars = usd_tools.list_mesh_primvars(str(output_usd), mesh_path1)\n", - " print(primvars)\n", - " color_primvar = usd_tools.pick_color_primvar(\n", - " primvars, keywords=(\"strain\", \"stress\")\n", - " )\n", - "\n", - " if color_primvar:\n", - " print(f\"\\nApplying colormap to '{color_primvar}'\")\n", - " usd_tools.apply_colormap_from_primvar(\n", - " str(output_usd),\n", - " mesh_path1,\n", - " color_primvar,\n", - " # intensity_range=(0, 300),\n", - " cmap=\"hot\",\n", - " # use_sigmoid_scale=True,\n", - " bind_vertex_color_material=True,\n", - " )\n", - " if mesh_path2 is not None:\n", - " mesh_prim = stage.GetPrimAtPath(mesh_path2)\n", - " usd_anatomy_tools.apply_anatomy_material_to_prim(\n", - " mesh_prim, usd_anatomy_tools.bone_params\n", - " )\n", - "\n", - " if not validation_report[\"is_consistent\"]:\n", - " print(\n", - " f\"Warning: Found {len(validation_report['warnings'])} topology/primvar issues\"\n", - " )\n", - " if validation_report[\"topology_changes\"]:\n", - " print(\"\\nNo strain/stress primvar found for coloring\")\n", + "usd_tools = USDTools()\n", + "usd_anatomy_tools = USDAnatomyTools(stage)\n", + "if conversion_settings.separate_objects_by_connectivity is True:\n", + " vessel_path = \"/World/Meshes/AlterraValve_object3\"\n", + "elif conversion_settings.separate_objects_by_cell_type is True:\n", + " vessel_path = \"/World/Meshes/AlterraValve_triangle1\"\n", + "else:\n", + " vessel_path = \"/World/Meshes/AlterraValve\"\n", "\n", - " print(f\" Size: {output_usd.stat().st_size / (1024 * 1024):.2f} MB\")\n", - " print(f\" Time range: {stage.GetStartTimeCode()} - {stage.GetEndTimeCode()}\")\n", - " print(\n", - " f\" Duration: {(stage.GetEndTimeCode() - stage.GetStartTimeCode()) / settings.times_per_second:.2f} seconds @ {settings.times_per_second} FPS\"\n", - " )\n", - "elif not COMPUTE_FULL_TIME_SERIES:\n", - " print(\"⏭️ Skipping Alterra full time series (COMPUTE_FULL_TIME_SERIES = False)\")" + "# Select primvar for coloring\n", + "primvars = usd_tools.list_mesh_primvars(str(output_usd), vessel_path)\n", + "color_primvar = usd_tools.pick_color_primvar(\n", + " primvars, keywords=tuple(colormap_primvar_substrs)\n", + ")\n", + "print(f\"Chosen primvar = {color_primvar}\")\n", + "\n", + "if color_primvar:\n", + " print(f\"\\nApplying colormap to '{color_primvar}' using {colormap_name}\")\n", + " usd_tools.apply_colormap_from_primvar(\n", + " str(output_usd),\n", + " vessel_path,\n", + " color_primvar,\n", + " intensity_range=(colormap_range_min, colormap_range_max),\n", + " cmap=colormap_name,\n", + " use_sigmoid_scale=True,\n", + " bind_vertex_color_material=True,\n", + " )" ] } ], diff --git a/experiments/Convert_VTK_To_USD/convert_chop_heart_vtk_to_usd.ipynb b/experiments/Convert_VTK_To_USD/convert_chop_heart_vtk_to_usd.ipynb index ef64038..c79526c 100644 --- a/experiments/Convert_VTK_To_USD/convert_chop_heart_vtk_to_usd.ipynb +++ b/experiments/Convert_VTK_To_USD/convert_chop_heart_vtk_to_usd.ipynb @@ -33,19 +33,42 @@ " \"RightVentricle\",\n", "]\n", "\n", + "input_dir = (\n", + " Path.cwd().parent.parent / \"data\" / \"CHOP-Valve4D\" / \"CT\" / \"Simpleware\" / \"parts\"\n", + ")\n", + "output_dir = Path.cwd() / \"results\" / \"heart\"\n", + "\n", + "output_dir.mkdir(parents=True, exist_ok=True)\n", + "\n", + "all_files = []\n", "for vtkname, usdname in zip(vtknames, usdnames):\n", " if os.path.exists(Path.absolute(Path(f\"RVOT28-Dias-{usdname}.usd\"))):\n", " os.remove(Path.absolute(Path(f\"RVOT28-Dias-{usdname}.usd\")))\n", + "\n", + " in_name = input_dir / f\"{vtkname}.vtk\"\n", + " all_files.append(in_name)\n", + "\n", " converter = WorkflowConvertVTKToUSD(\n", - " vtk_files=[f\"../../data/CHOP-Valve4D/Simpleware/parts/{vtkname}.vtk\"],\n", - " output_usd=Path.absolute(Path(f\"RVOT28-Dias-{usdname}.usd\")),\n", + " vtk_files=[in_name],\n", + " output_usd=Path.absolute(output_dir / Path(f\"RVOT28-Dias-{usdname}.usd\")),\n", " separate_by_connectivity=False,\n", " separate_by_cell_type=False,\n", " mesh_name=f\"RVOT28Dias_{usdname}\",\n", " appearance=\"anatomy\",\n", " anatomy_type=\"heart\",\n", " )\n", - " converter.run()" + " converter.run()\n", + "\n", + "converter = WorkflowConvertVTKToUSD(\n", + " vtk_files=all_files,\n", + " output_usd=Path.absolute(output_dir / Path(\"RVOT28-Dias-WholeHeart.usd\")),\n", + " separate_by_connectivity=False,\n", + " separate_by_cell_type=False,\n", + " mesh_name=\"RVOT28Dias_WholeHeart\",\n", + " appearance=\"anatomy\",\n", + " anatomy_type=\"heart\",\n", + ")\n", + "converter.run()" ] } ], diff --git a/experiments/Convert_VTK_To_USD/convert_chop_tpv25_valve_to_usd.ipynb b/experiments/Convert_VTK_To_USD/convert_chop_tpv25_valve_to_usd.ipynb index 9fa4b66..b8f05fe 100644 --- a/experiments/Convert_VTK_To_USD/convert_chop_tpv25_valve_to_usd.ipynb +++ b/experiments/Convert_VTK_To_USD/convert_chop_tpv25_valve_to_usd.ipynb @@ -25,19 +25,6 @@ "5. Create multiple variations (full resolution, subsampled, etc.)" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from pathlib import Path\n", - "import re\n", - "import time as time_module\n", - "\n", - "import shutil" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -53,31 +40,12 @@ "metadata": {}, "outputs": [], "source": [ - "# Configuration: Control which conversions to run\n", - "# Set to True to compute full time series (all frames) - takes longer\n", - "# Set to False to only compute subsampled time series (faster, for preview)\n", - "COMPUTE_FULL_TIME_SERIES = True # Default: only subsampled\n", - "\n", - "print(\"Time Series Configuration:\")\n", - "print(f\" - Compute Full Time Series: {COMPUTE_FULL_TIME_SERIES}\")\n", - "print(\" - Compute Subsampled Time Series: Always enabled\")\n", - "print()\n", - "if not COMPUTE_FULL_TIME_SERIES:\n", - " print(\"⚠️ Full time series conversion is DISABLED for faster execution.\")\n", - " print(\" Set COMPUTE_FULL_TIME_SERIES = True to enable full conversion.\")\n", - "else:\n", - " print(\"✓ Full time series conversion is ENABLED (this will take longer).\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import logging\n", + "from pathlib import Path\n", + "import re\n", + "import time as time_module\n", + "\n", "import numpy as np\n", - "from pxr import Usd, UsdGeom\n", + "\n", "\n", "# Import the vtk_to_usd library\n", "from physiomotion4d.vtk_to_usd import (\n", @@ -92,9 +60,7 @@ "# Import USDTools for post-processing colormap\n", "from physiomotion4d.usd_tools import USDTools\n", "from physiomotion4d.usd_anatomy_tools import USDAnatomyTools\n", - "\n", - "# Configure logging\n", - "logging.basicConfig(level=logging.INFO, format=\"%(levelname)s: %(message)s\")" + "from physiomotion4d.notebook_utils import running_as_test" ] }, { @@ -110,315 +76,45 @@ "metadata": {}, "outputs": [], "source": [ + "# Set to True to use as a test. Automatically done by\n", + "# running_as_test() helper function.\n", + "quick_run = running_as_test()\n", + "quick_run_step = 4\n", + "\n", "# Define data directories (TPV25 only)\n", "data_dir = Path.cwd().parent.parent / \"data\" / \"CHOP-Valve4D\"\n", "tpv25_dir = data_dir / \"TPV25\"\n", - "output_dir = Path.cwd() / \"output\" / \"valve4d\"\n", - "output_dir.mkdir(parents=True, exist_ok=True)\n", - "\n", - "print(f\"Data directory: {data_dir}\")\n", - "print(f\"Output directory: {output_dir}\")\n", - "print(\"\\nDirectory status:\")\n", - "print(f\" TPV25: {'✓' if tpv25_dir.exists() else '✗'} {tpv25_dir}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def discover_time_series(directory, pattern=r\"\\.t(\\d+)\\.vtk$\"):\n", - " \"\"\"Discover and sort time-series VTK files.\n", - "\n", - " Args:\n", - " directory: Directory containing VTK files\n", - " pattern: Regex pattern to extract time step number\n", - "\n", - " Returns:\n", - " list: Sorted list of (time_step, file_path) tuples\n", - " \"\"\"\n", - " vtk_files = list(Path(directory).glob(\"*.vtk\"))\n", - "\n", - " # Extract time step numbers and pair with files\n", - " time_series = []\n", - " for vtk_file in vtk_files:\n", - " match = re.search(pattern, vtk_file.name)\n", - " if match:\n", - " time_step = int(match.group(1))\n", - " time_series.append((time_step, vtk_file))\n", - "\n", - " # Sort by time step\n", - " time_series.sort(key=lambda x: x[0])\n", - "\n", - " return time_series\n", - "\n", - "\n", - "# Discover TPV25 time series\n", - "tpv25_series = discover_time_series(tpv25_dir)\n", - "\n", - "print(\"=\" * 60)\n", - "print(\"Time-Series Discovery (TPV25)\")\n", - "print(\"=\" * 60)\n", - "print(\"\\nTPV25:\")\n", - "print(f\" Files found: {len(tpv25_series)}\")\n", - "if tpv25_series:\n", - " print(f\" Time range: t{tpv25_series[0][0]} to t{tpv25_series[-1][0]}\")\n", - " print(f\" First file: {tpv25_series[0][1].name}\")\n", - " print(f\" Last file: {tpv25_series[-1][1].name}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 2. Inspect First Frame\n", - "\n", - "Examine the first time step to understand the data structure." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Read first frame of TPV25\n", - "if tpv25_series:\n", - " print(\"=\" * 60)\n", - " print(\"TPV25 - First Frame Analysis\")\n", - " print(\"=\" * 60)\n", - "\n", - " first_file = tpv25_series[0][1]\n", - " mesh_data = read_vtk_file(first_file, extract_surface=True)\n", - "\n", - " print(f\"\\nFile: {first_file.name}\")\n", - " print(\"\\nGeometry:\")\n", - " print(f\" Points: {len(mesh_data.points):,}\")\n", - " print(f\" Faces: {len(mesh_data.face_vertex_counts):,}\")\n", - " print(f\" Normals: {'Yes' if mesh_data.normals is not None else 'No'}\")\n", - " print(f\" Colors: {'Yes' if mesh_data.colors is not None else 'No'}\")\n", - "\n", - " # Bounding box\n", - " bbox_min = np.min(mesh_data.points, axis=0)\n", - " bbox_max = np.max(mesh_data.points, axis=0)\n", - " bbox_size = bbox_max - bbox_min\n", - " print(\"\\nBounding Box:\")\n", - " print(f\" Min: [{bbox_min[0]:.3f}, {bbox_min[1]:.3f}, {bbox_min[2]:.3f}]\")\n", - " print(f\" Max: [{bbox_max[0]:.3f}, {bbox_max[1]:.3f}, {bbox_max[2]:.3f}]\")\n", - " print(f\" Size: [{bbox_size[0]:.3f}, {bbox_size[1]:.3f}, {bbox_size[2]:.3f}]\")\n", - "\n", - " print(f\"\\nData Arrays ({len(mesh_data.generic_arrays)}):\")\n", - " for i, array in enumerate(mesh_data.generic_arrays, 1):\n", - " print(f\" {i}. {array.name}:\")\n", - " print(f\" - Type: {array.data_type.value}\")\n", - " print(f\" - Components: {array.num_components}\")\n", - " print(f\" - Interpolation: {array.interpolation}\")\n", - " print(f\" - Elements: {len(array.data):,}\")\n", - " if array.data.size > 0:\n", - " print(f\" - Range: [{np.min(array.data):.6f}, {np.max(array.data):.6f}]\")\n", - "\n", - " # Cell types (face vertex count) - TPV data has multiple cell types (triangle, quad, etc.)\n", - " unique_counts, num_each = np.unique(\n", - " mesh_data.face_vertex_counts, return_counts=True\n", - " )\n", - " print(\"\\nCell types (faces by vertex count):\")\n", - " for u, n in zip(unique_counts, num_each):\n", - " name = cell_type_name_for_vertex_count(int(u))\n", - " print(f\" {name} ({u} vertices): {n:,} faces\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Note: Helper functions removed - now using USDTools for primvar inspection and colorization\n", - "# The workflow has changed to: convert to USD first, then apply colormap post-processing\n", - "\n", - "# Configuration: choose colormap for visualization\n", - "DEFAULT_COLORMAP = \"viridis\" # matplotlib colormap name\n", "\n", - "# Enable automatic colorization (will pick strain/stress primvars if available)\n", - "ENABLE_AUTO_COLORIZATION = True\n", + "output_dir = Path.cwd() / \"results\" / \"valve4d-tpv25\"\n", + "if quick_run:\n", + " output_usd = output_dir / \"tpv25_quick.usd\"\n", + "else:\n", + " output_usd = output_dir / \"tpv25_full.usd\"\n", "\n", - "print(\"Colorization will be applied after USD conversion using USDTools methods\")\n", - "print(\" - USDTools.list_mesh_primvars() for inspection\")\n", - "print(\" - USDTools.pick_color_primvar() for selection\")\n", - "print(\" - USDTools.apply_colormap_from_primvar() for coloring\")\n", - "print(f\" - Colormap: {DEFAULT_COLORMAP}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "## 2. Configure Conversion Settings\n", + "colormap_primvar_substrs = [\"stress\", \"strain\"]\n", + "colormap_name = \"jet\" # matplotlib colormap name\n", + "colormap_range_min = 25\n", + "colormap_range_max = 200\n", "\n", - "# Create converter settings\n", - "settings = ConversionSettings(\n", + "conversion_settings = ConversionSettings(\n", " triangulate_meshes=True,\n", " compute_normals=False, # Use existing normals if available\n", " preserve_point_arrays=True,\n", " preserve_cell_arrays=True,\n", " separate_objects_by_cell_type=False,\n", - " separate_objects_by_connectivity=True,\n", + " separate_objects_by_connectivity=True, # Essential for tpv25 vtk file\n", " up_axis=\"Y\",\n", " times_per_second=60.0, # 60 FPS for smooth animation\n", " use_time_samples=True,\n", ")\n", "\n", - "print(\"Conversion settings configured\")\n", - "print(f\" - Triangulate: {settings.triangulate_meshes}\")\n", - "print(f\" - Separate objects by cell type: {settings.separate_objects_by_cell_type}\")\n", - "print(f\" - FPS: {settings.times_per_second}\")\n", - "print(f\" - Up axis: {settings.up_axis}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 3. Convert Full Time Series - TPV25" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Create material for TPV25\n", - "# Note: Vertex colors will be applied post-conversion by USDTools\n", - "# Create material for TPV25\n", - "# Note: Vertex colors will be applied post-conversion by USDTools\n", - "tpv25_material = MaterialData(\n", + "stent_material = MaterialData(\n", " name=\"tpv25_valve\",\n", - " diffuse_color=(0.85, 0.4, 0.4),\n", + " diffuse_color=(0.5, 0.5, 0.5),\n", " roughness=0.4,\n", - " metallic=0.0,\n", - " use_vertex_colors=False, # USDTools will bind vertex color material during colorization\n", - ")\n", - "\n", - "print(\"=\" * 60)\n", - "print(\"Converting TPV25 Time Series\")\n", - "print(\"=\" * 60)\n", - "print(f\"Dataset: {len(tpv25_series)} frames\")\n", - "\n", - "# Convert TPV25 (full resolution)\n", - "if COMPUTE_FULL_TIME_SERIES and tpv25_series:\n", - " converter = VTKToUSDConverter(settings)\n", - "\n", - " tpv25_files = [file_path for _, file_path in tpv25_series]\n", - " tpv25_times = [float(time_step) for time_step, _ in tpv25_series]\n", - "\n", - " output_usd = output_dir / \"tpv25_full.usd\"\n", - "\n", - " print(f\"\\nConverting to: {output_usd}\")\n", - " print(f\"Time codes: {tpv25_times[0]:.1f} to {tpv25_times[-1]:.1f}\")\n", - " print(\"\\nThis may take several minutes...\\n\")\n", - "\n", - " start_time = time_module.time()\n", - "\n", - " # Read MeshData\n", - " mesh_data_sequence = [read_vtk_file(f, extract_surface=True) for f in tpv25_files]\n", - "\n", - " # Validate topology consistency across time series\n", - " validation_report = validate_time_series_topology(\n", - " mesh_data_sequence, filenames=tpv25_files\n", - " )\n", - " if not validation_report[\"is_consistent\"]:\n", - " print(\n", - " f\"Warning: Found {len(validation_report['warnings'])} topology/primvar issues\"\n", - " )\n", - " if validation_report[\"topology_changes\"]:\n", - " print(\n", - " f\" Topology changes in {len(validation_report['topology_changes'])} frames\"\n", - " )\n", - "\n", - " # Convert to USD (preserves all primvars from VTK)\n", - " stage = converter.convert_mesh_data_sequence(\n", - " mesh_data_sequence=mesh_data_sequence,\n", - " output_usd=output_usd,\n", - " mesh_name=\"TPV25Valve\",\n", - " time_codes=tpv25_times,\n", - " material=tpv25_material,\n", - " )\n", - "\n", - " shutil.copy(output_usd, output_usd.with_suffix(\".save.usd\"))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "if COMPUTE_FULL_TIME_SERIES and tpv25_series:\n", - " # Post-process: apply colormap visualization using USDTools\n", - " if ENABLE_AUTO_COLORIZATION:\n", - " usd_tools = USDTools()\n", - " usd_anatomy_tools = USDAnatomyTools(stage)\n", - " if settings.separate_objects_by_connectivity is True:\n", - " mesh_path1 = \"/World/Meshes/TPV25Valve_object4\"\n", - " mesh_path2 = \"/World/Meshes/TPV25Valve_object3\"\n", - " elif settings.separate_objects_by_cell_type is True:\n", - " mesh_path1 = \"/World/Meshes/TPV25Valve_triangle1\"\n", - " mesh_path2 = \"/World/Meshes/TPV25Valve_triangle1\"\n", - " else:\n", - " mesh_path1 = \"/World/Meshes/TPV25Valve\"\n", - " mesh_path2 = None\n", - "\n", - " # Inspect and select primvar for coloring\n", - " primvars = usd_tools.list_mesh_primvars(str(output_usd), mesh_path1)\n", - " print(primvars)\n", - " color_primvar = usd_tools.pick_color_primvar(\n", - " primvars, keywords=(\"strain\", \"stress\")\n", - " )\n", - "\n", - " if color_primvar:\n", - " print(f\"\\nApplying colormap to '{color_primvar}' using {DEFAULT_COLORMAP}\")\n", - " usd_tools.apply_colormap_from_primvar(\n", - " str(output_usd),\n", - " mesh_path1,\n", - " color_primvar,\n", - " # intensity_range=(75, 200),\n", - " cmap=\"hot\",\n", - " # use_sigmoid_scale=True,\n", - " bind_vertex_color_material=True,\n", - " )\n", - " if mesh_path2 is not None:\n", - " mesh_prim = stage.GetPrimAtPath(mesh_path2)\n", - " usd_anatomy_tools.apply_anatomy_material_to_prim(\n", - " mesh_prim, usd_anatomy_tools.bone_params\n", - " )\n", - "\n", - " if not validation_report[\"is_consistent\"]:\n", - " print(\n", - " f\"Warning: Found {len(validation_report['warnings'])} topology/primvar issues\"\n", - " )\n", - " if validation_report[\"topology_changes\"]:\n", - " print(\"\\nNo strain/stress primvar found for coloring\")\n", - "\n", - " print(f\" Size: {output_usd.stat().st_size / (1024 * 1024):.2f} MB\")\n", - " print(f\" Time range: {stage.GetStartTimeCode()} - {stage.GetEndTimeCode()}\")\n", - " print(\n", - " f\" Duration: {(stage.GetEndTimeCode() - stage.GetStartTimeCode()) / settings.times_per_second:.2f} seconds @ {settings.times_per_second} FPS\"\n", - " )\n", - "elif not COMPUTE_FULL_TIME_SERIES:\n", - " print(\"⏭️ Skipping TPV25 full time series (COMPUTE_FULL_TIME_SERIES = False)\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 4. Convert Subsampled Time Series - TPV25 (single mesh)\n", - "\n", - "Convert TPV25 with every 5th frame to **tpv25_subsample_5x.usd**. Uses default settings (no split); one mesh prim `TPV25Valve`." + " metallic=0.9,\n", + " use_vertex_colors=False,\n", + ")" ] }, { @@ -427,220 +123,30 @@ "metadata": {}, "outputs": [], "source": [ - "# Subsample TPV25 (every 5th frame)\n", - "if tpv25_series:\n", - " subsample_rate = 5\n", - " tpv25_subsampled = tpv25_series[::subsample_rate]\n", - "\n", - " print(\"=\" * 60)\n", - " print(f\"Converting Subsampled TPV25 (every {subsample_rate}th frame)\")\n", - " print(\"=\" * 60)\n", - " print(f\"Frames: {len(tpv25_series)} → {len(tpv25_subsampled)}\")\n", - "\n", - " converter = VTKToUSDConverter(settings)\n", - "\n", - " tpv25_files_sub = [file_path for _, file_path in tpv25_subsampled]\n", - " tpv25_times_sub = [float(time_step) for time_step, _ in tpv25_subsampled]\n", - "\n", - " output_usd_sub = output_dir / f\"tpv25_subsample_{subsample_rate}x.usd\"\n", - "\n", - " print(f\"\\nConverting to: {output_usd_sub}\")\n", - "\n", - " start_time = time_module.time()\n", - "\n", - " # Read MeshData\n", - " mesh_data_sequence = [\n", - " read_vtk_file(f, extract_surface=True) for f in tpv25_files_sub\n", - " ]\n", - "\n", - " # Validate topology consistency across time series\n", - " validation_report = validate_time_series_topology(\n", - " mesh_data_sequence, filenames=tpv25_files_sub\n", - " )\n", - " if not validation_report[\"is_consistent\"]:\n", - " print(\n", - " f\"Warning: Found {len(validation_report['warnings'])} topology/primvar issues\"\n", - " )\n", - " if validation_report[\"topology_changes\"]:\n", - " print(\n", - " f\" Topology changes in {len(validation_report['topology_changes'])} frames\"\n", - " )\n", - "\n", - " # Convert to USD (preserves all primvars from VTK)\n", - " stage_sub = converter.convert_mesh_data_sequence(\n", - " mesh_data_sequence=mesh_data_sequence,\n", - " output_usd=output_usd_sub,\n", - " mesh_name=\"TPV25Valve\",\n", - " time_codes=tpv25_times_sub,\n", - " material=tpv25_material,\n", - " )\n", - "\n", - " # Post-process: apply colormap visualization using USDTools\n", - " if ENABLE_AUTO_COLORIZATION:\n", - " usd_tools = USDTools()\n", - " if settings.separate_objects_by_connectivity is True:\n", - " mesh_path = \"/World/Meshes/object3\"\n", - " elif settings.separate_objects_by_cell_type is True:\n", - " mesh_path = \"/World/Meshes/triangle1\"\n", - " else:\n", - " mesh_path = \"/World/Meshes/TPV25Valve\"\n", - "\n", - " # Inspect and select primvar for coloring\n", - " primvars = usd_tools.list_mesh_primvars(str(output_usd_sub), mesh_path)\n", - " color_primvar = usd_tools.pick_color_primvar(\n", - " primvars, keywords=(\"strain\", \"stress\")\n", - " )\n", - "\n", - " if color_primvar:\n", - " print(f\"\\nApplying colormap to '{color_primvar}' using {DEFAULT_COLORMAP}\")\n", - " usd_tools.apply_colormap_from_primvar(\n", - " str(output_usd_sub),\n", - " mesh_path,\n", - " color_primvar,\n", - " cmap=DEFAULT_COLORMAP,\n", - " bind_vertex_color_material=True,\n", - " )\n", - " else:\n", - " print(\"\\nNo strain/stress primvar found for coloring\")\n", - "\n", - " elapsed = time_module.time() - start_time\n", - "\n", - " print(f\"\\n✓ Conversion completed in {elapsed:.1f} seconds\")\n", - " print(f\" Output: {output_usd_sub}\")\n", - " print(f\" Size: {output_usd_sub.stat().st_size / (1024 * 1024):.2f} MB\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 5. TPV25 Subsampled — Split by Cell Type\n", - "\n", - "When `separate_objects_by_cell_type=True`, the converter splits the mesh into **separate USD prims** by cell type (triangle, quad, etc.). Output: **tpv25_subsample_5x_by_cell_type.usd** (distinct from the single-mesh subsample).\n", - "\n", - "TPV data contains multiple cell types (see first-frame analysis). Here we convert the same subsampled TPV25 sequence with triangulation off so quads remain quads; the stage has one mesh per cell type (e.g. `Triangle_0`, `Quad_0`)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Convert TPV25 subsampled with separate meshes per cell type (triangulate=False to preserve quads)\n", - "if tpv25_series:\n", - " settings_by_cell_type = ConversionSettings(\n", - " triangulate_meshes=False, # Keep quads so we get both Triangle_0 and Quad_0\n", - " compute_normals=False,\n", - " preserve_point_arrays=True,\n", - " preserve_cell_arrays=True,\n", - " separate_objects_by_cell_type=True,\n", - " separate_objects_by_connectivity=False,\n", - " up_axis=\"Y\",\n", - " times_per_second=60.0,\n", - " use_time_samples=True,\n", - " )\n", + "output_dir.mkdir(parents=True, exist_ok=True)\n", "\n", - " subsample_rate = 5\n", - " tpv25_subsampled = tpv25_series[::subsample_rate]\n", - " tpv25_files_sub = [file_path for _, file_path in tpv25_subsampled]\n", - " tpv25_times_sub = [float(t) for t, _ in tpv25_subsampled]\n", + "vtk_files = list(Path(tpv25_dir).glob(\"*.vtk\"))\n", + "pattern = r\"\\.t(\\d+)\\.vtk$\"\n", "\n", - " output_by_cell_type = output_dir / \"tpv25_subsample_5x_by_cell_type.usd\"\n", - " print(\"Converting TPV25 (subsampled) with separate objects by cell type...\")\n", - " print(\n", - " \" triangulate_meshes=False → triangles and quads preserved as separate meshes\"\n", - " )\n", - " print(f\" Output: {output_by_cell_type.name}\")\n", - "\n", - " converter_ct = VTKToUSDConverter(settings_by_cell_type)\n", - " mesh_data_sequence = [\n", - " read_vtk_file(f, extract_surface=True) for f in tpv25_files_sub\n", - " ]\n", - " stage_ct = converter_ct.convert_mesh_data_sequence(\n", - " mesh_data_sequence=mesh_data_sequence,\n", - " output_usd=output_by_cell_type,\n", - " mesh_name=\"TPV25Valve\", # base name when not splitting; ignored when splitting\n", - " time_codes=tpv25_times_sub,\n", - " material=tpv25_material,\n", - " )\n", + "# Extract time step numbers and pair with files\n", + "tpv25_series = []\n", + "for vtk_file in vtk_files:\n", + " match = re.search(pattern, vtk_file.name)\n", + " if match:\n", + " time_step = int(match.group(1))\n", + " tpv25_series.append((time_step, vtk_file))\n", "\n", - " # List mesh prims under /World/Meshes (each cell type is a separate prim)\n", - " meshes_prim = stage_ct.GetPrimAtPath(\"/World/Meshes\")\n", - " if meshes_prim:\n", - " children = meshes_prim.GetChildren()\n", - " print(f\"\\nMesh prims created (by cell type): {len(children)}\")\n", - " for child in children:\n", - " print(f\" - {child.GetPath().pathString}\")\n", - " print(f\"\\n✓ Saved: {output_by_cell_type}\")" + "# Sort by time step\n", + "tpv25_series.sort(key=lambda x: x[0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## 6. TPV25 Subsampled — Split by Connectivity\n", - "\n", - "When `separate_objects_by_connectivity=True`, the converter splits the mesh into **separate USD prims** by connected component (object1, object2, ...). Output: **tpv25_subsample_5x_by_connectivity.usd** (distinct from single-mesh and by-cell-type).\n", - "\n", - "Only one of `separate_objects_by_cell_type` and `separate_objects_by_connectivity` can be enabled at a time." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Convert TPV25 subsampled with separate meshes per connected component\n", - "if tpv25_series:\n", - " settings_by_connectivity = ConversionSettings(\n", - " triangulate_meshes=True,\n", - " compute_normals=False,\n", - " preserve_point_arrays=True,\n", - " preserve_cell_arrays=True,\n", - " separate_objects_by_cell_type=False,\n", - " separate_objects_by_connectivity=True,\n", - " up_axis=\"Y\",\n", - " times_per_second=60.0,\n", - " use_time_samples=True,\n", - " )\n", - "\n", - " subsample_rate = 5\n", - " tpv25_subsampled = tpv25_series[::subsample_rate]\n", - " tpv25_files_sub = [file_path for _, file_path in tpv25_subsampled]\n", - " tpv25_times_sub = [float(t) for t, _ in tpv25_subsampled]\n", - "\n", - " output_by_connectivity = output_dir / \"tpv25_subsample_5x_by_connectivity.usd\"\n", - " print(\"Converting TPV25 (subsampled) with separate objects by connectivity...\")\n", - " print(f\" Output: {output_by_connectivity.name}\")\n", - "\n", - " converter_conn = VTKToUSDConverter(settings_by_connectivity)\n", - " mesh_data_sequence = [\n", - " read_vtk_file(f, extract_surface=True) for f in tpv25_files_sub\n", - " ]\n", - " stage_conn = converter_conn.convert_mesh_data_sequence(\n", - " mesh_data_sequence=mesh_data_sequence,\n", - " output_usd=output_by_connectivity,\n", - " mesh_name=\"TPV25Valve\",\n", - " time_codes=tpv25_times_sub,\n", - " material=tpv25_material,\n", - " )\n", + "## 2. Inspect First Frame\n", "\n", - " meshes_prim = stage_conn.GetPrimAtPath(\"/World/Meshes\")\n", - " if meshes_prim:\n", - " children = meshes_prim.GetChildren()\n", - " print(f\"\\nMesh prims created (by connectivity): {len(children)}\")\n", - " for child in children:\n", - " print(f\" - {child.GetPath().pathString}\")\n", - " print(f\"\\n✓ Saved: {output_by_connectivity}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 7. Summary and File Inspection" + "Examine the first time step to understand the data structure." ] }, { @@ -649,64 +155,49 @@ "metadata": {}, "outputs": [], "source": [ - "import os\n", - "\n", - "print(\"=\" * 60)\n", - "print(\"Conversion Summary\")\n", - "print(\"=\" * 60)\n", - "\n", - "# List all generated USD files\n", - "usd_files = list(output_dir.glob(\"*.usd\"))\n", - "usd_files.extend(output_dir.glob(\"*.usda\"))\n", - "usd_files.extend(output_dir.glob(\"*.usdc\"))\n", - "\n", - "total_size = 0\n", - "\n", - "for usd_file in sorted(usd_files):\n", - " size_mb = os.path.getsize(usd_file) / (1024 * 1024)\n", - " total_size += size_mb\n", - "\n", - " print(f\"\\n{usd_file.name}:\")\n", - " print(f\" Size: {size_mb:.2f} MB\")\n", - "\n", - " # Open and inspect\n", - " stage = Usd.Stage.Open(str(usd_file))\n", - " if stage:\n", - " if stage.HasAuthoredTimeCodeRange():\n", - " duration = (\n", - " stage.GetEndTimeCode() - stage.GetStartTimeCode()\n", - " ) / stage.GetTimeCodesPerSecond()\n", - " print(\n", - " f\" Time range: {stage.GetStartTimeCode():.0f} - {stage.GetEndTimeCode():.0f}\"\n", - " )\n", - " print(\n", - " f\" Duration: {duration:.2f} seconds @ {stage.GetTimeCodesPerSecond():.0f} FPS\"\n", - " )\n", - " print(\n", - " f\" Frames: {int(stage.GetEndTimeCode() - stage.GetStartTimeCode() + 1)}\"\n", - " )\n", - "\n", - " # Count meshes\n", - " mesh_count = 0\n", - " for prim in stage.Traverse():\n", - " if prim.IsA(UsdGeom.Mesh):\n", - " mesh_count += 1\n", - " print(f\" Meshes: {mesh_count}\")\n", - "\n", - "print(f\"\\n{'=' * 60}\")\n", - "print(f\"Total size: {total_size:.2f} MB\")\n", - "print(f\"Total files: {len(usd_files)}\")\n", - "print(f\"Output directory: {output_dir}\")\n", - "print(f\"{'=' * 60}\")" + "# Debuggin\n", + "first_file = tpv25_series[0][1]\n", + "mesh_data = read_vtk_file(first_file, extract_surface=True)\n", + "\n", + "print(f\"\\nFile: {first_file.name}\")\n", + "print(\"\\nGeometry:\")\n", + "print(f\" Points: {len(mesh_data.points):,}\")\n", + "print(f\" Faces: {len(mesh_data.face_vertex_counts):,}\")\n", + "print(f\" Normals: {'Yes' if mesh_data.normals is not None else 'No'}\")\n", + "print(f\" Colors: {'Yes' if mesh_data.colors is not None else 'No'}\")\n", + "\n", + "# Bounding box\n", + "bbox_min = np.min(mesh_data.points, axis=0)\n", + "bbox_max = np.max(mesh_data.points, axis=0)\n", + "bbox_size = bbox_max - bbox_min\n", + "print(\"\\nBounding Box:\")\n", + "print(f\" Min: [{bbox_min[0]:.3f}, {bbox_min[1]:.3f}, {bbox_min[2]:.3f}]\")\n", + "print(f\" Max: [{bbox_max[0]:.3f}, {bbox_max[1]:.3f}, {bbox_max[2]:.3f}]\")\n", + "print(f\" Size: [{bbox_size[0]:.3f}, {bbox_size[1]:.3f}, {bbox_size[2]:.3f}]\")\n", + "\n", + "print(f\"\\nData Arrays ({len(mesh_data.generic_arrays)}):\")\n", + "for i, array in enumerate(mesh_data.generic_arrays, 1):\n", + " print(f\" {i}. {array.name}:\")\n", + " print(f\" - Type: {array.data_type.value}\")\n", + " print(f\" - Components: {array.num_components}\")\n", + " print(f\" - Interpolation: {array.interpolation}\")\n", + " print(f\" - Elements: {len(array.data):,}\")\n", + " if array.data.size > 0:\n", + " print(f\" - Range: [{np.min(array.data):.6f}, {np.max(array.data):.6f}]\")\n", + "\n", + "# Cell types (face vertex count = triangle, quad, etc.)\n", + "unique_counts, num_each = np.unique(mesh_data.face_vertex_counts, return_counts=True)\n", + "print(\"\\nCell types (faces by vertex count):\")\n", + "for u, n in zip(unique_counts, num_each):\n", + " name = cell_type_name_for_vertex_count(int(u))\n", + " print(f\" {name} ({u} vertices): {n:,} faces\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## 8. Detailed USD Inspection\n", - "\n", - "Examine the converted USD files to verify data preservation." + "## 3. Convert TPV25" ] }, { @@ -715,161 +206,45 @@ "metadata": {}, "outputs": [], "source": [ - "# Inspect one of the converted files in detail\n", - "inspect_file = output_dir / \"tpv25_subsample_5x.usd\"\n", - "\n", - "if inspect_file.exists():\n", - " print(\"=\" * 60)\n", - " print(f\"Detailed Inspection: {inspect_file.name}\")\n", - " print(\"=\" * 60)\n", - "\n", - " stage = Usd.Stage.Open(str(inspect_file))\n", - "\n", - " # Find mesh prim\n", - " mesh_prim = None\n", - " for prim in stage.Traverse():\n", - " if prim.IsA(UsdGeom.Mesh):\n", - " mesh_prim = prim\n", - " break\n", - "\n", - " if mesh_prim:\n", - " mesh = UsdGeom.Mesh(mesh_prim)\n", - "\n", - " print(f\"\\nMesh: {mesh_prim.GetPath()}\")\n", - "\n", - " # Geometry at first frame\n", - " first_time = stage.GetStartTimeCode()\n", - " points = mesh.GetPointsAttr().Get(first_time)\n", - " faces = mesh.GetFaceVertexCountsAttr().Get()\n", - "\n", - " print(f\"\\nGeometry (at t={first_time:.0f}):\")\n", - " print(f\" Points: {len(points):,}\")\n", - " print(f\" Faces: {len(faces):,}\")\n", - "\n", - " # Check time-varying attributes\n", - " print(\"\\nTime-Varying Attributes:\")\n", - " points_attr = mesh.GetPointsAttr()\n", - " if points_attr.GetNumTimeSamples() > 0:\n", - " print(f\" Points: {points_attr.GetNumTimeSamples()} time samples\")\n", - "\n", - " # List primvars\n", - " primvars_api = UsdGeom.PrimvarsAPI(mesh)\n", - " primvars = primvars_api.GetPrimvars()\n", - "\n", - " print(f\"\\nPrimvars ({len(primvars)}):\")\n", - " for primvar in primvars:\n", - " name = primvar.GetPrimvarName()\n", - " interpolation = primvar.GetInterpolation()\n", - " type_name = primvar.GetTypeName()\n", - " value = primvar.Get(first_time)\n", - " size = len(value) if value else 0\n", - "\n", - " print(f\" - {name}:\")\n", - " print(f\" Type: {type_name}\")\n", - " print(f\" Interpolation: {interpolation}\")\n", - " print(f\" Elements: {size:,}\")\n", - "\n", - " # Check if time-varying\n", - " if primvar.GetAttr().GetNumTimeSamples() > 0:\n", - " print(f\" Time samples: {primvar.GetAttr().GetNumTimeSamples()}\")\n", - "\n", - " # Material binding\n", - " from pxr import UsdShade\n", - "\n", - " binding_api = UsdShade.MaterialBindingAPI(mesh)\n", - " material_binding = binding_api.GetDirectBinding()\n", - " if material_binding:\n", - " print(f\"\\nMaterial: {material_binding.GetMaterialPath()}\")\n", - "else:\n", - " print(f\"File not found: {inspect_file}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 8.5. Post-Process USD with USDTools\n", + "converter = VTKToUSDConverter(conversion_settings)\n", "\n", - "Demonstrate using the new `USDTools` methods to inspect primvars and apply colormap visualization to existing USD files." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Example: Post-process an existing USD file to add colormap visualization\n", - "from physiomotion4d.usd_tools import USDTools\n", + "tpv25_files = [file_path for _, file_path in tpv25_series]\n", + "tpv25_times = [float(time_step) for time_step, _ in tpv25_series]\n", "\n", - "usd_tools = USDTools()\n", + "if quick_run:\n", + " tpv25_files = tpv25_files[::quick_run_step]\n", + " tpv25_times = tpv25_times[::quick_run_step]\n", "\n", - "# Pick a USD file to post-process\n", - "postprocess_file = output_dir / \"tpv25_subsample_5x.usd\"\n", + "print(f\"\\nConverting to: {output_usd}\")\n", + "print(f\"Number of time steps: {len(tpv25_times)}\")\n", + "print(\"\\nThis may take several minutes...\\n\")\n", "\n", - "if postprocess_file.exists():\n", - " print(\"=\" * 60)\n", - " print(f\"Post-Processing: {postprocess_file.name}\")\n", - " print(\"=\" * 60)\n", + "start_time = time_module.time()\n", "\n", - " # 1. List available primvars on the mesh\n", - " mesh_path = \"/World/Meshes/AlterraValve\"\n", - " primvars = usd_tools.list_mesh_primvars(str(postprocess_file), mesh_path)\n", + "# Read MeshData\n", + "mesh_data_sequence = [read_vtk_file(f, extract_surface=True) for f in tpv25_files]\n", "\n", - " print(f\"\\nAvailable primvars on {mesh_path}:\")\n", - " for pv in primvars:\n", - " time_info = (\n", - " f\", {pv['num_time_samples']} time samples\"\n", - " if pv[\"num_time_samples\"] > 0\n", - " else \"\"\n", - " )\n", - " range_info = (\n", - " f\", range={pv['range'][0]:.3g}..{pv['range'][1]:.3g}\" if pv[\"range\"] else \"\"\n", - " )\n", + "# Validate topology consistency across time series\n", + "validation_report = validate_time_series_topology(\n", + " mesh_data_sequence, filenames=tpv25_files\n", + ")\n", + "if not validation_report[\"is_consistent\"]:\n", + " print(\n", + " f\"Warning: Found {len(validation_report['warnings'])} topology/primvar issues\"\n", + " )\n", + " if validation_report[\"topology_changes\"]:\n", " print(\n", - " f\" - {pv['name']}: {pv['interpolation']}, {pv['elements']} elements{time_info}{range_info}\"\n", + " f\" Topology changes in {len(validation_report['topology_changes'])} frames\"\n", " )\n", "\n", - " # 2. Pick best primvar for coloring (prefer strain/stress)\n", - " color_primvar = usd_tools.pick_color_primvar(primvars)\n", - " print(f\"\\nAuto-selected for coloring: {color_primvar}\")\n", - "\n", - " # 3. Apply colormap to create displayColor visualization\n", - " # Note: This modifies the USD file in-place\n", - " if color_primvar:\n", - " print(f\"\\nApplying 'plasma' colormap to '{color_primvar}'...\")\n", - "\n", - " # Create a copy for demonstration (optional)\n", - " demo_file = output_dir / f\"{postprocess_file.stem}_colored.usd\"\n", - " import shutil\n", - "\n", - " shutil.copy(postprocess_file, demo_file)\n", - "\n", - " usd_tools.apply_colormap_from_primvar(\n", - " str(demo_file),\n", - " mesh_path,\n", - " color_primvar,\n", - " cmap=\"plasma\",\n", - " write_default_at_t0=True,\n", - " bind_vertex_color_material=True,\n", - " )\n", - "\n", - " print(f\"\\n✓ Created colored visualization: {demo_file.name}\")\n", - " print(f\" - displayColor primvar added with colormap from {color_primvar}\")\n", - " print(\" - Vertex color material bound for immediate visualization\")\n", - " print(\" - Ready to open in Omniverse with default coloring\")\n", - " else:\n", - " print(\"\\n⚠️ No suitable primvar found for coloring\")\n", - "else:\n", - " print(f\"File not found: {postprocess_file}\")\n", - " print(\"Run the conversion cells first to generate USD files.\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 9. Performance Analysis" + "# Convert to USD (preserves all primvars from VTK)\n", + "stage = converter.convert_mesh_data_sequence(\n", + " mesh_data_sequence=mesh_data_sequence,\n", + " output_usd=output_usd,\n", + " mesh_name=\"TPV25Valve\",\n", + " time_codes=tpv25_times,\n", + " material=stent_material,\n", + ")" ] }, { @@ -878,88 +253,33 @@ "metadata": {}, "outputs": [], "source": [ - "# Analyze conversion performance\n", - "print(\"=\" * 60)\n", - "print(\"Performance Analysis\")\n", - "print(\"=\" * 60)\n", - "\n", - "# Read a few frames to estimate per-frame metrics\n", - "if tpv25_series:\n", - " sample_files = [\n", - " tpv25_series[0][1],\n", - " tpv25_series[len(tpv25_series) // 2][1],\n", - " tpv25_series[-1][1],\n", - " ]\n", - "\n", - " total_points = 0\n", - " total_faces = 0\n", - " total_arrays = 0\n", - "\n", - " for sample_file in sample_files:\n", - " mesh_data = read_vtk_file(sample_file, extract_surface=True)\n", - " total_points += len(mesh_data.points)\n", - " total_faces += len(mesh_data.face_vertex_counts)\n", - " total_arrays += len(mesh_data.generic_arrays)\n", - "\n", - " avg_points = total_points / len(sample_files)\n", - " avg_faces = total_faces / len(sample_files)\n", - " avg_arrays = total_arrays / len(sample_files)\n", - "\n", - " print(\"\\nTPV25 Dataset:\")\n", - " print(f\" Average points per frame: {avg_points:,.0f}\")\n", - " print(f\" Average faces per frame: {avg_faces:,.0f}\")\n", - " print(f\" Average data arrays per frame: {avg_arrays:.0f}\")\n", - " print(f\" Total frames: {len(tpv25_series)}\")\n", - " print(f\" Estimated total points: {avg_points * len(tpv25_series):,.0f}\")\n", - " print(f\" Estimated total faces: {avg_faces * len(tpv25_series):,.0f}\")\n", - "\n", - "print(f\"\\n{'=' * 60}\")\n", - "print(\"\\n✓ All conversions completed!\")\n", - "print(\"\\nView the results:\")\n", - "print(\" - USDView: usdview .usd\")\n", - "print(\" - Omniverse: Open in Create/View/Composer\")\n", - "print(f\"\\nOutput files: {output_dir}\")\n", - "print(\"=\" * 60)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Conclusion\n", - "\n", - "This notebook demonstrated converting large-scale time-varying cardiac valve simulation data to USD:\n", - "\n", - "### Key Accomplishments\n", - "\n", - "1. **Discovered and organized** 200+ frame time-series datasets\n", - "2. **Converted full-resolution** datasets to animated USD\n", - "3. **Created subsampled versions** for faster preview\n", - "4. **Preserved all simulation data** as USD primvars\n", - "5. **Applied custom materials** for visualization\n", - "6. **Handled coordinate systems** (RAS → Y-up)\n", - "\n", - "### File Outputs\n", - "\n", - "- `tpv25_full.usd` - Complete 265-frame animation (single mesh)\n", - "- `tpv25_subsample_5x.usd` - Subsampled, single mesh\n", - "- `tpv25_subsample_5x_by_cell_type.usd` - Subsampled, split by cell type (Triangle_0, Quad_0, ...)\n", - "- `tpv25_subsample_5x_by_connectivity.usd` - Subsampled, split by connectivity (object1, object2, ...)\n", - "\n", - "### Performance Notes\n", - "\n", - "- Full conversions may take several minutes due to large frame counts\n", - "- Subsampling provides faster iteration during development\n", - "- All VTK point and cell data arrays are preserved as primvars\n", - "- Time-sampled attributes enable efficient animation\n", - "\n", - "### Next Steps\n", + "usd_tools = USDTools()\n", + "usd_anatomy_tools = USDAnatomyTools(stage)\n", + "if conversion_settings.separate_objects_by_connectivity is True:\n", + " vessel_path = \"/World/Meshes/TPV25Valve_object4\"\n", + "elif conversion_settings.separate_objects_by_cell_type is True:\n", + " vessel_path = \"/World/Meshes/TPV25Valve_triangle1\"\n", + "else:\n", + " vessel_path = \"/World/Meshes/TPV25Valve\"\n", "\n", - "1. **View animations** in USDView or Omniverse\n", - "2. **Analyze primvars** to visualize simulation data\n", - "3. **Create custom materials** based on data arrays\n", - "4. **Compose scenes** or add multiple assets for comparison\n", - "5. **Add cameras and lighting** for publication-quality renders" + "# Select primvar for coloring\n", + "primvars = usd_tools.list_mesh_primvars(str(output_usd), vessel_path)\n", + "color_primvar = usd_tools.pick_color_primvar(\n", + " primvars, keywords=tuple(colormap_primvar_substrs)\n", + ")\n", + "print(f\"Chosen primvar = {color_primvar}\")\n", + "\n", + "if color_primvar:\n", + " print(f\"\\nApplying colormap to '{color_primvar}' using {colormap_name}\")\n", + " usd_tools.apply_colormap_from_primvar(\n", + " str(output_usd),\n", + " vessel_path,\n", + " color_primvar,\n", + " intensity_range=(colormap_range_min, colormap_range_max),\n", + " cmap=colormap_name,\n", + " use_sigmoid_scale=True,\n", + " bind_vertex_color_material=True,\n", + " )" ] } ], diff --git a/experiments/Heart-Statistical_Model_To_Patient/heart_model_to_patient-CHOPValve.ipynb b/experiments/Heart-Statistical_Model_To_Patient/heart_model_to_patient-CHOPValve.ipynb new file mode 100644 index 0000000..5224ab8 --- /dev/null +++ b/experiments/Heart-Statistical_Model_To_Patient/heart_model_to_patient-CHOPValve.ipynb @@ -0,0 +1,243 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup and Imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "import os\n", + "from pathlib import Path\n", + "\n", + "import itk\n", + "import pyvista as pv\n", + "\n", + "# Import from PhysioMotion4D package\n", + "from physiomotion4d import (\n", + " WorkflowFitStatisticalModelToPatient,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define File Paths" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Patient CT image (defines coordinate frame)\n", + "patient_data_dir = Path.cwd().parent.parent / \"data\" / \"CHOP-Valve4D\" / \"CT\"\n", + "patient_ct_path = patient_data_dir / \"RVOT28-Dias.mha\"\n", + "\n", + "# Template model (moving)\n", + "model_data_dir = Path.cwd().parent.parent / \"data\" / \"KCL-Heart-Model\"\n", + "model_labelmap_path = model_data_dir / \"labelmap\" / \"average_labelmap_with_bkg.mha\"\n", + "model_pca_data_dir = (\n", + " Path.cwd().parent / \"Heart-Create_Statistical_Model\" / \"kcl-heart-model\"\n", + ")\n", + "model_pca_json_path = model_pca_data_dir / \"pca_model.json\"\n", + "model_mesh_path = model_pca_data_dir / \"pca_mean.vtp\"\n", + "model_pca_n_modes = 10\n", + "\n", + "# Output directory\n", + "output_dir = Path.cwd() / \"results-chop\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "patient_image = itk.imread(str(patient_ct_path))\n", + "\n", + "template_model = pv.read(str(model_mesh_path))\n", + "template_model_surface = template_model.extract_surface()\n", + "template_labelmap = itk.imread(str(model_labelmap_path))\n", + "\n", + "with open(model_pca_json_path, encoding=\"utf-8\") as f:\n", + " model_pca_data = json.load(f)\n", + "\n", + "os.makedirs(output_dir, exist_ok=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "registrar = WorkflowFitStatisticalModelToPatient(\n", + " template_model=template_model,\n", + " patient_image=patient_image,\n", + " segmentation_method=\"simpleware_heart\",\n", + ")\n", + "\n", + "registrar.set_use_pca_registration(\n", + " True, pca_model=model_pca_data, pca_number_of_modes=model_pca_n_modes\n", + ")\n", + "\n", + "registrar.set_use_mask_to_mask_registration(True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "patient_image = registrar.patient_image\n", + "itk.imwrite(\n", + " patient_image, str(output_dir / \"patient_image_preprocessed.mha\"), compression=True\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "results = registrar.run_workflow()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "registered_model = results[\"registered_template_model\"]\n", + "registered_model_surface = results[\"registered_template_model_surface\"]\n", + "\n", + "registered_model.save(str(output_dir / \"registered_model.vtp\"))\n", + "registered_model_surface.save(str(output_dir / \"registered_model_surface.vtp\"))\n", + "registered_labelmap = results[\"registered_template_labelmap\"]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pca_model = registrar.pca_template_model\n", + "pca_model_surface = registrar.pca_template_model_surface\n", + "pca_labelmap = registrar.pca_template_labelmap\n", + "\n", + "pca_model.save(str(output_dir / \"pca_model.vtu\"))\n", + "pca_model_surface.save(str(output_dir / \"pca_model_surface.vtp\"))\n", + "itk.imwrite(pca_labelmap, str(output_dir / \"pca_labelmap.mha\"), compression=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualize Final Results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Load meshes from registrar member variables\n", + "patient_surface = registrar.patient_model_surface\n", + "\n", + "# Create side-by-side comparison\n", + "plotter = pv.Plotter(shape=(1, 2))\n", + "\n", + "# After rough alignment\n", + "plotter.subplot(0, 0)\n", + "plotter.add_mesh(patient_surface, color=\"red\", opacity=0.5, label=\"Patient\")\n", + "plotter.add_mesh(pca_model_surface, color=\"green\", opacity=0.8, label=\"After PCA\")\n", + "plotter.add_title(\"PCA Alignment\")\n", + "\n", + "# After deformable registration\n", + "plotter.subplot(0, 1)\n", + "plotter.add_mesh(patient_surface, color=\"red\", opacity=0.5, label=\"Patient\")\n", + "plotter.add_mesh(\n", + " registered_model_surface, color=\"green\", opacity=0.8, label=\"Registered\"\n", + ")\n", + "plotter.add_title(\"Final Registration\")\n", + "\n", + "plotter.link_views()\n", + "plotter.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualize Deformation Magnitude" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# The transformed mesh has deformation magnitude stored as point data\n", + "if \"DeformationMagnitude\" in registered_model_surface.point_data:\n", + " plotter = pv.Plotter()\n", + " plotter.add_mesh(\n", + " registered_model_surface,\n", + " scalars=\"DeformationMagnitude\",\n", + " cmap=\"jet\",\n", + " show_scalar_bar=True,\n", + " scalar_bar_args={\"title\": \"Deformation (mm)\"},\n", + " )\n", + " plotter.add_title(\"Deformation Magnitude\")\n", + " plotter.show()\n", + "\n", + " # Print statistics\n", + " deformation = registered_model_surface[\"DeformationMagnitude\"]\n", + " print(\"Deformation statistics:\")\n", + " print(f\" Min: {deformation.min():.2f} mm\")\n", + " print(f\" Max: {deformation.max():.2f} mm\")\n", + " print(f\" Mean: {deformation.mean():.2f} mm\")\n", + " print(f\" Std: {deformation.std():.2f} mm\")\n", + "else:\n", + " print(\"DeformationMagnitude not found in mesh point data\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} \ No newline at end of file diff --git a/experiments/Reconstruct4DCT/reconstruct_4d_ct_class.ipynb b/experiments/Reconstruct4DCT/reconstruct_4d_ct_class.ipynb index 950a991..0ff7ee4 100644 --- a/experiments/Reconstruct4DCT/reconstruct_4d_ct_class.ipynb +++ b/experiments/Reconstruct4DCT/reconstruct_4d_ct_class.ipynb @@ -85,12 +85,13 @@ " reference_image_num = 7\n", "\n", " # Registration parameters - both ANTs and ICON for full run\n", - " registration_methods = [\"ants\", \"icon\", \"ants_icon\"]\n", + " registration_methods = [\"ants\"] # , \"icon\", \"ants_icon\"]\n", " number_of_iterations_list = [\n", - " [30, 15, 7, 3], # For ANTs\n", - " 20, # For ICON\n", - " [[30, 15, 7, 3], 20], # For ants_icon\n", - " ]\n", + " [30, 15, 7, 3],\n", + " ] # For ANTs\n", + " # 20, # For ICON\n", + " # [[30, 15, 7, 3], 20], # For ants_icon\n", + " # ]\n", "\n", "# Common parameters\n", "reference_image_file = os.path.join(\n", diff --git a/pyproject.toml b/pyproject.toml index eba1270..857845b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -83,10 +83,13 @@ dependencies = [ "matplotlib>=3.5.0", "jupyterlab>=4.0.0", "typing-extensions>=4.0.0", - "cupy-cuda12x>=13.6.0", + "cupy-cuda13x>=13.6.0", ] +[tool.uv] +link-mode = "copy" + [tool.uv.sources] torch = [ { index = "pytorch-cu130" } @@ -163,6 +166,7 @@ Changelog = "https://github.com/aylward/PhysioMotion4d/blob/main/CHANGELOG.md" [project.scripts] # CLI commands installed via pip # Entry points reference the main() functions in the cli submodule +physiomotion4d-convert-ct-to-vtk = "physiomotion4d.cli.convert_ct_to_vtk:main" physiomotion4d-heart-gated-ct = "physiomotion4d.cli.convert_heart_gated_ct_to_usd:main" physiomotion4d-convert-vtk-to-usd = "physiomotion4d.cli.convert_vtk_to_usd:main" physiomotion4d-create-statistical-model = "physiomotion4d.cli.create_statistical_model:main" diff --git a/src/physiomotion4d/__init__.py b/src/physiomotion4d/__init__.py index b1d7183..b5ff9bb 100644 --- a/src/physiomotion4d/__init__.py +++ b/src/physiomotion4d/__init__.py @@ -57,6 +57,7 @@ from .usd_tools import USDTools # Core workflow processor +from .workflow_convert_ct_to_vtk import WorkflowConvertCTToVTK from .workflow_convert_heart_gated_ct_to_usd import WorkflowConvertHeartGatedCTToUSD from .workflow_convert_vtk_to_usd import WorkflowConvertVTKToUSD from .workflow_reconstruct_highres_4d_ct import WorkflowReconstructHighres4DCT @@ -67,6 +68,7 @@ __all__ = [ # Workflow classes + "WorkflowConvertCTToVTK", "WorkflowConvertHeartGatedCTToUSD", "WorkflowConvertVTKToUSD", "WorkflowCreateStatisticalModel", diff --git a/src/physiomotion4d/cli/convert_ct_to_vtk.py b/src/physiomotion4d/cli/convert_ct_to_vtk.py new file mode 100644 index 0000000..6d72cfa --- /dev/null +++ b/src/physiomotion4d/cli/convert_ct_to_vtk.py @@ -0,0 +1,229 @@ +#!/usr/bin/env python +"""Command-line interface for the CT-to-VTK segmentation workflow. + +Segments a 3D CT image using a chosen backend and writes per-anatomy-group VTP +surfaces and VTU voxel meshes annotated with anatomy labels and colors. +""" + +import argparse +import os +import sys +import traceback + +import itk + +from physiomotion4d import WorkflowConvertCTToVTK + + +def main() -> int: + """CLI entry point for CT to VTK conversion.""" + parser = argparse.ArgumentParser( + description="Segment a CT image and export anatomy groups as VTK surfaces and meshes.", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Anatomy groups +-------------- + heart, lung, major_vessels, bone, soft_tissue, other, contrast + (empty groups are skipped automatically) + +Output files — combined mode (default) +--------------------------------------- + {prefix}_surfaces.vtp all surfaces merged into one file + {prefix}_meshes.vtu all voxel meshes merged into one file + +Output files — split mode (--split-files) +------------------------------------------ + {prefix}_{group}.vtp one surface per anatomy group + {prefix}_{group}.vtu one voxel mesh per anatomy group + +Examples +-------- + # Segment with TotalSegmentator, combined output + %(prog)s \\ + --input-image chest_ct.nii.gz \\ + --output-dir ./results + + # VISTA-3D, contrast-enhanced, split per group + %(prog)s \\ + --input-image chest_ct.nii.gz \\ + --segmentation-method vista_3d \\ + --contrast \\ + --split-files \\ + --output-dir ./results \\ + --output-prefix patient01 + + # Simpleware heart-only, cardiac anatomy groups, combined output + %(prog)s \\ + --input-image chest_ct.nii.gz \\ + --segmentation-method simpleware_heart \\ + --anatomy-groups heart major_vessels \\ + --output-dir ./results \\ + --output-prefix patient01 + + # Also save the ITK segmentation labelmap + %(prog)s \\ + --input-image chest_ct.nii.gz \\ + --output-dir ./results \\ + --save-labelmap + """, + ) + + # ── Required ────────────────────────────────────────────────────────── + parser.add_argument( + "--input-image", + required=True, + help="Path to the input CT image (.nii.gz, .nrrd, .mha, …).", + ) + parser.add_argument( + "--output-dir", + required=True, + help="Directory for output files (created if absent).", + ) + + # ── Segmentation ────────────────────────────────────────────────────── + parser.add_argument( + "--segmentation-method", + default="total_segmentator", + choices=list(WorkflowConvertCTToVTK.SEGMENTATION_METHODS), + help=( + "Segmentation backend. " + "total_segmentator (default) | vista_3d | simpleware_heart" + ), + ) + parser.add_argument( + "--contrast", + action="store_true", + default=False, + help="Enable contrast-enhanced blood segmentation (default: disabled).", + ) + parser.add_argument( + "--anatomy-groups", + nargs="+", + metavar="GROUP", + choices=list(WorkflowConvertCTToVTK.ANATOMY_GROUPS), + default=None, + help=( + "Anatomy groups to extract. Default: all non-empty groups. " + "Choices: " + " ".join(WorkflowConvertCTToVTK.ANATOMY_GROUPS) + ), + ) + + # ── Output ──────────────────────────────────────────────────────────── + parser.add_argument( + "--output-prefix", + default="", + help="Filename prefix for output files (default: no prefix).", + ) + parser.add_argument( + "--split-files", + action="store_true", + default=False, + help=( + "Write one VTP and one VTU file per anatomy group instead of " + "merging all groups into a single VTP and VTU (default: combined)." + ), + ) + parser.add_argument( + "--save-labelmap", + action="store_true", + default=False, + help="Also save the detailed per-structure segmentation labelmap as a NIfTI file.", + ) + + args = parser.parse_args() + + # ── Validate inputs ──────────────────────────────────────────────────── + if not os.path.exists(args.input_image): + print(f"Error: input image not found: {args.input_image}") + return 1 + + # ── Load image ───────────────────────────────────────────────────────── + print(f"Loading input image: {args.input_image}") + try: + input_image = itk.imread(args.input_image) + except (FileNotFoundError, OSError, RuntimeError) as exc: + print(f"Error loading image: {exc}") + traceback.print_exc() + return 1 + + # ── Run workflow ──────────────────────────────────────────────────────── + print(f"Segmentation method : {args.segmentation_method}") + print(f"Contrast enhanced : {args.contrast}") + print(f"Anatomy groups : {args.anatomy_groups or 'all'}") + print("=" * 70) + + try: + workflow = WorkflowConvertCTToVTK( + segmentation_method=args.segmentation_method, + ) + result = workflow.run_workflow( + input_image=input_image, + contrast_enhanced_study=args.contrast, + anatomy_groups=args.anatomy_groups, + ) + except (ValueError, RuntimeError, OSError) as exc: + print(f"Error during workflow: {exc}") + traceback.print_exc() + return 1 + + surfaces = result["surfaces"] + meshes = result["meshes"] + + if not surfaces and not meshes: + print("No anatomy groups produced any output. Check the input image.") + return 1 + + # ── Save results ──────────────────────────────────────────────────────── + print("=" * 70) + print("Saving results...") + os.makedirs(args.output_dir, exist_ok=True) + prefix = args.output_prefix + + try: + if args.split_files: + # One file per anatomy group + if surfaces: + saved_surfaces = WorkflowConvertCTToVTK.save_surfaces( + surfaces, args.output_dir, prefix=prefix + ) + for group, path in saved_surfaces.items(): + print(f" Surface [{group:15s}] → {path}") + if meshes: + saved_meshes = WorkflowConvertCTToVTK.save_meshes( + meshes, args.output_dir, prefix=prefix + ) + for group, path in saved_meshes.items(): + print(f" Mesh [{group:15s}] → {path}") + else: + # Combined single-file output + if surfaces: + surface_file = WorkflowConvertCTToVTK.save_combined_surface( + surfaces, args.output_dir, prefix=prefix + ) + print(f" Combined surface → {surface_file}") + if meshes: + mesh_file = WorkflowConvertCTToVTK.save_combined_mesh( + meshes, args.output_dir, prefix=prefix + ) + print(f" Combined mesh → {mesh_file}") + + if args.save_labelmap: + labelmap = result["labelmap"] + stem = f"{prefix}_labelmap" if prefix else "labelmap" + labelmap_file = os.path.join(args.output_dir, f"{stem}.nii.gz") + itk.imwrite(labelmap, labelmap_file) + print(f" Labelmap → {labelmap_file}") + + except (ValueError, OSError, RuntimeError) as exc: + print(f"Error saving results: {exc}") + traceback.print_exc() + return 1 + + print("=" * 70) + print("Conversion completed successfully.") + print(f"Output directory: {args.output_dir}") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/src/physiomotion4d/contour_tools.py b/src/physiomotion4d/contour_tools.py index 3420dca..6832405 100644 --- a/src/physiomotion4d/contour_tools.py +++ b/src/physiomotion4d/contour_tools.py @@ -176,7 +176,10 @@ def create_mask_from_mesh( if hasattr(mesh, "n_faces_strict"): # PyVista PolyData - faces = mesh.faces.reshape((mesh.n_faces_strict, 4))[:, 1:] + num_points_per_face = len(mesh.faces) // mesh.n_faces_strict + faces = mesh.faces.reshape((mesh.n_faces_strict, num_points_per_face))[ + :, 1: + ] else: # Handle other mesh types faces = mesh.faces.reshape((-1, 4))[:, 1:] diff --git a/src/physiomotion4d/register_images_icon.py b/src/physiomotion4d/register_images_icon.py index 0011639..fb1a8c9 100644 --- a/src/physiomotion4d/register_images_icon.py +++ b/src/physiomotion4d/register_images_icon.py @@ -74,6 +74,21 @@ def __init__(self, log_level: int | str = logging.INFO) -> None: self.number_of_iterations: int = 50 self.use_multi_modality: bool = False self.use_mass_preservation: bool = False + self.weights_path: Optional[str] = None + + def set_weights_path(self, weights_path: str) -> None: + """Set a custom weights file for the uniGradICON network. + + Use this to load a fine-tuned checkpoint instead of the default + pretrained weights. Clears any previously loaded network so the new + weights are applied on the next call to register(). + + Args: + weights_path: Path to a uniGradICON checkpoint, e.g. + "results/duke_4d_finetune/checkpoints/network_weights_100" + """ + self.weights_path = weights_path + self.net = None # force reload on next register() call def set_number_of_iterations(self, number_of_iterations: int) -> None: """Set the number of iterations for ICON registration. @@ -211,11 +226,13 @@ def registration_method( loss_fn=icon.LNCC(sigma=5), # loss_fn=icon.losses.MINDSSC(radius=2, dilation=2), apply_intensity_conservation_loss=self.use_mass_preservation, + weights_location=self.weights_path, ) else: self.net = get_unigradicon( loss_fn=icon.LNCC(sigma=5), apply_intensity_conservation_loss=self.use_mass_preservation, + weights_location=self.weights_path, ) inverse_transform = None diff --git a/src/physiomotion4d/segment_heart_simpleware.py b/src/physiomotion4d/segment_heart_simpleware.py index 63776a2..7ba6ddc 100644 --- a/src/physiomotion4d/segment_heart_simpleware.py +++ b/src/physiomotion4d/segment_heart_simpleware.py @@ -6,6 +6,7 @@ structure mappings. """ +import csv import logging import os import subprocess @@ -64,6 +65,8 @@ def __init__(self, log_level: int | str = logging.INFO): """ super().__init__(log_level=log_level) + self.landmarks: dict[str, tuple[float, float, float]] = {} + self.target_spacing = 1.0 # Heart structure IDs for Simpleware Medical ASCardio output @@ -280,6 +283,18 @@ def segmentation_method(self, preprocessed_image: itk.image) -> itk.image: labelmap_array == 0, mask_array, labelmap_array ) + landmarks_file = os.path.join(tmp_dir, "landmarks.csv") + self.landmarks.clear() + with open(landmarks_file, newline="", encoding="utf-8-sig") as fh: + next(fh) # skip line 1 (file header) + for row in csv.DictReader(fh): + coords = row["Measurement"].replace(" mm", "").split(",") + self.landmarks[row["Name"]] = ( + float(coords[0]), + float(coords[1]), + float(coords[2]), + ) + interior_image = itk.GetImageFromArray(interior_array.astype(np.uint8)) interior_image.CopyInformation(preprocessed_image) imMath = tube.ImageMath.New(interior_image) @@ -309,6 +324,10 @@ def segmentation_method(self, preprocessed_image: itk.image) -> itk.image: return labelmap_image + def get_landmarks(self) -> dict[str, tuple[float, float, float]]: + """Get the landmarks.""" + return self.landmarks + def trim_mask_to_essentials(self, labelmap_image: itk.image) -> itk.image: """Trim mask to essentials.""" diff --git a/src/physiomotion4d/simpleware_medical/SimplewareScript_heart_segmentation.py b/src/physiomotion4d/simpleware_medical/SimplewareScript_heart_segmentation.py index c6baac8..6aa156c 100644 --- a/src/physiomotion4d/simpleware_medical/SimplewareScript_heart_segmentation.py +++ b/src/physiomotion4d/simpleware_medical/SimplewareScript_heart_segmentation.py @@ -36,3 +36,6 @@ mask_name = mask.GetName() fixed_name = mask_name.replace(" ", "_").lower() mask.MetaImageExport(os.path.join(output_dir, f"mask_{fixed_name}.mhd")) + +landmarks = doc.GetMeasurements() +landmarks.Export(os.path.join(output_dir, "landmarks.csv")) diff --git a/src/physiomotion4d/usd_anatomy_tools.py b/src/physiomotion4d/usd_anatomy_tools.py index 79215ca..829d326 100644 --- a/src/physiomotion4d/usd_anatomy_tools.py +++ b/src/physiomotion4d/usd_anatomy_tools.py @@ -196,6 +196,33 @@ def get_anatomy_types(self) -> list[str]: """Return list of supported anatomy type names for apply_anatomy_material_to_mesh.""" return list(self._anatomy_params_by_type.keys()) + def get_anatomy_diffuse_color( + self, anatomy_type: str + ) -> tuple[float, float, float]: + """Return the diffuse reflection RGB color for the given anatomy type. + + This accessor does not require a USD stage and may be called on an instance + created with ``stage=None`` purely for color look-up purposes. + + Args: + anatomy_type: One of: heart, lung, bone, major_vessels, contrast, + soft_tissue, other, liver, spleen, kidney. + + Returns: + RGB tuple of floats in ``[0, 1]``. + + Raises: + ValueError: If *anatomy_type* is not supported. + """ + params = self._anatomy_params_by_type.get(anatomy_type.lower()) + if params is None: + raise ValueError( + f"Unknown anatomy_type '{anatomy_type}'. " + f"Supported: {', '.join(self.get_anatomy_types())}" + ) + color = params["diffuse_reflection_color"] + return (float(color[0]), float(color[1]), float(color[2])) + def apply_anatomy_material_to_mesh(self, mesh_path: str, anatomy_type: str) -> None: """Apply an anatomic OmniSurface material to a single mesh prim by type. diff --git a/src/physiomotion4d/workflow_convert_ct_to_vtk.py b/src/physiomotion4d/workflow_convert_ct_to_vtk.py new file mode 100644 index 0000000..a522bfc --- /dev/null +++ b/src/physiomotion4d/workflow_convert_ct_to_vtk.py @@ -0,0 +1,469 @@ +"""Workflow for segmenting a CT image and converting anatomy groups to VTK surfaces and meshes. + +The workflow segments a 3D CT image using a chosen backend, then extracts one VTP +(surface) and one VTU (voxel mesh) per non-empty anatomy group. Each output object +carries anatomy metadata and solid color from :class:`USDAnatomyTools` as field and +cell data so that downstream tools (PyVista, Paraview, USD pipeline) can use them +directly. + +Typical usage:: + + import itk + from physiomotion4d import WorkflowConvertCTToVTK + + ct = itk.imread('chest_ct.nii.gz') + workflow = WorkflowConvertCTToVTK(segmentation_method='total_segmentator') + result = workflow.run_workflow(ct, contrast_enhanced_study=True) + + # Combined single-file output (default) + WorkflowConvertCTToVTK.save_combined_surface(result['surfaces'], './out', prefix='patient') + WorkflowConvertCTToVTK.save_combined_mesh(result['meshes'], './out', prefix='patient') + + # Per-group split output + WorkflowConvertCTToVTK.save_surfaces(result['surfaces'], './out', prefix='patient') + WorkflowConvertCTToVTK.save_meshes(result['meshes'], './out', prefix='patient') +""" + +import logging +import os +from typing import Any, Optional, cast + +import itk +import numpy as np +import pyvista as pv + +from physiomotion4d.contour_tools import ContourTools +from physiomotion4d.physiomotion4d_base import PhysioMotion4DBase +from physiomotion4d.segment_anatomy_base import SegmentAnatomyBase +from physiomotion4d.usd_anatomy_tools import USDAnatomyTools + +#: Ordered tuple of anatomy group names matching :meth:`SegmentAnatomyBase.segment` keys. +ANATOMY_GROUPS: tuple[str, ...] = ( + "heart", + "lung", + "major_vessels", + "bone", + "soft_tissue", + "other", + "contrast", +) + +#: Supported segmentation backend identifiers. +SEGMENTATION_METHODS: tuple[str, ...] = ( + "total_segmentator", + "vista_3d", + "simpleware_heart", +) + + +class WorkflowConvertCTToVTK(PhysioMotion4DBase): + """Segment a CT image and produce per-anatomy-group VTK surfaces and meshes. + + **Segmentation backends** + + - ``'total_segmentator'`` — :class:`SegmentChestTotalSegmentator` (CPU-capable, + default). + - ``'vista_3d'`` — :class:`SegmentChestVista3D` (GPU-accelerated MONAI VISTA-3D). + - ``'simpleware_heart'`` — :class:`SegmentHeartSimpleware` (cardiac only; requires + a Simpleware Medical installation). + + **Output anatomy groups** + + ``heart``, ``lung``, ``major_vessels``, ``bone``, ``soft_tissue``, ``other``, + ``contrast``. Groups that are empty after segmentation are silently skipped. + + **VTK object annotation** + + Each :class:`pyvista.PolyData` surface and :class:`pyvista.UnstructuredGrid` mesh + returned by :meth:`run_workflow` carries: + + - ``field_data['AnatomyGroup']`` — anatomy group name, e.g. ``'heart'``. + - ``field_data['SegmentationLabelNames']`` — individual structure names within the + group (e.g. ``['left_ventricle', 'right_ventricle', …]``). + - ``field_data['SegmentationLabelIds']`` — corresponding integer label IDs. + - ``field_data['AnatomyColor']`` — RGB float color from :class:`USDAnatomyTools`. + - ``cell_data['Color']`` — RGBA uint8 array (n_cells × 4) for direct VTK rendering. + + **I/O contract** + + :meth:`run_workflow` performs *no* file I/O. Use the static helpers + :meth:`save_surfaces`, :meth:`save_meshes`, :meth:`save_combined_surface`, and + :meth:`save_combined_mesh` — or the CLI ``physiomotion4d-convert-ct-to-vtk`` — to + write results to disk. + """ + + #: Valid anatomy group names. + ANATOMY_GROUPS: tuple[str, ...] = ANATOMY_GROUPS + #: Valid segmentation method identifiers. + SEGMENTATION_METHODS: tuple[str, ...] = SEGMENTATION_METHODS + + def __init__( + self, + segmentation_method: str = "total_segmentator", + log_level: int | str = logging.INFO, + ) -> None: + """Initialize the workflow. + + Args: + segmentation_method: Segmentation backend to use. One of + ``'total_segmentator'`` (default), ``'vista_3d'``, or + ``'simpleware_heart'``. + log_level: Logging level. Default: ``logging.INFO``. + + Raises: + ValueError: If *segmentation_method* is not one of + :attr:`SEGMENTATION_METHODS`. + """ + super().__init__(class_name=self.__class__.__name__, log_level=log_level) + + if segmentation_method not in self.SEGMENTATION_METHODS: + raise ValueError( + f"Unknown segmentation_method '{segmentation_method}'. " + f"Choose from: {', '.join(self.SEGMENTATION_METHODS)}" + ) + + self.segmentation_method_name: str = segmentation_method + self._segmenter: Optional[SegmentAnatomyBase] = None + self._contour_tools: ContourTools = ContourTools(log_level=log_level) + + # Build anatomy-group → RGB color from USDAnatomyTools. + # USDAnatomyTools sets up its color dicts entirely in __init__ without + # accessing the stage, so stage=None is safe for this lookup-only use. + _anatomy_tools = USDAnatomyTools(stage=None, log_level=log_level) + supported_types = set(_anatomy_tools.get_anatomy_types()) + self._anatomy_color_map: dict[str, tuple[float, float, float]] = { + group: _anatomy_tools.get_anatomy_diffuse_color(group) + for group in ANATOMY_GROUPS + if group in supported_types + } + + # ─────────────────────────── Internal helpers ────────────────────────── + + def _create_segmenter(self) -> SegmentAnatomyBase: + """Instantiate the chosen segmentation backend (lazy import).""" + if self.segmentation_method_name == "total_segmentator": + from physiomotion4d.segment_chest_total_segmentator import ( + SegmentChestTotalSegmentator, + ) + + return SegmentChestTotalSegmentator(log_level=self.log_level) + if self.segmentation_method_name == "vista_3d": + from physiomotion4d.segment_chest_vista_3d import SegmentChestVista3D + + return SegmentChestVista3D(log_level=self.log_level) + if self.segmentation_method_name == "simpleware_heart": + from physiomotion4d.segment_heart_simpleware import SegmentHeartSimpleware + + segmenter = SegmentHeartSimpleware(log_level=self.log_level) + segmenter.set_trim_mask_to_essentials(True) + return segmenter + raise ValueError( + f"Unknown segmentation method: {self.segmentation_method_name}" + ) + + def _get_label_info_for_group(self, group: str) -> tuple[list[str], list[int]]: + """Return ``(label_names, label_ids)`` for *group* from the active segmenter. + + Reads the segmenter's ``_mask_ids`` dictionary. Returns empty lists + if the attribute does not exist (e.g. simpleware_heart lacks lung/bone ids). + """ + assert self._segmenter is not None, ( + "_create_segmenter() must be called before _get_label_info_for_group()" + ) + mask_ids: dict[int, str] = getattr(self._segmenter, f"{group}_mask_ids", {}) + return list(mask_ids.values()), list(mask_ids.keys()) + + @staticmethod + def _annotate( + vtk_obj: pv.DataSet, + group: str, + label_names: list[str], + label_ids: list[int], + color_rgb: tuple[float, float, float], + ) -> None: + """Attach anatomy metadata and solid RGBA color to a VTK object **in-place**. + + Sets: + + - ``field_data['AnatomyGroup']`` — group name. + - ``field_data['SegmentationLabelNames']`` — individual label names. + - ``field_data['SegmentationLabelIds']`` — integer label IDs (int32). + - ``field_data['AnatomyColor']`` — RGB float32 color. + - ``cell_data['Color']`` — RGBA uint8 solid color (n_cells × 4). + """ + vtk_obj.field_data["AnatomyGroup"] = np.array([group]) + vtk_obj.field_data["SegmentationLabelNames"] = np.array( + label_names if label_names else [group] + ) + vtk_obj.field_data["SegmentationLabelIds"] = np.array(label_ids, dtype=np.int32) + vtk_obj.field_data["AnatomyColor"] = np.array(color_rgb, dtype=np.float32) + + r, g, b = color_rgb + rgba = np.array([int(r * 255), int(g * 255), int(b * 255), 255], dtype=np.uint8) + if vtk_obj.n_cells > 0: + vtk_obj.cell_data["Color"] = np.tile(rgba, (vtk_obj.n_cells, 1)) + + def _extract_surface(self, mask_image: Any) -> Optional[pv.PolyData]: + """Extract a smoothed triangulated surface (VTP) from a binary mask image. + + Delegates to :meth:`ContourTools.extract_contours`. + + Returns: + Smoothed :class:`pyvista.PolyData`, or ``None`` if the mask is empty. + """ + arr = itk.GetArrayFromImage(mask_image) + if int(arr.sum()) == 0: + return None + return self._contour_tools.extract_contours(mask_image) + + def _extract_mesh(self, mask_image: Any) -> Optional[pv.UnstructuredGrid]: + """Extract a voxel-based volumetric mesh (VTU) from a binary mask image. + + Wraps the ITK image as a VTK ImageData and thresholds at 0.5 to obtain + hexahedral voxel cells for non-zero voxels. + + Returns: + :class:`pyvista.UnstructuredGrid` of labeled voxels, or ``None`` if empty. + """ + arr = itk.GetArrayFromImage(mask_image) + if int(arr.sum()) == 0: + return None + + vtk_image = pv.wrap(itk.vtk_image_from_image(mask_image)) + if not isinstance(vtk_image, pv.ImageData): + self.log_warning( + "Expected pv.ImageData from vtk_image_from_image, got %s — skipping mesh", + type(vtk_image).__name__, + ) + return None + + thresholded = vtk_image.threshold(0.5) + if isinstance(thresholded, pv.UnstructuredGrid): + return thresholded + return cast(pv.UnstructuredGrid, thresholded.cast_to_unstructured_grid()) + + # ─────────────────────────── Main workflow ───────────────────────────── + + def run_workflow( + self, + input_image: Any, + contrast_enhanced_study: bool = False, + anatomy_groups: Optional[list[str]] = None, + ) -> dict[str, Any]: + """Segment the CT image and extract per-anatomy-group VTK objects. + + Args: + input_image: Input 3D CT image (``itk.Image``). + contrast_enhanced_study: If ``True``, an additional connected-component + pass identifies contrast-enhanced blood. Default: ``False``. + anatomy_groups: Subset of anatomy groups to process. ``None`` (default) + processes all non-empty groups. Valid names: ``'heart'``, + ``'lung'``, ``'major_vessels'``, ``'bone'``, ``'soft_tissue'``, + ``'other'``, ``'contrast'``. + + Returns: + ``dict`` with the following keys: + + - ``'surfaces'`` — ``dict[str, pv.PolyData]``: smoothed surface per group. + - ``'meshes'`` — ``dict[str, pv.UnstructuredGrid]``: voxel mesh per group. + - ``'labelmap'`` — ``itk.Image``: detailed per-structure segmentation + labelmap from the segmenter. + - ``'segmentation_masks'`` — ``dict[str, itk.Image]``: per-group binary + masks used to produce the VTK objects. + + Raises: + ValueError: If any name in *anatomy_groups* is invalid. + """ + self.log_section("STARTING CT TO VTK WORKFLOW") + + # Validate requested groups + if anatomy_groups is not None: + invalid = [g for g in anatomy_groups if g not in self.ANATOMY_GROUPS] + if invalid: + raise ValueError( + f"Unknown anatomy groups: {invalid}. " + f"Valid: {list(self.ANATOMY_GROUPS)}" + ) + groups_to_process: list[str] = list(anatomy_groups) + else: + groups_to_process = list(self.ANATOMY_GROUPS) + + # Create and run segmenter + self.log_info("Creating segmenter: %s", self.segmentation_method_name) + self._segmenter = self._create_segmenter() + + self.log_section("Running segmentation") + seg_result: dict[str, Any] = self._segmenter.segment( + input_image, contrast_enhanced_study=contrast_enhanced_study + ) + + # Extract VTK objects per anatomy group + self.log_section("Extracting VTK objects") + surfaces: dict[str, pv.PolyData] = {} + meshes: dict[str, pv.UnstructuredGrid] = {} + seg_masks: dict[str, Any] = {} + + for group in groups_to_process: + if group not in seg_result: + self.log_warning( + "Group %s absent from segmentation result — skipping", group + ) + continue + + mask_image = seg_result[group] + if int(itk.GetArrayFromImage(mask_image).sum()) == 0: + self.log_info("Group %s is empty — skipping", group) + continue + + self.log_info("Processing anatomy group: %s", group) + seg_masks[group] = mask_image + + label_names, label_ids = self._get_label_info_for_group(group) + color = self._anatomy_color_map.get(group, (0.7, 0.7, 0.7)) + + self.log_info(" Extracting surface for: %s", group) + surface = self._extract_surface(mask_image) + if surface is not None: + self._annotate(surface, group, label_names, label_ids, color) + surfaces[group] = surface + + self.log_info(" Extracting voxel mesh for: %s", group) + mesh = self._extract_mesh(mask_image) + if mesh is not None: + self._annotate(mesh, group, label_names, label_ids, color) + meshes[group] = mesh + + self.log_section("CT TO VTK WORKFLOW COMPLETE") + self.log_info("Surfaces extracted: %d", len(surfaces)) + self.log_info("Meshes extracted: %d", len(meshes)) + + return { + "surfaces": surfaces, + "meshes": meshes, + "labelmap": seg_result["labelmap"], + "segmentation_masks": seg_masks, + } + + # ─────────────────────────── I/O helpers ─────────────────────────────── + + @staticmethod + def save_surfaces( + surfaces: dict[str, pv.PolyData], + output_dir: str, + prefix: str = "", + ) -> dict[str, str]: + """Save each group surface to its own VTP file. + + Args: + surfaces: Mapping of anatomy group name → surface (from + :meth:`run_workflow`). + output_dir: Directory to write files into (created if absent). + prefix: Optional filename prefix. Each file is named + ``{prefix}_{group}.vtp`` (or ``{group}.vtp`` when *prefix* is empty). + + Returns: + Mapping of anatomy group name → absolute path of the saved file. + """ + os.makedirs(output_dir, exist_ok=True) + saved: dict[str, str] = {} + for name, surface in surfaces.items(): + stem = f"{prefix}_{name}" if prefix else name + path = os.path.join(output_dir, f"{stem}.vtp") + surface.save(path) + saved[name] = path + return saved + + @staticmethod + def save_meshes( + meshes: dict[str, pv.UnstructuredGrid], + output_dir: str, + prefix: str = "", + ) -> dict[str, str]: + """Save each group voxel mesh to its own VTU file. + + Args: + meshes: Mapping of anatomy group name → mesh (from :meth:`run_workflow`). + output_dir: Directory to write files into (created if absent). + prefix: Optional filename prefix. Each file is named + ``{prefix}_{group}.vtu`` (or ``{group}.vtu`` when *prefix* is empty). + + Returns: + Mapping of anatomy group name → absolute path of the saved file. + """ + os.makedirs(output_dir, exist_ok=True) + saved: dict[str, str] = {} + for name, mesh in meshes.items(): + stem = f"{prefix}_{name}" if prefix else name + path = os.path.join(output_dir, f"{stem}.vtu") + mesh.save(path) + saved[name] = path + return saved + + @staticmethod + def save_combined_surface( + surfaces: dict[str, pv.PolyData], + output_dir: str, + prefix: str = "", + ) -> str: + """Merge all group surfaces into a single VTP file. + + The merged mesh retains per-cell ``Color`` (RGBA uint8) from each group's + annotation, enabling colour-by-anatomy rendering in Paraview, PyVista, etc. + Per-object ``field_data`` is not preserved in the merged file. + + Args: + surfaces: Mapping of anatomy group name → surface. + output_dir: Directory to write the file into (created if absent). + prefix: Optional filename prefix. Output is ``{prefix}_surfaces.vtp`` + (or ``surfaces.vtp`` when *prefix* is empty). + + Returns: + Absolute path to the saved VTP file. + + Raises: + ValueError: If *surfaces* is empty. + """ + if not surfaces: + raise ValueError("No surfaces to save.") + os.makedirs(output_dir, exist_ok=True) + stem = f"{prefix}_surfaces" if prefix else "surfaces" + output_file = os.path.join(output_dir, f"{stem}.vtp") + merged = cast( + pv.PolyData, pv.merge(list(surfaces.values()), merge_points=False) + ) + merged.save(output_file) + return output_file + + @staticmethod + def save_combined_mesh( + meshes: dict[str, pv.UnstructuredGrid], + output_dir: str, + prefix: str = "", + ) -> str: + """Merge all group meshes into a single VTU file. + + The merged mesh retains per-cell ``Color`` (RGBA uint8) from each group's + annotation. Per-object ``field_data`` is not preserved in the merged file. + + Args: + meshes: Mapping of anatomy group name → voxel mesh. + output_dir: Directory to write the file into (created if absent). + prefix: Optional filename prefix. Output is ``{prefix}_meshes.vtu`` + (or ``meshes.vtu`` when *prefix* is empty). + + Returns: + Absolute path to the saved VTU file. + + Raises: + ValueError: If *meshes* is empty. + """ + if not meshes: + raise ValueError("No meshes to save.") + os.makedirs(output_dir, exist_ok=True) + stem = f"{prefix}_meshes" if prefix else "meshes" + output_file = os.path.join(output_dir, f"{stem}.vtu") + merged = cast( + pv.UnstructuredGrid, pv.merge(list(meshes.values()), merge_points=False) + ) + merged.save(output_file) + return output_file diff --git a/src/physiomotion4d/workflow_fit_statistical_model_to_patient.py b/src/physiomotion4d/workflow_fit_statistical_model_to_patient.py index dfe31cf..8f9c907 100644 --- a/src/physiomotion4d/workflow_fit_statistical_model_to_patient.py +++ b/src/physiomotion4d/workflow_fit_statistical_model_to_patient.py @@ -38,6 +38,7 @@ from physiomotion4d.register_models_icp import RegisterModelsICP from physiomotion4d.register_models_pca import RegisterModelsPCA from physiomotion4d.transform_tools import TransformTools +from physiomotion4d.workflow_convert_ct_to_vtk import WorkflowConvertCTToVTK class WorkflowFitStatisticalModelToPatient(PhysioMotion4DBase): @@ -122,8 +123,9 @@ class WorkflowFitStatisticalModelToPatient(PhysioMotion4DBase): def __init__( self, template_model: pv.PolyData, - patient_models: list[pv.PolyData], + patient_models: list[pv.PolyData] | None = None, patient_image: Optional[itk.Image] = None, + segmentation_method: str = "simpleware_heart", log_level: int | str = logging.INFO, ): """Initialize the model-to-image-and-model registration pipeline. @@ -150,6 +152,19 @@ def __init__( self.template_labelmap_organ_extra_ids: Optional[list[int]] = None self.template_labelmap_background_ids: Optional[list[int]] = None + if patient_models is None and patient_image is not None: + convert_ct_to_vtk = WorkflowConvertCTToVTK( + segmentation_method=segmentation_method, + log_level=log_level, + ) + patient_models_data = convert_ct_to_vtk.run_workflow( + input_image=patient_image, + contrast_enhanced_study=False, + anatomy_groups=["heart"], + ) + patient_models = [patient_models_data["meshes"]["heart"]] + elif patient_models is None: + raise ValueError("Either patient_models or patient_image must be provided.") self.patient_models = patient_models patient_models_surfaces = [model.extract_surface() for model in patient_models] self.combined_patient_model = pv.merge(patient_models_surfaces) @@ -189,8 +204,8 @@ def __init__( self.patient_roi = None # Parameters for mask generation and processing - self.mask_dilation_mm: float = 5.0 # For auto-generated mask dilation - self.roi_dilation_mm: float = 20.0 # For ROI mask generation + self.mask_dilation_mm: float = 0.0 # For auto-generated mask dilation + self.roi_dilation_mm: float = 25.0 # For ROI mask generation # Stage 1: ICP alignment results self.icp_registrar: Optional[RegisterModelsICP] = None @@ -232,6 +247,7 @@ def __init__( # Final result self.registered_template_model: Optional[pv.UnstructuredGrid] = None self.registered_template_model_surface: Optional[pv.PolyData] = None + self.registered_template_labelmap: Optional[itk.Image] = None def _auto_generate_mask( self, models: list[pv.UnstructuredGrid], dilate_mm: Optional[float] = None @@ -479,6 +495,7 @@ def register_model_to_model_icp(self) -> dict: self.registered_template_model_surface = self.icp_template_model_surface self.registered_template_model = self.icp_template_model + self.registered_template_labelmap = self.icp_template_labelmap return { "inverse_point_transform": self.icp_inverse_point_transform, @@ -599,6 +616,8 @@ def register_model_to_model_pca(self) -> dict: else: self.pca_template_labelmap = None + self.registered_template_labelmap = self.pca_template_labelmap + self.log_info("Stage 2 complete: PCA registration finished.") return { @@ -667,6 +686,8 @@ def register_mask_to_mask( else: self.m2m_template_labelmap = None + self.registered_template_labelmap = self.m2m_template_labelmap + self.log_info("Stage 3 complete: Mask-to-mask registration finished.") return { @@ -786,6 +807,8 @@ def register_labelmap_to_image( self.registered_template_model_surface = self.m2i_template_model_surface + self.registered_template_labelmap = self.m2i_template_labelmap + return { "inverse_transform": self.m2i_inverse_transform, "forward_transform": self.m2i_forward_transform, @@ -926,4 +949,5 @@ def run_workflow( return { "registered_template_model": self.registered_template_model, "registered_template_model_surface": self.registered_template_model_surface, + "registered_template_labelmap": self.registered_template_labelmap, }