diff --git a/fastsafetensors/common.py b/fastsafetensors/common.py index f179417..7d5b694 100644 --- a/fastsafetensors/common.py +++ b/fastsafetensors/common.py @@ -212,7 +212,7 @@ def __init__( self.framework = framework ser = json.loads(string, object_pairs_hook=OrderedDict) self.metadata = ser.get("__metadata__", "") - if self.metadata: + if "__metadata__" in ser: del ser["__metadata__"] self.tensors: Dict[str, TensorFrame] = {} self.header_length = header_length diff --git a/tests/unit/test_fastsafetensors.py b/tests/unit/test_fastsafetensors.py index a051aee..26e9903 100644 --- a/tests/unit/test_fastsafetensors.py +++ b/tests/unit/test_fastsafetensors.py @@ -950,6 +950,23 @@ def dribbling_read(fd: int, length: int) -> bytes: assert meta.tensors["a0"].shape == [4, 8] +def test_from_file_empty_metadata(tmp_dir, framework) -> None: + # Regression test: a safetensors file whose header carries an empty + # __metadata__: {} object must parse without raising. Previously, + # __metadata__ was stripped only when truthy, so {} survived into the + # tensor sort and raised KeyError: 'data_offsets' (see e.g. loading + # MiniMaxAI/MiniMax-M3-MXFP8 via vLLM --load-format fastsafetensors). + device, _ = get_and_check_device(framework) + filename = os.path.join(tmp_dir, "empty_metadata.safetensors") + a0 = framework.randn((4, 8), device=device, dtype=DType.F32) + save_safetensors_file({"a0": a0.get_raw()}, filename, {}, framework) + + meta = SafeTensorsMetadata.from_file(filename, framework) + assert "a0" in meta.tensors + assert meta.tensors["a0"].shape == [4, 8] + assert meta.metadata == {} + + def test_no_module_level_torch_import_outside_frameworks() -> None: # Policy: framework-specific imports live behind the frameworks # abstraction. Only the torch backend (frameworks/_torch.py) may import