Skip to content

Commit 50bc81f

Browse files
committed
first slice chunked prefill
1 parent 49f42e1 commit 50bc81f

12 files changed

Lines changed: 649 additions & 26 deletions

File tree

lmdeploy/pytorch/engine/inputs_maker.py

Lines changed: 117 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,11 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2+
"""Engine-loop input construction for the LMDeploy PyTorch backend.
3+
4+
This module converts scheduler decisions into model-agent inputs. Most helpers
5+
build tensor fields for full-batch ``ModelInputs``; ``InputsMakerAsync`` is the
6+
coordinator that chooses prefill/chunk/decode work, attaches per-forward
7+
metadata, dispatches it to the executor, and updates local running state.
8+
"""
29
import logging
310
from collections import defaultdict
411
from dataclasses import dataclass
@@ -247,6 +254,39 @@ def check_enable(self):
247254

248255

249256
class InputsMakerAsync:
257+
"""Coordinate prefill, decode, and long-context input dispatch.
258+
259+
``Scheduler`` owns admission, ordering, and cache/KV resources. This class
260+
consumes the scheduler result and builds tensors only after resources have
261+
been granted. Prefill-like work is represented by full ``ModelInputs``:
262+
prompt prefill, final long-context chunks, and eager non-final long chunks.
263+
Decode is represented by ``ModelInputsDelta`` and reuses persistent
264+
model-agent/strategy ``StepInputs`` that were created by earlier prefill and
265+
decode forwards.
266+
267+
``running_seqs`` is local engine-loop state, not the scheduler's source of
268+
truth. It tracks sequences already sent to the executor so this class can
269+
build decode deltas, evict invalid decode requests, and update the local
270+
view after outputs return. Every dispatched forward also carries the
271+
strategy-specific ``extra_inputs``, sampling inputs, and stopping criteria
272+
expected by the model agent.
273+
274+
Long-context chunking is coordinated here because it spans scheduling
275+
policy and input construction. ``LongContextChunker`` tracks one active
276+
long prefill and selects model-safe chunk boundaries, including indivisible
277+
multimodal spans. Before tensors are created for each chunk, the scheduler
278+
reserves the chunk's KV ownership. Non-final chunks are eager chunk
279+
forwards with no user-visible output; the final chunk is treated as normal
280+
prefill so it can merge into persistent decode state.
281+
282+
The current first-slice chunked-prefill policy intentionally uses separate
283+
forwards instead of one mixed decode+prefill tensor batch. After a
284+
non-final chunk, runnable decode is preferred and remains on the existing
285+
delta/CUDAGraph path; at most one eager non-final long chunk is sent after
286+
decode gets a chance to run. Preserve chunk flags such as
287+
``is_chunk_multimodal`` and ``is_last_chunk`` because VLM and speculative
288+
decoding paths interpret them downstream.
289+
"""
250290

251291
def __init__(
252292
self,
@@ -276,6 +316,7 @@ def __init__(
276316

277317
# consecutive decode counter for prefill starvation prevention
278318
self._decode_count = 0
319+
self._last_forward_kind = None
279320

280321
# record for next forward.
281322
self.next_is_prefill = True
@@ -297,6 +338,38 @@ def _init_do_prefill(self, config: InputsMakerConfig):
297338
else:
298339
self.do_prefill = self.do_prefill_default
299340

341+
def _has_pending_last_long_context_chunk(self):
342+
"""Check whether a running long context has only its final chunk
343+
left."""
344+
return self.long_context_chunker.enabled() and self.long_context_chunker.is_last_chunk()
345+
346+
def _should_decode_before_long_context_chunk(self, prefill: bool):
347+
"""Prefer decode when a long-context chunk should not monopolize the
348+
loop."""
349+
if self.config.role == EngineRole.Prefill:
350+
return False
351+
if len(self.running_seqs) == 0:
352+
return False
353+
if not self.long_context_chunker.enabled():
354+
return False
355+
if self.long_context_chunker.is_last_chunk():
356+
return not prefill
357+
return getattr(self, '_last_forward_kind', None) == 'long_context_chunk'
358+
359+
def _forward_kind(self, inputs: 'ModelInputs|None', delta: 'ModelInputsDelta|None'):
360+
"""Classify a queued forward for long-context interleaving policy."""
361+
if inputs is None:
362+
if delta is not None:
363+
return 'decode'
364+
return None
365+
if inputs.is_chunk and not inputs.is_last_chunk:
366+
return 'long_context_chunk'
367+
if inputs.is_chunk:
368+
return 'last_long_context_chunk'
369+
if inputs.is_decoding:
370+
return 'decode'
371+
return 'prefill'
372+
300373
def _create_vision_model_inputs(self, messages: 'SeqList', model_inputs: ModelInputs):
301374
"""Create vision model inputs."""
302375
batch_size = len(messages)
@@ -723,25 +796,41 @@ def __create_model_inputs(seqs):
723796
extra_inputs = self.model_agent_strategy.make_extra_inputs(seqs, inputs)
724797
return inputs, delta, extra_inputs
725798

726-
def __create_inputs_chunk(running: 'SeqList'):
727-
chunk_size, multimodals = self.long_context_chunker.next_chunk_size()
799+
def __create_inputs_chunk(running: 'SeqList', chunk_size: int, multimodals: 'MultiModalInputs|None'):
728800
inputs = self.create_model_inputs_long_context(running[0], chunk_size, multimodals)
729801
extra_inputs = self.model_agent_strategy.make_extra_inputs(running, inputs)
730802
return inputs, extra_inputs
731803

804+
def __reserve_long_context_chunk(seq: 'SchedulerSequence', chunk_size: int, is_last_chunk: bool):
805+
if self.config.role == EngineRole.Prefill:
806+
prealloc_size = 0
807+
elif is_last_chunk:
808+
prealloc_size = self.engine_strategy.get_prealloc_size(True)
809+
else:
810+
prealloc_size = 0
811+
return scheduler.reserve_long_context_chunk(seq,
812+
chunk_size,
813+
prealloc_size=prealloc_size,
814+
is_last_chunk=is_last_chunk)
815+
732816
def __create_inputs_long_context_chunk():
733817
seq = self.long_context_chunker.seq
818+
chunk_size, multimodals = self.long_context_chunker.next_chunk_size()
819+
is_last_chunk = self.long_context_chunker.is_last_chunk()
820+
is_chunk_multimodal = self.long_context_chunker.has_multimodal
821+
if not __reserve_long_context_chunk(seq, chunk_size, is_last_chunk):
822+
return [], None, None, None
734823
running = [seq]
735-
if self.long_context_chunker.is_last_chunk():
824+
if is_last_chunk:
736825
inputs, delta, extra_inputs = __create_model_inputs(running)
737826
inputs.is_chunk = True
738827
inputs.is_last_chunk = True
739828
self.long_context_chunker.clear()
740829
else:
741-
inputs, extra_inputs = __create_inputs_chunk(running)
830+
inputs, extra_inputs = __create_inputs_chunk(running, chunk_size, multimodals)
742831
delta = None
743832
inputs.is_first_chunk = False
744-
inputs.is_chunk_multimodal = self.long_context_chunker.has_multimodal
833+
inputs.is_chunk_multimodal = is_chunk_multimodal
745834
return running, inputs, delta, extra_inputs
746835

747836
def __create_inputs_prefill():
@@ -770,7 +859,8 @@ def __create_inputs_prefill():
770859
self.long_context_chunker.clear()
771860
inputs, delta, extra_inputs = __create_model_inputs(running)
772861
else:
773-
inputs, extra_inputs = __create_inputs_chunk(running)
862+
chunk_size, multimodals = self.long_context_chunker.next_chunk_size()
863+
inputs, extra_inputs = __create_inputs_chunk(running, chunk_size, multimodals)
774864
inputs.is_first_chunk = True
775865
inputs.is_chunk_multimodal = self.long_context_chunker.has_multimodal
776866
elif len(running) > 0:
@@ -783,13 +873,19 @@ def __create_inputs_prefill():
783873

784874
inputs = None
785875
delta = None
876+
running = []
877+
extra_inputs = None
786878
swap_in_map = {}
787879
swap_out_map = {}
880+
deferred_long_context_chunk = False
788881

789882
self.long_context_chunker.check_enable()
790883
if self.long_context_chunker.enabled():
791884
# long context chunking
792-
running, inputs, delta, extra_inputs = __create_inputs_long_context_chunk()
885+
if self._should_decode_before_long_context_chunk(prefill):
886+
deferred_long_context_chunk = True
887+
else:
888+
running, inputs, delta, extra_inputs = __create_inputs_long_context_chunk()
793889
elif prefill:
794890
# prefill
795891
(
@@ -801,17 +897,20 @@ def __create_inputs_prefill():
801897
swap_out_map,
802898
) = __create_inputs_prefill()
803899

804-
# reset decode count when non-decoding inputs are produced
805-
if inputs is not None and not inputs.is_decoding:
806-
self._decode_count = 0
807-
808900
# try decoding
809901
if inputs is None and len(self.running_seqs) > 0 and self.config.role != EngineRole.Prefill:
810902
prefill = False
811903
delta, running, invalid_seqs = self.create_model_inputs_delta()
812904
self.to_evict_seqs = invalid_seqs
813905
extra_inputs = None
814906

907+
if inputs is None and delta is None and deferred_long_context_chunk and self.long_context_chunker.enabled():
908+
running, inputs, delta, extra_inputs = __create_inputs_long_context_chunk()
909+
910+
# reset decode count when non-decoding inputs are produced
911+
if inputs is not None and not inputs.is_decoding:
912+
self._decode_count = 0
913+
815914
# skip if enable empty
816915
if inputs is None and delta is None:
817916
return None
@@ -844,9 +943,10 @@ def do_prefill_pnode(self):
844943
def do_prefill_default(self):
845944
# decoding if no waiting
846945
scheduler = self.scheduler
946+
pending_last_chunk = self._has_pending_last_long_context_chunk()
847947

848948
# do decoding if not waiting
849-
if not scheduler.has_waiting():
949+
if not scheduler.has_waiting() and not pending_last_chunk:
850950
self._decode_count = 0
851951
return False
852952

@@ -861,6 +961,10 @@ def do_prefill_default(self):
861961
token_count += seq.num_token_ids
862962
if token_count >= self.config.max_prefill_token_num:
863963
return True
964+
if pending_last_chunk:
965+
token_count += self.long_context_chunker.seq.num_token_ids
966+
if token_count >= self.config.max_prefill_token_num:
967+
return True
864968

865969
# prefill if no enough running
866970
num_ready = scheduler.num_ready()
@@ -892,6 +996,7 @@ async def _send_next_inputs_impl(self, prefill: bool = None, enable_empty: bool
892996
session_ids = [seq.session_id for seq in next_running]
893997
logger.debug(f'Forward session_ids: {session_ids}')
894998
await self.executor.forward_async(forward_inputs)
999+
self._last_forward_kind = self._forward_kind(inputs, forward_inputs['delta'])
8951000
self.forward_inputs = forward_inputs
8961001
return forward_inputs, next_running
8971002

lmdeploy/pytorch/messages.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -684,6 +684,9 @@ class SchedulerSequence:
684684
meta: Any = None
685685
num_ignored_history: int = 0
686686
model_meta: dict[str, Any] = None
687+
# Exclusive absolute token limit for temporary KV ownership. Non-final
688+
# long-context chunks use this to allocate only the computed prefix.
689+
kv_token_limit: int | None = None
687690

688691
# For Disaggregation
689692
migration_request: None | MigrationRequest = None

lmdeploy/pytorch/paging/block_manager/default_block_manager.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,10 @@ class DefaultBlockManager(BaseBlockManager):
2525
@classmethod
2626
def num_required_blocks(cls, obj: SchedulerSequence, prealloc_size: int = 0):
2727
"""Get num required blocks."""
28-
num_tokens = obj.num_all_ids + prealloc_size
28+
num_tokens = obj.num_all_ids
29+
if obj.kv_token_limit is not None:
30+
num_tokens = min(num_tokens, obj.kv_token_limit)
31+
num_tokens += prealloc_size
2932

3033
num_all_blocks = _div_up(num_tokens, obj.block_size)
3134
return max(0, num_all_blocks - len(obj.logical_blocks))

lmdeploy/pytorch/paging/block_manager/window_block_manager.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,13 @@ def num_required_blocks(self, obj: SchedulerSequence, prealloc_size: int = 0):
4242
if obj.num_history_ids <= self.window_size:
4343
return super().num_required_blocks(obj, prealloc_size)
4444

45-
return super().num_required_blocks(obj, prealloc_size) - obj.num_ignored_history // obj.block_size
45+
# DefaultBlockManager applies kv_token_limit to the absolute token
46+
# count. Sliding-window accounting then subtracts already-dropped
47+
# history blocks so chunk-limited allocation grows only the retained
48+
# window.
49+
num_required_blocks = super().num_required_blocks(obj, prealloc_size)
50+
num_required_blocks -= obj.num_ignored_history // obj.block_size
51+
return max(0, num_required_blocks)
4652

4753
def can_allocate(self, msg: SchedulerSequence, prealloc_size: int = 0):
4854
"""Return if physical block can be allocated for given message."""

lmdeploy/pytorch/paging/block_trie.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -989,6 +989,8 @@ def allocate(self, seq: SchedulerSequence):
989989

990990
num_matched = node.num_matched
991991
num_valid_ids = seq.num_valid_ids
992+
if seq.kv_token_limit is not None:
993+
num_valid_ids = min(num_valid_ids, seq.kv_token_limit)
992994

993995
if num_matched + block_size > num_valid_ids:
994996
return

0 commit comments

Comments
 (0)