Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 17 additions & 21 deletions monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from monai.config import USE_COMPILED, DtypeLike
from monai.config.type_definitions import NdarrayOrTensor
from monai.data.box_utils import BoxMode, StandardMode
from monai.data.meta_obj import get_track_meta, set_track_meta
from monai.data.meta_obj import get_track_meta
from monai.data.meta_tensor import MetaTensor
from monai.data.utils import AFFINE_TOL, affine_to_spacing, compute_shape_offset, iter_patch, to_affine_nd, zoom_affine
from monai.networks.layers import AffineTransform, GaussianFilter, grid_pull
Expand Down Expand Up @@ -3567,30 +3567,26 @@ def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor:
input_shape = img.shape[1:]
target_shape = tuple(np.round(np.array(input_shape) * self.zoom_factor).astype(np.int_).tolist())

resize_tfm_downsample = Resize(
spatial_size=target_shape, size_mode="all", mode=self.downsample_mode, anti_aliasing=False
)

resize_tfm_upsample = Resize(
spatial_size=input_shape,
size_mode="all",
mode=self.upsample_mode,
anti_aliasing=False,
align_corners=self.align_corners,
)
# temporarily disable metadata tracking, since we do not want to invert the two Resize functions during
# post-processing
original_tack_meta_value = get_track_meta()
set_track_meta(False)
# Use F.interpolate directly on a plain tensor to avoid mutating the global
# set_track_meta flag, which is not thread-safe (see GitHub issue #8409).
img_t = convert_to_tensor(img, track_meta=False)
# F.interpolate requires float input and a batch dimension; cast matches
# the default dtype=float32 that Resize uses internally.
img_float = img_t.unsqueeze(0).to(dtype=torch.float32)

img_downsampled = resize_tfm_downsample(img)
img_upsampled = resize_tfm_upsample(img_downsampled)
downsample_mode = self.downsample_mode.value if hasattr(self.downsample_mode, "value") else self.downsample_mode
upsample_mode = self.upsample_mode.value if hasattr(self.upsample_mode, "value") else self.upsample_mode
# align_corners is only valid for linear/bilinear/bicubic/trilinear modes
_align_corners_modes = {"linear", "bilinear", "bicubic", "trilinear"}
upsample_align_corners = self.align_corners if upsample_mode in _align_corners_modes else None

# reset metadata tracking to original value
set_track_meta(original_tack_meta_value)
img_downsampled = torch.nn.functional.interpolate(img_float, size=target_shape, mode=downsample_mode)
img_upsampled_t = torch.nn.functional.interpolate(
img_downsampled, size=input_shape, mode=upsample_mode, align_corners=upsample_align_corners
).squeeze(0)

# copy metadata from original image to down-and-upsampled image
img_upsampled = MetaTensor(img_upsampled)
img_upsampled = MetaTensor(img_upsampled_t)
img_upsampled.copy_meta_from(img)

return img_upsampled
Expand Down
38 changes: 38 additions & 0 deletions tests/transforms/test_rand_simulate_low_resolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,14 @@

from __future__ import annotations

import threading
import unittest

import numpy as np
import torch
from parameterized import parameterized

from monai.data.meta_obj import get_track_meta
from monai.transforms import RandSimulateLowResolution
from tests.test_utils import TEST_NDARRAYS, assert_allclose

Expand Down Expand Up @@ -78,6 +81,41 @@ def test_value(self, arguments, image, expected_data):
result = randsimlowres(image)
assert_allclose(result, expected_data, rtol=1e-4, type_test="tensor")

def test_track_meta_global_state_unchanged(self):
# Verify that calling RandSimulateLowResolution does not modify the global
# set_track_meta flag (regression test for GitHub issue #8409).
img = torch.ones(1, 4, 4, 4)
tfm = RandSimulateLowResolution(prob=1.0, zoom_range=(0.5, 0.6))
tfm.set_random_state(seed=0)

original_track_meta = get_track_meta()
tfm(img)
self.assertEqual(get_track_meta(), original_track_meta, "set_track_meta global state was unexpectedly modified")

def test_thread_safety(self):
# Verify that concurrent calls do not corrupt each other's track_meta state
# (regression test for GitHub issue #8409).
errors = []

def run_transform():
img = torch.ones(1, 4, 4, 4)
tfm = RandSimulateLowResolution(prob=1.0, zoom_range=(0.5, 0.6))
before = get_track_meta()
try:
tfm(img)
except Exception as e:
errors.append(e)
if get_track_meta() != before:
errors.append(RuntimeError("track_meta state changed in thread"))

threads = [threading.Thread(target=run_transform) for _ in range(8)]
for t in threads:
t.start()
for t in threads:
t.join()

self.assertEqual(errors, [], f"Thread safety errors: {errors}")


if __name__ == "__main__":
unittest.main()
Loading