Skip to content

Commit 937a516

Browse files
committed
fix: replace set_track_meta with F.interpolate in RandSimulateLowResolution (#8409)
Signed-off-by: chhayankjain <chhayank44@gmail.com>
1 parent d0650f3 commit 937a516

2 files changed

Lines changed: 70 additions & 25 deletions

File tree

monai/transforms/spatial/array.py

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from monai.config import USE_COMPILED, DtypeLike
2727
from monai.config.type_definitions import NdarrayOrTensor
2828
from monai.data.box_utils import BoxMode, StandardMode
29-
from monai.data.meta_obj import get_track_meta, set_track_meta
29+
from monai.data.meta_obj import get_track_meta
3030
from monai.data.meta_tensor import MetaTensor
3131
from monai.data.utils import AFFINE_TOL, affine_to_spacing, compute_shape_offset, iter_patch, to_affine_nd, zoom_affine
3232
from monai.networks.layers import AffineTransform, GaussianFilter, grid_pull
@@ -3567,31 +3567,32 @@ def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor:
35673567
input_shape = img.shape[1:]
35683568
target_shape = tuple(np.round(np.array(input_shape) * self.zoom_factor).astype(np.int_).tolist())
35693569

3570-
resize_tfm_downsample = Resize(
3571-
spatial_size=target_shape, size_mode="all", mode=self.downsample_mode, anti_aliasing=False
3570+
# Use F.interpolate directly on a plain tensor to avoid mutating the global
3571+
# set_track_meta flag, which is not thread-safe (see GitHub issue #8409).
3572+
img_t = convert_to_tensor(img, track_meta=False)
3573+
# F.interpolate requires float input and a batch dimension; cast matches
3574+
# the default dtype=float32 that Resize uses internally.
3575+
img_float = img_t.unsqueeze(0).to(dtype=torch.float32)
3576+
3577+
downsample_mode = str(self.downsample_mode)
3578+
upsample_mode = str(self.upsample_mode)
3579+
# align_corners is only valid for linear/bilinear/bicubic/trilinear modes
3580+
_align_corners_modes = {"linear", "bilinear", "bicubic", "trilinear"}
3581+
downsample_align_corners = self.align_corners if downsample_mode in _align_corners_modes else None
3582+
upsample_align_corners = self.align_corners if upsample_mode in _align_corners_modes else None
3583+
3584+
img_downsampled = torch.nn.functional.interpolate(
3585+
img_float, size=target_shape, mode=downsample_mode, align_corners=downsample_align_corners
35723586
)
3573-
3574-
resize_tfm_upsample = Resize(
3575-
spatial_size=input_shape,
3576-
size_mode="all",
3577-
mode=self.upsample_mode,
3578-
anti_aliasing=False,
3579-
align_corners=self.align_corners,
3580-
)
3581-
# temporarily disable metadata tracking, since we do not want to invert the two Resize functions during
3582-
# post-processing
3583-
original_tack_meta_value = get_track_meta()
3584-
set_track_meta(False)
3585-
3586-
img_downsampled = resize_tfm_downsample(img)
3587-
img_upsampled = resize_tfm_upsample(img_downsampled)
3588-
3589-
# reset metadata tracking to original value
3590-
set_track_meta(original_tack_meta_value)
3591-
3592-
# copy metadata from original image to down-and-upsampled image
3593-
img_upsampled = MetaTensor(img_upsampled)
3594-
img_upsampled.copy_meta_from(img)
3587+
img_upsampled_t = torch.nn.functional.interpolate(
3588+
img_downsampled, size=input_shape, mode=upsample_mode, align_corners=upsample_align_corners
3589+
).squeeze(0)
3590+
3591+
# copy metadata from original image to down-and-upsampled image,
3592+
# respecting the caller's get_track_meta() setting.
3593+
img_upsampled = convert_to_tensor(img_upsampled_t, track_meta=get_track_meta())
3594+
if isinstance(img_upsampled, MetaTensor):
3595+
img_upsampled.copy_meta_from(img)
35953596

35963597
return img_upsampled
35973598

tests/transforms/test_rand_simulate_low_resolution.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,14 @@
1111

1212
from __future__ import annotations
1313

14+
import threading
1415
import unittest
1516

1617
import numpy as np
18+
import torch
1719
from parameterized import parameterized
1820

21+
from monai.data.meta_obj import get_track_meta
1922
from monai.transforms import RandSimulateLowResolution
2023
from tests.test_utils import TEST_NDARRAYS, assert_allclose
2124

@@ -78,6 +81,47 @@ def test_value(self, arguments, image, expected_data):
7881
result = randsimlowres(image)
7982
assert_allclose(result, expected_data, rtol=1e-4, type_test="tensor")
8083

84+
def test_track_meta_global_state_unchanged(self):
85+
# Verify that calling RandSimulateLowResolution does not modify the global
86+
# set_track_meta flag (regression test for GitHub issue #8409).
87+
img = torch.ones(1, 4, 4, 4)
88+
tfm = RandSimulateLowResolution(prob=1.0, zoom_range=(0.5, 0.6))
89+
tfm.set_random_state(seed=0)
90+
91+
original_track_meta = get_track_meta()
92+
tfm(img)
93+
self.assertEqual(get_track_meta(), original_track_meta, "set_track_meta global state was unexpectedly modified")
94+
95+
def test_thread_safety(self):
96+
# Verify that concurrent calls do not corrupt each other's track_meta state
97+
# (regression test for GitHub issue #8409).
98+
# expected_track_meta is captured before threads start so every worker
99+
# checks against the same baseline rather than its own (possibly already
100+
# corrupted) snapshot.
101+
errors = []
102+
expected_track_meta = get_track_meta()
103+
start_barrier = threading.Barrier(8)
104+
105+
def run_transform():
106+
img = torch.ones(1, 4, 4, 4)
107+
tfm = RandSimulateLowResolution(prob=1.0, zoom_range=(0.5, 0.6))
108+
start_barrier.wait() # synchronise so all threads hammer the transform at once
109+
try:
110+
for _ in range(50):
111+
tfm(img)
112+
except Exception as e: # noqa: BLE001
113+
errors.append(e)
114+
if get_track_meta() != expected_track_meta:
115+
errors.append(RuntimeError("track_meta state changed in thread"))
116+
117+
threads = [threading.Thread(target=run_transform) for _ in range(8)]
118+
for t in threads:
119+
t.start()
120+
for t in threads:
121+
t.join()
122+
123+
self.assertEqual(errors, [], f"Thread safety errors: {errors}")
124+
81125

82126
if __name__ == "__main__":
83127
unittest.main()

0 commit comments

Comments
 (0)