Skip to content

Commit 68f7e2e

Browse files
committed
Consistently return d unchanged for empty key iteration across spatial transforms
Unify the first_key == () handling in RandAffined, Rand2DElasticd, Rand3DElasticd, RandZoomd, and RandSimulateLowResolutiond to return the dict directly instead of calling convert_to_tensor on the entire dict, which could inadvertently convert non-tensor metadata. Signed-off-by: Heyang Qin <qysnn1@gmail.com>
1 parent 7d5bf42 commit 68f7e2e

1 file changed

Lines changed: 5 additions & 10 deletions

File tree

monai/transforms/spatial/dictionary.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1161,8 +1161,7 @@ def __call__(
11611161
d = dict(data)
11621162
first_key: Hashable = self.first_key(d)
11631163
if first_key == ():
1164-
out: dict[Hashable, NdarrayOrTensor] = convert_to_tensor(d, track_meta=get_track_meta())
1165-
return out
1164+
return d
11661165

11671166
self.randomize(None)
11681167
# all the keys share the same random Affine factor
@@ -1322,8 +1321,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
13221321
first_key: Hashable = self.first_key(d)
13231322

13241323
if first_key == ():
1325-
out: dict[Hashable, NdarrayOrTensor] = convert_to_tensor(d, track_meta=get_track_meta())
1326-
return out
1324+
return d
13271325

13281326
self.randomize(None)
13291327
device = self.rand_2d_elastic.device
@@ -1473,8 +1471,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torc
14731471
first_key: Hashable = self.first_key(d)
14741472

14751473
if first_key == ():
1476-
out: dict[Hashable, torch.Tensor] = convert_to_tensor(d, track_meta=get_track_meta())
1477-
return out
1474+
return d
14781475

14791476
self.randomize(None)
14801477
if isinstance(d[first_key], MetaTensor) and d[first_key].pending_operations: # type: ignore
@@ -2134,8 +2131,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = No
21342131
d = dict(data)
21352132
first_key: Hashable = self.first_key(d)
21362133
if first_key == ():
2137-
out: dict[Hashable, torch.Tensor] = convert_to_tensor(d, track_meta=get_track_meta())
2138-
return out
2134+
return d
21392135

21402136
self.randomize(None)
21412137

@@ -2633,8 +2629,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
26332629
d = dict(data)
26342630
first_key: Hashable = self.first_key(d)
26352631
if first_key == ():
2636-
out: dict[Hashable, NdarrayOrTensor] = convert_to_tensor(d, track_meta=get_track_meta())
2637-
return out
2632+
return d
26382633

26392634
self.randomize(None)
26402635

0 commit comments

Comments
 (0)