|
26 | 26 | from monai.config import USE_COMPILED, DtypeLike |
27 | 27 | from monai.config.type_definitions import NdarrayOrTensor |
28 | 28 | from monai.data.box_utils import BoxMode, StandardMode |
29 | | -from monai.data.meta_obj import get_track_meta, set_track_meta |
| 29 | +from monai.data.meta_obj import get_track_meta |
30 | 30 | from monai.data.meta_tensor import MetaTensor |
31 | 31 | from monai.data.utils import AFFINE_TOL, affine_to_spacing, compute_shape_offset, iter_patch, to_affine_nd, zoom_affine |
32 | 32 | from monai.networks.layers import AffineTransform, GaussianFilter, grid_pull |
@@ -3567,31 +3567,32 @@ def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor: |
3567 | 3567 | input_shape = img.shape[1:] |
3568 | 3568 | target_shape = tuple(np.round(np.array(input_shape) * self.zoom_factor).astype(np.int_).tolist()) |
3569 | 3569 |
|
3570 | | - resize_tfm_downsample = Resize( |
3571 | | - spatial_size=target_shape, size_mode="all", mode=self.downsample_mode, anti_aliasing=False |
| 3570 | + # Use F.interpolate directly on a plain tensor to avoid mutating the global |
| 3571 | + # set_track_meta flag, which is not thread-safe (see GitHub issue #8409). |
| 3572 | + img_t = convert_to_tensor(img, track_meta=False) |
| 3573 | + # F.interpolate requires float input and a batch dimension; cast matches |
| 3574 | + # the default dtype=float32 that Resize uses internally. |
| 3575 | + img_float = img_t.unsqueeze(0).to(dtype=torch.float32) |
| 3576 | + |
| 3577 | + downsample_mode = str(self.downsample_mode) |
| 3578 | + upsample_mode = str(self.upsample_mode) |
| 3579 | + # align_corners is only valid for linear/bilinear/bicubic/trilinear modes |
| 3580 | + _align_corners_modes = {"linear", "bilinear", "bicubic", "trilinear"} |
| 3581 | + downsample_align_corners = self.align_corners if downsample_mode in _align_corners_modes else None |
| 3582 | + upsample_align_corners = self.align_corners if upsample_mode in _align_corners_modes else None |
| 3583 | + |
| 3584 | + img_downsampled = torch.nn.functional.interpolate( |
| 3585 | + img_float, size=target_shape, mode=downsample_mode, align_corners=downsample_align_corners |
3572 | 3586 | ) |
3573 | | - |
3574 | | - resize_tfm_upsample = Resize( |
3575 | | - spatial_size=input_shape, |
3576 | | - size_mode="all", |
3577 | | - mode=self.upsample_mode, |
3578 | | - anti_aliasing=False, |
3579 | | - align_corners=self.align_corners, |
3580 | | - ) |
3581 | | - # temporarily disable metadata tracking, since we do not want to invert the two Resize functions during |
3582 | | - # post-processing |
3583 | | - original_tack_meta_value = get_track_meta() |
3584 | | - set_track_meta(False) |
3585 | | - |
3586 | | - img_downsampled = resize_tfm_downsample(img) |
3587 | | - img_upsampled = resize_tfm_upsample(img_downsampled) |
3588 | | - |
3589 | | - # reset metadata tracking to original value |
3590 | | - set_track_meta(original_tack_meta_value) |
3591 | | - |
3592 | | - # copy metadata from original image to down-and-upsampled image |
3593 | | - img_upsampled = MetaTensor(img_upsampled) |
3594 | | - img_upsampled.copy_meta_from(img) |
| 3587 | + img_upsampled_t = torch.nn.functional.interpolate( |
| 3588 | + img_downsampled, size=input_shape, mode=upsample_mode, align_corners=upsample_align_corners |
| 3589 | + ).squeeze(0) |
| 3590 | + |
| 3591 | + # copy metadata from original image to down-and-upsampled image, |
| 3592 | + # respecting the caller's get_track_meta() setting. |
| 3593 | + img_upsampled = convert_to_tensor(img_upsampled_t, track_meta=get_track_meta()) |
| 3594 | + if isinstance(img_upsampled, MetaTensor): |
| 3595 | + img_upsampled.copy_meta_from(img) |
3595 | 3596 |
|
3596 | 3597 | return img_upsampled |
3597 | 3598 |
|
|
0 commit comments