Skip to content

Commit 1a6db7b

Browse files
authored
ENH: Add method for flipping image and orientations. (#29)
1 parent 48893e3 commit 1a6db7b

2 files changed

Lines changed: 214 additions & 22 deletions

File tree

src/physiomotion4d/image_tools.py

Lines changed: 48 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -246,34 +246,69 @@ def convert_array_to_image_of_vectors(
246246

247247
return itk_image
248248

249-
def flip_image_to_identity_direction(
250-
self, in_image: itk.Image, in_mask: Optional[itk.Image] = None
249+
def flip_image(
250+
self,
251+
in_image: itk.Image,
252+
in_mask: Optional[itk.Image] = None,
253+
flip_x: bool = False,
254+
flip_y: bool = False,
255+
flip_z: bool = False,
256+
flip_and_make_identity: bool = False,
251257
) -> Any | tuple[Any, Any]:
252258
"""
253-
Flip the image to the identity direction.
259+
Flip the image and mask.
260+
261+
Only axis-aligned flips are supported. If ``flip_and_make_identity`` is
262+
True, the image and mask are first flipped along any axes whose
263+
corresponding diagonal entries in the direction matrix are negative
264+
(assuming the direction matrix encodes only axis-aligned flips), then
265+
any additional requested flips are performed, and finally the direction
266+
matrix is set to the identity matrix. This is useful when combining ITK
267+
images with VTK objects (that often do not support a direction matrix).
268+
269+
Args:
270+
in_image: The input image to flip
271+
in_mask: The input mask to flip
272+
flip_x: Flip the image and mask along the x-axis
273+
flip_y: Flip the image and mask along the y-axis
274+
flip_z: Flip the image and mask along the z-axis
275+
flip_and_make_identity: Flip the image and mask and make the direction
276+
matrix identity.
254277
"""
255-
flip0 = np.array(in_image.GetDirection())[0, 0] < 0
256-
flip1 = np.array(in_image.GetDirection())[1, 1] < 0
257-
flip2 = np.array(in_image.GetDirection())[2, 2] < 0
278+
flip0 = False
279+
flip1 = False
280+
flip2 = False
281+
if flip_and_make_identity:
282+
flip0 = np.array(in_image.GetDirection())[0, 0] < 0
283+
flip1 = np.array(in_image.GetDirection())[1, 1] < 0
284+
flip2 = np.array(in_image.GetDirection())[2, 2] < 0
285+
if flip_x:
286+
flip0 = True
287+
if flip_y:
288+
flip1 = True
289+
if flip_z:
290+
flip2 = True
258291
if flip0 or flip1 or flip2:
259-
self.log_info(
260-
f"Flipping image to identity direction: {flip0}, {flip1}, {flip2}"
261-
)
292+
self.log_info(f"Flipping image: {flip0}, {flip1}, {flip2}")
262293
flip_filter = itk.FlipImageFilter.New(Input=in_image)
263294
flip_filter.SetFlipAxes([int(flip0), int(flip1), int(flip2)])
264295
flip_filter.SetFlipAboutOrigin(True)
265296
flip_filter.Update()
266297
out_image = flip_filter.GetOutput()
267-
id_mat = itk.Matrix[itk.D, 3, 3]()
268-
id_mat.SetIdentity()
269-
out_image.SetDirection(id_mat)
298+
if flip_and_make_identity:
299+
id_mat = itk.Matrix[itk.D, 3, 3]()
300+
id_mat.SetIdentity()
301+
out_image.SetDirection(id_mat)
270302
if in_mask is not None:
271303
flip_filter = itk.FlipImageFilter.New(Input=in_mask)
272304
flip_filter.SetFlipAxes([int(flip0), int(flip1), int(flip2)])
273305
flip_filter.SetFlipAboutOrigin(True)
274306
flip_filter.Update()
275307
out_mask = flip_filter.GetOutput()
276-
out_mask.SetDirection(id_mat)
308+
if flip_and_make_identity:
309+
id_mat = itk.Matrix[itk.D, 3, 3]()
310+
id_mat.SetIdentity()
311+
out_mask.SetDirection(id_mat)
277312
return out_image, out_mask
278313
else:
279314
return out_image

tests/test_image_tools.py

Lines changed: 166 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
scalar and vector images.
77
"""
88

9+
from __future__ import annotations
10+
911
import itk
1012
import numpy as np
1113
import pytest
@@ -18,11 +20,11 @@ class TestImageTools:
1820
"""Test suite for ImageTools conversions."""
1921

2022
@pytest.fixture
21-
def image_tools(self):
23+
def image_tools(self) -> ImageTools:
2224
"""Create ImageTools instance."""
2325
return ImageTools()
2426

25-
def test_itk_to_sitk_scalar_image(self, image_tools):
27+
def test_itk_to_sitk_scalar_image(self, image_tools: ImageTools) -> None:
2628
"""Test conversion of scalar ITK image to SimpleITK."""
2729
# Create a simple 3D scalar ITK image
2830
size = [10, 20, 30]
@@ -57,7 +59,7 @@ def test_itk_to_sitk_scalar_image(self, image_tools):
5759

5860
print("✓ ITK to SimpleITK scalar conversion successful")
5961

60-
def test_sitk_to_itk_scalar_image(self, image_tools):
62+
def test_sitk_to_itk_scalar_image(self, image_tools: ImageTools) -> None:
6163
"""Test conversion of scalar SimpleITK image to ITK."""
6264
# Create a simple 3D scalar SimpleITK image
6365
size = [10, 20, 30]
@@ -89,7 +91,7 @@ def test_sitk_to_itk_scalar_image(self, image_tools):
8991

9092
print("✓ SimpleITK to ITK scalar conversion successful")
9193

92-
def test_roundtrip_scalar_image(self, image_tools):
94+
def test_roundtrip_scalar_image(self, image_tools: ImageTools) -> None:
9395
"""Test roundtrip conversion: ITK -> SimpleITK -> ITK."""
9496
# Create ITK image
9597
size = [15, 25, 35]
@@ -125,7 +127,7 @@ def test_roundtrip_scalar_image(self, image_tools):
125127

126128
print("✓ Roundtrip scalar conversion successful")
127129

128-
def test_itk_to_sitk_vector_image(self, image_tools):
130+
def test_itk_to_sitk_vector_image(self, image_tools: ImageTools) -> None:
129131
"""Test conversion of vector ITK image to SimpleITK."""
130132
# Create a 3D vector ITK image (like a displacement field)
131133
size = [8, 12, 16]
@@ -163,7 +165,7 @@ def test_itk_to_sitk_vector_image(self, image_tools):
163165

164166
print("✓ ITK to SimpleITK vector conversion successful")
165167

166-
def test_sitk_to_itk_vector_image(self, image_tools):
168+
def test_sitk_to_itk_vector_image(self, image_tools: ImageTools) -> None:
167169
"""Test conversion of vector SimpleITK image to ITK."""
168170
# Create a 3D vector SimpleITK image
169171
size = [8, 12, 16]
@@ -194,7 +196,7 @@ def test_sitk_to_itk_vector_image(self, image_tools):
194196

195197
print("✓ SimpleITK to ITK vector conversion successful")
196198

197-
def test_roundtrip_vector_image(self, image_tools):
199+
def test_roundtrip_vector_image(self, image_tools: ImageTools) -> None:
198200
"""Test roundtrip conversion for vector images: ITK -> SimpleITK -> ITK."""
199201
# Create ITK vector image
200202
size = [10, 15, 20]
@@ -234,8 +236,12 @@ def test_roundtrip_vector_image(self, image_tools):
234236
@pytest.mark.requires_data
235237
@pytest.mark.slow
236238
def test_imwrite_imread_vd3(
237-
self, image_tools, ants_registration_results, test_images, test_directories
238-
):
239+
self,
240+
image_tools: ImageTools,
241+
ants_registration_results: dict,
242+
test_images: list,
243+
test_directories: dict,
244+
) -> None:
239245
"""Test reading and writing double precision vector images."""
240246
from physiomotion4d.transform_tools import TransformTools
241247

@@ -297,5 +303,156 @@ def test_imwrite_imread_vd3(
297303
assert mean_diff < 1e-6, f"Mean difference too large: {mean_diff}"
298304

299305

306+
def _make_synthetic_itk_image(
307+
shape_xyz: tuple[int, int, int],
308+
arr: np.ndarray | None = None,
309+
direction: np.ndarray | None = None,
310+
) -> itk.Image[itk.F, 3]:
311+
"""Create a small 3D ITK image. shape_xyz is (nx, ny, nz); array from ITK is (nz, ny, nx)."""
312+
# ITK uses (x, y, z) for size; array_from_image gives (z, y, x)
313+
if arr is None:
314+
arr = np.arange(np.prod(shape_xyz), dtype=np.float32).reshape(
315+
shape_xyz[2], shape_xyz[1], shape_xyz[0]
316+
)
317+
ImageType = itk.Image[itk.F, 3]
318+
itk_image = ImageType.New()
319+
region = itk.ImageRegion[3]()
320+
region.SetSize(shape_xyz) # (nx, ny, nz)
321+
itk_image.SetRegions(region)
322+
itk_image.SetSpacing([1.0, 1.0, 1.0])
323+
itk_image.SetOrigin([0.0, 0.0, 0.0])
324+
if direction is not None:
325+
itk_image.SetDirection(itk.matrix_from_array(direction))
326+
itk_image.Allocate()
327+
itk.array_view_from_image(itk_image)[:] = arr
328+
return itk_image
329+
330+
331+
class TestFlipImage:
332+
"""Unit tests for ImageTools.flip_image (axis flips and direction reset)."""
333+
334+
@pytest.fixture
335+
def image_tools(self) -> ImageTools:
336+
return ImageTools()
337+
338+
def test_flip_x_flips_along_last_array_axis(self, image_tools: ImageTools) -> None:
339+
"""flip_x flips the image along the x (last) array dimension."""
340+
# Small image: ITK size (nx, ny, nz) = (3, 2, 2) -> array shape (2, 2, 3)
341+
shape_xyz = (3, 2, 2)
342+
arr = np.arange(12, dtype=np.float32).reshape(2, 2, 3) # (z, y, x)
343+
itk_image = _make_synthetic_itk_image(shape_xyz, arr=arr)
344+
out = image_tools.flip_image(itk_image, flip_x=True)
345+
out_arr = itk.array_from_image(out)
346+
expected = np.flip(arr, axis=2)
347+
assert np.allclose(out_arr, expected), (
348+
"flip_x should match np.flip(..., axis=2)"
349+
)
350+
351+
def test_flip_y_flips_along_middle_array_axis(
352+
self, image_tools: ImageTools
353+
) -> None:
354+
"""flip_y flips the image along the y (middle) array dimension."""
355+
shape_xyz = (3, 2, 2)
356+
arr = np.arange(12, dtype=np.float32).reshape(2, 2, 3)
357+
itk_image = _make_synthetic_itk_image(shape_xyz, arr=arr)
358+
out = image_tools.flip_image(itk_image, flip_y=True)
359+
out_arr = itk.array_from_image(out)
360+
expected = np.flip(arr, axis=1)
361+
assert np.allclose(out_arr, expected), (
362+
"flip_y should match np.flip(..., axis=1)"
363+
)
364+
365+
def test_flip_z_flips_along_first_array_axis(self, image_tools: ImageTools) -> None:
366+
"""flip_z flips the image along the z (first) array dimension."""
367+
shape_xyz = (3, 2, 2)
368+
arr = np.arange(12, dtype=np.float32).reshape(2, 2, 3)
369+
itk_image = _make_synthetic_itk_image(shape_xyz, arr=arr)
370+
out = image_tools.flip_image(itk_image, flip_z=True)
371+
out_arr = itk.array_from_image(out)
372+
expected = np.flip(arr, axis=0)
373+
assert np.allclose(out_arr, expected), (
374+
"flip_z should match np.flip(..., axis=0)"
375+
)
376+
377+
def test_flip_xy_combines_flips(self, image_tools: ImageTools) -> None:
378+
"""flip_x and flip_y together flip both axes."""
379+
shape_xyz = (3, 2, 2)
380+
arr = np.arange(12, dtype=np.float32).reshape(2, 2, 3)
381+
itk_image = _make_synthetic_itk_image(shape_xyz, arr=arr)
382+
out = image_tools.flip_image(itk_image, flip_x=True, flip_y=True)
383+
out_arr = itk.array_from_image(out)
384+
expected = np.flip(np.flip(arr, axis=2), axis=1)
385+
assert np.allclose(out_arr, expected)
386+
387+
def test_no_flip_returns_same_image(self, image_tools: ImageTools) -> None:
388+
"""With no flip flags, image is returned unchanged."""
389+
shape_xyz = (2, 2, 2)
390+
arr = np.arange(8, dtype=np.float32).reshape(2, 2, 2)
391+
itk_image = _make_synthetic_itk_image(shape_xyz, arr=arr)
392+
out = image_tools.flip_image(
393+
itk_image, flip_x=False, flip_y=False, flip_z=False
394+
)
395+
out_arr = itk.array_from_image(out)
396+
assert np.allclose(out_arr, arr)
397+
398+
def test_mask_flipped_in_lockstep_with_image(self, image_tools: ImageTools) -> None:
399+
"""When a mask is provided, it is flipped with the same axes as the image."""
400+
shape_xyz = (3, 2, 2)
401+
arr = np.arange(12, dtype=np.float32).reshape(2, 2, 3)
402+
mask_arr = (arr % 2 == 0).astype(np.float32) # 0/1 pattern in lockstep with arr
403+
itk_image = _make_synthetic_itk_image(shape_xyz, arr=arr)
404+
itk_mask = _make_synthetic_itk_image(shape_xyz, arr=mask_arr)
405+
out_image, out_mask = image_tools.flip_image(
406+
itk_image, in_mask=itk_mask, flip_x=True, flip_z=True
407+
)
408+
out_img_arr = itk.array_from_image(out_image)
409+
out_msk_arr = itk.array_from_image(out_mask)
410+
expected_img = np.flip(np.flip(arr, axis=2), axis=0)
411+
expected_msk = np.flip(np.flip(mask_arr, axis=2), axis=0)
412+
assert np.allclose(out_img_arr, expected_img), "Image should be flipped x and z"
413+
assert np.allclose(out_msk_arr, expected_msk), (
414+
"Mask should be flipped in lockstep"
415+
)
416+
# Consistency: mask value should still align with image (same pattern, flipped)
417+
assert np.allclose(out_msk_arr, (out_img_arr % 2 == 0).astype(np.float32))
418+
419+
def test_flip_and_make_identity_sets_direction_to_identity(
420+
self, image_tools: ImageTools
421+
) -> None:
422+
"""flip_and_make_identity flips as needed and sets direction matrix to identity."""
423+
shape_xyz = (2, 2, 2)
424+
arr = np.arange(8, dtype=np.float32).reshape(2, 2, 2)
425+
# Direction with negative diagonal (e.g. flipped along one axis)
426+
direction = np.diag([-1.0, 1.0, 1.0])
427+
itk_image = _make_synthetic_itk_image(shape_xyz, arr=arr, direction=direction)
428+
out = image_tools.flip_image(itk_image, flip_and_make_identity=True)
429+
out_direction = np.array(out.GetDirection())
430+
identity = np.eye(3)
431+
assert np.allclose(out_direction, identity), (
432+
"flip_and_make_identity should set direction to identity"
433+
)
434+
435+
def test_flip_and_make_identity_with_mask_sets_both_directions_to_identity(
436+
self, image_tools: ImageTools
437+
) -> None:
438+
"""With mask and flip_and_make_identity, both image and mask get identity direction."""
439+
shape_xyz = (2, 2, 2)
440+
arr = np.ones((2, 2, 2), dtype=np.float32)
441+
mask_arr = np.ones((2, 2, 2), dtype=np.float32)
442+
direction = np.diag([1.0, -1.0, 1.0])
443+
itk_image = _make_synthetic_itk_image(shape_xyz, arr=arr, direction=direction)
444+
itk_mask = _make_synthetic_itk_image(
445+
shape_xyz, arr=mask_arr, direction=direction
446+
)
447+
out_image, out_mask = image_tools.flip_image(
448+
itk_image, in_mask=itk_mask, flip_and_make_identity=True
449+
)
450+
for im, name in [(out_image, "image"), (out_mask, "mask")]:
451+
dir_mat = np.array(im.GetDirection())
452+
assert np.allclose(dir_mat, np.eye(3)), (
453+
f"flip_and_make_identity should set {name} direction to identity"
454+
)
455+
456+
300457
if __name__ == "__main__":
301458
pytest.main([__file__, "-v", "-s"])

0 commit comments

Comments
 (0)