|
19 | 19 |
|
20 | 20 | from collections.abc import Callable, Hashable, Mapping, Sequence |
21 | 21 | from copy import deepcopy |
22 | | -from typing import Any |
| 22 | +from typing import Any, Optional, Union, cast |
23 | 23 |
|
24 | 24 | import numpy as np |
25 | 25 | import torch |
@@ -486,7 +486,15 @@ def __init__( |
486 | 486 | self._has_str_roi = any(isinstance(v, str) for v in [roi_center, roi_size, roi_start, roi_end]) |
487 | 487 |
|
488 | 488 | if not self._has_str_roi: |
489 | | - cropper = SpatialCrop(roi_center, roi_size, roi_start, roi_end, roi_slices, lazy=lazy) |
| 489 | + _Roi = Optional[Union[Sequence[int], int]] |
| 490 | + cropper = SpatialCrop( |
| 491 | + cast(_Roi, roi_center), |
| 492 | + cast(_Roi, roi_size), |
| 493 | + cast(_Roi, roi_start), |
| 494 | + cast(_Roi, roi_end), |
| 495 | + roi_slices, |
| 496 | + lazy=lazy, |
| 497 | + ) |
490 | 498 | else: |
491 | 499 | # Placeholder cropper for the string-key path. Replaced on self.cropper at |
492 | 500 | # __call__ time once string keys are resolved from the data dictionary. |
@@ -583,7 +591,7 @@ def inverse(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTen |
583 | 591 | cropped = transform[TraceKeys.EXTRA_INFO]["cropped"] |
584 | 592 | inverse_transform = BorderPad(cropped) |
585 | 593 | with inverse_transform.trace_transform(False): |
586 | | - d[key] = inverse_transform(d[key]) |
| 594 | + d[key] = inverse_transform(d[key]) # type: ignore[assignment] |
587 | 595 | return d |
588 | 596 |
|
589 | 597 |
|
|
0 commit comments