11# Copyright (c) OpenMMLab. All rights reserved.
2+ """Engine-loop input construction for the LMDeploy PyTorch backend.
3+
4+ This module converts scheduler decisions into model-agent inputs. Most helpers
5+ build tensor fields for full-batch ``ModelInputs``; ``InputsMakerAsync`` is the
6+ coordinator that chooses prefill/chunk/decode work, attaches per-forward
7+ metadata, dispatches it to the executor, and updates local running state.
8+ """
29import logging
310from collections import defaultdict
411from dataclasses import dataclass
@@ -243,6 +250,39 @@ def check_enable(self):
243250
244251
245252class InputsMakerAsync :
253+ """Coordinate prefill, decode, and long-context input dispatch.
254+
255+ ``Scheduler`` owns admission, ordering, and cache/KV resources. This class
256+ consumes the scheduler result and builds tensors only after resources have
257+ been granted. Prefill-like work is represented by full ``ModelInputs``:
258+ prompt prefill, final long-context chunks, and eager non-final long chunks.
259+ Decode is represented by ``ModelInputsDelta`` and reuses persistent
260+ model-agent/strategy ``StepInputs`` that were created by earlier prefill and
261+ decode forwards.
262+
263+ ``running_seqs`` is local engine-loop state, not the scheduler's source of
264+ truth. It tracks sequences already sent to the executor so this class can
265+ build decode deltas, evict invalid decode requests, and update the local
266+ view after outputs return. Every dispatched forward also carries the
267+ strategy-specific ``extra_inputs``, sampling inputs, and stopping criteria
268+ expected by the model agent.
269+
270+ Long-context chunking is coordinated here because it spans scheduling
271+ policy and input construction. ``LongContextChunker`` tracks one active
272+ long prefill and selects model-safe chunk boundaries, including indivisible
273+ multimodal spans. Before tensors are created for each chunk, the scheduler
274+ reserves the chunk's KV ownership. Non-final chunks are eager chunk
275+ forwards with no user-visible output; the final chunk is treated as normal
276+ prefill so it can merge into persistent decode state.
277+
278+ The current first-slice chunked-prefill policy intentionally uses separate
279+ forwards instead of one mixed decode+prefill tensor batch. After a
280+ non-final chunk, runnable decode is preferred and remains on the existing
281+ delta/CUDAGraph path; at most one eager non-final long chunk is sent after
282+ decode gets a chance to run. Preserve chunk flags such as
283+ ``is_chunk_multimodal`` and ``is_last_chunk`` because VLM and speculative
284+ decoding paths interpret them downstream.
285+ """
246286
247287 def __init__ (
248288 self ,
@@ -272,6 +312,7 @@ def __init__(
272312
273313 # consecutive decode counter for prefill starvation prevention
274314 self ._decode_count = 0
315+ self ._last_forward_kind = None
275316
276317 # record for next forward.
277318 self .next_is_prefill = True
@@ -293,6 +334,38 @@ def _init_do_prefill(self, config: InputsMakerConfig):
293334 else :
294335 self .do_prefill = self .do_prefill_default
295336
337+ def _has_pending_last_long_context_chunk (self ):
338+ """Check whether a running long context has only its final chunk
339+ left."""
340+ return self .long_context_chunker .enabled () and self .long_context_chunker .is_last_chunk ()
341+
342+ def _should_decode_before_long_context_chunk (self , prefill : bool ):
343+ """Prefer decode when a long-context chunk should not monopolize the
344+ loop."""
345+ if self .config .role == EngineRole .Prefill :
346+ return False
347+ if len (self .running_seqs ) == 0 :
348+ return False
349+ if not self .long_context_chunker .enabled ():
350+ return False
351+ if self .long_context_chunker .is_last_chunk ():
352+ return not prefill
353+ return getattr (self , '_last_forward_kind' , None ) == 'long_context_chunk'
354+
355+ def _forward_kind (self , inputs : 'ModelInputs|None' , delta : 'ModelInputsDelta|None' ):
356+ """Classify a queued forward for long-context interleaving policy."""
357+ if inputs is None :
358+ if delta is not None :
359+ return 'decode'
360+ return None
361+ if inputs .is_chunk and not inputs .is_last_chunk :
362+ return 'long_context_chunk'
363+ if inputs .is_chunk :
364+ return 'last_long_context_chunk'
365+ if inputs .is_decoding :
366+ return 'decode'
367+ return 'prefill'
368+
296369 def _create_vision_model_inputs (self , messages : 'SeqList' , model_inputs : ModelInputs ):
297370 """Create vision model inputs."""
298371 batch_size = len (messages )
@@ -722,26 +795,41 @@ def __create_model_inputs(seqs):
722795 extra_inputs = self .model_agent_strategy .make_extra_inputs (seqs , inputs )
723796 return inputs , delta , extra_inputs
724797
725- def __create_inputs_chunk (running : 'SeqList' ):
726- chunk_size , multimodals = self .long_context_chunker .next_chunk_size ()
798+ def __create_inputs_chunk (running : 'SeqList' , chunk_size : int , multimodals : 'MultiModalInputs|None' ):
727799 inputs = self .create_model_inputs_long_context (running [0 ], chunk_size , multimodals )
728800 extra_inputs = self .model_agent_strategy .make_extra_inputs (running , inputs )
729801 return inputs , extra_inputs
730802
803+ def __reserve_long_context_chunk (seq : 'SchedulerSequence' , chunk_size : int , is_last_chunk : bool ):
804+ if self .config .role == EngineRole .Prefill :
805+ prealloc_size = 0
806+ elif is_last_chunk :
807+ prealloc_size = self .engine_strategy .get_prealloc_size (True )
808+ else :
809+ prealloc_size = 0
810+ return scheduler .reserve_long_context_chunk (seq ,
811+ chunk_size ,
812+ prealloc_size = prealloc_size ,
813+ is_last_chunk = is_last_chunk )
814+
731815 def __create_inputs_long_context_chunk ():
732816 seq = self .long_context_chunker .seq
817+ chunk_size , multimodals = self .long_context_chunker .next_chunk_size ()
818+ is_last_chunk = self .long_context_chunker .is_last_chunk ()
819+ is_chunk_multimodal = self .long_context_chunker .has_multimodal
820+ if not __reserve_long_context_chunk (seq , chunk_size , is_last_chunk ):
821+ return [], None , None , None
733822 running = [seq ]
734- has_multimodal = self .long_context_chunker .has_multimodal
735- if self .long_context_chunker .is_last_chunk ():
823+ if is_last_chunk :
736824 inputs , delta , extra_inputs = __create_model_inputs (running )
737825 inputs .is_chunk = True
738826 inputs .is_last_chunk = True
739827 self .long_context_chunker .clear ()
740828 else :
741- inputs , extra_inputs = __create_inputs_chunk (running )
829+ inputs , extra_inputs = __create_inputs_chunk (running , chunk_size , multimodals )
742830 delta = None
743831 inputs .is_first_chunk = False
744- inputs .is_chunk_multimodal = has_multimodal
832+ inputs .is_chunk_multimodal = is_chunk_multimodal
745833 return running , inputs , delta , extra_inputs
746834
747835 def __create_inputs_prefill ():
@@ -770,7 +858,8 @@ def __create_inputs_prefill():
770858 self .long_context_chunker .clear ()
771859 inputs , delta , extra_inputs = __create_model_inputs (running )
772860 else :
773- inputs , extra_inputs = __create_inputs_chunk (running )
861+ chunk_size , multimodals = self .long_context_chunker .next_chunk_size ()
862+ inputs , extra_inputs = __create_inputs_chunk (running , chunk_size , multimodals )
774863 inputs .is_first_chunk = True
775864 inputs .is_chunk_multimodal = self .long_context_chunker .has_multimodal
776865 elif len (running ) > 0 :
@@ -783,13 +872,19 @@ def __create_inputs_prefill():
783872
784873 inputs = None
785874 delta = None
875+ running = []
876+ extra_inputs = None
786877 swap_in_map = {}
787878 swap_out_map = {}
879+ deferred_long_context_chunk = False
788880
789881 self .long_context_chunker .check_enable ()
790882 if self .long_context_chunker .enabled ():
791883 # long context chunking
792- running , inputs , delta , extra_inputs = __create_inputs_long_context_chunk ()
884+ if self ._should_decode_before_long_context_chunk (prefill ):
885+ deferred_long_context_chunk = True
886+ else :
887+ running , inputs , delta , extra_inputs = __create_inputs_long_context_chunk ()
793888 elif prefill :
794889 # prefill
795890 (
@@ -801,17 +896,20 @@ def __create_inputs_prefill():
801896 swap_out_map ,
802897 ) = __create_inputs_prefill ()
803898
804- # reset decode count when non-decoding inputs are produced
805- if inputs is not None and not inputs .is_decoding :
806- self ._decode_count = 0
807-
808899 # try decoding
809900 if inputs is None and len (self .running_seqs ) > 0 and self .config .role != EngineRole .Prefill :
810901 prefill = False
811902 delta , running , invalid_seqs = self .create_model_inputs_delta ()
812903 self .to_evict_seqs = invalid_seqs
813904 extra_inputs = None
814905
906+ if inputs is None and delta is None and deferred_long_context_chunk and self .long_context_chunker .enabled ():
907+ running , inputs , delta , extra_inputs = __create_inputs_long_context_chunk ()
908+
909+ # reset decode count when non-decoding inputs are produced
910+ if inputs is not None and not inputs .is_decoding :
911+ self ._decode_count = 0
912+
815913 # skip if enable empty
816914 if inputs is None and delta is None :
817915 return None
@@ -844,9 +942,10 @@ def do_prefill_pnode(self):
844942 def do_prefill_default (self ):
845943 # decoding if no waiting
846944 scheduler = self .scheduler
945+ pending_last_chunk = self ._has_pending_last_long_context_chunk ()
847946
848947 # do decoding if not waiting
849- if not scheduler .has_waiting ():
948+ if not scheduler .has_waiting () and not pending_last_chunk :
850949 self ._decode_count = 0
851950 return False
852951
@@ -861,6 +960,10 @@ def do_prefill_default(self):
861960 token_count += seq .num_token_ids
862961 if token_count >= self .config .max_prefill_token_num :
863962 return True
963+ if pending_last_chunk :
964+ token_count += self .long_context_chunker .seq .num_token_ids
965+ if token_count >= self .config .max_prefill_token_num :
966+ return True
864967
865968 # prefill if no enough running
866969 num_ready = scheduler .num_ready ()
@@ -892,6 +995,7 @@ async def _send_next_inputs_impl(self, prefill: bool = None, enable_empty: bool
892995 session_ids = [seq .session_id for seq in next_running ]
893996 logger .debug (f'Forward session_ids: { session_ids } ' )
894997 await self .executor .forward_async (forward_inputs )
998+ self ._last_forward_kind = self ._forward_kind (inputs , forward_inputs ['delta' ])
895999 self .scheduler .tick ()
8961000 self .forward_inputs = forward_inputs
8971001 return forward_inputs , next_running
0 commit comments