Skip to content

Commit bbfbdb9

Browse files
author
YASH Nankani
committed
Address review commnets
Signed-off-by: YASH Nankani <ynankani@2u1g-x570-0073.ipp2a1.colossus.nvidia.com>
1 parent 1468975 commit bbfbdb9

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

modelopt/torch/export/diffusers_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -889,7 +889,7 @@ def _roundup(a: int, b: int) -> int:
889889
(_roundup(scale.shape[1], 16) - scale.shape[1]) if padding_strategy == "row_col" else 0
890890
)
891891

892-
if pad_r > 0 or pad_c_w > 0:
892+
if pad_r > 0 or pad_c_w > 0 or pad_c_s > 0:
893893
state_dict[w_key] = torch.nn.functional.pad(weight, (0, pad_c_w, 0, pad_r))
894894
state_dict[s_key] = torch.nn.functional.pad(scale, (0, pad_c_s, 0, pad_r))
895895
padded_count += 1

tests/unit/torch/export/test_nvfp4_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import pytest
2121
import torch
22+
from safetensors import safe_open
2223
from safetensors.torch import load_file, save_file
2324

2425
from modelopt.torch.export.diffusers_utils import (
@@ -159,6 +160,13 @@ def test_metadata_injection(self, tmp_path):
159160

160161
reloaded = load_file(str(tmp_path / "model.safetensors"))
161162
assert torch.allclose(reloaded["weight"], sd["weight"])
163+
with safe_open(str(tmp_path / "model.safetensors"), framework="pt", device="cpu") as f:
164+
metadata = f.metadata()
165+
assert json.loads(metadata["quantization_config"]) == hf_quant_config
166+
assert json.loads(metadata["_quantization_metadata"]) == {
167+
"format_version": "1.0",
168+
"layers": {},
169+
}
162170

163171
def test_padding_and_swizzle(self, tmp_path):
164172
from modelopt.torch.export.unified_export_hf import _postprocess_safetensors
@@ -176,6 +184,7 @@ def test_padding_and_swizzle(self, tmp_path):
176184
reloaded = load_file(str(tmp_path / "model.safetensors"))
177185
assert reloaded["layer0.weight"].shape[0] == 32
178186
assert reloaded["layer0.weight_scale"].dtype == torch.float8_e4m3fn
187+
assert reloaded["layer0.weight_scale"].shape == (128, 64 // 16)
179188

180189
def test_sharded_guard(self, tmp_path):
181190
from modelopt.torch.export.unified_export_hf import _postprocess_safetensors

0 commit comments

Comments
 (0)