|
19 | 19 | import zmq |
20 | 20 | from loguru import logger |
21 | 21 | from pydantic import BaseModel, PlainSerializer, PlainValidator, WithJsonSchema |
22 | | -from safetensors.torch import _getdtype, safe_open |
| 22 | +from safetensors.torch import _TYPES, _getdtype, safe_open |
23 | 23 | from torch.multiprocessing.reductions import reduce_tensor |
24 | 24 |
|
25 | 25 | from checkpoint_engine.device_utils import DeviceManager, get_ip, npu_generate_uuid |
@@ -499,6 +499,23 @@ def _inplace_pin_memory(file_path: str) -> MemoryBuffer: |
499 | 499 | time.sleep(3) |
500 | 500 | header_tensor = t[flag_size:start_pos] |
501 | 501 | header = json.loads(header_tensor.numpy().tobytes()) |
| 502 | + if "__metadata__" in header: |
| 503 | + header.pop("__metadata__") |
| 504 | + |
| 505 | + # check the header format |
| 506 | + for name, meta in header.items(): |
| 507 | + assert isinstance(name, str), f"parameter name {name} should be str" |
| 508 | + assert isinstance(meta, dict), f"parameter meta {meta} should be dict" |
| 509 | + assert "data_offsets" in meta, f"data_offsets not in meta for parameter {name}" |
| 510 | + assert "shape" in meta, f"shape not in meta for parameter {name}" |
| 511 | + assert "dtype" in meta, f"dtype not in meta for parameter {name}" |
| 512 | + assert meta["dtype"] in _TYPES, ( |
| 513 | + f"dtype {meta['dtype']} not a valid safetensors dtype, supported dtypes: {_TYPES.keys()}" |
| 514 | + ) |
| 515 | + assert isinstance(meta["shape"], list), f"shape {meta['shape']} should be list" |
| 516 | + assert isinstance(meta["data_offsets"], list) and len(meta["data_offsets"]) == 2, ( |
| 517 | + f"data_offsets {meta['data_offsets']} should be list of length 2" |
| 518 | + ) |
502 | 519 |
|
503 | 520 | metas: list[ParameterMeta] = [] |
504 | 521 | offset = 0 |
|
0 commit comments