|
1 | 1 | import re |
2 | | -from types import SimpleNamespace |
3 | | -from typing import Dict, List |
4 | 2 |
|
5 | | -import torch |
6 | | -from transformers import PretrainedConfig |
7 | | - |
8 | | -from ...inputs import ( |
9 | | - ContentFormat, |
10 | | - MultimodalPlaceholderMetadata, |
11 | | - MultimodalPlaceholderPlacement, |
12 | | - register_input_processor, |
13 | | - support_multimodal_disaggregated, |
14 | | -) |
15 | | -from ..pyexecutor.config_utils import get_qwen3_hybrid_layer_types |
16 | | -from .checkpoints.base_weight_mapper import BaseWeightMapper |
17 | | -from .checkpoints.hf.qwen3_5_weight_mapper import Qwen3_5MoeHfWeightMapper |
18 | | -from .modeling_multimodal_utils import _is_disagg |
19 | 3 | from .modeling_qwen3_next import Qwen3NextForCausalLM |
20 | | -from .modeling_qwen3vl import ( |
21 | | - Qwen3VisionModel, |
22 | | - Qwen3VisionModelBase, |
23 | | - Qwen3VLInputProcessorBase, |
24 | | - Qwen3VLModelBase, |
25 | | -) |
26 | | -from .modeling_utils import ModelConfig, register_auto_model, register_vision_encoder |
| 4 | +from .modeling_utils import register_auto_model |
27 | 5 |
|
28 | 6 | _LANG_PREFIX = "model.language_model." |
29 | 7 |
|
@@ -73,248 +51,6 @@ def _translate_mtp_pattern(name, n_hidden_layers): |
73 | 51 | return None |
74 | 52 |
|
75 | 53 |
|
76 | | -# --- Config adapters -------------------------------------------------------- |
77 | | -# |
78 | | -# These run from `load_pretrained_config` in |
79 | | -# `tensorrt_llm/_torch/pyexecutor/config_utils.py` via lazy import — the |
80 | | -# runtime layer asks the model module how to load its own config. |
81 | | -# |
82 | | -# There are two entry points: |
83 | | -# - `Qwen35ConfigCompat.normalize(config_dict)` — for text-only |
84 | | -# Qwen3.5 (MoE and dense). Returns a dict that |
85 | | -# `transformers.Qwen3NextConfig.from_dict(...)` can consume, so the |
86 | | -# existing Qwen3Next runtime is reused unchanged. |
87 | | -# - `_normalize_qwen35_moe_vl_config(model_config)` — for the |
88 | | -# Qwen3.5-MoE VLM. Mutates the HF-native `transformers.Qwen3_5MoeConfig` |
89 | | -# in place, attaching the runtime aliases the Qwen3Next-based LM expects |
90 | | -# while keeping `text_config` / `vision_config` composite. |
91 | | - |
92 | | - |
93 | | -class Qwen35ConfigCompat: |
94 | | - """Temporary shim for flattening Qwen3.5 text configs into Qwen3NextConfig. |
95 | | -
|
96 | | - We normalize to `Qwen3NextConfig` (rather than to a Qwen3.5-native |
97 | | - schema) so the runtime can reuse the existing `Qwen3NextForCausalLM` |
98 | | - model implementation unchanged — Qwen3.5 text is structurally identical |
99 | | - to Qwen3Next, so matching the config schema lets the same code serve |
100 | | - both. |
101 | | -
|
102 | | - This is used for Qwen3.5 text-only configs and for shared helper logic such |
103 | | - as RoPE and quantization exclude-module normalization. Qwen3.5-MoE VLM |
104 | | - configs should stay composite and use transformers.Qwen3_5MoeConfig plus |
105 | | - _normalize_qwen35_moe_vl_config instead. |
106 | | -
|
107 | | - To remove: delete this class and the elif branch in |
108 | | - load_pretrained_config that flattens Qwen3.5 text configs. |
109 | | - """ |
110 | | - |
111 | | - @staticmethod |
112 | | - def normalize(config_dict: dict) -> dict: |
113 | | - """Entry point: raw config.json dict -> flat Qwen3NextConfig-compatible dict.""" |
114 | | - text_config = Qwen35ConfigCompat._extract_text_config(config_dict) |
115 | | - text_config = Qwen35ConfigCompat._inherit_quantization_config(config_dict, text_config) |
116 | | - text_config = Qwen35ConfigCompat._flatten_rope(text_config) |
117 | | - |
118 | | - # Detect dense vs MoE and set architecture + MoE defaults accordingly |
119 | | - is_moe = "num_experts" in text_config and text_config["num_experts"] > 0 |
120 | | - if is_moe: |
121 | | - text_config["architectures"] = ["Qwen3_5MoeForCausalLM"] |
122 | | - else: |
123 | | - text_config["architectures"] = ["Qwen3_5ForCausalLM"] |
124 | | - # Ensure MoE fields are zeroed so Qwen3NextConfig defaults don't |
125 | | - # accidentally enable MoE for the dense model. |
126 | | - text_config.setdefault("num_experts", 0) |
127 | | - text_config.setdefault("num_experts_per_tok", 0) |
128 | | - text_config.setdefault("moe_intermediate_size", 0) |
129 | | - text_config.setdefault("shared_expert_intermediate_size", 0) |
130 | | - return text_config |
131 | | - |
132 | | - _VLM_ARCHITECTURES = { |
133 | | - "Qwen3_5MoeForConditionalGeneration", |
134 | | - "Qwen3_5ForConditionalGeneration", |
135 | | - } |
136 | | - |
137 | | - @staticmethod |
138 | | - def _extract_text_config(config_dict: dict) -> dict: |
139 | | - """Pull nested text_config from VLM checkpoints, or use dict as-is.""" |
140 | | - architectures = config_dict.get("architectures") or [] |
141 | | - if architectures and architectures[0] in Qwen35ConfigCompat._VLM_ARCHITECTURES: |
142 | | - text_config = dict(config_dict.get("text_config") or {}) |
143 | | - else: |
144 | | - text_config = dict(config_dict) |
145 | | - if not text_config: |
146 | | - raise ValueError("Qwen3.5 config is missing a usable text_config") |
147 | | - return text_config |
148 | | - |
149 | | - @staticmethod |
150 | | - def _inherit_quantization_config(config_dict: dict, text_config: dict) -> dict: |
151 | | - """Copy top-level quantization_config into text_config with name normalization. |
152 | | -
|
153 | | - Also adds a temporary workaround that keeps packed linear-attention |
154 | | - in_proj_qkvz on the bf16 path until FP8 block-scale TP loading is |
155 | | - fixed for that layout. |
156 | | - """ |
157 | | - if "quantization_config" in text_config: |
158 | | - return text_config |
159 | | - if "quantization_config" not in config_dict: |
160 | | - return text_config |
161 | | - |
162 | | - quantization_config = dict(config_dict["quantization_config"]) |
163 | | - if "modules_to_not_convert" in quantization_config: |
164 | | - modules = Qwen35ConfigCompat._normalize_exclude_modules( |
165 | | - quantization_config["modules_to_not_convert"] |
166 | | - ) |
167 | | - modules = Qwen35ConfigCompat._add_qkvz_bf16_workaround(text_config, modules) |
168 | | - quantization_config["modules_to_not_convert"] = sorted(set(modules)) |
169 | | - text_config["quantization_config"] = quantization_config |
170 | | - return text_config |
171 | | - |
172 | | - @staticmethod |
173 | | - def _normalize_exclude_modules(modules: list[str]) -> list[str]: |
174 | | - """Translate HF quantization exclude-module paths to TRT-LLM names. |
175 | | -
|
176 | | - - Strip model.language_model. prefix -> model. |
177 | | - - Drop model.visual.* and mtp.* entries |
178 | | - - Map split projection names to packed TRT-LLM names |
179 | | - """ |
180 | | - normalized = set() |
181 | | - for name in modules: |
182 | | - if name.startswith("model.language_model."): |
183 | | - name = "model." + name[len("model.language_model.") :] |
184 | | - if name.startswith("model.visual.") or name.startswith("mtp."): |
185 | | - continue |
186 | | - name = re.sub(r"\.in_proj_[ab]$", ".in_proj_ba", name) |
187 | | - name = re.sub(r"\.in_proj_(q|k|v|z|qkv)$", ".in_proj_qkvz", name) |
188 | | - normalized.add(name) |
189 | | - return sorted(normalized) |
190 | | - |
191 | | - @staticmethod |
192 | | - def _add_qkvz_bf16_workaround(text_config: dict, modules: list[str]) -> list[str]: |
193 | | - """Keep packed linear-attention qkvz on bf16 path for all linear-attention layers. |
194 | | -
|
195 | | - Temporary until FP8 block-scale TP loading is fixed for this layout. |
196 | | - """ |
197 | | - try: |
198 | | - layer_types = get_qwen3_hybrid_layer_types(SimpleNamespace(**text_config)) |
199 | | - except (ValueError, AttributeError): |
200 | | - return modules |
201 | | - for layer_idx, layer_type in enumerate(layer_types): |
202 | | - if layer_type == "linear_attention": |
203 | | - modules.append(f"model.layers.{layer_idx}.linear_attn.in_proj_qkvz") |
204 | | - return modules |
205 | | - |
206 | | - @staticmethod |
207 | | - def _flatten_rope(text_config: dict) -> dict: |
208 | | - """Flatten rope_parameters into top-level rope_theta / partial_rotary_factor / rope_scaling. |
209 | | -
|
210 | | - Qwen3.5 nests these inside a rope_parameters dict and uses rope_type |
211 | | - instead of type in rope_scaling. Qwen3NextConfig expects them as |
212 | | - top-level fields with rope_scaling.type. |
213 | | - """ |
214 | | - rope_parameters = dict(text_config.pop("rope_parameters", {}) or {}) |
215 | | - rope_scaling = dict(text_config.get("rope_scaling") or {}) |
216 | | - if rope_parameters: |
217 | | - rope_theta = rope_parameters.pop("rope_theta", None) |
218 | | - if rope_theta is not None: |
219 | | - text_config.setdefault("rope_theta", rope_theta) |
220 | | - partial_rotary_factor = rope_parameters.pop("partial_rotary_factor", None) |
221 | | - if partial_rotary_factor is not None: |
222 | | - text_config.setdefault("partial_rotary_factor", partial_rotary_factor) |
223 | | - if rope_parameters: |
224 | | - rope_scaling = rope_parameters | rope_scaling |
225 | | - if rope_scaling: |
226 | | - has_mrope = "mrope_section" in rope_scaling or rope_scaling.get( |
227 | | - "mrope_interleaved", False |
228 | | - ) |
229 | | - if has_mrope: |
230 | | - rope_scaling["type"] = "mrope" |
231 | | - rope_scaling.pop("rope_type", None) |
232 | | - elif "type" not in rope_scaling and "rope_type" in rope_scaling: |
233 | | - rope_type = rope_scaling.pop("rope_type") |
234 | | - # "default" means standard RoPE (no scaling) — don't set |
235 | | - # rope_scaling to avoid triggering scaling code paths. |
236 | | - if rope_type == "default": |
237 | | - rope_scaling = {} |
238 | | - else: |
239 | | - rope_scaling["type"] = rope_type |
240 | | - if rope_scaling: |
241 | | - text_config["rope_scaling"] = rope_scaling |
242 | | - return text_config |
243 | | - |
244 | | - |
245 | | -def _normalize_qwen35_mrope_config(text_config) -> None: |
246 | | - """Materialize Qwen3.5 mRoPE aliases needed by the Qwen3-VL path. |
247 | | -
|
248 | | - HF stores RoPE metadata under `rope_parameters`; the shared Qwen3-VL |
249 | | - wrapper reads `rope_theta`, `partial_rotary_factor`, and |
250 | | - `rope_scaling` directly on the text config. |
251 | | - """ |
252 | | - rope_parameters = getattr(text_config, "rope_parameters", None) |
253 | | - if not rope_parameters: |
254 | | - return |
255 | | - if hasattr(rope_parameters, "to_dict"): |
256 | | - rope_parameters = rope_parameters.to_dict() |
257 | | - flattened = Qwen35ConfigCompat._flatten_rope( |
258 | | - { |
259 | | - "rope_parameters": dict(rope_parameters), |
260 | | - "rope_scaling": dict(getattr(text_config, "rope_scaling", None) or {}), |
261 | | - } |
262 | | - ) |
263 | | - for attr in ("rope_theta", "partial_rotary_factor", "rope_scaling"): |
264 | | - value = flattened.get(attr) |
265 | | - if value is not None: |
266 | | - setattr(text_config, attr, value) |
267 | | - |
268 | | - |
269 | | -def _normalize_qwen35_qwen3next_text_aliases(text_config) -> None: |
270 | | - """Materialize Qwen3Next-style text aliases used by the shared runtime.""" |
271 | | - if getattr(text_config, "intermediate_size", None) is None: |
272 | | - moe_intermediate_size = getattr(text_config, "moe_intermediate_size", None) |
273 | | - num_experts_per_tok = getattr(text_config, "num_experts_per_tok", None) |
274 | | - shared_expert_intermediate_size = ( |
275 | | - getattr(text_config, "shared_expert_intermediate_size", 0) or 0 |
276 | | - ) |
277 | | - if moe_intermediate_size is not None and num_experts_per_tok is not None: |
278 | | - text_config.intermediate_size = ( |
279 | | - num_experts_per_tok * moe_intermediate_size + shared_expert_intermediate_size |
280 | | - ) |
281 | | - |
282 | | - |
283 | | -def _normalize_qwen35_quantization_config(model_config) -> None: |
284 | | - quantization_config = getattr(model_config, "quantization_config", None) |
285 | | - if not isinstance(quantization_config, dict): |
286 | | - return |
287 | | - |
288 | | - modules = quantization_config.get("modules_to_not_convert") |
289 | | - if modules is None: |
290 | | - return |
291 | | - |
292 | | - text_config = getattr(model_config, "text_config", None) |
293 | | - normalized_modules = Qwen35ConfigCompat._normalize_exclude_modules(modules) |
294 | | - if text_config is not None: |
295 | | - normalized_modules = Qwen35ConfigCompat._add_qkvz_bf16_workaround( |
296 | | - text_config.to_dict(), normalized_modules |
297 | | - ) |
298 | | - quantization_config["modules_to_not_convert"] = sorted(set(normalized_modules)) |
299 | | - |
300 | | - |
301 | | -def _normalize_qwen35_moe_vl_config(model_config) -> None: |
302 | | - """Adapt HF Qwen3.5-MoE VLM config to TRT-LLM runtime conventions.""" |
303 | | - if not getattr(model_config, "architectures", None): |
304 | | - model_config.architectures = ["Qwen3_5MoeForConditionalGeneration"] |
305 | | - |
306 | | - text_config = getattr(model_config, "text_config", None) |
307 | | - if text_config is None: |
308 | | - raise ValueError("Qwen3.5-MoE VLM config is missing text_config") |
309 | | - |
310 | | - text_config.architectures = ["Qwen3_5MoeForCausalLM"] |
311 | | - _normalize_qwen35_qwen3next_text_aliases(text_config) |
312 | | - _normalize_qwen35_mrope_config(text_config) |
313 | | - |
314 | | - model_config.get_text_config = lambda decoder=False: text_config |
315 | | - _normalize_qwen35_quantization_config(model_config) |
316 | | - |
317 | | - |
318 | 54 | def _normalize_qwen35_exclude_modules(model_config): |
319 | 55 | """Normalize NVFP4/FP8 exclude_modules from HF naming to TRT-LLM naming. |
320 | 56 |
|
@@ -390,56 +126,10 @@ class Qwen3_5ForCausalLM(Qwen3NextForCausalLM): |
390 | 126 |
|
391 | 127 | Same reuse pattern as Qwen3_5MoeForCausalLM, but for the dense 27B |
392 | 128 | variant which uses GatedMLP instead of SparseMoeBlock. The config |
393 | | - normalizer (Qwen35ConfigCompat) sets num_experts=0 so that |
| 129 | + normalizer (_Qwen35ConfigCompat) sets num_experts=0 so that |
394 | 130 | Qwen3NextModel selects GatedMLP for the feed-forward layers. |
395 | 131 | """ |
396 | 132 |
|
397 | 133 | def __init__(self, model_config): |
398 | 134 | _normalize_qwen35_exclude_modules(model_config) |
399 | 135 | super().__init__(model_config) |
400 | | - |
401 | | - |
402 | | -# TODO: Add tests for disaggregated support. |
403 | | -@support_multimodal_disaggregated |
404 | | -@register_vision_encoder(Qwen3VisionModelBase, vlm_base_model=Qwen3VisionModel) |
405 | | -@register_auto_model("Qwen3_5MoeForConditionalGeneration") |
406 | | -@register_input_processor( |
407 | | - Qwen3VLInputProcessorBase, |
408 | | - model_type="qwen3_5_moe", |
409 | | - placeholder_metadata=MultimodalPlaceholderMetadata( |
410 | | - placeholder_map={ |
411 | | - "image": "<|vision_start|><|image_pad|><|vision_end|>", |
412 | | - "video": "<|vision_start|><|video_pad|><|vision_end|>", |
413 | | - }, |
414 | | - placeholder_placement=MultimodalPlaceholderPlacement.BEFORE_TEXT, |
415 | | - placeholders_separator="", |
416 | | - content_format=ContentFormat.STRING, |
417 | | - ), |
418 | | -) |
419 | | -class Qwen3_5MoeVLModel(Qwen3VLModelBase): |
420 | | - """VLM wrapper composing Qwen3 vision encoder with Qwen3.5 MoE text decoder.""" |
421 | | - |
422 | | - def __init__(self, model_config: ModelConfig[PretrainedConfig], *args, **kwargs): |
423 | | - kwargs["vision_model_class"] = Qwen3VisionModel |
424 | | - kwargs["disable_fuse_rope"] = kwargs.get("disable_fuse_rope", False) |
425 | | - super().__init__(model_config, *args, **kwargs) |
426 | | - |
427 | | - @property |
428 | | - def multimodal_data_device_paths(self) -> List[str]: |
429 | | - return [ |
430 | | - "image.pixel_values", |
431 | | - "video.pixel_values_videos", |
432 | | - "multimodal_embedding", |
433 | | - ] |
434 | | - |
435 | | - def load_weights(self, weights: Dict[str, torch.Tensor], weight_mapper: BaseWeightMapper): |
436 | | - if not _is_disagg(): |
437 | | - self.mm_encoder.load_weights(weights) |
438 | | - |
439 | | - weight_mapper = Qwen3_5MoeHfWeightMapper() |
440 | | - weight_mapper.init_model_and_config(self.llm, self.model_config) |
441 | | - filtered_weights = {k: v for k, v in weights.items() if not k.startswith("model.visual.")} |
442 | | - params_map = { |
443 | | - r"^model\.language_model\.(.*)$": r"model.\1", |
444 | | - } |
445 | | - self.llm.load_weights(filtered_weights, weight_mapper, params_map=params_map) |
0 commit comments