-
Notifications
You must be signed in to change notification settings - Fork 2.5k
[None][feat] Cosmos3 reasoner only support #15117
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
a773c0e
6c91cab
1ed327c
08a1e42
8783f6a
16bd6d5
b338fc2
439bdaf
997c829
13a1d7c
1b95b94
eae018c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -230,7 +230,11 @@ Visual Generation Serving | |
| trtllm-serve black-forest-labs/FLUX.2-dev \ | ||
| --visual_gen_args config.yml | ||
|
|
||
| The ``--visual_gen_args`` flag accepts a YAML file that configures quantization, parallelism, and TeaCache. Available visual generation endpoints include ``/v1/images/generations``, ``/v1/videos``, ``/v1/videos/generations``, and video management APIs. | ||
| # Video generation (Cosmos3 hybrid checkpoint) | ||
| trtllm-serve nvidia/Cosmos3-Nano \ | ||
| --enable_visual_gen | ||
|
|
||
| For checkpoints that support both LLM and Visual Generation, such as Cosmos3, pass ``--enable_visual_gen`` to select the VisualGen runtime when ``--visual_gen_args`` is not specified. The ``--visual_gen_args`` flag accepts a YAML file that configures quantization, parallelism, and TeaCache. Available visual generation endpoints include ``/v1/images/generations``, ``/v1/videos``, ``/v1/videos/generations``, and video management APIs. | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. AFAIK, this is the first hybrid model that supports both text output and multimodal generation serving. Can we improve the general Cosmos3 model doc to reflect that this model supports multimodal output? My understanding is:
So do we really need to add Looking ahead to Omni, if this kind of hybrid model becomes more popular, maybe we can consider adding a new trtllm-serve argument to explicitly select the output modality, to better align which engine path to use. |
||
|
|
||
| For full details, see the :doc:`../../models/visual-generation.md` feature documentation. Example client scripts are available in the `examples/visual_gen/serve/ <https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/visual_gen/serve>`_ directory. | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,82 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| # Adapted from upstream transformers Cosmos3OmniConfig: | ||
| # https://github.com/huggingface/transformers/blob/main/src/transformers/models/cosmos3_omni/configuration_cosmos3_omni.py | ||
| # | ||
| # Workaround until TRT-LLM upgrades to a transformers release that registers cosmos3_omni natively. | ||
|
|
||
| import os | ||
|
|
||
| from huggingface_hub.dataclasses import strict | ||
| from transformers.configuration_utils import PreTrainedConfig | ||
| from transformers.models.auto.configuration_auto import CONFIG_MAPPING, AutoConfig | ||
|
|
||
|
|
||
| @strict | ||
| class Cosmos3Config(PreTrainedConfig): | ||
| model_type = "cosmos3" | ||
| sub_configs = {"vision_config": AutoConfig, "text_config": AutoConfig} | ||
| keys_to_ignore_at_inference = ["past_key_values"] | ||
|
|
||
| text_config: dict | PreTrainedConfig | None = None | ||
| vision_config: dict | PreTrainedConfig | None = None | ||
| image_token_id: int = 151655 | ||
| video_token_id: int = 151656 | ||
| vision_start_token_id: int = 151652 | ||
| vision_end_token_id: int = 151653 | ||
| tie_word_embeddings: bool = False | ||
|
|
||
| def __post_init__(self, **kwargs): | ||
| if isinstance(self.vision_config, dict): | ||
| model_type = self.vision_config.pop("model_type", "qwen3_vl_vision") | ||
| if model_type == "qwen3_vl": | ||
| model_type = "qwen3_vl_vision" | ||
| self.vision_config = CONFIG_MAPPING[model_type](**self.vision_config) | ||
| elif self.vision_config is None: | ||
| self.vision_config = CONFIG_MAPPING["qwen3_vl_vision"]() | ||
|
|
||
| if isinstance(self.text_config, dict): | ||
| model_type = self.text_config.get("model_type", "qwen3_vl_text") | ||
| self.text_config = CONFIG_MAPPING[model_type](**self.text_config) | ||
| elif self.text_config is None: | ||
| self.text_config = CONFIG_MAPPING["qwen3_vl_text"]() | ||
|
|
||
| super().__post_init__(**kwargs) | ||
|
|
||
| @classmethod | ||
| def from_dict(cls, config_dict, **kwargs): | ||
| config = super().from_dict(config_dict, **kwargs) | ||
| # PreTrainedConfig.from_pretrained / from_dict hydrate declared fields from | ||
| # config.json only. ``_name_or_path`` is runtime metadata (not stored in the | ||
| # JSON) and is not set to the checkpoint directory automatically. Cosmos3 | ||
| # needs that path to locate ``transformer/`` and ``vision_encoder/``; TRT-LLM | ||
| # does not pass it separately on ModelConfig, so set it when callers provide | ||
| # one via kwargs. | ||
| name_or_path = kwargs.get("_name_or_path") or kwargs.get("name_or_path") | ||
| if name_or_path and ( | ||
| not getattr(config, "_name_or_path", None) or len(str(config._name_or_path)) < 2 | ||
| ): | ||
| config._name_or_path = os.fspath(name_or_path) | ||
| return config | ||
|
|
||
| @classmethod | ||
| def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): | ||
| config = super().from_pretrained(pretrained_model_name_or_path, **kwargs) | ||
| # Same as from_dict above: copy the resolved checkpoint identifier onto the | ||
| # config so Cosmos3Model can find the unified checkpoint root. | ||
| if not getattr(config, "_name_or_path", None) or len(str(config._name_or_path)) < 2: | ||
| config._name_or_path = os.fspath(pretrained_model_name_or_path) | ||
| return config |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,67 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| import re | ||
|
|
||
| from tensorrt_llm._torch.models.checkpoints.hf.qwen3vl_weight_mapper import Qwen3VLHfWeightMapper | ||
| from tensorrt_llm._torch.models.modeling_utils import register_mapper | ||
|
|
||
|
|
||
| @register_mapper("HF", "Cosmos3ForConditionalGeneration") | ||
| class Cosmos3HfWeightMapper(Qwen3VLHfWeightMapper): | ||
| """ | ||
| Cosmos3 unified checkpoints store the Reasoner LLM with the old `model.` prefix | ||
| stripped off and the ViT under flat `blocks.*` / `merger.*` / `patch_embed.*` / | ||
| `pos_embed.*` / `deepstack_merger_list.*`. Re-target both to the nested Qwen3-VL | ||
| layout (`model.language_model.*` and `model.visual.*`). Newer checkpoints also | ||
| use Diffusers-style attention names; map them back to Qwen3-VL module names. | ||
| """ | ||
|
|
||
| KEYS_TO_DROP = ( | ||
| # Generator (image / video diffusion) MoT expert + cross-modal projections # codespell:ignore | ||
| r"\.add_q_proj\.", | ||
| r"\.add_k_proj\.", | ||
| r"\.add_v_proj\.", | ||
| r"\.to_add_out\.", | ||
| r"\.norm_added_q\.", | ||
| r"\.norm_added_k\.", | ||
| r"moe_gen", | ||
| r"^proj_out\.", | ||
| r"^proj_in\.", | ||
| r"^time_embedder\.", | ||
| # Sound tower | ||
| r"^audio_proj_out\.", | ||
| r"^audio_proj_in\.", | ||
| r"^audio_modality_embed$", | ||
| # Action tower | ||
| r"^action_proj_out\.", | ||
| r"^action_proj_in\.", | ||
| r"^action_modality_embed$", | ||
| ) | ||
|
|
||
| def __init__(self): | ||
| super().__init__() | ||
|
|
||
| self.prefix_params_map = { | ||
| r"^(layers\.|embed_tokens\.|norm\.)": r"model.language_model.\1", | ||
| r"^(blocks\.|merger\.|patch_embed\.|pos_embed\.|deepstack_merger_list\.)": r"model.visual.\1", | ||
| } | ||
| self.attn_params_map = { | ||
| r"(.*)\.self_attn\.to_q\.(.*)": r"\1.self_attn.q_proj.\2", | ||
| r"(.*)\.self_attn\.to_k\.(.*)": r"\1.self_attn.k_proj.\2", | ||
| r"(.*)\.self_attn\.to_v\.(.*)": r"\1.self_attn.v_proj.\2", | ||
| r"(.*)\.self_attn\.to_out\.(.*)": r"\1.self_attn.o_proj.\2", | ||
| r"(.*)\.self_attn\.norm_q\.(.*)": r"\1.self_attn.q_norm.\2", | ||
| r"(.*)\.self_attn\.norm_k\.(.*)": r"\1.self_attn.k_norm.\2", | ||
| } | ||
|
|
||
| @classmethod | ||
| def should_drop_checkpoint_key(cls, key: str) -> bool: | ||
| return any(re.search(pattern, key) for pattern in cls.KEYS_TO_DROP) | ||
|
|
||
| def preprocess_weights(self, weights: dict) -> dict: | ||
| weights = { | ||
| key: value for key, value in weights.items() if not self.should_drop_checkpoint_key(key) | ||
| } | ||
| weights = self.rename_by_params_map(self.prefix_params_map, weights) | ||
| return self.rename_by_params_map(self.attn_params_map, weights) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,105 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| import os | ||
| from typing import Dict, Tuple | ||
|
|
||
| import safetensors.torch | ||
| import torch | ||
| from transformers import PretrainedConfig | ||
|
|
||
| from ...inputs import ( | ||
| ContentFormat, | ||
| MultimodalPlaceholderMetadata, | ||
| MultimodalPlaceholderPlacement, | ||
| register_input_processor, | ||
| support_multimodal_disaggregated, | ||
| ) | ||
| from ..model_config import ModelConfig | ||
| from .checkpoints.base_weight_mapper import BaseWeightMapper | ||
| from .checkpoints.hf.cosmos3_weight_mapper import Cosmos3HfWeightMapper | ||
| from .modeling_qwen3vl import ( | ||
| Qwen3VisionModel, | ||
| Qwen3VisionModelBase, | ||
| Qwen3VLInputProcessorBase, | ||
| Qwen3VLModel, | ||
| ) | ||
| from .modeling_utils import register_auto_model, register_vision_encoder | ||
|
|
||
|
|
||
| def _get_cosmos3_model_paths(config: PretrainedConfig) -> Tuple[str, str, str]: | ||
| """Resolve unified Cosmos3 checkpoint paths from the omni config. | ||
|
|
||
| Unified Cosmos3 checkpoints use `transformer/` for LLM weights and | ||
| `vision_encoder/` for the vision tower. | ||
| """ | ||
| root_path = config._name_or_path | ||
|
|
||
| root_path = os.fspath(root_path) | ||
| llm_path = os.path.join(root_path, "transformer") | ||
| vision_path = os.path.join(root_path, "vision_encoder") | ||
|
|
||
| if not os.path.isdir(llm_path): | ||
| raise FileNotFoundError(f"Cosmos3 transformer weights not found under {llm_path}.") | ||
|
|
||
| return root_path, llm_path, vision_path | ||
|
|
||
|
|
||
| PLACEHOLDER_METADATA = MultimodalPlaceholderMetadata( | ||
| placeholder_map={ | ||
| "image": "<|vision_start|><|image_pad|><|vision_end|>", | ||
| "video": "<|vision_start|><|video_pad|><|vision_end|>", | ||
| }, | ||
| placeholder_placement=MultimodalPlaceholderPlacement.BEFORE_TEXT, | ||
| placeholders_separator="", | ||
| content_format=ContentFormat.STRING, | ||
| ) | ||
|
|
||
|
|
||
| @support_multimodal_disaggregated | ||
| @register_vision_encoder(Qwen3VisionModelBase, vlm_base_model=Qwen3VisionModel) | ||
| @register_auto_model("Cosmos3ForConditionalGeneration") | ||
| @register_input_processor( | ||
| Qwen3VLInputProcessorBase, model_type="cosmos3", placeholder_metadata=PLACEHOLDER_METADATA | ||
| ) | ||
| # cosmos3_omni is the backward-compat alias for cosmos3, remove it when checkpoints migrate to cosmos3 | ||
| @register_input_processor( | ||
| Qwen3VLInputProcessorBase, model_type="cosmos3_omni", placeholder_metadata=PLACEHOLDER_METADATA | ||
| ) | ||
| class Cosmos3Model(Qwen3VLModel): | ||
| def __init__(self, model_config: ModelConfig[PretrainedConfig], *args, **kwargs): | ||
| omni_config = model_config.pretrained_config | ||
| if omni_config is None: | ||
| raise ValueError( | ||
| "Cosmos3Model requires model_config.pretrained_config to resolve " | ||
| "the LLM and vision encoder checkpoint paths, but it was None." | ||
| ) | ||
| if not getattr(omni_config, "_name_or_path", None): | ||
| raise ValueError( | ||
| "Cosmos3Model requires model_config.pretrained_config._name_or_path to resolve " | ||
| "the LLM and vision encoder checkpoint paths, but it was None or empty." | ||
| ) | ||
|
|
||
| (self._checkpoint_root, self.llm_path, self._vision_encoder_path) = ( | ||
| _get_cosmos3_model_paths(omni_config) | ||
| ) | ||
|
|
||
| super().__init__(model_config, *args, **kwargs) | ||
|
|
||
| @property | ||
| def llm_checkpoint_dir(self) -> str: | ||
| """Return the directory of the LLM checkpoint (``transformer/`` subdir).""" | ||
| return self.llm_path | ||
|
|
||
| def load_weights(self, weights: Dict[str, torch.Tensor], weight_mapper: BaseWeightMapper): | ||
| vision_weights_file = os.path.join(self._vision_encoder_path, "model.safetensors") | ||
| if not os.path.isfile(vision_weights_file): | ||
| raise FileNotFoundError( | ||
| f"Cosmos3 vision encoder weights not found at {vision_weights_file}." | ||
| ) | ||
| weights.update(safetensors.torch.load_file(vision_weights_file)) | ||
|
|
||
| if not isinstance(weight_mapper, Cosmos3HfWeightMapper): | ||
| weight_mapper = Cosmos3HfWeightMapper() | ||
| weights = weight_mapper.preprocess_weights(weights) | ||
| super().load_weights(weights, weight_mapper) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Non-blocking:
This PR seems to enable
trtllm-serve. What about our offline inference paths, such astrtllm-benchor the multimodal quickstart examples? are they supported as well?