Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 8 additions & 28 deletions examples/diffusers/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@

INT8_DEFAULT_CONFIG = {
"quant_cfg": {
"*weight_quantizer": {"num_bits": 8, "axis": 0},
"*input_quantizer": {"num_bits": 8, "axis": 0},
"*weight_quantizer": {"num_bits": 8, "axis": None},
Comment thread
jingyu-ml marked this conversation as resolved.
Outdated
Copy link

Copilot AI Feb 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The weight_quantizer axis should be 0 (per-channel quantization), not None (per-tensor quantization), to match the standard modelopt INT8_DEFAULT_CFG configuration. The modelopt INT8_DEFAULT_CFG uses axis=0 for weight quantization and axis=None for input quantization. This inconsistency could lead to suboptimal quantization quality for weights.

Suggested change
"*weight_quantizer": {"num_bits": 8, "axis": None},
"*weight_quantizer": {"num_bits": 8, "axis": 0},

Copilot uses AI. Check for mistakes.
"*input_quantizer": {"num_bits": 8, "axis": None},
"*output_quantizer": {"enable": False},
"default": {"enable": False},
},
Expand Down Expand Up @@ -112,8 +112,10 @@ def set_quant_config_attr(quant_config, trt_high_precision_dtype, quant_algo, **


def reset_set_int8_config(quant_config, percentile, n_steps, collect_method, backbone):
"""
Configure INT8 quantization with different settings for Conv2d and Linear layers.
"""Add PercentileCalibrator to Conv2d input quantizers.

Linear layers are left unchanged — their axis settings come from the base
quant_config (e.g. INT8_SMOOTHQUANT_CFG or INT8_DEFAULT_CONFIG).

Args:
quant_config: The quantization configuration dictionary
Expand All @@ -122,31 +124,9 @@ def reset_set_int8_config(quant_config, percentile, n_steps, collect_method, bac
collect_method: Method for collecting calibration statistics
backbone: The model backbone to analyze layer types
"""

# Build a mapping of layer names to their types
layer_type_map = {}
for name, module in backbone.named_modules():
if isinstance(module, (nn.Linear, nn.Conv2d)):
layer_type_map[name] = type(module)

quant_config["quant_cfg"] = {}
for layer_name, layer_type in layer_type_map.items():
wq_name = f"*{layer_name}*weight_quantizer*"
aq_name = f"*{layer_name}*input_quantizer*"
if layer_type is nn.Linear:
quant_config["quant_cfg"][wq_name] = {
"num_bits": 8,
"axis": 0,
}
quant_config["quant_cfg"][aq_name] = {
"num_bits": 8,
"axis": -1,
}
else:
quant_config["quant_cfg"][wq_name] = {
"num_bits": 8,
"axis": 0,
}
if isinstance(module, nn.Conv2d):
aq_name = f"*{name}*input_quantizer*"
quant_config["quant_cfg"][aq_name] = {
"num_bits": 8,
"axis": None,
Expand Down
125 changes: 125 additions & 0 deletions tests/gpu/torch/export/test_export_diffusers_hf_ckpt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pathlib import Path
from typing import NamedTuple

import pytest
from _test_utils.examples.models import FLUX_SCHNELL_PATH, SDXL_1_0_PATH
from _test_utils.examples.run_command import run_example_command
from _test_utils.torch.misc import minimum_sm


class DiffuserHfExportModel(NamedTuple):
name: str
path: str
dtype: str
format_type: str
quant_algo: str
collect_method: str

def quantize_and_export_hf(self, tmp_path: Path) -> Path:
hf_ckpt_dir = tmp_path / f"{self.name}_{self.format_type}_hf_ckpt"
cmd_args = [
"python",
"quantize.py",
"--model",
self.name,
"--override-model-path",
self.path,
"--calib-size",
"8",
"--batch-size",
"2",
"--n-steps",
"20",
"--percentile",
"1.0",
"--alpha",
"0.8",
"--format",
self.format_type,
"--quant-algo",
self.quant_algo,
"--collect-method",
self.collect_method,
"--trt-high-precision-dtype",
self.dtype,
"--hf-ckpt-dir",
str(hf_ckpt_dir),
]
run_example_command(cmd_args, "diffusers/quantization")
return hf_ckpt_dir


@pytest.mark.parametrize(
"model",
[
DiffuserHfExportModel(
name="sdxl-1.0",
path=SDXL_1_0_PATH,
dtype="Half",
format_type="int8",
quant_algo="smoothquant",
collect_method="min-mean",
),
DiffuserHfExportModel(
name="flux-schnell",
path=FLUX_SCHNELL_PATH,
dtype="BFloat16",
format_type="int8",
quant_algo="smoothquant",
collect_method="min-mean",
),
pytest.param(
DiffuserHfExportModel(
name="sdxl-1.0",
path=SDXL_1_0_PATH,
dtype="Half",
format_type="fp8",
quant_algo="max",
collect_method="default",
),
marks=minimum_sm(89),
),
pytest.param(
DiffuserHfExportModel(
name="flux-schnell",
path=FLUX_SCHNELL_PATH,
dtype="BFloat16",
format_type="fp4",
quant_algo="max",
collect_method="default",
),
marks=minimum_sm(89),
),
],
ids=[
"sdxl_1.0_int8_smoothquant_min_mean",
"flux_schnell_int8_smoothquant_min_mean",
"sdxl_1.0_fp8_max_default",
"flux_schnell_fp4_max_default",
],
)
def test_diffusers_hf_ckpt_export(model: DiffuserHfExportModel, tmp_path: Path) -> None:
hf_ckpt_dir = model.quantize_and_export_hf(tmp_path)

assert hf_ckpt_dir.exists(), f"HF checkpoint directory was not created: {hf_ckpt_dir}"

config_files = list(hf_ckpt_dir.rglob("config.json"))
assert len(config_files) > 0, f"No config.json found in {hf_ckpt_dir}"

weight_files = list(hf_ckpt_dir.rglob("*.safetensors")) + list(hf_ckpt_dir.rglob("*.bin"))
assert len(weight_files) > 0, f"No weight files (.safetensors or .bin) found in {hf_ckpt_dir}"
23 changes: 19 additions & 4 deletions tests/unit/torch/export/test_export_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import pytest
from _test_utils.torch.diffusers_models import get_tiny_dit, get_tiny_flux, get_tiny_unet
from _test_utils.torch.misc import minimum_sm

pytest.importorskip("diffusers")

Expand Down Expand Up @@ -87,19 +88,33 @@ def _process_stub(*_args, **_kwargs):


@pytest.mark.parametrize("model_factory", [get_tiny_unet, get_tiny_dit, get_tiny_flux])
def test_export_diffusers_real_quantized(tmp_path, model_factory):
@pytest.mark.parametrize(
("config_id", "quant_cfg"),
[
("int8", mtq.INT8_DEFAULT_CFG),
("int8_smoothquant", mtq.INT8_SMOOTHQUANT_CFG),
("fp8", mtq.FP8_DEFAULT_CFG),
pytest.param("fp4", mtq.NVFP4_DEFAULT_CFG, marks=minimum_sm(89)),
],
)
def test_export_diffusers_real_quantized(tmp_path, model_factory, config_id, quant_cfg):
model = model_factory()
export_dir = tmp_path / f"export_{type(model).__name__}_real_quant"
export_dir = tmp_path / f"export_{type(model).__name__}_{config_id}_real_quant"

def _calib_fn(m):
param = next(m.parameters())
dummy_inputs = generate_diffusion_dummy_inputs(m, param.device, param.dtype)
assert dummy_inputs is not None
m(**dummy_inputs)

mtq.quantize(model, mtq.FP8_DEFAULT_CFG, forward_loop=_calib_fn)
mtq.quantize(model, quant_cfg, forward_loop=_calib_fn)

export_hf_checkpoint(model, export_dir=export_dir)
try:
export_hf_checkpoint(model, export_dir=export_dir)
except AssertionError as e:
if "block size" in str(e) and config_id == "fp4":
pytest.skip(f"Tiny model weights incompatible with FP4 block quantization: {e}")
raise

config_path = export_dir / "config.json"
assert config_path.exists()
Expand Down
Loading