Skip to content

Commit 83527c0

Browse files
author
YASH Nankani
committed
Address review comments
Signed-off-by: YASH Nankani <ynankani@2u1g-x570-0073.ipp2a1.colossus.nvidia.com>
1 parent 5f91a7e commit 83527c0

3 files changed

Lines changed: 43 additions & 17 deletions

File tree

modelopt/torch/export/diffusers_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -918,7 +918,11 @@ def _ceil_div(a: int, b: int) -> int:
918918
return (a + b - 1) // b
919919

920920
def _to_blocked(input_matrix: torch.Tensor) -> torch.Tensor:
921-
"""Rearrange scale matrix to cuBLAS 2-D block-scaling-factors layout."""
921+
"""Rearrange scale matrix to cuBLAS 2-D block-scaling-factors layout.
922+
923+
Note: rows are padded to multiples of 128 for cuBLAS alignment, so the
924+
output shape may differ from the input (e.g. (16, 4) -> (128, 4)).
925+
"""
922926
rows, cols = input_matrix.shape
923927
n_row_blocks = _ceil_div(rows, 128)
924928
n_col_blocks = _ceil_div(cols, 4)

modelopt/torch/export/unified_export_hf.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@
2828

2929
import torch
3030
import torch.nn as nn
31-
from safetensors.torch import load_file, save_file
31+
from safetensors import safe_open
32+
from safetensors.torch import save_file
3233

3334
from .diffusers_utils import build_layerwise_quant_metadata, pad_nvfp4_weights, swizzle_nvfp4_scales
3435

@@ -180,8 +181,6 @@ def _postprocess_safetensors(
180181
padding_strategy: ``"row"``, ``"row_col"``, or None.
181182
enable_swizzle_layout: Whether to swizzle block scales.
182183
"""
183-
import struct
184-
185184
safetensor_files = sorted(export_dir.glob("*.safetensors"))
186185
if not safetensor_files:
187186
return
@@ -195,22 +194,16 @@ def _postprocess_safetensors(
195194
)
196195

197196
for sf_path in safetensor_files:
198-
sd = load_file(str(sf_path))
199-
200-
with open(sf_path, "rb") as f:
201-
header_size = struct.unpack("<Q", f.read(8))[0]
202-
header = json.loads(f.read(header_size))
203-
metadata = header.get("__metadata__", None) or {}
204-
205-
# Clone tensors so the memory-mapped file handle from load_file is
206-
# released before we overwrite the same path (required on Windows).
207-
sd = {k: v.clone() for k, v in sd.items()}
197+
with safe_open(str(sf_path), framework="pt") as f:
198+
metadata = dict(f.metadata() or {})
199+
sd = {k: f.get_tensor(k).clone() for k in f.keys()}
208200

209201
if merged_base_safetensor_path is not None and model_type is not None:
210202
sd, base_metadata = merge_diffusion_checkpoint(
211-
sd, merged_base_safetensor_path, model_type, hf_quant_config
203+
sd, merged_base_safetensor_path, model_type, hf_quant_config=None
212204
)
213-
metadata.update(base_metadata)
205+
base_metadata.update(metadata)
206+
metadata = base_metadata
214207

215208
if padding_strategy is not None:
216209
sd = pad_nvfp4_weights(sd, padding_strategy)

tests/unit/torch/export/test_nvfp4_utils.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -200,6 +200,35 @@ def test_sharded_guard(self, tmp_path):
200200
enable_layerwise_quant_metadata=True,
201201
)
202202

203+
def test_preserves_existing_metadata(self, tmp_path):
204+
"""Simulate save_pretrained output: safetensors with pre-existing metadata."""
205+
from modelopt.torch.export.unified_export_hf import _postprocess_safetensors
206+
207+
sd = _make_nvfp4_state_dict(rows=20, cols=64)
208+
preexisting_metadata = {"format": "pt", "_class_name": "MyModel"}
209+
save_file(sd, str(tmp_path / "model.safetensors"), metadata=preexisting_metadata)
210+
211+
hf_quant_config = {"quant_algo": "NVFP4"}
212+
_postprocess_safetensors(
213+
tmp_path,
214+
hf_quant_config=hf_quant_config,
215+
padding_strategy="row",
216+
enable_swizzle_layout=True,
217+
enable_layerwise_quant_metadata=True,
218+
)
219+
220+
reloaded = load_file(str(tmp_path / "model.safetensors"))
221+
assert reloaded["layer0.weight"].shape[0] == 32
222+
assert reloaded["layer0.weight_scale"].shape == (128, 64 // 16)
223+
224+
with safe_open(str(tmp_path / "model.safetensors"), framework="pt") as f:
225+
metadata = f.metadata()
226+
assert metadata["format"] == "pt"
227+
assert metadata["_class_name"] == "MyModel"
228+
assert json.loads(metadata["quantization_config"]) == hf_quant_config
229+
layer_meta = json.loads(metadata["_quantization_metadata"])
230+
assert "layer0" in layer_meta["layers"]
231+
203232
def test_no_safetensor_files(self, tmp_path):
204233
from modelopt.torch.export.unified_export_hf import _postprocess_safetensors
205234

0 commit comments

Comments
 (0)