Skip to content

Commit 815c500

Browse files
committed
fix ci and format
Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
1 parent 0c98432 commit 815c500

4 files changed

Lines changed: 11 additions & 31 deletions

File tree

monai/bundle/scripts.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -647,6 +647,7 @@ def load(
647647
workflow_name: str | BundleWorkflow | None = None,
648648
args_file: str | None = None,
649649
copy_model_args: dict | None = None,
650+
net_override: dict | None = None,
650651
) -> object | tuple[torch.nn.Module, dict, dict] | Any:
651652
"""
652653
Load model weights or TorchScript module of a bundle.
@@ -692,7 +693,7 @@ def load(
692693
workflow_name: specified bundle workflow name, should be a string or class, default to "ConfigWorkflow".
693694
args_file: a JSON or YAML file to provide default values for all the args in "download" function.
694695
copy_model_args: other arguments for the `monai.networks.copy_model_state` function.
695-
net_kwargs: other arguments that are used to instantiate the network class defined by `net_name`.
696+
net_override: id-value pairs to override the parameters in the network of the bundle, default to `None`.
696697
697698
Returns:
698699
1. If `load_ts_module` is `False` and `model` is `None`,

monai/networks/nets/swin_unetr.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -101,13 +101,13 @@ def __init__(
101101
Examples::
102102
103103
# for 3D single channel input with size (96,96,96), 4-channel output and feature size of 48.
104-
>>> net = SwinUNETR(img_size=(96,96,96), in_channels=1, out_channels=4, feature_size=48)
104+
>>> net = SwinUNETR(in_channels=1, out_channels=4, feature_size=48)
105105
106106
# for 3D 4-channel input with size (128,128,128), 3-channel output and (2,4,2,2) layers in each stage.
107-
>>> net = SwinUNETR(img_size=(128,128,128), in_channels=4, out_channels=3, depths=(2,4,2,2))
107+
>>> net = SwinUNETR(in_channels=4, out_channels=3, depths=(2,4,2,2))
108108
109109
# for 2D single channel input with size (96,96), 2-channel output and gradient checkpointing.
110-
>>> net = SwinUNETR(img_size=(96,96), in_channels=3, out_channels=2, use_checkpoint=True, spatial_dims=2)
110+
>>> net = SwinUNETR(in_channels=3, out_channels=2, use_checkpoint=True, spatial_dims=2)
111111
112112
"""
113113

@@ -118,12 +118,9 @@ def __init__(
118118

119119
self.patch_size = patch_size
120120

121-
img_size = ensure_tuple_rep(img_size, spatial_dims)
122121
patch_sizes = ensure_tuple_rep(self.patch_size, spatial_dims)
123122
window_size = ensure_tuple_rep(window_size, spatial_dims)
124123

125-
self._check_input_size(img_size)
126-
127124
if not (0 <= drop_rate <= 1):
128125
raise ValueError("dropout rate should be between 0 and 1.")
129126

@@ -1097,7 +1094,7 @@ def filter_swinunetr(key, value):
10971094
from monai.networks.utils import copy_model_state
10981095
from monai.networks.nets.swin_unetr import SwinUNETR, filter_swinunetr
10991096
1100-
model = SwinUNETR(img_size=(96, 96, 96), in_channels=1, out_channels=3, feature_size=48)
1097+
model = SwinUNETR(in_channels=1, out_channels=3, feature_size=48)
11011098
resource = (
11021099
"https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/ssl_pretrained_weights.pth"
11031100
)

monai/utils/jupyter_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ def plot_engine_status(
234234

235235

236236
def _get_loss_from_output(
237-
output: list[torch.Tensor | dict[str, torch.Tensor]] | dict[str, torch.Tensor] | torch.Tensor,
237+
output: list[torch.Tensor | dict[str, torch.Tensor]] | dict[str, torch.Tensor] | torch.Tensor
238238
) -> torch.Tensor:
239239
"""Returns a single value from the network output, which is a dict or tensor."""
240240

tests/networks/nets/test_swin_unetr.py

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@
5151
"spatial_dims": len(img_size),
5252
"in_channels": in_channels,
5353
"out_channels": out_channels,
54-
"img_size": img_size,
5554
"feature_size": feature_size,
5655
"depths": depth,
5756
"norm_name": norm_name,
@@ -67,7 +66,7 @@
6766

6867
TEST_CASE_FILTER = [
6968
[
70-
{"img_size": (96, 96, 96), "in_channels": 1, "out_channels": 14, "feature_size": 48, "use_checkpoint": True},
69+
{"in_channels": 1, "out_channels": 14, "feature_size": 48, "use_checkpoint": True},
7170
"swinViT.layers1.0.blocks.0.norm1.weight",
7271
torch.tensor([0.9473, 0.9343, 0.8566, 0.8487, 0.8065, 0.7779, 0.6333, 0.5555]),
7372
]
@@ -85,30 +84,13 @@ def test_shape(self, input_param, input_shape, expected_shape):
8584

8685
def test_ill_arg(self):
8786
with self.assertRaises(ValueError):
88-
SwinUNETR(
89-
in_channels=1,
90-
out_channels=3,
91-
img_size=(128, 128, 128),
92-
feature_size=24,
93-
norm_name="instance",
94-
attn_drop_rate=4,
95-
)
87+
SwinUNETR(spatial_dims=1, in_channels=1, out_channels=2, feature_size=48, norm_name="instance")
9688

9789
with self.assertRaises(ValueError):
98-
SwinUNETR(in_channels=1, out_channels=2, img_size=(96, 96), feature_size=48, norm_name="instance")
90+
SwinUNETR(in_channels=1, out_channels=4, feature_size=50, norm_name="instance")
9991

10092
with self.assertRaises(ValueError):
101-
SwinUNETR(in_channels=1, out_channels=4, img_size=(96, 96, 96), feature_size=50, norm_name="instance")
102-
103-
with self.assertRaises(ValueError):
104-
SwinUNETR(
105-
in_channels=1,
106-
out_channels=3,
107-
img_size=(85, 85, 85),
108-
feature_size=24,
109-
norm_name="instance",
110-
drop_rate=0.4,
111-
)
93+
SwinUNETR(in_channels=1, out_channels=3, feature_size=24, norm_name="instance", drop_rate=-1)
11294

11395
def test_patch_merging(self):
11496
dim = 10

0 commit comments

Comments
 (0)