Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 115 additions & 13 deletions lmdeploy/pytorch/engine/inputs_maker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
"""Engine-loop input construction for the LMDeploy PyTorch backend.

This module converts scheduler decisions into model-agent inputs. Most helpers
build tensor fields for full-batch ``ModelInputs``; ``InputsMakerAsync`` is the
coordinator that chooses prefill/chunk/decode work, attaches per-forward
metadata, dispatches it to the executor, and updates local running state.
"""
import logging
from collections import defaultdict
from dataclasses import dataclass
Expand Down Expand Up @@ -243,6 +250,39 @@ def check_enable(self):


class InputsMakerAsync:
"""Coordinate prefill, decode, and long-context input dispatch.

``Scheduler`` owns admission, ordering, and cache/KV resources. This class
consumes the scheduler result and builds tensors only after resources have
been granted. Prefill-like work is represented by full ``ModelInputs``:
prompt prefill, final long-context chunks, and eager non-final long chunks.
Decode is represented by ``ModelInputsDelta`` and reuses persistent
model-agent/strategy ``StepInputs`` that were created by earlier prefill and
decode forwards.

``running_seqs`` is local engine-loop state, not the scheduler's source of
truth. It tracks sequences already sent to the executor so this class can
build decode deltas, evict invalid decode requests, and update the local
view after outputs return. Every dispatched forward also carries the
strategy-specific ``extra_inputs``, sampling inputs, and stopping criteria
expected by the model agent.

Long-context chunking is coordinated here because it spans scheduling
policy and input construction. ``LongContextChunker`` tracks one active
long prefill and selects model-safe chunk boundaries, including indivisible
multimodal spans. Before tensors are created for each chunk, the scheduler
reserves the chunk's KV ownership. Non-final chunks are eager chunk
forwards with no user-visible output; the final chunk is treated as normal
prefill so it can merge into persistent decode state.

The current first-slice chunked-prefill policy intentionally uses separate
forwards instead of one mixed decode+prefill tensor batch. After a
non-final chunk, runnable decode is preferred and remains on the existing
delta/CUDAGraph path; at most one eager non-final long chunk is sent after
decode gets a chance to run. Preserve chunk flags such as
``is_chunk_multimodal`` and ``is_last_chunk`` because VLM and speculative
decoding paths interpret them downstream.
"""

def __init__(
self,
Expand Down Expand Up @@ -272,6 +312,7 @@ def __init__(

# consecutive decode counter for prefill starvation prevention
self._decode_count = 0
self._last_forward_kind = None

# record for next forward.
self.next_is_prefill = True
Expand All @@ -293,6 +334,38 @@ def _init_do_prefill(self, config: InputsMakerConfig):
else:
self.do_prefill = self.do_prefill_default

def _has_pending_last_long_context_chunk(self):
"""Check whether a running long context has only its final chunk
left."""
return self.long_context_chunker.enabled() and self.long_context_chunker.is_last_chunk()

def _should_decode_before_long_context_chunk(self, prefill: bool):
"""Prefer decode when a long-context chunk should not monopolize the
loop."""
if self.config.role == EngineRole.Prefill:
return False
if len(self.running_seqs) == 0:
return False
if not self.long_context_chunker.enabled():
return False
if self.long_context_chunker.is_last_chunk():
return not prefill
return getattr(self, '_last_forward_kind', None) == 'long_context_chunk'

def _forward_kind(self, inputs: 'ModelInputs|None', delta: 'ModelInputsDelta|None'):
"""Classify a queued forward for long-context interleaving policy."""
if inputs is None:
if delta is not None:
return 'decode'
return None
if inputs.is_chunk and not inputs.is_last_chunk:
return 'long_context_chunk'
if inputs.is_chunk:
return 'last_long_context_chunk'
if inputs.is_decoding:
return 'decode'
return 'prefill'

def _create_vision_model_inputs(self, messages: 'SeqList', model_inputs: ModelInputs):
"""Create vision model inputs."""
batch_size = len(messages)
Expand Down Expand Up @@ -734,26 +807,41 @@ def __create_model_inputs(seqs):
extra_inputs = self.model_agent_strategy.make_extra_inputs(seqs, inputs)
return inputs, delta, extra_inputs

def __create_inputs_chunk(running: 'SeqList'):
chunk_size, multimodals = self.long_context_chunker.next_chunk_size()
def __create_inputs_chunk(running: 'SeqList', chunk_size: int, multimodals: 'MultiModalInputs|None'):
inputs = self.create_model_inputs_long_context(running[0], chunk_size, multimodals)
extra_inputs = self.model_agent_strategy.make_extra_inputs(running, inputs)
return inputs, extra_inputs

def __reserve_long_context_chunk(seq: 'SchedulerSequence', chunk_size: int, is_last_chunk: bool):
if self.config.role == EngineRole.Prefill:
prealloc_size = 0
elif is_last_chunk:
prealloc_size = self.engine_strategy.get_prealloc_size(True)
else:
prealloc_size = 0
return scheduler.reserve_long_context_chunk(seq,
chunk_size,
prealloc_size=prealloc_size,
is_last_chunk=is_last_chunk)

def __create_inputs_long_context_chunk():
seq = self.long_context_chunker.seq
chunk_size, multimodals = self.long_context_chunker.next_chunk_size()
is_last_chunk = self.long_context_chunker.is_last_chunk()
is_chunk_multimodal = self.long_context_chunker.has_multimodal
if not __reserve_long_context_chunk(seq, chunk_size, is_last_chunk):
return [], None, None, None
running = [seq]
has_multimodal = self.long_context_chunker.has_multimodal
if self.long_context_chunker.is_last_chunk():
if is_last_chunk:
inputs, delta, extra_inputs = __create_model_inputs(running)
inputs.is_chunk = True
inputs.is_last_chunk = True
self.long_context_chunker.clear()
else:
inputs, extra_inputs = __create_inputs_chunk(running)
inputs, extra_inputs = __create_inputs_chunk(running, chunk_size, multimodals)
delta = None
inputs.is_first_chunk = False
inputs.is_chunk_multimodal = has_multimodal
inputs.is_chunk_multimodal = is_chunk_multimodal
return running, inputs, delta, extra_inputs

def __create_inputs_prefill():
Expand Down Expand Up @@ -782,7 +870,8 @@ def __create_inputs_prefill():
self.long_context_chunker.clear()
inputs, delta, extra_inputs = __create_model_inputs(running)
else:
inputs, extra_inputs = __create_inputs_chunk(running)
chunk_size, multimodals = self.long_context_chunker.next_chunk_size()
inputs, extra_inputs = __create_inputs_chunk(running, chunk_size, multimodals)
inputs.is_first_chunk = True
inputs.is_chunk_multimodal = self.long_context_chunker.has_multimodal
elif len(running) > 0:
Expand All @@ -795,13 +884,19 @@ def __create_inputs_prefill():

inputs = None
delta = None
running = []
extra_inputs = None
swap_in_map = {}
swap_out_map = {}
deferred_long_context_chunk = False

self.long_context_chunker.check_enable()
if self.long_context_chunker.enabled():
# long context chunking
running, inputs, delta, extra_inputs = __create_inputs_long_context_chunk()
if self._should_decode_before_long_context_chunk(prefill):
deferred_long_context_chunk = True
else:
running, inputs, delta, extra_inputs = __create_inputs_long_context_chunk()
elif prefill:
# prefill
(
Expand All @@ -813,17 +908,20 @@ def __create_inputs_prefill():
swap_out_map,
) = __create_inputs_prefill()

# reset decode count when non-decoding inputs are produced
if inputs is not None and not inputs.is_decoding:
self._decode_count = 0

# try decoding
if inputs is None and len(self.running_seqs) > 0 and self.config.role != EngineRole.Prefill:
prefill = False
delta, running, invalid_seqs = self.create_model_inputs_delta()
self.to_evict_seqs = invalid_seqs
extra_inputs = None

if inputs is None and delta is None and deferred_long_context_chunk and self.long_context_chunker.enabled():
running, inputs, delta, extra_inputs = __create_inputs_long_context_chunk()

# reset decode count when non-decoding inputs are produced
if inputs is not None and not inputs.is_decoding:
self._decode_count = 0

# skip if enable empty
if inputs is None and delta is None:
return None
Expand Down Expand Up @@ -858,11 +956,14 @@ def do_prefill_pnode(self):
def do_prefill_default(self):
# decoding if no waiting
scheduler = self.scheduler
pending_last_chunk = self._has_pending_last_long_context_chunk()

# do decoding if not waiting
if not scheduler.has_waiting():
if not scheduler.has_waiting() and not pending_last_chunk:
self._decode_count = 0
return False
if pending_last_chunk:
return True

# force prefill if too many consecutive decode rounds
if self._decode_count >= self.config.prefill_interval:
Expand Down Expand Up @@ -906,6 +1007,7 @@ async def _send_next_inputs_impl(self, prefill: bool = None, enable_empty: bool
session_ids = [seq.session_id for seq in next_running]
logger.debug(f'Forward session_ids: {session_ids}')
await self.executor.forward_async(forward_inputs)
self._last_forward_kind = self._forward_kind(inputs, forward_inputs['delta'])
self.scheduler.tick()
self.forward_inputs = forward_inputs
return forward_inputs, next_running
Expand Down
3 changes: 3 additions & 0 deletions lmdeploy/pytorch/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,9 @@ class SchedulerSequence:
meta: Any = None
num_ignored_history: int = 0
model_meta: dict[str, Any] = None
# Exclusive absolute token limit for temporary KV ownership. Non-final
# long-context chunks use this to allocate only the computed prefix.
kv_token_limit: int | None = None

# For Disaggregation
migration_request: None | MigrationRequest = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@ class DefaultBlockManager(BaseBlockManager):
@classmethod
def num_required_blocks(cls, obj: SchedulerSequence, prealloc_size: int = 0):
"""Get num required blocks."""
num_tokens = obj.num_all_ids + prealloc_size
num_tokens = obj.num_all_ids
if obj.kv_token_limit is not None:
num_tokens = min(num_tokens, obj.kv_token_limit)
num_tokens += prealloc_size

num_all_blocks = _div_up(num_tokens, obj.block_size)
return max(0, num_all_blocks - len(obj.logical_blocks))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,13 @@ def num_required_blocks(self, obj: SchedulerSequence, prealloc_size: int = 0):
if obj.num_history_ids <= self.window_size:
return super().num_required_blocks(obj, prealloc_size)

return super().num_required_blocks(obj, prealloc_size) - obj.num_ignored_history // obj.block_size
# DefaultBlockManager applies kv_token_limit to the absolute token
# count. Sliding-window accounting then subtracts already-dropped
# history blocks so chunk-limited allocation grows only the retained
# window.
num_required_blocks = super().num_required_blocks(obj, prealloc_size)
num_required_blocks -= obj.num_ignored_history // obj.block_size
return max(0, num_required_blocks)

def can_allocate(self, msg: SchedulerSequence, prealloc_size: int = 0):
"""Return if physical block can be allocated for given message."""
Expand Down
2 changes: 2 additions & 0 deletions lmdeploy/pytorch/paging/block_trie.py
Original file line number Diff line number Diff line change
Expand Up @@ -1085,6 +1085,8 @@ def allocate(self, seq: SchedulerSequence):

num_matched = node.num_matched
num_valid_ids = seq.num_valid_ids
if seq.kv_token_limit is not None:
num_valid_ids = min(num_valid_ids, seq.kv_token_limit)

if num_matched + block_size > num_valid_ids:
return
Expand Down
Loading
Loading