Skip to content

Commit dbf9767

Browse files
committed
Address review comments
Signed-off-by: Michal Guzek <mguzek@nvidia.com>
1 parent 11ab076 commit dbf9767

5 files changed

Lines changed: 27 additions & 21 deletions

File tree

docs/source/models/supported-models.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ Note: Support for other models may vary. Features marked "N/A" are not applicabl
9595
| `Qwen2_5_VLForConditionalGeneration` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | L + I + V |
9696
| `Qwen3VLForConditionalGeneration` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | L + I + V |
9797
| `Qwen3VLMoeForConditionalGeneration` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | L + I + V |
98+
| `Qwen3_5MoeForConditionalGeneration` | Yes | Yes | Untested | Yes | Yes | No | Untested | Yes | L + I + V |
9899

99100
Note:
100101
- L: Language

tensorrt_llm/_torch/models/modeling_qwen3_5.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
# runtime layer asks the model module how to load its own config.
3636
#
3737
# There are two entry points:
38-
# - `_Qwen35ConfigCompat.normalize(config_dict)` — for text-only
38+
# - `Qwen35ConfigCompat.normalize(config_dict)` — for text-only
3939
# Qwen3.5 (MoE and dense). Returns a dict that
4040
# `transformers.Qwen3NextConfig.from_dict(...)` can consume, so the
4141
# existing Qwen3Next runtime is reused unchanged.
@@ -45,7 +45,7 @@
4545
# while keeping `text_config` / `vision_config` composite.
4646

4747

48-
class _Qwen35ConfigCompat:
48+
class Qwen35ConfigCompat:
4949
"""Temporary shim for flattening Qwen3.5 text configs into Qwen3NextConfig.
5050
5151
We normalize to `Qwen3NextConfig` (rather than to a Qwen3.5-native
@@ -66,9 +66,9 @@ class _Qwen35ConfigCompat:
6666
@staticmethod
6767
def normalize(config_dict: dict) -> dict:
6868
"""Entry point: raw config.json dict -> flat Qwen3NextConfig-compatible dict."""
69-
text_config = _Qwen35ConfigCompat._extract_text_config(config_dict)
70-
text_config = _Qwen35ConfigCompat._inherit_quantization_config(config_dict, text_config)
71-
text_config = _Qwen35ConfigCompat._flatten_rope(text_config)
69+
text_config = Qwen35ConfigCompat._extract_text_config(config_dict)
70+
text_config = Qwen35ConfigCompat._inherit_quantization_config(config_dict, text_config)
71+
text_config = Qwen35ConfigCompat._flatten_rope(text_config)
7272

7373
# Detect dense vs MoE and set architecture + MoE defaults accordingly
7474
is_moe = "num_experts" in text_config and text_config["num_experts"] > 0
@@ -93,7 +93,7 @@ def normalize(config_dict: dict) -> dict:
9393
def _extract_text_config(config_dict: dict) -> dict:
9494
"""Pull nested text_config from VLM checkpoints, or use dict as-is."""
9595
architectures = config_dict.get("architectures") or []
96-
if architectures and architectures[0] in _Qwen35ConfigCompat._VLM_ARCHITECTURES:
96+
if architectures and architectures[0] in Qwen35ConfigCompat._VLM_ARCHITECTURES:
9797
text_config = dict(config_dict.get("text_config") or {})
9898
else:
9999
text_config = dict(config_dict)
@@ -116,10 +116,10 @@ def _inherit_quantization_config(config_dict: dict, text_config: dict) -> dict:
116116

117117
quantization_config = dict(config_dict["quantization_config"])
118118
if "modules_to_not_convert" in quantization_config:
119-
modules = _Qwen35ConfigCompat._normalize_exclude_modules(
119+
modules = Qwen35ConfigCompat._normalize_exclude_modules(
120120
quantization_config["modules_to_not_convert"]
121121
)
122-
modules = _Qwen35ConfigCompat._add_qkvz_bf16_workaround(text_config, modules)
122+
modules = Qwen35ConfigCompat._add_qkvz_bf16_workaround(text_config, modules)
123123
quantization_config["modules_to_not_convert"] = sorted(set(modules))
124124
text_config["quantization_config"] = quantization_config
125125
return text_config
@@ -209,7 +209,7 @@ def _normalize_qwen35_mrope_config(text_config) -> None:
209209
return
210210
if hasattr(rope_parameters, "to_dict"):
211211
rope_parameters = rope_parameters.to_dict()
212-
flattened = _Qwen35ConfigCompat._flatten_rope(
212+
flattened = Qwen35ConfigCompat._flatten_rope(
213213
{
214214
"rope_parameters": dict(rope_parameters),
215215
"rope_scaling": dict(getattr(text_config, "rope_scaling", None) or {}),
@@ -245,9 +245,9 @@ def _normalize_qwen35_quantization_config(model_config) -> None:
245245
return
246246

247247
text_config = getattr(model_config, "text_config", None)
248-
normalized_modules = _Qwen35ConfigCompat._normalize_exclude_modules(modules)
248+
normalized_modules = Qwen35ConfigCompat._normalize_exclude_modules(modules)
249249
if text_config is not None:
250-
normalized_modules = _Qwen35ConfigCompat._add_qkvz_bf16_workaround(
250+
normalized_modules = Qwen35ConfigCompat._add_qkvz_bf16_workaround(
251251
text_config.to_dict(), normalized_modules
252252
)
253253
quantization_config["modules_to_not_convert"] = sorted(set(normalized_modules))
@@ -331,7 +331,7 @@ class Qwen3_5ForCausalLM(Qwen3NextForCausalLM):
331331
332332
Same reuse pattern as Qwen3_5MoeForCausalLM, but for the dense 27B
333333
variant which uses GatedMLP instead of SparseMoeBlock. The config
334-
normalizer (_Qwen35ConfigCompat) sets num_experts=0 so that
334+
normalizer (Qwen35ConfigCompat) sets num_experts=0 so that
335335
Qwen3NextModel selects GatedMLP for the feed-forward layers.
336336
"""
337337

@@ -340,6 +340,7 @@ def __init__(self, model_config):
340340
super().__init__(model_config)
341341

342342

343+
# TODO: Add tests for disaggregated support.
343344
@support_multimodal_disaggregated
344345
@register_vision_encoder(Qwen3VisionModelBase, vlm_base_model=Qwen3VisionModel)
345346
@register_auto_model("Qwen3_5MoeForConditionalGeneration")

tensorrt_llm/_torch/pyexecutor/config_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -379,9 +379,9 @@ def load_pretrained_config(model_name_or_path: str,
379379
)):
380380
# Qwen3.5 text-only: flatten to Qwen3NextConfig via the model-side shim.
381381
from tensorrt_llm._torch.models.modeling_qwen3_5 import \
382-
_Qwen35ConfigCompat
382+
Qwen35ConfigCompat
383383
model_config = transformers.Qwen3NextConfig.from_dict(
384-
_Qwen35ConfigCompat.normalize(config_dict))
384+
Qwen35ConfigCompat.normalize(config_dict))
385385
elif (model_type == "exaone4" and config_dict.get("sliding_window") is None
386386
and config_dict.get("layer_types") is None):
387387
# transformers 5.5.x Exaone4Config.__post_init__ first forces

tests/integration/test_lists/test-db/l0_l40s.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ l0_l40s:
2323
- unittest/_torch/modeling/test_modeling_qwen2_5vl.py::TestQwen2_5_VL::test_all
2424
- unittest/_torch/modeling/test_modeling_qwen3vl_moe.py::TestQwen3VLMoe::test_all
2525
- unittest/_torch/modeling/test_modeling_qwen3vl.py::TestQwen3VL::test_all
26+
- unittest/_torch/modeling/test_modeling_qwen3_5_vl_moe.py::TestQwen3_5MoeVL::test_all
2627
- test_e2e.py::test_ptp_scaffolding[DeepSeek-R1-Distill-Qwen-7B-DeepSeek-R1/DeepSeek-R1-Distill-Qwen-7B]
2728
- unittest/llmapi/apps/_test_openai_chat_multimodal.py::test_single_chat_session_image_embeds -m needs_l40s
2829
# MMMU sanity check

tests/unittest/_torch/modeling/test_modeling_qwen3_5_vl_moe.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -325,13 +325,9 @@ def create_trtllm_model(
325325
model = model_class(model_config, **kwargs).to("cuda")
326326

327327
if load_weights:
328-
weight_mapper_class = self.get_weight_mapper_class()
329-
if weight_mapper_class is not None:
330-
weight_mapper = weight_mapper_class()
331-
weight_mapper.init_model_and_config(model, trtllm_config)
332-
model.load_weights(hf_model_state_dict, weight_mapper)
333-
else:
334-
model.load_weights(hf_model_state_dict)
328+
weight_mapper = self.get_weight_mapper_class()()
329+
weight_mapper.init_model_and_config(model, trtllm_config)
330+
model.load_weights(hf_model_state_dict, weight_mapper)
335331

336332
for module in model.modules():
337333
if hasattr(module, "post_load_weights") and not getattr(
@@ -346,6 +342,13 @@ def _dummy_request_kwargs(self, scenario):
346342
position-id buffer allocated at dummy-request time."""
347343
return {"use_mrope": True}
348344

345+
def get_tolerance(self):
346+
"""Tighten `rtol` to `0.1` (4x tighter than the base 0.4
347+
default) while keeping `atol` at `0.4` to absorb single-logit
348+
tail outliers seen on `multiple_image` / `video`.
349+
"""
350+
return 0.4, 0.1
351+
349352
def get_trtllm_inputs(
350353
self,
351354
input_ids,

0 commit comments

Comments
 (0)