1919
2020from collections .abc import Callable , Hashable , Mapping , Sequence
2121from copy import deepcopy
22- from typing import Any
22+ from typing import Any , Optional , Union , cast
2323
2424import numpy as np
2525import torch
5050from monai .transforms .traits import LazyTrait , MultiSampleTrait
5151from monai .transforms .transform import LazyTransform , MapTransform , Randomizable
5252from monai .transforms .utils import is_positive
53- from monai .utils import MAX_SEED , Method , PytorchPadMode , ensure_tuple_rep
53+ from monai .utils import MAX_SEED , Method , PytorchPadMode , TraceKeys , ensure_tuple_rep
5454
5555__all__ = [
5656 "Padd" ,
@@ -431,17 +431,33 @@ class SpatialCropd(Cropd):
431431 - a spatial center and size
432432 - the start and end coordinates of the ROI
433433
434+ ROI parameters (``roi_center``, ``roi_size``, ``roi_start``, ``roi_end``) can also be specified as
435+ string dictionary keys. When a string is provided, the actual coordinate values are read from the
436+ data dictionary at call time. This enables pipelines where coordinates are computed by earlier
437+ transforms (e.g., :py:class:`monai.transforms.TransformPointsWorldToImaged`) and stored in the
438+ data dictionary under the given key.
439+
440+ Example::
441+
442+ from monai.transforms import Compose, TransformPointsWorldToImaged, SpatialCropd
443+
444+ pipeline = Compose([
445+ TransformPointsWorldToImaged(keys="roi_start", refer_keys="image"),
446+ TransformPointsWorldToImaged(keys="roi_end", refer_keys="image"),
447+ SpatialCropd(keys="image", roi_start="roi_start", roi_end="roi_end"),
448+ ])
449+
434450 This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`
435451 for more information.
436452 """
437453
438454 def __init__ (
439455 self ,
440456 keys : KeysCollection ,
441- roi_center : Sequence [int ] | int | None = None ,
442- roi_size : Sequence [int ] | int | None = None ,
443- roi_start : Sequence [int ] | int | None = None ,
444- roi_end : Sequence [int ] | int | None = None ,
457+ roi_center : Sequence [int ] | int | str | None = None ,
458+ roi_size : Sequence [int ] | int | str | None = None ,
459+ roi_start : Sequence [int ] | int | str | None = None ,
460+ roi_end : Sequence [int ] | int | str | None = None ,
445461 roi_slices : Sequence [slice ] | None = None ,
446462 allow_missing_keys : bool = False ,
447463 lazy : bool = False ,
@@ -450,19 +466,134 @@ def __init__(
450466 Args:
451467 keys: keys of the corresponding items to be transformed.
452468 See also: :py:class:`monai.transforms.compose.MapTransform`
453- roi_center: voxel coordinates for center of the crop ROI.
469+ roi_center: voxel coordinates for center of the crop ROI, or a string key to look up
470+ the coordinates from the data dictionary.
454471 roi_size: size of the crop ROI, if a dimension of ROI size is larger than image size,
455- will not crop that dimension of the image.
456- roi_start: voxel coordinates for start of the crop ROI.
472+ will not crop that dimension of the image. Can also be a string key.
473+ roi_start: voxel coordinates for start of the crop ROI, or a string key to look up
474+ the coordinates from the data dictionary.
457475 roi_end: voxel coordinates for end of the crop ROI, if a coordinate is out of image,
458- use the end coordinate of image.
476+ use the end coordinate of image. Can also be a string key.
459477 roi_slices: list of slices for each of the spatial dimensions.
460478 allow_missing_keys: don't raise exception if key is missing.
461479 lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False.
462480 """
463- cropper = SpatialCrop (roi_center , roi_size , roi_start , roi_end , roi_slices , lazy = lazy )
481+ self ._roi_center = roi_center
482+ self ._roi_size = roi_size
483+ self ._roi_start = roi_start
484+ self ._roi_end = roi_end
485+ self ._roi_slices = roi_slices
486+ self ._has_str_roi = any (isinstance (v , str ) for v in [roi_center , roi_size , roi_start , roi_end ])
487+
488+ if not self ._has_str_roi :
489+ _roi_t = Optional [Union [Sequence [int ], int ]]
490+ cropper = SpatialCrop (
491+ cast (_roi_t , roi_center ),
492+ cast (_roi_t , roi_size ),
493+ cast (_roi_t , roi_start ),
494+ cast (_roi_t , roi_end ),
495+ roi_slices ,
496+ lazy = lazy ,
497+ )
498+ else :
499+ # Placeholder cropper for the string-key path. Replaced on self.cropper at
500+ # __call__ time once string keys are resolved from the data dictionary.
501+ cropper = SpatialCrop (roi_start = [0 ], roi_end = [1 ], lazy = lazy )
464502 super ().__init__ (keys , cropper = cropper , allow_missing_keys = allow_missing_keys , lazy = lazy )
465503
504+ @staticmethod
505+ def _resolve_roi_param (val , d ):
506+ """Resolve an ROI parameter from the data dictionary if it is a string key.
507+
508+ Args:
509+ val: the ROI parameter value. If a string, it is used as a key to look up
510+ the actual value from ``d``. Otherwise returned as-is.
511+ d: the data dictionary.
512+
513+ Returns:
514+ The resolved ROI parameter. Tensors and numpy arrays are flattened to 1-D
515+ and rounded to int64 so they can be consumed by ``Crop.compute_slices``.
516+
517+ Raises:
518+ KeyError: if ``val`` is a string key that does not exist in ``d``.
519+ """
520+ if not isinstance (val , str ):
521+ return val
522+ if val not in d :
523+ raise KeyError (f"ROI key '{ val } ' not found in the data dictionary." )
524+ resolved = d [val ]
525+ # ApplyTransformToPoints outputs tensors of shape (C, N, dims).
526+ # A single coordinate like [142.5, -67.3, 301.8] becomes shape (1, 1, 3).
527+ # Flatten to 1-D and round to integers for compute_slices.
528+ # Uses banker's rounding (torch.round) to avoid systematic bias in spatial coordinates.
529+ if isinstance (resolved , np .ndarray ):
530+ resolved = torch .from_numpy (resolved )
531+ if isinstance (resolved , torch .Tensor ):
532+ resolved = torch .round (resolved .flatten ()).to (torch .int64 )
533+ return resolved
534+
535+ @property
536+ def requires_current_data (self ) -> bool :
537+ """Returns True if ROI values are derived from dictionary members, False if constant members."""
538+ return self ._has_str_roi
539+
540+ def __call__ (self , data : Mapping [Hashable , torch .Tensor ], lazy : bool | None = None ) -> dict [Hashable , torch .Tensor ]:
541+ """
542+ Args:
543+ data: dictionary of data items to be transformed.
544+ lazy: whether to execute lazily. If ``None``, uses the instance default.
545+
546+ Returns:
547+ Dictionary with cropped data for each key.
548+ """
549+ if not self .requires_current_data :
550+ return super ().__call__ (data , lazy = lazy )
551+
552+ d = dict (data )
553+ roi_center = self ._resolve_roi_param (self ._roi_center , d )
554+ roi_size = self ._resolve_roi_param (self ._roi_size , d )
555+ roi_start = self ._resolve_roi_param (self ._roi_start , d )
556+ roi_end = self ._resolve_roi_param (self ._roi_end , d )
557+
558+ lazy_ = self .lazy if lazy is None else lazy
559+ cropper = SpatialCrop (
560+ roi_center = roi_center ,
561+ roi_size = roi_size ,
562+ roi_start = roi_start ,
563+ roi_end = roi_end ,
564+ roi_slices = self ._roi_slices ,
565+ lazy = lazy_ ,
566+ )
567+ for key in self .key_iterator (d ):
568+ d [key ] = cropper (d [key ], lazy = lazy_ )
569+ return d
570+
571+ def inverse (self , data : Mapping [Hashable , MetaTensor ]) -> dict [Hashable , MetaTensor ]:
572+ """
573+ Inverse of the crop transform, restoring the original spatial dimensions via padding.
574+
575+ For the string-key path, the cropper used in ``__call__`` is a per-invocation local
576+ instance, so its ``id()`` won't match the one stored in the MetaTensor's transform stack.
577+ This override bypasses the ID check and applies the inverse directly using the crop info
578+ stored in the MetaTensor.
579+
580+ Args:
581+ data: dictionary of cropped ``MetaTensor`` items.
582+
583+ Returns:
584+ Dictionary with inverse-transformed (padded) data for each key.
585+ """
586+ if not self .requires_current_data :
587+ return super ().inverse (data )
588+ d = dict (data )
589+ for key in self .key_iterator (d ):
590+ transform = self .cropper .pop_transform (d [key ], check = False )
591+ cropped = transform [TraceKeys .EXTRA_INFO ]["cropped" ]
592+ inverse_transform = BorderPad (cropped )
593+ with inverse_transform .trace_transform (False ):
594+ d [key ] = inverse_transform (d [key ]) # type: ignore[assignment]
595+ return d
596+
466597
467598class CenterSpatialCropd (Cropd ):
468599 """
0 commit comments