Skip to content

Commit d6e7baf

Browse files
committed
first slice chunked prefill
1 parent 75f5ddc commit d6e7baf

12 files changed

Lines changed: 653 additions & 28 deletions

File tree

lmdeploy/pytorch/engine/inputs_maker.py

Lines changed: 117 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,11 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2+
"""Engine-loop input construction for the LMDeploy PyTorch backend.
3+
4+
This module converts scheduler decisions into model-agent inputs. Most helpers
5+
build tensor fields for full-batch ``ModelInputs``; ``InputsMakerAsync`` is the
6+
coordinator that chooses prefill/chunk/decode work, attaches per-forward
7+
metadata, dispatches it to the executor, and updates local running state.
8+
"""
29
import logging
310
from collections import defaultdict
411
from dataclasses import dataclass
@@ -243,6 +250,39 @@ def check_enable(self):
243250

244251

245252
class InputsMakerAsync:
253+
"""Coordinate prefill, decode, and long-context input dispatch.
254+
255+
``Scheduler`` owns admission, ordering, and cache/KV resources. This class
256+
consumes the scheduler result and builds tensors only after resources have
257+
been granted. Prefill-like work is represented by full ``ModelInputs``:
258+
prompt prefill, final long-context chunks, and eager non-final long chunks.
259+
Decode is represented by ``ModelInputsDelta`` and reuses persistent
260+
model-agent/strategy ``StepInputs`` that were created by earlier prefill and
261+
decode forwards.
262+
263+
``running_seqs`` is local engine-loop state, not the scheduler's source of
264+
truth. It tracks sequences already sent to the executor so this class can
265+
build decode deltas, evict invalid decode requests, and update the local
266+
view after outputs return. Every dispatched forward also carries the
267+
strategy-specific ``extra_inputs``, sampling inputs, and stopping criteria
268+
expected by the model agent.
269+
270+
Long-context chunking is coordinated here because it spans scheduling
271+
policy and input construction. ``LongContextChunker`` tracks one active
272+
long prefill and selects model-safe chunk boundaries, including indivisible
273+
multimodal spans. Before tensors are created for each chunk, the scheduler
274+
reserves the chunk's KV ownership. Non-final chunks are eager chunk
275+
forwards with no user-visible output; the final chunk is treated as normal
276+
prefill so it can merge into persistent decode state.
277+
278+
The current first-slice chunked-prefill policy intentionally uses separate
279+
forwards instead of one mixed decode+prefill tensor batch. After a
280+
non-final chunk, runnable decode is preferred and remains on the existing
281+
delta/CUDAGraph path; at most one eager non-final long chunk is sent after
282+
decode gets a chance to run. Preserve chunk flags such as
283+
``is_chunk_multimodal`` and ``is_last_chunk`` because VLM and speculative
284+
decoding paths interpret them downstream.
285+
"""
246286

247287
def __init__(
248288
self,
@@ -272,6 +312,7 @@ def __init__(
272312

273313
# consecutive decode counter for prefill starvation prevention
274314
self._decode_count = 0
315+
self._last_forward_kind = None
275316

276317
# record for next forward.
277318
self.next_is_prefill = True
@@ -293,6 +334,38 @@ def _init_do_prefill(self, config: InputsMakerConfig):
293334
else:
294335
self.do_prefill = self.do_prefill_default
295336

337+
def _has_pending_last_long_context_chunk(self):
338+
"""Check whether a running long context has only its final chunk
339+
left."""
340+
return self.long_context_chunker.enabled() and self.long_context_chunker.is_last_chunk()
341+
342+
def _should_decode_before_long_context_chunk(self, prefill: bool):
343+
"""Prefer decode when a long-context chunk should not monopolize the
344+
loop."""
345+
if self.config.role == EngineRole.Prefill:
346+
return False
347+
if len(self.running_seqs) == 0:
348+
return False
349+
if not self.long_context_chunker.enabled():
350+
return False
351+
if self.long_context_chunker.is_last_chunk():
352+
return not prefill
353+
return getattr(self, '_last_forward_kind', None) == 'long_context_chunk'
354+
355+
def _forward_kind(self, inputs: 'ModelInputs|None', delta: 'ModelInputsDelta|None'):
356+
"""Classify a queued forward for long-context interleaving policy."""
357+
if inputs is None:
358+
if delta is not None:
359+
return 'decode'
360+
return None
361+
if inputs.is_chunk and not inputs.is_last_chunk:
362+
return 'long_context_chunk'
363+
if inputs.is_chunk:
364+
return 'last_long_context_chunk'
365+
if inputs.is_decoding:
366+
return 'decode'
367+
return 'prefill'
368+
296369
def _create_vision_model_inputs(self, messages: 'SeqList', model_inputs: ModelInputs):
297370
"""Create vision model inputs."""
298371
batch_size = len(messages)
@@ -722,26 +795,41 @@ def __create_model_inputs(seqs):
722795
extra_inputs = self.model_agent_strategy.make_extra_inputs(seqs, inputs)
723796
return inputs, delta, extra_inputs
724797

725-
def __create_inputs_chunk(running: 'SeqList'):
726-
chunk_size, multimodals = self.long_context_chunker.next_chunk_size()
798+
def __create_inputs_chunk(running: 'SeqList', chunk_size: int, multimodals: 'MultiModalInputs|None'):
727799
inputs = self.create_model_inputs_long_context(running[0], chunk_size, multimodals)
728800
extra_inputs = self.model_agent_strategy.make_extra_inputs(running, inputs)
729801
return inputs, extra_inputs
730802

803+
def __reserve_long_context_chunk(seq: 'SchedulerSequence', chunk_size: int, is_last_chunk: bool):
804+
if self.config.role == EngineRole.Prefill:
805+
prealloc_size = 0
806+
elif is_last_chunk:
807+
prealloc_size = self.engine_strategy.get_prealloc_size(True)
808+
else:
809+
prealloc_size = 0
810+
return scheduler.reserve_long_context_chunk(seq,
811+
chunk_size,
812+
prealloc_size=prealloc_size,
813+
is_last_chunk=is_last_chunk)
814+
731815
def __create_inputs_long_context_chunk():
732816
seq = self.long_context_chunker.seq
817+
chunk_size, multimodals = self.long_context_chunker.next_chunk_size()
818+
is_last_chunk = self.long_context_chunker.is_last_chunk()
819+
is_chunk_multimodal = self.long_context_chunker.has_multimodal
820+
if not __reserve_long_context_chunk(seq, chunk_size, is_last_chunk):
821+
return [], None, None, None
733822
running = [seq]
734-
has_multimodal = self.long_context_chunker.has_multimodal
735-
if self.long_context_chunker.is_last_chunk():
823+
if is_last_chunk:
736824
inputs, delta, extra_inputs = __create_model_inputs(running)
737825
inputs.is_chunk = True
738826
inputs.is_last_chunk = True
739827
self.long_context_chunker.clear()
740828
else:
741-
inputs, extra_inputs = __create_inputs_chunk(running)
829+
inputs, extra_inputs = __create_inputs_chunk(running, chunk_size, multimodals)
742830
delta = None
743831
inputs.is_first_chunk = False
744-
inputs.is_chunk_multimodal = has_multimodal
832+
inputs.is_chunk_multimodal = is_chunk_multimodal
745833
return running, inputs, delta, extra_inputs
746834

747835
def __create_inputs_prefill():
@@ -770,7 +858,8 @@ def __create_inputs_prefill():
770858
self.long_context_chunker.clear()
771859
inputs, delta, extra_inputs = __create_model_inputs(running)
772860
else:
773-
inputs, extra_inputs = __create_inputs_chunk(running)
861+
chunk_size, multimodals = self.long_context_chunker.next_chunk_size()
862+
inputs, extra_inputs = __create_inputs_chunk(running, chunk_size, multimodals)
774863
inputs.is_first_chunk = True
775864
inputs.is_chunk_multimodal = self.long_context_chunker.has_multimodal
776865
elif len(running) > 0:
@@ -783,13 +872,19 @@ def __create_inputs_prefill():
783872

784873
inputs = None
785874
delta = None
875+
running = []
876+
extra_inputs = None
786877
swap_in_map = {}
787878
swap_out_map = {}
879+
deferred_long_context_chunk = False
788880

789881
self.long_context_chunker.check_enable()
790882
if self.long_context_chunker.enabled():
791883
# long context chunking
792-
running, inputs, delta, extra_inputs = __create_inputs_long_context_chunk()
884+
if self._should_decode_before_long_context_chunk(prefill):
885+
deferred_long_context_chunk = True
886+
else:
887+
running, inputs, delta, extra_inputs = __create_inputs_long_context_chunk()
793888
elif prefill:
794889
# prefill
795890
(
@@ -801,17 +896,20 @@ def __create_inputs_prefill():
801896
swap_out_map,
802897
) = __create_inputs_prefill()
803898

804-
# reset decode count when non-decoding inputs are produced
805-
if inputs is not None and not inputs.is_decoding:
806-
self._decode_count = 0
807-
808899
# try decoding
809900
if inputs is None and len(self.running_seqs) > 0 and self.config.role != EngineRole.Prefill:
810901
prefill = False
811902
delta, running, invalid_seqs = self.create_model_inputs_delta()
812903
self.to_evict_seqs = invalid_seqs
813904
extra_inputs = None
814905

906+
if inputs is None and delta is None and deferred_long_context_chunk and self.long_context_chunker.enabled():
907+
running, inputs, delta, extra_inputs = __create_inputs_long_context_chunk()
908+
909+
# reset decode count when non-decoding inputs are produced
910+
if inputs is not None and not inputs.is_decoding:
911+
self._decode_count = 0
912+
815913
# skip if enable empty
816914
if inputs is None and delta is None:
817915
return None
@@ -844,9 +942,10 @@ def do_prefill_pnode(self):
844942
def do_prefill_default(self):
845943
# decoding if no waiting
846944
scheduler = self.scheduler
945+
pending_last_chunk = self._has_pending_last_long_context_chunk()
847946

848947
# do decoding if not waiting
849-
if not scheduler.has_waiting():
948+
if not scheduler.has_waiting() and not pending_last_chunk:
850949
self._decode_count = 0
851950
return False
852951

@@ -861,6 +960,10 @@ def do_prefill_default(self):
861960
token_count += seq.num_token_ids
862961
if token_count >= self.config.max_prefill_token_num:
863962
return True
963+
if pending_last_chunk:
964+
token_count += self.long_context_chunker.seq.num_token_ids
965+
if token_count >= self.config.max_prefill_token_num:
966+
return True
864967

865968
# prefill if no enough running
866969
num_ready = scheduler.num_ready()
@@ -892,6 +995,7 @@ async def _send_next_inputs_impl(self, prefill: bool = None, enable_empty: bool
892995
session_ids = [seq.session_id for seq in next_running]
893996
logger.debug(f'Forward session_ids: {session_ids}')
894997
await self.executor.forward_async(forward_inputs)
998+
self._last_forward_kind = self._forward_kind(inputs, forward_inputs['delta'])
895999
self.scheduler.tick()
8961000
self.forward_inputs = forward_inputs
8971001
return forward_inputs, next_running

lmdeploy/pytorch/messages.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -696,6 +696,9 @@ class SchedulerSequence:
696696
meta: Any = None
697697
num_ignored_history: int = 0
698698
model_meta: dict[str, Any] = None
699+
# Exclusive absolute token limit for temporary KV ownership. Non-final
700+
# long-context chunks use this to allocate only the computed prefix.
701+
kv_token_limit: int | None = None
699702

700703
# For Disaggregation
701704
migration_request: None | MigrationRequest = None

lmdeploy/pytorch/paging/block_manager/default_block_manager.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,10 @@ class DefaultBlockManager(BaseBlockManager):
2525
@classmethod
2626
def num_required_blocks(cls, obj: SchedulerSequence, prealloc_size: int = 0):
2727
"""Get num required blocks."""
28-
num_tokens = obj.num_all_ids + prealloc_size
28+
num_tokens = obj.num_all_ids
29+
if obj.kv_token_limit is not None:
30+
num_tokens = min(num_tokens, obj.kv_token_limit)
31+
num_tokens += prealloc_size
2932

3033
num_all_blocks = _div_up(num_tokens, obj.block_size)
3134
return max(0, num_all_blocks - len(obj.logical_blocks))

lmdeploy/pytorch/paging/block_manager/window_block_manager.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,13 @@ def num_required_blocks(self, obj: SchedulerSequence, prealloc_size: int = 0):
4242
if obj.num_history_ids <= self.window_size:
4343
return super().num_required_blocks(obj, prealloc_size)
4444

45-
return super().num_required_blocks(obj, prealloc_size) - obj.num_ignored_history // obj.block_size
45+
# DefaultBlockManager applies kv_token_limit to the absolute token
46+
# count. Sliding-window accounting then subtracts already-dropped
47+
# history blocks so chunk-limited allocation grows only the retained
48+
# window.
49+
num_required_blocks = super().num_required_blocks(obj, prealloc_size)
50+
num_required_blocks -= obj.num_ignored_history // obj.block_size
51+
return max(0, num_required_blocks)
4652

4753
def can_allocate(self, msg: SchedulerSequence, prealloc_size: int = 0):
4854
"""Return if physical block can be allocated for given message."""

lmdeploy/pytorch/paging/block_trie.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1085,6 +1085,8 @@ def allocate(self, seq: SchedulerSequence):
10851085

10861086
num_matched = node.num_matched
10871087
num_valid_ids = seq.num_valid_ids
1088+
if seq.kv_token_limit is not None:
1089+
num_valid_ids = min(num_valid_ids, seq.kv_token_limit)
10881090

10891091
if num_matched + block_size > num_valid_ids:
10901092
return

0 commit comments

Comments
 (0)