diff --git a/src/physiomotion4d/image_tools.py b/src/physiomotion4d/image_tools.py index ccb974d..c6e4023 100644 --- a/src/physiomotion4d/image_tools.py +++ b/src/physiomotion4d/image_tools.py @@ -246,34 +246,69 @@ def convert_array_to_image_of_vectors( return itk_image - def flip_image_to_identity_direction( - self, in_image: itk.Image, in_mask: Optional[itk.Image] = None + def flip_image( + self, + in_image: itk.Image, + in_mask: Optional[itk.Image] = None, + flip_x: bool = False, + flip_y: bool = False, + flip_z: bool = False, + flip_and_make_identity: bool = False, ) -> Any | tuple[Any, Any]: """ - Flip the image to the identity direction. + Flip the image and mask. + + Only axis-aligned flips are supported. If ``flip_and_make_identity`` is + True, the image and mask are first flipped along any axes whose + corresponding diagonal entries in the direction matrix are negative + (assuming the direction matrix encodes only axis-aligned flips), then + any additional requested flips are performed, and finally the direction + matrix is set to the identity matrix. This is useful when combining ITK + images with VTK objects (that often do not support a direction matrix). + + Args: + in_image: The input image to flip + in_mask: The input mask to flip + flip_x: Flip the image and mask along the x-axis + flip_y: Flip the image and mask along the y-axis + flip_z: Flip the image and mask along the z-axis + flip_and_make_identity: Flip the image and mask and make the direction + matrix identity. """ - flip0 = np.array(in_image.GetDirection())[0, 0] < 0 - flip1 = np.array(in_image.GetDirection())[1, 1] < 0 - flip2 = np.array(in_image.GetDirection())[2, 2] < 0 + flip0 = False + flip1 = False + flip2 = False + if flip_and_make_identity: + flip0 = np.array(in_image.GetDirection())[0, 0] < 0 + flip1 = np.array(in_image.GetDirection())[1, 1] < 0 + flip2 = np.array(in_image.GetDirection())[2, 2] < 0 + if flip_x: + flip0 = True + if flip_y: + flip1 = True + if flip_z: + flip2 = True if flip0 or flip1 or flip2: - self.log_info( - f"Flipping image to identity direction: {flip0}, {flip1}, {flip2}" - ) + self.log_info(f"Flipping image: {flip0}, {flip1}, {flip2}") flip_filter = itk.FlipImageFilter.New(Input=in_image) flip_filter.SetFlipAxes([int(flip0), int(flip1), int(flip2)]) flip_filter.SetFlipAboutOrigin(True) flip_filter.Update() out_image = flip_filter.GetOutput() - id_mat = itk.Matrix[itk.D, 3, 3]() - id_mat.SetIdentity() - out_image.SetDirection(id_mat) + if flip_and_make_identity: + id_mat = itk.Matrix[itk.D, 3, 3]() + id_mat.SetIdentity() + out_image.SetDirection(id_mat) if in_mask is not None: flip_filter = itk.FlipImageFilter.New(Input=in_mask) flip_filter.SetFlipAxes([int(flip0), int(flip1), int(flip2)]) flip_filter.SetFlipAboutOrigin(True) flip_filter.Update() out_mask = flip_filter.GetOutput() - out_mask.SetDirection(id_mat) + if flip_and_make_identity: + id_mat = itk.Matrix[itk.D, 3, 3]() + id_mat.SetIdentity() + out_mask.SetDirection(id_mat) return out_image, out_mask else: return out_image diff --git a/tests/test_image_tools.py b/tests/test_image_tools.py index 104a2cb..eb9d11c 100644 --- a/tests/test_image_tools.py +++ b/tests/test_image_tools.py @@ -6,6 +6,8 @@ scalar and vector images. """ +from __future__ import annotations + import itk import numpy as np import pytest @@ -18,11 +20,11 @@ class TestImageTools: """Test suite for ImageTools conversions.""" @pytest.fixture - def image_tools(self): + def image_tools(self) -> ImageTools: """Create ImageTools instance.""" return ImageTools() - def test_itk_to_sitk_scalar_image(self, image_tools): + def test_itk_to_sitk_scalar_image(self, image_tools: ImageTools) -> None: """Test conversion of scalar ITK image to SimpleITK.""" # Create a simple 3D scalar ITK image size = [10, 20, 30] @@ -57,7 +59,7 @@ def test_itk_to_sitk_scalar_image(self, image_tools): print("✓ ITK to SimpleITK scalar conversion successful") - def test_sitk_to_itk_scalar_image(self, image_tools): + def test_sitk_to_itk_scalar_image(self, image_tools: ImageTools) -> None: """Test conversion of scalar SimpleITK image to ITK.""" # Create a simple 3D scalar SimpleITK image size = [10, 20, 30] @@ -89,7 +91,7 @@ def test_sitk_to_itk_scalar_image(self, image_tools): print("✓ SimpleITK to ITK scalar conversion successful") - def test_roundtrip_scalar_image(self, image_tools): + def test_roundtrip_scalar_image(self, image_tools: ImageTools) -> None: """Test roundtrip conversion: ITK -> SimpleITK -> ITK.""" # Create ITK image size = [15, 25, 35] @@ -125,7 +127,7 @@ def test_roundtrip_scalar_image(self, image_tools): print("✓ Roundtrip scalar conversion successful") - def test_itk_to_sitk_vector_image(self, image_tools): + def test_itk_to_sitk_vector_image(self, image_tools: ImageTools) -> None: """Test conversion of vector ITK image to SimpleITK.""" # Create a 3D vector ITK image (like a displacement field) size = [8, 12, 16] @@ -163,7 +165,7 @@ def test_itk_to_sitk_vector_image(self, image_tools): print("✓ ITK to SimpleITK vector conversion successful") - def test_sitk_to_itk_vector_image(self, image_tools): + def test_sitk_to_itk_vector_image(self, image_tools: ImageTools) -> None: """Test conversion of vector SimpleITK image to ITK.""" # Create a 3D vector SimpleITK image size = [8, 12, 16] @@ -194,7 +196,7 @@ def test_sitk_to_itk_vector_image(self, image_tools): print("✓ SimpleITK to ITK vector conversion successful") - def test_roundtrip_vector_image(self, image_tools): + def test_roundtrip_vector_image(self, image_tools: ImageTools) -> None: """Test roundtrip conversion for vector images: ITK -> SimpleITK -> ITK.""" # Create ITK vector image size = [10, 15, 20] @@ -234,8 +236,12 @@ def test_roundtrip_vector_image(self, image_tools): @pytest.mark.requires_data @pytest.mark.slow def test_imwrite_imread_vd3( - self, image_tools, ants_registration_results, test_images, test_directories - ): + self, + image_tools: ImageTools, + ants_registration_results: dict, + test_images: list, + test_directories: dict, + ) -> None: """Test reading and writing double precision vector images.""" from physiomotion4d.transform_tools import TransformTools @@ -297,5 +303,156 @@ def test_imwrite_imread_vd3( assert mean_diff < 1e-6, f"Mean difference too large: {mean_diff}" +def _make_synthetic_itk_image( + shape_xyz: tuple[int, int, int], + arr: np.ndarray | None = None, + direction: np.ndarray | None = None, +) -> itk.Image[itk.F, 3]: + """Create a small 3D ITK image. shape_xyz is (nx, ny, nz); array from ITK is (nz, ny, nx).""" + # ITK uses (x, y, z) for size; array_from_image gives (z, y, x) + if arr is None: + arr = np.arange(np.prod(shape_xyz), dtype=np.float32).reshape( + shape_xyz[2], shape_xyz[1], shape_xyz[0] + ) + ImageType = itk.Image[itk.F, 3] + itk_image = ImageType.New() + region = itk.ImageRegion[3]() + region.SetSize(shape_xyz) # (nx, ny, nz) + itk_image.SetRegions(region) + itk_image.SetSpacing([1.0, 1.0, 1.0]) + itk_image.SetOrigin([0.0, 0.0, 0.0]) + if direction is not None: + itk_image.SetDirection(itk.matrix_from_array(direction)) + itk_image.Allocate() + itk.array_view_from_image(itk_image)[:] = arr + return itk_image + + +class TestFlipImage: + """Unit tests for ImageTools.flip_image (axis flips and direction reset).""" + + @pytest.fixture + def image_tools(self) -> ImageTools: + return ImageTools() + + def test_flip_x_flips_along_last_array_axis(self, image_tools: ImageTools) -> None: + """flip_x flips the image along the x (last) array dimension.""" + # Small image: ITK size (nx, ny, nz) = (3, 2, 2) -> array shape (2, 2, 3) + shape_xyz = (3, 2, 2) + arr = np.arange(12, dtype=np.float32).reshape(2, 2, 3) # (z, y, x) + itk_image = _make_synthetic_itk_image(shape_xyz, arr=arr) + out = image_tools.flip_image(itk_image, flip_x=True) + out_arr = itk.array_from_image(out) + expected = np.flip(arr, axis=2) + assert np.allclose(out_arr, expected), ( + "flip_x should match np.flip(..., axis=2)" + ) + + def test_flip_y_flips_along_middle_array_axis( + self, image_tools: ImageTools + ) -> None: + """flip_y flips the image along the y (middle) array dimension.""" + shape_xyz = (3, 2, 2) + arr = np.arange(12, dtype=np.float32).reshape(2, 2, 3) + itk_image = _make_synthetic_itk_image(shape_xyz, arr=arr) + out = image_tools.flip_image(itk_image, flip_y=True) + out_arr = itk.array_from_image(out) + expected = np.flip(arr, axis=1) + assert np.allclose(out_arr, expected), ( + "flip_y should match np.flip(..., axis=1)" + ) + + def test_flip_z_flips_along_first_array_axis(self, image_tools: ImageTools) -> None: + """flip_z flips the image along the z (first) array dimension.""" + shape_xyz = (3, 2, 2) + arr = np.arange(12, dtype=np.float32).reshape(2, 2, 3) + itk_image = _make_synthetic_itk_image(shape_xyz, arr=arr) + out = image_tools.flip_image(itk_image, flip_z=True) + out_arr = itk.array_from_image(out) + expected = np.flip(arr, axis=0) + assert np.allclose(out_arr, expected), ( + "flip_z should match np.flip(..., axis=0)" + ) + + def test_flip_xy_combines_flips(self, image_tools: ImageTools) -> None: + """flip_x and flip_y together flip both axes.""" + shape_xyz = (3, 2, 2) + arr = np.arange(12, dtype=np.float32).reshape(2, 2, 3) + itk_image = _make_synthetic_itk_image(shape_xyz, arr=arr) + out = image_tools.flip_image(itk_image, flip_x=True, flip_y=True) + out_arr = itk.array_from_image(out) + expected = np.flip(np.flip(arr, axis=2), axis=1) + assert np.allclose(out_arr, expected) + + def test_no_flip_returns_same_image(self, image_tools: ImageTools) -> None: + """With no flip flags, image is returned unchanged.""" + shape_xyz = (2, 2, 2) + arr = np.arange(8, dtype=np.float32).reshape(2, 2, 2) + itk_image = _make_synthetic_itk_image(shape_xyz, arr=arr) + out = image_tools.flip_image( + itk_image, flip_x=False, flip_y=False, flip_z=False + ) + out_arr = itk.array_from_image(out) + assert np.allclose(out_arr, arr) + + def test_mask_flipped_in_lockstep_with_image(self, image_tools: ImageTools) -> None: + """When a mask is provided, it is flipped with the same axes as the image.""" + shape_xyz = (3, 2, 2) + arr = np.arange(12, dtype=np.float32).reshape(2, 2, 3) + mask_arr = (arr % 2 == 0).astype(np.float32) # 0/1 pattern in lockstep with arr + itk_image = _make_synthetic_itk_image(shape_xyz, arr=arr) + itk_mask = _make_synthetic_itk_image(shape_xyz, arr=mask_arr) + out_image, out_mask = image_tools.flip_image( + itk_image, in_mask=itk_mask, flip_x=True, flip_z=True + ) + out_img_arr = itk.array_from_image(out_image) + out_msk_arr = itk.array_from_image(out_mask) + expected_img = np.flip(np.flip(arr, axis=2), axis=0) + expected_msk = np.flip(np.flip(mask_arr, axis=2), axis=0) + assert np.allclose(out_img_arr, expected_img), "Image should be flipped x and z" + assert np.allclose(out_msk_arr, expected_msk), ( + "Mask should be flipped in lockstep" + ) + # Consistency: mask value should still align with image (same pattern, flipped) + assert np.allclose(out_msk_arr, (out_img_arr % 2 == 0).astype(np.float32)) + + def test_flip_and_make_identity_sets_direction_to_identity( + self, image_tools: ImageTools + ) -> None: + """flip_and_make_identity flips as needed and sets direction matrix to identity.""" + shape_xyz = (2, 2, 2) + arr = np.arange(8, dtype=np.float32).reshape(2, 2, 2) + # Direction with negative diagonal (e.g. flipped along one axis) + direction = np.diag([-1.0, 1.0, 1.0]) + itk_image = _make_synthetic_itk_image(shape_xyz, arr=arr, direction=direction) + out = image_tools.flip_image(itk_image, flip_and_make_identity=True) + out_direction = np.array(out.GetDirection()) + identity = np.eye(3) + assert np.allclose(out_direction, identity), ( + "flip_and_make_identity should set direction to identity" + ) + + def test_flip_and_make_identity_with_mask_sets_both_directions_to_identity( + self, image_tools: ImageTools + ) -> None: + """With mask and flip_and_make_identity, both image and mask get identity direction.""" + shape_xyz = (2, 2, 2) + arr = np.ones((2, 2, 2), dtype=np.float32) + mask_arr = np.ones((2, 2, 2), dtype=np.float32) + direction = np.diag([1.0, -1.0, 1.0]) + itk_image = _make_synthetic_itk_image(shape_xyz, arr=arr, direction=direction) + itk_mask = _make_synthetic_itk_image( + shape_xyz, arr=mask_arr, direction=direction + ) + out_image, out_mask = image_tools.flip_image( + itk_image, in_mask=itk_mask, flip_and_make_identity=True + ) + for im, name in [(out_image, "image"), (out_mask, "mask")]: + dir_mat = np.array(im.GetDirection()) + assert np.allclose(dir_mat, np.eye(3)), ( + f"flip_and_make_identity should set {name} direction to identity" + ) + + if __name__ == "__main__": pytest.main([__file__, "-v", "-s"])