@@ -101,13 +101,13 @@ def __init__(
101101 Examples::
102102
103103 # for 3D single channel input with size (96,96,96), 4-channel output and feature size of 48.
104- >>> net = SwinUNETR(img_size=(96,96,96), in_channels=1, out_channels=4, feature_size=48)
104+ >>> net = SwinUNETR(in_channels=1, out_channels=4, feature_size=48)
105105
106106 # for 3D 4-channel input with size (128,128,128), 3-channel output and (2,4,2,2) layers in each stage.
107- >>> net = SwinUNETR(img_size=(128,128,128), in_channels=4, out_channels=3, depths=(2,4,2,2))
107+ >>> net = SwinUNETR(in_channels=4, out_channels=3, depths=(2,4,2,2))
108108
109109 # for 2D single channel input with size (96,96), 2-channel output and gradient checkpointing.
110- >>> net = SwinUNETR(img_size=(96,96), in_channels=3, out_channels=2, use_checkpoint=True, spatial_dims=2)
110+ >>> net = SwinUNETR(in_channels=3, out_channels=2, use_checkpoint=True, spatial_dims=2)
111111
112112 """
113113
@@ -118,12 +118,9 @@ def __init__(
118118
119119 self .patch_size = patch_size
120120
121- img_size = ensure_tuple_rep (img_size , spatial_dims )
122121 patch_sizes = ensure_tuple_rep (self .patch_size , spatial_dims )
123122 window_size = ensure_tuple_rep (window_size , spatial_dims )
124123
125- self ._check_input_size (img_size )
126-
127124 if not (0 <= drop_rate <= 1 ):
128125 raise ValueError ("dropout rate should be between 0 and 1." )
129126
@@ -1097,7 +1094,7 @@ def filter_swinunetr(key, value):
10971094 from monai.networks.utils import copy_model_state
10981095 from monai.networks.nets.swin_unetr import SwinUNETR, filter_swinunetr
10991096
1100- model = SwinUNETR(img_size=(96, 96, 96), in_channels=1, out_channels=3, feature_size=48)
1097+ model = SwinUNETR(in_channels=1, out_channels=3, feature_size=48)
11011098 resource = (
11021099 "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/ssl_pretrained_weights.pth"
11031100 )
0 commit comments