|
47 | 47 | from ..distributed.communicator import ReduceOp |
48 | 48 | from ..expert_statistic import ExpertStatistic |
49 | 49 | from ..models.modeling_llama import Llama4ForConditionalGeneration |
| 50 | +from ..models.modeling_multimodal_mixin import \ |
| 51 | + maybe_prefetch_mm_encoder_for_next_iter |
50 | 52 | from ..models.modeling_utils import DecoderModelForCausalLM |
51 | 53 | from ..modules.decoder_layer import DecoderLayer |
52 | 54 | from ..speculative.drafter import Drafter |
@@ -3025,6 +3027,8 @@ def _executor_loop(self): |
3025 | 3027 | self.dwdp_manager.prefetch_first_layers() |
3026 | 3028 | batch_outputs = self._forward_step(scheduled_batch) |
3027 | 3029 |
|
| 3030 | + self._maybe_prefetch_next_iter_mm_encoders(scheduled_batch) |
| 3031 | + |
3028 | 3032 | guided_decoder_failed_requests = None |
3029 | 3033 | if self.guided_decoder is not None: |
3030 | 3034 | guided_decoder_failed_requests = self.guided_decoder.execute( |
@@ -3150,7 +3154,7 @@ def _handle_control_request(self): |
3150 | 3154 | def _sync_and_process_resource_governor_queue(self): |
3151 | 3155 | """Synchronize and process resource governor requests across all ranks. |
3152 | 3156 |
|
3153 | | - Only called when ``_resource_governor_enabled`` is ``True``. |
| 3157 | + Only called when ``_resource_governor_enabled`` is `True`. |
3154 | 3158 | Uses a two-phase broadcast: first broadcast the count (a single int), |
3155 | 3159 | then broadcast the actual requests only when count > 0. This avoids |
3156 | 3160 | serializing and deserializing an empty Python list on every iteration. |
@@ -3441,6 +3445,8 @@ def _executor_loop_overlap(self): |
3441 | 3445 | scheduled_batch, previous_tensors_device, |
3442 | 3446 | num_accepted_tokens_device) |
3443 | 3447 |
|
| 3448 | + self._maybe_prefetch_next_iter_mm_encoders(scheduled_batch) |
| 3449 | + |
3444 | 3450 | if self.previous_batch is not None and should_process_previous_batch: |
3445 | 3451 | self._update_requests(self.previous_batch.sample_state) |
3446 | 3452 |
|
@@ -4729,6 +4735,61 @@ def _check_disagg_gen_cache_transfer_status(self, atLeastNum: int = 0): |
4729 | 4735 | req.state = LlmRequestState.DISAGG_TRANS_ERROR |
4730 | 4736 | self._check_cache_transfer_errors("generation requests") |
4731 | 4737 |
|
| 4738 | + def _maybe_prefetch_next_iter_mm_encoders( |
| 4739 | + self, scheduled_batch: ScheduledRequests) -> None: |
| 4740 | + """Best-effort hook for cross-iter MM encoder prefetch. |
| 4741 | +
|
| 4742 | + Called immediately after `_forward_step`, so the side-stream encoder |
| 4743 | + work can overlap current-iteration sampling in the non-overlap loop and |
| 4744 | + previous-batch `_update_requests` in the overlap loop. No-op unless |
| 4745 | + `TLLM_MM_SIDE_STREAM_MAX_AHEAD` is positive and the model is a |
| 4746 | + `MultimodalModelMixin` subclass. |
| 4747 | +
|
| 4748 | + Walks `active_requests` for context-init candidates that are NOT |
| 4749 | + in the just-scheduled batch (and, in overlap mode, not in the |
| 4750 | + previous batch either) and dispatches one of them, subject to the |
| 4751 | + outstanding-ahead cap in `maybe_prefetch_mm_encoder_for_next_iter`. |
| 4752 | + That helper runs the encoder on a side CUDA stream and stashes |
| 4753 | + results back into `request.py_multimodal_data`. The next iteration's |
| 4754 | + `_prepare_inputs` then picks up the cached embedding and the mixin |
| 4755 | + consume site waits on the recorded CUDA event. |
| 4756 | +
|
| 4757 | + Shared between `_executor_loop` (non-overlap) and |
| 4758 | + `_executor_loop_overlap`. `self.previous_batch` is always None in |
| 4759 | + non-overlap mode, so the second union term is a no-op there. |
| 4760 | + """ |
| 4761 | + model = getattr(self.model_engine, "model", None) |
| 4762 | + if model is None: |
| 4763 | + return |
| 4764 | + in_flight = {r.py_request_id for r in scheduled_batch.all_requests()} |
| 4765 | + if self.previous_batch is not None: |
| 4766 | + in_flight |= { |
| 4767 | + r.py_request_id |
| 4768 | + for r in self.previous_batch.scheduled_requests.all_requests() |
| 4769 | + } |
| 4770 | + pending = [ |
| 4771 | + r for r in self.active_requests |
| 4772 | + if r.state == LlmRequestState.CONTEXT_INIT |
| 4773 | + ] |
| 4774 | + if not pending: |
| 4775 | + return |
| 4776 | + try: |
| 4777 | + maybe_prefetch_mm_encoder_for_next_iter( |
| 4778 | + model=model, |
| 4779 | + pending_requests=pending, |
| 4780 | + in_flight_request_ids=in_flight, |
| 4781 | + max_prefetch=1, |
| 4782 | + ) |
| 4783 | + except Exception: |
| 4784 | + # Speculative prefetch is best-effort and must never crash the |
| 4785 | + # executor loop. On failure, `py_mm_encoder_event` is not stamped, |
| 4786 | + # so the next iteration's `_prepare_inputs` falls back to the |
| 4787 | + # standard in-iter encode path (which re-runs `to_device` and the |
| 4788 | + # encoder unconditionally when no cached embedding is present). |
| 4789 | + logger.warning( |
| 4790 | + f"Cross-iter MM encoder prefetch failed; falling back to " |
| 4791 | + f"in-iter encode.\n{traceback.format_exc()}") |
| 4792 | + |
4732 | 4793 | def _forward_step( |
4733 | 4794 | self, |
4734 | 4795 | scheduled_requests: ScheduledRequests, |
|
0 commit comments