11# Copyright (c) OpenMMLab. All rights reserved.
2+ """Engine-loop input construction for the LMDeploy PyTorch backend.
3+
4+ This module converts scheduler decisions into model-agent inputs. Most helpers
5+ build tensor fields for full-batch ``ModelInputs``; ``InputsMakerAsync`` is the
6+ coordinator that chooses prefill/chunk/decode work, attaches per-forward
7+ metadata, dispatches it to the executor, and updates local running state.
8+ """
29import logging
310from collections import defaultdict
411from dataclasses import dataclass
@@ -247,6 +254,39 @@ def check_enable(self):
247254
248255
249256class InputsMakerAsync :
257+ """Coordinate prefill, decode, and long-context input dispatch.
258+
259+ ``Scheduler`` owns admission, ordering, and cache/KV resources. This class
260+ consumes the scheduler result and builds tensors only after resources have
261+ been granted. Prefill-like work is represented by full ``ModelInputs``:
262+ prompt prefill, final long-context chunks, and eager non-final long chunks.
263+ Decode is represented by ``ModelInputsDelta`` and reuses persistent
264+ model-agent/strategy ``StepInputs`` that were created by earlier prefill and
265+ decode forwards.
266+
267+ ``running_seqs`` is local engine-loop state, not the scheduler's source of
268+ truth. It tracks sequences already sent to the executor so this class can
269+ build decode deltas, evict invalid decode requests, and update the local
270+ view after outputs return. Every dispatched forward also carries the
271+ strategy-specific ``extra_inputs``, sampling inputs, and stopping criteria
272+ expected by the model agent.
273+
274+ Long-context chunking is coordinated here because it spans scheduling
275+ policy and input construction. ``LongContextChunker`` tracks one active
276+ long prefill and selects model-safe chunk boundaries, including indivisible
277+ multimodal spans. Before tensors are created for each chunk, the scheduler
278+ reserves the chunk's KV ownership. Non-final chunks are eager chunk
279+ forwards with no user-visible output; the final chunk is treated as normal
280+ prefill so it can merge into persistent decode state.
281+
282+ The current first-slice chunked-prefill policy intentionally uses separate
283+ forwards instead of one mixed decode+prefill tensor batch. After a
284+ non-final chunk, runnable decode is preferred and remains on the existing
285+ delta/CUDAGraph path; at most one eager non-final long chunk is sent after
286+ decode gets a chance to run. Preserve chunk flags such as
287+ ``is_chunk_multimodal`` and ``is_last_chunk`` because VLM and speculative
288+ decoding paths interpret them downstream.
289+ """
250290
251291 def __init__ (
252292 self ,
@@ -276,6 +316,7 @@ def __init__(
276316
277317 # consecutive decode counter for prefill starvation prevention
278318 self ._decode_count = 0
319+ self ._last_forward_kind = None
279320
280321 # record for next forward.
281322 self .next_is_prefill = True
@@ -297,6 +338,38 @@ def _init_do_prefill(self, config: InputsMakerConfig):
297338 else :
298339 self .do_prefill = self .do_prefill_default
299340
341+ def _has_pending_last_long_context_chunk (self ):
342+ """Check whether a running long context has only its final chunk
343+ left."""
344+ return self .long_context_chunker .enabled () and self .long_context_chunker .is_last_chunk ()
345+
346+ def _should_decode_before_long_context_chunk (self , prefill : bool ):
347+ """Prefer decode when a long-context chunk should not monopolize the
348+ loop."""
349+ if self .config .role == EngineRole .Prefill :
350+ return False
351+ if len (self .running_seqs ) == 0 :
352+ return False
353+ if not self .long_context_chunker .enabled ():
354+ return False
355+ if self .long_context_chunker .is_last_chunk ():
356+ return not prefill
357+ return getattr (self , '_last_forward_kind' , None ) == 'long_context_chunk'
358+
359+ def _forward_kind (self , inputs : 'ModelInputs|None' , delta : 'ModelInputsDelta|None' ):
360+ """Classify a queued forward for long-context interleaving policy."""
361+ if inputs is None :
362+ if delta is not None :
363+ return 'decode'
364+ return None
365+ if inputs .is_chunk and not inputs .is_last_chunk :
366+ return 'long_context_chunk'
367+ if inputs .is_chunk :
368+ return 'last_long_context_chunk'
369+ if inputs .is_decoding :
370+ return 'decode'
371+ return 'prefill'
372+
300373 def _create_vision_model_inputs (self , messages : 'SeqList' , model_inputs : ModelInputs ):
301374 """Create vision model inputs."""
302375 batch_size = len (messages )
@@ -723,25 +796,41 @@ def __create_model_inputs(seqs):
723796 extra_inputs = self .model_agent_strategy .make_extra_inputs (seqs , inputs )
724797 return inputs , delta , extra_inputs
725798
726- def __create_inputs_chunk (running : 'SeqList' ):
727- chunk_size , multimodals = self .long_context_chunker .next_chunk_size ()
799+ def __create_inputs_chunk (running : 'SeqList' , chunk_size : int , multimodals : 'MultiModalInputs|None' ):
728800 inputs = self .create_model_inputs_long_context (running [0 ], chunk_size , multimodals )
729801 extra_inputs = self .model_agent_strategy .make_extra_inputs (running , inputs )
730802 return inputs , extra_inputs
731803
804+ def __reserve_long_context_chunk (seq : 'SchedulerSequence' , chunk_size : int , is_last_chunk : bool ):
805+ if self .config .role == EngineRole .Prefill :
806+ prealloc_size = 0
807+ elif is_last_chunk :
808+ prealloc_size = self .engine_strategy .get_prealloc_size (True )
809+ else :
810+ prealloc_size = 0
811+ return scheduler .reserve_long_context_chunk (seq ,
812+ chunk_size ,
813+ prealloc_size = prealloc_size ,
814+ is_last_chunk = is_last_chunk )
815+
732816 def __create_inputs_long_context_chunk ():
733817 seq = self .long_context_chunker .seq
818+ chunk_size , multimodals = self .long_context_chunker .next_chunk_size ()
819+ is_last_chunk = self .long_context_chunker .is_last_chunk ()
820+ is_chunk_multimodal = self .long_context_chunker .has_multimodal
821+ if not __reserve_long_context_chunk (seq , chunk_size , is_last_chunk ):
822+ return [], None , None , None
734823 running = [seq ]
735- if self . long_context_chunker . is_last_chunk () :
824+ if is_last_chunk :
736825 inputs , delta , extra_inputs = __create_model_inputs (running )
737826 inputs .is_chunk = True
738827 inputs .is_last_chunk = True
739828 self .long_context_chunker .clear ()
740829 else :
741- inputs , extra_inputs = __create_inputs_chunk (running )
830+ inputs , extra_inputs = __create_inputs_chunk (running , chunk_size , multimodals )
742831 delta = None
743832 inputs .is_first_chunk = False
744- inputs .is_chunk_multimodal = self . long_context_chunker . has_multimodal
833+ inputs .is_chunk_multimodal = is_chunk_multimodal
745834 return running , inputs , delta , extra_inputs
746835
747836 def __create_inputs_prefill ():
@@ -770,7 +859,8 @@ def __create_inputs_prefill():
770859 self .long_context_chunker .clear ()
771860 inputs , delta , extra_inputs = __create_model_inputs (running )
772861 else :
773- inputs , extra_inputs = __create_inputs_chunk (running )
862+ chunk_size , multimodals = self .long_context_chunker .next_chunk_size ()
863+ inputs , extra_inputs = __create_inputs_chunk (running , chunk_size , multimodals )
774864 inputs .is_first_chunk = True
775865 inputs .is_chunk_multimodal = self .long_context_chunker .has_multimodal
776866 elif len (running ) > 0 :
@@ -783,13 +873,19 @@ def __create_inputs_prefill():
783873
784874 inputs = None
785875 delta = None
876+ running = []
877+ extra_inputs = None
786878 swap_in_map = {}
787879 swap_out_map = {}
880+ deferred_long_context_chunk = False
788881
789882 self .long_context_chunker .check_enable ()
790883 if self .long_context_chunker .enabled ():
791884 # long context chunking
792- running , inputs , delta , extra_inputs = __create_inputs_long_context_chunk ()
885+ if self ._should_decode_before_long_context_chunk (prefill ):
886+ deferred_long_context_chunk = True
887+ else :
888+ running , inputs , delta , extra_inputs = __create_inputs_long_context_chunk ()
793889 elif prefill :
794890 # prefill
795891 (
@@ -801,17 +897,20 @@ def __create_inputs_prefill():
801897 swap_out_map ,
802898 ) = __create_inputs_prefill ()
803899
804- # reset decode count when non-decoding inputs are produced
805- if inputs is not None and not inputs .is_decoding :
806- self ._decode_count = 0
807-
808900 # try decoding
809901 if inputs is None and len (self .running_seqs ) > 0 and self .config .role != EngineRole .Prefill :
810902 prefill = False
811903 delta , running , invalid_seqs = self .create_model_inputs_delta ()
812904 self .to_evict_seqs = invalid_seqs
813905 extra_inputs = None
814906
907+ if inputs is None and delta is None and deferred_long_context_chunk and self .long_context_chunker .enabled ():
908+ running , inputs , delta , extra_inputs = __create_inputs_long_context_chunk ()
909+
910+ # reset decode count when non-decoding inputs are produced
911+ if inputs is not None and not inputs .is_decoding :
912+ self ._decode_count = 0
913+
815914 # skip if enable empty
816915 if inputs is None and delta is None :
817916 return None
@@ -844,9 +943,10 @@ def do_prefill_pnode(self):
844943 def do_prefill_default (self ):
845944 # decoding if no waiting
846945 scheduler = self .scheduler
946+ pending_last_chunk = self ._has_pending_last_long_context_chunk ()
847947
848948 # do decoding if not waiting
849- if not scheduler .has_waiting ():
949+ if not scheduler .has_waiting () and not pending_last_chunk :
850950 self ._decode_count = 0
851951 return False
852952
@@ -861,6 +961,10 @@ def do_prefill_default(self):
861961 token_count += seq .num_token_ids
862962 if token_count >= self .config .max_prefill_token_num :
863963 return True
964+ if pending_last_chunk :
965+ token_count += self .long_context_chunker .seq .num_token_ids
966+ if token_count >= self .config .max_prefill_token_num :
967+ return True
864968
865969 # prefill if no enough running
866970 num_ready = scheduler .num_ready ()
@@ -892,6 +996,7 @@ async def _send_next_inputs_impl(self, prefill: bool = None, enable_empty: bool
892996 session_ids = [seq .session_id for seq in next_running ]
893997 logger .debug (f'Forward session_ids: { session_ids } ' )
894998 await self .executor .forward_async (forward_inputs )
999+ self ._last_forward_kind = self ._forward_kind (inputs , forward_inputs ['delta' ])
8951000 self .forward_inputs = forward_inputs
8961001 return forward_inputs , next_running
8971002
0 commit comments