@@ -11,7 +11,7 @@ class MyError(Exception):
1111
1212
1313@pytest .mark .parametrize (
14- "dir_name, safetensor_filenames, expected_safetensor_filenames" ,
14+ "dir_name, safetensor_filenames, expected_safetensor_filenames, use_consolidated " ,
1515 [
1616 (
1717 "foo" ,
@@ -21,6 +21,18 @@ class MyError(Exception):
2121 "consolidated.safetensors" ,
2222 ],
2323 ["model-00001-of-00002.safetensors" , "model-000002-of-00002.safetensors" ],
24+ False ,
25+ ),
26+ # If use_consolidated specified explicitly.
27+ (
28+ "foo" ,
29+ [
30+ "model-00001-of-00002.safetensors" ,
31+ "model-000002-of-00002.safetensors" ,
32+ "consolidated.safetensors" ,
33+ ],
34+ ["consolidated.safetensors" ],
35+ True ,
2436 ),
2537 (
2638 "foo" ,
@@ -29,12 +41,14 @@ class MyError(Exception):
2941 "foo-consolidated.safetensors" ,
3042 ],
3143 [f"model-0000{ i } -of-00010.safetensors" for i in range (1 , 11 )],
44+ False ,
3245 ),
3346 # If there is only a consolidated safetensor, that one should still be used.
3447 (
3548 "foo" ,
3649 ["consolidated.safetensors" ],
3750 ["consolidated.safetensors" ],
51+ False ,
3852 ),
3953 # If the directory contains "consolidated" in its name, but its contents are sharded tensors.
4054 (
@@ -45,6 +59,7 @@ class MyError(Exception):
4559 "consolidated.safetensors" ,
4660 ],
4761 ["model-00001-of-00002.safetensors" , "model-000002-of-00002.safetensors" ],
62+ False ,
4863 ),
4964 ],
5065)
@@ -53,6 +68,7 @@ def test_load_weights_ignores_consolidated_ckpt_when_sharded_ckpt_exists(
5368 dir_name : str ,
5469 safetensor_filenames : list [str ],
5570 expected_safetensor_filenames : list [str ],
71+ use_consolidated : bool ,
5672):
5773 checkpoint_dir = tmp_path / dir_name
5874 checkpoint_dir .mkdir ()
@@ -70,7 +86,9 @@ def test_load_weights_ignores_consolidated_ckpt_when_sharded_ckpt_exists(
7086 mock .patch .object (loader , "prefetch_files" ) as prefetch_files ,
7187 pytest .raises (MyError ),
7288 ):
73- loader .load_weights (checkpoint_dir = str (checkpoint_dir ), mapping = Mapping ())
89+ loader .load_weights (
90+ checkpoint_dir = str (checkpoint_dir ), mapping = Mapping (), use_consolidated = use_consolidated
91+ )
7492
7593 prefetch_files .assert_called_once ()
7694 prefetched_files = prefetch_files .call_args [0 ][0 ]
0 commit comments