Skip to content

Commit 1fed118

Browse files
aymuos15ericspod
authored andcommitted
Fix nested Compose map_items in forward and inverse paths (Project-MONAI#8787)
## Summary Fixes Project-MONAI#7932, Project-MONAI#7565 When a child `Compose` has a different `map_items` setting than its parent, the parent's `apply_transform` would expand list/tuple data before the child ever sees it — silently overriding the child's `map_items`. This PR makes three coordinated changes so the child's `map_items` is respected: - **Forward path** (`apply_transform`): Skip list expansion when the transform is a `Compose` instance, letting it handle expansion via its own `map_items` in `execute_compose`. - **Inverse path** (`_inverse_one` helper): Delegate directly to `Compose.inverse()` for nested `Compose` objects (including `RandomOrder` and `SomeOf`) instead of routing through `apply_transform(t.inverse, ...)`. - **`flatten()`**: Only inline nested `Compose` objects that share the same `map_items` as the parent. Children with a different `map_items` are preserved as-is. ## Test plan - [x] `test_child_map_items_false_receives_list` — parent `map_items=True`, child `map_items=False`: child receives list as-is - [x] `test_inverse_respects_child_map_items` — inverse roundtrip with nested Compose - [x] `test_parent_no_map_child_map` — parent `map_items=False`, child `map_items=True`: child maps over items - [x] `test_flatten_preserves_different_map_items` — `flatten()` does not merge children with different `map_items` --------- Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com>
1 parent d1a8557 commit 1fed118

3 files changed

Lines changed: 209 additions & 16 deletions

File tree

monai/transforms/compose.py

Lines changed: 43 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,31 @@
3737
__all__ = ["Compose", "OneOf", "RandomOrder", "SomeOf", "execute_compose"]
3838

3939

40+
def _inverse_one(
41+
t: InvertibleTransform, data: Any, map_items: bool | int, unpack_items: bool, log_stats: bool | str
42+
) -> Any:
43+
"""Invert a single transform, delegating directly to nested ``Compose`` objects.
44+
45+
When ``t`` is a ``Compose`` instance its own ``inverse()`` is called so that
46+
the child's ``map_items`` setting is respected. For all other invertible
47+
transforms, ``apply_transform`` is used with ``lazy=False``.
48+
49+
Args:
50+
t: The invertible transform to invert.
51+
data: Data to be inverted.
52+
map_items: Whether to map over list/tuple items (forwarded to
53+
``apply_transform`` for non-``Compose`` transforms).
54+
unpack_items: Whether to unpack data as parameters.
55+
log_stats: Logger name or boolean for logging.
56+
57+
Returns:
58+
The inverted data.
59+
"""
60+
if isinstance(t, Compose):
61+
return t.inverse(data)
62+
return apply_transform(t.inverse, data, map_items, unpack_items, lazy=False, log_stats=log_stats)
63+
64+
4065
def execute_compose(
4166
data: NdarrayOrTensor | Sequence[NdarrayOrTensor] | Mapping[Any, NdarrayOrTensor],
4267
transforms: Sequence[Any],
@@ -315,20 +340,32 @@ def get_index_of_first(self, predicate):
315340
return None
316341

317342
def flatten(self):
318-
"""Return a Composition with a simple list of transforms, as opposed to any nested Compositions.
343+
"""Return a Composition with a flattened list of transforms.
344+
345+
Nested ``Compose`` objects that share the same ``map_items`` setting as
346+
the parent are inlined. Nested ``Compose`` objects with a *different*
347+
``map_items`` value are kept as-is so their item-mapping behaviour is
348+
preserved at runtime and during inversion.
319349
320350
e.g., `t1 = Compose([x, x, x, x, Compose([Compose([x, x]), x, x])]).flatten()`
321351
will result in the equivalent of `t1 = Compose([x, x, x, x, x, x, x, x])`.
322352
323353
"""
324354
new_transforms = []
325355
for t in self.transforms:
326-
if type(t) is Compose: # nopep8
356+
if type(t) is Compose and t.map_items == self.map_items:
327357
new_transforms += t.flatten().transforms
328358
else:
329359
new_transforms.append(t)
330360

331-
return Compose(new_transforms)
361+
return Compose(
362+
new_transforms,
363+
map_items=self.map_items,
364+
unpack_items=self.unpack_items,
365+
log_stats=self.log_stats,
366+
lazy=self._lazy,
367+
overrides=self.overrides,
368+
)
332369

333370
def __len__(self):
334371
"""Return number of transformations."""
@@ -365,9 +402,7 @@ def inverse(self, data):
365402
)
366403
# loop backwards over transforms
367404
for t in reversed(invertible_transforms):
368-
data = apply_transform(
369-
t.inverse, data, self.map_items, self.unpack_items, lazy=False, log_stats=self.log_stats
370-
)
405+
data = _inverse_one(t, data, self.map_items, self.unpack_items, self.log_stats)
371406
return data
372407

373408
@staticmethod
@@ -622,9 +657,7 @@ def inverse(self, data):
622657
# loop backwards over transforms
623658
for o in reversed(applied_order):
624659
if isinstance(self.transforms[o], InvertibleTransform):
625-
data = apply_transform(
626-
self.transforms[o].inverse, data, self.map_items, self.unpack_items, log_stats=self.log_stats
627-
)
660+
data = _inverse_one(self.transforms[o], data, self.map_items, self.unpack_items, self.log_stats)
628661
return data
629662

630663

@@ -789,8 +822,6 @@ def inverse(self, data):
789822
# loop backwards over transforms
790823
for o in reversed(applied_order):
791824
if isinstance(self.transforms[o], InvertibleTransform):
792-
data = apply_transform(
793-
self.transforms[o].inverse, data, self.map_items, self.unpack_items, log_stats=self.log_stats
794-
)
825+
data = _inverse_one(self.transforms[o], data, self.map_items, self.unpack_items, self.log_stats)
795826

796827
return data

monai/transforms/transform.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -143,10 +143,13 @@ def apply_transform(
143143
try:
144144
map_items_ = int(map_items) if isinstance(map_items, bool) else map_items
145145
if isinstance(data, (list, tuple)) and map_items_ > 0 and not isinstance(transform, ReduceTrait):
146-
return [
147-
apply_transform(transform, item, map_items_ - 1, unpack_items, log_stats, lazy, overrides)
148-
for item in data
149-
]
146+
# If the transform is a Compose with its own map_items, let it handle list/tuple
147+
# expansion internally so that nested Compose map_items settings are respected.
148+
if not isinstance(transform, transforms.compose.Compose):
149+
return [
150+
apply_transform(transform, item, map_items_ - 1, unpack_items, log_stats, lazy, overrides)
151+
for item in data
152+
]
150153
return _apply_transform(transform, data, unpack_items, lazy, overrides, log_stats)
151154
except Exception as e:
152155
# if in debug mode, don't swallow exception so that the breakpoint

tests/transforms/compose/test_compose.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -775,6 +775,165 @@ def test_compose_execute_equivalence_with_flags(self, flags, data, pipeline):
775775
self.assertEqual(expected, actual)
776776

777777

778+
class TestNestedComposeMapItems(unittest.TestCase):
779+
"""Tests for nested Compose respecting child map_items (issues #7932, #7565)."""
780+
781+
def test_child_map_items_false_receives_list(self):
782+
"""Parent map_items=True, child map_items=False: child receives list as-is."""
783+
784+
def split(x):
785+
return [x + 1, x + 2]
786+
787+
def sum_list(items):
788+
return sum(items)
789+
790+
# The child Compose(map_items=False) should receive the list from split()
791+
# and pass it as-is to sum_list, rather than the parent expanding the list.
792+
pipeline = mt.Compose([split, mt.Compose([sum_list], map_items=False)])
793+
result = pipeline(10)
794+
self.assertEqual(result, 23) # (10+1) + (10+2) = 23
795+
796+
def test_inverse_respects_child_map_items(self):
797+
"""Inverse path should delegate to child Compose.inverse directly."""
798+
pipeline = mt.Compose([mt.Flip(0), mt.Compose([mt.Flip(1)], map_items=False)])
799+
data = torch.randn(1, 4, 4)
800+
result = pipeline(data)
801+
restored = pipeline.inverse(result)
802+
torch.testing.assert_close(data, restored)
803+
804+
def test_parent_no_map_child_map(self):
805+
"""Parent map_items=False, child map_items=True: child maps over items."""
806+
807+
def double(x):
808+
return x * 2
809+
810+
# Parent treats the list as a single value; child maps double() over each item.
811+
pipeline = mt.Compose([mt.Compose([double], map_items=True)], map_items=False)
812+
result = pipeline([1, 2, 3])
813+
self.assertEqual(result, [2, 4, 6])
814+
815+
def test_flatten_preserves_different_map_items(self):
816+
"""flatten() should not merge a child Compose with different map_items."""
817+
818+
def noop(x):
819+
return x
820+
821+
parent = mt.Compose([noop, mt.Compose([noop, noop], map_items=False), noop])
822+
flat = parent.flatten()
823+
# The inner Compose(map_items=False) should NOT be flattened
824+
self.assertEqual(len(flat.transforms), 3)
825+
self.assertIsInstance(flat.transforms[1], mt.Compose)
826+
827+
def test_multiple_children_with_mixed_map_items(self):
828+
"""Multiple internal Composes with different map_items should be handled correctly."""
829+
830+
def add_one(items):
831+
if isinstance(items, list):
832+
return [x + 1 for x in items]
833+
return items + 1
834+
835+
def multiply_two(items):
836+
if isinstance(items, list):
837+
return [x * 2 for x in items]
838+
return items * 2
839+
840+
# Parent with map_items=False processes the entire input as one unit
841+
# Child 1 (map_items=True) will map over each item in what it receives
842+
# Child 2 (map_items=False) will process the entire thing
843+
pipeline = mt.Compose(
844+
[mt.Compose([add_one], map_items=True), mt.Compose([multiply_two], map_items=False)], map_items=False
845+
)
846+
847+
# Input [1, 2, 3]
848+
# First child with map_items=True maps add_one over [1,2,3]: [2, 3, 4]
849+
# Second child with map_items=False receives [2,3,4] and applies multiply_two: [4, 6, 8]
850+
result = pipeline([1, 2, 3])
851+
self.assertEqual(result, [4, 6, 8])
852+
853+
def test_flatten_with_multiple_children_preserves_both(self):
854+
"""flatten() should preserve child with different map_items but flatten child with same."""
855+
856+
def noop(x):
857+
return x
858+
859+
parent = mt.Compose(
860+
[
861+
noop,
862+
mt.Compose([noop, noop], map_items=True), # Same as parent, will be flattened
863+
mt.Compose([noop, noop], map_items=False), # Different, will be preserved
864+
noop,
865+
]
866+
)
867+
flat = parent.flatten()
868+
# First nested Compose(map_items=True) will be flattened into parent
869+
# Second nested Compose(map_items=False) will be preserved
870+
# Result: noop + noop + noop + Compose([noop, noop]) + noop = 5 transforms
871+
self.assertEqual(len(flat.transforms), 5)
872+
# Check that the preserved one is at the correct position
873+
self.assertIsInstance(flat.transforms[3], mt.Compose)
874+
self.assertEqual(flat.transforms[3].map_items, False)
875+
876+
def test_three_level_nesting_respects_different_map_items(self):
877+
"""Three-level nesting with different map_items at each level."""
878+
879+
def add_one(x):
880+
return x + 1
881+
882+
# Level 1 (outermost): map_items=True (default)
883+
# Level 2: map_items=False
884+
# Level 3: map_items=True (same as level 2, so will be flattened into level 2)
885+
innermost = mt.Compose([add_one], map_items=True)
886+
middle = mt.Compose([add_one, innermost], map_items=False)
887+
outer = mt.Compose([middle])
888+
889+
# Test with a simple value
890+
# outer has map_items=True (default), middle has map_items=False
891+
# So middle should be preserved and receive the input as-is
892+
result = outer(5)
893+
# outer(5) -> maps to middle -> middle(5) with map_items=False
894+
# middle(5) -> add_one(5) = 6, then innermost(6) with map_items=True
895+
# innermost(6) -> add_one(6) = 7
896+
self.assertEqual(result, 7)
897+
898+
def test_inverse_with_multiple_children_different_map_items(self):
899+
"""Inverse should work correctly with multiple children having different map_items."""
900+
pipeline = mt.Compose(
901+
[mt.Flip(0), mt.Compose([mt.Flip(1)], map_items=False), mt.Compose([mt.Flip(0)], map_items=True)]
902+
)
903+
data = torch.randn(2, 4, 4)
904+
result = pipeline(data)
905+
restored = pipeline.inverse(result)
906+
torch.testing.assert_close(data, restored)
907+
908+
def test_flatten_with_mixed_same_and_different_map_items(self):
909+
"""flatten() should merge children with same map_items but preserve those with different."""
910+
911+
def noop(x):
912+
return x
913+
914+
# Parent has map_items=True (default)
915+
# Child 1 has map_items=True (same as parent) -> should be flattened
916+
# Child 2 has map_items=False (different from parent) -> should NOT be flattened
917+
parent = mt.Compose(
918+
[
919+
noop,
920+
mt.Compose([noop, noop], map_items=True), # Same as parent, will be flattened
921+
mt.Compose([noop, noop], map_items=False), # Different from parent, will be preserved
922+
noop,
923+
]
924+
)
925+
flat = parent.flatten()
926+
# After flatten:
927+
# - noop (preserved)
928+
# - 2 noops from first Compose (flattened because map_items=True matches parent)
929+
# - Compose([noop, noop], map_items=False) (preserved because different)
930+
# - noop (preserved)
931+
# Total: 5 transforms
932+
self.assertEqual(len(flat.transforms), 5)
933+
self.assertIsInstance(flat.transforms[3], mt.Compose)
934+
self.assertEqual(flat.transforms[3].map_items, False)
935+
936+
778937
class TestComposeCallableInput(unittest.TestCase):
779938

780939
def test_value_error_when_not_sequence(self):

0 commit comments

Comments
 (0)