Skip to content

Commit 96a4a09

Browse files
moraxunv-guomingz
andauthored
[TRTLLM-12500][feat] Add support for Qwen3.5 VL MoE (#14164)
Signed-off-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com> Signed-off-by: Michal Guzek <mguzek@nvidia.com> Co-authored-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com>
1 parent 24b68fc commit 96a4a09

14 files changed

Lines changed: 1037 additions & 173 deletions

File tree

docs/source/models/supported-models.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ Note: Support for other models may vary. Features marked "N/A" are not applicabl
9696
| `Qwen2_5_VLForConditionalGeneration` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | L + I + V |
9797
| `Qwen3VLForConditionalGeneration` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | L + I + V |
9898
| `Qwen3VLMoeForConditionalGeneration` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | L + I + V |
99+
| `Qwen3_5MoeForConditionalGeneration` | Yes | Yes | Untested | Yes | Yes | No | Untested | Yes | L + I + V |
99100

100101
Note:
101102
- L: Language

tensorrt_llm/_torch/models/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@
3636
Qwen2ForRewardModel)
3737
from .modeling_qwen2vl import Qwen2_5_VLModel, Qwen2VLModel
3838
from .modeling_qwen3 import Qwen3ForCausalLM
39-
from .modeling_qwen3_5 import Qwen3_5ForCausalLM, Qwen3_5MoeForCausalLM
39+
from .modeling_qwen3_5 import (Qwen3_5ForCausalLM, Qwen3_5MoeForCausalLM,
40+
Qwen3_5MoeVLModel)
4041
from .modeling_qwen3_moe import Qwen3MoeForCausalLM
4142
from .modeling_qwen3_next import Qwen3NextForCausalLM
4243
from .modeling_qwen3vl import Qwen3VLModel
@@ -88,6 +89,7 @@
8889
"Qwen3MoeForCausalLM",
8990
"Qwen3_5ForCausalLM",
9091
"Qwen3_5MoeForCausalLM",
92+
"Qwen3_5MoeVLModel",
9193
"Qwen3NextForCausalLM",
9294
"Qwen3MoeVLModel",
9395
"GptOssForCausalLM",

tensorrt_llm/_torch/models/checkpoints/hf/qwen3_5_weight_mapper.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414

1515
@register_mapper("HF", "Qwen3_5MoeForCausalLM")
16+
@register_mapper("HF", "Qwen3_5MoeForConditionalGeneration")
1617
@register_mapper("HF", "Qwen3_5ForCausalLM")
1718
class Qwen3_5MoeHfWeightMapper(Qwen3NextHfWeightMapper):
1819
"""Weight mapper for Qwen3.5 MoE text checkpoints.

tensorrt_llm/_torch/models/modeling_qwen3_5.py

Lines changed: 312 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,29 @@
11
import re
2+
from types import SimpleNamespace
3+
from typing import Dict, List
24

5+
import torch
6+
from transformers import PretrainedConfig
7+
8+
from ...inputs import (
9+
ContentFormat,
10+
MultimodalPlaceholderMetadata,
11+
MultimodalPlaceholderPlacement,
12+
register_input_processor,
13+
support_multimodal_disaggregated,
14+
)
15+
from ..pyexecutor.config_utils import get_qwen3_hybrid_layer_types
16+
from .checkpoints.base_weight_mapper import BaseWeightMapper
17+
from .checkpoints.hf.qwen3_5_weight_mapper import Qwen3_5MoeHfWeightMapper
18+
from .modeling_multimodal_utils import _is_disagg
319
from .modeling_qwen3_next import Qwen3NextForCausalLM
4-
from .modeling_utils import register_auto_model
20+
from .modeling_qwen3vl import (
21+
Qwen3VisionModel,
22+
Qwen3VisionModelBase,
23+
Qwen3VLInputProcessorBase,
24+
Qwen3VLModelBase,
25+
)
26+
from .modeling_utils import ModelConfig, register_auto_model, register_vision_encoder
527

628
_LANG_PREFIX = "model.language_model."
729

@@ -51,6 +73,248 @@ def _translate_mtp_pattern(name, n_hidden_layers):
5173
return None
5274

5375

76+
# --- Config adapters --------------------------------------------------------
77+
#
78+
# These run from `load_pretrained_config` in
79+
# `tensorrt_llm/_torch/pyexecutor/config_utils.py` via lazy import — the
80+
# runtime layer asks the model module how to load its own config.
81+
#
82+
# There are two entry points:
83+
# - `Qwen35ConfigCompat.normalize(config_dict)` — for text-only
84+
# Qwen3.5 (MoE and dense). Returns a dict that
85+
# `transformers.Qwen3NextConfig.from_dict(...)` can consume, so the
86+
# existing Qwen3Next runtime is reused unchanged.
87+
# - `_normalize_qwen35_moe_vl_config(model_config)` — for the
88+
# Qwen3.5-MoE VLM. Mutates the HF-native `transformers.Qwen3_5MoeConfig`
89+
# in place, attaching the runtime aliases the Qwen3Next-based LM expects
90+
# while keeping `text_config` / `vision_config` composite.
91+
92+
93+
class Qwen35ConfigCompat:
94+
"""Temporary shim for flattening Qwen3.5 text configs into Qwen3NextConfig.
95+
96+
We normalize to `Qwen3NextConfig` (rather than to a Qwen3.5-native
97+
schema) so the runtime can reuse the existing `Qwen3NextForCausalLM`
98+
model implementation unchanged — Qwen3.5 text is structurally identical
99+
to Qwen3Next, so matching the config schema lets the same code serve
100+
both.
101+
102+
This is used for Qwen3.5 text-only configs and for shared helper logic such
103+
as RoPE and quantization exclude-module normalization. Qwen3.5-MoE VLM
104+
configs should stay composite and use transformers.Qwen3_5MoeConfig plus
105+
_normalize_qwen35_moe_vl_config instead.
106+
107+
To remove: delete this class and the elif branch in
108+
load_pretrained_config that flattens Qwen3.5 text configs.
109+
"""
110+
111+
@staticmethod
112+
def normalize(config_dict: dict) -> dict:
113+
"""Entry point: raw config.json dict -> flat Qwen3NextConfig-compatible dict."""
114+
text_config = Qwen35ConfigCompat._extract_text_config(config_dict)
115+
text_config = Qwen35ConfigCompat._inherit_quantization_config(config_dict, text_config)
116+
text_config = Qwen35ConfigCompat._flatten_rope(text_config)
117+
118+
# Detect dense vs MoE and set architecture + MoE defaults accordingly
119+
is_moe = "num_experts" in text_config and text_config["num_experts"] > 0
120+
if is_moe:
121+
text_config["architectures"] = ["Qwen3_5MoeForCausalLM"]
122+
else:
123+
text_config["architectures"] = ["Qwen3_5ForCausalLM"]
124+
# Ensure MoE fields are zeroed so Qwen3NextConfig defaults don't
125+
# accidentally enable MoE for the dense model.
126+
text_config.setdefault("num_experts", 0)
127+
text_config.setdefault("num_experts_per_tok", 0)
128+
text_config.setdefault("moe_intermediate_size", 0)
129+
text_config.setdefault("shared_expert_intermediate_size", 0)
130+
return text_config
131+
132+
_VLM_ARCHITECTURES = {
133+
"Qwen3_5MoeForConditionalGeneration",
134+
"Qwen3_5ForConditionalGeneration",
135+
}
136+
137+
@staticmethod
138+
def _extract_text_config(config_dict: dict) -> dict:
139+
"""Pull nested text_config from VLM checkpoints, or use dict as-is."""
140+
architectures = config_dict.get("architectures") or []
141+
if architectures and architectures[0] in Qwen35ConfigCompat._VLM_ARCHITECTURES:
142+
text_config = dict(config_dict.get("text_config") or {})
143+
else:
144+
text_config = dict(config_dict)
145+
if not text_config:
146+
raise ValueError("Qwen3.5 config is missing a usable text_config")
147+
return text_config
148+
149+
@staticmethod
150+
def _inherit_quantization_config(config_dict: dict, text_config: dict) -> dict:
151+
"""Copy top-level quantization_config into text_config with name normalization.
152+
153+
Also adds a temporary workaround that keeps packed linear-attention
154+
in_proj_qkvz on the bf16 path until FP8 block-scale TP loading is
155+
fixed for that layout.
156+
"""
157+
if "quantization_config" in text_config:
158+
return text_config
159+
if "quantization_config" not in config_dict:
160+
return text_config
161+
162+
quantization_config = dict(config_dict["quantization_config"])
163+
if "modules_to_not_convert" in quantization_config:
164+
modules = Qwen35ConfigCompat._normalize_exclude_modules(
165+
quantization_config["modules_to_not_convert"]
166+
)
167+
modules = Qwen35ConfigCompat._add_qkvz_bf16_workaround(text_config, modules)
168+
quantization_config["modules_to_not_convert"] = sorted(set(modules))
169+
text_config["quantization_config"] = quantization_config
170+
return text_config
171+
172+
@staticmethod
173+
def _normalize_exclude_modules(modules: list[str]) -> list[str]:
174+
"""Translate HF quantization exclude-module paths to TRT-LLM names.
175+
176+
- Strip model.language_model. prefix -> model.
177+
- Drop model.visual.* and mtp.* entries
178+
- Map split projection names to packed TRT-LLM names
179+
"""
180+
normalized = set()
181+
for name in modules:
182+
if name.startswith("model.language_model."):
183+
name = "model." + name[len("model.language_model.") :]
184+
if name.startswith("model.visual.") or name.startswith("mtp."):
185+
continue
186+
name = re.sub(r"\.in_proj_[ab]$", ".in_proj_ba", name)
187+
name = re.sub(r"\.in_proj_(q|k|v|z|qkv)$", ".in_proj_qkvz", name)
188+
normalized.add(name)
189+
return sorted(normalized)
190+
191+
@staticmethod
192+
def _add_qkvz_bf16_workaround(text_config: dict, modules: list[str]) -> list[str]:
193+
"""Keep packed linear-attention qkvz on bf16 path for all linear-attention layers.
194+
195+
Temporary until FP8 block-scale TP loading is fixed for this layout.
196+
"""
197+
try:
198+
layer_types = get_qwen3_hybrid_layer_types(SimpleNamespace(**text_config))
199+
except (ValueError, AttributeError):
200+
return modules
201+
for layer_idx, layer_type in enumerate(layer_types):
202+
if layer_type == "linear_attention":
203+
modules.append(f"model.layers.{layer_idx}.linear_attn.in_proj_qkvz")
204+
return modules
205+
206+
@staticmethod
207+
def _flatten_rope(text_config: dict) -> dict:
208+
"""Flatten rope_parameters into top-level rope_theta / partial_rotary_factor / rope_scaling.
209+
210+
Qwen3.5 nests these inside a rope_parameters dict and uses rope_type
211+
instead of type in rope_scaling. Qwen3NextConfig expects them as
212+
top-level fields with rope_scaling.type.
213+
"""
214+
rope_parameters = dict(text_config.pop("rope_parameters", {}) or {})
215+
rope_scaling = dict(text_config.get("rope_scaling") or {})
216+
if rope_parameters:
217+
rope_theta = rope_parameters.pop("rope_theta", None)
218+
if rope_theta is not None:
219+
text_config.setdefault("rope_theta", rope_theta)
220+
partial_rotary_factor = rope_parameters.pop("partial_rotary_factor", None)
221+
if partial_rotary_factor is not None:
222+
text_config.setdefault("partial_rotary_factor", partial_rotary_factor)
223+
if rope_parameters:
224+
rope_scaling = rope_parameters | rope_scaling
225+
if rope_scaling:
226+
has_mrope = "mrope_section" in rope_scaling or rope_scaling.get(
227+
"mrope_interleaved", False
228+
)
229+
if has_mrope:
230+
rope_scaling["type"] = "mrope"
231+
rope_scaling.pop("rope_type", None)
232+
elif "type" not in rope_scaling and "rope_type" in rope_scaling:
233+
rope_type = rope_scaling.pop("rope_type")
234+
# "default" means standard RoPE (no scaling) — don't set
235+
# rope_scaling to avoid triggering scaling code paths.
236+
if rope_type == "default":
237+
rope_scaling = {}
238+
else:
239+
rope_scaling["type"] = rope_type
240+
if rope_scaling:
241+
text_config["rope_scaling"] = rope_scaling
242+
return text_config
243+
244+
245+
def _normalize_qwen35_mrope_config(text_config) -> None:
246+
"""Materialize Qwen3.5 mRoPE aliases needed by the Qwen3-VL path.
247+
248+
HF stores RoPE metadata under `rope_parameters`; the shared Qwen3-VL
249+
wrapper reads `rope_theta`, `partial_rotary_factor`, and
250+
`rope_scaling` directly on the text config.
251+
"""
252+
rope_parameters = getattr(text_config, "rope_parameters", None)
253+
if not rope_parameters:
254+
return
255+
if hasattr(rope_parameters, "to_dict"):
256+
rope_parameters = rope_parameters.to_dict()
257+
flattened = Qwen35ConfigCompat._flatten_rope(
258+
{
259+
"rope_parameters": dict(rope_parameters),
260+
"rope_scaling": dict(getattr(text_config, "rope_scaling", None) or {}),
261+
}
262+
)
263+
for attr in ("rope_theta", "partial_rotary_factor", "rope_scaling"):
264+
value = flattened.get(attr)
265+
if value is not None:
266+
setattr(text_config, attr, value)
267+
268+
269+
def _normalize_qwen35_qwen3next_text_aliases(text_config) -> None:
270+
"""Materialize Qwen3Next-style text aliases used by the shared runtime."""
271+
if getattr(text_config, "intermediate_size", None) is None:
272+
moe_intermediate_size = getattr(text_config, "moe_intermediate_size", None)
273+
num_experts_per_tok = getattr(text_config, "num_experts_per_tok", None)
274+
shared_expert_intermediate_size = (
275+
getattr(text_config, "shared_expert_intermediate_size", 0) or 0
276+
)
277+
if moe_intermediate_size is not None and num_experts_per_tok is not None:
278+
text_config.intermediate_size = (
279+
num_experts_per_tok * moe_intermediate_size + shared_expert_intermediate_size
280+
)
281+
282+
283+
def _normalize_qwen35_quantization_config(model_config) -> None:
284+
quantization_config = getattr(model_config, "quantization_config", None)
285+
if not isinstance(quantization_config, dict):
286+
return
287+
288+
modules = quantization_config.get("modules_to_not_convert")
289+
if modules is None:
290+
return
291+
292+
text_config = getattr(model_config, "text_config", None)
293+
normalized_modules = Qwen35ConfigCompat._normalize_exclude_modules(modules)
294+
if text_config is not None:
295+
normalized_modules = Qwen35ConfigCompat._add_qkvz_bf16_workaround(
296+
text_config.to_dict(), normalized_modules
297+
)
298+
quantization_config["modules_to_not_convert"] = sorted(set(normalized_modules))
299+
300+
301+
def _normalize_qwen35_moe_vl_config(model_config) -> None:
302+
"""Adapt HF Qwen3.5-MoE VLM config to TRT-LLM runtime conventions."""
303+
if not getattr(model_config, "architectures", None):
304+
model_config.architectures = ["Qwen3_5MoeForConditionalGeneration"]
305+
306+
text_config = getattr(model_config, "text_config", None)
307+
if text_config is None:
308+
raise ValueError("Qwen3.5-MoE VLM config is missing text_config")
309+
310+
text_config.architectures = ["Qwen3_5MoeForCausalLM"]
311+
_normalize_qwen35_qwen3next_text_aliases(text_config)
312+
_normalize_qwen35_mrope_config(text_config)
313+
314+
model_config.get_text_config = lambda decoder=False: text_config
315+
_normalize_qwen35_quantization_config(model_config)
316+
317+
54318
def _normalize_qwen35_exclude_modules(model_config):
55319
"""Normalize NVFP4/FP8 exclude_modules from HF naming to TRT-LLM naming.
56320
@@ -126,10 +390,56 @@ class Qwen3_5ForCausalLM(Qwen3NextForCausalLM):
126390
127391
Same reuse pattern as Qwen3_5MoeForCausalLM, but for the dense 27B
128392
variant which uses GatedMLP instead of SparseMoeBlock. The config
129-
normalizer (_Qwen35ConfigCompat) sets num_experts=0 so that
393+
normalizer (Qwen35ConfigCompat) sets num_experts=0 so that
130394
Qwen3NextModel selects GatedMLP for the feed-forward layers.
131395
"""
132396

133397
def __init__(self, model_config):
134398
_normalize_qwen35_exclude_modules(model_config)
135399
super().__init__(model_config)
400+
401+
402+
# TODO: Add tests for disaggregated support.
403+
@support_multimodal_disaggregated
404+
@register_vision_encoder(Qwen3VisionModelBase, vlm_base_model=Qwen3VisionModel)
405+
@register_auto_model("Qwen3_5MoeForConditionalGeneration")
406+
@register_input_processor(
407+
Qwen3VLInputProcessorBase,
408+
model_type="qwen3_5_moe",
409+
placeholder_metadata=MultimodalPlaceholderMetadata(
410+
placeholder_map={
411+
"image": "<|vision_start|><|image_pad|><|vision_end|>",
412+
"video": "<|vision_start|><|video_pad|><|vision_end|>",
413+
},
414+
placeholder_placement=MultimodalPlaceholderPlacement.BEFORE_TEXT,
415+
placeholders_separator="",
416+
content_format=ContentFormat.STRING,
417+
),
418+
)
419+
class Qwen3_5MoeVLModel(Qwen3VLModelBase):
420+
"""VLM wrapper composing Qwen3 vision encoder with Qwen3.5 MoE text decoder."""
421+
422+
def __init__(self, model_config: ModelConfig[PretrainedConfig], *args, **kwargs):
423+
kwargs["vision_model_class"] = Qwen3VisionModel
424+
kwargs["disable_fuse_rope"] = kwargs.get("disable_fuse_rope", False)
425+
super().__init__(model_config, *args, **kwargs)
426+
427+
@property
428+
def multimodal_data_device_paths(self) -> List[str]:
429+
return [
430+
"image.pixel_values",
431+
"video.pixel_values_videos",
432+
"multimodal_embedding",
433+
]
434+
435+
def load_weights(self, weights: Dict[str, torch.Tensor], weight_mapper: BaseWeightMapper):
436+
if not _is_disagg():
437+
self.mm_encoder.load_weights(weights)
438+
439+
weight_mapper = Qwen3_5MoeHfWeightMapper()
440+
weight_mapper.init_model_and_config(self.llm, self.model_config)
441+
filtered_weights = {k: v for k, v in weights.items() if not k.startswith("model.visual.")}
442+
params_map = {
443+
r"^model\.language_model\.(.*)$": r"model.\1",
444+
}
445+
self.llm.load_weights(filtered_weights, weight_mapper, params_map=params_map)

tensorrt_llm/_torch/models/modeling_qwen3_next.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -973,9 +973,18 @@ def get_model_defaults(cls, llm_args: 'TorchLlmArgs') -> dict:
973973
# is supported for Mamba/SSM-based models
974974
return {"kv_cache_config": {"enable_block_reuse": False}}
975975

976-
def load_weights(self, weights: dict, weight_mapper: BaseWeightMapper):
976+
def load_weights(self,
977+
weights: dict,
978+
weight_mapper: BaseWeightMapper,
979+
params_map: Optional[Dict[str, str]] = None,
980+
allow_partial_loading: bool = False):
977981
new_weights = weight_mapper.preprocess_weights(weights)
978-
super().load_weights(new_weights, weight_mapper)
982+
super().load_weights(
983+
new_weights,
984+
weight_mapper=weight_mapper,
985+
params_map=params_map,
986+
allow_partial_loading=allow_partial_loading,
987+
)
979988

980989
def post_load_weights(self):
981990
for idx, layer in enumerate(

0 commit comments

Comments
 (0)