Skip to content

Commit 425d79d

Browse files
committed
[None][chore] Small cleanups to MultimodalModelMixin
Signed-off-by: William Zhang <133824995+2ez4bz@users.noreply.github.com>
1 parent 2a18bd4 commit 425d79d

5 files changed

Lines changed: 128 additions & 105 deletions

File tree

tensorrt_llm/_torch/models/modeling_mistral.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from tensorrt_llm._torch.models.modeling_mistral_large3 import (
2424
Mistral3Gate, MistralLarge3ForCausalLM)
2525
from tensorrt_llm._torch.models.modeling_multimodal_mixin import (
26-
MultimodalEncoderOutput, MultimodalModelMixin, PreparedLlmInputs)
26+
MultimodalModelMixin, PreparedLlmInputs)
2727
from tensorrt_llm._torch.models.modeling_multimodal_utils import (
2828
_MULTIMODAL_ENV_NAME, _is_mm_disagg)
2929
from tensorrt_llm._torch.models.modeling_utils import (DecoderModel,
@@ -705,9 +705,9 @@ def infer_max_seq_len(self) -> int:
705705
def encode_multimodal_inputs(
706706
self,
707707
multimodal_params: Sequence[MultimodalParams],
708-
) -> MultimodalEncoderOutput:
708+
) -> torch.Tensor:
709709
mm_embeds = self._vision_forward(list(multimodal_params))
710-
return MultimodalEncoderOutput(embeddings=mm_embeds[0])
710+
return mm_embeds[0]
711711

712712
def get_language_model_forward_kwargs(
713713
self,

tensorrt_llm/_torch/models/modeling_multimodal_mixin.py

Lines changed: 50 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -123,27 +123,6 @@ def _run_on_aux_stream(aux_stream: torch.cuda.Stream) -> Iterator[torch.cuda.Eve
123123
exit_event.record()
124124

125125

126-
@dataclass(frozen=True)
127-
class MultimodalEncoderOutput:
128-
"""Output produced by a model-owned multimodal encoder hook.
129-
130-
Contract:
131-
- `embeddings` contains all multimodal embedding rows for the supplied
132-
`multimodal_params`.
133-
- Rows are concatenated in the same order as `multimodal_params`.
134-
- Per-request row counts match `total_embeds_in_request` from runtime
135-
metadata when that metadata is available.
136-
- Special multimodal tokens occupy token positions but do not have rows in
137-
this tensor.
138-
139-
The single-tensor shape is required for chunked-prefill embedding reuse,
140-
which lets later chunks skip the encoder. See
141-
`modeling_multimodal_utils.py` for the caching machinery.
142-
"""
143-
144-
embeddings: torch.Tensor
145-
146-
147126
@dataclass(frozen=True)
148127
class PreparedLlmInputs:
149128
"""Prepared inputs returned by `MultimodalModelMixin`."""
@@ -181,8 +160,13 @@ def convert(tensor: torch.Tensor) -> torch.Tensor:
181160
def encode_multimodal_inputs(
182161
self,
183162
multimodal_params: Sequence[MultimodalParams],
184-
) -> MultimodalEncoderOutput:
185-
"""Run model-specific multimodal encoder work."""
163+
) -> torch.Tensor:
164+
"""Run model-specific multimodal encoder work.
165+
166+
Returns the single primary multimodal embedding tensor for the supplied params. Rows are
167+
expected to be concatenated in request order, and special multimodal tokens occupy token
168+
positions but do not have rows here.
169+
"""
186170
raise NotImplementedError
187171

188172
@property
@@ -222,16 +206,16 @@ def after_full_multimodal_embeddings(
222206
*,
223207
input_ids: torch.Tensor,
224208
multimodal_params: Sequence[MultimodalParams],
225-
encoder_output: MultimodalEncoderOutput,
209+
embeddings: torch.Tensor,
226210
**forward_kwargs: Any,
227-
) -> tuple[torch.Tensor, MultimodalEncoderOutput]:
211+
) -> tuple[torch.Tensor, torch.Tensor]:
228212
"""Optional hook before active chunk rows are selected.
229213
230214
Runs after cache lookup or encoder execution has produced full
231215
per-request multimodal embeddings, but before the mixin selects rows
232216
active in the current forward chunk.
233217
"""
234-
return input_ids, encoder_output
218+
return input_ids, embeddings
235219

236220
def after_active_multimodal_embeddings(
237221
self,
@@ -274,21 +258,16 @@ def prepare_multimodal_inputs(
274258
if not context_params:
275259
return PreparedLlmInputs(input_ids=input_ids, inputs_embeds=None)
276260

277-
full_output = self._get_or_encode_multimodal_embeddings(context_params)
261+
full_embeddings = self._get_or_encode_multimodal_embeddings(context_params)
278262

279-
input_ids, full_output = self.after_full_multimodal_embeddings(
263+
input_ids, full_embeddings = self.after_full_multimodal_embeddings(
280264
input_ids=input_ids,
281265
multimodal_params=context_params,
282-
encoder_output=full_output,
266+
embeddings=full_embeddings,
283267
**forward_kwargs,
284268
)
285269

286-
active_embeddings = self._find_active_multimodal_embeddings(
287-
[full_output.embeddings],
288-
input_ids=input_ids,
289-
positions=positions,
290-
multimodal_params=context_params,
291-
)
270+
active_embeddings = find_input_mm_embeds([full_embeddings], list(context_params))
292271
active_embeddings, extra_embeds = self.after_active_multimodal_embeddings(
293272
active_embeddings=active_embeddings,
294273
multimodal_params=context_params,
@@ -319,49 +298,20 @@ def prepare_multimodal_inputs(
319298
def _get_or_encode_multimodal_embeddings(
320299
self,
321300
multimodal_params: Sequence[MultimodalParams],
322-
) -> MultimodalEncoderOutput:
301+
) -> torch.Tensor:
323302
"""Return cached multimodal embeddings or run the encoder for misses.
324303
325-
Delegates cache lookup and gather behavior to
326-
`get_multimodal_embeddings`, then validates the single primary tensor
327-
contract for both encoded and cached-only paths.
304+
Delegates cache lookup and gather behavior to `get_multimodal_embeddings`, then validates
305+
the single tensor contract for both encoded and cached-only paths.
328306
"""
329-
330-
def encoder_forward_fn(params: list[MultimodalParams]) -> list[torch.Tensor]:
331-
encoder_output = self.encode_multimodal_inputs(params)
332-
if not isinstance(encoder_output, MultimodalEncoderOutput):
333-
raise TypeError("encode_multimodal_inputs must return MultimodalEncoderOutput.")
334-
if not isinstance(encoder_output.embeddings, torch.Tensor):
335-
raise TypeError("MultimodalEncoderOutput.embeddings must be a torch.Tensor.")
336-
return [encoder_output.embeddings]
337-
338307
embeddings = get_multimodal_embeddings(
339-
encoder_forward_fn=encoder_forward_fn,
308+
encoder_forward_fn=self.encode_multimodal_inputs,
340309
multimodal_params=list(multimodal_params),
341310
)
342-
primary = self._require_primary_embedding(embeddings)
343-
# Validate post-gather so cached-only paths (KV reuse, all-cached chunked
344-
# prefill) are also checked, not just paths that ran the encoder.
345-
self._validate_primary_embedding_rows(primary, multimodal_params)
346-
return MultimodalEncoderOutput(embeddings=primary)
347-
348-
def _find_active_multimodal_embeddings(
349-
self,
350-
multimodal_embeddings: list[torch.Tensor],
351-
*,
352-
input_ids: torch.Tensor,
353-
positions: Optional[torch.Tensor],
354-
multimodal_params: Sequence[MultimodalParams],
355-
) -> list[torch.Tensor]:
356-
"""Named internal stage for selecting active chunk multimodal rows.
357-
358-
This initial template stage currently delegates to
359-
`find_input_mm_embeds`. Model-specific behavior around slicing should
360-
use `after_full_multimodal_embeddings` or
361-
`after_active_multimodal_embeddings` so the common mixin sequence stays
362-
centralized.
363-
"""
364-
return find_input_mm_embeds(multimodal_embeddings, list(multimodal_params))
311+
# Validate post-gather so cached-only paths (KV reuse, all-cached chunked prefill) are also
312+
# checked, not just paths that ran the encoder.
313+
self._validate_embeddings(embeddings, multimodal_params)
314+
return embeddings[0]
365315

366316
def _fuse_multimodal_embeddings(
367317
self,
@@ -403,40 +353,46 @@ def _fuse_multimodal_embeddings(
403353
return fused_input_ids, inputs_embeds, ()
404354

405355
@staticmethod
406-
def _require_primary_embedding(embeddings: list[torch.Tensor]) -> torch.Tensor:
407-
if len(embeddings) != 1:
408-
raise ValueError(
409-
"MultimodalModelMixin requires a single primary embedding tensor, "
410-
f"got {len(embeddings)} tensors."
411-
)
412-
return embeddings[0]
413-
414-
@staticmethod
415-
def _validate_primary_embedding_rows(
416-
primary: torch.Tensor,
356+
def _validate_embeddings(
357+
embeddings: list[torch.Tensor],
417358
multimodal_params: Sequence[MultimodalParams],
418359
) -> None:
419-
"""Validate gathered primary embedding row count against runtime metadata.
360+
"""Validate gathered embeddings embedding row count against runtime metadata.
420361
421362
Skipped if any param lacks `multimodal_runtime.total_embeds_in_request`, since the contract
422363
cannot be evaluated without complete metadata.
423364
"""
365+
if len(embeddings) != 1:
366+
raise ValueError(
367+
f"MultimodalModelMixin requires a single embedding tensor, got {len(embeddings)} "
368+
"tensors."
369+
)
370+
371+
embeddings_tensor = embeddings[0]
424372
expected_rows = 0
373+
has_runtime_metadata = []
425374
for param in multimodal_params:
426375
runtime = param.multimodal_runtime
427-
if runtime is None or runtime.total_embeds_in_request is None:
428-
logger.debug(
429-
"Skipping multimodal embedding row-count validation: "
430-
"runtime metadata missing or incomplete for at least one param."
431-
)
432-
return
433-
expected_rows += runtime.total_embeds_in_request
376+
has_runtime = runtime is not None and runtime.total_embeds_in_request is not None
377+
has_runtime_metadata.append(has_runtime)
378+
if has_runtime:
379+
expected_rows += runtime.total_embeds_in_request
380+
381+
if any(has_runtime_metadata) and not all(has_runtime_metadata):
382+
raise ValueError(
383+
"Multimodal runtime metadata must be present for every param or none of them."
384+
)
385+
if not all(has_runtime_metadata):
386+
logger.debug(
387+
"Skipping multimodal embedding row-count validation: runtime metadata missing "
388+
"for all params."
389+
)
390+
return
434391

435-
actual_rows = primary.shape[0]
392+
actual_rows = embeddings_tensor.shape[0]
436393
if actual_rows != expected_rows:
437394
raise ValueError(
438-
"Multimodal embedding row count mismatch: "
439-
f"expected {expected_rows}, got {actual_rows}."
395+
f"Multimodal embedding row count mismatch: expected {expected_rows}, got {actual_rows}."
440396
)
441397

442398

@@ -539,8 +495,6 @@ def _dispatch_cross_iter_prefetch(
539495
for (req, _, _), p in zip(candidates, params_list):
540496
req.py_multimodal_data = p.multimodal_data
541497
encoder_output = model.encode_multimodal_inputs(params_list)
542-
if not isinstance(encoder_output, MultimodalEncoderOutput):
543-
raise TypeError("encode_multimodal_inputs must return MultimodalEncoderOutput.")
544498
_cache_multimodal_embeddings(params_list, [encoder_output.embeddings])
545499
finally:
546500
# Stash the event on every candidate's durable LlmRequest (not the

tensorrt_llm/_torch/models/modeling_multimodal_utils.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -191,10 +191,26 @@ def _cache_multimodal_embeddings(
191191
)
192192

193193

194+
def _normalize_encoder_embeddings(
195+
encoder_embeddings: Union[torch.Tensor, List[torch.Tensor]],
196+
) -> List[torch.Tensor]:
197+
if isinstance(encoder_embeddings, torch.Tensor):
198+
return [encoder_embeddings]
199+
200+
if (not isinstance(encoder_embeddings, list) or not all(
201+
isinstance(embedding, torch.Tensor)
202+
for embedding in encoder_embeddings)):
203+
raise TypeError(
204+
"encoder_forward_fn must return a torch.Tensor or a list of torch.Tensor."
205+
)
206+
207+
return encoder_embeddings
208+
209+
194210
def get_multimodal_embeddings(
195211
encoder_forward_fn: Callable[
196-
[List[MultimodalParams]],
197-
List[torch.Tensor],
212+
...,
213+
torch.Tensor | List[torch.Tensor],
198214
],
199215
multimodal_params: List[MultimodalParams],
200216
encoder_kwargs: Optional[Dict[str, Any]] = None,
@@ -210,7 +226,8 @@ def get_multimodal_embeddings(
210226
211227
Args:
212228
encoder_forward_fn: Callable that performs encoder forward pass.
213-
Should accept List[MultimodalParams] and return List[torch.Tensor].
229+
Should accept List[MultimodalParams] and return either
230+
a single torch.Tensor or List[torch.Tensor].
214231
multimodal_params: All multimodal parameters in the batch.
215232
encoder_kwargs: Optional kwargs to pass to encoder_forward_fn.
216233
Returns:
@@ -235,6 +252,7 @@ def get_multimodal_embeddings(
235252
kwargs = encoder_kwargs or {}
236253
encoder_embeddings = encoder_forward_fn(uncached_multimodal_params,
237254
**kwargs)
255+
encoder_embeddings = _normalize_encoder_embeddings(encoder_embeddings)
238256

239257
# TODO: support multiple multimodal modalities per request
240258
if len(encoder_embeddings) > 1:

tests/unittest/_torch/multimodal/test_mm_encoder_cross_iter_prefetch.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525

2626
from tensorrt_llm._torch.models import modeling_multimodal_mixin as mm_mixin
2727
from tensorrt_llm._torch.models.modeling_multimodal_mixin import (
28-
MultimodalEncoderOutput,
2928
MultimodalModelMixin,
3029
_get_mm_aux_stream,
3130
maybe_prefetch_mm_encoder_for_next_iter,
@@ -40,7 +39,7 @@ def __init__(self, hidden_size: int, tokens_per_image: int):
4039
self.encoder_call_count = 0
4140
self.encoder_call_stream_id = None
4241

43-
def encode_multimodal_inputs(self, multimodal_params, **kwargs) -> MultimodalEncoderOutput:
42+
def encode_multimodal_inputs(self, multimodal_params, **kwargs) -> torch.Tensor:
4443
self.encoder_call_count += 1
4544
self.encoder_call_stream_id = torch.cuda.current_stream().cuda_stream
4645
pv = multimodal_params[0].multimodal_data["image"]["pixel_values"]
@@ -52,7 +51,7 @@ def encode_multimodal_inputs(self, multimodal_params, **kwargs) -> MultimodalEnc
5251
self._hidden_size,
5352
device=pv.device,
5453
)
55-
return MultimodalEncoderOutput(embeddings=embeddings)
54+
return embeddings
5655

5756

5857
def _make_request(request_id: int, num_tokens: int) -> LlmRequest:

tests/unittest/_torch/multimodal/test_multimodal_mixin.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,28 @@ def encode_multimodal_inputs(self, multimodal_params):
4848
raise AssertionError("Tests use cached multimodal embeddings and should not encode.")
4949

5050

51+
class TensorEncoderMultimodalModel(DummyMultimodalModel):
52+
def __init__(
53+
self,
54+
embedding: Embedding,
55+
mm_token_ids: torch.Tensor,
56+
mm_embeds: torch.Tensor,
57+
):
58+
super().__init__(embedding, mm_token_ids)
59+
self.mm_embeds = mm_embeds
60+
61+
def encode_multimodal_inputs(self, multimodal_params, **encoder_kwargs) -> torch.Tensor:
62+
return self.mm_embeds
63+
64+
5165
def make_cached_multimodal_param(mm_embeds: torch.Tensor) -> MultimodalParams:
5266
return MultimodalParams(multimodal_data={"multimodal_embedding": mm_embeds})
5367

5468

69+
def make_raw_multimodal_param() -> MultimodalParams:
70+
return MultimodalParams(multimodal_data={"image": {"pixel_values": torch.empty(1)}})
71+
72+
5573
def test_cast_multimodal_encoder_dtype_keeps_meta_tensors_meta():
5674
module = torch.nn.Linear(4, 4, device="meta")
5775

@@ -95,3 +113,37 @@ def test_prepare_multimodal_inputs_forwards_precomputed_indices(device):
95113
mm_emb.to(dtype=out.inputs_embeds.dtype, device=out.inputs_embeds.device),
96114
)
97115
torch.testing.assert_close(out.inputs_embeds[text_idx], emb(input_ids[text_idx]))
116+
117+
118+
@pytest.mark.parametrize("device", ["cpu"] + (["cuda"] if torch.cuda.is_available() else []))
119+
def test_prepare_multimodal_inputs_accepts_tensor_encoder_output(device):
120+
hidden = 8
121+
mm_token_id = 7
122+
emb = make_embedding(num_embeddings=40, hidden_size=hidden, device=device)
123+
124+
input_ids = torch.tensor([0, mm_token_id, 1], dtype=torch.long, device=device)
125+
text_idx = torch.tensor([0, 2], dtype=torch.long, device=device)
126+
mm_idx = torch.tensor([1], dtype=torch.long, device=device)
127+
mm_emb = torch.randn(mm_idx.shape[0], hidden, device=device)
128+
model = TensorEncoderMultimodalModel(
129+
emb,
130+
torch.tensor([mm_token_id], dtype=torch.long, device=device),
131+
mm_emb,
132+
)
133+
134+
out = model.prepare_multimodal_inputs(
135+
input_ids=input_ids,
136+
positions=None,
137+
multimodal_params=[make_raw_multimodal_param()],
138+
num_context_requests=1,
139+
text_token_indices=text_idx,
140+
mm_token_indices=mm_idx,
141+
)
142+
143+
assert out.input_ids is None
144+
assert out.inputs_embeds is not None
145+
torch.testing.assert_close(
146+
out.inputs_embeds[mm_idx],
147+
mm_emb.to(dtype=out.inputs_embeds.dtype, device=out.inputs_embeds.device),
148+
)
149+
torch.testing.assert_close(out.inputs_embeds[text_idx], emb(input_ids[text_idx]))

0 commit comments

Comments
 (0)