Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 48 additions & 13 deletions src/physiomotion4d/image_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,34 +246,69 @@ def convert_array_to_image_of_vectors(

return itk_image

def flip_image_to_identity_direction(
self, in_image: itk.Image, in_mask: Optional[itk.Image] = None
def flip_image(
self,
in_image: itk.Image,
in_mask: Optional[itk.Image] = None,
flip_x: bool = False,
flip_y: bool = False,
flip_z: bool = False,
flip_and_make_identity: bool = False,
) -> Any | tuple[Any, Any]:
Comment thread
aylward marked this conversation as resolved.
Comment thread
aylward marked this conversation as resolved.
"""
Flip the image to the identity direction.
Flip the image and mask.

Only axis-aligned flips are supported. If ``flip_and_make_identity`` is
True, the image and mask are first flipped along any axes whose
corresponding diagonal entries in the direction matrix are negative
(assuming the direction matrix encodes only axis-aligned flips), then
any additional requested flips are performed, and finally the direction
matrix is set to the identity matrix. This is useful when combining ITK
images with VTK objects (that often do not support a direction matrix).

Args:
in_image: The input image to flip
in_mask: The input mask to flip
flip_x: Flip the image and mask along the x-axis
flip_y: Flip the image and mask along the y-axis
flip_z: Flip the image and mask along the z-axis
flip_and_make_identity: Flip the image and mask and make the direction
matrix identity.
"""
flip0 = np.array(in_image.GetDirection())[0, 0] < 0
flip1 = np.array(in_image.GetDirection())[1, 1] < 0
flip2 = np.array(in_image.GetDirection())[2, 2] < 0
flip0 = False
flip1 = False
flip2 = False
if flip_and_make_identity:
flip0 = np.array(in_image.GetDirection())[0, 0] < 0
flip1 = np.array(in_image.GetDirection())[1, 1] < 0
flip2 = np.array(in_image.GetDirection())[2, 2] < 0
Comment thread
aylward marked this conversation as resolved.
Comment thread
aylward marked this conversation as resolved.
Comment thread
aylward marked this conversation as resolved.
if flip_x:
flip0 = True
if flip_y:
flip1 = True
if flip_z:
flip2 = True
Comment thread
aylward marked this conversation as resolved.
Comment thread
aylward marked this conversation as resolved.
if flip0 or flip1 or flip2:
self.log_info(
f"Flipping image to identity direction: {flip0}, {flip1}, {flip2}"
)
self.log_info(f"Flipping image: {flip0}, {flip1}, {flip2}")
flip_filter = itk.FlipImageFilter.New(Input=in_image)
flip_filter.SetFlipAxes([int(flip0), int(flip1), int(flip2)])
flip_filter.SetFlipAboutOrigin(True)
flip_filter.Update()
out_image = flip_filter.GetOutput()
id_mat = itk.Matrix[itk.D, 3, 3]()
id_mat.SetIdentity()
out_image.SetDirection(id_mat)
if flip_and_make_identity:
id_mat = itk.Matrix[itk.D, 3, 3]()
id_mat.SetIdentity()
out_image.SetDirection(id_mat)
Comment thread
aylward marked this conversation as resolved.
if in_mask is not None:
flip_filter = itk.FlipImageFilter.New(Input=in_mask)
flip_filter.SetFlipAxes([int(flip0), int(flip1), int(flip2)])
flip_filter.SetFlipAboutOrigin(True)
flip_filter.Update()
out_mask = flip_filter.GetOutput()
out_mask.SetDirection(id_mat)
if flip_and_make_identity:
id_mat = itk.Matrix[itk.D, 3, 3]()
id_mat.SetIdentity()
out_mask.SetDirection(id_mat)
Comment thread
aylward marked this conversation as resolved.
return out_image, out_mask
else:
return out_image
Expand Down
175 changes: 166 additions & 9 deletions tests/test_image_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
scalar and vector images.
"""

from __future__ import annotations

import itk
import numpy as np
import pytest
Expand All @@ -18,11 +20,11 @@ class TestImageTools:
"""Test suite for ImageTools conversions."""

@pytest.fixture
def image_tools(self):
def image_tools(self) -> ImageTools:
"""Create ImageTools instance."""
return ImageTools()

def test_itk_to_sitk_scalar_image(self, image_tools):
def test_itk_to_sitk_scalar_image(self, image_tools: ImageTools) -> None:
"""Test conversion of scalar ITK image to SimpleITK."""
# Create a simple 3D scalar ITK image
size = [10, 20, 30]
Expand Down Expand Up @@ -57,7 +59,7 @@ def test_itk_to_sitk_scalar_image(self, image_tools):

print("✓ ITK to SimpleITK scalar conversion successful")

def test_sitk_to_itk_scalar_image(self, image_tools):
def test_sitk_to_itk_scalar_image(self, image_tools: ImageTools) -> None:
"""Test conversion of scalar SimpleITK image to ITK."""
# Create a simple 3D scalar SimpleITK image
size = [10, 20, 30]
Expand Down Expand Up @@ -89,7 +91,7 @@ def test_sitk_to_itk_scalar_image(self, image_tools):

print("✓ SimpleITK to ITK scalar conversion successful")

def test_roundtrip_scalar_image(self, image_tools):
def test_roundtrip_scalar_image(self, image_tools: ImageTools) -> None:
"""Test roundtrip conversion: ITK -> SimpleITK -> ITK."""
# Create ITK image
size = [15, 25, 35]
Expand Down Expand Up @@ -125,7 +127,7 @@ def test_roundtrip_scalar_image(self, image_tools):

print("✓ Roundtrip scalar conversion successful")

def test_itk_to_sitk_vector_image(self, image_tools):
def test_itk_to_sitk_vector_image(self, image_tools: ImageTools) -> None:
"""Test conversion of vector ITK image to SimpleITK."""
# Create a 3D vector ITK image (like a displacement field)
size = [8, 12, 16]
Expand Down Expand Up @@ -163,7 +165,7 @@ def test_itk_to_sitk_vector_image(self, image_tools):

print("✓ ITK to SimpleITK vector conversion successful")

def test_sitk_to_itk_vector_image(self, image_tools):
def test_sitk_to_itk_vector_image(self, image_tools: ImageTools) -> None:
"""Test conversion of vector SimpleITK image to ITK."""
# Create a 3D vector SimpleITK image
size = [8, 12, 16]
Expand Down Expand Up @@ -194,7 +196,7 @@ def test_sitk_to_itk_vector_image(self, image_tools):

print("✓ SimpleITK to ITK vector conversion successful")

def test_roundtrip_vector_image(self, image_tools):
def test_roundtrip_vector_image(self, image_tools: ImageTools) -> None:
"""Test roundtrip conversion for vector images: ITK -> SimpleITK -> ITK."""
# Create ITK vector image
size = [10, 15, 20]
Expand Down Expand Up @@ -234,8 +236,12 @@ def test_roundtrip_vector_image(self, image_tools):
@pytest.mark.requires_data
@pytest.mark.slow
def test_imwrite_imread_vd3(
self, image_tools, ants_registration_results, test_images, test_directories
):
self,
image_tools: ImageTools,
ants_registration_results: dict,
test_images: list,
test_directories: dict,
) -> None:
"""Test reading and writing double precision vector images."""
from physiomotion4d.transform_tools import TransformTools

Expand Down Expand Up @@ -297,5 +303,156 @@ def test_imwrite_imread_vd3(
assert mean_diff < 1e-6, f"Mean difference too large: {mean_diff}"


def _make_synthetic_itk_image(
shape_xyz: tuple[int, int, int],
arr: np.ndarray | None = None,
direction: np.ndarray | None = None,
) -> itk.Image[itk.F, 3]:
"""Create a small 3D ITK image. shape_xyz is (nx, ny, nz); array from ITK is (nz, ny, nx)."""
# ITK uses (x, y, z) for size; array_from_image gives (z, y, x)
if arr is None:
arr = np.arange(np.prod(shape_xyz), dtype=np.float32).reshape(
shape_xyz[2], shape_xyz[1], shape_xyz[0]
)
ImageType = itk.Image[itk.F, 3]
itk_image = ImageType.New()
region = itk.ImageRegion[3]()
region.SetSize(shape_xyz) # (nx, ny, nz)
itk_image.SetRegions(region)
itk_image.SetSpacing([1.0, 1.0, 1.0])
itk_image.SetOrigin([0.0, 0.0, 0.0])
if direction is not None:
itk_image.SetDirection(itk.matrix_from_array(direction))
itk_image.Allocate()
itk.array_view_from_image(itk_image)[:] = arr
return itk_image


class TestFlipImage:
"""Unit tests for ImageTools.flip_image (axis flips and direction reset)."""

@pytest.fixture
def image_tools(self) -> ImageTools:
return ImageTools()

def test_flip_x_flips_along_last_array_axis(self, image_tools: ImageTools) -> None:
"""flip_x flips the image along the x (last) array dimension."""
# Small image: ITK size (nx, ny, nz) = (3, 2, 2) -> array shape (2, 2, 3)
shape_xyz = (3, 2, 2)
arr = np.arange(12, dtype=np.float32).reshape(2, 2, 3) # (z, y, x)
itk_image = _make_synthetic_itk_image(shape_xyz, arr=arr)
out = image_tools.flip_image(itk_image, flip_x=True)
out_arr = itk.array_from_image(out)
expected = np.flip(arr, axis=2)
assert np.allclose(out_arr, expected), (
"flip_x should match np.flip(..., axis=2)"
)

def test_flip_y_flips_along_middle_array_axis(
self, image_tools: ImageTools
) -> None:
"""flip_y flips the image along the y (middle) array dimension."""
shape_xyz = (3, 2, 2)
arr = np.arange(12, dtype=np.float32).reshape(2, 2, 3)
itk_image = _make_synthetic_itk_image(shape_xyz, arr=arr)
out = image_tools.flip_image(itk_image, flip_y=True)
out_arr = itk.array_from_image(out)
expected = np.flip(arr, axis=1)
assert np.allclose(out_arr, expected), (
"flip_y should match np.flip(..., axis=1)"
)

def test_flip_z_flips_along_first_array_axis(self, image_tools: ImageTools) -> None:
"""flip_z flips the image along the z (first) array dimension."""
shape_xyz = (3, 2, 2)
arr = np.arange(12, dtype=np.float32).reshape(2, 2, 3)
itk_image = _make_synthetic_itk_image(shape_xyz, arr=arr)
out = image_tools.flip_image(itk_image, flip_z=True)
out_arr = itk.array_from_image(out)
expected = np.flip(arr, axis=0)
assert np.allclose(out_arr, expected), (
"flip_z should match np.flip(..., axis=0)"
)

def test_flip_xy_combines_flips(self, image_tools: ImageTools) -> None:
"""flip_x and flip_y together flip both axes."""
shape_xyz = (3, 2, 2)
arr = np.arange(12, dtype=np.float32).reshape(2, 2, 3)
itk_image = _make_synthetic_itk_image(shape_xyz, arr=arr)
out = image_tools.flip_image(itk_image, flip_x=True, flip_y=True)
out_arr = itk.array_from_image(out)
expected = np.flip(np.flip(arr, axis=2), axis=1)
assert np.allclose(out_arr, expected)

def test_no_flip_returns_same_image(self, image_tools: ImageTools) -> None:
"""With no flip flags, image is returned unchanged."""
shape_xyz = (2, 2, 2)
arr = np.arange(8, dtype=np.float32).reshape(2, 2, 2)
itk_image = _make_synthetic_itk_image(shape_xyz, arr=arr)
out = image_tools.flip_image(
itk_image, flip_x=False, flip_y=False, flip_z=False
)
out_arr = itk.array_from_image(out)
assert np.allclose(out_arr, arr)

def test_mask_flipped_in_lockstep_with_image(self, image_tools: ImageTools) -> None:
"""When a mask is provided, it is flipped with the same axes as the image."""
shape_xyz = (3, 2, 2)
arr = np.arange(12, dtype=np.float32).reshape(2, 2, 3)
mask_arr = (arr % 2 == 0).astype(np.float32) # 0/1 pattern in lockstep with arr
itk_image = _make_synthetic_itk_image(shape_xyz, arr=arr)
itk_mask = _make_synthetic_itk_image(shape_xyz, arr=mask_arr)
out_image, out_mask = image_tools.flip_image(
itk_image, in_mask=itk_mask, flip_x=True, flip_z=True
)
out_img_arr = itk.array_from_image(out_image)
out_msk_arr = itk.array_from_image(out_mask)
expected_img = np.flip(np.flip(arr, axis=2), axis=0)
expected_msk = np.flip(np.flip(mask_arr, axis=2), axis=0)
assert np.allclose(out_img_arr, expected_img), "Image should be flipped x and z"
assert np.allclose(out_msk_arr, expected_msk), (
"Mask should be flipped in lockstep"
)
# Consistency: mask value should still align with image (same pattern, flipped)
assert np.allclose(out_msk_arr, (out_img_arr % 2 == 0).astype(np.float32))

def test_flip_and_make_identity_sets_direction_to_identity(
self, image_tools: ImageTools
) -> None:
"""flip_and_make_identity flips as needed and sets direction matrix to identity."""
shape_xyz = (2, 2, 2)
arr = np.arange(8, dtype=np.float32).reshape(2, 2, 2)
# Direction with negative diagonal (e.g. flipped along one axis)
direction = np.diag([-1.0, 1.0, 1.0])
itk_image = _make_synthetic_itk_image(shape_xyz, arr=arr, direction=direction)
out = image_tools.flip_image(itk_image, flip_and_make_identity=True)
out_direction = np.array(out.GetDirection())
Comment thread
aylward marked this conversation as resolved.
identity = np.eye(3)
assert np.allclose(out_direction, identity), (
"flip_and_make_identity should set direction to identity"
)

def test_flip_and_make_identity_with_mask_sets_both_directions_to_identity(
self, image_tools: ImageTools
) -> None:
"""With mask and flip_and_make_identity, both image and mask get identity direction."""
shape_xyz = (2, 2, 2)
arr = np.ones((2, 2, 2), dtype=np.float32)
mask_arr = np.ones((2, 2, 2), dtype=np.float32)
direction = np.diag([1.0, -1.0, 1.0])
itk_image = _make_synthetic_itk_image(shape_xyz, arr=arr, direction=direction)
itk_mask = _make_synthetic_itk_image(
shape_xyz, arr=mask_arr, direction=direction
)
out_image, out_mask = image_tools.flip_image(
itk_image, in_mask=itk_mask, flip_and_make_identity=True
)
for im, name in [(out_image, "image"), (out_mask, "mask")]:
dir_mat = np.array(im.GetDirection())
assert np.allclose(dir_mat, np.eye(3)), (
f"flip_and_make_identity should set {name} direction to identity"
)


Comment thread
aylward marked this conversation as resolved.
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])
Loading