@@ -123,27 +123,6 @@ def _run_on_aux_stream(aux_stream: torch.cuda.Stream) -> Iterator[torch.cuda.Eve
123123 exit_event .record ()
124124
125125
126- @dataclass (frozen = True )
127- class MultimodalEncoderOutput :
128- """Output produced by a model-owned multimodal encoder hook.
129-
130- Contract:
131- - `embeddings` contains all multimodal embedding rows for the supplied
132- `multimodal_params`.
133- - Rows are concatenated in the same order as `multimodal_params`.
134- - Per-request row counts match `total_embeds_in_request` from runtime
135- metadata when that metadata is available.
136- - Special multimodal tokens occupy token positions but do not have rows in
137- this tensor.
138-
139- The single-tensor shape is required for chunked-prefill embedding reuse,
140- which lets later chunks skip the encoder. See
141- `modeling_multimodal_utils.py` for the caching machinery.
142- """
143-
144- embeddings : torch .Tensor
145-
146-
147126@dataclass (frozen = True )
148127class PreparedLlmInputs :
149128 """Prepared inputs returned by `MultimodalModelMixin`."""
@@ -181,8 +160,13 @@ def convert(tensor: torch.Tensor) -> torch.Tensor:
181160 def encode_multimodal_inputs (
182161 self ,
183162 multimodal_params : Sequence [MultimodalParams ],
184- ) -> MultimodalEncoderOutput :
185- """Run model-specific multimodal encoder work."""
163+ ) -> torch .Tensor :
164+ """Run model-specific multimodal encoder work.
165+
166+ Returns the single primary multimodal embedding tensor for the supplied params. Rows are
167+ expected to be concatenated in request order, and special multimodal tokens occupy token
168+ positions but do not have rows here.
169+ """
186170 raise NotImplementedError
187171
188172 @property
@@ -222,16 +206,16 @@ def after_full_multimodal_embeddings(
222206 * ,
223207 input_ids : torch .Tensor ,
224208 multimodal_params : Sequence [MultimodalParams ],
225- encoder_output : MultimodalEncoderOutput ,
209+ embeddings : torch . Tensor ,
226210 ** forward_kwargs : Any ,
227- ) -> tuple [torch .Tensor , MultimodalEncoderOutput ]:
211+ ) -> tuple [torch .Tensor , torch . Tensor ]:
228212 """Optional hook before active chunk rows are selected.
229213
230214 Runs after cache lookup or encoder execution has produced full
231215 per-request multimodal embeddings, but before the mixin selects rows
232216 active in the current forward chunk.
233217 """
234- return input_ids , encoder_output
218+ return input_ids , embeddings
235219
236220 def after_active_multimodal_embeddings (
237221 self ,
@@ -274,21 +258,16 @@ def prepare_multimodal_inputs(
274258 if not context_params :
275259 return PreparedLlmInputs (input_ids = input_ids , inputs_embeds = None )
276260
277- full_output = self ._get_or_encode_multimodal_embeddings (context_params )
261+ full_embeddings = self ._get_or_encode_multimodal_embeddings (context_params )
278262
279- input_ids , full_output = self .after_full_multimodal_embeddings (
263+ input_ids , full_embeddings = self .after_full_multimodal_embeddings (
280264 input_ids = input_ids ,
281265 multimodal_params = context_params ,
282- encoder_output = full_output ,
266+ embeddings = full_embeddings ,
283267 ** forward_kwargs ,
284268 )
285269
286- active_embeddings = self ._find_active_multimodal_embeddings (
287- [full_output .embeddings ],
288- input_ids = input_ids ,
289- positions = positions ,
290- multimodal_params = context_params ,
291- )
270+ active_embeddings = find_input_mm_embeds ([full_embeddings ], list (context_params ))
292271 active_embeddings , extra_embeds = self .after_active_multimodal_embeddings (
293272 active_embeddings = active_embeddings ,
294273 multimodal_params = context_params ,
@@ -319,49 +298,20 @@ def prepare_multimodal_inputs(
319298 def _get_or_encode_multimodal_embeddings (
320299 self ,
321300 multimodal_params : Sequence [MultimodalParams ],
322- ) -> MultimodalEncoderOutput :
301+ ) -> torch . Tensor :
323302 """Return cached multimodal embeddings or run the encoder for misses.
324303
325- Delegates cache lookup and gather behavior to
326- `get_multimodal_embeddings`, then validates the single primary tensor
327- contract for both encoded and cached-only paths.
304+ Delegates cache lookup and gather behavior to `get_multimodal_embeddings`, then validates
305+ the single tensor contract for both encoded and cached-only paths.
328306 """
329-
330- def encoder_forward_fn (params : list [MultimodalParams ]) -> list [torch .Tensor ]:
331- encoder_output = self .encode_multimodal_inputs (params )
332- if not isinstance (encoder_output , MultimodalEncoderOutput ):
333- raise TypeError ("encode_multimodal_inputs must return MultimodalEncoderOutput." )
334- if not isinstance (encoder_output .embeddings , torch .Tensor ):
335- raise TypeError ("MultimodalEncoderOutput.embeddings must be a torch.Tensor." )
336- return [encoder_output .embeddings ]
337-
338307 embeddings = get_multimodal_embeddings (
339- encoder_forward_fn = encoder_forward_fn ,
308+ encoder_forward_fn = self . encode_multimodal_inputs ,
340309 multimodal_params = list (multimodal_params ),
341310 )
342- primary = self ._require_primary_embedding (embeddings )
343- # Validate post-gather so cached-only paths (KV reuse, all-cached chunked
344- # prefill) are also checked, not just paths that ran the encoder.
345- self ._validate_primary_embedding_rows (primary , multimodal_params )
346- return MultimodalEncoderOutput (embeddings = primary )
347-
348- def _find_active_multimodal_embeddings (
349- self ,
350- multimodal_embeddings : list [torch .Tensor ],
351- * ,
352- input_ids : torch .Tensor ,
353- positions : Optional [torch .Tensor ],
354- multimodal_params : Sequence [MultimodalParams ],
355- ) -> list [torch .Tensor ]:
356- """Named internal stage for selecting active chunk multimodal rows.
357-
358- This initial template stage currently delegates to
359- `find_input_mm_embeds`. Model-specific behavior around slicing should
360- use `after_full_multimodal_embeddings` or
361- `after_active_multimodal_embeddings` so the common mixin sequence stays
362- centralized.
363- """
364- return find_input_mm_embeds (multimodal_embeddings , list (multimodal_params ))
311+ # Validate post-gather so cached-only paths (KV reuse, all-cached chunked prefill) are also
312+ # checked, not just paths that ran the encoder.
313+ self ._validate_embeddings (embeddings , multimodal_params )
314+ return embeddings [0 ]
365315
366316 def _fuse_multimodal_embeddings (
367317 self ,
@@ -403,40 +353,46 @@ def _fuse_multimodal_embeddings(
403353 return fused_input_ids , inputs_embeds , ()
404354
405355 @staticmethod
406- def _require_primary_embedding (embeddings : list [torch .Tensor ]) -> torch .Tensor :
407- if len (embeddings ) != 1 :
408- raise ValueError (
409- "MultimodalModelMixin requires a single primary embedding tensor, "
410- f"got { len (embeddings )} tensors."
411- )
412- return embeddings [0 ]
413-
414- @staticmethod
415- def _validate_primary_embedding_rows (
416- primary : torch .Tensor ,
356+ def _validate_embeddings (
357+ embeddings : list [torch .Tensor ],
417358 multimodal_params : Sequence [MultimodalParams ],
418359 ) -> None :
419- """Validate gathered primary embedding row count against runtime metadata.
360+ """Validate gathered embeddings embedding row count against runtime metadata.
420361
421362 Skipped if any param lacks `multimodal_runtime.total_embeds_in_request`, since the contract
422363 cannot be evaluated without complete metadata.
423364 """
365+ if len (embeddings ) != 1 :
366+ raise ValueError (
367+ f"MultimodalModelMixin requires a single embedding tensor, got { len (embeddings )} "
368+ "tensors."
369+ )
370+
371+ embeddings_tensor = embeddings [0 ]
424372 expected_rows = 0
373+ has_runtime_metadata = []
425374 for param in multimodal_params :
426375 runtime = param .multimodal_runtime
427- if runtime is None or runtime .total_embeds_in_request is None :
428- logger .debug (
429- "Skipping multimodal embedding row-count validation: "
430- "runtime metadata missing or incomplete for at least one param."
431- )
432- return
433- expected_rows += runtime .total_embeds_in_request
376+ has_runtime = runtime is not None and runtime .total_embeds_in_request is not None
377+ has_runtime_metadata .append (has_runtime )
378+ if has_runtime :
379+ expected_rows += runtime .total_embeds_in_request
380+
381+ if any (has_runtime_metadata ) and not all (has_runtime_metadata ):
382+ raise ValueError (
383+ "Multimodal runtime metadata must be present for every param or none of them."
384+ )
385+ if not all (has_runtime_metadata ):
386+ logger .debug (
387+ "Skipping multimodal embedding row-count validation: runtime metadata missing "
388+ "for all params."
389+ )
390+ return
434391
435- actual_rows = primary .shape [0 ]
392+ actual_rows = embeddings_tensor .shape [0 ]
436393 if actual_rows != expected_rows :
437394 raise ValueError (
438- "Multimodal embedding row count mismatch: "
439- f"expected { expected_rows } , got { actual_rows } ."
395+ f"Multimodal embedding row count mismatch: expected { expected_rows } , got { actual_rows } ."
440396 )
441397
442398
@@ -539,8 +495,6 @@ def _dispatch_cross_iter_prefetch(
539495 for (req , _ , _ ), p in zip (candidates , params_list ):
540496 req .py_multimodal_data = p .multimodal_data
541497 encoder_output = model .encode_multimodal_inputs (params_list )
542- if not isinstance (encoder_output , MultimodalEncoderOutput ):
543- raise TypeError ("encode_multimodal_inputs must return MultimodalEncoderOutput." )
544498 _cache_multimodal_embeddings (params_list , [encoder_output .embeddings ])
545499 finally :
546500 # Stash the event on every candidate's durable LlmRequest (not the
0 commit comments