Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/CODEOWNERS
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,11 @@
/tensorrt_llm/_torch/models/modeling_nemotron_nano.py @NVIDIA/trt-llm-multimodal-devs @NVIDIA/trt-llm-torch-models-devs
/tensorrt_llm/_torch/models/modeling_qwen3vl.py @NVIDIA/trt-llm-multimodal-devs @NVIDIA/trt-llm-torch-models-devs
/tensorrt_llm/_torch/models/modeling_qwen3vl_moe.py @NVIDIA/trt-llm-multimodal-devs @NVIDIA/trt-llm-torch-models-devs
/tensorrt_llm/_torch/models/modeling_cosmos3.py @NVIDIA/trt-llm-multimodal-devs @NVIDIA/trt-llm-torch-models-devs
/tensorrt_llm/_torch/models/modeling_qwen3_5.py @NVIDIA/trt-llm-multimodal-devs @NVIDIA/trt-llm-torch-models-devs
/tensorrt_llm/_torch/models/modeling_hunyuan_dense.py @NVIDIA/trt-llm-multimodal-devs @NVIDIA/trt-llm-torch-models-devs
/tensorrt_llm/_torch/models/checkpoints/hf/qwen3vl_weight_mapper.py @NVIDIA/trt-llm-multimodal-devs @NVIDIA/trt-llm-torch-models-devs
/tensorrt_llm/_torch/models/checkpoints/hf/cosmos3_weight_mapper.py @NVIDIA/trt-llm-multimodal-devs @NVIDIA/trt-llm-torch-models-devs
/tensorrt_llm/_torch/models/checkpoints/hf/qwen3vl_moe_weight_mapper.py @NVIDIA/trt-llm-multimodal-devs @NVIDIA/trt-llm-torch-models-devs
/tensorrt_llm/_torch/models/checkpoints/hf/qwen3_5_weight_mapper.py @NVIDIA/trt-llm-multimodal-devs @NVIDIA/trt-llm-torch-models-devs
/tensorrt_llm/_torch/models/checkpoints/hf/llava_next_weight_mapper.py @NVIDIA/trt-llm-multimodal-devs @NVIDIA/trt-llm-torch-models-devs
Expand Down
6 changes: 5 additions & 1 deletion docs/source/commands/trtllm-serve/trtllm-serve.rst
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,11 @@ Visual Generation Serving
trtllm-serve black-forest-labs/FLUX.2-dev \
--visual_gen_args config.yml

The ``--visual_gen_args`` flag accepts a YAML file that configures quantization, parallelism, and TeaCache. Available visual generation endpoints include ``/v1/images/generations``, ``/v1/videos``, ``/v1/videos/generations``, and video management APIs.
# Video generation (Cosmos3 hybrid checkpoint)
trtllm-serve nvidia/Cosmos3-Nano \

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non-blocking:
This PR seems to enable trtllm-serve. What about our offline inference paths, such as trtllm-bench or the multimodal quickstart examples? are they supported as well?

--enable_visual_gen

For checkpoints that support both LLM and Visual Generation, such as Cosmos3, pass ``--enable_visual_gen`` to select the VisualGen runtime when ``--visual_gen_args`` is not specified. The ``--visual_gen_args`` flag accepts a YAML file that configures quantization, parallelism, and TeaCache. Available visual generation endpoints include ``/v1/images/generations``, ``/v1/videos``, ``/v1/videos/generations``, and video management APIs.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIK, this is the first hybrid model that supports both text output and multimodal generation serving. Can we improve the general Cosmos3 model doc to reflect that this model supports multimodal output?

My understanding is:

  • If --configs is passed, we should use the config-defined LLM engine path.
  • If --visual_gen_args is passed, we should use the config-defined VG engine path.
  • If no configs are passed, we can default to the VG engine path.

So do we really need to add enable_visual_gen to the server API?

Looking ahead to Omni, if this kind of hybrid model becomes more popular, maybe we can consider adding a new trtllm-serve argument to explicitly select the output modality, to better align which engine path to use.


For full details, see the :doc:`../../models/visual-generation.md` feature documentation. Example client scripts are available in the `examples/visual_gen/serve/ <https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/visual_gen/serve>`_ directory.

Expand Down
1 change: 1 addition & 0 deletions docs/source/models/supported-models.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ Note: Support for other models may vary. Features marked "N/A" are not applicabl
| `Qwen3VLForConditionalGeneration` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | L + I + V |
| `Qwen3VLMoeForConditionalGeneration` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | L + I + V |
| `Step3p7ForConditionalGeneration` | Yes | Yes | Untested | Yes | Untested | Untested | Untested | Untested | L + I |
| `Cosmos3ForConditionalGeneration` | Yes | Yes | Yes | Yes | Yes | Yes | Untested | Untested | L + I + V |
| `MiniMaxM3SparseForConditionalGeneration` [^11] | Yes | Yes | Untested | Yes | Untested | No | Untested | Untested | L + I + V |

Note:
Expand Down
16 changes: 13 additions & 3 deletions tensorrt_llm/_torch/configs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from tensorrt_llm._torch.configs.cosmos3 import Cosmos3Config
from tensorrt_llm._torch.configs.deepseek_v3 import DeepseekV3Config
from tensorrt_llm._torch.configs.laguna import LagunaConfig


def _register_custom_configs_with_transformers() -> None:
# Make AutoConfig.from_pretrained / AutoTokenizer.from_pretrained accept
# model_types that TRT-LLM understands but upstream transformers does not
# (DeepSeek-V3.2, Kimi K2, and Laguna ship config.json with these
# model_types and rely on TRT-LLM's local config workarounds).
# (DeepSeek-V3.2, Kimi K2, Laguna, and Cosmos3 omni ship config.json with
Comment thread
bastefaniak marked this conversation as resolved.
# these model_types and rely on TRT-LLM's local config workarounds).
#
# Without this, transformers 5.5.x falls back to a bare PreTrainedConfig
# that lacks attributes like `max_position_embeddings`, and
Expand All @@ -15,12 +16,21 @@ def _register_custom_configs_with_transformers() -> None:
# consistency check for aliases (for example, DeepseekV3Config.model_type
# is "deepseek_v3") by writing into the underlying mapping directly.
from transformers.models.auto.configuration_auto import CONFIG_MAPPING
from transformers.models.qwen3_vl.configuration_qwen3_vl import Qwen3VLVisionConfig

custom_configs = {
# "cosmos3" is the canonical model_type; "cosmos3_omni" is kept as a
Comment thread
2ez4bz marked this conversation as resolved.
# backward-compat alias for checkpoints that predate the rename.
"cosmos3": Cosmos3Config,
"cosmos3_omni": Cosmos3Config,
"deepseek_v32": DeepseekV3Config,
"kimi_k2": DeepseekV3Config,
"laguna": LagunaConfig,
}
# Cosmos3Config resolves vision sub-configs via ``qwen3_vl_vision``; that
# alias is only present in newer transformers releases.
if "qwen3_vl_vision" not in CONFIG_MAPPING:
CONFIG_MAPPING.register("qwen3_vl_vision", Qwen3VLVisionConfig, exist_ok=True)
for model_type, config_class in custom_configs.items():
if model_type in CONFIG_MAPPING:
continue
Expand All @@ -30,4 +40,4 @@ def _register_custom_configs_with_transformers() -> None:
_register_custom_configs_with_transformers()
del _register_custom_configs_with_transformers

__all__ = ["DeepseekV3Config", "LagunaConfig"]
__all__ = ["Cosmos3Config", "DeepseekV3Config", "LagunaConfig"]
82 changes: 82 additions & 0 deletions tensorrt_llm/_torch/configs/cosmos3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Adapted from upstream transformers Cosmos3OmniConfig:
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/cosmos3_omni/configuration_cosmos3_omni.py
#
# Workaround until TRT-LLM upgrades to a transformers release that registers cosmos3_omni natively.

import os

from huggingface_hub.dataclasses import strict
from transformers.configuration_utils import PreTrainedConfig
from transformers.models.auto.configuration_auto import CONFIG_MAPPING, AutoConfig


@strict
class Cosmos3Config(PreTrainedConfig):
model_type = "cosmos3"
sub_configs = {"vision_config": AutoConfig, "text_config": AutoConfig}
keys_to_ignore_at_inference = ["past_key_values"]

text_config: dict | PreTrainedConfig | None = None
vision_config: dict | PreTrainedConfig | None = None
image_token_id: int = 151655
video_token_id: int = 151656
vision_start_token_id: int = 151652
vision_end_token_id: int = 151653
tie_word_embeddings: bool = False

def __post_init__(self, **kwargs):
if isinstance(self.vision_config, dict):
model_type = self.vision_config.pop("model_type", "qwen3_vl_vision")
if model_type == "qwen3_vl":
model_type = "qwen3_vl_vision"
self.vision_config = CONFIG_MAPPING[model_type](**self.vision_config)
elif self.vision_config is None:
self.vision_config = CONFIG_MAPPING["qwen3_vl_vision"]()

if isinstance(self.text_config, dict):
model_type = self.text_config.get("model_type", "qwen3_vl_text")
self.text_config = CONFIG_MAPPING[model_type](**self.text_config)
elif self.text_config is None:
self.text_config = CONFIG_MAPPING["qwen3_vl_text"]()

super().__post_init__(**kwargs)

@classmethod
def from_dict(cls, config_dict, **kwargs):
config = super().from_dict(config_dict, **kwargs)
# PreTrainedConfig.from_pretrained / from_dict hydrate declared fields from
# config.json only. ``_name_or_path`` is runtime metadata (not stored in the
# JSON) and is not set to the checkpoint directory automatically. Cosmos3
# needs that path to locate ``transformer/`` and ``vision_encoder/``; TRT-LLM
# does not pass it separately on ModelConfig, so set it when callers provide
# one via kwargs.
name_or_path = kwargs.get("_name_or_path") or kwargs.get("name_or_path")
if name_or_path and (
not getattr(config, "_name_or_path", None) or len(str(config._name_or_path)) < 2
):
config._name_or_path = os.fspath(name_or_path)
return config

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
config = super().from_pretrained(pretrained_model_name_or_path, **kwargs)
# Same as from_dict above: copy the resolved checkpoint identifier onto the
# config so Cosmos3Model can find the unified checkpoint root.
if not getattr(config, "_name_or_path", None) or len(str(config._name_or_path)) < 2:
config._name_or_path = os.fspath(pretrained_model_name_or_path)
return config
2 changes: 2 additions & 0 deletions tensorrt_llm/_torch/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .modeling_bert import BertForSequenceClassification
from .modeling_clip import CLIPVisionModel
from .modeling_cohere2 import Cohere2ForCausalLM
from .modeling_cosmos3 import Cosmos3Model
from .modeling_deepseekv3 import DeepseekV3ForCausalLM
from .modeling_exaone4 import Exaone4ForCausalLM
from .modeling_exaone4_5 import Exaone4_5_ForConditionalGeneration
Expand Down Expand Up @@ -64,6 +65,7 @@
"BartForConditionalGeneration",
"BertForSequenceClassification",
"CLIPVisionModel",
"Cosmos3Model",
"DeepseekV3ForCausalLM",
"Exaone4ForCausalLM",
"Exaone4_5_ForConditionalGeneration",
Expand Down
3 changes: 2 additions & 1 deletion tensorrt_llm/_torch/models/checkpoints/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .hf.afmoe_weight_mapper import AfmoeHfWeightMapper
from .hf.checkpoint_loader import HfCheckpointLoader
from .hf.config_loader import HfConfigLoader
from .hf.cosmos3_weight_mapper import Cosmos3HfWeightMapper
from .hf.gemma3_weight_mapper import Gemma3HfWeightMapper
from .hf.gemma4_weight_mapper import Gemma4HfWeightMapper
from .hf.llama4_weight_mapper import Llama4HfWeightMapper
Expand Down Expand Up @@ -34,5 +35,5 @@
"Qwen3_5MoeHfWeightMapper", "Qwen3NextHfWeightMapper",
"Gemma4HfWeightMapper", "LlavaNextHfWeightMapper",
"MistralLarge3CheckpointLoader", "MistralLarge3WeightMapper",
"MXCheckpointLoader", "Qwen3VLHfWeightMapper"
"MXCheckpointLoader", "Qwen3VLHfWeightMapper", "Cosmos3HfWeightMapper"
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import re

from tensorrt_llm._torch.models.checkpoints.hf.qwen3vl_weight_mapper import Qwen3VLHfWeightMapper
from tensorrt_llm._torch.models.modeling_utils import register_mapper


@register_mapper("HF", "Cosmos3ForConditionalGeneration")
class Cosmos3HfWeightMapper(Qwen3VLHfWeightMapper):
"""
Cosmos3 unified checkpoints store the Reasoner LLM with the old `model.` prefix
stripped off and the ViT under flat `blocks.*` / `merger.*` / `patch_embed.*` /
`pos_embed.*` / `deepstack_merger_list.*`. Re-target both to the nested Qwen3-VL
layout (`model.language_model.*` and `model.visual.*`). Newer checkpoints also
use Diffusers-style attention names; map them back to Qwen3-VL module names.
"""

KEYS_TO_DROP = (
# Generator (image / video diffusion) MoT expert + cross-modal projections # codespell:ignore
r"\.add_q_proj\.",
r"\.add_k_proj\.",
r"\.add_v_proj\.",
r"\.to_add_out\.",
r"\.norm_added_q\.",
r"\.norm_added_k\.",
r"moe_gen",
r"^proj_out\.",
r"^proj_in\.",
r"^time_embedder\.",
# Sound tower
r"^audio_proj_out\.",
r"^audio_proj_in\.",
r"^audio_modality_embed$",
# Action tower
r"^action_proj_out\.",
r"^action_proj_in\.",
r"^action_modality_embed$",
)

def __init__(self):
super().__init__()

self.prefix_params_map = {
r"^(layers\.|embed_tokens\.|norm\.)": r"model.language_model.\1",
r"^(blocks\.|merger\.|patch_embed\.|pos_embed\.|deepstack_merger_list\.)": r"model.visual.\1",
}
self.attn_params_map = {
r"(.*)\.self_attn\.to_q\.(.*)": r"\1.self_attn.q_proj.\2",
r"(.*)\.self_attn\.to_k\.(.*)": r"\1.self_attn.k_proj.\2",
r"(.*)\.self_attn\.to_v\.(.*)": r"\1.self_attn.v_proj.\2",
r"(.*)\.self_attn\.to_out\.(.*)": r"\1.self_attn.o_proj.\2",
r"(.*)\.self_attn\.norm_q\.(.*)": r"\1.self_attn.q_norm.\2",
r"(.*)\.self_attn\.norm_k\.(.*)": r"\1.self_attn.k_norm.\2",
}

@classmethod
def should_drop_checkpoint_key(cls, key: str) -> bool:
return any(re.search(pattern, key) for pattern in cls.KEYS_TO_DROP)

def preprocess_weights(self, weights: dict) -> dict:
weights = {
key: value for key, value in weights.items() if not self.should_drop_checkpoint_key(key)
}
weights = self.rename_by_params_map(self.prefix_params_map, weights)
return self.rename_by_params_map(self.attn_params_map, weights)
105 changes: 105 additions & 0 deletions tensorrt_llm/_torch/models/modeling_cosmos3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import os
from typing import Dict, Tuple

import safetensors.torch
import torch
from transformers import PretrainedConfig

from ...inputs import (
ContentFormat,
MultimodalPlaceholderMetadata,
MultimodalPlaceholderPlacement,
register_input_processor,
support_multimodal_disaggregated,
)
from ..model_config import ModelConfig
from .checkpoints.base_weight_mapper import BaseWeightMapper
from .checkpoints.hf.cosmos3_weight_mapper import Cosmos3HfWeightMapper
from .modeling_qwen3vl import (
Qwen3VisionModel,
Qwen3VisionModelBase,
Qwen3VLInputProcessorBase,
Qwen3VLModel,
)
from .modeling_utils import register_auto_model, register_vision_encoder


def _get_cosmos3_model_paths(config: PretrainedConfig) -> Tuple[str, str, str]:
"""Resolve unified Cosmos3 checkpoint paths from the omni config.

Unified Cosmos3 checkpoints use `transformer/` for LLM weights and
`vision_encoder/` for the vision tower.
"""
root_path = config._name_or_path

root_path = os.fspath(root_path)
llm_path = os.path.join(root_path, "transformer")
vision_path = os.path.join(root_path, "vision_encoder")

if not os.path.isdir(llm_path):
raise FileNotFoundError(f"Cosmos3 transformer weights not found under {llm_path}.")

return root_path, llm_path, vision_path


PLACEHOLDER_METADATA = MultimodalPlaceholderMetadata(
placeholder_map={
"image": "<|vision_start|><|image_pad|><|vision_end|>",
"video": "<|vision_start|><|video_pad|><|vision_end|>",
},
placeholder_placement=MultimodalPlaceholderPlacement.BEFORE_TEXT,
placeholders_separator="",
content_format=ContentFormat.STRING,
)


@support_multimodal_disaggregated
@register_vision_encoder(Qwen3VisionModelBase, vlm_base_model=Qwen3VisionModel)
@register_auto_model("Cosmos3ForConditionalGeneration")
@register_input_processor(
Qwen3VLInputProcessorBase, model_type="cosmos3", placeholder_metadata=PLACEHOLDER_METADATA
)
# cosmos3_omni is the backward-compat alias for cosmos3, remove it when checkpoints migrate to cosmos3
@register_input_processor(
Qwen3VLInputProcessorBase, model_type="cosmos3_omni", placeholder_metadata=PLACEHOLDER_METADATA
)
class Cosmos3Model(Qwen3VLModel):
def __init__(self, model_config: ModelConfig[PretrainedConfig], *args, **kwargs):
omni_config = model_config.pretrained_config
if omni_config is None:
raise ValueError(
"Cosmos3Model requires model_config.pretrained_config to resolve "
"the LLM and vision encoder checkpoint paths, but it was None."
)
if not getattr(omni_config, "_name_or_path", None):
raise ValueError(
"Cosmos3Model requires model_config.pretrained_config._name_or_path to resolve "
"the LLM and vision encoder checkpoint paths, but it was None or empty."
)

(self._checkpoint_root, self.llm_path, self._vision_encoder_path) = (
_get_cosmos3_model_paths(omni_config)
)

super().__init__(model_config, *args, **kwargs)

@property
def llm_checkpoint_dir(self) -> str:
"""Return the directory of the LLM checkpoint (``transformer/`` subdir)."""
return self.llm_path

def load_weights(self, weights: Dict[str, torch.Tensor], weight_mapper: BaseWeightMapper):
vision_weights_file = os.path.join(self._vision_encoder_path, "model.safetensors")
if not os.path.isfile(vision_weights_file):
raise FileNotFoundError(
f"Cosmos3 vision encoder weights not found at {vision_weights_file}."
)
weights.update(safetensors.torch.load_file(vision_weights_file))

if not isinstance(weight_mapper, Cosmos3HfWeightMapper):
weight_mapper = Cosmos3HfWeightMapper()
weights = weight_mapper.preprocess_weights(weights)
super().load_weights(weights, weight_mapper)
5 changes: 4 additions & 1 deletion tensorrt_llm/_torch/models/modeling_qwen3vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1115,7 +1115,10 @@ def __init__(
# itself from this dict in `Qwen2_5_VLVisionAttention.__init__`
# so it does not poison LM lookups.
llm_model_config.extra_attrs = model_config.extra_attrs
if self.original_arch == "Qwen3VLForConditionalGeneration":
if self.original_arch in (
"Qwen3VLForConditionalGeneration",
"Cosmos3ForConditionalGeneration",
):
llm_model_config.pretrained_config.architectures = ["Qwen3ForCausalLM"]
elif self.original_arch == "Qwen3VLMoeForConditionalGeneration":
llm_model_config.pretrained_config.architectures = ["Qwen3MoeForCausalLM"]
Expand Down
3 changes: 3 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,9 @@ def __getitem__(self, key):


_CONFIG_REGISTRY: dict[str, type[transformers.PretrainedConfig]] = LazyConfigDict(
cosmos3="Cosmos3Config",
cosmos3_omni=
"Cosmos3Config", # backward-compat alias for pre-rename checkpoints
deepseek_v32="DeepseekV3Config",
kimi_k2="DeepseekV3Config",
glm_moe_dsa="DeepseekV3Config",
Expand Down
Loading
Loading