Skip to content

Commit 4a8b7af

Browse files
authored
[None][feat] Side-stream for MM encoder (#14322)
* Why? Multimodal context requests currently run their encoder only after they are scheduled. That potentially keeps the next request's image encoding on the critical path even when the executor already has independent GPU work from the current iteration to overlap it with. * What? Add an opt-in cross-iteration prefetch path gated by `TLLM_MM_SIDE_STREAM_MAX_AHEAD`. The executor picks pending multimodal context requests that are not in flight, moves their inputs to CUDA and runs the encoder on an auxiliary stream. This leverages the recently added `MultimodalEncoderMixin`. Signed-off-by: William Zhang <133824995+2ez4bz@users.noreply.github.com>
1 parent c25fa74 commit 4a8b7af

9 files changed

Lines changed: 639 additions & 32 deletions

File tree

tensorrt_llm/_torch/models/modeling_mistral.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -612,7 +612,7 @@ def __init__(
612612
# NOTE: attn_backend: Pixtral head size not always divisible by 128
613613
vision_model_config = self._get_sub_model_config(model_config_cp,
614614
"vision_config",
615-
attn_backend="VANILLA",
615+
attn_backend="TRTLLM",
616616
quant_config=None)
617617

618618
self._vision_tower = modeling_pixtral.PixtralVisionModel(
@@ -705,7 +705,6 @@ def infer_max_seq_len(self) -> int:
705705
def encode_multimodal_inputs(
706706
self,
707707
multimodal_params: Sequence[MultimodalParams],
708-
**encoder_kwargs: Any,
709708
) -> MultimodalEncoderOutput:
710709
mm_embeds = self._vision_forward(list(multimodal_params))
711710
return MultimodalEncoderOutput(embeddings=mm_embeds[0])

tensorrt_llm/_torch/models/modeling_multimodal_mixin.py

Lines changed: 280 additions & 28 deletions
Large diffs are not rendered by default.

tensorrt_llm/_torch/models/modeling_multimodal_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,13 @@ def get_multimodal_embeddings(
219219
if not multimodal_params:
220220
return []
221221

222+
# Wait before touching tensors produced on the MM side stream. Do not
223+
# clear the event here; repeated stream-side waits are cheap, and leaving
224+
# the event field untouched avoids races if a caller accidentally reuses it.
225+
for param in multimodal_params:
226+
if param.encoder_event is not None:
227+
torch.cuda.current_stream().wait_event(param.encoder_event)
228+
222229
# Step 1: Find uncached multimodal params that need encoder processing
223230
uncached_multimodal_params = _get_uncached_multimodal_params(
224231
multimodal_params)

tensorrt_llm/_torch/pyexecutor/llm_request.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -686,6 +686,13 @@ def __init__(
686686
and encoder_output_len is None):
687687
encoder_output_len = len(encoder_input_tokens)
688688
kwargs["encoder_output_len"] = encoder_output_len
689+
690+
# Cross-iter MM encoder prefetch event: stamped by the side-stream
691+
# producer in `modeling_multimodal_mixin._dispatch_cross_iter_prefetch`
692+
# and consumed (then cleared) in `model_engine._prepare_inputs` when
693+
# the request is next scheduled.
694+
self.py_mm_encoder_event: Optional[torch.cuda.Event] = None
695+
689696
if llm_request is not None:
690697
super().__init__(llm_request)
691698
else:

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3037,6 +3037,17 @@ def append_cross_attention_state(request: LlmRequest,
30373037
multimodal_data=request.py_multimodal_data,
30383038
multimodal_runtime=py_multimodal_runtime,
30393039
input_ids_start_offset=context_start_idx)
3040+
# Transfer any cross-iter MM encoder prefetch event stamped on the request onto the
3041+
# freshly-built MultimodalParams. The downstream consume site reads it from the wrapper,
3042+
# not from the request.
3043+
# NOTE: the prefetch producer always writes the cached embedding into
3044+
# `py_multimodal_data` before stamping the event, so whenever the event is present,
3045+
# `has_content()` below is `True` and the wrapper reaches the consume site that waits on
3046+
# it.
3047+
mm_encoder_event = request.py_mm_encoder_event
3048+
if mm_encoder_event is not None:
3049+
multimodal_params.encoder_event = mm_encoder_event
3050+
request.py_mm_encoder_event = None
30403051
if multimodal_params.has_content():
30413052
# TODO: Visit later to decide the appropriate position of sending multimodal data & selectively sending multimodal data
30423053
multimodal_params.to_device("multimodal_data",

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@
4747
from ..distributed.communicator import ReduceOp
4848
from ..expert_statistic import ExpertStatistic
4949
from ..models.modeling_llama import Llama4ForConditionalGeneration
50+
from ..models.modeling_multimodal_mixin import \
51+
maybe_prefetch_mm_encoder_for_next_iter
5052
from ..models.modeling_utils import DecoderModelForCausalLM
5153
from ..modules.decoder_layer import DecoderLayer
5254
from ..speculative.drafter import Drafter
@@ -3025,6 +3027,8 @@ def _executor_loop(self):
30253027
self.dwdp_manager.prefetch_first_layers()
30263028
batch_outputs = self._forward_step(scheduled_batch)
30273029

3030+
self._maybe_prefetch_next_iter_mm_encoders(scheduled_batch)
3031+
30283032
guided_decoder_failed_requests = None
30293033
if self.guided_decoder is not None:
30303034
guided_decoder_failed_requests = self.guided_decoder.execute(
@@ -3150,7 +3154,7 @@ def _handle_control_request(self):
31503154
def _sync_and_process_resource_governor_queue(self):
31513155
"""Synchronize and process resource governor requests across all ranks.
31523156
3153-
Only called when ``_resource_governor_enabled`` is ``True``.
3157+
Only called when ``_resource_governor_enabled`` is `True`.
31543158
Uses a two-phase broadcast: first broadcast the count (a single int),
31553159
then broadcast the actual requests only when count > 0. This avoids
31563160
serializing and deserializing an empty Python list on every iteration.
@@ -3441,6 +3445,8 @@ def _executor_loop_overlap(self):
34413445
scheduled_batch, previous_tensors_device,
34423446
num_accepted_tokens_device)
34433447

3448+
self._maybe_prefetch_next_iter_mm_encoders(scheduled_batch)
3449+
34443450
if self.previous_batch is not None and should_process_previous_batch:
34453451
self._update_requests(self.previous_batch.sample_state)
34463452

@@ -4729,6 +4735,61 @@ def _check_disagg_gen_cache_transfer_status(self, atLeastNum: int = 0):
47294735
req.state = LlmRequestState.DISAGG_TRANS_ERROR
47304736
self._check_cache_transfer_errors("generation requests")
47314737

4738+
def _maybe_prefetch_next_iter_mm_encoders(
4739+
self, scheduled_batch: ScheduledRequests) -> None:
4740+
"""Best-effort hook for cross-iter MM encoder prefetch.
4741+
4742+
Called immediately after `_forward_step`, so the side-stream encoder
4743+
work can overlap current-iteration sampling in the non-overlap loop and
4744+
previous-batch `_update_requests` in the overlap loop. No-op unless
4745+
`TLLM_MM_SIDE_STREAM_MAX_AHEAD` is positive and the model is a
4746+
`MultimodalModelMixin` subclass.
4747+
4748+
Walks `active_requests` for context-init candidates that are NOT
4749+
in the just-scheduled batch (and, in overlap mode, not in the
4750+
previous batch either) and dispatches one of them, subject to the
4751+
outstanding-ahead cap in `maybe_prefetch_mm_encoder_for_next_iter`.
4752+
That helper runs the encoder on a side CUDA stream and stashes
4753+
results back into `request.py_multimodal_data`. The next iteration's
4754+
`_prepare_inputs` then picks up the cached embedding and the mixin
4755+
consume site waits on the recorded CUDA event.
4756+
4757+
Shared between `_executor_loop` (non-overlap) and
4758+
`_executor_loop_overlap`. `self.previous_batch` is always None in
4759+
non-overlap mode, so the second union term is a no-op there.
4760+
"""
4761+
model = getattr(self.model_engine, "model", None)
4762+
if model is None:
4763+
return
4764+
in_flight = {r.py_request_id for r in scheduled_batch.all_requests()}
4765+
if self.previous_batch is not None:
4766+
in_flight |= {
4767+
r.py_request_id
4768+
for r in self.previous_batch.scheduled_requests.all_requests()
4769+
}
4770+
pending = [
4771+
r for r in self.active_requests
4772+
if r.state == LlmRequestState.CONTEXT_INIT
4773+
]
4774+
if not pending:
4775+
return
4776+
try:
4777+
maybe_prefetch_mm_encoder_for_next_iter(
4778+
model=model,
4779+
pending_requests=pending,
4780+
in_flight_request_ids=in_flight,
4781+
max_prefetch=1,
4782+
)
4783+
except Exception:
4784+
# Speculative prefetch is best-effort and must never crash the
4785+
# executor loop. On failure, `py_mm_encoder_event` is not stamped,
4786+
# so the next iteration's `_prepare_inputs` falls back to the
4787+
# standard in-iter encode path (which re-runs `to_device` and the
4788+
# encoder unconditionally when no cached embedding is present).
4789+
logger.warning(
4790+
f"Cross-iter MM encoder prefetch failed; falling back to "
4791+
f"in-iter encode.\n{traceback.format_exc()}")
4792+
47324793
def _forward_step(
47334794
self,
47344795
scheduled_requests: ScheduledRequests,

tensorrt_llm/inputs/multimodal.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,13 @@ class MultimodalParams:
503503
multimodal_data: Optional[Dict[str, Any]] = field(default_factory=dict)
504504
multimodal_runtime: Optional[MultimodalRuntimeData] = None
505505
input_ids_start_offset: int = 0
506+
# CUDA event recorded on a side stream by the MM encoder prefetch path.
507+
# When set, the consume site in `get_multimodal_embeddings` issues a
508+
# `wait_event` on the current stream before reading cached embeddings.
509+
# Always `None` unless `TLLM_MM_SIDE_STREAM_MAX_AHEAD` is positive and a prefetch ran.
510+
encoder_event: Optional[torch.cuda.Event] = field(default=None,
511+
repr=False,
512+
compare=False)
506513

507514
def __post_init__(self):
508515
"""Ensure default values are properly set."""

0 commit comments

Comments
 (0)