Skip to content

Commit 7d5bf42

Browse files
committed
Fix RandGridDistortiond crash when transform is skipped
When _do_transform is False, convert_to_tensor was called on the entire data dict, which fails when non-tensor values (e.g. ints, strings) are present — causing "AttributeError: 'int' object has no attribute 'numel'" in the DataLoader collate function. Convert only the keyed tensor items instead, consistent with how other dict transforms handle the no-transform case. Also return dict unchanged when no keys match (first_key == ()). Fixes #8604 Signed-off-by: Heyang Qin <qysnn1@gmail.com>
1 parent 5b71547 commit 7d5bf42

File tree

2 files changed

+19
-4
lines changed

2 files changed

+19
-4
lines changed

monai/transforms/spatial/dictionary.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2305,13 +2305,13 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torc
23052305
d = dict(data)
23062306
self.randomize(None)
23072307
if not self._do_transform:
2308-
out: dict[Hashable, torch.Tensor] = convert_to_tensor(d, track_meta=get_track_meta())
2309-
return out
2308+
for key in self.key_iterator(d):
2309+
d[key] = convert_to_tensor(d[key], track_meta=get_track_meta())
2310+
return d
23102311

23112312
first_key: Hashable = self.first_key(d)
23122313
if first_key == ():
2313-
out = convert_to_tensor(d, track_meta=get_track_meta())
2314-
return out
2314+
return d
23152315
if isinstance(d[first_key], MetaTensor) and d[first_key].pending_operations: # type: ignore
23162316
warnings.warn(f"data['{first_key}'] has pending operations, transform may return incorrect results.")
23172317
self.rand_grid_distortion.randomize(d[first_key].shape[1:])

tests/transforms/test_rand_grid_distortiond.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,14 +77,29 @@
7777

7878

7979
class TestRandGridDistortiond(unittest.TestCase):
80+
"""Test cases for RandGridDistortiond dictionary transform."""
81+
8082
@parameterized.expand(TESTS)
8183
def test_rand_grid_distortiond(self, input_param, seed, input_data, expected_val_img, expected_val_mask):
84+
"""Verify distortion produces expected output for image and mask keys."""
8285
g = RandGridDistortiond(**input_param)
8386
g.set_random_state(seed=seed)
8487
result = g(input_data)
8588
assert_allclose(result["img"], expected_val_img, type_test=False, rtol=1e-4, atol=1e-4)
8689
assert_allclose(result["mask"], expected_val_mask, type_test=False, rtol=1e-4, atol=1e-4)
8790

8891

92+
def test_no_transform_with_non_tensor_metadata(self):
93+
"""When _do_transform is False, non-tensor values in the dict should not cause an error."""
94+
img = np.indices([6, 6]).astype(np.float32)
95+
data = {"img": img, "extra_info": 42, "label_name": "tumor"}
96+
g = RandGridDistortiond(keys=["img"], prob=0.0) # prob=0 ensures _do_transform is False
97+
result = g(data)
98+
# non-tensor metadata should pass through unchanged
99+
self.assertEqual(result["extra_info"], 42)
100+
self.assertEqual(result["label_name"], "tumor")
101+
assert_allclose(result["img"], img, type_test=False)
102+
103+
89104
if __name__ == "__main__":
90105
unittest.main()

0 commit comments

Comments
 (0)