Skip to content

Commit 376f84f

Browse files
authored
ENH: Simplify heart models for inter-patient consistency (#31)
* ENH: Simplify heart models for inter-patient consistency * ENH: Update PCA and simpleware heart seg for mesh PCA * BUG: Fixed trim_mask_to_essentials - logic cut/paste error * ENH: Clean up trimming of heart to essentials to avoid overlap
1 parent 6436958 commit 376f84f

9 files changed

Lines changed: 357 additions & 112 deletions

experiments/Convert_VTK_To_USD/convert_chop_valve_to_usd.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1132,4 +1132,4 @@
11321132
},
11331133
"nbformat": 4,
11341134
"nbformat_minor": 4
1135-
}
1135+
}

experiments/Heart-Simpleware_Segmentation/simpleware_heart_segmentation.ipynb

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,9 @@
367367
" heart_array = itk.array_from_image(result[\"heart\"])\n",
368368
" vessels_array = itk.array_from_image(result[\"major_vessels\"])\n",
369369
"\n",
370+
" labelmap_essentials = segmenter.trim_mask_to_essentials(result[\"labelmap\"])\n",
371+
" labelmap_essentials_array = itk.array_from_image(labelmap_essentials)\n",
372+
"\n",
370373
" # Select middle slice\n",
371374
" mid_slice = image_array.shape[0] // 2\n",
372375
"\n",
@@ -380,7 +383,8 @@
380383
"\n",
381384
" axes[0, 1].imshow(image_array[mid_slice, :, :], cmap=\"gray\", vmin=-200, vmax=400)\n",
382385
" labelmap_overlay = np.ma.masked_where(\n",
383-
" labelmap_array[mid_slice, :, :] == 0, labelmap_array[mid_slice, :, :]\n",
386+
" labelmap_essentials_array[mid_slice, :, :] == 0,\n",
387+
" labelmap_essentials_array[mid_slice, :, :],\n",
384388
" )\n",
385389
" axes[0, 1].imshow(labelmap_overlay, cmap=\"jet\", alpha=0.5, vmin=1, vmax=10)\n",
386390
" axes[0, 1].set_title(\"Labelmap Overlay\")\n",
@@ -407,7 +411,8 @@
407411
" mid_sagittal = image_array.shape[2] // 2\n",
408412
" axes[1, 1].imshow(image_array[:, :, mid_sagittal], cmap=\"gray\", vmin=-200, vmax=400)\n",
409413
" sagittal_overlay = np.ma.masked_where(\n",
410-
" labelmap_array[:, :, mid_sagittal] == 0, labelmap_array[:, :, mid_sagittal]\n",
414+
" labelmap_essentials_array[:, :, mid_sagittal] == 0,\n",
415+
" labelmap_essentials_array[:, :, mid_sagittal],\n",
411416
" )\n",
412417
" axes[1, 1].imshow(sagittal_overlay, cmap=\"jet\", alpha=0.5, vmin=1, vmax=10)\n",
413418
" axes[1, 1].set_title(\"Sagittal View\")\n",
@@ -417,7 +422,8 @@
417422
" mid_coronal = image_array.shape[1] // 2\n",
418423
" axes[1, 2].imshow(image_array[:, mid_coronal, :], cmap=\"gray\", vmin=-200, vmax=400)\n",
419424
" coronal_overlay = np.ma.masked_where(\n",
420-
" labelmap_array[:, mid_coronal, :] == 0, labelmap_array[:, mid_coronal, :]\n",
425+
" labelmap_essentials_array[:, mid_coronal, :] == 0,\n",
426+
" labelmap_essentials_array[:, mid_coronal, :],\n",
421427
" )\n",
422428
" axes[1, 2].imshow(coronal_overlay, cmap=\"jet\", alpha=0.5, vmin=1, vmax=10)\n",
423429
" axes[1, 2].set_title(\"Coronal View\")\n",
@@ -457,6 +463,7 @@
457463
"\n",
458464
"# Convert heart mask to VTK\n",
459465
"heart_vtk = itk.vtk_image_from_image(result[\"heart\"])\n",
466+
"heart_essentials_vtk = itk.vtk_image_from_image(labelmap_essentials)\n",
460467
"vessels_vtk = itk.vtk_image_from_image(result[\"major_vessels\"])\n",
461468
"\n",
462469
"# Create PyVista plotter\n",
@@ -466,7 +473,15 @@
466473
"heart_grid = pv.wrap(heart_vtk)\n",
467474
"heart_surface = heart_grid.contour([0.5])\n",
468475
"if heart_surface.n_points > 0:\n",
469-
" plotter.add_mesh(heart_surface, color=\"red\", opacity=1.0, label=\"Heart\")\n",
476+
" plotter.add_mesh(heart_surface, color=\"red\", opacity=0.5, label=\"Heart\")\n",
477+
"\n",
478+
"# Extract heart surface\n",
479+
"heart_essentials_grid = pv.wrap(heart_essentials_vtk)\n",
480+
"heart_essentials_surface = heart_essentials_grid.contour([0.5])\n",
481+
"if heart_essentials_surface.n_points > 0:\n",
482+
" plotter.add_mesh(\n",
483+
" heart_essentials_surface, color=\"grey\", opacity=1.0, label=\"Heart Essential\"\n",
484+
" )\n",
470485
"\n",
471486
"# Extract vessels surface\n",
472487
"vessels_grid = pv.wrap(vessels_vtk)\n",
@@ -486,6 +501,14 @@
486501
"print(f\"3D visualization saved to: {screenshot_path}\")"
487502
]
488503
},
504+
{
505+
"cell_type": "code",
506+
"execution_count": null,
507+
"id": "f6046393",
508+
"metadata": {},
509+
"outputs": [],
510+
"source": []
511+
},
489512
{
490513
"cell_type": "markdown",
491514
"id": "o3p4q5r6",

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ dependencies = [
6565
# AI/ML and segmentation
6666
"monai>=1.3.0",
6767
"torch>=2.0.0,<3.0.0",
68-
"transformers>=4.21.0",
68+
"transformers>=4.21.0,<5.0.0",
6969
"totalsegmentator>=2.0.0",
7070

7171
# Registration

src/physiomotion4d/cli/fit_statistical_model_to_patient.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,8 @@ def main() -> int:
234234
pca_model=pca_model,
235235
pca_number_of_modes=args.pca_number_of_modes,
236236
)
237+
if args.use_mask_to_mask:
238+
workflow.set_use_mask_to_mask_registration(args.use_mask_to_mask)
237239
if args.use_mask_to_image:
238240
workflow.set_use_mask_to_image_registration(
239241
True,
@@ -252,8 +254,6 @@ def main() -> int:
252254
print("\nStarting registration pipeline...")
253255
print("=" * 70)
254256
result = workflow.run_workflow(
255-
use_mask_to_mask_registration=args.use_mask_to_mask,
256-
use_mask_to_image_registration=args.use_mask_to_image,
257257
use_icon_registration_refinement=args.use_icon_refinement,
258258
)
259259

src/physiomotion4d/cli/visualize_pca_modes.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,11 @@ def main() -> int:
140140
traceback.print_exc()
141141
return 1
142142

143-
if not isinstance(mean_mesh, pv.PolyData):
144-
print("Error: PCA mean surface must be a PolyData (.vtp).")
143+
if not isinstance(mean_mesh, (pv.PolyData, pv.UnstructuredGrid)):
144+
print(
145+
"Error: PCA mean surface must be PolyData or UnstructuredGrid.",
146+
f"Type: {type(mean_mesh)}",
147+
)
145148
return 1
146149

147150
try:

src/physiomotion4d/contour_tools.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,12 +165,15 @@ def create_reference_image(
165165

166166
def create_mask_from_mesh(
167167
self,
168-
mesh: pv.DataSet,
168+
mesh: pv.DataSet | pv.UnstructuredGrid,
169169
reference_image: itk.Image,
170170
) -> itk.Image:
171171
ref_spacing = np.array(reference_image.GetSpacing())
172172

173173
# Create trimesh object with LPS coordinates
174+
if isinstance(mesh, pv.UnstructuredGrid):
175+
mesh = mesh.extract_surface()
176+
174177
if hasattr(mesh, "n_faces_strict"):
175178
# PyVista PolyData
176179
faces = mesh.faces.reshape((mesh.n_faces_strict, 4))[:, 1:]
@@ -248,7 +251,7 @@ def create_mask_from_mesh(
248251

249252
def create_distance_map(
250253
self,
251-
mesh: pv.DataSet,
254+
mesh: pv.DataSet | pv.UnstructuredGrid,
252255
reference_image: itk.Image,
253256
squared_distance: bool = False,
254257
negative_inside: bool = True,

src/physiomotion4d/segment_heart_simpleware.py

Lines changed: 145 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def __init__(self, log_level: int | str = logging.INFO):
9898
# From Base Class
9999
# self.contrast_mask_ids = {135: "contrast"}
100100

101-
self.trim_mesh_to_essentials = False
101+
self._trim_mask = False
102102

103103
self.set_other_and_all_mask_ids()
104104

@@ -112,13 +112,13 @@ def __init__(self, log_level: int | str = logging.INFO):
112112
"SimplewareScript_heart_segmentation.py",
113113
)
114114

115-
def set_trim_mesh_to_essentials(self, trim_mesh_to_essentials: bool) -> None:
116-
"""Set whether to trim mesh to common and critical structures.
115+
def set_trim_mask_to_essentials(self, trim_mask: bool) -> None:
116+
"""Set whether to trim mask to common and critical structures.
117117
118118
Args:
119-
trim_mesh_to_essentials (bool): Whether to reduce to essential.
119+
trim_mask (bool): Whether to reduce to essential.
120120
"""
121-
self.trim_mesh_to_essentials = trim_mesh_to_essentials
121+
self._trim_mask = trim_mask
122122

123123
def set_simpleware_executable_path(self, path: str) -> None:
124124
"""Set the path to the Simpleware Medical console executable.
@@ -283,8 +283,9 @@ def segmentation_method(self, preprocessed_image: itk.image) -> itk.image:
283283
interior_image = itk.GetImageFromArray(interior_array.astype(np.uint8))
284284
interior_image.CopyInformation(preprocessed_image)
285285
imMath = tube.ImageMath.New(interior_image)
286-
imMath.Dilate(7, 1, 0)
287-
imMath.Erode(4, 1, 0)
286+
spacing = interior_image.GetSpacing()
287+
imMath.Dilate(round(7 / spacing[0]), 1, 0)
288+
imMath.Erode(round(4 / spacing[0]), 1, 0)
288289
exterior_image = imMath.GetOutputUChar()
289290
exterior_array = itk.GetArrayFromImage(exterior_image)
290291
mask_id = 6 # Heart mask id
@@ -300,17 +301,144 @@ def segmentation_method(self, preprocessed_image: itk.image) -> itk.image:
300301
"ensure the ASCardio module ran successfully."
301302
)
302303

303-
if self.trim_mesh_to_essentials:
304-
z = labelmap_array.shape[2] - 1
305-
z_classes = np.unique(labelmap_array[z, :, :])
306-
heart_count = np.sum((c in [1, 2, 3, 4, 5]) for c in z_classes)
307-
while heart_count < 3 and z > 0:
308-
z -= 1
309-
z_classes = np.unique(labelmap_array[z, :, :])
310-
heart_count = np.sum((c in [1, 2, 3, 4, 5]) for c in z_classes)
311-
if z < labelmap_array.shape[2] - 3:
312-
labelmap_array[(z + 3) :, :, :] = 0
313304
labelmap_image = itk.GetImageFromArray(labelmap_array.astype(np.uint8))
314305
labelmap_image.CopyInformation(preprocessed_image)
315306

307+
if self._trim_mask:
308+
labelmap_image = self.trim_mask_to_essentials(labelmap_image)
309+
316310
return labelmap_image
311+
312+
def trim_mask_to_essentials(self, labelmap_image: itk.image) -> itk.image:
313+
"""Trim mask to essentials."""
314+
315+
# Reference code for cropping aorta and pulmonary artery to
316+
# portions adjacent to the heart.
317+
# Trim z-axis
318+
# z = labelmap_array.shape[2] - 1
319+
# z_classes = np.unique(labelmap_array[z, :, :])
320+
# heart_count = np.sum((c in [1, 2, 3, 4, 5]) for c in z_classes)
321+
# while heart_count < 3 and z > 0:
322+
# z -= 1
323+
# z_classes = np.unique(labelmap_array[z, :, :])
324+
# heart_count = np.sum((c in [1, 2, 3, 4, 5]) for c in z_classes)
325+
# if z < labelmap_array.shape[2] - 3:
326+
# labelmap_array[(z + 3) :, :, :] = 0
327+
328+
# In labelmap,
329+
# if pixel is in keep_mask, was left or right atrium, then keep as
330+
# left or right atrium
331+
332+
# 1) Erase Heart and Myo label
333+
labelmap_arr = itk.array_from_image(labelmap_image)
334+
335+
heart_arr = itk.array_from_image(labelmap_image)
336+
heart_arr[heart_arr == 6] = 0
337+
heart_arr[heart_arr == 5] = 0
338+
339+
img = itk.image_from_array(heart_arr)
340+
img.CopyInformation(labelmap_image)
341+
imMath = tube.ImageMath.New(img)
342+
343+
# 2) Erode then Dilate Left Atrium label to clip vessels
344+
spacing = labelmap_image.GetSpacing()
345+
imMath.Erode(round(7 / spacing[0]), 3, 0)
346+
imMath.Dilate(round(7 / spacing[0]), 3, 0)
347+
348+
# 3) Erode then Dilate Right Atrium label to clip vessels
349+
imMath.Erode(round(7 / spacing[0]), 4, 0)
350+
imMath.Dilate(round(7 / spacing[0]), 4, 0)
351+
simple_img = imMath.GetOutput()
352+
simple_arr = itk.array_from_image(simple_img)
353+
354+
# Keep the largest component of the left atrium
355+
simple_arr_3 = simple_arr.copy()
356+
simple_arr_3[simple_arr_3 != 3] = 0
357+
simple_arr_3[simple_arr_3 == 3] = 1
358+
simple_img_3 = itk.image_from_array(simple_arr_3)
359+
connComp = tube.SegmentConnectedComponents.New(simple_img_3)
360+
connComp.SetKeepOnlyLargestComponent(True)
361+
connComp.Update()
362+
mask_img_3 = connComp.GetOutput()
363+
mask_arr_3 = itk.array_from_image(mask_img_3)
364+
simple_arr_3[mask_arr_3 == 0] = 0
365+
366+
# Keep the largest component of the right atrium
367+
simple_arr_4 = simple_arr.copy()
368+
simple_arr_4[simple_arr_4 != 4] = 0
369+
simple_arr_4[simple_arr_4 == 4] = 1
370+
simple_img_4 = itk.image_from_array(simple_arr_4)
371+
connComp = tube.SegmentConnectedComponents.New(simple_img_4)
372+
connComp.SetKeepOnlyLargestComponent(True)
373+
connComp.Update()
374+
mask_img_4 = connComp.GetOutput()
375+
mask_arr_4 = itk.array_from_image(mask_img_4)
376+
simple_arr_4[mask_arr_4 == 0] = 0
377+
378+
# Replace the left and right atrium labels with the largest components
379+
simple_arr[simple_arr == 3] = 0
380+
simple_arr[simple_arr == 4] = 0
381+
simple_arr[simple_arr_3 > 0] = 3
382+
simple_arr[simple_arr_4 > 0] = 4
383+
simple_img = itk.image_from_array(simple_arr)
384+
simple_img.CopyInformation(labelmap_image)
385+
386+
# 4) Dilate all others = keep_mask
387+
keep_mask_arr = heart_arr.copy()
388+
keep_mask_arr[keep_mask_arr == 2] = 1
389+
keep_mask_arr[keep_mask_arr == 5] = 1
390+
keep_mask_arr[keep_mask_arr != 1] = 0
391+
keep_mask = itk.image_from_array(keep_mask_arr)
392+
keep_mask.CopyInformation(labelmap_image)
393+
imMath.SetInput(keep_mask)
394+
imMath.Dilate(round(7 / spacing[0]), 1, 0)
395+
keep_mask = imMath.GetOutput()
396+
keep_mask_arr = itk.array_from_image(keep_mask)
397+
398+
# Add the left and right atrium labels to the keep_mask
399+
heart_arr = heart_arr * keep_mask_arr
400+
heart_arr[simple_arr == 3] = 3
401+
heart_arr[simple_arr == 4] = 4
402+
heart_img = itk.image_from_array(heart_arr)
403+
heart_img.CopyInformation(labelmap_image)
404+
405+
# Dilate the keep_mask to simulate 3mm (heart)
406+
keep_mask_arr = heart_arr.copy()
407+
keep_mask_arr[keep_mask_arr == 1] = 0
408+
keep_mask_arr[keep_mask_arr > 0] = 1
409+
keep_mask = itk.image_from_array(keep_mask_arr)
410+
keep_mask.CopyInformation(labelmap_image)
411+
imMath.SetInput(keep_mask)
412+
imMath.Dilate(round(5 / spacing[0]), 1, 0)
413+
imMath.Erode(round(2 / spacing[0]), 1, 0)
414+
heart_mask = imMath.GetOutput()
415+
416+
# Insert the heart and myo labels back into the labelmap
417+
heart_mask_arr = itk.array_from_image(heart_mask)
418+
heart_mask_arr[heart_arr > 0] = 0
419+
heart_arr[heart_mask_arr > 0] = 6
420+
heart_arr_myo = itk.array_from_image(labelmap_image)
421+
heart_arr[heart_arr_myo == 5] = 5
422+
heart_arr[heart_arr_myo == 1] = 1
423+
heart_img = itk.image_from_array(heart_arr)
424+
heart_img.CopyInformation(labelmap_image)
425+
426+
# Add in missing pieces / gaps of the myocardium
427+
lv_arr = heart_arr.copy()
428+
lv_arr[lv_arr != 1] = 0
429+
lv_img = itk.image_from_array(lv_arr)
430+
lv_img.CopyInformation(labelmap_image)
431+
imMath.SetInput(lv_img)
432+
imMath.Dilate(round(2 / spacing[0]), 1, 0)
433+
lv_img = imMath.GetOutput()
434+
lv_arr = itk.array_from_image(lv_img)
435+
lv_arr = lv_arr * 5 # Myocardium label is 5
436+
437+
# Add the gap-filled myocardium back into the labelmap
438+
heart_arr = np.where(heart_arr == 0, lv_arr, heart_arr)
439+
# Eliminate overlap with other labels
440+
heart_arr = np.where(labelmap_arr > 6, 0, heart_arr)
441+
heart_img = itk.image_from_array(heart_arr)
442+
heart_img.CopyInformation(labelmap_image)
443+
444+
return heart_img

0 commit comments

Comments
 (0)