Skip to content

Commit 6c5c644

Browse files
support Qwen3.5 quantization fp8
1 parent 952a62b commit 6c5c644

File tree

10 files changed

+195
-10
lines changed

10 files changed

+195
-10
lines changed

examples/llm_ptq/README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,8 @@ Please reference our [framework scripts](#framework-scripts) and our [docs](http
108108
| Llama-Nemotron Ultra ||||||
109109
| Gemma 3 | ✅<sup>2</sup> | - || - | - |
110110
| QWen 2, 2.5 <sup>4</sup> ||||||
111-
| QWen3, 3.5 MOE, Next <sup>6</sup> || - | - | - ||
111+
| QWen3, Next <sup>6</sup> || - | - | - ||
112+
| QWen3.5 (Dense & MoE) <sup>6</sup> || - | - | - ||
112113
| QwQ || - | - | - ||
113114
| DeepSeek V3, R1, V3.1, V3.2<sup>7</sup> | - | - | - | - ||
114115
| GLM-4.7<sup>8</sup> || - | - | - ||
@@ -478,6 +479,8 @@ print(llm_fp8.generate(["What's the age of the earth? "]))
478479
| QWen3 | FP4 ||| - |
479480
| QWen3 MoE | FP8 ||||
480481
| QWen3 MoE | FP4 || - | - |
482+
| QWen3.5 Dense | FP8 ||||
483+
| QWen3.5 MoE | FP8 ||||
481484
| QWen3.5 MoE | FP4 | - | - ||
482485
| QWen2.5 | FP8 ||||
483486
| QWen2.5 | FP4 ||| - |

examples/llm_ptq/example_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,11 @@ def build_quant_cfg(
252252
quant_cfg["quant_cfg"].append({"quantizer_name": "*image*", "enable": False})
253253
quant_cfg["quant_cfg"].append({"quantizer_name": "*vision*", "enable": False})
254254

255+
if model_type == "qwen3_5moe":
256+
# TRT-LLM's Qwen3.5-MoE weight loader uses intermediate_size (default hidden_size*2)
257+
# instead of moe_intermediate_size for expert buffer allocation, causing shape mismatches.
258+
quant_cfg["quant_cfg"].append({"quantizer_name": "*experts*", "enable": False})
259+
255260
return quant_cfg
256261

257262

examples/vlm_ptq/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ Please refer to the [llm_ptq/README.md](../llm_ptq/README.md#getting-started) fo
3838
| VILA ||||| - |
3939
| Phi-3-vision, Phi-4-multimodal ||||||
4040
| Qwen2, 2.5-VL ||||||
41+
| Qwen3.5-VL (Dense & MoE) || - | - | - | - |
4142
| Gemma3 || - | - | - | - |
4243

4344
> *<sup>1.</sup>Only TensorRT-LLM checkpoint export is supported. Not compatible with the TensorRT-LLM torch backend* \
@@ -46,6 +47,8 @@ Please refer to the [llm_ptq/README.md](../llm_ptq/README.md#getting-started) fo
4647
4748
> *For detailed TensorRT-LLM torch backend multimodal support, please refer to [this doc](https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/models/supported-models.md#multimodal-feature-support-matrix-pytorch-backend)*
4849
50+
> **Qwen3.5 VLM Note:** When quantizing Qwen3.5 VLM models, linear attention (`linear_attn`) layers are not quantized (TRT-LLM compatibility), and MoE expert layers are also excluded from quantization for the MoE variant. The exported checkpoint preserves the original VLM format (`Qwen3_5ForConditionalGeneration` architecture, `model.language_model.*` key prefix) and can be deployed directly on TRT-LLM, vLLM, and SGLang.
51+
4952
> *The accuracy loss after PTQ may vary depending on the actual model and the quantization method. Different models may have different accuracy loss and usually the accuracy loss is more significant when the base model is small. If the accuracy after PTQ is not meeting the requirement, please try either modifying [hf_ptq.py](../llm_ptq/hf_ptq.py) and disabling the KV cache quantization or using the [QAT](./../llm_qat/README.md) instead.*
5053
5154
## Framework Scripts

modelopt/torch/export/layer_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ def get_experts_list(module: torch.nn.Module, model_type: str):
9595
"qwen2moeforcausallm",
9696
"qwen3moeforcausallm",
9797
"qwen3nextforcausallm",
98+
"qwen3_5moeforconditionalgeneration",
9899
]
99100
):
100101
linear_names = ["gate_proj", "down_proj", "up_proj"]

modelopt/torch/export/model_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
"MPT": "mpt",
3030
"Bloom": "bloom",
3131
"ChatGLM": "chatglm",
32+
"Qwen3_5Moe": "qwen3_5moe",
33+
"Qwen3_5": "qwen3_5",
3234
"Qwen3Moe": "qwen3moe",
3335
"Qwen3Next": "qwen3next",
3436
"QWen": "qwen",

modelopt/torch/export/quant_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1216,12 +1216,12 @@ def _update_svdquant(modules, new_pre_quant_scale):
12161216
# Mathematical equivalence:
12171217
# Before: o_proj_out = [attn @ (v_proj_in @ v_proj.W^T)^T * scale] @ o_proj.W^T
12181218
# After: o_proj_out = [attn @ (v_proj_in @ (v_proj.W * scale)^T)^T] @ o_proj.W^T
1219-
(["LlamaAttention", "Qwen3Attention", "Qwen3MoeAttention"], ("v_proj", "o_proj")),
1219+
(["LlamaAttention", "Qwen3Attention", "Qwen3MoeAttention", "Qwen3_5Attention"], ("v_proj", "o_proj")),
12201220
# MLP: Fuse down_proj's pre_quant_scale into up_proj's output dimension
12211221
# Mathematical equivalence:
12221222
# Before: down_proj_out = {[act_fn(self.gate_proj(x)) * up_proj(x)] * scale} @ down_proj.W^T
12231223
# After: down_proj_out = {[act_fn(self.gate_proj(x)) * (up_proj(x) * scale)]} @ down_proj.W^T
1224-
(["LlamaMLP", "Qwen3MLP", "Qwen3MoeMLP"], ("up_proj", "down_proj")),
1224+
(["LlamaMLP", "Qwen3MLP", "Qwen3MoeMLP", "Qwen3_5MLP"], ("up_proj", "down_proj")),
12251225
]
12261226

12271227

modelopt/torch/export/unified_export_hf.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -360,21 +360,21 @@ def llm_dummy_forward():
360360
[1, model.config.num_mel_bins, feature_extractor.nb_max_frames], dtype=model.dtype
361361
).to(model.device)
362362

363-
if is_vl_model and "nemotron" in model_type:
364-
# For Nemotron VL models, run optimization on just the language model/decoder.
365-
# This avoids needing pixel_values for the vision encoder.
363+
if is_vl_model and any(tag in model_type for tag in ("nemotron", "qwen3_5")):
364+
# For VL models whose vision encoder requires pixel_values (Nemotron, Qwen3.5),
365+
# run optimization on just the language model / decoder to avoid needing
366+
# pixel_values for the vision encoder.
366367
language_model_lineage = get_language_model_from_vl(model)
367368

368369
if language_model_lineage is not None:
369370
language_model = language_model_lineage[-1]
370371
print(
371372
f"Running optimization on language model with fake_input shape: {fake_input.shape}"
372373
)
373-
# Pass use_cache=False to avoid KV cache issues in encoder-decoder models
374374
language_model(fake_input, use_cache=False)
375375
else:
376376
raise ValueError(
377-
f"Cannot extract language_model from Nemotron VL model (type: {model_type}). "
377+
f"Cannot extract language_model from VL model (type: {model_type}). "
378378
"This is required for requantization/resmoothing optimization. "
379379
"Please ensure the model architecture is supported or file an issue."
380380
)
@@ -468,7 +468,7 @@ def _export_quantized_weight(
468468
weight_scaling_factor,
469469
)
470470

471-
if hasattr(input_quantizer, "_amax"):
471+
if hasattr(input_quantizer, "_amax") and input_quantizer.is_enabled:
472472
assert input_quantizer is not None
473473
input_quantizer._amax = input_quantizer._amax.to(torch.float32)
474474

@@ -810,6 +810,25 @@ def _export_transformers_checkpoint(
810810
# Process all quantized modules and export weights
811811
_process_quantized_modules(model, dtype, is_modelopt_qlora)
812812

813+
# Clean up _QuantFusedExperts modules whose quantizers are all disabled.
814+
# When expert quantization is intentionally disabled (e.g. Qwen3.5-MoE to avoid
815+
# TRT-LLM intermediate_size mismatch), the _QuantFusedExperts wrapper still exists
816+
# but _process_quantized_modules skips it (QUANTIZATION_NONE). Remove the
817+
# leftover quantizer attributes so save_pretrained produces clean 3D fused weights.
818+
_fused_experts_attrs = (
819+
"gate_up_proj_weight_quantizers",
820+
"down_proj_weight_quantizers",
821+
"gate_up_proj_input_quantizer",
822+
"down_proj_input_quantizer",
823+
)
824+
for _name, _mod in model.named_modules():
825+
if not hasattr(_mod, "gate_up_proj_weight_quantizers"):
826+
continue
827+
if all(not q.is_enabled for q in _mod.gate_up_proj_weight_quantizers):
828+
for _attr in _fused_experts_attrs:
829+
if hasattr(_mod, _attr):
830+
delattr(_mod, _attr)
831+
813832
# Reconstruct fused MoELinear: per-expert _QuantLinear weights → original 3D format
814833
from modelopt.torch.quantization.plugins.huggingface import _reconstruct_fused_moe_linear
815834

modelopt/torch/quantization/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def find_quant_cfg_entry_by_path(
227227
"quantizer_name": "*mlp.shared_expert_gate.*",
228228
"enable": False,
229229
}, # Skip the MOE router
230-
{"quantizer_name": "*linear_attn.conv1d*", "enable": False},
230+
{"quantizer_name": "*linear_attn*", "enable": False}, # TRT-LLM linear-attn packing limit
231231
{"quantizer_name": "*mixer.conv1d*", "enable": False}, # Skip mamba conv1d
232232
{"quantizer_name": "*output_layer*", "enable": False},
233233
{"quantizer_name": "output.*", "enable": False},

tests/_test_utils/torch/transformers_models.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,111 @@
4040
SEED = 1234
4141

4242

43+
try:
44+
from transformers import Qwen3_5TextConfig
45+
except ImportError:
46+
Qwen3_5TextConfig = None
47+
48+
try:
49+
from transformers import Qwen3_5MoeTextConfig
50+
except ImportError:
51+
Qwen3_5MoeTextConfig = None
52+
53+
54+
##### Qwen3.5 Dense #####
55+
def get_tiny_qwen3_5(**config_kwargs) -> PreTrainedModel:
56+
"""Create a tiny Qwen3.5 Dense model (hybrid GatedDeltaNet + Softmax attention).
57+
58+
Requires ``transformers`` with ``Qwen3_5TextConfig`` support.
59+
"""
60+
if Qwen3_5TextConfig is None:
61+
pytest.skip("transformers does not have Qwen3_5TextConfig")
62+
63+
set_seed(SEED)
64+
kwargs = {
65+
"hidden_size": 32,
66+
"intermediate_size": 32,
67+
"num_hidden_layers": 4,
68+
"num_attention_heads": 4,
69+
"num_key_value_heads": 2,
70+
"max_position_embeddings": 64,
71+
"vocab_size": 32,
72+
"head_dim": 8,
73+
"short_chunk_size": 32,
74+
"attn_type": [0, 0, 0, 1],
75+
}
76+
kwargs.update(**config_kwargs)
77+
config = Qwen3_5TextConfig(**kwargs)
78+
tiny_model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.bfloat16)
79+
return tiny_model
80+
81+
82+
def create_tiny_qwen3_5_dir(
83+
tmp_path: Path | str, with_tokenizer: bool = False, return_model: bool = False, **config_kwargs
84+
) -> Path | tuple[Path, PreTrainedModel]:
85+
"""Save a tiny Qwen3.5 Dense model to disk for testing."""
86+
model_dir = Path(tmp_path) / "tiny_qwen3_5"
87+
if with_tokenizer:
88+
tokenizer = AutoTokenizer.from_pretrained(
89+
"hf-internal-testing/tiny-random-LlamaForCausalLM"
90+
)
91+
tokenizer.save_pretrained(model_dir)
92+
config_kwargs["vocab_size"] = tokenizer.vocab_size
93+
tiny_model = get_tiny_qwen3_5(**config_kwargs)
94+
tiny_model.save_pretrained(model_dir)
95+
96+
if return_model:
97+
return model_dir, tiny_model
98+
return model_dir
99+
100+
101+
##### Qwen3.5 MoE #####
102+
def get_tiny_qwen3_5_moe(**config_kwargs) -> PreTrainedModel:
103+
"""Create a tiny Qwen3.5 MoE model (hybrid attention + mixture-of-experts).
104+
105+
Requires ``transformers`` with ``Qwen3_5MoeTextConfig`` support.
106+
"""
107+
if Qwen3_5MoeTextConfig is None:
108+
pytest.skip("transformers does not have Qwen3_5MoeTextConfig")
109+
110+
set_seed(SEED)
111+
kwargs = {
112+
"hidden_size": 32,
113+
"intermediate_size": 32,
114+
"moe_intermediate_size": 32,
115+
"num_hidden_layers": 4,
116+
"num_attention_heads": 4,
117+
"num_key_value_heads": 2,
118+
"max_position_embeddings": 64,
119+
"vocab_size": 32,
120+
"head_dim": 8,
121+
"short_chunk_size": 32,
122+
"attn_type": [0, 0, 0, 1],
123+
"num_experts": 4,
124+
"num_experts_per_tok": 2,
125+
"decoder_sparse_step": 1,
126+
}
127+
kwargs.update(**config_kwargs)
128+
config = Qwen3_5MoeTextConfig(**kwargs)
129+
tiny_model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.bfloat16)
130+
return tiny_model
131+
132+
133+
def create_tiny_qwen3_5_moe_dir(
134+
tmp_path: Path | str, with_tokenizer: bool = False, **config_kwargs
135+
) -> Path:
136+
"""Save a tiny Qwen3.5 MoE model to disk for testing."""
137+
model_dir = Path(tmp_path) / "tiny_qwen3_5_moe"
138+
if with_tokenizer:
139+
tokenizer = AutoTokenizer.from_pretrained(
140+
"hf-internal-testing/tiny-random-LlamaForCausalLM"
141+
)
142+
tokenizer.save_pretrained(model_dir)
143+
config_kwargs["vocab_size"] = tokenizer.vocab_size
144+
get_tiny_qwen3_5_moe(**config_kwargs).save_pretrained(model_dir)
145+
return model_dir
146+
147+
43148
##### Qwen3 #####
44149
def get_tiny_qwen3(**config_kwargs) -> PreTrainedModel:
45150
set_seed(SEED)

tests/unit/torch/quantization/plugins/test_huggingface.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
create_tiny_llama_dir,
2525
get_tiny_gpt_oss,
2626
get_tiny_llama,
27+
get_tiny_qwen3_5,
28+
get_tiny_qwen3_5_moe,
2729
get_tiny_qwen3_moe,
2830
tf_modelopt_state_and_output_tester,
2931
)
@@ -243,3 +245,48 @@ def test_hf_decoder_discoverer_registration_path():
243245
assert LayerActivationCollector.get_decoder_layers(model) is get_homogeneous_hf_decoder_layers(
244246
model
245247
)
248+
249+
250+
def test_qwen3_5_hybrid_attention_quantize():
251+
"""Verify FP8 quantization disables all linear_attn quantizers while self_attn is quantized."""
252+
model = get_tiny_qwen3_5()
253+
mtq.quantize(model, mtq.FP8_DEFAULT_CFG, lambda m: m(**m.dummy_inputs))
254+
255+
for name, module in model.named_modules():
256+
if not hasattr(module, "weight_quantizer"):
257+
continue
258+
if "linear_attn" in name:
259+
assert not module.weight_quantizer.is_enabled, (
260+
f"linear_attn module {name} should have weight_quantizer disabled"
261+
)
262+
assert not module.input_quantizer.is_enabled, (
263+
f"linear_attn module {name} should have input_quantizer disabled"
264+
)
265+
elif "self_attn" in name and "layernorm" not in name:
266+
assert module.weight_quantizer.is_enabled, (
267+
f"self_attn module {name} should have weight_quantizer enabled"
268+
)
269+
270+
271+
@pytest.mark.skipif(
272+
Version(torch.__version__) < Version("2.9"),
273+
reason="torch 2.8 grouped_mm is CUDA-only",
274+
)
275+
def test_qwen3_5_moe_experts_not_quantized():
276+
"""Verify MoE expert quantizers are disabled when build_quant_cfg rules are applied."""
277+
model = get_tiny_qwen3_5_moe()
278+
279+
import copy
280+
281+
quant_cfg = copy.deepcopy(mtq.FP8_DEFAULT_CFG)
282+
quant_cfg["quant_cfg"].append({"quantizer_name": "*experts*", "enable": False})
283+
284+
mtq.quantize(model, quant_cfg, lambda m: m(**m.dummy_inputs))
285+
286+
for name, module in model.named_modules():
287+
if not hasattr(module, "weight_quantizer"):
288+
continue
289+
if "experts" in name:
290+
assert not module.weight_quantizer.is_enabled, (
291+
f"expert module {name} should have weight_quantizer disabled"
292+
)

0 commit comments

Comments
 (0)