Skip to content

Commit 4ba3ac1

Browse files
committed
Update the fix and add more test cases
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
1 parent d78797b commit 4ba3ac1

3 files changed

Lines changed: 152 additions & 32 deletions

File tree

examples/diffusers/quantization/config.py

Lines changed: 8 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@
3232

3333
INT8_DEFAULT_CONFIG = {
3434
"quant_cfg": {
35-
"*weight_quantizer": {"num_bits": 8, "axis": 0},
36-
"*input_quantizer": {"num_bits": 8, "axis": 0},
35+
"*weight_quantizer": {"num_bits": 8, "axis": None},
36+
"*input_quantizer": {"num_bits": 8, "axis": None},
3737
"*output_quantizer": {"enable": False},
3838
"default": {"enable": False},
3939
},
@@ -112,8 +112,10 @@ def set_quant_config_attr(quant_config, trt_high_precision_dtype, quant_algo, **
112112

113113

114114
def reset_set_int8_config(quant_config, percentile, n_steps, collect_method, backbone):
115-
"""
116-
Configure INT8 quantization with different settings for Conv2d and Linear layers.
115+
"""Add PercentileCalibrator to Conv2d input quantizers.
116+
117+
Linear layers are left unchanged — their axis settings come from the base
118+
quant_config (e.g. INT8_SMOOTHQUANT_CFG or INT8_DEFAULT_CONFIG).
117119
118120
Args:
119121
quant_config: The quantization configuration dictionary
@@ -122,31 +124,9 @@ def reset_set_int8_config(quant_config, percentile, n_steps, collect_method, bac
122124
collect_method: Method for collecting calibration statistics
123125
backbone: The model backbone to analyze layer types
124126
"""
125-
126-
# Build a mapping of layer names to their types
127-
layer_type_map = {}
128127
for name, module in backbone.named_modules():
129-
if isinstance(module, (nn.Linear, nn.Conv2d)):
130-
layer_type_map[name] = type(module)
131-
132-
quant_config["quant_cfg"] = {}
133-
for layer_name, layer_type in layer_type_map.items():
134-
wq_name = f"*{layer_name}*weight_quantizer*"
135-
aq_name = f"*{layer_name}*input_quantizer*"
136-
if layer_type is nn.Linear:
137-
quant_config["quant_cfg"][wq_name] = {
138-
"num_bits": 8,
139-
"axis": 0,
140-
}
141-
quant_config["quant_cfg"][aq_name] = {
142-
"num_bits": 8,
143-
"axis": -1,
144-
}
145-
else:
146-
quant_config["quant_cfg"][wq_name] = {
147-
"num_bits": 8,
148-
"axis": 0,
149-
}
128+
if isinstance(module, nn.Conv2d):
129+
aq_name = f"*{name}*input_quantizer*"
150130
quant_config["quant_cfg"][aq_name] = {
151131
"num_bits": 8,
152132
"axis": None,
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from pathlib import Path
17+
from typing import NamedTuple
18+
19+
import pytest
20+
from _test_utils.examples.models import FLUX_SCHNELL_PATH, SDXL_1_0_PATH
21+
from _test_utils.examples.run_command import run_example_command
22+
from _test_utils.torch.misc import minimum_sm
23+
24+
25+
class DiffuserHfExportModel(NamedTuple):
26+
name: str
27+
path: str
28+
dtype: str
29+
format_type: str
30+
quant_algo: str
31+
collect_method: str
32+
33+
def quantize_and_export_hf(self, tmp_path: Path) -> Path:
34+
hf_ckpt_dir = tmp_path / f"{self.name}_{self.format_type}_hf_ckpt"
35+
cmd_args = [
36+
"python",
37+
"quantize.py",
38+
"--model",
39+
self.name,
40+
"--override-model-path",
41+
self.path,
42+
"--calib-size",
43+
"8",
44+
"--batch-size",
45+
"2",
46+
"--n-steps",
47+
"20",
48+
"--percentile",
49+
"1.0",
50+
"--alpha",
51+
"0.8",
52+
"--format",
53+
self.format_type,
54+
"--quant-algo",
55+
self.quant_algo,
56+
"--collect-method",
57+
self.collect_method,
58+
"--trt-high-precision-dtype",
59+
self.dtype,
60+
"--hf-ckpt-dir",
61+
str(hf_ckpt_dir),
62+
]
63+
run_example_command(cmd_args, "diffusers/quantization")
64+
return hf_ckpt_dir
65+
66+
67+
@pytest.mark.parametrize(
68+
"model",
69+
[
70+
DiffuserHfExportModel(
71+
name="sdxl-1.0",
72+
path=SDXL_1_0_PATH,
73+
dtype="Half",
74+
format_type="int8",
75+
quant_algo="smoothquant",
76+
collect_method="min-mean",
77+
),
78+
DiffuserHfExportModel(
79+
name="flux-schnell",
80+
path=FLUX_SCHNELL_PATH,
81+
dtype="BFloat16",
82+
format_type="int8",
83+
quant_algo="smoothquant",
84+
collect_method="min-mean",
85+
),
86+
pytest.param(
87+
DiffuserHfExportModel(
88+
name="sdxl-1.0",
89+
path=SDXL_1_0_PATH,
90+
dtype="Half",
91+
format_type="fp8",
92+
quant_algo="max",
93+
collect_method="default",
94+
),
95+
marks=minimum_sm(89),
96+
),
97+
pytest.param(
98+
DiffuserHfExportModel(
99+
name="flux-schnell",
100+
path=FLUX_SCHNELL_PATH,
101+
dtype="BFloat16",
102+
format_type="fp4",
103+
quant_algo="max",
104+
collect_method="default",
105+
),
106+
marks=minimum_sm(89),
107+
),
108+
],
109+
ids=[
110+
"sdxl_1.0_int8_smoothquant_min_mean",
111+
"flux_schnell_int8_smoothquant_min_mean",
112+
"sdxl_1.0_fp8_max_default",
113+
"flux_schnell_fp4_max_default",
114+
],
115+
)
116+
def test_diffusers_hf_ckpt_export(model: DiffuserHfExportModel, tmp_path: Path) -> None:
117+
hf_ckpt_dir = model.quantize_and_export_hf(tmp_path)
118+
119+
assert hf_ckpt_dir.exists(), f"HF checkpoint directory was not created: {hf_ckpt_dir}"
120+
121+
config_files = list(hf_ckpt_dir.rglob("config.json"))
122+
assert len(config_files) > 0, f"No config.json found in {hf_ckpt_dir}"
123+
124+
weight_files = list(hf_ckpt_dir.rglob("*.safetensors")) + list(hf_ckpt_dir.rglob("*.bin"))
125+
assert len(weight_files) > 0, f"No weight files (.safetensors or .bin) found in {hf_ckpt_dir}"

tests/unit/torch/export/test_export_diffusers.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import pytest
1919
from _test_utils.torch.diffusers_models import get_tiny_dit, get_tiny_flux, get_tiny_unet
20+
from _test_utils.torch.misc import minimum_sm
2021

2122
pytest.importorskip("diffusers")
2223

@@ -87,19 +88,33 @@ def _process_stub(*_args, **_kwargs):
8788

8889

8990
@pytest.mark.parametrize("model_factory", [get_tiny_unet, get_tiny_dit, get_tiny_flux])
90-
def test_export_diffusers_real_quantized(tmp_path, model_factory):
91+
@pytest.mark.parametrize(
92+
("config_id", "quant_cfg"),
93+
[
94+
("int8", mtq.INT8_DEFAULT_CFG),
95+
("int8_smoothquant", mtq.INT8_SMOOTHQUANT_CFG),
96+
("fp8", mtq.FP8_DEFAULT_CFG),
97+
pytest.param("fp4", mtq.NVFP4_DEFAULT_CFG, marks=minimum_sm(89)),
98+
],
99+
)
100+
def test_export_diffusers_real_quantized(tmp_path, model_factory, config_id, quant_cfg):
91101
model = model_factory()
92-
export_dir = tmp_path / f"export_{type(model).__name__}_real_quant"
102+
export_dir = tmp_path / f"export_{type(model).__name__}_{config_id}_real_quant"
93103

94104
def _calib_fn(m):
95105
param = next(m.parameters())
96106
dummy_inputs = generate_diffusion_dummy_inputs(m, param.device, param.dtype)
97107
assert dummy_inputs is not None
98108
m(**dummy_inputs)
99109

100-
mtq.quantize(model, mtq.FP8_DEFAULT_CFG, forward_loop=_calib_fn)
110+
mtq.quantize(model, quant_cfg, forward_loop=_calib_fn)
101111

102-
export_hf_checkpoint(model, export_dir=export_dir)
112+
try:
113+
export_hf_checkpoint(model, export_dir=export_dir)
114+
except AssertionError as e:
115+
if "block size" in str(e) and config_id == "fp4":
116+
pytest.skip(f"Tiny model weights incompatible with FP4 block quantization: {e}")
117+
raise
103118

104119
config_path = export_dir / "config.json"
105120
assert config_path.exists()

0 commit comments

Comments
 (0)