Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion examples/llm_ptq/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ Please reference our [framework scripts](#framework-scripts) and our [docs](http
| Llama-Nemotron Ultra | ✅ | ❌ | ❌ | ❌ | ❌ |
| Gemma 3 | ✅<sup>2</sup> | - | ✅ | - | - |
| QWen 2, 2.5 <sup>4</sup> | ✅ | ✅ | ✅ | ✅ | ✅ |
| QWen3, 3.5 MOE, Next <sup>6</sup> | ✅ | - | - | - | ✅ |
| QWen3, Next <sup>6</sup> | ✅ | - | - | - | ✅ |
| QWen3.5 (Dense & MoE) <sup>6</sup> | ✅ | - | - | - | ✅ |
| QwQ | ✅ | - | - | - | ✅ |
| DeepSeek V3, R1, V3.1, V3.2<sup>7</sup> | - | - | - | - | ✅ |
| GLM-4.7<sup>8</sup> | ✅ | - | - | - | ✅ |
Expand Down Expand Up @@ -478,6 +479,8 @@ print(llm_fp8.generate(["What's the age of the earth? "]))
| QWen3 | FP4 | ✅ | ✅ | - |
| QWen3 MoE | FP8 | ✅ | ✅ | ✅ |
| QWen3 MoE | FP4 | ✅ | - | - |
| QWen3.5 Dense | FP8 | ✅ | ✅ | ✅ |
| QWen3.5 MoE | FP8 | ✅ | ✅ | ✅ |
| QWen3.5 MoE | FP4 | - | - | ✅ |
| QWen2.5 | FP8 | ✅ | ✅ | ✅ |
| QWen2.5 | FP4 | ✅ | ✅ | - |
Expand Down
5 changes: 5 additions & 0 deletions examples/llm_ptq/example_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,11 @@ def build_quant_cfg(
quant_cfg["quant_cfg"].append({"quantizer_name": "*image*", "enable": False})
quant_cfg["quant_cfg"].append({"quantizer_name": "*vision*", "enable": False})

if model_type == "qwen3_5moe":
Copy link
Copy Markdown
Collaborator

@shengliangxu shengliangxu Apr 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So if it is the TRT-LLM loading issue, should we fix TRT-LLM instead?

# TRT-LLM's Qwen3.5-MoE weight loader uses intermediate_size (default hidden_size*2)
# instead of moe_intermediate_size for expert buffer allocation, causing shape mismatches.
quant_cfg["quant_cfg"].append({"quantizer_name": "*experts*", "enable": False})

return quant_cfg


Expand Down
3 changes: 3 additions & 0 deletions examples/vlm_ptq/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ Please refer to the [llm_ptq/README.md](../llm_ptq/README.md#getting-started) fo
| VILA | ✅ | ✅ | ✅ | ✅ | - |
| Phi-3-vision, Phi-4-multimodal | ✅ | ✅ | ✅ | ✅ | ✅ |
| Qwen2, 2.5-VL | ✅ | ✅ | ✅ | ✅ | ✅ |
| Qwen3.5-VL (Dense & MoE) | ✅ | - | - | - | - |
| Gemma3 | ✅ | - | - | - | - |

> *<sup>1.</sup>Only TensorRT-LLM checkpoint export is supported. Not compatible with the TensorRT-LLM torch backend* \
Expand All @@ -46,6 +47,8 @@ Please refer to the [llm_ptq/README.md](../llm_ptq/README.md#getting-started) fo

> *For detailed TensorRT-LLM torch backend multimodal support, please refer to [this doc](https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/models/supported-models.md#multimodal-feature-support-matrix-pytorch-backend)*

> **Qwen3.5 VLM Note:** When quantizing Qwen3.5 VLM models, linear attention (`linear_attn`) layers are not quantized (TRT-LLM compatibility), and MoE expert layers are also excluded from quantization for the MoE variant. The exported checkpoint preserves the original VLM format (`Qwen3_5ForConditionalGeneration` architecture, `model.language_model.*` key prefix) and can be deployed directly on TRT-LLM, vLLM, and SGLang.

> *The accuracy loss after PTQ may vary depending on the actual model and the quantization method. Different models may have different accuracy loss and usually the accuracy loss is more significant when the base model is small. If the accuracy after PTQ is not meeting the requirement, please try either modifying [hf_ptq.py](../llm_ptq/hf_ptq.py) and disabling the KV cache quantization or using the [QAT](./../llm_qat/README.md) instead.*

## Framework Scripts
Expand Down
1 change: 1 addition & 0 deletions modelopt/torch/export/layer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def get_experts_list(module: torch.nn.Module, model_type: str):
"qwen2moeforcausallm",
"qwen3moeforcausallm",
"qwen3nextforcausallm",
"qwen3_5moeforconditionalgeneration",
]
):
linear_names = ["gate_proj", "down_proj", "up_proj"]
Expand Down
2 changes: 2 additions & 0 deletions modelopt/torch/export/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
"MPT": "mpt",
"Bloom": "bloom",
"ChatGLM": "chatglm",
"Qwen3_5Moe": "qwen3_5moe",
"Qwen3_5": "qwen3_5",
"Qwen3Moe": "qwen3moe",
"Qwen3Next": "qwen3next",
"QWen": "qwen",
Expand Down
4 changes: 2 additions & 2 deletions modelopt/torch/export/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1216,12 +1216,12 @@ def _update_svdquant(modules, new_pre_quant_scale):
# Mathematical equivalence:
# Before: o_proj_out = [attn @ (v_proj_in @ v_proj.W^T)^T * scale] @ o_proj.W^T
# After: o_proj_out = [attn @ (v_proj_in @ (v_proj.W * scale)^T)^T] @ o_proj.W^T
(["LlamaAttention", "Qwen3Attention", "Qwen3MoeAttention"], ("v_proj", "o_proj")),
(["LlamaAttention", "Qwen3Attention", "Qwen3MoeAttention", "Qwen3_5Attention"], ("v_proj", "o_proj")),
# MLP: Fuse down_proj's pre_quant_scale into up_proj's output dimension
# Mathematical equivalence:
# Before: down_proj_out = {[act_fn(self.gate_proj(x)) * up_proj(x)] * scale} @ down_proj.W^T
# After: down_proj_out = {[act_fn(self.gate_proj(x)) * (up_proj(x) * scale)]} @ down_proj.W^T
(["LlamaMLP", "Qwen3MLP", "Qwen3MoeMLP"], ("up_proj", "down_proj")),
(["LlamaMLP", "Qwen3MLP", "Qwen3MoeMLP", "Qwen3_5MLP"], ("up_proj", "down_proj")),
]


Expand Down
31 changes: 25 additions & 6 deletions modelopt/torch/export/unified_export_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,21 +360,21 @@ def llm_dummy_forward():
[1, model.config.num_mel_bins, feature_extractor.nb_max_frames], dtype=model.dtype
).to(model.device)

if is_vl_model and "nemotron" in model_type:
# For Nemotron VL models, run optimization on just the language model/decoder.
# This avoids needing pixel_values for the vision encoder.
if is_vl_model and any(tag in model_type for tag in ("nemotron", "qwen3_5")):
# For VL models whose vision encoder requires pixel_values (Nemotron, Qwen3.5),
# run optimization on just the language model / decoder to avoid needing
# pixel_values for the vision encoder.
language_model_lineage = get_language_model_from_vl(model)

if language_model_lineage is not None:
language_model = language_model_lineage[-1]
print(
f"Running optimization on language model with fake_input shape: {fake_input.shape}"
)
# Pass use_cache=False to avoid KV cache issues in encoder-decoder models
language_model(fake_input, use_cache=False)
else:
raise ValueError(
f"Cannot extract language_model from Nemotron VL model (type: {model_type}). "
f"Cannot extract language_model from VL model (type: {model_type}). "
"This is required for requantization/resmoothing optimization. "
"Please ensure the model architecture is supported or file an issue."
)
Expand Down Expand Up @@ -468,7 +468,7 @@ def _export_quantized_weight(
weight_scaling_factor,
)

if hasattr(input_quantizer, "_amax"):
if hasattr(input_quantizer, "_amax") and input_quantizer.is_enabled:
assert input_quantizer is not None
input_quantizer._amax = input_quantizer._amax.to(torch.float32)

Expand Down Expand Up @@ -810,6 +810,25 @@ def _export_transformers_checkpoint(
# Process all quantized modules and export weights
_process_quantized_modules(model, dtype, is_modelopt_qlora)

# Clean up _QuantFusedExperts modules whose quantizers are all disabled.
# When expert quantization is intentionally disabled (e.g. Qwen3.5-MoE to avoid
# TRT-LLM intermediate_size mismatch), the _QuantFusedExperts wrapper still exists
# but _process_quantized_modules skips it (QUANTIZATION_NONE). Remove the
# leftover quantizer attributes so save_pretrained produces clean 3D fused weights.
_fused_experts_attrs = (
"gate_up_proj_weight_quantizers",
"down_proj_weight_quantizers",
"gate_up_proj_input_quantizer",
"down_proj_input_quantizer",
)
for _name, _mod in model.named_modules():
if not hasattr(_mod, "gate_up_proj_weight_quantizers"):
continue
if all(not q.is_enabled for q in _mod.gate_up_proj_weight_quantizers):
for _attr in _fused_experts_attrs:
if hasattr(_mod, _attr):
delattr(_mod, _attr)

# Reconstruct fused MoELinear: per-expert _QuantLinear weights → original 3D format
from modelopt.torch.quantization.plugins.huggingface import _reconstruct_fused_moe_linear

Expand Down
2 changes: 1 addition & 1 deletion modelopt/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def find_quant_cfg_entry_by_path(
"quantizer_name": "*mlp.shared_expert_gate.*",
"enable": False,
}, # Skip the MOE router
{"quantizer_name": "*linear_attn.conv1d*", "enable": False},
{"quantizer_name": "*linear_attn*", "enable": False}, # TRT-LLM linear-attn packing limit
{"quantizer_name": "*mixer.conv1d*", "enable": False}, # Skip mamba conv1d
{"quantizer_name": "*output_layer*", "enable": False},
{"quantizer_name": "output.*", "enable": False},
Expand Down
105 changes: 105 additions & 0 deletions tests/_test_utils/torch/transformers_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,111 @@
SEED = 1234


try:
from transformers import Qwen3_5TextConfig
except ImportError:
Qwen3_5TextConfig = None

try:
from transformers import Qwen3_5MoeTextConfig
except ImportError:
Qwen3_5MoeTextConfig = None


##### Qwen3.5 Dense #####
def get_tiny_qwen3_5(**config_kwargs) -> PreTrainedModel:
"""Create a tiny Qwen3.5 Dense model (hybrid GatedDeltaNet + Softmax attention).

Requires ``transformers`` with ``Qwen3_5TextConfig`` support.
"""
if Qwen3_5TextConfig is None:
pytest.skip("transformers does not have Qwen3_5TextConfig")

set_seed(SEED)
kwargs = {
"hidden_size": 32,
"intermediate_size": 32,
"num_hidden_layers": 4,
"num_attention_heads": 4,
"num_key_value_heads": 2,
"max_position_embeddings": 64,
"vocab_size": 32,
"head_dim": 8,
"short_chunk_size": 32,
"attn_type": [0, 0, 0, 1],
}
kwargs.update(**config_kwargs)
config = Qwen3_5TextConfig(**kwargs)
tiny_model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.bfloat16)
return tiny_model


def create_tiny_qwen3_5_dir(
tmp_path: Path | str, with_tokenizer: bool = False, return_model: bool = False, **config_kwargs
) -> Path | tuple[Path, PreTrainedModel]:
"""Save a tiny Qwen3.5 Dense model to disk for testing."""
model_dir = Path(tmp_path) / "tiny_qwen3_5"
if with_tokenizer:
tokenizer = AutoTokenizer.from_pretrained(
"hf-internal-testing/tiny-random-LlamaForCausalLM"
)
tokenizer.save_pretrained(model_dir)
config_kwargs["vocab_size"] = tokenizer.vocab_size
tiny_model = get_tiny_qwen3_5(**config_kwargs)
tiny_model.save_pretrained(model_dir)

if return_model:
return model_dir, tiny_model
return model_dir


##### Qwen3.5 MoE #####
def get_tiny_qwen3_5_moe(**config_kwargs) -> PreTrainedModel:
"""Create a tiny Qwen3.5 MoE model (hybrid attention + mixture-of-experts).

Requires ``transformers`` with ``Qwen3_5MoeTextConfig`` support.
"""
if Qwen3_5MoeTextConfig is None:
pytest.skip("transformers does not have Qwen3_5MoeTextConfig")

set_seed(SEED)
kwargs = {
"hidden_size": 32,
"intermediate_size": 32,
"moe_intermediate_size": 32,
"num_hidden_layers": 4,
"num_attention_heads": 4,
"num_key_value_heads": 2,
"max_position_embeddings": 64,
"vocab_size": 32,
"head_dim": 8,
"short_chunk_size": 32,
"attn_type": [0, 0, 0, 1],
"num_experts": 4,
"num_experts_per_tok": 2,
"decoder_sparse_step": 1,
}
kwargs.update(**config_kwargs)
config = Qwen3_5MoeTextConfig(**kwargs)
tiny_model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.bfloat16)
return tiny_model


def create_tiny_qwen3_5_moe_dir(
tmp_path: Path | str, with_tokenizer: bool = False, **config_kwargs
) -> Path:
"""Save a tiny Qwen3.5 MoE model to disk for testing."""
model_dir = Path(tmp_path) / "tiny_qwen3_5_moe"
if with_tokenizer:
tokenizer = AutoTokenizer.from_pretrained(
"hf-internal-testing/tiny-random-LlamaForCausalLM"
)
tokenizer.save_pretrained(model_dir)
config_kwargs["vocab_size"] = tokenizer.vocab_size
get_tiny_qwen3_5_moe(**config_kwargs).save_pretrained(model_dir)
return model_dir


##### Qwen3 #####
def get_tiny_qwen3(**config_kwargs) -> PreTrainedModel:
set_seed(SEED)
Expand Down
47 changes: 47 additions & 0 deletions tests/unit/torch/quantization/plugins/test_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
create_tiny_llama_dir,
get_tiny_gpt_oss,
get_tiny_llama,
get_tiny_qwen3_5,
get_tiny_qwen3_5_moe,
get_tiny_qwen3_moe,
tf_modelopt_state_and_output_tester,
)
Expand Down Expand Up @@ -243,3 +245,48 @@ def test_hf_decoder_discoverer_registration_path():
assert LayerActivationCollector.get_decoder_layers(model) is get_homogeneous_hf_decoder_layers(
model
)


def test_qwen3_5_hybrid_attention_quantize():
"""Verify FP8 quantization disables all linear_attn quantizers while self_attn is quantized."""
model = get_tiny_qwen3_5()
mtq.quantize(model, mtq.FP8_DEFAULT_CFG, lambda m: m(**m.dummy_inputs))

for name, module in model.named_modules():
if not hasattr(module, "weight_quantizer"):
continue
if "linear_attn" in name:
assert not module.weight_quantizer.is_enabled, (
f"linear_attn module {name} should have weight_quantizer disabled"
)
assert not module.input_quantizer.is_enabled, (
f"linear_attn module {name} should have input_quantizer disabled"
)
elif "self_attn" in name and "layernorm" not in name:
assert module.weight_quantizer.is_enabled, (
f"self_attn module {name} should have weight_quantizer enabled"
)


@pytest.mark.skipif(
Version(torch.__version__) < Version("2.9"),
reason="torch 2.8 grouped_mm is CUDA-only",
)
def test_qwen3_5_moe_experts_not_quantized():
"""Verify MoE expert quantizers are disabled when build_quant_cfg rules are applied."""
model = get_tiny_qwen3_5_moe()

import copy

quant_cfg = copy.deepcopy(mtq.FP8_DEFAULT_CFG)
quant_cfg["quant_cfg"].append({"quantizer_name": "*experts*", "enable": False})

mtq.quantize(model, quant_cfg, lambda m: m(**m.dummy_inputs))

for name, module in model.named_modules():
if not hasattr(module, "weight_quantizer"):
continue
if "experts" in name:
assert not module.weight_quantizer.is_enabled, (
f"expert module {name} should have weight_quantizer disabled"
)