Skip to content

Commit b887eae

Browse files
eclipse0922sewon.jeonericspod
authored
Fix Invertd (#8651)
Fixes #8396. ### Description This PR fixes a regression where `Invertd` and direct `Compose.inverse()` calls could fail when later post-processing steps had appended unrelated entries to the same transform history. Root cause: - inverse execution assumed that the most recent history entry had to belong to the transform currently being inverted - once post-processing appended additional history entries, inverse would pop from the tail, hit the wrong transform, and raise an ID mismatch Solution: - centralize inverse matching in the shared `TraceableTransform` history lookup path - search backward for the most recent matching transform entry instead of assuming the tail entry must match - pop only the matched preprocessing entry, leaving unrelated post-processing history intact - remove the group-tracking plumbing that was previously introduced in `Compose`, `TraceKeys`, and `Invertd` Regression coverage added for: - the original `Invertd` reproducer from #8396 - multiple unrelated trailing post-processing transforms - same-class trailing transform history on `MetaTensor` - trace-key history with a different instance of the same invertible transform class - direct `Compose.inverse()` with mixed post-processing history ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: sewon.jeon <sewon.jeon@connecteve.com> Signed-off-by: sewon jeon <irocks0922@gmail.com> Signed-off-by: sewon.jeon <irocks0922@gmail.com> Co-authored-by: sewon.jeon <sewon.jeon@connecteve.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com>
1 parent 5ef6b93 commit b887eae

File tree

15 files changed

+289
-21
lines changed

15 files changed

+289
-21
lines changed

monai/apps/detection/transforms/box_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ def convert_box_to_mask(
267267
boxes_only_mask = np.ones(box_size, dtype=np.int16) * np.int16(labels_np[b])
268268
# apply to global mask
269269
slicing = [b]
270-
slicing.extend(slice(boxes_np[b, d], boxes_np[b, d + spatial_dims]) for d in range(spatial_dims)) # type:ignore
270+
slicing.extend(slice(boxes_np[b, d], boxes_np[b, d + spatial_dims]) for d in range(spatial_dims)) # type: ignore
271271
boxes_mask_np[tuple(slicing)] = boxes_only_mask
272272
return convert_to_dst_type(src=boxes_mask_np, dst=boxes, dtype=torch.int16)[0]
273273

monai/auto3dseg/analyzer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def update_ops_nested_label(self, nested_key: str, op: Operations) -> None:
105105
raise ValueError("Nested_key input format is wrong. Please ensure it is like key1#0#key2")
106106
root: str
107107
child_key: str
108-
(root, _, child_key) = keys
108+
root, _, child_key = keys
109109
if root not in self.ops:
110110
self.ops[root] = [{}]
111111
self.ops[root][0].update({child_key: None})

monai/bundle/scripts.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1948,7 +1948,7 @@ def create_workflow(
19481948
19491949
"""
19501950
_args = update_kwargs(args=args_file, workflow_name=workflow_name, config_file=config_file, **kwargs)
1951-
(workflow_name, config_file) = _pop_args(
1951+
workflow_name, config_file = _pop_args(
19521952
_args, workflow_name=ConfigWorkflow, config_file=None
19531953
) # the default workflow name is "ConfigWorkflow"
19541954
if isinstance(workflow_name, str):

monai/data/dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ class DatasetFunc(Dataset):
139139
"""
140140

141141
def __init__(self, data: Any, func: Callable, **kwargs) -> None:
142-
super().__init__(data=None, transform=None) # type:ignore
142+
super().__init__(data=None, transform=None) # type: ignore
143143
self.src = data
144144
self.func = func
145145
self.kwargs = kwargs
@@ -1635,7 +1635,7 @@ def _cachecheck(self, item_transformed):
16351635
return (_data, _meta)
16361636
return _data
16371637
else:
1638-
item: list[dict[Any, Any]] = [{} for _ in range(len(item_transformed))] # type:ignore
1638+
item: list[dict[Any, Any]] = [{} for _ in range(len(item_transformed))] # type: ignore
16391639
for i, _item in enumerate(item_transformed):
16401640
for k in _item:
16411641
meta_i_k = self._load_meta_cache(meta_hash_file_name=f"{hashfile.name}-{k}-meta-{i}")

monai/handlers/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def stopping_fn_from_loss() -> Callable[[Engine], Any]:
4848
"""
4949

5050
def stopping_fn(engine: Engine) -> Any:
51-
return -engine.state.output # type:ignore
51+
return -engine.state.output # type: ignore
5252

5353
return stopping_fn
5454

monai/metrics/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ def get_edge_surface_distance(
320320
edges_spacing = None
321321
if use_subvoxels:
322322
edges_spacing = spacing if spacing is not None else ([1] * len(y_pred.shape))
323-
(edges_pred, edges_gt, *areas) = get_mask_edges(
323+
edges_pred, edges_gt, *areas = get_mask_edges(
324324
y_pred, y, crop=True, spacing=edges_spacing, always_return_as_numpy=False
325325
)
326326
if not edges_gt.any():

monai/transforms/inverse.py

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,8 @@ def get_transform_info(self) -> dict:
119119
"""
120120
Return a dictionary with the relevant information pertaining to an applied transform.
121121
"""
122+
self._init_trace_threadlocal()
123+
122124
vals = (
123125
self.__class__.__name__,
124126
id(self),
@@ -300,25 +302,45 @@ def track_transform_meta(
300302
return out_obj
301303

302304
def check_transforms_match(self, transform: Mapping) -> None:
303-
"""Check transforms are of same instance."""
304-
xform_id = transform.get(TraceKeys.ID, "")
305-
if xform_id == id(self):
306-
return
307-
# TraceKeys.NONE to skip the id check
308-
if xform_id == TraceKeys.NONE:
305+
"""Check whether a traced transform entry matches this transform.
306+
307+
When multiprocessing uses ``spawn``, transform instances are recreated,
308+
so matching can fall back to the transform class name instead of the
309+
original instance ID.
310+
"""
311+
if self._transforms_match(transform):
309312
return
313+
314+
xform_id = transform.get(TraceKeys.ID, "")
310315
xform_name = transform.get(TraceKeys.CLASS_NAME, "")
311316
warning_msg = transform.get(TraceKeys.EXTRA_INFO, {}).get("warn")
312317
if warning_msg:
313318
warnings.warn(warning_msg)
314-
# basic check if multiprocessing uses 'spawn' (objects get recreated so don't have same ID)
315-
if torch.multiprocessing.get_start_method() in ("spawn", None) and xform_name == self.__class__.__name__:
316-
return
317319
raise RuntimeError(
318320
f"Error {self.__class__.__name__} getting the most recently "
319321
f"applied invertible transform {xform_name} {xform_id} != {id(self)}."
320322
)
321323

324+
def _transforms_match(self, transform: Mapping) -> bool:
325+
"""Return whether a traced transform entry matches this transform.
326+
327+
Matching succeeds when the traced ID matches this instance, when the ID
328+
check is explicitly disabled with ``TraceKeys.NONE``, or when
329+
multiprocessing uses ``spawn`` and the traced class name matches this
330+
transform class.
331+
"""
332+
xform_id = transform.get(TraceKeys.ID, "")
333+
if xform_id == id(self):
334+
return True
335+
# TraceKeys.NONE to skip the id check
336+
if xform_id == TraceKeys.NONE:
337+
return True
338+
xform_name = transform.get(TraceKeys.CLASS_NAME, "")
339+
# basic check if multiprocessing uses 'spawn' (objects get recreated so don't have same ID)
340+
if torch.multiprocessing.get_start_method(allow_none=True) == "spawn" and xform_name == self.__class__.__name__:
341+
return True
342+
return False
343+
322344
def get_most_recent_transform(self, data, key: Hashable = None, check: bool = True, pop: bool = False):
323345
"""
324346
Get most recent matching transform for the current class from the sequence of applied operations.
@@ -350,10 +372,16 @@ def get_most_recent_transform(self, data, key: Hashable = None, check: bool = Tr
350372
if not all_transforms:
351373
raise ValueError(f"Item of type {type(data)} (key: {key}, pop: {pop}) has empty 'applied_operations'")
352374

375+
match_idx = len(all_transforms) - 1
353376
if check:
354-
self.check_transforms_match(all_transforms[-1])
377+
for idx in range(len(all_transforms) - 1, -1, -1):
378+
if self._transforms_match(all_transforms[idx]):
379+
match_idx = idx
380+
break
381+
else:
382+
self.check_transforms_match(all_transforms[-1])
355383

356-
return all_transforms.pop(-1) if pop else all_transforms[-1]
384+
return all_transforms.pop(match_idx) if pop else all_transforms[match_idx]
357385

358386
def pop_transform(self, data, key: Hashable = None, check: bool = True):
359387
"""

monai/transforms/io/array.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
"""
1212
A collection of "vanilla" transforms for IO functions.
1313
"""
14+
1415
from __future__ import annotations
1516

1617
import inspect

monai/transforms/utility/array.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -702,7 +702,7 @@ def __init__(
702702
# if the root log level is higher than INFO, set a separate stream handler to record
703703
console = logging.StreamHandler(sys.stdout)
704704
console.setLevel(logging.INFO)
705-
console.is_data_stats_handler = True # type:ignore[attr-defined]
705+
console.is_data_stats_handler = True # type: ignore[attr-defined]
706706
_logger.addHandler(console)
707707

708708
def __call__(

tests/integration/test_loader_semaphore.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
# limitations under the License.
1111
"""this test should not generate errors or
1212
UserWarning: semaphore_tracker: There appear to be 1 leaked semaphores"""
13+
1314
from __future__ import annotations
1415

1516
import multiprocessing as mp

0 commit comments

Comments
 (0)