diff --git a/tensorrt_llm/_torch/models/modeling_mistral.py b/tensorrt_llm/_torch/models/modeling_mistral.py index a735625620a1..8ad9c525677d 100644 --- a/tensorrt_llm/_torch/models/modeling_mistral.py +++ b/tensorrt_llm/_torch/models/modeling_mistral.py @@ -23,7 +23,7 @@ from tensorrt_llm._torch.models.modeling_mistral_large3 import ( Mistral3Gate, MistralLarge3ForCausalLM) from tensorrt_llm._torch.models.modeling_multimodal_mixin import ( - MultimodalEncoderOutput, MultimodalModelMixin, PreparedLlmInputs) + MultimodalModelMixin, PreparedLlmInputs) from tensorrt_llm._torch.models.modeling_multimodal_utils import ( _MULTIMODAL_ENV_NAME, _is_mm_disagg) from tensorrt_llm._torch.models.modeling_utils import (DecoderModel, @@ -705,9 +705,9 @@ def infer_max_seq_len(self) -> int: def encode_multimodal_inputs( self, multimodal_params: Sequence[MultimodalParams], - ) -> MultimodalEncoderOutput: + ) -> torch.Tensor: mm_embeds = self._vision_forward(list(multimodal_params)) - return MultimodalEncoderOutput(embeddings=mm_embeds[0]) + return mm_embeds[0] def get_language_model_forward_kwargs( self, diff --git a/tensorrt_llm/_torch/models/modeling_multimodal_mixin.py b/tensorrt_llm/_torch/models/modeling_multimodal_mixin.py index 465dd2fc6208..c32e7d636ddd 100644 --- a/tensorrt_llm/_torch/models/modeling_multimodal_mixin.py +++ b/tensorrt_llm/_torch/models/modeling_multimodal_mixin.py @@ -123,27 +123,6 @@ def _run_on_aux_stream(aux_stream: torch.cuda.Stream) -> Iterator[torch.cuda.Eve exit_event.record() -@dataclass(frozen=True) -class MultimodalEncoderOutput: - """Output produced by a model-owned multimodal encoder hook. - - Contract: - - `embeddings` contains all multimodal embedding rows for the supplied - `multimodal_params`. - - Rows are concatenated in the same order as `multimodal_params`. - - Per-request row counts match `total_embeds_in_request` from runtime - metadata when that metadata is available. - - Special multimodal tokens occupy token positions but do not have rows in - this tensor. - - The single-tensor shape is required for chunked-prefill embedding reuse, - which lets later chunks skip the encoder. See - `modeling_multimodal_utils.py` for the caching machinery. - """ - - embeddings: torch.Tensor - - @dataclass(frozen=True) class PreparedLlmInputs: """Prepared inputs returned by `MultimodalModelMixin`.""" @@ -181,8 +160,13 @@ def convert(tensor: torch.Tensor) -> torch.Tensor: def encode_multimodal_inputs( self, multimodal_params: Sequence[MultimodalParams], - ) -> MultimodalEncoderOutput: - """Run model-specific multimodal encoder work.""" + ) -> torch.Tensor: + """Run model-specific multimodal encoder work. + + Returns the single primary multimodal embedding tensor for the supplied params. Rows are + expected to be concatenated in request order, and special multimodal tokens occupy token + positions but do not have rows here. + """ raise NotImplementedError @property @@ -222,16 +206,16 @@ def after_full_multimodal_embeddings( *, input_ids: torch.Tensor, multimodal_params: Sequence[MultimodalParams], - encoder_output: MultimodalEncoderOutput, + embeddings: torch.Tensor, **forward_kwargs: Any, - ) -> tuple[torch.Tensor, MultimodalEncoderOutput]: + ) -> tuple[torch.Tensor, torch.Tensor]: """Optional hook before active chunk rows are selected. Runs after cache lookup or encoder execution has produced full per-request multimodal embeddings, but before the mixin selects rows active in the current forward chunk. """ - return input_ids, encoder_output + return input_ids, embeddings def after_active_multimodal_embeddings( self, @@ -274,21 +258,16 @@ def prepare_multimodal_inputs( if not context_params: return PreparedLlmInputs(input_ids=input_ids, inputs_embeds=None) - full_output = self._get_or_encode_multimodal_embeddings(context_params) + full_embeddings = self._get_or_encode_multimodal_embeddings(context_params) - input_ids, full_output = self.after_full_multimodal_embeddings( + input_ids, full_embeddings = self.after_full_multimodal_embeddings( input_ids=input_ids, multimodal_params=context_params, - encoder_output=full_output, + embeddings=full_embeddings, **forward_kwargs, ) - active_embeddings = self._find_active_multimodal_embeddings( - [full_output.embeddings], - input_ids=input_ids, - positions=positions, - multimodal_params=context_params, - ) + active_embeddings = find_input_mm_embeds([full_embeddings], list(context_params)) active_embeddings, extra_embeds = self.after_active_multimodal_embeddings( active_embeddings=active_embeddings, multimodal_params=context_params, @@ -319,49 +298,20 @@ def prepare_multimodal_inputs( def _get_or_encode_multimodal_embeddings( self, multimodal_params: Sequence[MultimodalParams], - ) -> MultimodalEncoderOutput: + ) -> torch.Tensor: """Return cached multimodal embeddings or run the encoder for misses. - Delegates cache lookup and gather behavior to - `get_multimodal_embeddings`, then validates the single primary tensor - contract for both encoded and cached-only paths. + Delegates cache lookup and gather behavior to `get_multimodal_embeddings`, then validates + the single tensor contract for both encoded and cached-only paths. """ - - def encoder_forward_fn(params: list[MultimodalParams]) -> list[torch.Tensor]: - encoder_output = self.encode_multimodal_inputs(params) - if not isinstance(encoder_output, MultimodalEncoderOutput): - raise TypeError("encode_multimodal_inputs must return MultimodalEncoderOutput.") - if not isinstance(encoder_output.embeddings, torch.Tensor): - raise TypeError("MultimodalEncoderOutput.embeddings must be a torch.Tensor.") - return [encoder_output.embeddings] - embeddings = get_multimodal_embeddings( - encoder_forward_fn=encoder_forward_fn, + encoder_forward_fn=self.encode_multimodal_inputs, multimodal_params=list(multimodal_params), ) - primary = self._require_primary_embedding(embeddings) - # Validate post-gather so cached-only paths (KV reuse, all-cached chunked - # prefill) are also checked, not just paths that ran the encoder. - self._validate_primary_embedding_rows(primary, multimodal_params) - return MultimodalEncoderOutput(embeddings=primary) - - def _find_active_multimodal_embeddings( - self, - multimodal_embeddings: list[torch.Tensor], - *, - input_ids: torch.Tensor, - positions: Optional[torch.Tensor], - multimodal_params: Sequence[MultimodalParams], - ) -> list[torch.Tensor]: - """Named internal stage for selecting active chunk multimodal rows. - - This initial template stage currently delegates to - `find_input_mm_embeds`. Model-specific behavior around slicing should - use `after_full_multimodal_embeddings` or - `after_active_multimodal_embeddings` so the common mixin sequence stays - centralized. - """ - return find_input_mm_embeds(multimodal_embeddings, list(multimodal_params)) + # Validate post-gather so cached-only paths (KV reuse, all-cached chunked prefill) are also + # checked, not just paths that ran the encoder. + self._validate_embeddings(embeddings, multimodal_params) + return embeddings[0] def _fuse_multimodal_embeddings( self, @@ -403,40 +353,46 @@ def _fuse_multimodal_embeddings( return fused_input_ids, inputs_embeds, () @staticmethod - def _require_primary_embedding(embeddings: list[torch.Tensor]) -> torch.Tensor: - if len(embeddings) != 1: - raise ValueError( - "MultimodalModelMixin requires a single primary embedding tensor, " - f"got {len(embeddings)} tensors." - ) - return embeddings[0] - - @staticmethod - def _validate_primary_embedding_rows( - primary: torch.Tensor, + def _validate_embeddings( + embeddings: list[torch.Tensor], multimodal_params: Sequence[MultimodalParams], ) -> None: - """Validate gathered primary embedding row count against runtime metadata. + """Validate gathered embeddings embedding row count against runtime metadata. Skipped if any param lacks `multimodal_runtime.total_embeds_in_request`, since the contract cannot be evaluated without complete metadata. """ + if len(embeddings) != 1: + raise ValueError( + f"MultimodalModelMixin requires a single embedding tensor, got {len(embeddings)} " + "tensors." + ) + + embeddings_tensor = embeddings[0] expected_rows = 0 + has_runtime_metadata = [] for param in multimodal_params: runtime = param.multimodal_runtime - if runtime is None or runtime.total_embeds_in_request is None: - logger.debug( - "Skipping multimodal embedding row-count validation: " - "runtime metadata missing or incomplete for at least one param." - ) - return - expected_rows += runtime.total_embeds_in_request + has_runtime = runtime is not None and runtime.total_embeds_in_request is not None + has_runtime_metadata.append(has_runtime) + if has_runtime: + expected_rows += runtime.total_embeds_in_request + + if any(has_runtime_metadata) and not all(has_runtime_metadata): + raise ValueError( + "Multimodal runtime metadata must be present for every param or none of them." + ) + if not all(has_runtime_metadata): + logger.debug( + "Skipping multimodal embedding row-count validation: runtime metadata missing " + "for all params." + ) + return - actual_rows = primary.shape[0] + actual_rows = embeddings_tensor.shape[0] if actual_rows != expected_rows: raise ValueError( - "Multimodal embedding row count mismatch: " - f"expected {expected_rows}, got {actual_rows}." + f"Multimodal embedding row count mismatch: expected {expected_rows}, got {actual_rows}." ) @@ -539,9 +495,7 @@ def _dispatch_cross_iter_prefetch( for (req, _, _), p in zip(candidates, params_list): req.py_multimodal_data = p.multimodal_data encoder_output = model.encode_multimodal_inputs(params_list) - if not isinstance(encoder_output, MultimodalEncoderOutput): - raise TypeError("encode_multimodal_inputs must return MultimodalEncoderOutput.") - _cache_multimodal_embeddings(params_list, [encoder_output.embeddings]) + _cache_multimodal_embeddings(params_list, [encoder_output]) finally: # Stash the event on every candidate's durable LlmRequest (not the # per-iter `MultimodalParams`), since `_prepare_inputs` rebuilds the diff --git a/tensorrt_llm/_torch/models/modeling_multimodal_utils.py b/tensorrt_llm/_torch/models/modeling_multimodal_utils.py index ab4389a84fe2..d8aef78a665a 100644 --- a/tensorrt_llm/_torch/models/modeling_multimodal_utils.py +++ b/tensorrt_llm/_torch/models/modeling_multimodal_utils.py @@ -191,10 +191,26 @@ def _cache_multimodal_embeddings( ) +def _normalize_encoder_embeddings( + encoder_embeddings: Union[torch.Tensor, List[torch.Tensor]], +) -> List[torch.Tensor]: + if isinstance(encoder_embeddings, torch.Tensor): + return [encoder_embeddings] + + if (not isinstance(encoder_embeddings, list) or not all( + isinstance(embedding, torch.Tensor) + for embedding in encoder_embeddings)): + raise TypeError( + "encoder_forward_fn must return a torch.Tensor or a list of torch.Tensor." + ) + + return encoder_embeddings + + def get_multimodal_embeddings( encoder_forward_fn: Callable[ - [List[MultimodalParams]], - List[torch.Tensor], + ..., + torch.Tensor | List[torch.Tensor], ], multimodal_params: List[MultimodalParams], encoder_kwargs: Optional[Dict[str, Any]] = None, @@ -210,7 +226,8 @@ def get_multimodal_embeddings( Args: encoder_forward_fn: Callable that performs encoder forward pass. - Should accept List[MultimodalParams] and return List[torch.Tensor]. + Should accept List[MultimodalParams] and return either + a single torch.Tensor or List[torch.Tensor]. multimodal_params: All multimodal parameters in the batch. encoder_kwargs: Optional kwargs to pass to encoder_forward_fn. Returns: @@ -235,6 +252,7 @@ def get_multimodal_embeddings( kwargs = encoder_kwargs or {} encoder_embeddings = encoder_forward_fn(uncached_multimodal_params, **kwargs) + encoder_embeddings = _normalize_encoder_embeddings(encoder_embeddings) # TODO: support multiple multimodal modalities per request if len(encoder_embeddings) > 1: diff --git a/tests/unittest/_torch/multimodal/test_mm_encoder_cross_iter_prefetch.py b/tests/unittest/_torch/multimodal/test_mm_encoder_cross_iter_prefetch.py index e3aca6190bde..a0fce6edeb37 100644 --- a/tests/unittest/_torch/multimodal/test_mm_encoder_cross_iter_prefetch.py +++ b/tests/unittest/_torch/multimodal/test_mm_encoder_cross_iter_prefetch.py @@ -25,7 +25,6 @@ from tensorrt_llm._torch.models import modeling_multimodal_mixin as mm_mixin from tensorrt_llm._torch.models.modeling_multimodal_mixin import ( - MultimodalEncoderOutput, MultimodalModelMixin, _get_mm_aux_stream, maybe_prefetch_mm_encoder_for_next_iter, @@ -40,7 +39,7 @@ def __init__(self, hidden_size: int, tokens_per_image: int): self.encoder_call_count = 0 self.encoder_call_stream_id = None - def encode_multimodal_inputs(self, multimodal_params, **kwargs) -> MultimodalEncoderOutput: + def encode_multimodal_inputs(self, multimodal_params, **kwargs) -> torch.Tensor: self.encoder_call_count += 1 self.encoder_call_stream_id = torch.cuda.current_stream().cuda_stream pv = multimodal_params[0].multimodal_data["image"]["pixel_values"] @@ -52,7 +51,7 @@ def encode_multimodal_inputs(self, multimodal_params, **kwargs) -> MultimodalEnc self._hidden_size, device=pv.device, ) - return MultimodalEncoderOutput(embeddings=embeddings) + return embeddings def _make_request(request_id: int, num_tokens: int) -> LlmRequest: diff --git a/tests/unittest/_torch/multimodal/test_multimodal_mixin.py b/tests/unittest/_torch/multimodal/test_multimodal_mixin.py index 8425094290cd..626c3493fa84 100644 --- a/tests/unittest/_torch/multimodal/test_multimodal_mixin.py +++ b/tests/unittest/_torch/multimodal/test_multimodal_mixin.py @@ -48,10 +48,28 @@ def encode_multimodal_inputs(self, multimodal_params): raise AssertionError("Tests use cached multimodal embeddings and should not encode.") +class TensorEncoderMultimodalModel(DummyMultimodalModel): + def __init__( + self, + embedding: Embedding, + mm_token_ids: torch.Tensor, + mm_embeds: torch.Tensor, + ): + super().__init__(embedding, mm_token_ids) + self.mm_embeds = mm_embeds + + def encode_multimodal_inputs(self, multimodal_params, **encoder_kwargs) -> torch.Tensor: + return self.mm_embeds + + def make_cached_multimodal_param(mm_embeds: torch.Tensor) -> MultimodalParams: return MultimodalParams(multimodal_data={"multimodal_embedding": mm_embeds}) +def make_raw_multimodal_param() -> MultimodalParams: + return MultimodalParams(multimodal_data={"image": {"pixel_values": torch.empty(1)}}) + + def test_cast_multimodal_encoder_dtype_keeps_meta_tensors_meta(): module = torch.nn.Linear(4, 4, device="meta") @@ -95,3 +113,37 @@ def test_prepare_multimodal_inputs_forwards_precomputed_indices(device): mm_emb.to(dtype=out.inputs_embeds.dtype, device=out.inputs_embeds.device), ) torch.testing.assert_close(out.inputs_embeds[text_idx], emb(input_ids[text_idx])) + + +@pytest.mark.parametrize("device", ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])) +def test_prepare_multimodal_inputs_accepts_tensor_encoder_output(device): + hidden = 8 + mm_token_id = 7 + emb = make_embedding(num_embeddings=40, hidden_size=hidden, device=device) + + input_ids = torch.tensor([0, mm_token_id, 1], dtype=torch.long, device=device) + text_idx = torch.tensor([0, 2], dtype=torch.long, device=device) + mm_idx = torch.tensor([1], dtype=torch.long, device=device) + mm_emb = torch.randn(mm_idx.shape[0], hidden, device=device) + model = TensorEncoderMultimodalModel( + emb, + torch.tensor([mm_token_id], dtype=torch.long, device=device), + mm_emb, + ) + + out = model.prepare_multimodal_inputs( + input_ids=input_ids, + positions=None, + multimodal_params=[make_raw_multimodal_param()], + num_context_requests=1, + text_token_indices=text_idx, + mm_token_indices=mm_idx, + ) + + assert out.input_ids is None + assert out.inputs_embeds is not None + torch.testing.assert_close( + out.inputs_embeds[mm_idx], + mm_emb.to(dtype=out.inputs_embeds.dtype, device=out.inputs_embeds.device), + ) + torch.testing.assert_close(out.inputs_embeds[text_idx], emb(input_ids[text_idx]))