@@ -268,11 +268,10 @@ class TestLoad(unittest.TestCase):
268268 @skip_if_quick
269269 def test_load_weights (self , bundle_files , bundle_name , repo , device , model_file ):
270270 with skip_if_downloading_fails ():
271- # download bundle, and load weights from the downloaded path
272271 with tempfile .TemporaryDirectory () as tempdir :
273272 bundle_root = os .path .join (tempdir , bundle_name )
274273 # load weights
275- weights = load (
274+ model_1 = load (
276275 name = bundle_name ,
277276 model_file = model_file ,
278277 bundle_dir = tempdir ,
@@ -288,7 +287,7 @@ def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file)
288287 del net_args ["_target_" ]
289288 model = getattr (nets , model_name )(** net_args )
290289 model .to (device )
291- model .load_state_dict (weights )
290+ model .load_state_dict (model_1 )
292291 model .eval ()
293292
294293 # prepare data and test
@@ -334,6 +333,7 @@ def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file)
334333 output_3 = model_3 .forward (input_tensor )
335334 assert_allclose (output_3 , expected_output , atol = 1e-4 , rtol = 1e-4 , type_test = False )
336335
336+
337337 @parameterized .expand ([TEST_CASE_8 ])
338338 @skip_if_quick
339339 @skipUnless (has_huggingface_hub , "Requires `huggingface_hub`." )
@@ -369,7 +369,6 @@ def test_load_weights_with_net_override(self, bundle_name, device, net_override)
369369 source = "monaihosting" ,
370370 progress = False ,
371371 device = device ,
372- return_state_dict = False ,
373372 net_override = net_override ,
374373 )
375374
0 commit comments