diff --git a/lmdeploy/pytorch/engine/engine_loop.py b/lmdeploy/pytorch/engine/engine_loop.py index 234ee67d14..55b3aad2a8 100644 --- a/lmdeploy/pytorch/engine/engine_loop.py +++ b/lmdeploy/pytorch/engine/engine_loop.py @@ -2,6 +2,7 @@ import asyncio import logging import time +from collections.abc import Callable from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Optional @@ -55,25 +56,32 @@ def clear(self): class RunableEventAsync: """Awaitable async runable event.""" - def __init__(self, scheduler: 'Scheduler'): + def __init__(self, scheduler: 'Scheduler', extra_runable_checker: Callable[[], bool] | None = None): self.scheduler = scheduler + self.extra_runable_checker = extra_runable_checker self.event = asyncio.Event() + def has_unfinished(self): + """Check whether scheduler or engine-local state has runnable work.""" + if self.scheduler.has_unfinished(): + return True + return self.extra_runable_checker is not None and self.extra_runable_checker() + async def wait(self): """Wait event.""" await self.event.wait() def set(self): """Set event.""" - if self.scheduler.has_unfinished(): + if self.has_unfinished(): self.event.set() else: self.event.clear() -def build_runable_event(scheduler: 'Scheduler'): +def build_runable_event(scheduler: 'Scheduler', extra_runable_checker: Callable[[], bool] | None = None): """Build runable event.""" - return RunableEventAsync(scheduler) + return RunableEventAsync(scheduler, extra_runable_checker) @dataclass @@ -128,7 +136,9 @@ def __init__(self, self.resp_queue = asyncio.Queue() self.forward_event = CounterEvent() self.migration_event = asyncio.Event() - self.has_runable_event = RunableEventAsync(self.scheduler) + # Active long-context chunks are owned by InputsMaker, not the + # scheduler WAITING/READY queues, so include them in the runnable gate. + self.has_runable_event = RunableEventAsync(self.scheduler, self.inputs_maker.has_pending_long_context_chunk) # Sleep uses a small handshake with the scheduling loops: # 1. sleep() sets _sleep_requested and waits for main/migration drain events. # 2. main_loop and migration_loop reach safe boundaries, acknowledge @@ -383,13 +393,12 @@ def __get_logprobs(batched_outputs: 'BatchedOutputs'): async def _main_loop_try_send_next_inputs(self): """Try send next inputs.""" - scheduler = self.scheduler - if not scheduler.has_unfinished(): + if not self.has_runable_event.has_unfinished(): await self.has_runable_event.wait() if self._sleep_requested: return None, None - scheduler.collect_migration_done() + self.scheduler.collect_migration_done() return await self.inputs_maker.send_next_inputs() @staticmethod diff --git a/lmdeploy/pytorch/engine/inputs_maker.py b/lmdeploy/pytorch/engine/inputs_maker.py index ab58826069..76a75ae008 100644 --- a/lmdeploy/pytorch/engine/inputs_maker.py +++ b/lmdeploy/pytorch/engine/inputs_maker.py @@ -1,4 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. +"""Engine-loop input construction for the LMDeploy PyTorch backend. + +This module converts scheduler decisions into model-agent inputs. Most helpers +build tensor fields for full-batch ``ModelInputs``; ``InputsMakerAsync`` is the +coordinator that chooses prefill/chunk/decode work, attaches per-forward +metadata, dispatches it to the executor, and updates local running state. +""" import logging from collections import defaultdict from dataclasses import dataclass @@ -8,6 +15,7 @@ import torch from torch.profiler import record_function +from lmdeploy.pytorch import envs as _envs from lmdeploy.pytorch.disagg.config import EngineRole from lmdeploy.pytorch.messages import MessageStatus from lmdeploy.pytorch.model_inputs import ModelInputs, ModelInputsDelta, VisionModelInputs @@ -239,10 +247,47 @@ def check_enable(self): if not self.enabled(): return if self.seq.status != MessageStatus.RUNNING: + # A stopped long request no longer has a valid continuation. We do + # not send a cleanup-only worker forward here: normal prefill/decode + # ignore chunk carry, and the next first chunk resets carry before + # use. Avoiding a no-work forward also keeps DP ranks aligned. self.clear() class InputsMakerAsync: + """Coordinate prefill, decode, and long-context input dispatch. + + ``Scheduler`` owns admission, ordering, and cache/KV resources. This class + consumes the scheduler result and builds tensors only after resources have + been granted. Prefill-like work is represented by full ``ModelInputs``: + prompt prefill, final long-context chunks, and eager non-final long chunks. + Decode is represented by ``ModelInputsDelta`` and reuses persistent + model-agent/strategy ``StepInputs`` that were created by earlier prefill and + decode forwards. + + ``running_seqs`` is local engine-loop state, not the scheduler's source of + truth. It tracks sequences already sent to the executor so this class can + build decode deltas, evict invalid decode requests, and update the local + view after outputs return. Every dispatched forward also carries the + strategy-specific ``extra_inputs``, sampling inputs, and stopping criteria + expected by the model agent. + + Long-context chunking is coordinated here because it spans scheduling + policy and input construction. ``LongContextChunker`` tracks one active + long prefill and selects model-safe chunk boundaries, including indivisible + multimodal spans. Before tensors are created for each chunk, the scheduler + reserves the chunk's KV ownership. Non-final chunks are eager chunk + forwards with no user-visible output; the final chunk is treated as normal + prefill so it can merge into persistent decode state. + + The current first-slice chunked-prefill policy intentionally uses separate + forwards instead of one mixed decode+prefill tensor batch. After a + non-final chunk, runnable decode is preferred and remains on the existing + delta/CUDAGraph path; at most one eager non-final long chunk is sent after + decode gets a chance to run. Preserve chunk flags such as + ``is_chunk_multimodal`` and ``is_last_chunk`` because VLM and speculative + decoding paths interpret them downstream. + """ def __init__( self, @@ -272,6 +317,9 @@ def __init__( # consecutive decode counter for prefill starvation prevention self._decode_count = 0 + self._last_forward_kind = None + self._short_prefill_turns_since_long_chunk = 0 + self._short_prefill_turns_per_long_chunk = max(1, _envs.opt_ttft_short_turns) # record for next forward. self.next_is_prefill = True @@ -293,6 +341,48 @@ def _init_do_prefill(self, config: InputsMakerConfig): else: self.do_prefill = self.do_prefill_default + def _has_pending_last_long_context_chunk(self): + """Check whether a running long context has only its final chunk + left.""" + return self.long_context_chunker.enabled() and self.long_context_chunker.is_last_chunk() + + def has_pending_long_context_chunk(self): + """Check whether engine-local long-context chunk work can run.""" + self.long_context_chunker.check_enable() + return self.long_context_chunker.enabled() + + def _should_defer_long_context_chunk(self, prefill: bool): + """Check whether the active long-context chunk should yield this + loop.""" + if self.config.role == EngineRole.Prefill: + return False + if not self.long_context_chunker.enabled(): + return False + if self.long_context_chunker.is_last_chunk(): + if len(self.running_seqs) == 0: + return False + return not prefill + return getattr(self, '_last_forward_kind', None) == 'long_context_chunk' + + def _is_long_context_chunk_turn_due(self): + """Check if active long chunk should run before another short + prefill.""" + return self._short_prefill_turns_since_long_chunk >= self._short_prefill_turns_per_long_chunk + + def _forward_kind(self, inputs: 'ModelInputs|None', delta: 'ModelInputsDelta|None'): + """Classify a queued forward for long-context interleaving policy.""" + if inputs is None: + if delta is not None: + return 'decode' + return None + if inputs.is_chunk and not inputs.is_last_chunk: + return 'long_context_chunk' + if inputs.is_chunk: + return 'last_long_context_chunk' + if inputs.is_decoding: + return 'decode' + return 'prefill' + def _create_vision_model_inputs(self, messages: 'SeqList', model_inputs: ModelInputs): """Create vision model inputs.""" batch_size = len(messages) @@ -685,7 +775,7 @@ def update_running_seqs(self, running: 'SeqList', inputs: 'ModelInputs|None'): return is_decoding = inputs is None - if self.long_context_chunker.enabled() and not is_decoding: + if self.long_context_chunker.enabled() and not is_decoding and inputs.is_chunk: # long context chunk does not need to update running seqs self.long_context_chunker.update_step(inputs) return @@ -734,34 +824,53 @@ def __create_model_inputs(seqs): extra_inputs = self.model_agent_strategy.make_extra_inputs(seqs, inputs) return inputs, delta, extra_inputs - def __create_inputs_chunk(running: 'SeqList'): - chunk_size, multimodals = self.long_context_chunker.next_chunk_size() + def __create_inputs_chunk(running: 'SeqList', chunk_size: int, multimodals: 'MultiModalInputs|None'): inputs = self.create_model_inputs_long_context(running[0], chunk_size, multimodals) extra_inputs = self.model_agent_strategy.make_extra_inputs(running, inputs) return inputs, extra_inputs + def __reserve_long_context_chunk(seq: 'SchedulerSequence', chunk_size: int, is_last_chunk: bool): + if self.config.role == EngineRole.Prefill: + prealloc_size = 0 + elif is_last_chunk: + prealloc_size = self.engine_strategy.get_prealloc_size(True) + else: + prealloc_size = 0 + return scheduler.reserve_long_context_chunk(seq, + chunk_size, + prealloc_size=prealloc_size, + is_last_chunk=is_last_chunk) + def __create_inputs_long_context_chunk(): seq = self.long_context_chunker.seq + chunk_size, multimodals = self.long_context_chunker.next_chunk_size() + is_last_chunk = self.long_context_chunker.is_last_chunk() + is_chunk_multimodal = self.long_context_chunker.has_multimodal + if not __reserve_long_context_chunk(seq, chunk_size, is_last_chunk): + return [], None, None, None running = [seq] - has_multimodal = self.long_context_chunker.has_multimodal - if self.long_context_chunker.is_last_chunk(): + if is_last_chunk: inputs, delta, extra_inputs = __create_model_inputs(running) inputs.is_chunk = True inputs.is_last_chunk = True self.long_context_chunker.clear() else: - inputs, extra_inputs = __create_inputs_chunk(running) + inputs, extra_inputs = __create_inputs_chunk(running, chunk_size, multimodals) delta = None inputs.is_first_chunk = False - inputs.is_chunk_multimodal = has_multimodal + inputs.is_chunk_multimodal = is_chunk_multimodal + self._short_prefill_turns_since_long_chunk = 0 return running, inputs, delta, extra_inputs - def __create_inputs_prefill(): + def __create_inputs_prefill(allow_long_prefill: bool = True, prefer_long_prefill: bool = False): if self.config.role == EngineRole.Prefill: prealloc_size = 0 else: prealloc_size = self.engine_strategy.get_prealloc_size(True) - scheduler_output = scheduler.schedule(is_prefill=prefill, prealloc_size=prealloc_size) + scheduler_output = scheduler.schedule(is_prefill=True, + prealloc_size=prealloc_size, + allow_long_prefill=allow_long_prefill, + prefer_long_prefill=prefer_long_prefill) running = scheduler_output.running swap_in_map = scheduler_output.swap_in_map swap_out_map = scheduler_output.swap_out_map @@ -782,28 +891,148 @@ def __create_inputs_prefill(): self.long_context_chunker.clear() inputs, delta, extra_inputs = __create_model_inputs(running) else: - inputs, extra_inputs = __create_inputs_chunk(running) + chunk_size, multimodals = self.long_context_chunker.next_chunk_size() + inputs, extra_inputs = __create_inputs_chunk(running, chunk_size, multimodals) inputs.is_first_chunk = True inputs.is_chunk_multimodal = self.long_context_chunker.has_multimodal + self._short_prefill_turns_since_long_chunk = 0 elif len(running) > 0: # create inputs inputs, delta, extra_inputs = __create_model_inputs(running) return running, inputs, delta, extra_inputs, swap_in_map, swap_out_map + def __create_short_or_normal_prefill_turn(): + nonlocal attempted_short_or_normal_prefill + attempted_short_or_normal_prefill = True + result = __create_inputs_prefill(allow_long_prefill=False) + _, prefill_inputs, prefill_delta, _, _, _ = result + if prefill_inputs is not None or prefill_delta is not None: + self._short_prefill_turns_since_long_chunk += 1 + return result + + def __is_empty_forward(forward_inputs: 'ModelInputs|None', forward_delta: 'ModelInputsDelta|None'): + return forward_inputs is None and forward_delta is None + + def __try_active_long_context_chunk(): + nonlocal attempted_long_work + nonlocal active_long_chunk_blocked_by_kv + attempted_long_work = True + result = __create_inputs_long_context_chunk() + _, chunk_inputs, chunk_delta, _ = result + active_long_chunk_blocked_by_kv = __is_empty_forward(chunk_inputs, chunk_delta) + return result + + def __should_try_short_prefill_before_active_chunk(): + """Allow short/normal prefill quota before an active non-final + chunk.""" + if self.long_context_chunker.is_last_chunk(): + return False + if not scheduler.has_waiting(): + return False + return not self._is_long_context_chunk_turn_due() + + def __has_no_forward(): + return __is_empty_forward(inputs, delta) + + def __can_fallback_to_short_after_long_work(): + if not __has_no_forward(): + return False + if not attempted_long_work: + return False + if active_long_chunk_blocked_by_kv: + return False + if attempted_short_or_normal_prefill: + return False + return scheduler.has_waiting() + + def __can_try_short_prefill_after_defer(): + if not __has_no_forward(): + return False + if not deferred_long_context_chunk: + return False + if self._is_long_context_chunk_turn_due(): + return False + return scheduler.has_waiting() + + def __can_retry_deferred_active_chunk(): + return __has_no_forward() and deferred_long_context_chunk and self.long_context_chunker.enabled() + scheduler = self.scheduler logger.debug(f'Make forward inputs with prefill={prefill}, enable_empty={enable_empty}') inputs = None delta = None + running = [] + extra_inputs = None swap_in_map = {} swap_out_map = {} - + deferred_long_context_chunk = False + attempted_long_work = False + attempted_short_or_normal_prefill = False + active_long_chunk_blocked_by_kv = False + + # Bounded opt-TTFT prefill policy: protect decode before continuing + # non-final long chunks, then allow a bounded number of short/normal + # prefill turns before forcing one long-work turn. A long-work turn + # continues the active chunker first, otherwise it admits one waiting + # long prefill through the scheduler. self.long_context_chunker.check_enable() if self.long_context_chunker.enabled(): - # long context chunking - running, inputs, delta, extra_inputs = __create_inputs_long_context_chunk() + if self._should_defer_long_context_chunk(prefill): + deferred_long_context_chunk = True + elif __should_try_short_prefill_before_active_chunk(): + # After a decode turn, keep the short/normal prefill quota in + # front of active long chunks; otherwise decode -> long can + # repeat and small waiting requests remain gated by the active + # chunker even while the long-work turn is not due. + ( + running, + inputs, + delta, + extra_inputs, + swap_in_map, + swap_out_map, + ) = __create_short_or_normal_prefill_turn() + if __is_empty_forward(inputs, delta): + running, inputs, delta, extra_inputs = __try_active_long_context_chunk() + else: + running, inputs, delta, extra_inputs = __try_active_long_context_chunk() elif prefill: # prefill + has_waiting_long_prefill = scheduler.has_waiting_long_prefill() + if has_waiting_long_prefill and not self._is_long_context_chunk_turn_due(): + ( + running, + inputs, + delta, + extra_inputs, + swap_in_map, + swap_out_map, + ) = __create_short_or_normal_prefill_turn() + if __has_no_forward(): + ( + running, + inputs, + delta, + extra_inputs, + swap_in_map, + swap_out_map, + ) = __create_inputs_prefill(prefer_long_prefill=True) + else: + ( + running, + inputs, + delta, + extra_inputs, + swap_in_map, + swap_out_map, + ) = __create_inputs_prefill(prefer_long_prefill=has_waiting_long_prefill) + attempted_long_work = has_waiting_long_prefill + + # Waiting-long admission failure can still fall back to short prefills. + # Active-long reservation failure means KV is pinned by running work; + # admit decode only so existing requests can drain blocks. + if __can_fallback_to_short_after_long_work(): ( running, inputs, @@ -811,11 +1040,7 @@ def __create_inputs_prefill(): extra_inputs, swap_in_map, swap_out_map, - ) = __create_inputs_prefill() - - # reset decode count when non-decoding inputs are produced - if inputs is not None and not inputs.is_decoding: - self._decode_count = 0 + ) = __create_short_or_normal_prefill_turn() # try decoding if inputs is None and len(self.running_seqs) > 0 and self.config.role != EngineRole.Prefill: @@ -824,6 +1049,23 @@ def __create_inputs_prefill(): self.to_evict_seqs = invalid_seqs extra_inputs = None + if __can_try_short_prefill_after_defer(): + ( + running, + inputs, + delta, + extra_inputs, + swap_in_map, + swap_out_map, + ) = __create_short_or_normal_prefill_turn() + + if __can_retry_deferred_active_chunk(): + running, inputs, delta, extra_inputs = __try_active_long_context_chunk() + + # reset decode count when non-decoding inputs are produced + if inputs is not None and not inputs.is_decoding: + self._decode_count = 0 + # skip if enable empty if inputs is None and delta is None: return None @@ -858,11 +1100,14 @@ def do_prefill_pnode(self): def do_prefill_default(self): # decoding if no waiting scheduler = self.scheduler + pending_last_chunk = self._has_pending_last_long_context_chunk() # do decoding if not waiting - if not scheduler.has_waiting(): + if not scheduler.has_waiting() and not pending_last_chunk: self._decode_count = 0 return False + if pending_last_chunk: + return True # force prefill if too many consecutive decode rounds if self._decode_count >= self.config.prefill_interval: @@ -906,6 +1151,7 @@ async def _send_next_inputs_impl(self, prefill: bool = None, enable_empty: bool session_ids = [seq.session_id for seq in next_running] logger.debug(f'Forward session_ids: {session_ids}') await self.executor.forward_async(forward_inputs) + self._last_forward_kind = self._forward_kind(inputs, forward_inputs['delta']) self.scheduler.tick() self.forward_inputs = forward_inputs return forward_inputs, next_running diff --git a/lmdeploy/pytorch/engine/model_agent/agent.py b/lmdeploy/pytorch/engine/model_agent/agent.py index f9c28e4159..88c20287e9 100644 --- a/lmdeploy/pytorch/engine/model_agent/agent.py +++ b/lmdeploy/pytorch/engine/model_agent/agent.py @@ -626,11 +626,11 @@ def _prepare_inputs_prefill( # for second round chat self.step_inputs.reindex(delta) - if inputs.is_first_chunk or not inputs.is_chunk: + if inputs.is_first_chunk: self._prev_chunk_output = None # check long context - if self._prev_chunk_output is not None: + if inputs.is_chunk and self._prev_chunk_output is not None: # update model metas model_metas = self._prev_chunk_output.get('model_metas') inputs.model_metas = model_metas diff --git a/lmdeploy/pytorch/envs.py b/lmdeploy/pytorch/envs.py index 9ec6dfa3ae..3ca8c24e85 100644 --- a/lmdeploy/pytorch/envs.py +++ b/lmdeploy/pytorch/envs.py @@ -70,6 +70,21 @@ def env_to_float( return value +def env_to_choice( + env_var: str, + default: str, + choices: set | list, +): + """Env to selected string.""" + value = os.getenv(env_var) + if value is None: + return default + value = value.lower().strip() + if value not in choices: + raise ValueError(f"Invalid environment variable '{env_var}={value}'. Allowed values: {choices}") + return value + + _ENVS = dict() @@ -173,6 +188,11 @@ def _patched_get_env( # fake capture flag for debug cudagraph padding behavior fake_capture = env_to_bool('LMDEPLOY_FAKE_CUDA_GRAPH_CAPTURE', False) + # opt-ttft + opt_ttft_policy = env_to_choice('LMDEPLOY_PT_TTFT_POLICY', 'size', {'fifo', 'size'}) + opt_ttft_short_turns = max(1, env_to_int('LMDEPLOY_PT_TTFT_SHORT_TURNS', 3)) + opt_ttft_aging_sec = env_to_float('LMDEPLOY_PT_TTFT_AGING_SEC', 2.0) + def get_all_envs(): """Get all environment variables.""" diff --git a/lmdeploy/pytorch/messages.py b/lmdeploy/pytorch/messages.py index 00888bfdbb..49729668b2 100644 --- a/lmdeploy/pytorch/messages.py +++ b/lmdeploy/pytorch/messages.py @@ -698,6 +698,9 @@ class SchedulerSequence: meta: Any = None num_ignored_history: int = 0 model_meta: dict[str, Any] = None + # Exclusive absolute token limit for temporary KV ownership. Non-final + # long-context chunks use this to allocate only the computed prefix. + kv_token_limit: int | None = None # For Disaggregation migration_request: None | MigrationRequest = None diff --git a/lmdeploy/pytorch/paging/block_manager/default_block_manager.py b/lmdeploy/pytorch/paging/block_manager/default_block_manager.py index c8f935900f..5d33ba665b 100644 --- a/lmdeploy/pytorch/paging/block_manager/default_block_manager.py +++ b/lmdeploy/pytorch/paging/block_manager/default_block_manager.py @@ -25,7 +25,10 @@ class DefaultBlockManager(BaseBlockManager): @classmethod def num_required_blocks(cls, obj: SchedulerSequence, prealloc_size: int = 0): """Get num required blocks.""" - num_tokens = obj.num_all_ids + prealloc_size + num_tokens = obj.num_all_ids + if obj.kv_token_limit is not None: + num_tokens = min(num_tokens, obj.kv_token_limit) + num_tokens += prealloc_size num_all_blocks = _div_up(num_tokens, obj.block_size) return max(0, num_all_blocks - len(obj.logical_blocks)) diff --git a/lmdeploy/pytorch/paging/block_manager/window_block_manager.py b/lmdeploy/pytorch/paging/block_manager/window_block_manager.py index cde3f90ac5..28f89e6144 100644 --- a/lmdeploy/pytorch/paging/block_manager/window_block_manager.py +++ b/lmdeploy/pytorch/paging/block_manager/window_block_manager.py @@ -42,7 +42,13 @@ def num_required_blocks(self, obj: SchedulerSequence, prealloc_size: int = 0): if obj.num_history_ids <= self.window_size: return super().num_required_blocks(obj, prealloc_size) - return super().num_required_blocks(obj, prealloc_size) - obj.num_ignored_history // obj.block_size + # DefaultBlockManager applies kv_token_limit to the absolute token + # count. Sliding-window accounting then subtracts already-dropped + # history blocks so chunk-limited allocation grows only the retained + # window. + num_required_blocks = super().num_required_blocks(obj, prealloc_size) + num_required_blocks -= obj.num_ignored_history // obj.block_size + return max(0, num_required_blocks) def can_allocate(self, msg: SchedulerSequence, prealloc_size: int = 0): """Return if physical block can be allocated for given message.""" diff --git a/lmdeploy/pytorch/paging/block_trie.py b/lmdeploy/pytorch/paging/block_trie.py index c32770a7db..cb47e8a3fa 100644 --- a/lmdeploy/pytorch/paging/block_trie.py +++ b/lmdeploy/pytorch/paging/block_trie.py @@ -1085,6 +1085,8 @@ def allocate(self, seq: SchedulerSequence): num_matched = node.num_matched num_valid_ids = seq.num_valid_ids + if seq.kv_token_limit is not None: + num_valid_ids = min(num_valid_ids, seq.kv_token_limit) if num_matched + block_size > num_valid_ids: return diff --git a/lmdeploy/pytorch/paging/scheduler.py b/lmdeploy/pytorch/paging/scheduler.py index 690a547e90..22c8fb711b 100644 --- a/lmdeploy/pytorch/paging/scheduler.py +++ b/lmdeploy/pytorch/paging/scheduler.py @@ -38,13 +38,16 @@ """ import logging +import time from collections import OrderedDict +from collections.abc import Callable from contextlib import contextmanager from dataclasses import dataclass from torch.profiler import record_function from lmdeploy.messages import EventType, ScheduleMetrics +from lmdeploy.pytorch import envs as _envs from lmdeploy.utils import get_logger from ..config import CacheConfig, SchedulerConfig @@ -70,6 +73,37 @@ class SchedulerOutput: copy_map: MapType +_PREFILL_GATE_SKIP = 'skip' +_PREFILL_GATE_BREAK = 'break' + + +@dataclass +class _PrefixMatchForPrefillGate: + """Tentative prefix match kept only because it passes a prefill gate.""" + + stats_snapshot: object + prefill_token_count: int + is_nonfinal_long_prefill: bool + + +@dataclass +class _PrefillGateCheck: + """Result of prefill-gate checks before final resource admission.""" + + prefix_match: _PrefixMatchForPrefillGate | None = None + rollback_action: str | None = None + reject_action: str | None = None + + +@dataclass(frozen=True) +class _PrefillReorderInfo: + """Immutable pre-admission metadata used only for waiting-list ordering.""" + + prefill_token_count: int + is_nonfinal_long_prefill: bool + estimated_long_chunks: int + + class Scheduler: """Tools to schedule next step. @@ -102,6 +136,8 @@ def __init__( self.seq_meta = seq_meta self.seq_manager = SequenceManager(seq_meta) self.scheduler_tick = 0 + self._long_prefill_policy = _envs.opt_ttft_policy + self._long_prefill_aging_seconds_per_chunk = max(0.001, _envs.opt_ttft_aging_sec) def tick(self): """Mark one scheduler progress step (once per forward dispatch).""" @@ -141,6 +177,7 @@ def _rollback_unscheduled_prefix_match(self, seq: SchedulerSequence, stats_snaps seq.state.free() elif seq.num_history_ids > 0: seq.set_step(0) + seq.kv_token_limit = None prefix_cache = seq.prefix_cache prefix_cache.last_shared_node = None prefix_cache.restore_state = -1 @@ -149,6 +186,96 @@ def _rollback_unscheduled_prefix_match(self, seq: SchedulerSequence, stats_snaps prefix_cache.match_start_step = -1 seq.cached_tokens = 0 + def _rollback_prefix_match_for_prefill_gate(self, seq: SchedulerSequence, stats_snapshot, reason: str): + """Rollback a prefix match tried only to re-check prefill gates.""" + if logger.isEnabledFor(logging.DEBUG): + logger.debug(f'Rollback tentative prefix-cache gate match: session_id={seq.session_id} ' + f'seq_id={seq.seq_id} reason={reason} num_history_ids={seq.num_history_ids} ' + f'restore_state={seq.prefix_cache.restore_state}') + self._rollback_unscheduled_prefix_match(seq, stats_snapshot) + + def _try_prefix_match_for_prefill_gate( + self, + seq: SchedulerSequence, + accept_match: Callable[[_PrefixMatchForPrefillGate], bool], + rollback_reason: str, + ): + """Tentatively match prefix cache before rejecting a prefill candidate. + + This helper is intentionally limited to pre-admission gates. It does not evict, allocate, acquire SSM restore + state, or publish cache state. The caller either continues into the normal admission path with the returned + match, or the helper rolls every match side effect back. + """ + if not self.block_trie.enable: + return None + + stats_snapshot = self.block_trie.snapshot_stats() + self.block_trie.match(seq) + if self._prefix_hit_starts_middle_long_context_chunk(seq): + self._rollback_prefix_match_for_prefill_gate(seq, stats_snapshot, + 'long-context chunk starts after prefix hit') + return None + + prefix_match = _PrefixMatchForPrefillGate( + stats_snapshot=stats_snapshot, + prefill_token_count=self._prefill_admission_token_count(seq), + is_nonfinal_long_prefill=self._prefill_kv_token_limit(seq) is not None, + ) + if accept_match(prefix_match): + return prefix_match + + self._rollback_prefix_match_for_prefill_gate(seq, stats_snapshot, rollback_reason) + return None + + def _check_prefill_admission_gates(self, seq: SchedulerSequence, token_count: int, has_admitted: bool, + allow_long_prefill: bool): + """Check prefill policy gates before resource admission. + + A prefix-cache hit can shrink a request enough to pass a short-turn or + token-budget gate. When that happens, the returned prefix match is + still tentative; if later resource admission rolls it back, the caller + must reject this candidate with ``rollback_action`` for the current + scheduler turn. + """ + prefill_token_count = self._prefill_admission_token_count(seq) + is_nonfinal_long_prefill = self._prefill_kv_token_limit(seq) is not None + prefix_match = None + rollback_action = None + + if is_nonfinal_long_prefill and not allow_long_prefill: + prefix_match = self._try_prefix_match_for_prefill_gate( + seq, + accept_match=lambda match: not match.is_nonfinal_long_prefill, + rollback_reason='still non-final long prefill on short turn') + if prefix_match is None: + return _PrefillGateCheck(reject_action=_PREFILL_GATE_SKIP) + prefill_token_count = prefix_match.prefill_token_count + rollback_action = _PREFILL_GATE_SKIP + + exceeds_token_budget = (has_admitted + and token_count + prefill_token_count > self.cache_config.max_prefill_token_num) + if exceeds_token_budget: + if prefix_match is None: + prefix_match = self._try_prefix_match_for_prefill_gate( + seq, + accept_match=lambda match: token_count + + match.prefill_token_count <= self.cache_config.max_prefill_token_num, + rollback_reason='still exceeds prefill token budget') + if prefix_match is not None: + prefill_token_count = prefix_match.prefill_token_count + rollback_action = _PREFILL_GATE_SKIP if not allow_long_prefill else _PREFILL_GATE_BREAK + + still_exceeds_token_budget = token_count + prefill_token_count > self.cache_config.max_prefill_token_num + if prefix_match is None or still_exceeds_token_budget: + if prefix_match is not None: + self._rollback_prefix_match_for_prefill_gate(seq, prefix_match.stats_snapshot, + 'still exceeds prefill token budget') + reject_action = _PREFILL_GATE_SKIP if not allow_long_prefill else _PREFILL_GATE_BREAK + return _PrefillGateCheck(reject_action=reject_action) + + return _PrefillGateCheck(prefix_match=prefix_match, + rollback_action=rollback_action) + @staticmethod def _finalize_prefix_cache_match(seq: SchedulerSequence): """Publish accepted cached-token count within the current prompt.""" @@ -177,13 +304,95 @@ def _prefix_hit_starts_middle_long_context_chunk(self, seq: SchedulerSequence): if seq.num_history_ids <= 0: return False + max_prefill_num = self._long_context_chunk_limit(seq) + return seq.num_token_ids > max_prefill_num + + def _long_context_chunk_limit(self, seq: SchedulerSequence): + """Return the token budget for one long-context chunk.""" max_prefill_num = self.cache_config.max_prefill_token_num mm_for_chunk_limit = seq.get_chunk_limit_multimodals() for value in mm_for_chunk_limit.values(): max_mm_size = max([v.end - v.start for v in value], default=0) max_prefill_num = max(max_prefill_num, max_mm_size) - return seq.num_token_ids > max_prefill_num + return max_prefill_num + + def _next_long_context_chunk_end(self, seq: SchedulerSequence, max_prefill_num: int | None = None): + """Return the exclusive absolute token end for the next chunk.""" + if max_prefill_num is None: + max_prefill_num = self._long_context_chunk_limit(seq) + chunk_size = min(seq.num_token_ids, max_prefill_num) + start = seq.num_history_ids + end = start + chunk_size + + input_mm = seq.get_input_multimodals() + if len(input_mm) == 0: + return end + + multimodal_data = [] + for modal_type, modal_datas in input_mm.items(): + multimodal_data += [(modal_type, data) for data in modal_datas] + multimodal_data = sorted(multimodal_data, key=lambda x: x[1].start) + + for _, data in multimodal_data: + assert data.start >= start, 'multimodal data should be sorted by start' + if data.start >= end: + break + if data.end > end: + end = data.start + break + + return end + + def _prefill_kv_token_limit(self, seq: SchedulerSequence): + """Limit KV allocation for a non-final long-context prefill chunk.""" + max_prefill_num = self._long_context_chunk_limit(seq) + if seq.num_token_ids <= max_prefill_num: + return None + return self._next_long_context_chunk_end(seq, max_prefill_num) + + def _prefill_admission_token_count(self, seq: SchedulerSequence): + """Return token budget cost for the next prefill or chunk.""" + kv_token_limit = self._prefill_kv_token_limit(seq) + if kv_token_limit is None: + return seq.num_token_ids + return max(0, kv_token_limit - seq.num_history_ids) + + def has_waiting_long_prefill(self): + """Whether a waiting request would need a non-final prefill chunk.""" + return any(self._prefill_kv_token_limit(seq) is not None for seq in self.waiting) + + def _prepare_prefill_allocation(self, seq: SchedulerSequence, prealloc_size: int): + """Apply chunk KV limit and return the effective prealloc size.""" + kv_token_limit = self._prefill_kv_token_limit(seq) + if kv_token_limit is None: + seq.kv_token_limit = None + return prealloc_size + + seq.kv_token_limit = kv_token_limit + return 0 + + def reserve_long_context_chunk(self, + seq: SchedulerSequence, + chunk_size: int, + prealloc_size: int = 0, + is_last_chunk: bool = False): + """Reserve KV blocks for the next chunk of a running long prefill.""" + old_kv_token_limit = seq.kv_token_limit + if is_last_chunk: + seq.kv_token_limit = None + else: + seq.kv_token_limit = seq.num_history_ids + chunk_size + prealloc_size = 0 + + evictable = self.hanging + self.waiting + if not self.eviction_helper.evict_for_seq(seq, evictable, prealloc_size): + seq.kv_token_limit = old_kv_token_limit + return False + + self.block_manager.allocate(seq, prealloc_size) + self.block_trie.allocate(seq) + return True @staticmethod def create_status_list_property(status: MessageStatus): @@ -286,7 +495,10 @@ def _reorder_migrating(): return migration_ready @record_function('schedule_prefill') - def _schedule_prefill(self, prealloc_size: int = 0): + def _schedule_prefill(self, + prealloc_size: int = 0, + allow_long_prefill: bool = True, + prefer_long_prefill: bool = False): """Schedule for prefilling.""" max_batches = self.scheduler_config.max_batches - self.num_ready() - self.num_running() @@ -297,38 +509,157 @@ def _schedule_prefill(self, prealloc_size: int = 0): running: SeqList = [] token_count = 0 - def _to_running(seq: SchedulerSequence): + def _to_running(seq: SchedulerSequence, prefill_token_count: int): """To running.""" seq.state.activate() running.append(seq) nonlocal token_count - token_count += seq.num_token_ids + token_count += prefill_token_count - def __evict_for_seq(seq: SchedulerSequence, waiting): + def __evict_for_seq(seq: SchedulerSequence, waiting, evict_prealloc_size: int): """Evict until can append.""" from itertools import chain hanging = reversed(self.hanging) waiting = reversed(waiting) evictable = list(chain(hanging, waiting)) - return eviction_helper.evict_for_seq(seq, evictable, prealloc_size) + return eviction_helper.evict_for_seq(seq, evictable, evict_prealloc_size) + + def __prepare_and_evict(seq: SchedulerSequence, waiting): + """Apply chunk allocation limits and evict for this prefill.""" + alloc_prealloc_size = self._prepare_prefill_allocation(seq, prealloc_size) + if __evict_for_seq(seq, waiting, alloc_prealloc_size): + return True, alloc_prealloc_size + seq.kv_token_limit = None + return False, alloc_prealloc_size + + reorder_info_cache: dict[int, _PrefillReorderInfo] = {} + + def _get_prefill_reorder_info(seq: SchedulerSequence): + """Return reorder-only info before prefix-cache side effects. + + Prefix-cache match/rollback mutates the remaining prompt. Keep this cache confined to waiting-list ordering + and recompute fresh values in the admission path below. + """ + seq_key = id(seq) + info = reorder_info_cache.get(seq_key) + if info is not None: + return info + + chunk_limit = self._long_context_chunk_limit(seq) + if seq.num_token_ids <= chunk_limit: + info = _PrefillReorderInfo(prefill_token_count=seq.num_token_ids, + is_nonfinal_long_prefill=False, + estimated_long_chunks=1) + else: + kv_token_limit = self._next_long_context_chunk_end(seq, chunk_limit) + safe_chunk_limit = max(1, chunk_limit) + info = _PrefillReorderInfo( + prefill_token_count=max(0, kv_token_limit - seq.num_history_ids), + is_nonfinal_long_prefill=True, + estimated_long_chunks=max(1, (seq.num_token_ids + safe_chunk_limit - 1) // safe_chunk_limit), + ) + reorder_info_cache[seq_key] = info + return info + + def _long_prefill_priority_key_for_reorder(seq: SchedulerSequence, now: float): + """Prefer smaller long prompts, with age credit to avoid + starvation.""" + info = _get_prefill_reorder_info(seq) + wait_age = max(0.0, now - seq.arrive_time) + age_credit = int(wait_age // self._long_prefill_aging_seconds_per_chunk) + age_adjusted_chunks = info.estimated_long_chunks - age_credit + return age_adjusted_chunks, info.estimated_long_chunks, seq.arrive_time + + def _split_waiting_by_prefill_kind(waiting: SeqList): + """Split waiting requests into normal/final and non-final long + prefill.""" + normal_waiting: SeqList = [] + long_waiting: SeqList = [] + for seq in waiting: + if _get_prefill_reorder_info(seq).is_nonfinal_long_prefill: + long_waiting.append(seq) + else: + normal_waiting.append(seq) + return normal_waiting, long_waiting + + def _sort_normal_prefills(waiting: SeqList): + return sorted(waiting, key=lambda seq: (_get_prefill_reorder_info(seq).prefill_token_count, + seq.arrive_time)) + + def _sort_long_prefills_for_long_turn(waiting: SeqList): + if self._long_prefill_policy != 'size': + return waiting + now = time.perf_counter() + return sorted(waiting, key=lambda seq: _long_prefill_priority_key_for_reorder(seq, now)) + + def _reorder_waiting_for_long_turn(waiting: SeqList): + """Choose one long waiter, then fill the turn with normal + prefills.""" + normal_waiting, long_waiting = _split_waiting_by_prefill_kind(waiting) + if len(long_waiting) == 0: + return None + + long_waiting = _sort_long_prefills_for_long_turn(long_waiting) + normal_waiting = _sort_normal_prefills(normal_waiting) + return [long_waiting[0]] + normal_waiting + long_waiting[1:] + + def _reorder_waiting_for_short_turn(waiting: SeqList): + """Prioritize normal/final prefills while preserving long + waiters.""" + normal_waiting, long_waiting = _split_waiting_by_prefill_kind(waiting) + return _sort_normal_prefills(normal_waiting) + long_waiting def _reorder_waiting(): """Reorder waiting.""" - return sorted(self.waiting, key=lambda seq: seq.arrive_time) + waiting = sorted(self.waiting, key=lambda seq: seq.arrive_time) + if prefer_long_prefill: + # Long-work turns choose one long waiter first. The size policy + # only reorders this long lane; it is not global + # shortest-prefill-first admission. + long_turn_waiting = _reorder_waiting_for_long_turn(waiting) + if long_turn_waiting is not None: + return long_turn_waiting + + if allow_long_prefill: + return waiting + + return _reorder_waiting_for_short_turn(waiting) num_waiting = self.seq_manager.num_sequences(MessageStatus.WAITING) if (len(running) >= max_batches or num_waiting == 0): return running, swap_in_map, swap_out_map, copy_map waiting = _reorder_waiting() + skipped_waiting: SeqList = [] while len(waiting) > 0 and len(running) < max_batches: seq = waiting.pop(0) - - if (len(running) > 0 and token_count + seq.num_token_ids > self.cache_config.max_prefill_token_num): + gate_check = self._check_prefill_admission_gates(seq, + token_count=token_count, + has_admitted=len(running) > 0, + allow_long_prefill=allow_long_prefill) + + def __reject_after_prefill_gate_match_rollback(): + """Reject if resource admission rolled back a gate-only hit.""" + if gate_check.prefix_match is None: + return False + if gate_check.rollback_action == _PREFILL_GATE_SKIP: + skipped_waiting.append(seq) + return True + return False + + if gate_check.reject_action is not None: + if gate_check.reject_action == _PREFILL_GATE_SKIP: + skipped_waiting.append(seq) + continue break + evictable_waiting = skipped_waiting + waiting + if self.block_trie.enable: - stats_snapshot = self.block_trie.snapshot_stats() + if gate_check.prefix_match is None: + stats_snapshot = self.block_trie.snapshot_stats() + else: + stats_snapshot = gate_check.prefix_match.stats_snapshot def __rollback_prefix_match(reason: str): if logger.isEnabledFor(logging.DEBUG): @@ -337,46 +668,74 @@ def __rollback_prefix_match(reason: str): f'restore_state={seq.prefix_cache.restore_state}') self._rollback_unscheduled_prefix_match(seq, stats_snapshot) - self.block_trie.match(seq) - if self._prefix_hit_starts_middle_long_context_chunk(seq): - __rollback_prefix_match('long-context chunk starts after prefix hit') + if gate_check.prefix_match is None: + self.block_trie.match(seq) + if self._prefix_hit_starts_middle_long_context_chunk(seq): + __rollback_prefix_match('long-context chunk starts after prefix hit') had_ssm_restore = self.is_ssm and seq.prefix_cache.restore_state >= 0 if not self._acquire_ssm_restore_if_needed(seq): __rollback_prefix_match('failed to acquire SSM restore checkpoint') + if gate_check.prefix_match is not None: + if __reject_after_prefill_gate_match_rollback(): + continue + break - if not __evict_for_seq(seq, waiting): + evicted, alloc_prealloc_size = __prepare_and_evict(seq, evictable_waiting) + if not evicted: if not had_ssm_restore: __rollback_prefix_match('eviction failed') + if __reject_after_prefill_gate_match_rollback(): + continue break # A matched SSM restore may be pinning the only checkpoint # state that eviction would otherwise free. Roll it back once # and retry eviction before declaring the sequence unschedulable. __rollback_prefix_match('eviction failed with pinned SSM restore') - if not __evict_for_seq(seq, waiting): + if __reject_after_prefill_gate_match_rollback(): + continue + if gate_check.prefix_match is not None: + break + evicted, alloc_prealloc_size = __prepare_and_evict(seq, evictable_waiting) + if not evicted: break # allocate session memory if self.is_ssm and not self._ensure_runtime_state_available(): __rollback_prefix_match('no runtime SSM state available') - if not __evict_for_seq(seq, waiting): + if __reject_after_prefill_gate_match_rollback(): + continue + if gate_check.prefix_match is not None: + break + evicted, alloc_prealloc_size = __prepare_and_evict(seq, evictable_waiting) + if not evicted: break if not self._ensure_runtime_state_available(): + seq.kv_token_limit = None break else: - if not __evict_for_seq(seq, waiting): + evicted, alloc_prealloc_size = __prepare_and_evict(seq, evictable_waiting) + if not evicted: break - self.block_manager.allocate(seq, prealloc_size) + # Prefix-cache matching can advance the sequence step and shrink + # the remaining prefill tail. Charge the admitted batch with the + # post-match/post-rollback cost, not the conservative pre-match + # estimate used to decide whether this sequence is worth trying. + prefill_token_count = self._prefill_admission_token_count(seq) + self.block_manager.allocate(seq, alloc_prealloc_size) if self.block_trie.enable: self.block_trie.allocate(seq) if self.is_ssm: self.state_manager.allocate(seq) if self.block_trie.enable: self._finish_prefix_cache_schedule(seq) - _to_running(seq) + _to_running(seq, prefill_token_count) seq.record_event(EventType.SCHEDULED) + if seq.kv_token_limit is not None: + break + return running, swap_in_map, swap_out_map, copy_map @record_function('schedule_decoding') @@ -433,10 +792,14 @@ def __evict_for_seq(seq: SchedulerSequence, num_required_blocks: int): return self.ready[:self.scheduler_config.max_batches], swap_in_map, swap_out_map, copy_map - def schedule(self, is_prefill: bool, prealloc_size: int = 0): + def schedule(self, + is_prefill: bool, + prealloc_size: int = 0, + allow_long_prefill: bool = True, + prefer_long_prefill: bool = False): """Schedule inputs for next steps.""" if is_prefill: - output = self._schedule_prefill(prealloc_size) + output = self._schedule_prefill(prealloc_size, allow_long_prefill, prefer_long_prefill) else: output = self._schedule_decoding(prealloc_size) running, swap_in_map, swap_out_map, copy_map = output diff --git a/lmdeploy/pytorch/paging/seq_states/states.py b/lmdeploy/pytorch/paging/seq_states/states.py index befa7e8ee4..834ca6cc62 100644 --- a/lmdeploy/pytorch/paging/seq_states/states.py +++ b/lmdeploy/pytorch/paging/seq_states/states.py @@ -15,6 +15,7 @@ def _free_seq(seq: SchedulerSequence, scheduler: 'Scheduler'): seq.prefix_cache.last_shared_node = None seq.prefix_cache.match_start_step = -1 seq.cached_tokens = 0 + seq.kv_token_limit = None if seq.num_blocks > 0: scheduler.block_manager.free(seq) if seq.logical_state >= 0: diff --git a/lmdeploy/pytorch/spec_decode/spec_agent.py b/lmdeploy/pytorch/spec_decode/spec_agent.py index 1894328169..4670b7b98d 100644 --- a/lmdeploy/pytorch/spec_decode/spec_agent.py +++ b/lmdeploy/pytorch/spec_decode/spec_agent.py @@ -204,10 +204,9 @@ def _prepare_inputs_from_main(self, model_inputs: ModelInputs, extra_inputs: Ext history_lengths = model_inputs.history_lengths.clone() if not model_inputs.is_chunk: - # Dummy inputs are DP placeholders and should not disturb - # local long-context carry-over state. - if not model_inputs.is_dummy: - self._prev_chunk_last.clear() + # Non-chunk prefill/decode can be interleaved between long-context + # chunks. Keep pending chunk carry here; a new first chunk clears it + # explicitly, and the final chunk consumes it. # Case A: non-chunked — shift left by 1, place next_token at end input_ids = model_inputs.input_ids.clone() input_ids[:, :-1] = model_inputs.input_ids[:, 1:] diff --git a/tests/pytorch/engine/test_inputs_maker.py b/tests/pytorch/engine/test_inputs_maker.py index 667be525b3..e1e4cd3301 100644 --- a/tests/pytorch/engine/test_inputs_maker.py +++ b/tests/pytorch/engine/test_inputs_maker.py @@ -5,10 +5,12 @@ import pytest +import lmdeploy.pytorch.engine.inputs_maker as inputs_maker_module from lmdeploy.pytorch.disagg.config import EngineRole -from lmdeploy.pytorch.engine.engine_loop import EngineLoop +from lmdeploy.pytorch.engine.engine_loop import EngineLoop, RunableEventAsync from lmdeploy.pytorch.engine.inputs_maker import ( InputsMakerAsync, + InputsMakerConfig, LongContextChunker, _compact_state_prefix_cache_restore_offsets, _compact_state_prefix_cache_save_offsets, @@ -38,6 +40,10 @@ def __init__(self, self.return_logits = False self.return_routed_experts = False self.return_ce_loss = False + self.status = MessageStatus.RUNNING + + def set_step(self, step: int): + self.num_history_ids = step def get_input_multimodals(self): return self._input_multimodals @@ -60,12 +66,36 @@ def _state_seq(logical_state: int, restore_state: int = -1): class _FakeScheduler: - def __init__(self, running): + def __init__(self, running, waiting=None, num_ready=0, num_running=0): self.running = running - - def schedule(self, is_prefill: bool, prealloc_size: int): + self.waiting = waiting or [] + self._num_ready = num_ready + self._num_running = num_running + + def schedule(self, + is_prefill: bool, + prealloc_size: int, + allow_long_prefill: bool = True, + prefer_long_prefill: bool = False): + self.allow_long_prefill = allow_long_prefill + self.prefer_long_prefill = prefer_long_prefill return SimpleNamespace(running=self.running, swap_in_map={}, swap_out_map={}) + def reserve_long_context_chunk(self, seq, chunk_size: int, prealloc_size: int = 0, is_last_chunk: bool = False): + return True + + def has_waiting(self): + return len(self.waiting) > 0 + + def has_waiting_long_prefill(self): + return False + + def num_ready(self): + return self._num_ready + + def num_running(self): + return self._num_running + class _FakeEngineStrategy: @@ -88,6 +118,14 @@ def make_stopping_criteria(self, running): return None +def _fake_model_inputs(is_chunk: bool = False): + return SimpleNamespace(is_decoding=False, + is_chunk=is_chunk, + is_first_chunk=False, + is_last_chunk=False, + is_chunk_multimodal=False) + + def test_engine_loop_skips_prefix_cache_publish_when_disabled(): class _DisabledBlockTrie: @@ -168,6 +206,93 @@ async def get_output_async(self): assert not block_trie.pinned +def test_engine_loop_treats_pending_long_context_chunk_as_runnable(): + events = [] + + class _Scheduler: + + def has_unfinished(self): + return False + + def collect_migration_done(self): + events.append('collect_migration_done') + + class _InputsMaker: + + def has_pending_long_context_chunk(self): + return True + + async def send_next_inputs(self): + events.append('send_next_inputs') + return 'forward_inputs', ['long-seq'] + + loop = EngineLoop.__new__(EngineLoop) + loop.scheduler = _Scheduler() + loop.inputs_maker = _InputsMaker() + loop.has_runable_event = RunableEventAsync(loop.scheduler, loop.inputs_maker.has_pending_long_context_chunk) + loop._sleep_requested = False + + result = asyncio.run(asyncio.wait_for(loop._main_loop_try_send_next_inputs(), timeout=1.0)) + + assert result == ('forward_inputs', ['long-seq']) + assert events == ['collect_migration_done', 'send_next_inputs'] + + +def _make_policy_maker(long_seq, decode_seq=None): + maker = InputsMakerAsync.__new__(InputsMakerAsync) + maker.config = SimpleNamespace(role=EngineRole.Decode) + maker.spec_decoding = False + maker.scheduler = _FakeScheduler([]) + maker.engine_strategy = _FakeEngineStrategy() + maker.sampling_strategy = _FakeSamplingStrategy() + maker.model_agent_strategy = _FakeModelAgentStrategy() + maker.long_context_chunker = LongContextChunker(max_prefill_token_num=512) + maker.long_context_chunker.set_seq(long_seq) + maker.running_seqs = [] if decode_seq is None else [decode_seq] + maker.to_evict_seqs = [] + maker._decode_count = 0 + maker._last_forward_kind = None + maker._short_prefill_turns_since_long_chunk = 0 + maker._short_prefill_turns_per_long_chunk = 3 + return maker + + +def test_inputs_maker_reads_opt_ttft_short_turns_env(monkeypatch): + monkeypatch.setattr(inputs_maker_module._envs, 'opt_ttft_short_turns', 5) + scheduler = SimpleNamespace(cache_config=SimpleNamespace(block_size=16, kernel_block_size=16)) + config = InputsMakerConfig(max_batches=1, max_prefill_token_num=512, role=EngineRole.Decode) + + maker = InputsMakerAsync( + executor=SimpleNamespace(device_type='cpu'), + scheduler=scheduler, + adapter_manager=SimpleNamespace(), + engine_strategy=_FakeEngineStrategy(), + sampling_strategy=_FakeSamplingStrategy(), + model_agent_strategy=_FakeModelAgentStrategy(), + config=config, + ) + + assert maker._short_prefill_turns_per_long_chunk == 5 + + +def test_inputs_maker_clamps_opt_ttft_short_turns_to_one(monkeypatch): + monkeypatch.setattr(inputs_maker_module._envs, 'opt_ttft_short_turns', 0) + scheduler = SimpleNamespace(cache_config=SimpleNamespace(block_size=16, kernel_block_size=16)) + config = InputsMakerConfig(max_batches=1, max_prefill_token_num=512, role=EngineRole.Decode) + + maker = InputsMakerAsync( + executor=SimpleNamespace(device_type='cpu'), + scheduler=scheduler, + adapter_manager=SimpleNamespace(), + engine_strategy=_FakeEngineStrategy(), + sampling_strategy=_FakeSamplingStrategy(), + model_agent_strategy=_FakeModelAgentStrategy(), + config=config, + ) + + assert maker._short_prefill_turns_per_long_chunk == 1 + + def test_long_context_chunker_uses_cached_multimodal_size_for_chunk_limit(): image = _DummyMultiModal(start=512, end=5888) seq = _DummySeq( @@ -287,6 +412,7 @@ def test_long_context_final_chunk_preserves_multimodal_flag_for_spec_decoding(): all_multimodals={'image': [image]}, input_multimodals={}, ) + model_inputs = SimpleNamespace(is_decoding=False, is_chunk=False, is_first_chunk=False, @@ -320,6 +446,541 @@ def test_long_context_final_chunk_preserves_multimodal_flag_for_spec_decoding(): assert not maker.long_context_chunker.enabled() +def test_long_context_chunk_defers_to_decode_after_chunk_forward(): + long_seq = _DummySeq(history_ids=0, token_ids=1024, all_multimodals={}, input_multimodals={}) + decode_seq = _DummySeq(history_ids=0, token_ids=1, all_multimodals={}, input_multimodals={}) + delta = SimpleNamespace(is_decoding=True) + maker = _make_policy_maker(long_seq, decode_seq) + maker._last_forward_kind = 'long_context_chunk' + maker.create_model_inputs_delta = lambda: (delta, [decode_seq], []) + maker.create_model_inputs_long_context = lambda *args, **kwargs: (_ for _ in ()).throw( + AssertionError('long chunk should wait behind decode')) + + forward_inputs = maker._make_forward_inputs(prefill=False) + + assert forward_inputs['inputs'] is None + assert forward_inputs['delta'] is delta + assert maker.to_evict_seqs == [] + + +def test_long_context_chunk_runs_after_decode_forward(): + long_seq = _DummySeq(history_ids=0, token_ids=1024, all_multimodals={}, input_multimodals={}) + decode_seq = _DummySeq(history_ids=0, token_ids=1, all_multimodals={}, input_multimodals={}) + model_inputs = SimpleNamespace(is_decoding=False, + is_chunk=True, + is_first_chunk=False, + is_last_chunk=False, + is_chunk_multimodal=False) + maker = _make_policy_maker(long_seq, decode_seq) + maker._last_forward_kind = 'decode' + maker.create_model_inputs_delta = lambda: (_ for _ in ()).throw(AssertionError('decode should not repeat')) + maker.create_model_inputs_long_context = lambda seq, chunk_size, multimodals: model_inputs + + forward_inputs = maker._make_forward_inputs(prefill=False) + + assert forward_inputs['inputs'] is model_inputs + assert forward_inputs['delta'] is None + assert not model_inputs.is_first_chunk + assert not model_inputs.is_last_chunk + + +def test_abandoned_long_context_chunk_is_dropped_without_cleanup_forward(): + long_seq = _DummySeq(history_ids=0, token_ids=1024, all_multimodals={}, input_multimodals={}) + maker = _make_policy_maker(long_seq) + long_seq.status = MessageStatus.STOPPED + maker.create_model_inputs_delta = lambda: (_ for _ in ()).throw(AssertionError('decode should not run')) + maker.create_model_inputs_long_context = lambda *args, **kwargs: (_ for _ in ()).throw( + AssertionError('abandoned chunk should not continue')) + + forward_inputs = maker._make_forward_inputs(prefill=False) + + assert forward_inputs is None + assert not maker.long_context_chunker.enabled() + + +def test_deferred_long_context_chunk_runs_when_decode_has_no_valid_seqs(): + long_seq = _DummySeq(history_ids=0, token_ids=1024, all_multimodals={}, input_multimodals={}) + decode_seq = _DummySeq(history_ids=0, token_ids=1, all_multimodals={}, input_multimodals={}) + model_inputs = SimpleNamespace(is_decoding=False, + is_chunk=True, + is_first_chunk=False, + is_last_chunk=False, + is_chunk_multimodal=False) + maker = _make_policy_maker(long_seq, decode_seq) + maker._decode_count = 3 + maker._last_forward_kind = 'long_context_chunk' + maker.create_model_inputs_delta = lambda: (None, [], [decode_seq]) + maker.create_model_inputs_long_context = lambda seq, chunk_size, multimodals: model_inputs + + forward_inputs = maker._make_forward_inputs(prefill=False) + + assert forward_inputs['inputs'] is model_inputs + assert forward_inputs['delta'] is None + assert maker.to_evict_seqs == [decode_seq] + assert maker._decode_count == 0 + + +def test_long_context_chunk_falls_back_to_decode_when_chunk_reservation_fails(): + long_seq = _DummySeq(history_ids=0, token_ids=1024, all_multimodals={}, input_multimodals={}) + decode_seq = _DummySeq(history_ids=0, token_ids=1, all_multimodals={}, input_multimodals={}) + delta = SimpleNamespace(is_decoding=True) + maker = _make_policy_maker(long_seq, decode_seq) + maker._last_forward_kind = 'decode' + maker.scheduler.reserve_long_context_chunk = lambda *args, **kwargs: False + maker.create_model_inputs_delta = lambda: (delta, [decode_seq], []) + maker.create_model_inputs_long_context = lambda *args, **kwargs: (_ for _ in ()).throw( + AssertionError('chunk inputs should not be created without KV reservation')) + + forward_inputs = maker._make_forward_inputs(prefill=False) + + assert forward_inputs['inputs'] is None + assert forward_inputs['delta'] is delta + + +def test_last_long_context_chunk_waits_for_prefill_turn_with_decode_ready(): + long_seq = _DummySeq(history_ids=512, token_ids=256, all_multimodals={}, input_multimodals={}) + decode_seq = _DummySeq(history_ids=0, token_ids=1, all_multimodals={}, input_multimodals={}) + delta = SimpleNamespace(is_decoding=True) + maker = _make_policy_maker(long_seq, decode_seq) + maker._last_forward_kind = 'long_context_chunk' + maker.create_model_inputs_delta = lambda: (delta, [decode_seq], []) + + forward_inputs = maker._make_forward_inputs(prefill=False) + + assert forward_inputs['inputs'] is None + assert forward_inputs['delta'] is delta + assert maker.long_context_chunker.enabled() + + +def test_active_long_context_chunk_round_robin_does_not_starve_with_waiting_short_prefills(): + long_seq = _DummySeq(history_ids=0, token_ids=2048, all_multimodals={}, input_multimodals={}) + short_seqs = [ + _DummySeq(history_ids=0, token_ids=16, all_multimodals={}, input_multimodals={}) for _ in range(4) + ] + short_batches = [[seq] for seq in short_seqs] + chunk_inputs = _fake_model_inputs(is_chunk=True) + + class _RoundRobinScheduler(_FakeScheduler): + + def __init__(self): + super().__init__([], waiting=[object()]) + self.schedule_calls = 0 + + def schedule(self, + is_prefill: bool, + prealloc_size: int, + allow_long_prefill: bool = True, + prefer_long_prefill: bool = False): + assert not allow_long_prefill + assert not prefer_long_prefill + self.schedule_calls += 1 + running = short_batches.pop(0) + return SimpleNamespace(running=running, swap_in_map={}, swap_out_map={}) + + def has_waiting(self): + return len(short_batches) > 0 + + maker = _make_policy_maker(long_seq) + maker.scheduler = _RoundRobinScheduler() + maker._last_forward_kind = 'long_context_chunk' + maker.create_model_inputs_delta = lambda: (_ for _ in ()).throw(AssertionError('decode should not run')) + maker.create_model_inputs = lambda seqs, is_prefill: _fake_model_inputs() + maker.create_model_inputs_delta_valid_only = lambda: (None, [], []) + maker.create_model_inputs_long_context = lambda seq, chunk_size, multimodals: chunk_inputs + + first = maker._make_forward_inputs(prefill=True) + maker._last_forward_kind = 'prefill' + second = maker._make_forward_inputs(prefill=True) + maker._last_forward_kind = 'prefill' + third = maker._make_forward_inputs(prefill=True) + maker._last_forward_kind = 'prefill' + fourth = maker._make_forward_inputs(prefill=True) + + assert first['running'] == [short_seqs[0]] + assert not first['inputs'].is_chunk + assert second['running'] == [short_seqs[1]] + assert not second['inputs'].is_chunk + assert third['running'] == [short_seqs[2]] + assert not third['inputs'].is_chunk + assert fourth['running'] == [long_seq] + assert fourth['inputs'] is chunk_inputs + assert fourth['inputs'].is_chunk + assert not fourth['inputs'].is_last_chunk + assert maker.scheduler.schedule_calls == 3 + + +def test_active_long_context_chunk_obeys_short_prefill_quota_after_decode_turn(): + long_seq = _DummySeq(history_ids=0, token_ids=2048, all_multimodals={}, input_multimodals={}) + short_seq = _DummySeq(history_ids=0, token_ids=16, all_multimodals={}, input_multimodals={}) + calls = [] + + class _ShortPrefillScheduler(_FakeScheduler): + + def __init__(self): + super().__init__([], waiting=[short_seq]) + + def schedule(self, + is_prefill: bool, + prealloc_size: int, + allow_long_prefill: bool = True, + prefer_long_prefill: bool = False): + calls.append((allow_long_prefill, prefer_long_prefill)) + assert not allow_long_prefill + assert not prefer_long_prefill + return SimpleNamespace(running=[short_seq], swap_in_map={}, swap_out_map={}) + + def has_waiting(self): + return True + + maker = _make_policy_maker(long_seq) + maker.scheduler = _ShortPrefillScheduler() + maker._last_forward_kind = 'decode' + maker.create_model_inputs = lambda seqs, is_prefill: _fake_model_inputs() + maker.create_model_inputs_delta = lambda: (_ for _ in ()).throw(AssertionError('decode should not run')) + maker.create_model_inputs_delta_valid_only = lambda: (None, [], []) + maker.create_model_inputs_long_context = lambda *args, **kwargs: (_ for _ in ()).throw( + AssertionError('long chunk should wait behind short prefill quota')) + + forward_inputs = maker._make_forward_inputs(prefill=False) + + assert forward_inputs['running'] == [short_seq] + assert not forward_inputs['inputs'].is_chunk + assert forward_inputs['delta'] is None + assert maker._short_prefill_turns_since_long_chunk == 1 + assert calls == [(False, False)] + + +def test_active_long_context_chunk_does_not_start_another_waiting_long_prefill(): + active_long_seq = _DummySeq(history_ids=0, token_ids=2048, all_multimodals={}, input_multimodals={}) + waiting_long_seq = _DummySeq(history_ids=0, token_ids=2048, all_multimodals={}, input_multimodals={}) + waiting_long_seq.status = MessageStatus.WAITING + short_seqs = [ + _DummySeq(history_ids=0, token_ids=16, all_multimodals={}, input_multimodals={}) for _ in range(3) + ] + short_batches = [[seq] for seq in short_seqs] + chunk_inputs = _fake_model_inputs(is_chunk=True) + calls = [] + + class _ActiveLongScheduler(_FakeScheduler): + + def __init__(self): + super().__init__([], waiting=[waiting_long_seq]) + + def schedule(self, + is_prefill: bool, + prealloc_size: int, + allow_long_prefill: bool = True, + prefer_long_prefill: bool = False): + calls.append((allow_long_prefill, prefer_long_prefill)) + assert not allow_long_prefill + assert not prefer_long_prefill + return SimpleNamespace(running=short_batches.pop(0), swap_in_map={}, swap_out_map={}) + + def has_waiting(self): + return True + + def has_waiting_long_prefill(self): + return True + + maker = _make_policy_maker(active_long_seq) + maker.scheduler = _ActiveLongScheduler() + maker._last_forward_kind = 'long_context_chunk' + maker.create_model_inputs_delta = lambda: (_ for _ in ()).throw(AssertionError('decode should not run')) + maker.create_model_inputs = lambda seqs, is_prefill: _fake_model_inputs() + maker.create_model_inputs_delta_valid_only = lambda: (None, [], []) + maker.create_model_inputs_long_context = lambda seq, chunk_size, multimodals: chunk_inputs + + first = maker._make_forward_inputs(prefill=True) + maker._last_forward_kind = 'prefill' + second = maker._make_forward_inputs(prefill=True) + maker._last_forward_kind = 'prefill' + third = maker._make_forward_inputs(prefill=True) + maker._last_forward_kind = 'prefill' + fourth = maker._make_forward_inputs(prefill=True) + + assert first['running'] == [short_seqs[0]] + assert second['running'] == [short_seqs[1]] + assert third['running'] == [short_seqs[2]] + assert fourth['running'] == [active_long_seq] + assert fourth['inputs'] is chunk_inputs + assert fourth['inputs'].is_chunk + assert not fourth['inputs'].is_last_chunk + assert waiting_long_seq.status == MessageStatus.WAITING + assert calls == [(False, False), (False, False), (False, False)] + + +def test_active_long_context_chunk_reservation_failure_blocks_short_prefill_and_drains_decode(): + active_long_seq = _DummySeq(history_ids=0, token_ids=2048, all_multimodals={}, input_multimodals={}) + decode_seq = _DummySeq(history_ids=0, token_ids=1, all_multimodals={}, input_multimodals={}) + short_seq = _DummySeq(history_ids=0, token_ids=16, all_multimodals={}, input_multimodals={}) + delta = SimpleNamespace(is_decoding=True) + calls = [] + + class _ReserveFailScheduler(_FakeScheduler): + + def __init__(self): + super().__init__([], waiting=[short_seq]) + + def reserve_long_context_chunk(self, + seq, + chunk_size: int, + prealloc_size: int = 0, + is_last_chunk: bool = False): + return False + + def schedule(self, + is_prefill: bool, + prealloc_size: int, + allow_long_prefill: bool = True, + prefer_long_prefill: bool = False): + calls.append((allow_long_prefill, prefer_long_prefill)) + raise AssertionError('short prefill should wait while active long chunk is KV-blocked') + + def has_waiting(self): + return True + + maker = _make_policy_maker(active_long_seq, decode_seq) + maker.scheduler = _ReserveFailScheduler() + maker._last_forward_kind = 'prefill' + maker._short_prefill_turns_since_long_chunk = maker._short_prefill_turns_per_long_chunk + maker.create_model_inputs = lambda seqs, is_prefill: (_ for _ in ()).throw( + AssertionError('short prefill should not run')) + maker.create_model_inputs_delta = lambda: (delta, [decode_seq], []) + maker.create_model_inputs_delta_valid_only = lambda: (None, [], []) + maker.create_model_inputs_long_context = lambda *args, **kwargs: (_ for _ in ()).throw( + AssertionError('chunk inputs should not be created without KV reservation')) + + forward_inputs = maker._make_forward_inputs(prefill=True) + + assert forward_inputs['inputs'] is None + assert forward_inputs['delta'] is delta + assert forward_inputs['running'] == [decode_seq] + assert maker.long_context_chunker.enabled() + assert maker.long_context_chunker.next_step == 0 + assert calls == [] + + +def test_active_long_context_chunk_reservation_failure_blocks_short_prefill_without_decode(): + active_long_seq = _DummySeq(history_ids=0, token_ids=2048, all_multimodals={}, input_multimodals={}) + short_seq = _DummySeq(history_ids=0, token_ids=16, all_multimodals={}, input_multimodals={}) + calls = [] + + class _ReserveFailScheduler(_FakeScheduler): + + def __init__(self): + super().__init__([], waiting=[short_seq]) + + def reserve_long_context_chunk(self, + seq, + chunk_size: int, + prealloc_size: int = 0, + is_last_chunk: bool = False): + return False + + def schedule(self, + is_prefill: bool, + prealloc_size: int, + allow_long_prefill: bool = True, + prefer_long_prefill: bool = False): + calls.append((allow_long_prefill, prefer_long_prefill)) + raise AssertionError('short prefill should wait while active long chunk is KV-blocked') + + def has_waiting(self): + return True + + maker = _make_policy_maker(active_long_seq) + maker.scheduler = _ReserveFailScheduler() + maker._last_forward_kind = 'prefill' + maker._short_prefill_turns_since_long_chunk = maker._short_prefill_turns_per_long_chunk + maker.create_model_inputs = lambda seqs, is_prefill: (_ for _ in ()).throw( + AssertionError('short prefill should not run')) + maker.create_model_inputs_delta = lambda: (_ for _ in ()).throw(AssertionError('decode should not run')) + maker.create_model_inputs_delta_valid_only = lambda: (None, [], []) + maker.create_model_inputs_long_context = lambda *args, **kwargs: (_ for _ in ()).throw( + AssertionError('chunk inputs should not be created without KV reservation')) + + forward_inputs = maker._make_forward_inputs(prefill=True) + + assert forward_inputs is None + assert maker.long_context_chunker.enabled() + assert maker.long_context_chunker.next_step == 0 + assert calls == [] + + +def test_waiting_long_context_first_chunk_gets_round_robin_turn_after_short_prefills(): + long_seq = _DummySeq(history_ids=0, token_ids=2048, all_multimodals={}, input_multimodals={}) + short_seqs = [ + _DummySeq(history_ids=0, token_ids=16, all_multimodals={}, input_multimodals={}) for _ in range(3) + ] + chunk_inputs = _fake_model_inputs(is_chunk=True) + calls = [] + + class _RoundRobinScheduler(_FakeScheduler): + + def __init__(self): + super().__init__([], waiting=[object()]) + + def schedule(self, + is_prefill: bool, + prealloc_size: int, + allow_long_prefill: bool = True, + prefer_long_prefill: bool = False): + calls.append((allow_long_prefill, prefer_long_prefill)) + if prefer_long_prefill: + return SimpleNamespace(running=[long_seq], swap_in_map={}, swap_out_map={}) + running = [short_seqs[len(calls) - 1]] + return SimpleNamespace(running=running, swap_in_map={}, swap_out_map={}) + + def has_waiting(self): + return True + + def has_waiting_long_prefill(self): + return True + + maker = InputsMakerAsync.__new__(InputsMakerAsync) + maker.config = SimpleNamespace(role=EngineRole.Decode) + maker.spec_decoding = False + maker.scheduler = _RoundRobinScheduler() + maker.engine_strategy = _FakeEngineStrategy() + maker.sampling_strategy = _FakeSamplingStrategy() + maker.model_agent_strategy = _FakeModelAgentStrategy() + maker.long_context_chunker = LongContextChunker(max_prefill_token_num=512) + maker.running_seqs = [] + maker.to_evict_seqs = [] + maker._decode_count = 0 + maker._last_forward_kind = None + maker._short_prefill_turns_since_long_chunk = 0 + maker._short_prefill_turns_per_long_chunk = 3 + maker.create_model_inputs = lambda seqs, is_prefill: _fake_model_inputs() + maker.create_model_inputs_delta_valid_only = lambda: (None, [], []) + maker.create_model_inputs_long_context = lambda seq, chunk_size, multimodals: chunk_inputs + + first = maker._make_forward_inputs(prefill=True) + second = maker._make_forward_inputs(prefill=True) + third = maker._make_forward_inputs(prefill=True) + fourth = maker._make_forward_inputs(prefill=True) + + assert first['running'] == [short_seqs[0]] + assert not first['inputs'].is_chunk + assert second['running'] == [short_seqs[1]] + assert not second['inputs'].is_chunk + assert third['running'] == [short_seqs[2]] + assert not third['inputs'].is_chunk + assert fourth['running'] == [long_seq] + assert fourth['inputs'] is chunk_inputs + assert fourth['inputs'].is_first_chunk + assert calls == [(False, False), (False, False), (False, False), (True, True)] + + +def test_waiting_long_context_admission_failure_falls_back_to_short_prefill(): + long_seq = _DummySeq(history_ids=0, token_ids=2048, all_multimodals={}, input_multimodals={}) + short_seq = _DummySeq(history_ids=0, token_ids=16, all_multimodals={}, input_multimodals={}) + calls = [] + + class _WaitingLongFailScheduler(_FakeScheduler): + + def __init__(self): + super().__init__([], waiting=[long_seq, short_seq]) + + def schedule(self, + is_prefill: bool, + prealloc_size: int, + allow_long_prefill: bool = True, + prefer_long_prefill: bool = False): + calls.append((allow_long_prefill, prefer_long_prefill)) + if prefer_long_prefill: + return SimpleNamespace(running=[], swap_in_map={}, swap_out_map={}) + assert not allow_long_prefill + return SimpleNamespace(running=[short_seq], swap_in_map={}, swap_out_map={}) + + def has_waiting(self): + return True + + def has_waiting_long_prefill(self): + return True + + maker = InputsMakerAsync.__new__(InputsMakerAsync) + maker.config = SimpleNamespace(role=EngineRole.Decode) + maker.spec_decoding = False + maker.scheduler = _WaitingLongFailScheduler() + maker.engine_strategy = _FakeEngineStrategy() + maker.sampling_strategy = _FakeSamplingStrategy() + maker.model_agent_strategy = _FakeModelAgentStrategy() + maker.long_context_chunker = LongContextChunker(max_prefill_token_num=512) + maker.running_seqs = [] + maker.to_evict_seqs = [] + maker._decode_count = 0 + maker._last_forward_kind = None + maker._short_prefill_turns_since_long_chunk = 3 + maker._short_prefill_turns_per_long_chunk = 3 + maker.create_model_inputs = lambda seqs, is_prefill: _fake_model_inputs() + maker.create_model_inputs_delta = lambda: (_ for _ in ()).throw(AssertionError('decode should not run')) + maker.create_model_inputs_delta_valid_only = lambda: (None, [], []) + maker.create_model_inputs_long_context = lambda *args, **kwargs: (_ for _ in ()).throw( + AssertionError('long prefill should not create chunk inputs after admission failure')) + + forward_inputs = maker._make_forward_inputs(prefill=True) + + assert forward_inputs['running'] == [short_seq] + assert not forward_inputs['inputs'].is_chunk + assert forward_inputs['delta'] is None + assert calls == [(True, True), (False, False)] + + +def test_normal_prefill_can_update_running_while_long_chunker_is_active(): + long_seq = _DummySeq(history_ids=0, token_ids=1024, all_multimodals={}, input_multimodals={}) + short_seq = _DummySeq(history_ids=0, token_ids=16, all_multimodals={}, input_multimodals={}) + model_inputs = _fake_model_inputs() + maker = _make_policy_maker(long_seq) + + maker.update_running_seqs([short_seq], model_inputs) + + assert maker.running_seqs == [short_seq] + assert maker.long_context_chunker.enabled() + assert maker.long_context_chunker.next_step == 0 + + +def test_last_long_context_chunk_runs_as_prefill_on_prefill_turn(): + image = _DummyMultiModal(start=600, end=700) + long_seq = _DummySeq(history_ids=512, + token_ids=256, + all_multimodals={'image': [image]}, + input_multimodals={'image': [image]}) + decode_seq = _DummySeq(history_ids=0, token_ids=1, all_multimodals={}, input_multimodals={}) + model_inputs = SimpleNamespace(is_decoding=False, + is_chunk=False, + is_first_chunk=False, + is_last_chunk=False, + is_chunk_multimodal=False) + maker = _make_policy_maker(long_seq, decode_seq) + maker._last_forward_kind = 'long_context_chunk' + maker.create_model_inputs = lambda seqs, is_prefill: model_inputs + maker.create_model_inputs_delta_valid_only = lambda: (None, [decode_seq], []) + + forward_inputs = maker._make_forward_inputs(prefill=True) + + assert forward_inputs['inputs'] is model_inputs + assert model_inputs.is_chunk + assert model_inputs.is_last_chunk + assert model_inputs.is_chunk_multimodal + assert not maker.long_context_chunker.enabled() + + +def test_do_prefill_default_forces_pending_last_chunk_prefill(): + long_seq = _DummySeq(history_ids=512, token_ids=256, all_multimodals={}, input_multimodals={}) + maker = InputsMakerAsync.__new__(InputsMakerAsync) + maker.config = SimpleNamespace(role=EngineRole.Decode, + max_prefill_token_num=512, + max_batches=1, + prefill_interval=100) + maker.scheduler = _FakeScheduler([], num_ready=1, num_running=1) + maker.long_context_chunker = LongContextChunker(max_prefill_token_num=512) + maker.long_context_chunker.set_seq(long_seq) + maker._decode_count = 0 + + assert maker.do_prefill_default() + + def test_state_prefix_cache_restore_offsets_are_compact(): messages = [_state_seq(4, 11), _state_seq(5, -1), _state_seq(6, 13)] diff --git a/tests/pytorch/engine/test_model_agent.py b/tests/pytorch/engine/test_model_agent.py index 4b92d925dc..568610fd03 100644 --- a/tests/pytorch/engine/test_model_agent.py +++ b/tests/pytorch/engine/test_model_agent.py @@ -41,6 +41,46 @@ def _make_agent_with_queues(): return agent +def test_prepare_inputs_prefill_keeps_chunk_model_metas_across_interleaved_prefill(): + from lmdeploy.pytorch.engine.model_agent.agent import BaseModelAgent + + agent = BaseModelAgent.__new__(BaseModelAgent) + prev_output = {'model_metas': [{'chunk': 1}]} + agent._prev_chunk_output = prev_output + + normal_prefill = SimpleNamespace(is_chunk=False, + is_first_chunk=False, + is_last_chunk=False, + model_metas=[{ + 'normal': 1 + }]) + + agent._prepare_inputs_prefill(normal_prefill, delta=None) + + assert agent._prev_chunk_output is prev_output + assert normal_prefill.model_metas == [{'normal': 1}] + + middle_chunk = SimpleNamespace(is_chunk=True, is_first_chunk=False, is_last_chunk=False, model_metas=None) + + agent._prepare_inputs_prefill(middle_chunk, delta=None) + + assert middle_chunk.model_metas == [{'chunk': 1}] + assert agent._prev_chunk_output is prev_output + + +def test_prepare_inputs_prefill_final_chunk_consumes_chunk_model_metas(): + from lmdeploy.pytorch.engine.model_agent.agent import BaseModelAgent + + agent = BaseModelAgent.__new__(BaseModelAgent) + agent._prev_chunk_output = {'model_metas': [{'chunk': 1}]} + final_chunk = SimpleNamespace(is_chunk=True, is_first_chunk=False, is_last_chunk=True, model_metas=None) + + agent._prepare_inputs_prefill(final_chunk, delta=None) + + assert final_chunk.model_metas == [{'chunk': 1}] + assert agent._prev_chunk_output is None + + class TestDrainQueues: def test_drain_empty_queues(self): diff --git a/tests/pytorch/paging/test_block_manager.py b/tests/pytorch/paging/test_block_manager.py index b08116d7f4..25611d565c 100644 --- a/tests/pytorch/paging/test_block_manager.py +++ b/tests/pytorch/paging/test_block_manager.py @@ -343,3 +343,29 @@ def test_win_alloc(self, scheduler, block_mgr, num_gpu_blocks, window_size): block_table = block_mgr.get_block_table(msg) assert block_table is None or len(block_table) == 2 block_mgr.free(msg) + + def test_win_alloc_respects_kv_token_limit(self, scheduler, block_mgr, num_gpu_blocks, window_size): + sess = scheduler.add_session(0) + block_size = sess.seq_meta.block_size + + token_ids = torch.tensor([1] * (window_size * 3)) + msg = sess.add_sequence(token_ids) + + msg.kv_token_limit = window_size + block_mgr.allocate(msg) + assert len(block_mgr.get_block_table(msg)) == 2 + assert block_mgr.get_num_free_gpu_blocks() == num_gpu_blocks - 2 + + msg.set_step(window_size) + msg.kv_token_limit = window_size + block_size + block_mgr.allocate(msg) + assert len(block_mgr.get_block_table(msg)) == 3 + assert block_mgr.get_num_free_gpu_blocks() == num_gpu_blocks - 3 + + msg.set_step(window_size + block_size) + msg.kv_token_limit = window_size + block_size * 2 + assert block_mgr.num_required_blocks(msg) == 1 + block_mgr.allocate(msg) + assert len(block_mgr.get_block_table(msg)) == 3 + assert block_mgr.get_num_free_gpu_blocks() == num_gpu_blocks - 3 + assert msg.num_ignored_history == block_size diff --git a/tests/pytorch/paging/test_scheduler.py b/tests/pytorch/paging/test_scheduler.py index d91b217be2..b34d53c87e 100644 --- a/tests/pytorch/paging/test_scheduler.py +++ b/tests/pytorch/paging/test_scheduler.py @@ -1,6 +1,9 @@ +import time + import pytest import torch +import lmdeploy.pytorch.paging.scheduler as scheduler_module from lmdeploy.pytorch.config import CacheConfig, SchedulerConfig from lmdeploy.pytorch.disagg.conn.protocol import MigrationProtocol, MigrationRequest from lmdeploy.pytorch.engine.inputs_maker import _compact_state_prefix_cache_save_offsets @@ -221,13 +224,13 @@ def test_state_manager_caps_runtime_count_even_with_extra_free_slots(): manager.allocate_state() -def _make_ssm_scheduler(max_batch_size: int = 1, prefix_cache_state_budget: int = 0): +def _make_ssm_scheduler(max_batch_size: int = 1, prefix_cache_state_budget: int = 0, num_gpu_blocks: int = 16): from lmdeploy.pytorch.strategies.ar.sequence import ARSequenceStrategy block_size = 16 cache_config = CacheConfig(max_batches=max_batch_size, block_size=block_size, num_cpu_blocks=4, - num_gpu_blocks=16, + num_gpu_blocks=num_gpu_blocks, enable_prefix_caching=True, num_state_caches=max_batch_size + 1 + prefix_cache_state_budget, prefix_cache_state_budget=prefix_cache_state_budget, @@ -515,6 +518,242 @@ def test_scheduler_publishes_cached_tokens_for_accepted_prefix_hit(): assert seq.prefix_cache.match_start_step == -1 +def test_scheduler_recomputes_prefill_budget_after_prefix_hit(): + from lmdeploy.pytorch.strategies.ar.sequence import ARSequenceStrategy + block_size = 16 + seq_meta = SequenceMeta(block_size, strategy=ARSequenceStrategy()) + cache_config = CacheConfig(max_batches=2, + block_size=block_size, + num_cpu_blocks=0, + num_gpu_blocks=8, + max_prefill_token_num=block_size, + enable_prefix_caching=True) + scheduler_config = SchedulerConfig(max_batches=2, + max_session_len=128, + max_request_output_len=64, + eviction_type='recompute') + scheduler = Scheduler(scheduler_config=scheduler_config, cache_config=cache_config, seq_meta=seq_meta) + + cached = scheduler.add_session(0).add_sequence([1] * block_size + [2]) + scheduler.schedule(is_prefill=True) + cached.state.stop() + + cache_hit_tail = scheduler.add_session(1).add_sequence([1] * block_size + [3]) + short = scheduler.add_session(2).add_sequence([4]) + + output = scheduler.schedule(is_prefill=True) + + assert output.running == [cache_hit_tail, short] + assert cache_hit_tail.num_history_ids == block_size + assert cache_hit_tail.num_token_ids == 1 + assert short.status == MessageStatus.READY + + +def _make_prefix_cache_scheduler(max_batches: int = 2, max_prefill_token_num: int = 16): + from lmdeploy.pytorch.strategies.ar.sequence import ARSequenceStrategy + block_size = 16 + seq_meta = SequenceMeta(block_size, strategy=ARSequenceStrategy()) + cache_config = CacheConfig(max_batches=max_batches, + block_size=block_size, + num_cpu_blocks=0, + num_gpu_blocks=8, + max_prefill_token_num=max_prefill_token_num, + enable_prefix_caching=True) + scheduler_config = SchedulerConfig(max_batches=max_batches, + max_session_len=128, + max_request_output_len=64, + eviction_type='recompute') + scheduler = Scheduler(scheduler_config=scheduler_config, cache_config=cache_config, seq_meta=seq_meta) + return scheduler, block_size + + +def test_scheduler_short_turn_uses_prefix_hit_to_admit_long_looking_sibling(): + scheduler, block_size = _make_prefix_cache_scheduler(max_batches=2, max_prefill_token_num=16) + + cached = scheduler.add_session(0).add_sequence([1] * block_size) + scheduler.schedule(is_prefill=True) + cached.state.stop() + + short = scheduler.add_session(1).add_sequence([4]) + cache_hit_tail = scheduler.add_session(2).add_sequence([1] * block_size + [3]) + + output = scheduler.schedule(is_prefill=True, allow_long_prefill=False) + + assert output.running == [short, cache_hit_tail] + assert cache_hit_tail.num_history_ids == block_size + assert cache_hit_tail.num_token_ids == 1 + assert cache_hit_tail.cached_tokens == block_size + + +def test_scheduler_budget_gate_uses_prefix_hit_to_admit_sibling(): + scheduler, block_size = _make_prefix_cache_scheduler(max_batches=2, max_prefill_token_num=16) + + cached = scheduler.add_session(0).add_sequence([1] * block_size) + scheduler.schedule(is_prefill=True) + cached.state.stop() + + almost_full = scheduler.add_session(1).add_sequence([4] * (block_size - 1)) + cache_hit_tail = scheduler.add_session(2).add_sequence([1] * block_size + [3]) + + output = scheduler.schedule(is_prefill=True) + + assert output.running == [almost_full, cache_hit_tail] + assert cache_hit_tail.num_history_ids == block_size + assert cache_hit_tail.num_token_ids == 1 + + +def test_scheduler_reorder_cache_stays_order_only_after_prefix_hit(): + scheduler, block_size = _make_prefix_cache_scheduler(max_batches=2, max_prefill_token_num=16) + + cached = scheduler.add_session(0).add_sequence([1] * block_size) + scheduler.schedule(is_prefill=True) + cached.state.stop() + + cache_hit_tail = scheduler.add_session(1).add_sequence([1] * block_size + [3]) + normal = scheduler.add_session(2).add_sequence([4] * (block_size - 1)) + + output = scheduler.schedule(is_prefill=True, prefer_long_prefill=True) + + assert output.running == [cache_hit_tail, normal] + assert cache_hit_tail.num_history_ids == block_size + assert cache_hit_tail.num_token_ids == 1 + assert cache_hit_tail.cached_tokens == block_size + assert normal.status == MessageStatus.READY + + +def test_scheduler_rolls_back_prefix_match_for_prefill_gate_when_tail_still_exceeds_budget(): + scheduler, block_size = _make_prefix_cache_scheduler(max_batches=2, max_prefill_token_num=16) + + cached = scheduler.add_session(0).add_sequence([1] * block_size) + scheduler.schedule(is_prefill=True) + cached.state.stop() + + full = scheduler.add_session(1).add_sequence([4] * block_size) + cache_hit_tail = scheduler.add_session(2).add_sequence([1] * block_size + [3]) + + output = scheduler.schedule(is_prefill=True, allow_long_prefill=False) + + assert output.running == [full] + assert cache_hit_tail.status == MessageStatus.WAITING + assert cache_hit_tail.num_history_ids == 0 + assert cache_hit_tail.cached_tokens == 0 + assert cache_hit_tail.prefix_cache.last_shared_node is None + assert cache_hit_tail.prefix_cache.match_start_step == -1 + + +def test_scheduler_rolls_back_prefix_match_for_prefill_gate_that_still_needs_long_chunk(): + scheduler, block_size = _make_prefix_cache_scheduler(max_batches=1, max_prefill_token_num=16) + + cached = scheduler.add_session(0).add_sequence([1] * block_size) + scheduler.schedule(is_prefill=True) + cached.state.stop() + scheduler.block_trie.stats.reset() + + still_long = scheduler.add_session(1).add_sequence([1] * block_size + [3] * (block_size + 1)) + + output = scheduler.schedule(is_prefill=True, allow_long_prefill=False) + + assert output.running == [] + assert still_long.status == MessageStatus.WAITING + assert still_long.num_history_ids == 0 + assert still_long.cached_tokens == 0 + assert still_long.prefix_cache.last_shared_node is None + assert still_long.prefix_cache.match_start_step == -1 + assert scheduler.block_trie.stats.num_query_tokens == 0 + assert scheduler.block_trie.stats.num_hit_tokens == 0 + + +def test_ssm_scheduler_rolls_back_prefix_match_for_prefill_gate_without_pinning_restore_state(): + scheduler = _make_ssm_scheduler(max_batch_size=1, prefix_cache_state_budget=1) + scheduler.cache_config.max_prefill_token_num = scheduler.seq_meta.block_size + block_size = scheduler.seq_meta.block_size + node, state_idx = _add_ready_ssm_checkpoint(scheduler, [1] * block_size * 2) + scheduler.block_trie.stats.reset() + + still_long = scheduler.add_session(100).add_sequence([1] * block_size * 2 + [3] * (block_size + 1)) + + output = scheduler.schedule(is_prefill=True, allow_long_prefill=False) + + assert output.running == [] + assert still_long.status == MessageStatus.WAITING + assert still_long.num_history_ids == 0 + assert still_long.cached_tokens == 0 + assert still_long.prefix_cache.last_shared_node is None + assert still_long.prefix_cache.restore_state == -1 + assert still_long.prefix_cache.restore_node is None + assert not still_long.prefix_cache.restore_state_acquired + assert node.state_idx == state_idx + assert node.state_ref_count == 0 + assert scheduler.block_trie.stats.num_query_tokens == 0 + assert scheduler.block_trie.stats.num_hit_tokens == 0 + + +def test_ssm_scheduler_rejects_prefix_match_for_prefill_gate_after_pinned_restore_rollback(): + scheduler = _make_ssm_scheduler(max_batch_size=1, prefix_cache_state_budget=1, num_gpu_blocks=2) + scheduler.cache_config.max_prefill_token_num = scheduler.seq_meta.block_size + block_size = scheduler.seq_meta.block_size + node, state_idx = _add_ready_ssm_checkpoint(scheduler, [1] * block_size * 2) + scheduler.block_trie.stats.reset() + + cache_hit_tail = scheduler.add_session(100).add_sequence([1] * block_size * 2 + [3]) + + output = scheduler.schedule(is_prefill=True, allow_long_prefill=False) + + assert output.running == [] + assert cache_hit_tail.status == MessageStatus.WAITING + assert cache_hit_tail.num_history_ids == 0 + assert cache_hit_tail.num_token_ids == block_size * 2 + 1 + assert cache_hit_tail.num_blocks == 0 + assert cache_hit_tail.kv_token_limit is None + assert cache_hit_tail.logical_state == -1 + assert cache_hit_tail.cached_tokens == 0 + assert cache_hit_tail.prefix_cache.last_shared_node is None + assert cache_hit_tail.prefix_cache.restore_state == -1 + assert cache_hit_tail.prefix_cache.restore_node is None + assert not cache_hit_tail.prefix_cache.restore_state_acquired + assert node.state_idx == state_idx + assert node.state_ready + assert node.state_ref_count == 0 + assert scheduler.block_trie.stats.num_query_tokens == 0 + assert scheduler.block_trie.stats.num_hit_tokens == 0 + + +def test_ssm_scheduler_rejects_prefix_match_for_prefill_gate_after_runtime_state_rollback(monkeypatch): + scheduler = _make_ssm_scheduler(max_batch_size=1, prefix_cache_state_budget=1, num_gpu_blocks=4) + scheduler.cache_config.max_prefill_token_num = scheduler.seq_meta.block_size + block_size = scheduler.seq_meta.block_size + node, state_idx = _add_ready_ssm_checkpoint(scheduler, [1] * block_size * 2) + ensure_results = iter([False, True]) + + def _ensure_runtime_state_available_once_then_succeed(): + return next(ensure_results) + + monkeypatch.setattr(scheduler, '_ensure_runtime_state_available', _ensure_runtime_state_available_once_then_succeed) + scheduler.block_trie.stats.reset() + + cache_hit_tail = scheduler.add_session(100).add_sequence([1] * block_size * 2 + [3]) + + output = scheduler.schedule(is_prefill=True, allow_long_prefill=False) + + assert output.running == [] + assert cache_hit_tail.status == MessageStatus.WAITING + assert cache_hit_tail.num_history_ids == 0 + assert cache_hit_tail.num_token_ids == block_size * 2 + 1 + assert cache_hit_tail.num_blocks == 0 + assert cache_hit_tail.kv_token_limit is None + assert cache_hit_tail.logical_state == -1 + assert cache_hit_tail.cached_tokens == 0 + assert cache_hit_tail.prefix_cache.last_shared_node is None + assert cache_hit_tail.prefix_cache.restore_state == -1 + assert cache_hit_tail.prefix_cache.restore_node is None + assert not cache_hit_tail.prefix_cache.restore_state_acquired + assert node.state_idx == state_idx + assert node.state_ready + assert node.state_ref_count == 0 + assert scheduler.block_trie.stats.num_query_tokens == 0 + assert scheduler.block_trie.stats.num_hit_tokens == 0 + + def test_scheduler_reports_zero_cached_tokens_for_prefix_miss(): from lmdeploy.pytorch.strategies.ar.sequence import ARSequenceStrategy block_size = 16 @@ -613,6 +852,296 @@ def test_scheduler_excludes_recompute_eviction_prefix_hits_from_stats(): assert scheduler.block_trie.stats.num_hit_tokens == 0 +def _make_scheduler_for_decode_growth(num_gpu_blocks: int = 2): + from lmdeploy.pytorch.strategies.ar.sequence import ARSequenceStrategy + block_size = 4 + seq_meta = SequenceMeta(block_size, strategy=ARSequenceStrategy()) + cache_config = CacheConfig(max_batches=2, + block_size=block_size, + num_cpu_blocks=0, + num_gpu_blocks=num_gpu_blocks, + max_prefill_token_num=block_size * 4) + scheduler_config = SchedulerConfig(max_batches=2, + max_session_len=64, + max_request_output_len=64, + eviction_type='recompute') + scheduler = Scheduler(scheduler_config=scheduler_config, cache_config=cache_config, seq_meta=seq_meta) + return scheduler, block_size + + +def _make_scheduler_for_long_context_chunks(num_gpu_blocks: int = 6): + from lmdeploy.pytorch.strategies.ar.sequence import ARSequenceStrategy + block_size = 4 + seq_meta = SequenceMeta(block_size, strategy=ARSequenceStrategy()) + cache_config = CacheConfig(max_batches=2, + block_size=block_size, + num_cpu_blocks=0, + num_gpu_blocks=num_gpu_blocks, + max_prefill_token_num=block_size * 2) + scheduler_config = SchedulerConfig(max_batches=2, + max_session_len=64, + max_request_output_len=64, + eviction_type='recompute') + scheduler = Scheduler(scheduler_config=scheduler_config, cache_config=cache_config, seq_meta=seq_meta) + return scheduler, block_size + + +def _make_ssm_scheduler_for_long_context_chunks(num_gpu_blocks: int = 2): + from lmdeploy.pytorch.strategies.ar.sequence import ARSequenceStrategy + block_size = 4 + seq_meta = SequenceMeta(block_size, strategy=ARSequenceStrategy()) + cache_config = CacheConfig(max_batches=1, + block_size=block_size, + num_cpu_blocks=0, + num_gpu_blocks=num_gpu_blocks, + max_prefill_token_num=block_size * 2, + num_state_caches=2, + states_shapes=[((1, ), torch.float32)]) + scheduler_config = SchedulerConfig(max_batches=1, + max_session_len=64, + max_request_output_len=64, + eviction_type='recompute') + scheduler = Scheduler(scheduler_config=scheduler_config, cache_config=cache_config, seq_meta=seq_meta) + return scheduler, block_size + + +def test_schedule_running_reclaims_waiting_blocks_for_decode_growth(): + scheduler, block_size = _make_scheduler_for_decode_growth(num_gpu_blocks=2) + decode = scheduler.add_session(100).add_sequence([1] * block_size) + waiting = scheduler.add_session(101).add_sequence([2] * block_size) + + output = scheduler.schedule(is_prefill=True) + assert output.running == [decode, waiting] + scheduler.activate_seqs([decode]) + waiting.state.evict() + assert decode.status == MessageStatus.RUNNING + assert waiting.status == MessageStatus.WAITING + assert scheduler.block_manager.get_num_free_gpu_blocks() == 0 + + valid_mask = scheduler.schedule_running([decode], num_required_tokens=1, prealloc_size=1) + + assert valid_mask == [True] + assert decode.status == MessageStatus.RUNNING + assert decode.num_blocks == 2 + assert waiting.status == MessageStatus.WAITING + assert waiting.num_blocks == 0 + assert scheduler.block_manager.get_num_free_gpu_blocks() == 0 + + +def test_schedule_running_keeps_other_running_sequence_when_decode_growth_fails(): + scheduler, block_size = _make_scheduler_for_decode_growth(num_gpu_blocks=2) + decode = scheduler.add_session(100).add_sequence([1] * block_size) + long_chunk = scheduler.add_session(101).add_sequence([2] * block_size) + + output = scheduler.schedule(is_prefill=True) + assert output.running == [decode, long_chunk] + scheduler.activate_seqs([decode, long_chunk]) + assert decode.status == MessageStatus.RUNNING + assert long_chunk.status == MessageStatus.RUNNING + assert scheduler.block_manager.get_num_free_gpu_blocks() == 0 + + valid_mask = scheduler.schedule_running([decode], num_required_tokens=1, prealloc_size=1) + + assert valid_mask == [False] + assert decode.status == MessageStatus.WAITING + assert long_chunk.status == MessageStatus.RUNNING + assert long_chunk.num_blocks == 1 + assert scheduler.block_manager.get_num_free_gpu_blocks() == 0 + + +def test_schedule_prefill_allocates_only_first_long_context_chunk(): + scheduler, block_size = _make_scheduler_for_long_context_chunks(num_gpu_blocks=2) + long_seq = scheduler.add_session(100).add_sequence([1] * (block_size * 4)) + + output = scheduler.schedule(is_prefill=True, prealloc_size=1) + + assert output.running == [long_seq] + assert long_seq.status == MessageStatus.READY + assert long_seq.kv_token_limit == block_size * 2 + assert long_seq.num_blocks == 2 + assert scheduler.block_manager.get_num_free_gpu_blocks() == 0 + + +def test_schedule_prefill_short_only_skips_long_waiter_without_mutation(): + scheduler, block_size = _make_scheduler_for_long_context_chunks(num_gpu_blocks=8) + head_long = scheduler.add_session(100).add_sequence([1] * (block_size * 4)) + short_a = scheduler.add_session(101).add_sequence([2] * (block_size // 2)) + short_b = scheduler.add_session(102).add_sequence([3] * (block_size // 2)) + + output = scheduler.schedule(is_prefill=True, allow_long_prefill=False) + + assert output.running == [short_a, short_b] + assert head_long.status == MessageStatus.WAITING + assert head_long.num_blocks == 0 + assert head_long.kv_token_limit is None + assert short_a.status == MessageStatus.READY + assert short_b.status == MessageStatus.READY + + short_a.session.remove_sequence(short_a) + short_b.session.remove_sequence(short_b) + next_output = scheduler.schedule(is_prefill=True) + + assert next_output.running == [head_long] + assert head_long.status == MessageStatus.READY + assert head_long.kv_token_limit == block_size * 2 + assert head_long.num_blocks == 2 + + +def test_schedule_prefill_prefer_long_admits_oldest_long_waiter_first(): + scheduler, block_size = _make_scheduler_for_long_context_chunks(num_gpu_blocks=8) + short_a = scheduler.add_session(100).add_sequence([1] * (block_size // 2)) + old_long = scheduler.add_session(101).add_sequence([2] * (block_size * 4)) + short_b = scheduler.add_session(102).add_sequence([3] * (block_size // 2)) + new_long = scheduler.add_session(103).add_sequence([4] * (block_size * 4)) + + assert scheduler.has_waiting_long_prefill() + + output = scheduler.schedule(is_prefill=True, prefer_long_prefill=True) + + assert output.running == [old_long] + assert old_long.status == MessageStatus.READY + assert old_long.kv_token_limit == block_size * 2 + assert old_long.num_blocks == 2 + assert short_a.status == MessageStatus.WAITING + assert short_a.num_blocks == 0 + assert short_b.status == MessageStatus.WAITING + assert short_b.num_blocks == 0 + assert new_long.status == MessageStatus.WAITING + assert new_long.num_blocks == 0 + assert new_long.kv_token_limit is None + + +def test_scheduler_reads_opt_ttft_env(monkeypatch): + monkeypatch.setattr(scheduler_module._envs, 'opt_ttft_policy', 'fifo') + monkeypatch.setattr(scheduler_module._envs, 'opt_ttft_aging_sec', 0.25) + + scheduler, _ = _make_scheduler_for_long_context_chunks(num_gpu_blocks=8) + + assert scheduler._long_prefill_policy == 'fifo' + assert scheduler._long_prefill_aging_seconds_per_chunk == 0.25 + + +def test_schedule_prefill_prefer_long_fifo_policy_keeps_oldest_huge_waiter_first(): + scheduler, block_size = _make_scheduler_for_long_context_chunks(num_gpu_blocks=8) + scheduler._long_prefill_policy = 'fifo' + now = time.perf_counter() + huge_long = scheduler.add_session(100).add_sequence([1] * (block_size * 16)) + huge_long.arrive_time = now - 1.0 + moderate_long = scheduler.add_session(101).add_sequence([2] * (block_size * 4)) + moderate_long.arrive_time = now + + output = scheduler.schedule(is_prefill=True, prefer_long_prefill=True) + + assert output.running == [huge_long] + assert huge_long.status == MessageStatus.READY + assert huge_long.kv_token_limit == block_size * 2 + assert huge_long.num_blocks == 2 + assert moderate_long.status == MessageStatus.WAITING + assert moderate_long.num_blocks == 0 + assert moderate_long.kv_token_limit is None + + +def test_schedule_prefill_prefer_long_admits_smaller_long_waiter_first(): + scheduler, block_size = _make_scheduler_for_long_context_chunks(num_gpu_blocks=8) + now = time.perf_counter() + huge_long = scheduler.add_session(100).add_sequence([1] * (block_size * 16)) + huge_long.arrive_time = now - 1.0 + moderate_long = scheduler.add_session(101).add_sequence([2] * (block_size * 4)) + moderate_long.arrive_time = now + short = scheduler.add_session(102).add_sequence([3] * (block_size // 2)) + + output = scheduler.schedule(is_prefill=True, prefer_long_prefill=True) + + assert output.running == [moderate_long] + assert moderate_long.status == MessageStatus.READY + assert moderate_long.kv_token_limit == block_size * 2 + assert moderate_long.num_blocks == 2 + assert huge_long.status == MessageStatus.WAITING + assert huge_long.num_blocks == 0 + assert huge_long.kv_token_limit is None + assert short.status == MessageStatus.WAITING + assert short.num_blocks == 0 + + +def test_schedule_prefill_prefer_long_ages_huge_long_waiter(): + scheduler, block_size = _make_scheduler_for_long_context_chunks(num_gpu_blocks=8) + scheduler._long_prefill_aging_seconds_per_chunk = 0.01 + now = time.perf_counter() + huge_long = scheduler.add_session(100).add_sequence([1] * (block_size * 16)) + huge_long.arrive_time = now - 1.0 + moderate_long = scheduler.add_session(101).add_sequence([2] * (block_size * 4)) + moderate_long.arrive_time = now + + output = scheduler.schedule(is_prefill=True, prefer_long_prefill=True) + + assert output.running == [huge_long] + assert huge_long.status == MessageStatus.READY + assert huge_long.kv_token_limit == block_size * 2 + assert huge_long.num_blocks == 2 + assert moderate_long.status == MessageStatus.WAITING + assert moderate_long.num_blocks == 0 + assert moderate_long.kv_token_limit is None + + +def test_schedule_prefill_reapplies_chunk_limit_after_ssm_state_rollback(): + scheduler, block_size = _make_ssm_scheduler_for_long_context_chunks(num_gpu_blocks=2) + long_seq = scheduler.add_session(100).add_sequence([1] * (block_size * 4)) + + ensure_results = iter([False, True]) + + def _ensure_runtime_state_available_once_then_succeed(): + return next(ensure_results) + + scheduler._ensure_runtime_state_available = _ensure_runtime_state_available_once_then_succeed + + output = scheduler.schedule(is_prefill=True, prealloc_size=1) + + assert output.running == [long_seq] + assert long_seq.status == MessageStatus.READY + assert long_seq.kv_token_limit == block_size * 2 + assert long_seq.num_blocks == 2 + + +def test_reserve_long_context_chunk_grows_one_chunk_at_a_time(): + scheduler, block_size = _make_scheduler_for_long_context_chunks(num_gpu_blocks=6) + long_seq = scheduler.add_session(100).add_sequence([1] * (block_size * 5)) + + output = scheduler.schedule(is_prefill=True, prealloc_size=1) + assert output.running == [long_seq] + assert long_seq.kv_token_limit == block_size * 2 + assert long_seq.num_blocks == 2 + + scheduler.activate_seqs([long_seq]) + long_seq.set_step(block_size * 2) + + assert scheduler.reserve_long_context_chunk(long_seq, block_size * 2) + assert long_seq.status == MessageStatus.RUNNING + assert long_seq.kv_token_limit == block_size * 4 + assert long_seq.num_blocks == 4 + + long_seq.set_step(block_size * 4) + + assert scheduler.reserve_long_context_chunk(long_seq, block_size, prealloc_size=1, is_last_chunk=True) + assert long_seq.kv_token_limit is None + assert long_seq.num_blocks == 6 + assert scheduler.block_manager.get_num_free_gpu_blocks() == 0 + + +def test_reserve_long_context_chunk_failure_preserves_committed_prefix(): + scheduler, block_size = _make_scheduler_for_long_context_chunks(num_gpu_blocks=2) + long_seq = scheduler.add_session(100).add_sequence([1] * (block_size * 4)) + + output = scheduler.schedule(is_prefill=True) + assert output.running == [long_seq] + scheduler.activate_seqs([long_seq]) + long_seq.set_step(block_size * 2) + + assert not scheduler.reserve_long_context_chunk(long_seq, block_size * 2) + assert long_seq.status == MessageStatus.RUNNING + assert long_seq.kv_token_limit == block_size * 2 + assert long_seq.num_blocks == 2 + + def test_scheduler_rolls_back_prefix_hit_that_would_start_long_context_chunk_from_middle(): from lmdeploy.pytorch.strategies.ar.sequence import ARSequenceStrategy block_size = 16 diff --git a/tests/pytorch/spec_decode/test_spec_agent.py b/tests/pytorch/spec_decode/test_spec_agent.py index d121bed2a3..2742a0b31e 100644 --- a/tests/pytorch/spec_decode/test_spec_agent.py +++ b/tests/pytorch/spec_decode/test_spec_agent.py @@ -2,7 +2,9 @@ import torch -from lmdeploy.pytorch.spec_decode.spec_agent import _expand_sampling_inputs +from lmdeploy.pytorch.model_inputs import DPMeta, ModelInputs +from lmdeploy.pytorch.spec_decode.spec_agent import SpecModelAgent, _expand_sampling_inputs +from lmdeploy.pytorch.strategies.ar_spec.model_agent import ARSpecExtraInputs device = 'cuda' if torch.cuda.is_available() else 'cpu' @@ -312,3 +314,94 @@ def test_slice_sampling_inputs_prefill(): sampling_inputs = SamplingInputs(max_top_k=1, batch_size=2) result = _slice_sampling_inputs(sampling_inputs, 1) assert result is sampling_inputs + + +def _model_inputs(input_ids, + *, + is_decoding=False, + is_chunk=False, + is_first_chunk=False, + is_last_chunk=False, + dp_meta=None): + input_ids = torch.tensor([input_ids]) + seq_length = torch.tensor([input_ids.size(1)]) + history_lengths = torch.tensor([0]) + max_q_seqlen = input_ids.size(1) + return ModelInputs( + input_ids=input_ids, + seq_length=seq_length, + history_lengths=history_lengths, + block_offsets=torch.zeros(1, 1, dtype=torch.int32), + is_decoding=is_decoding, + num_ignored_history=torch.zeros(1, dtype=torch.long), + max_q_seqlen=max_q_seqlen, + max_kv_seqlen=max_q_seqlen, + sum_kv_seqlen=max_q_seqlen, + is_chunk=is_chunk, + is_first_chunk=is_first_chunk, + is_last_chunk=is_last_chunk, + dp_meta=dp_meta, + ) + + +def _extra(hidden_values): + hidden_states = torch.tensor([hidden_values], dtype=torch.float32) + return ARSpecExtraInputs( + target_hidden_states=hidden_states, + next_token_ids=torch.tensor([99]), + last_token_indices=torch.tensor([hidden_states.size(1) - 1]), + ) + + +def test_prepare_inputs_from_main_keeps_chunk_carry_across_decode(): + agent = SpecModelAgent.__new__(SpecModelAgent) + agent._prev_chunk_last = {} + + first_chunk = _model_inputs([10, 11, 12], is_chunk=True, is_first_chunk=True) + agent._prepare_inputs_from_main(first_chunk, _extra([[1, 10], [2, 20], [3, 30]])) + saved_first_chunk_last = agent._prev_chunk_last['hidden_states'].clone() + + decode = _model_inputs([90, 91, 92], is_decoding=True) + agent._prepare_inputs_from_main(decode, _extra([[9, 90], [8, 80], [7, 70]])) + + assert torch.equal(agent._prev_chunk_last['hidden_states'], saved_first_chunk_last) + + middle_chunk = _model_inputs([20, 21, 22], is_chunk=True) + draft_inputs, _ = agent._prepare_inputs_from_main(middle_chunk, _extra([[4, 40], [5, 50], [6, 60]])) + + assert torch.equal(draft_inputs.target_hidden_states[:, :1], saved_first_chunk_last) + assert torch.equal(agent._prev_chunk_last['hidden_states'], torch.tensor([[[6., 60.]]])) + + +def test_prepare_inputs_from_main_keeps_chunk_carry_across_interleaved_prefill(): + agent = SpecModelAgent.__new__(SpecModelAgent) + saved = torch.ones(1, 1, 2) + agent._prev_chunk_last = {'hidden_states': saved.clone()} + + prefill = _model_inputs([10, 11, 12]) + agent._prepare_inputs_from_main(prefill, _extra([[1, 10], [2, 20], [3, 30]])) + + torch.testing.assert_close(agent._prev_chunk_last['hidden_states'], saved) + + +def test_prepare_inputs_from_main_first_chunk_clears_stale_chunk_carry(): + agent = SpecModelAgent.__new__(SpecModelAgent) + agent._prev_chunk_last = {'hidden_states': torch.ones(1, 1, 2)} + + first_chunk = _model_inputs([10, 11, 12], is_chunk=True, is_first_chunk=True) + agent._prepare_inputs_from_main(first_chunk, _extra([[1, 10], [2, 20], [3, 30]])) + + torch.testing.assert_close(agent._prev_chunk_last['hidden_states'], torch.tensor([[[3., 30.]]])) + + +def test_prepare_inputs_from_main_keeps_chunk_carry_for_dp_local_decode_global_prefill(): + agent = SpecModelAgent.__new__(SpecModelAgent) + saved = torch.ones(1, 1, 2) + agent._prev_chunk_last = {'hidden_states': saved.clone()} + agent.proposer = _DummyProposer() + + dp_meta = DPMeta(dp_batches=[1, 1], dp_is_decoding=False) + inputs = _model_inputs([90, 91, 92], is_decoding=True, dp_meta=dp_meta) + agent._prepare_inputs_from_main(inputs, _extra([[9, 90], [8, 80], [7, 70]])) + + assert torch.equal(agent._prev_chunk_last['hidden_states'], saved)