Skip to content

Commit 28d5424

Browse files
committed
[None][fix] Honor Qwen Image quant ignore list
Signed-off-by: Alex Steiner <asteiner@nvidia.com>
1 parent 7193f41 commit 28d5424

3 files changed

Lines changed: 123 additions & 0 deletions

File tree

tensorrt_llm/_torch/visual_gen/models/qwen_image/transformer_qwen_image.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from tensorrt_llm._torch.visual_gen.models.modeling import BaseDiffusionModel
3333
from tensorrt_llm._torch.visual_gen.modules.attention import Attention, QKVMode
3434
from tensorrt_llm._torch.visual_gen.quantization.loader import DynamicLinearWeightLoader
35+
from tensorrt_llm.models.modeling_utils import QuantConfig
3536

3637
_WEIGHT_KEY_REMAPS = [
3738
(".net.0.proj.", ".up_proj."),
@@ -825,6 +826,8 @@ def __init__(
825826
**linear_kwargs,
826827
)
827828

829+
self.apply_quant_config_exclude_modules()
830+
828831
@property
829832
def device(self) -> torch.device:
830833
return self.proj_out.weight.device
@@ -873,6 +876,26 @@ def to_inference_dtype(self) -> "QwenImageTransformer2DModel":
873876
buffer.data = buffer.data.to(target_dtype)
874877
return self
875878

879+
def apply_quant_config_exclude_modules(self) -> None:
880+
quant_config = self.model_config.quant_config
881+
if quant_config is None or quant_config.exclude_modules is None:
882+
return
883+
884+
kv_cache_quant_algo = quant_config.kv_cache_quant_algo if quant_config else None
885+
no_quant_config = QuantConfig(kv_cache_quant_algo=kv_cache_quant_algo)
886+
887+
for name, module in self.named_modules():
888+
if isinstance(module, Linear):
889+
is_excluded = quant_config.is_module_excluded_from_quantization(name)
890+
if is_excluded and getattr(module, "quant_config", None) is not None:
891+
module.quant_config = no_quant_config
892+
if getattr(module, "_weights_created", False):
893+
# Rebuild weights so quant_method and parameter layout match the no-quant config.
894+
module._weights_created = False
895+
module._parameters.clear()
896+
module._buffers.clear()
897+
module.create_weights()
898+
876899
def load_weights(self, weights: Dict[str, torch.Tensor]) -> None:
877900
"""Load HF ``transformer/*.safetensors`` state_dict.
878901

tests/integration/defs/examples/visual_gen/test_visual_gen.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1499,6 +1499,68 @@ def test_qwen_image_example(_visual_gen_deps, llm_root, llm_venv):
14991499
assert os.path.isfile(output_path), f"Example did not produce output at {output_path}"
15001500

15011501

1502+
def test_qwen_image_example_with_quant_ignore(_visual_gen_deps, llm_root, llm_venv):
1503+
"""Run Qwen-Image end-to-end with dynamic quantization and an ignore list."""
1504+
scratch_space = conftest.llm_models_root()
1505+
model_path = os.path.join(scratch_space, QWEN_IMAGE_MODEL_SUBPATH)
1506+
_skip_if_missing(model_path, "Qwen-Image checkpoint", is_dir=True)
1507+
model_index_path = os.path.join(model_path, "model_index.json")
1508+
if not os.path.isfile(model_index_path):
1509+
pytest.skip(
1510+
f"Qwen-Image checkpoint is incomplete: {model_path} (missing {model_index_path})"
1511+
)
1512+
1513+
out_dir = os.path.join(
1514+
llm_venv.get_working_directory(), "visual_gen_output", "qwen_image_quant_ignore"
1515+
)
1516+
os.makedirs(out_dir, exist_ok=True)
1517+
output_path = os.path.join(out_dir, "qwen_image_quant_ignore_output.png")
1518+
config_path = os.path.join(out_dir, "qwen_image_quant_ignore.yaml")
1519+
with open(config_path, "w") as f:
1520+
f.write(
1521+
textwrap.dedent(
1522+
"""\
1523+
quant_config:
1524+
quant_algo: FP8_BLOCK_SCALES
1525+
dynamic: true
1526+
ignore:
1527+
- "transformer_blocks.0*"
1528+
- "transformer_blocks.1.*"
1529+
- "transformer_blocks.58*"
1530+
- "transformer_blocks.59*"
1531+
- "img_in"
1532+
- "txt_in"
1533+
- "time_text_embed*"
1534+
- "proj_out"
1535+
attention_config:
1536+
backend: VANILLA
1537+
parallel_config:
1538+
cfg_size: 1
1539+
ulysses_size: 1
1540+
cuda_graph_config:
1541+
enable: false
1542+
"""
1543+
)
1544+
)
1545+
1546+
script_path = os.path.join(llm_root, "examples", "visual_gen", "models", "qwen_image.py")
1547+
assert os.path.isfile(script_path), f"Example script not found: {script_path}"
1548+
1549+
venv_check_call(
1550+
llm_venv,
1551+
[
1552+
script_path,
1553+
"--model",
1554+
model_path,
1555+
"--visual_gen_args",
1556+
config_path,
1557+
"--output_path",
1558+
output_path,
1559+
],
1560+
)
1561+
assert os.path.isfile(output_path), f"Example did not produce output at {output_path}"
1562+
1563+
15021564
def test_cosmos3_example(_visual_gen_deps, llm_root, llm_venv):
15031565
"""Run examples/visual_gen/models/cosmos3_ti2v.py with FP8 config end-to-end.
15041566

tests/unittest/_torch/visual_gen/test_qwen_image_registry.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import pytest
1414
import torch
1515

16+
from tensorrt_llm._torch.modules.linear import NVFP4LinearMethod, UnquantizedLinearMethod
1617
# Importing the models package side-effects the ``@register_pipeline``
1718
# decorator on ``QwenImagePipeline`` being applied, which is what we are
1819
# testing here.
@@ -23,6 +24,8 @@
2324
QwenImageTransformer2DModel,
2425
)
2526
from tensorrt_llm._torch.visual_gen.pipeline_registry import PIPELINE_REGISTRY, AutoPipeline
27+
from tensorrt_llm.models.modeling_utils import QuantConfig
28+
from tensorrt_llm.quantization.mode import QuantAlgo
2629
from tensorrt_llm.visual_gen.args import AttentionConfig
2730

2831

@@ -65,6 +68,41 @@ def test_transformer_load_weights_detects_mismatch():
6568
model.load_weights({})
6669

6770

71+
def test_transformer_applies_quant_config_ignore_list() -> None:
72+
"""Qwen-Image should honor selective dynamic quantization exclusions."""
73+
model_config = DiffusionModelConfig(
74+
quant_config=QuantConfig(
75+
quant_algo=QuantAlgo.NVFP4,
76+
exclude_modules=[
77+
"transformer_blocks.0*",
78+
"img_in",
79+
"proj_out",
80+
],
81+
),
82+
dynamic_weight_quant=True,
83+
force_dynamic_quantization=True,
84+
)
85+
model = QwenImageTransformer2DModel(model_config=model_config, num_layers=2)
86+
87+
assert model.img_in.quant_config.quant_algo is None
88+
assert model.proj_out.quant_config.quant_algo is None
89+
assert model.transformer_blocks[0].attn.add_q_proj.quant_config.quant_algo is None
90+
assert model.transformer_blocks[0].img_mlp.up_proj.quant_config.quant_algo is None
91+
assert isinstance(model.img_in.quant_method, UnquantizedLinearMethod)
92+
assert isinstance(model.proj_out.quant_method, UnquantizedLinearMethod)
93+
assert isinstance(
94+
model.transformer_blocks[0].attn.add_q_proj.quant_method, UnquantizedLinearMethod
95+
)
96+
assert isinstance(
97+
model.transformer_blocks[0].img_mlp.up_proj.quant_method, UnquantizedLinearMethod
98+
)
99+
100+
assert model.txt_in.quant_config.quant_algo == QuantAlgo.NVFP4
101+
assert model.transformer_blocks[1].attn.add_q_proj.quant_config.quant_algo == QuantAlgo.NVFP4
102+
assert isinstance(model.txt_in.quant_method, NVFP4LinearMethod)
103+
assert isinstance(model.transformer_blocks[1].attn.add_q_proj.quant_method, NVFP4LinearMethod)
104+
105+
68106
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
69107
@pytest.mark.parametrize("with_text_mask", [False, True])
70108
def test_transformer_forward_sanity(with_text_mask):

0 commit comments

Comments
 (0)