Skip to content

Commit ba8fc01

Browse files
committed
fix #8453
Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
1 parent a170243 commit ba8fc01

2 files changed

Lines changed: 3 additions & 6 deletions

File tree

monai/bundle/scripts.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -702,8 +702,6 @@ def load(
702702
3. If `load_ts_module` is `True`, return a triple that include a TorchScript module,
703703
the corresponding metadata dict, and extra files dict.
704704
please check `monai.data.load_net_with_metadata` for more details.
705-
4. If `return_state_dict` is True, return model weights, only used for compatibility
706-
when `model` and `net_name` are all `None`.
707705
708706
"""
709707
bundle_dir_ = _process_bundle_dir(bundle_dir)

tests/bundle/test_bundle_download.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -268,11 +268,10 @@ class TestLoad(unittest.TestCase):
268268
@skip_if_quick
269269
def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file):
270270
with skip_if_downloading_fails():
271-
# download bundle, and load weights from the downloaded path
272271
with tempfile.TemporaryDirectory() as tempdir:
273272
bundle_root = os.path.join(tempdir, bundle_name)
274273
# load weights
275-
weights = load(
274+
model_1 = load(
276275
name=bundle_name,
277276
model_file=model_file,
278277
bundle_dir=tempdir,
@@ -288,7 +287,7 @@ def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file)
288287
del net_args["_target_"]
289288
model = getattr(nets, model_name)(**net_args)
290289
model.to(device)
291-
model.load_state_dict(weights)
290+
model.load_state_dict(model_1)
292291
model.eval()
293292

294293
# prepare data and test
@@ -334,6 +333,7 @@ def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file)
334333
output_3 = model_3.forward(input_tensor)
335334
assert_allclose(output_3, expected_output, atol=1e-4, rtol=1e-4, type_test=False)
336335

336+
337337
@parameterized.expand([TEST_CASE_8])
338338
@skip_if_quick
339339
@skipUnless(has_huggingface_hub, "Requires `huggingface_hub`.")
@@ -369,7 +369,6 @@ def test_load_weights_with_net_override(self, bundle_name, device, net_override)
369369
source="monaihosting",
370370
progress=False,
371371
device=device,
372-
return_state_dict=False,
373372
net_override=net_override,
374373
)
375374

0 commit comments

Comments
 (0)