diff --git a/tensorrt_llm/_torch/visual_gen/models/qwen_image/transformer_qwen_image.py b/tensorrt_llm/_torch/visual_gen/models/qwen_image/transformer_qwen_image.py index ad00bf60bdfc..2e60812e4cd4 100644 --- a/tensorrt_llm/_torch/visual_gen/models/qwen_image/transformer_qwen_image.py +++ b/tensorrt_llm/_torch/visual_gen/models/qwen_image/transformer_qwen_image.py @@ -32,6 +32,7 @@ from tensorrt_llm._torch.visual_gen.models.modeling import BaseDiffusionModel from tensorrt_llm._torch.visual_gen.modules.attention import Attention, QKVMode from tensorrt_llm._torch.visual_gen.quantization.loader import DynamicLinearWeightLoader +from tensorrt_llm.models.modeling_utils import QuantConfig _WEIGHT_KEY_REMAPS = [ (".net.0.proj.", ".up_proj."), @@ -825,6 +826,8 @@ def __init__( **linear_kwargs, ) + self.apply_quant_config_exclude_modules() + @property def device(self) -> torch.device: return self.proj_out.weight.device @@ -873,6 +876,26 @@ def to_inference_dtype(self) -> "QwenImageTransformer2DModel": buffer.data = buffer.data.to(target_dtype) return self + def apply_quant_config_exclude_modules(self) -> None: + quant_config = self.model_config.quant_config + if quant_config is None or quant_config.exclude_modules is None: + return + + kv_cache_quant_algo = quant_config.kv_cache_quant_algo if quant_config else None + no_quant_config = QuantConfig(kv_cache_quant_algo=kv_cache_quant_algo) + + for name, module in self.named_modules(): + if isinstance(module, Linear): + is_excluded = quant_config.is_module_excluded_from_quantization(name) + if is_excluded and getattr(module, "quant_config", None) is not None: + module.quant_config = no_quant_config + if getattr(module, "_weights_created", False): + # Rebuild weights so quant_method and parameter layout match the no-quant config. + module._weights_created = False + module._parameters.clear() + module._buffers.clear() + module.create_weights() + def load_weights(self, weights: Dict[str, torch.Tensor]) -> None: """Load HF ``transformer/*.safetensors`` state_dict. diff --git a/tests/integration/defs/examples/visual_gen/test_visual_gen.py b/tests/integration/defs/examples/visual_gen/test_visual_gen.py index 41058e88ed44..938ce6d82f76 100644 --- a/tests/integration/defs/examples/visual_gen/test_visual_gen.py +++ b/tests/integration/defs/examples/visual_gen/test_visual_gen.py @@ -1499,6 +1499,68 @@ def test_qwen_image_example(_visual_gen_deps, llm_root, llm_venv): assert os.path.isfile(output_path), f"Example did not produce output at {output_path}" +def test_qwen_image_example_with_quant_ignore(_visual_gen_deps, llm_root, llm_venv): + """Run Qwen-Image end-to-end with dynamic quantization and an ignore list.""" + scratch_space = conftest.llm_models_root() + model_path = os.path.join(scratch_space, QWEN_IMAGE_MODEL_SUBPATH) + _skip_if_missing(model_path, "Qwen-Image checkpoint", is_dir=True) + model_index_path = os.path.join(model_path, "model_index.json") + if not os.path.isfile(model_index_path): + pytest.skip( + f"Qwen-Image checkpoint is incomplete: {model_path} (missing {model_index_path})" + ) + + out_dir = os.path.join( + llm_venv.get_working_directory(), "visual_gen_output", "qwen_image_quant_ignore" + ) + os.makedirs(out_dir, exist_ok=True) + output_path = os.path.join(out_dir, "qwen_image_quant_ignore_output.png") + config_path = os.path.join(out_dir, "qwen_image_quant_ignore.yaml") + with open(config_path, "w") as f: + f.write( + textwrap.dedent( + """\ + quant_config: + quant_algo: FP8_BLOCK_SCALES + dynamic: true + ignore: + - "transformer_blocks.0*" + - "transformer_blocks.1.*" + - "transformer_blocks.58*" + - "transformer_blocks.59*" + - "img_in" + - "txt_in" + - "time_text_embed*" + - "proj_out" + attention_config: + backend: VANILLA + parallel_config: + cfg_size: 1 + ulysses_size: 1 + cuda_graph_config: + enable: false + """ + ) + ) + + script_path = os.path.join(llm_root, "examples", "visual_gen", "models", "qwen_image.py") + assert os.path.isfile(script_path), f"Example script not found: {script_path}" + + venv_check_call( + llm_venv, + [ + script_path, + "--model", + model_path, + "--visual_gen_args", + config_path, + "--output_path", + output_path, + ], + ) + assert os.path.isfile(output_path), f"Example did not produce output at {output_path}" + + def test_cosmos3_example(_visual_gen_deps, llm_root, llm_venv): """Run examples/visual_gen/models/cosmos3_ti2v.py with FP8 config end-to-end. diff --git a/tests/unittest/_torch/visual_gen/test_qwen_image_registry.py b/tests/unittest/_torch/visual_gen/test_qwen_image_registry.py index 64e82d27837b..63f0b9c998bf 100644 --- a/tests/unittest/_torch/visual_gen/test_qwen_image_registry.py +++ b/tests/unittest/_torch/visual_gen/test_qwen_image_registry.py @@ -13,6 +13,7 @@ import pytest import torch +from tensorrt_llm._torch.modules.linear import NVFP4LinearMethod, UnquantizedLinearMethod # Importing the models package side-effects the ``@register_pipeline`` # decorator on ``QwenImagePipeline`` being applied, which is what we are # testing here. @@ -23,6 +24,8 @@ QwenImageTransformer2DModel, ) from tensorrt_llm._torch.visual_gen.pipeline_registry import PIPELINE_REGISTRY, AutoPipeline +from tensorrt_llm.models.modeling_utils import QuantConfig +from tensorrt_llm.quantization.mode import QuantAlgo from tensorrt_llm.visual_gen.args import AttentionConfig @@ -65,6 +68,41 @@ def test_transformer_load_weights_detects_mismatch(): model.load_weights({}) +def test_transformer_applies_quant_config_ignore_list() -> None: + """Qwen-Image should honor selective dynamic quantization exclusions.""" + model_config = DiffusionModelConfig( + quant_config=QuantConfig( + quant_algo=QuantAlgo.NVFP4, + exclude_modules=[ + "transformer_blocks.0*", + "img_in", + "proj_out", + ], + ), + dynamic_weight_quant=True, + force_dynamic_quantization=True, + ) + model = QwenImageTransformer2DModel(model_config=model_config, num_layers=2) + + assert model.img_in.quant_config.quant_algo is None + assert model.proj_out.quant_config.quant_algo is None + assert model.transformer_blocks[0].attn.add_q_proj.quant_config.quant_algo is None + assert model.transformer_blocks[0].img_mlp.up_proj.quant_config.quant_algo is None + assert isinstance(model.img_in.quant_method, UnquantizedLinearMethod) + assert isinstance(model.proj_out.quant_method, UnquantizedLinearMethod) + assert isinstance( + model.transformer_blocks[0].attn.add_q_proj.quant_method, UnquantizedLinearMethod + ) + assert isinstance( + model.transformer_blocks[0].img_mlp.up_proj.quant_method, UnquantizedLinearMethod + ) + + assert model.txt_in.quant_config.quant_algo == QuantAlgo.NVFP4 + assert model.transformer_blocks[1].attn.add_q_proj.quant_config.quant_algo == QuantAlgo.NVFP4 + assert isinstance(model.txt_in.quant_method, NVFP4LinearMethod) + assert isinstance(model.transformer_blocks[1].attn.add_q_proj.quant_method, NVFP4LinearMethod) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.parametrize("with_text_mask", [False, True]) def test_transformer_forward_sanity(with_text_mask):