Skip to content

Commit 733542c

Browse files
committed
[None][chore] Small cleanups to MultimodalModelMixin
Signed-off-by: William Zhang <133824995+2ez4bz@users.noreply.github.com>
1 parent 54dec4f commit 733542c

4 files changed

Lines changed: 112 additions & 95 deletions

File tree

tensorrt_llm/_torch/models/modeling_mistral.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from tensorrt_llm._torch.models.modeling_mistral_large3 import (
2424
Mistral3Gate, MistralLarge3ForCausalLM)
2525
from tensorrt_llm._torch.models.modeling_multimodal_mixin import (
26-
MultimodalEncoderOutput, MultimodalModelMixin, PreparedLlmInputs)
26+
MultimodalModelMixin, PreparedLlmInputs)
2727
from tensorrt_llm._torch.models.modeling_multimodal_utils import (
2828
_MULTIMODAL_ENV_NAME, _is_mm_disagg)
2929
from tensorrt_llm._torch.models.modeling_utils import (DecoderModel,
@@ -706,9 +706,9 @@ def encode_multimodal_inputs(
706706
self,
707707
multimodal_params: Sequence[MultimodalParams],
708708
**encoder_kwargs: Any,
709-
) -> MultimodalEncoderOutput:
709+
) -> torch.Tensor:
710710
mm_embeds = self._vision_forward(list(multimodal_params))
711-
return MultimodalEncoderOutput(embeddings=mm_embeds[0])
711+
return mm_embeds[0]
712712

713713
def get_language_model_forward_kwargs(
714714
self,

tensorrt_llm/_torch/models/modeling_multimodal_mixin.py

Lines changed: 36 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -29,27 +29,6 @@
2929
)
3030

3131

32-
@dataclass(frozen=True)
33-
class MultimodalEncoderOutput:
34-
"""Output produced by a model-owned multimodal encoder hook.
35-
36-
Contract:
37-
- `embeddings` contains all multimodal embedding rows for the supplied
38-
`multimodal_params`.
39-
- Rows are concatenated in the same order as `multimodal_params`.
40-
- Per-request row counts match `total_embeds_in_request` from runtime
41-
metadata when that metadata is available.
42-
- Special multimodal tokens occupy token positions but do not have rows in
43-
this tensor.
44-
45-
The single-tensor shape is required for chunked-prefill embedding reuse,
46-
which lets later chunks skip the encoder. See
47-
`modeling_multimodal_utils.py` for the caching machinery.
48-
"""
49-
50-
embeddings: torch.Tensor
51-
52-
5332
@dataclass(frozen=True)
5433
class PreparedLlmInputs:
5534
"""Prepared inputs returned by `MultimodalModelMixin`."""
@@ -71,8 +50,13 @@ def encode_multimodal_inputs(
7150
self,
7251
multimodal_params: Sequence[MultimodalParams],
7352
**encoder_kwargs: Any,
74-
) -> MultimodalEncoderOutput:
75-
"""Run model-specific multimodal encoder work."""
53+
) -> torch.Tensor:
54+
"""Run model-specific multimodal encoder work.
55+
56+
Returns the single primary multimodal embedding tensor for the supplied params. Rows are
57+
expected to be concatenated in request order, and special multimodal tokens occupy token
58+
positions but do not have rows here.
59+
"""
7660
raise NotImplementedError
7761

7862
@property
@@ -122,16 +106,16 @@ def after_full_multimodal_embeddings(
122106
*,
123107
input_ids: torch.Tensor,
124108
multimodal_params: Sequence[MultimodalParams],
125-
encoder_output: MultimodalEncoderOutput,
109+
embeddings: torch.Tensor,
126110
**forward_kwargs: Any,
127-
) -> tuple[torch.Tensor, MultimodalEncoderOutput]:
111+
) -> tuple[torch.Tensor, torch.Tensor]:
128112
"""Optional hook before active chunk rows are selected.
129113
130114
Runs after cache lookup or encoder execution has produced full
131115
per-request multimodal embeddings, but before the mixin selects rows
132116
active in the current forward chunk.
133117
"""
134-
return input_ids, encoder_output
118+
return input_ids, embeddings
135119

136120
def after_active_multimodal_embeddings(
137121
self,
@@ -179,24 +163,19 @@ def prepare_multimodal_inputs(
179163
multimodal_params=context_params,
180164
**forward_kwargs,
181165
)
182-
full_output = self._get_or_encode_multimodal_embeddings(
166+
full_embeddings = self._get_or_encode_multimodal_embeddings(
183167
context_params,
184168
**encoder_kwargs,
185169
)
186170

187-
input_ids, full_output = self.after_full_multimodal_embeddings(
171+
input_ids, full_embeddings = self.after_full_multimodal_embeddings(
188172
input_ids=input_ids,
189173
multimodal_params=context_params,
190-
encoder_output=full_output,
174+
embeddings=full_embeddings,
191175
**forward_kwargs,
192176
)
193177

194-
active_embeddings = self._find_active_multimodal_embeddings(
195-
[full_output.embeddings],
196-
input_ids=input_ids,
197-
positions=positions,
198-
multimodal_params=context_params,
199-
)
178+
active_embeddings = find_input_mm_embeds([full_embeddings], list(context_params))
200179
active_embeddings, extra_embeds = self.after_active_multimodal_embeddings(
201180
active_embeddings=active_embeddings,
202181
multimodal_params=context_params,
@@ -228,50 +207,21 @@ def _get_or_encode_multimodal_embeddings(
228207
self,
229208
multimodal_params: Sequence[MultimodalParams],
230209
**encoder_kwargs: Any,
231-
) -> MultimodalEncoderOutput:
210+
) -> torch.Tensor:
232211
"""Return cached multimodal embeddings or run the encoder for misses.
233212
234-
Delegates cache lookup and gather behavior to
235-
`get_multimodal_embeddings`, then validates the single primary tensor
236-
contract for both encoded and cached-only paths.
213+
Delegates cache lookup and gather behavior to `get_multimodal_embeddings`, then validates
214+
the single tensor contract for both encoded and cached-only paths.
237215
"""
238-
239-
def encoder_forward_fn(params: list[MultimodalParams], **kwargs: Any) -> list[torch.Tensor]:
240-
encoder_output = self.encode_multimodal_inputs(params, **kwargs)
241-
if not isinstance(encoder_output, MultimodalEncoderOutput):
242-
raise TypeError("encode_multimodal_inputs must return MultimodalEncoderOutput.")
243-
if not isinstance(encoder_output.embeddings, torch.Tensor):
244-
raise TypeError("MultimodalEncoderOutput.embeddings must be a torch.Tensor.")
245-
return [encoder_output.embeddings]
246-
247216
embeddings = get_multimodal_embeddings(
248-
encoder_forward_fn=encoder_forward_fn,
217+
encoder_forward_fn=self.encode_multimodal_inputs,
249218
multimodal_params=list(multimodal_params),
250219
encoder_kwargs=encoder_kwargs,
251220
)
252-
primary = self._require_primary_embedding(embeddings)
253-
# Validate post-gather so cached-only paths (KV reuse, all-cached chunked
254-
# prefill) are also checked, not just paths that ran the encoder.
255-
self._validate_primary_embedding_rows(primary, multimodal_params)
256-
return MultimodalEncoderOutput(embeddings=primary)
257-
258-
def _find_active_multimodal_embeddings(
259-
self,
260-
multimodal_embeddings: list[torch.Tensor],
261-
*,
262-
input_ids: torch.Tensor,
263-
positions: Optional[torch.Tensor],
264-
multimodal_params: Sequence[MultimodalParams],
265-
) -> list[torch.Tensor]:
266-
"""Named internal stage for selecting active chunk multimodal rows.
267-
268-
This initial template stage currently delegates to
269-
`find_input_mm_embeds`. Model-specific behavior around slicing should
270-
use `after_full_multimodal_embeddings` or
271-
`after_active_multimodal_embeddings` so the common mixin sequence stays
272-
centralized.
273-
"""
274-
return find_input_mm_embeds(multimodal_embeddings, list(multimodal_params))
221+
# Validate post-gather so cached-only paths (KV reuse, all-cached chunked prefill) are also
222+
# checked, not just paths that ran the encoder.
223+
self._validate_embeddings(embeddings, multimodal_params)
224+
return embeddings[0]
275225

276226
def _fuse_multimodal_embeddings(
277227
self,
@@ -313,38 +263,35 @@ def _fuse_multimodal_embeddings(
313263
return fused_input_ids, inputs_embeds, ()
314264

315265
@staticmethod
316-
def _require_primary_embedding(embeddings: list[torch.Tensor]) -> torch.Tensor:
317-
if len(embeddings) != 1:
318-
raise ValueError(
319-
"MultimodalModelMixin requires a single primary embedding tensor, "
320-
f"got {len(embeddings)} tensors."
321-
)
322-
return embeddings[0]
323-
324-
@staticmethod
325-
def _validate_primary_embedding_rows(
326-
primary: torch.Tensor,
266+
def _validate_embeddings(
267+
embeddings: list[torch.Tensor],
327268
multimodal_params: Sequence[MultimodalParams],
328269
) -> None:
329-
"""Validate gathered primary embedding row count against runtime metadata.
270+
"""Validate gathered embeddings embedding row count against runtime metadata.
330271
331272
Skipped if any param lacks `multimodal_runtime.total_embeds_in_request`, since the contract
332273
cannot be evaluated without complete metadata.
333274
"""
275+
if len(embeddings) != 1:
276+
raise ValueError(
277+
f"MultimodalModelMixin requires a single embedding tensor, got {len(embeddings)} "
278+
"tensors."
279+
)
280+
281+
embeddings_tensor = embeddings[0]
334282
expected_rows = 0
335283
for param in multimodal_params:
336284
runtime = param.multimodal_runtime
337285
if runtime is None or runtime.total_embeds_in_request is None:
338286
logger.debug(
339-
"Skipping multimodal embedding row-count validation: "
340-
"runtime metadata missing or incomplete for at least one param."
287+
"Skipping multimodal embedding row-count validation: runtime metadata missing "
288+
"or incomplete for at least one param."
341289
)
342290
return
343291
expected_rows += runtime.total_embeds_in_request
344292

345-
actual_rows = primary.shape[0]
293+
actual_rows = embeddings_tensor.shape[0]
346294
if actual_rows != expected_rows:
347295
raise ValueError(
348-
"Multimodal embedding row count mismatch: "
349-
f"expected {expected_rows}, got {actual_rows}."
296+
f"Multimodal embedding row count mismatch: expected {expected_rows}, got {actual_rows}."
350297
)

tensorrt_llm/_torch/models/modeling_multimodal_utils.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -196,10 +196,26 @@ def _cache_multimodal_embeddings(
196196
)
197197

198198

199+
def _normalize_encoder_embeddings(
200+
encoder_embeddings: Union[torch.Tensor, List[torch.Tensor]],
201+
) -> List[torch.Tensor]:
202+
if isinstance(encoder_embeddings, torch.Tensor):
203+
return [encoder_embeddings]
204+
205+
if (not isinstance(encoder_embeddings, list) or not all(
206+
isinstance(embedding, torch.Tensor)
207+
for embedding in encoder_embeddings)):
208+
raise TypeError(
209+
"encoder_forward_fn must return a torch.Tensor or a list of torch.Tensor."
210+
)
211+
212+
return encoder_embeddings
213+
214+
199215
def get_multimodal_embeddings(
200216
encoder_forward_fn: Callable[
201-
[List[MultimodalParams]],
202-
List[torch.Tensor],
217+
...,
218+
torch.Tensor | List[torch.Tensor],
203219
],
204220
multimodal_params: List[MultimodalParams],
205221
encoder_kwargs: Optional[Dict[str, Any]] = None,
@@ -215,7 +231,8 @@ def get_multimodal_embeddings(
215231
216232
Args:
217233
encoder_forward_fn: Callable that performs encoder forward pass.
218-
Should accept List[MultimodalParams] and return List[torch.Tensor].
234+
Should accept List[MultimodalParams] and return either
235+
a single torch.Tensor or List[torch.Tensor].
219236
multimodal_params: All multimodal parameters in the batch.
220237
encoder_kwargs: Optional kwargs to pass to encoder_forward_fn.
221238
Returns:
@@ -233,6 +250,7 @@ def get_multimodal_embeddings(
233250
kwargs = encoder_kwargs or {}
234251
encoder_embeddings = encoder_forward_fn(uncached_multimodal_params,
235252
**kwargs)
253+
encoder_embeddings = _normalize_encoder_embeddings(encoder_embeddings)
236254

237255
# TODO: support multiple multimodal modalities per request
238256
if len(encoder_embeddings) > 1:

tests/unittest/_torch/multimodal/test_multimodal_mixin.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,28 @@ def encode_multimodal_inputs(self, multimodal_params, **encoder_kwargs):
4848
raise AssertionError("Tests use cached multimodal embeddings and should not encode.")
4949

5050

51+
class TensorEncoderMultimodalModel(DummyMultimodalModel):
52+
def __init__(
53+
self,
54+
embedding: Embedding,
55+
mm_token_ids: torch.Tensor,
56+
mm_embeds: torch.Tensor,
57+
):
58+
super().__init__(embedding, mm_token_ids)
59+
self.mm_embeds = mm_embeds
60+
61+
def encode_multimodal_inputs(self, multimodal_params, **encoder_kwargs) -> torch.Tensor:
62+
return self.mm_embeds
63+
64+
5165
def make_cached_multimodal_param(mm_embeds: torch.Tensor) -> MultimodalParams:
5266
return MultimodalParams(multimodal_data={"multimodal_embedding": mm_embeds})
5367

5468

69+
def make_raw_multimodal_param() -> MultimodalParams:
70+
return MultimodalParams(multimodal_data={"image": {"pixel_values": torch.empty(1)}})
71+
72+
5573
@pytest.mark.parametrize("device", ["cpu"] + (["cuda"] if torch.cuda.is_available() else []))
5674
def test_prepare_multimodal_inputs_forwards_precomputed_indices(device):
5775
hidden = 8
@@ -84,3 +102,37 @@ def test_prepare_multimodal_inputs_forwards_precomputed_indices(device):
84102
mm_emb.to(dtype=out.inputs_embeds.dtype, device=out.inputs_embeds.device),
85103
)
86104
torch.testing.assert_close(out.inputs_embeds[text_idx], emb(input_ids[text_idx]))
105+
106+
107+
@pytest.mark.parametrize("device", ["cpu"] + (["cuda"] if torch.cuda.is_available() else []))
108+
def test_prepare_multimodal_inputs_accepts_tensor_encoder_output(device):
109+
hidden = 8
110+
mm_token_id = 7
111+
emb = make_embedding(num_embeddings=40, hidden_size=hidden, device=device)
112+
113+
input_ids = torch.tensor([0, mm_token_id, 1], dtype=torch.long, device=device)
114+
text_idx = torch.tensor([0, 2], dtype=torch.long, device=device)
115+
mm_idx = torch.tensor([1], dtype=torch.long, device=device)
116+
mm_emb = torch.randn(mm_idx.shape[0], hidden, device=device)
117+
model = TensorEncoderMultimodalModel(
118+
emb,
119+
torch.tensor([mm_token_id], dtype=torch.long, device=device),
120+
mm_emb,
121+
)
122+
123+
out = model.prepare_multimodal_inputs(
124+
input_ids=input_ids,
125+
positions=None,
126+
multimodal_params=[make_raw_multimodal_param()],
127+
num_context_requests=1,
128+
text_token_indices=text_idx,
129+
mm_token_indices=mm_idx,
130+
)
131+
132+
assert out.input_ids is None
133+
assert out.inputs_embeds is not None
134+
torch.testing.assert_close(
135+
out.inputs_embeds[mm_idx],
136+
mm_emb.to(dtype=out.inputs_embeds.dtype, device=out.inputs_embeds.device),
137+
)
138+
torch.testing.assert_close(out.inputs_embeds[text_idx], emb(input_ids[text_idx]))

0 commit comments

Comments
 (0)