Skip to content

Commit 439cf69

Browse files
committed
add test cases
Signed-off-by: ynankani <ynankani@nvidia.com>
1 parent 66ce914 commit 439cf69

3 files changed

Lines changed: 215 additions & 2 deletions

File tree

examples/diffusers/quantization/quantize.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -324,9 +324,23 @@ def export_hf_ckpt(self, pipe: Any, model_config: ModelConfig | None = None) ->
324324
for key in ("enable_swizzle_layout", "enable_layerwise_quant_metadata"):
325325
val = model_config.extra_params.get(key)
326326
if val is not None:
327-
kwargs[key] = str(val).lower() in ("true", "1", "yes")
327+
normalized = str(val).strip().lower()
328+
if normalized in ("true", "1", "yes"):
329+
kwargs[key] = True
330+
elif normalized in ("false", "0", "no"):
331+
kwargs[key] = False
332+
else:
333+
raise ValueError(
334+
f"Invalid value for {key}: {val!r}. "
335+
"Expected true/false, 1/0, or yes/no."
336+
)
328337
padding = model_config.extra_params.get("padding_strategy")
329-
if padding:
338+
if padding is not None:
339+
padding = str(padding).strip().lower()
340+
if padding not in ("row", "row_col"):
341+
raise ValueError(
342+
f"Invalid padding_strategy: {padding!r}. Expected 'row' or 'row_col'."
343+
)
330344
kwargs["padding_strategy"] = padding
331345
export_hf_checkpoint(pipe, export_dir=self.config.hf_ckpt_dir, **kwargs)
332346
self.logger.info("HuggingFace checkpoint export completed successfully")

modelopt/torch/export/unified_export_hf.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,14 @@ def _postprocess_safetensors(
187187
if not safetensor_files:
188188
return
189189

190+
if list(export_dir.glob("*.safetensors.index.json")) and (
191+
merged_base_safetensor_path is not None or enable_layerwise_quant_metadata
192+
):
193+
raise NotImplementedError(
194+
"Post-processing sharded safetensors is not supported. "
195+
"Export with a larger max_shard_size or disable merge/metadata options."
196+
)
197+
190198
for sf_path in safetensor_files:
191199
sd = load_file(str(sf_path))
192200

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Tests for NVFP4 utility functions (pad, swizzle, metadata) and _postprocess_safetensors."""
17+
18+
import json
19+
from pathlib import Path
20+
21+
import pytest
22+
import torch
23+
from safetensors.torch import load_file, save_file
24+
25+
from modelopt.torch.export.diffusers_utils import (
26+
build_layerwise_quant_metadata,
27+
pad_nvfp4_weights,
28+
swizzle_nvfp4_scales,
29+
)
30+
31+
32+
def _make_nvfp4_state_dict(rows=32, cols=64):
33+
"""Create a minimal NVFP4 state dict with one quantized layer."""
34+
return {
35+
"layer0.weight": torch.randint(0, 255, (rows, cols), dtype=torch.uint8),
36+
"layer0.weight_scale": torch.randn(rows, cols // 16).to(torch.float8_e4m3fn),
37+
"layer0.weight_scale_2": torch.randn(rows, 1),
38+
"layer0.bias": torch.randn(rows),
39+
}
40+
41+
42+
# ---------------------------------------------------------------------------
43+
# _find_nvfp4_layers (tested implicitly via pad / swizzle that rely on it)
44+
# ---------------------------------------------------------------------------
45+
46+
47+
class TestBuildLayerwiseQuantMetadata:
48+
def test_basic(self):
49+
sd = _make_nvfp4_state_dict()
50+
cfg = {"quant_algo": "NVFP4"}
51+
result = json.loads(build_layerwise_quant_metadata(sd, cfg))
52+
53+
assert result["format_version"] == "1.0"
54+
assert "layer0" in result["layers"]
55+
assert result["layers"]["layer0"]["format"] == "nvfp4"
56+
57+
def test_no_quantized_layers(self):
58+
sd = {"linear.weight": torch.randn(4, 4), "linear.bias": torch.randn(4)}
59+
result = json.loads(build_layerwise_quant_metadata(sd, {"quant_algo": "FP8"}))
60+
assert result["layers"] == {}
61+
62+
def test_multiple_layers(self):
63+
sd = {**_make_nvfp4_state_dict()}
64+
sd["layer1.weight"] = torch.randint(0, 255, (16, 32), dtype=torch.uint8)
65+
sd["layer1.weight_scale"] = torch.randn(16, 2).to(torch.float8_e4m3fn)
66+
sd["layer1.weight_scale_2"] = torch.randn(16, 1)
67+
68+
result = json.loads(build_layerwise_quant_metadata(sd, {"quant_algo": "NVFP4"}))
69+
assert "layer0" in result["layers"]
70+
assert "layer1" in result["layers"]
71+
72+
73+
class TestPadNvfp4Weights:
74+
def test_row_padding(self):
75+
sd = _make_nvfp4_state_dict(rows=20, cols=64)
76+
result = pad_nvfp4_weights(sd, "row")
77+
78+
assert result["layer0.weight"].shape[0] % 16 == 0
79+
assert result["layer0.weight_scale"].shape[0] % 16 == 0
80+
assert result["layer0.weight"].shape[0] == 32
81+
82+
def test_row_col_padding(self):
83+
sd = _make_nvfp4_state_dict(rows=20, cols=48)
84+
result = pad_nvfp4_weights(sd, "row_col")
85+
86+
w = result["layer0.weight"]
87+
s = result["layer0.weight_scale"]
88+
assert w.shape[0] % 16 == 0
89+
assert w.shape[1] % 16 == 0
90+
assert s.shape[0] % 16 == 0
91+
assert s.shape[1] % 16 == 0
92+
93+
def test_already_aligned(self):
94+
sd = _make_nvfp4_state_dict(rows=32, cols=64)
95+
orig_w_shape = sd["layer0.weight"].shape
96+
result = pad_nvfp4_weights(sd, "row")
97+
98+
assert result["layer0.weight"].shape == orig_w_shape
99+
100+
def test_invalid_strategy(self):
101+
sd = _make_nvfp4_state_dict()
102+
with pytest.raises(ValueError, match="padding_strategy"):
103+
pad_nvfp4_weights(sd, "invalid")
104+
105+
def test_non_nvfp4_tensors_untouched(self):
106+
sd = _make_nvfp4_state_dict(rows=20, cols=64)
107+
bias_before = sd["layer0.bias"].clone()
108+
pad_nvfp4_weights(sd, "row")
109+
assert torch.equal(sd["layer0.bias"], bias_before)
110+
111+
112+
class TestSwizzleNvfp4Scales:
113+
def test_shape_preserved(self):
114+
sd = _make_nvfp4_state_dict(rows=128, cols=64)
115+
orig_shape = sd["layer0.weight_scale"].shape
116+
result = swizzle_nvfp4_scales(sd)
117+
118+
assert result["layer0.weight_scale"].shape == orig_shape
119+
120+
def test_dtype_is_fp8(self):
121+
sd = _make_nvfp4_state_dict(rows=128, cols=64)
122+
result = swizzle_nvfp4_scales(sd)
123+
124+
assert result["layer0.weight_scale"].dtype == torch.float8_e4m3fn
125+
126+
def test_non_nvfp4_tensors_untouched(self):
127+
sd = _make_nvfp4_state_dict(rows=128, cols=64)
128+
bias_before = sd["layer0.bias"].clone()
129+
swizzle_nvfp4_scales(sd)
130+
assert torch.equal(sd["layer0.bias"], bias_before)
131+
132+
def test_small_scale_needs_internal_padding(self):
133+
"""Scales with rows < 128 trigger internal padding in _to_blocked."""
134+
sd = _make_nvfp4_state_dict(rows=16, cols=64)
135+
result = swizzle_nvfp4_scales(sd)
136+
# _to_blocked pads rows up to the next multiple of 128
137+
assert result["layer0.weight_scale"].shape == (128, 64 // 16)
138+
139+
140+
class TestPostprocessSafetensors:
141+
def test_metadata_injection(self, tmp_path):
142+
from modelopt.torch.export.unified_export_hf import _postprocess_safetensors
143+
144+
sd = {"weight": torch.randn(4, 4)}
145+
save_file(sd, str(tmp_path / "model.safetensors"))
146+
147+
hf_quant_config = {"quant_algo": "FP8", "kv_cache_quant_algo": "FP8"}
148+
_postprocess_safetensors(
149+
tmp_path,
150+
hf_quant_config=hf_quant_config,
151+
enable_layerwise_quant_metadata=True,
152+
)
153+
154+
reloaded = load_file(str(tmp_path / "model.safetensors"))
155+
assert torch.allclose(reloaded["weight"], sd["weight"])
156+
157+
def test_padding_and_swizzle(self, tmp_path):
158+
from modelopt.torch.export.unified_export_hf import _postprocess_safetensors
159+
160+
sd = _make_nvfp4_state_dict(rows=20, cols=64)
161+
save_file(sd, str(tmp_path / "model.safetensors"))
162+
163+
_postprocess_safetensors(
164+
tmp_path,
165+
padding_strategy="row",
166+
enable_swizzle_layout=True,
167+
enable_layerwise_quant_metadata=False,
168+
)
169+
170+
reloaded = load_file(str(tmp_path / "model.safetensors"))
171+
assert reloaded["layer0.weight"].shape[0] == 32
172+
assert reloaded["layer0.weight_scale"].dtype == torch.float8_e4m3fn
173+
174+
def test_sharded_guard(self, tmp_path):
175+
from modelopt.torch.export.unified_export_hf import _postprocess_safetensors
176+
177+
save_file({"w": torch.randn(2, 2)}, str(tmp_path / "model.safetensors"))
178+
(tmp_path / "model.safetensors.index.json").write_text("{}")
179+
180+
with pytest.raises(NotImplementedError, match="sharded"):
181+
_postprocess_safetensors(
182+
tmp_path,
183+
merged_base_safetensor_path="/fake/path.safetensors",
184+
model_type="ltx2",
185+
enable_layerwise_quant_metadata=True,
186+
)
187+
188+
def test_no_safetensor_files(self, tmp_path):
189+
from modelopt.torch.export.unified_export_hf import _postprocess_safetensors
190+
191+
_postprocess_safetensors(tmp_path)

0 commit comments

Comments
 (0)