@@ -47,6 +47,19 @@ class SwinUNETR(nn.Module):
4747 Swin UNETR based on: "Hatamizadeh et al.,
4848 Swin UNETR: Swin Transformers for Semantic Segmentation of Brain Tumors in MRI Images
4949 <https://arxiv.org/abs/2201.01266>"
50+
51+ Spatial Shape Constraints:
52+ Each spatial dimension of the input must be divisible by ``patch_size ** 5``.
53+ With the default ``patch_size=2``, this means each spatial dimension must be divisible by **32**
54+ (i.e., 2^5 = 32). This requirement comes from the patch embedding step followed by 4 stages
55+ of PatchMerging downsampling, each halving the spatial resolution.
56+
57+ For a custom ``patch_size``, the divisibility requirement is ``patch_size ** 5``.
58+
59+ Examples of valid 3D input sizes (with default ``patch_size=2``):
60+ ``(32, 32, 32)``, ``(64, 64, 64)``, ``(96, 96, 96)``, ``(128, 128, 128)``, ``(64, 32, 192)``.
61+
62+ A ``ValueError`` is raised in ``forward()`` if the input spatial shape violates this constraint.
5063 """
5164
5265 def __init__ (
@@ -76,7 +89,8 @@ def __init__(
7689 Args:
7790 in_channels: dimension of input channels.
7891 out_channels: dimension of output channels.
79- patch_size: size of the patch token.
92+ patch_size: size of the patch token. Input spatial dimensions must be divisible by
93+ ``patch_size ** 5`` (e.g., divisible by 32 when ``patch_size=2``).
8094 feature_size: dimension of network feature size.
8195 depths: number of layers in each stage.
8296 num_heads: number of attention heads.
@@ -108,6 +122,10 @@ def __init__(
108122 # for 2D single channel input with size (96,96), 2-channel output and gradient checkpointing.
109123 >>> net = SwinUNETR(in_channels=3, out_channels=2, use_checkpoint=True, spatial_dims=2)
110124
125+ Raises:
126+ ValueError: When a spatial dimension of the input is not divisible by ``patch_size ** 5``.
127+ Use ``net._check_input_size(spatial_shape)`` to validate a shape before inference.
128+
111129 """
112130
113131 super ().__init__ ()
0 commit comments