Skip to content

Cannot compile CenterSpatialCrop #8191

Description

@ziw-liu

Describe the bug
Using torch.compile to optimize MONAI transforms generally works (apart from graph breaks), but CenterSpatialCrop (and its dictionary wrapper) does not.

To Reproduce

import torch
from monai.data.meta_obj import set_track_meta
from monai.transforms import (
    CenterSpatialCrop,
    RandAdjustContrast,
    RandAffine,
    RandFlip,
    RandGaussianNoise,
    RandGaussianSmooth,
    RandScaleIntensity,
    RandSpatialCropSamples,
)

# avoid subclassing tensor
set_track_meta(False)

transforms = [
    RandAffine(
        prob=1.0,
        rotate_range=(torch.pi, 0, 0),
        scale_range=(0, 0.3, 0.3),
        padding_mode="zeros",
        mode="bilinear",
    ),
    CenterSpatialCrop(roi_size=(1, 256, 256)),
    RandSpatialCropSamples(roi_size=(1, 256, 256), num_samples=2),
    RandFlip(prob=0.5, spatial_axis=(1, 2)),
    RandAdjustContrast(prob=0.5, gamma=(0.8, 1.2)),
    RandScaleIntensity(factors=0.5, prob=0.5),
    RandGaussianNoise(prob=0.5, mean=0.0, std=0.3),
    RandGaussianSmooth(
        sigma_x=(0.25, 0.75),
        sigma_y=(0.25, 0.75),
        sigma_z=(0.0, 0.0),
        prob=0.5,
    ),
]

img = torch.rand(1, 1, 512, 512, dtype=torch.float32, device="cuda")


@torch.compile
def apply_transform(x, tf):
    tf(x)


for tf in transforms:
    try:
        apply_transform(img, tf)
        print(f"{type(tf)} compiled successfully.")
    except Exception as e:
        assert isinstance(tf, CenterSpatialCrop)
        print(f"Failed to compile {type(tf)}.")
        print(e)

This script shows this error message:

Failed to compile <class 'monai.transforms.croppad.array.CenterSpatialCrop'>.
Failed running call_function <built-in method as_tensor of type object at 0x7f147a4e8240>(*([FakeTensor(..., size=(), dtype=torch.int16), FakeTensor(..., size=(), dtype=torch.int16), FakeTensor(..., size=(), dtype=torch.int16)],), **{'dtype': torch.int16, 'device': 'cpu'}):
The tensor has a non-zero number of elements, but its data is not allocated yet.
If you're using torch.compile/export/fx, it is likely that we are erroneously tracing into a custom kernel. To fix this, please wrap the custom kernel into an opaque custom op. Please see the following for details: https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html
If you're using Caffe2, Caffe2 uses a lazy allocation, so you will need to call mutable_data() or raw_mutable_data() to actually allocate memory.

from user code:
   File "/home/user.name/viscy/viscy/scripts/bench_compile_transform.py", line 44, in apply_transform
    tf(x)
  File "/hpc/mydata/user.name/anaconda/2022.05/x86_64/envs/viscy/lib/python3.11/site-packages/monai/transforms/croppad/array.py", line 533, in __call__
    slices=self.compute_slices(img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:]),
  File "/hpc/mydata/user.name/anaconda/2022.05/x86_64/envs/viscy/lib/python3.11/site-packages/monai/transforms/croppad/array.py", line 522, in compute_slices
    return super().compute_slices(roi_center=roi_center, roi_size=roi_size)
  File "/hpc/mydata/user.name/anaconda/2022.05/x86_64/envs/viscy/lib/python3.11/site-packages/monai/transforms/croppad/array.py", line 392, in compute_slices
    roi_center_t = convert_to_tensor(data=roi_center, dtype=torch.int16, wrap_sequence=True, device="cpu")
  File "/hpc/mydata/user.name/anaconda/2022.05/x86_64/envs/viscy/lib/python3.11/site-packages/monai/utils/type_conversion.py", line 174, in convert_to_tensor
    return _convert_tensor(list_ret, dtype=dtype, device=device) if wrap_sequence else list_ret
  File "/hpc/mydata/user.name/anaconda/2022.05/x86_64/envs/viscy/lib/python3.11/site-packages/monai/utils/type_conversion.py", line 149, in _convert_tensor
    tensor = torch.as_tensor(tensor, **kwargs)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Expected behavior
CenterSpatialCrop can be compiled just as other transforms.

Environment

Ensuring you use the relevant python executable, please paste the output of:

================================
Printing MONAI config...
================================
MONAI version: 1.4.0
Numpy version: 1.26.4
Pytorch version: 2.5.0+cu124
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 46a5272196a6c2590ca2589029eed8e4d56ff008
MONAI __file__: /hpc/mydata/<username>/anaconda/2022.05/x86_64/envs/viscy/lib/python3.11/site-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
ITK version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: NOT INSTALLED or UNKNOWN VERSION.
scikit-image version: 0.24.0
scipy version: 1.14.0
Pillow version: 10.4.0
Tensorboard version: 2.17.1
gdown version: NOT INSTALLED or UNKNOWN VERSION.
TorchVision version: 0.20.0+cu124
tqdm version: 4.66.5
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 6.0.0
pandas version: 2.2.2
einops version: NOT INSTALLED or UNKNOWN VERSION.
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.
clearml version: NOT INSTALLED or UNKNOWN VERSION.

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies


================================
Printing system config...
================================
System: Linux
Linux version: Rocky Linux 8.10 (Green Obsidian)
Platform: Linux-4.18.0-553.16.1.el8_10.x86_64-x86_64-with-glibc2.28
Processor: x86_64
Machine: x86_64
Python version: 3.11.9
Process name: python
Command: ['python', '-c', 'import monai; monai.config.print_debug_info()']
Open files: [popenfile(path='/home/<username>/.vscode-server/data/logs/20241031T101853/remoteagent.log', fd=19, position=5336, mode='a', flags=33793), popenfile(path='/home/<username>/.vscode-server/data/logs/20241031T101853/ptyhost.log', fd=20, position=4686, mode='a', flags=33793)]
Num physical CPUs: 16
Num logical CPUs: 16
Num usable CPUs: 16
CPU usage (%): [8.5, 8.5, 3.5, 8.1, 3.9, 3.9, 3.2, 4.6, 5.0, 18.6, 23.9, 3.5, 3.5, 3.6, 4.3, 4.6]
CPU freq. (MHz): 2935
Load avg. in last 1, 5, 15 mins (%): [0.6, 0.5, 1.4]
Disk usage (%): 93.3
Avg. sensor temp. (Celsius): UNKNOWN for given OS
Total physical memory (GB): 503.8
Available memory (GB): 440.0
Used memory (GB): 27.3

================================
Printing GPU config...
================================
Num GPUs: 1
Has CUDA: True
CUDA version: 12.4
cuDNN enabled: True
NVIDIA_TF32_OVERRIDE: None
TORCH_ALLOW_TF32_CUBLAS_OVERRIDE: None
cuDNN version: 90100
Current device: 0
Library compiled for CUDA architectures: ['sm_50', 'sm_60', 'sm_70', 'sm_75', 'sm_80', 'sm_86', 'sm_90']
GPU 0 Name: NVIDIA A40
GPU 0 Is integrated: False
GPU 0 Is multi GPU board: False
GPU 0 Multi processor count: 84
GPU 0 Total memory (GB): 44.7
GPU 0 CUDA capability (maj.min): 8.6

Additional context
The error message points to the tensor conversion called in the Crop class. Curiously the other cropping transform (RandSpatialCropSamples) does work.

Edit: fix typo

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Fields

    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions