2929)
3030
3131
32- @dataclass (frozen = True )
33- class MultimodalEncoderOutput :
34- """Output produced by a model-owned multimodal encoder hook.
35-
36- Contract:
37- - `embeddings` contains all multimodal embedding rows for the supplied
38- `multimodal_params`.
39- - Rows are concatenated in the same order as `multimodal_params`.
40- - Per-request row counts match `total_embeds_in_request` from runtime
41- metadata when that metadata is available.
42- - Special multimodal tokens occupy token positions but do not have rows in
43- this tensor.
44-
45- The single-tensor shape is required for chunked-prefill embedding reuse,
46- which lets later chunks skip the encoder. See
47- `modeling_multimodal_utils.py` for the caching machinery.
48- """
49-
50- embeddings : torch .Tensor
51-
52-
5332@dataclass (frozen = True )
5433class PreparedLlmInputs :
5534 """Prepared inputs returned by `MultimodalModelMixin`."""
@@ -71,8 +50,13 @@ def encode_multimodal_inputs(
7150 self ,
7251 multimodal_params : Sequence [MultimodalParams ],
7352 ** encoder_kwargs : Any ,
74- ) -> MultimodalEncoderOutput :
75- """Run model-specific multimodal encoder work."""
53+ ) -> torch .Tensor :
54+ """Run model-specific multimodal encoder work.
55+
56+ Returns the single primary multimodal embedding tensor for the supplied params. Rows are
57+ expected to be concatenated in request order, and special multimodal tokens occupy token
58+ positions but do not have rows here.
59+ """
7660 raise NotImplementedError
7761
7862 @property
@@ -122,16 +106,16 @@ def after_full_multimodal_embeddings(
122106 * ,
123107 input_ids : torch .Tensor ,
124108 multimodal_params : Sequence [MultimodalParams ],
125- encoder_output : MultimodalEncoderOutput ,
109+ embeddings : torch . Tensor ,
126110 ** forward_kwargs : Any ,
127- ) -> tuple [torch .Tensor , MultimodalEncoderOutput ]:
111+ ) -> tuple [torch .Tensor , torch . Tensor ]:
128112 """Optional hook before active chunk rows are selected.
129113
130114 Runs after cache lookup or encoder execution has produced full
131115 per-request multimodal embeddings, but before the mixin selects rows
132116 active in the current forward chunk.
133117 """
134- return input_ids , encoder_output
118+ return input_ids , embeddings
135119
136120 def after_active_multimodal_embeddings (
137121 self ,
@@ -179,24 +163,19 @@ def prepare_multimodal_inputs(
179163 multimodal_params = context_params ,
180164 ** forward_kwargs ,
181165 )
182- full_output = self ._get_or_encode_multimodal_embeddings (
166+ full_embeddings = self ._get_or_encode_multimodal_embeddings (
183167 context_params ,
184168 ** encoder_kwargs ,
185169 )
186170
187- input_ids , full_output = self .after_full_multimodal_embeddings (
171+ input_ids , full_embeddings = self .after_full_multimodal_embeddings (
188172 input_ids = input_ids ,
189173 multimodal_params = context_params ,
190- encoder_output = full_output ,
174+ embeddings = full_embeddings ,
191175 ** forward_kwargs ,
192176 )
193177
194- active_embeddings = self ._find_active_multimodal_embeddings (
195- [full_output .embeddings ],
196- input_ids = input_ids ,
197- positions = positions ,
198- multimodal_params = context_params ,
199- )
178+ active_embeddings = find_input_mm_embeds ([full_embeddings ], list (context_params ))
200179 active_embeddings , extra_embeds = self .after_active_multimodal_embeddings (
201180 active_embeddings = active_embeddings ,
202181 multimodal_params = context_params ,
@@ -228,50 +207,21 @@ def _get_or_encode_multimodal_embeddings(
228207 self ,
229208 multimodal_params : Sequence [MultimodalParams ],
230209 ** encoder_kwargs : Any ,
231- ) -> MultimodalEncoderOutput :
210+ ) -> torch . Tensor :
232211 """Return cached multimodal embeddings or run the encoder for misses.
233212
234- Delegates cache lookup and gather behavior to
235- `get_multimodal_embeddings`, then validates the single primary tensor
236- contract for both encoded and cached-only paths.
213+ Delegates cache lookup and gather behavior to `get_multimodal_embeddings`, then validates
214+ the single tensor contract for both encoded and cached-only paths.
237215 """
238-
239- def encoder_forward_fn (params : list [MultimodalParams ], ** kwargs : Any ) -> list [torch .Tensor ]:
240- encoder_output = self .encode_multimodal_inputs (params , ** kwargs )
241- if not isinstance (encoder_output , MultimodalEncoderOutput ):
242- raise TypeError ("encode_multimodal_inputs must return MultimodalEncoderOutput." )
243- if not isinstance (encoder_output .embeddings , torch .Tensor ):
244- raise TypeError ("MultimodalEncoderOutput.embeddings must be a torch.Tensor." )
245- return [encoder_output .embeddings ]
246-
247216 embeddings = get_multimodal_embeddings (
248- encoder_forward_fn = encoder_forward_fn ,
217+ encoder_forward_fn = self . encode_multimodal_inputs ,
249218 multimodal_params = list (multimodal_params ),
250219 encoder_kwargs = encoder_kwargs ,
251220 )
252- primary = self ._require_primary_embedding (embeddings )
253- # Validate post-gather so cached-only paths (KV reuse, all-cached chunked
254- # prefill) are also checked, not just paths that ran the encoder.
255- self ._validate_primary_embedding_rows (primary , multimodal_params )
256- return MultimodalEncoderOutput (embeddings = primary )
257-
258- def _find_active_multimodal_embeddings (
259- self ,
260- multimodal_embeddings : list [torch .Tensor ],
261- * ,
262- input_ids : torch .Tensor ,
263- positions : Optional [torch .Tensor ],
264- multimodal_params : Sequence [MultimodalParams ],
265- ) -> list [torch .Tensor ]:
266- """Named internal stage for selecting active chunk multimodal rows.
267-
268- This initial template stage currently delegates to
269- `find_input_mm_embeds`. Model-specific behavior around slicing should
270- use `after_full_multimodal_embeddings` or
271- `after_active_multimodal_embeddings` so the common mixin sequence stays
272- centralized.
273- """
274- return find_input_mm_embeds (multimodal_embeddings , list (multimodal_params ))
221+ # Validate post-gather so cached-only paths (KV reuse, all-cached chunked prefill) are also
222+ # checked, not just paths that ran the encoder.
223+ self ._validate_embeddings (embeddings , multimodal_params )
224+ return embeddings [0 ]
275225
276226 def _fuse_multimodal_embeddings (
277227 self ,
@@ -313,38 +263,35 @@ def _fuse_multimodal_embeddings(
313263 return fused_input_ids , inputs_embeds , ()
314264
315265 @staticmethod
316- def _require_primary_embedding (embeddings : list [torch .Tensor ]) -> torch .Tensor :
317- if len (embeddings ) != 1 :
318- raise ValueError (
319- "MultimodalModelMixin requires a single primary embedding tensor, "
320- f"got { len (embeddings )} tensors."
321- )
322- return embeddings [0 ]
323-
324- @staticmethod
325- def _validate_primary_embedding_rows (
326- primary : torch .Tensor ,
266+ def _validate_embeddings (
267+ embeddings : list [torch .Tensor ],
327268 multimodal_params : Sequence [MultimodalParams ],
328269 ) -> None :
329- """Validate gathered primary embedding row count against runtime metadata.
270+ """Validate gathered embeddings embedding row count against runtime metadata.
330271
331272 Skipped if any param lacks `multimodal_runtime.total_embeds_in_request`, since the contract
332273 cannot be evaluated without complete metadata.
333274 """
275+ if len (embeddings ) != 1 :
276+ raise ValueError (
277+ f"MultimodalModelMixin requires a single embedding tensor, got { len (embeddings )} "
278+ "tensors."
279+ )
280+
281+ embeddings_tensor = embeddings [0 ]
334282 expected_rows = 0
335283 for param in multimodal_params :
336284 runtime = param .multimodal_runtime
337285 if runtime is None or runtime .total_embeds_in_request is None :
338286 logger .debug (
339- "Skipping multimodal embedding row-count validation: "
340- "runtime metadata missing or incomplete for at least one param."
287+ "Skipping multimodal embedding row-count validation: runtime metadata missing "
288+ "or incomplete for at least one param."
341289 )
342290 return
343291 expected_rows += runtime .total_embeds_in_request
344292
345- actual_rows = primary .shape [0 ]
293+ actual_rows = embeddings_tensor .shape [0 ]
346294 if actual_rows != expected_rows :
347295 raise ValueError (
348- "Multimodal embedding row count mismatch: "
349- f"expected { expected_rows } , got { actual_rows } ."
296+ f"Multimodal embedding row count mismatch: expected { expected_rows } , got { actual_rows } ."
350297 )
0 commit comments