Skip to content

Commit 272355d

Browse files
committed
feat: add header format check and key __metadata__ ignored
1 parent af99ce3 commit 272355d

1 file changed

Lines changed: 18 additions & 1 deletion

File tree

checkpoint_engine/ps.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import zmq
2020
from loguru import logger
2121
from pydantic import BaseModel, PlainSerializer, PlainValidator, WithJsonSchema
22-
from safetensors.torch import _getdtype, safe_open
22+
from safetensors.torch import _TYPES, _getdtype, safe_open
2323
from torch.multiprocessing.reductions import reduce_tensor
2424

2525
from checkpoint_engine.device_utils import DeviceManager, get_ip, npu_generate_uuid
@@ -499,6 +499,23 @@ def _inplace_pin_memory(file_path: str) -> MemoryBuffer:
499499
time.sleep(3)
500500
header_tensor = t[flag_size:start_pos]
501501
header = json.loads(header_tensor.numpy().tobytes())
502+
if "__metadata__" in header:
503+
header.pop("__metadata__")
504+
505+
# check the header format
506+
for name, meta in header.items():
507+
assert isinstance(name, str), f"parameter name {name} should be str"
508+
assert isinstance(meta, dict), f"parameter meta {meta} should be dict"
509+
assert "data_offsets" in meta, f"data_offsets not in meta for parameter {name}"
510+
assert "shape" in meta, f"shape not in meta for parameter {name}"
511+
assert "dtype" in meta, f"dtype not in meta for parameter {name}"
512+
assert meta["dtype"] in _TYPES, (
513+
f"dtype {meta['dtype']} not a valid safetensors dtype, supported dtypes: {_TYPES.keys()}"
514+
)
515+
assert isinstance(meta["shape"], list), f"shape {meta['shape']} should be list"
516+
assert isinstance(meta["data_offsets"], list) and len(meta["data_offsets"]) == 2, (
517+
f"data_offsets {meta['data_offsets']} should be list of length 2"
518+
)
502519

503520
metas: list[ParameterMeta] = []
504521
offset = 0

0 commit comments

Comments
 (0)