Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions tensorrt_llm/_torch/models/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from tensorrt_llm._torch.models.modeling_mistral_large3 import (
Mistral3Gate, MistralLarge3ForCausalLM)
from tensorrt_llm._torch.models.modeling_multimodal_mixin import (
MultimodalEncoderOutput, MultimodalModelMixin, PreparedLlmInputs)
MultimodalModelMixin, PreparedLlmInputs)
from tensorrt_llm._torch.models.modeling_multimodal_utils import (
_MULTIMODAL_ENV_NAME, _is_mm_disagg)
from tensorrt_llm._torch.models.modeling_utils import (DecoderModel,
Expand Down Expand Up @@ -705,9 +705,9 @@ def infer_max_seq_len(self) -> int:
def encode_multimodal_inputs(
self,
multimodal_params: Sequence[MultimodalParams],
) -> MultimodalEncoderOutput:
) -> torch.Tensor:
mm_embeds = self._vision_forward(list(multimodal_params))
return MultimodalEncoderOutput(embeddings=mm_embeds[0])
return mm_embeds[0]

def get_language_model_forward_kwargs(
self,
Expand Down
148 changes: 51 additions & 97 deletions tensorrt_llm/_torch/models/modeling_multimodal_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,27 +123,6 @@ def _run_on_aux_stream(aux_stream: torch.cuda.Stream) -> Iterator[torch.cuda.Eve
exit_event.record()


@dataclass(frozen=True)
class MultimodalEncoderOutput:
"""Output produced by a model-owned multimodal encoder hook.

Contract:
- `embeddings` contains all multimodal embedding rows for the supplied
`multimodal_params`.
- Rows are concatenated in the same order as `multimodal_params`.
- Per-request row counts match `total_embeds_in_request` from runtime
metadata when that metadata is available.
- Special multimodal tokens occupy token positions but do not have rows in
this tensor.

The single-tensor shape is required for chunked-prefill embedding reuse,
which lets later chunks skip the encoder. See
`modeling_multimodal_utils.py` for the caching machinery.
"""

embeddings: torch.Tensor


@dataclass(frozen=True)
class PreparedLlmInputs:
"""Prepared inputs returned by `MultimodalModelMixin`."""
Expand Down Expand Up @@ -181,8 +160,13 @@ def convert(tensor: torch.Tensor) -> torch.Tensor:
def encode_multimodal_inputs(
self,
multimodal_params: Sequence[MultimodalParams],
) -> MultimodalEncoderOutput:
"""Run model-specific multimodal encoder work."""
) -> torch.Tensor:
"""Run model-specific multimodal encoder work.

Returns the single primary multimodal embedding tensor for the supplied params. Rows are
expected to be concatenated in request order, and special multimodal tokens occupy token
positions but do not have rows here.
Comment thread
moraxu marked this conversation as resolved.
"""
raise NotImplementedError

@property
Expand Down Expand Up @@ -222,16 +206,16 @@ def after_full_multimodal_embeddings(
*,
input_ids: torch.Tensor,
multimodal_params: Sequence[MultimodalParams],
encoder_output: MultimodalEncoderOutput,
embeddings: torch.Tensor,
**forward_kwargs: Any,
) -> tuple[torch.Tensor, MultimodalEncoderOutput]:
) -> tuple[torch.Tensor, torch.Tensor]:
"""Optional hook before active chunk rows are selected.

Runs after cache lookup or encoder execution has produced full
per-request multimodal embeddings, but before the mixin selects rows
active in the current forward chunk.
"""
return input_ids, encoder_output
return input_ids, embeddings

def after_active_multimodal_embeddings(
self,
Expand Down Expand Up @@ -274,21 +258,16 @@ def prepare_multimodal_inputs(
if not context_params:
return PreparedLlmInputs(input_ids=input_ids, inputs_embeds=None)

full_output = self._get_or_encode_multimodal_embeddings(context_params)
full_embeddings = self._get_or_encode_multimodal_embeddings(context_params)

input_ids, full_output = self.after_full_multimodal_embeddings(
input_ids, full_embeddings = self.after_full_multimodal_embeddings(
input_ids=input_ids,
multimodal_params=context_params,
encoder_output=full_output,
embeddings=full_embeddings,
**forward_kwargs,
)

active_embeddings = self._find_active_multimodal_embeddings(
[full_output.embeddings],
input_ids=input_ids,
positions=positions,
multimodal_params=context_params,
)
active_embeddings = find_input_mm_embeds([full_embeddings], list(context_params))
active_embeddings, extra_embeds = self.after_active_multimodal_embeddings(
active_embeddings=active_embeddings,
multimodal_params=context_params,
Expand Down Expand Up @@ -319,49 +298,20 @@ def prepare_multimodal_inputs(
def _get_or_encode_multimodal_embeddings(
self,
multimodal_params: Sequence[MultimodalParams],
) -> MultimodalEncoderOutput:
) -> torch.Tensor:
"""Return cached multimodal embeddings or run the encoder for misses.

Delegates cache lookup and gather behavior to
`get_multimodal_embeddings`, then validates the single primary tensor
contract for both encoded and cached-only paths.
Delegates cache lookup and gather behavior to `get_multimodal_embeddings`, then validates
the single tensor contract for both encoded and cached-only paths.
"""

def encoder_forward_fn(params: list[MultimodalParams]) -> list[torch.Tensor]:
encoder_output = self.encode_multimodal_inputs(params)
if not isinstance(encoder_output, MultimodalEncoderOutput):
raise TypeError("encode_multimodal_inputs must return MultimodalEncoderOutput.")
if not isinstance(encoder_output.embeddings, torch.Tensor):
raise TypeError("MultimodalEncoderOutput.embeddings must be a torch.Tensor.")
return [encoder_output.embeddings]

embeddings = get_multimodal_embeddings(
encoder_forward_fn=encoder_forward_fn,
encoder_forward_fn=self.encode_multimodal_inputs,
multimodal_params=list(multimodal_params),
)
primary = self._require_primary_embedding(embeddings)
# Validate post-gather so cached-only paths (KV reuse, all-cached chunked
# prefill) are also checked, not just paths that ran the encoder.
self._validate_primary_embedding_rows(primary, multimodal_params)
return MultimodalEncoderOutput(embeddings=primary)

def _find_active_multimodal_embeddings(
self,
multimodal_embeddings: list[torch.Tensor],
*,
input_ids: torch.Tensor,
positions: Optional[torch.Tensor],
multimodal_params: Sequence[MultimodalParams],
) -> list[torch.Tensor]:
"""Named internal stage for selecting active chunk multimodal rows.

This initial template stage currently delegates to
`find_input_mm_embeds`. Model-specific behavior around slicing should
use `after_full_multimodal_embeddings` or
`after_active_multimodal_embeddings` so the common mixin sequence stays
centralized.
"""
return find_input_mm_embeds(multimodal_embeddings, list(multimodal_params))
# Validate post-gather so cached-only paths (KV reuse, all-cached chunked prefill) are also
# checked, not just paths that ran the encoder.
self._validate_embeddings(embeddings, multimodal_params)
return embeddings[0]

def _fuse_multimodal_embeddings(
self,
Expand Down Expand Up @@ -403,40 +353,46 @@ def _fuse_multimodal_embeddings(
return fused_input_ids, inputs_embeds, ()

@staticmethod
def _require_primary_embedding(embeddings: list[torch.Tensor]) -> torch.Tensor:
if len(embeddings) != 1:
raise ValueError(
"MultimodalModelMixin requires a single primary embedding tensor, "
f"got {len(embeddings)} tensors."
)
return embeddings[0]

@staticmethod
def _validate_primary_embedding_rows(
primary: torch.Tensor,
def _validate_embeddings(
embeddings: list[torch.Tensor],
multimodal_params: Sequence[MultimodalParams],
) -> None:
"""Validate gathered primary embedding row count against runtime metadata.
"""Validate gathered embeddings embedding row count against runtime metadata.

Skipped if any param lacks `multimodal_runtime.total_embeds_in_request`, since the contract
cannot be evaluated without complete metadata.
"""
if len(embeddings) != 1:
raise ValueError(
f"MultimodalModelMixin requires a single embedding tensor, got {len(embeddings)} "
"tensors."
)

embeddings_tensor = embeddings[0]
expected_rows = 0
has_runtime_metadata = []
for param in multimodal_params:
runtime = param.multimodal_runtime
if runtime is None or runtime.total_embeds_in_request is None:
logger.debug(
"Skipping multimodal embedding row-count validation: "
"runtime metadata missing or incomplete for at least one param."
)
return
expected_rows += runtime.total_embeds_in_request
has_runtime = runtime is not None and runtime.total_embeds_in_request is not None
has_runtime_metadata.append(has_runtime)
if has_runtime:
expected_rows += runtime.total_embeds_in_request

if any(has_runtime_metadata) and not all(has_runtime_metadata):
raise ValueError(
"Multimodal runtime metadata must be present for every param or none of them."
)
if not all(has_runtime_metadata):
logger.debug(
"Skipping multimodal embedding row-count validation: runtime metadata missing "
"for all params."
)
return

actual_rows = primary.shape[0]
actual_rows = embeddings_tensor.shape[0]
if actual_rows != expected_rows:
raise ValueError(
"Multimodal embedding row count mismatch: "
f"expected {expected_rows}, got {actual_rows}."
f"Multimodal embedding row count mismatch: expected {expected_rows}, got {actual_rows}."
)


Expand Down Expand Up @@ -539,9 +495,7 @@ def _dispatch_cross_iter_prefetch(
for (req, _, _), p in zip(candidates, params_list):
req.py_multimodal_data = p.multimodal_data
encoder_output = model.encode_multimodal_inputs(params_list)
if not isinstance(encoder_output, MultimodalEncoderOutput):
raise TypeError("encode_multimodal_inputs must return MultimodalEncoderOutput.")
_cache_multimodal_embeddings(params_list, [encoder_output.embeddings])
_cache_multimodal_embeddings(params_list, [encoder_output])
finally:
# Stash the event on every candidate's durable LlmRequest (not the
# per-iter `MultimodalParams`), since `_prepare_inputs` rebuilds the
Expand Down
24 changes: 21 additions & 3 deletions tensorrt_llm/_torch/models/modeling_multimodal_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,26 @@ def _cache_multimodal_embeddings(
)


def _normalize_encoder_embeddings(
encoder_embeddings: Union[torch.Tensor, List[torch.Tensor]],
) -> List[torch.Tensor]:
if isinstance(encoder_embeddings, torch.Tensor):
return [encoder_embeddings]

if (not isinstance(encoder_embeddings, list) or not all(
isinstance(embedding, torch.Tensor)
for embedding in encoder_embeddings)):
raise TypeError(
"encoder_forward_fn must return a torch.Tensor or a list of torch.Tensor."
)

return encoder_embeddings


def get_multimodal_embeddings(
encoder_forward_fn: Callable[
[List[MultimodalParams]],
List[torch.Tensor],
...,
torch.Tensor | List[torch.Tensor],
],
multimodal_params: List[MultimodalParams],
encoder_kwargs: Optional[Dict[str, Any]] = None,
Expand All @@ -210,7 +226,8 @@ def get_multimodal_embeddings(

Args:
encoder_forward_fn: Callable that performs encoder forward pass.
Should accept List[MultimodalParams] and return List[torch.Tensor].
Should accept List[MultimodalParams] and return either
a single torch.Tensor or List[torch.Tensor].
multimodal_params: All multimodal parameters in the batch.
encoder_kwargs: Optional kwargs to pass to encoder_forward_fn.
Returns:
Expand All @@ -235,6 +252,7 @@ def get_multimodal_embeddings(
kwargs = encoder_kwargs or {}
encoder_embeddings = encoder_forward_fn(uncached_multimodal_params,
**kwargs)
encoder_embeddings = _normalize_encoder_embeddings(encoder_embeddings)

# TODO: support multiple multimodal modalities per request
if len(encoder_embeddings) > 1:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@

from tensorrt_llm._torch.models import modeling_multimodal_mixin as mm_mixin
from tensorrt_llm._torch.models.modeling_multimodal_mixin import (
MultimodalEncoderOutput,
MultimodalModelMixin,
_get_mm_aux_stream,
maybe_prefetch_mm_encoder_for_next_iter,
Expand All @@ -40,7 +39,7 @@ def __init__(self, hidden_size: int, tokens_per_image: int):
self.encoder_call_count = 0
self.encoder_call_stream_id = None

def encode_multimodal_inputs(self, multimodal_params, **kwargs) -> MultimodalEncoderOutput:
def encode_multimodal_inputs(self, multimodal_params, **kwargs) -> torch.Tensor:
self.encoder_call_count += 1
self.encoder_call_stream_id = torch.cuda.current_stream().cuda_stream
pv = multimodal_params[0].multimodal_data["image"]["pixel_values"]
Expand All @@ -52,7 +51,7 @@ def encode_multimodal_inputs(self, multimodal_params, **kwargs) -> MultimodalEnc
self._hidden_size,
device=pv.device,
)
return MultimodalEncoderOutput(embeddings=embeddings)
return embeddings


def _make_request(request_id: int, num_tokens: int) -> LlmRequest:
Expand Down
52 changes: 52 additions & 0 deletions tests/unittest/_torch/multimodal/test_multimodal_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,28 @@ def encode_multimodal_inputs(self, multimodal_params):
raise AssertionError("Tests use cached multimodal embeddings and should not encode.")


class TensorEncoderMultimodalModel(DummyMultimodalModel):
def __init__(
self,
embedding: Embedding,
mm_token_ids: torch.Tensor,
mm_embeds: torch.Tensor,
):
super().__init__(embedding, mm_token_ids)
self.mm_embeds = mm_embeds

def encode_multimodal_inputs(self, multimodal_params, **encoder_kwargs) -> torch.Tensor:
return self.mm_embeds


def make_cached_multimodal_param(mm_embeds: torch.Tensor) -> MultimodalParams:
return MultimodalParams(multimodal_data={"multimodal_embedding": mm_embeds})


def make_raw_multimodal_param() -> MultimodalParams:
return MultimodalParams(multimodal_data={"image": {"pixel_values": torch.empty(1)}})


def test_cast_multimodal_encoder_dtype_keeps_meta_tensors_meta():
module = torch.nn.Linear(4, 4, device="meta")

Expand Down Expand Up @@ -95,3 +113,37 @@ def test_prepare_multimodal_inputs_forwards_precomputed_indices(device):
mm_emb.to(dtype=out.inputs_embeds.dtype, device=out.inputs_embeds.device),
)
torch.testing.assert_close(out.inputs_embeds[text_idx], emb(input_ids[text_idx]))


@pytest.mark.parametrize("device", ["cpu"] + (["cuda"] if torch.cuda.is_available() else []))
def test_prepare_multimodal_inputs_accepts_tensor_encoder_output(device):
hidden = 8
mm_token_id = 7
emb = make_embedding(num_embeddings=40, hidden_size=hidden, device=device)

input_ids = torch.tensor([0, mm_token_id, 1], dtype=torch.long, device=device)
text_idx = torch.tensor([0, 2], dtype=torch.long, device=device)
mm_idx = torch.tensor([1], dtype=torch.long, device=device)
mm_emb = torch.randn(mm_idx.shape[0], hidden, device=device)
model = TensorEncoderMultimodalModel(
emb,
torch.tensor([mm_token_id], dtype=torch.long, device=device),
mm_emb,
)

out = model.prepare_multimodal_inputs(
input_ids=input_ids,
positions=None,
multimodal_params=[make_raw_multimodal_param()],
num_context_requests=1,
text_token_indices=text_idx,
mm_token_indices=mm_idx,
)

assert out.input_ids is None
assert out.inputs_embeds is not None
torch.testing.assert_close(
out.inputs_embeds[mm_idx],
mm_emb.to(dtype=out.inputs_embeds.dtype, device=out.inputs_embeds.device),
)
torch.testing.assert_close(out.inputs_embeds[text_idx], emb(input_ids[text_idx]))
Loading