9696absl .logging .set_verbosity (absl .logging .INFO ) # for max_logging.log
9797
9898
99- def get_hf_dict_from_safetensor (local_path ):
99+ def get_hf_dict_from_safetensor (local_path , framework = "pt" ):
100100 """
101101 If the safetensor contains more HF keys than MaxText model,
102102 these HF keys will be loaded but ignored during conversion.
@@ -109,7 +109,7 @@ def get_hf_dict_from_safetensor(local_path):
109109 max_logging .log (f"Loading { len (ckpt_paths )} checkpoints" )
110110 for i , ckpt_path in tqdm (enumerate (ckpt_paths ), total = len (ckpt_paths )):
111111 # max_logging.log(f"Loading checkpoint {i+1} of {len(ckpt_paths)} ...")
112- with safe_open (ckpt_path , framework = "pt" , device = "cpu" ) as f :
112+ with safe_open (ckpt_path , framework = framework , device = "cpu" ) as f :
113113 for key in f .keys ():
114114 if key .endswith ("_scale_inv" ):
115115 raise ValueError ("fp8 checkpoint is not supported." )
@@ -658,9 +658,12 @@ def main(args: Sequence[str], test_args: Sequence[str]) -> None:
658658 hf_state_dict_numpy = get_hf_dict_from_pretrained (model_id , token = hf_token )
659659 else :
660660 hf_state_dict_numpy = get_hf_dict_from_pretrained (model_id , token = hf_token , dtype = torch .bfloat16 )
661+ elif test_args .mode == "default-2" :
662+ max_logging .log (f"Loading with `safe_open`, framework=np" )
663+ hf_state_dict_numpy = get_hf_dict_from_safetensor (model_id , framework = "np" )
661664 else :
662- max_logging .log (f"Loading with `safe_open`" )
663- hf_state_dict_numpy = get_hf_dict_from_safetensor (model_id )
665+ max_logging .log (f"Loading with `safe_open`, framework=pt " )
666+ hf_state_dict_numpy = get_hf_dict_from_safetensor (model_id , framework = "pt" )
664667 # print(hf_state_dict_numpy)
665668
666669 unique_dtypes = {tensor .dtype for tensor in hf_state_dict_numpy .values ()}
@@ -681,6 +684,8 @@ def _eager_getter(key):
681684 raise ValueError (f"HuggingFace key { key } not found in state_dict." )
682685 if test_args .mode == "default" :
683686 return hf_state_dict_numpy [key ]
687+ elif test_args .mode == "default-2" :
688+ return hf_state_dict_numpy [key ]
684689 elif test_args .mode == "float16" :
685690 # torch.bfloat16 -> torch.float16 -> np.float16
686691 return hf_state_dict_numpy [key ].to (torch .float16 ).numpy ()
@@ -830,7 +835,7 @@ def _eager_getter(key):
830835 type = str ,
831836 required = False ,
832837 default = "default" ,
833- choices = ["default" , "float16" , "bfloat16-1" , "bfloat16-2" ],
838+ choices = ["default" , "default-2" , " float16" , "bfloat16-1" , "bfloat16-2" ],
834839 help = "" ,
835840 )
836841
@@ -844,6 +849,9 @@ def _eager_getter(key):
844849 assert local_args .hf_model_path != ""
845850 assert local_args .mode != "default"
846851
852+ if local_args .use_from_pretrained_api :
853+ assert local_args .mode != "default-2"
854+
847855 # Set jax environment
848856 jax .config .update ("jax_platforms" , "cpu" )
849857 os .environ ["XLA_FLAGS" ] = f"--xla_force_host_platform_device_count={ local_args .simulated_cpu_devices_count } "
0 commit comments