Skip to content

Commit b21c695

Browse files
committed
test: safe open with np
1 parent 8186fa9 commit b21c695

1 file changed

Lines changed: 13 additions & 5 deletions

File tree

src/MaxText/utils/ckpt_conversion/to_maxtext.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@
9696
absl.logging.set_verbosity(absl.logging.INFO) # for max_logging.log
9797

9898

99-
def get_hf_dict_from_safetensor(local_path):
99+
def get_hf_dict_from_safetensor(local_path, framework="pt"):
100100
"""
101101
If the safetensor contains more HF keys than MaxText model,
102102
these HF keys will be loaded but ignored during conversion.
@@ -109,7 +109,7 @@ def get_hf_dict_from_safetensor(local_path):
109109
max_logging.log(f"Loading {len(ckpt_paths)} checkpoints")
110110
for i, ckpt_path in tqdm(enumerate(ckpt_paths), total=len(ckpt_paths)):
111111
# max_logging.log(f"Loading checkpoint {i+1} of {len(ckpt_paths)} ...")
112-
with safe_open(ckpt_path, framework="pt", device="cpu") as f:
112+
with safe_open(ckpt_path, framework=framework, device="cpu") as f:
113113
for key in f.keys():
114114
if key.endswith("_scale_inv"):
115115
raise ValueError("fp8 checkpoint is not supported.")
@@ -658,9 +658,12 @@ def main(args: Sequence[str], test_args: Sequence[str]) -> None:
658658
hf_state_dict_numpy = get_hf_dict_from_pretrained(model_id, token=hf_token)
659659
else:
660660
hf_state_dict_numpy = get_hf_dict_from_pretrained(model_id, token=hf_token, dtype=torch.bfloat16)
661+
elif test_args.mode == "default-2":
662+
max_logging.log(f"Loading with `safe_open`, framework=np")
663+
hf_state_dict_numpy = get_hf_dict_from_safetensor(model_id, framework="np")
661664
else:
662-
max_logging.log(f"Loading with `safe_open`")
663-
hf_state_dict_numpy = get_hf_dict_from_safetensor(model_id)
665+
max_logging.log(f"Loading with `safe_open`, framework=pt")
666+
hf_state_dict_numpy = get_hf_dict_from_safetensor(model_id, framework="pt")
664667
# print(hf_state_dict_numpy)
665668

666669
unique_dtypes = {tensor.dtype for tensor in hf_state_dict_numpy.values()}
@@ -681,6 +684,8 @@ def _eager_getter(key):
681684
raise ValueError(f"HuggingFace key {key} not found in state_dict.")
682685
if test_args.mode == "default":
683686
return hf_state_dict_numpy[key]
687+
elif test_args.mode == "default-2":
688+
return hf_state_dict_numpy[key]
684689
elif test_args.mode == "float16":
685690
# torch.bfloat16 -> torch.float16 -> np.float16
686691
return hf_state_dict_numpy[key].to(torch.float16).numpy()
@@ -830,7 +835,7 @@ def _eager_getter(key):
830835
type=str,
831836
required=False,
832837
default="default",
833-
choices=["default", "float16", "bfloat16-1", "bfloat16-2"],
838+
choices=["default", "default-2", "float16", "bfloat16-1", "bfloat16-2"],
834839
help="",
835840
)
836841

@@ -844,6 +849,9 @@ def _eager_getter(key):
844849
assert local_args.hf_model_path != ""
845850
assert local_args.mode != "default"
846851

852+
if local_args.use_from_pretrained_api:
853+
assert local_args.mode != "default-2"
854+
847855
# Set jax environment
848856
jax.config.update("jax_platforms", "cpu")
849857
os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={local_args.simulated_cpu_devices_count}"

0 commit comments

Comments
 (0)