diff --git a/lmdeploy/pytorch/engine/inputs_maker.py b/lmdeploy/pytorch/engine/inputs_maker.py index ab58826069..f83ecd233f 100644 --- a/lmdeploy/pytorch/engine/inputs_maker.py +++ b/lmdeploy/pytorch/engine/inputs_maker.py @@ -1,4 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. +"""Engine-loop input construction for the LMDeploy PyTorch backend. + +This module converts scheduler decisions into model-agent inputs. Most helpers +build tensor fields for full-batch ``ModelInputs``; ``InputsMakerAsync`` is the +coordinator that chooses prefill/chunk/decode work, attaches per-forward +metadata, dispatches it to the executor, and updates local running state. +""" import logging from collections import defaultdict from dataclasses import dataclass @@ -243,6 +250,39 @@ def check_enable(self): class InputsMakerAsync: + """Coordinate prefill, decode, and long-context input dispatch. + + ``Scheduler`` owns admission, ordering, and cache/KV resources. This class + consumes the scheduler result and builds tensors only after resources have + been granted. Prefill-like work is represented by full ``ModelInputs``: + prompt prefill, final long-context chunks, and eager non-final long chunks. + Decode is represented by ``ModelInputsDelta`` and reuses persistent + model-agent/strategy ``StepInputs`` that were created by earlier prefill and + decode forwards. + + ``running_seqs`` is local engine-loop state, not the scheduler's source of + truth. It tracks sequences already sent to the executor so this class can + build decode deltas, evict invalid decode requests, and update the local + view after outputs return. Every dispatched forward also carries the + strategy-specific ``extra_inputs``, sampling inputs, and stopping criteria + expected by the model agent. + + Long-context chunking is coordinated here because it spans scheduling + policy and input construction. ``LongContextChunker`` tracks one active + long prefill and selects model-safe chunk boundaries, including indivisible + multimodal spans. Before tensors are created for each chunk, the scheduler + reserves the chunk's KV ownership. Non-final chunks are eager chunk + forwards with no user-visible output; the final chunk is treated as normal + prefill so it can merge into persistent decode state. + + The current first-slice chunked-prefill policy intentionally uses separate + forwards instead of one mixed decode+prefill tensor batch. After a + non-final chunk, runnable decode is preferred and remains on the existing + delta/CUDAGraph path; at most one eager non-final long chunk is sent after + decode gets a chance to run. Preserve chunk flags such as + ``is_chunk_multimodal`` and ``is_last_chunk`` because VLM and speculative + decoding paths interpret them downstream. + """ def __init__( self, @@ -272,6 +312,7 @@ def __init__( # consecutive decode counter for prefill starvation prevention self._decode_count = 0 + self._last_forward_kind = None # record for next forward. self.next_is_prefill = True @@ -293,6 +334,38 @@ def _init_do_prefill(self, config: InputsMakerConfig): else: self.do_prefill = self.do_prefill_default + def _has_pending_last_long_context_chunk(self): + """Check whether a running long context has only its final chunk + left.""" + return self.long_context_chunker.enabled() and self.long_context_chunker.is_last_chunk() + + def _should_decode_before_long_context_chunk(self, prefill: bool): + """Prefer decode when a long-context chunk should not monopolize the + loop.""" + if self.config.role == EngineRole.Prefill: + return False + if len(self.running_seqs) == 0: + return False + if not self.long_context_chunker.enabled(): + return False + if self.long_context_chunker.is_last_chunk(): + return not prefill + return getattr(self, '_last_forward_kind', None) == 'long_context_chunk' + + def _forward_kind(self, inputs: 'ModelInputs|None', delta: 'ModelInputsDelta|None'): + """Classify a queued forward for long-context interleaving policy.""" + if inputs is None: + if delta is not None: + return 'decode' + return None + if inputs.is_chunk and not inputs.is_last_chunk: + return 'long_context_chunk' + if inputs.is_chunk: + return 'last_long_context_chunk' + if inputs.is_decoding: + return 'decode' + return 'prefill' + def _create_vision_model_inputs(self, messages: 'SeqList', model_inputs: ModelInputs): """Create vision model inputs.""" batch_size = len(messages) @@ -734,26 +807,41 @@ def __create_model_inputs(seqs): extra_inputs = self.model_agent_strategy.make_extra_inputs(seqs, inputs) return inputs, delta, extra_inputs - def __create_inputs_chunk(running: 'SeqList'): - chunk_size, multimodals = self.long_context_chunker.next_chunk_size() + def __create_inputs_chunk(running: 'SeqList', chunk_size: int, multimodals: 'MultiModalInputs|None'): inputs = self.create_model_inputs_long_context(running[0], chunk_size, multimodals) extra_inputs = self.model_agent_strategy.make_extra_inputs(running, inputs) return inputs, extra_inputs + def __reserve_long_context_chunk(seq: 'SchedulerSequence', chunk_size: int, is_last_chunk: bool): + if self.config.role == EngineRole.Prefill: + prealloc_size = 0 + elif is_last_chunk: + prealloc_size = self.engine_strategy.get_prealloc_size(True) + else: + prealloc_size = 0 + return scheduler.reserve_long_context_chunk(seq, + chunk_size, + prealloc_size=prealloc_size, + is_last_chunk=is_last_chunk) + def __create_inputs_long_context_chunk(): seq = self.long_context_chunker.seq + chunk_size, multimodals = self.long_context_chunker.next_chunk_size() + is_last_chunk = self.long_context_chunker.is_last_chunk() + is_chunk_multimodal = self.long_context_chunker.has_multimodal + if not __reserve_long_context_chunk(seq, chunk_size, is_last_chunk): + return [], None, None, None running = [seq] - has_multimodal = self.long_context_chunker.has_multimodal - if self.long_context_chunker.is_last_chunk(): + if is_last_chunk: inputs, delta, extra_inputs = __create_model_inputs(running) inputs.is_chunk = True inputs.is_last_chunk = True self.long_context_chunker.clear() else: - inputs, extra_inputs = __create_inputs_chunk(running) + inputs, extra_inputs = __create_inputs_chunk(running, chunk_size, multimodals) delta = None inputs.is_first_chunk = False - inputs.is_chunk_multimodal = has_multimodal + inputs.is_chunk_multimodal = is_chunk_multimodal return running, inputs, delta, extra_inputs def __create_inputs_prefill(): @@ -782,7 +870,8 @@ def __create_inputs_prefill(): self.long_context_chunker.clear() inputs, delta, extra_inputs = __create_model_inputs(running) else: - inputs, extra_inputs = __create_inputs_chunk(running) + chunk_size, multimodals = self.long_context_chunker.next_chunk_size() + inputs, extra_inputs = __create_inputs_chunk(running, chunk_size, multimodals) inputs.is_first_chunk = True inputs.is_chunk_multimodal = self.long_context_chunker.has_multimodal elif len(running) > 0: @@ -795,13 +884,19 @@ def __create_inputs_prefill(): inputs = None delta = None + running = [] + extra_inputs = None swap_in_map = {} swap_out_map = {} + deferred_long_context_chunk = False self.long_context_chunker.check_enable() if self.long_context_chunker.enabled(): # long context chunking - running, inputs, delta, extra_inputs = __create_inputs_long_context_chunk() + if self._should_decode_before_long_context_chunk(prefill): + deferred_long_context_chunk = True + else: + running, inputs, delta, extra_inputs = __create_inputs_long_context_chunk() elif prefill: # prefill ( @@ -813,10 +908,6 @@ def __create_inputs_prefill(): swap_out_map, ) = __create_inputs_prefill() - # reset decode count when non-decoding inputs are produced - if inputs is not None and not inputs.is_decoding: - self._decode_count = 0 - # try decoding if inputs is None and len(self.running_seqs) > 0 and self.config.role != EngineRole.Prefill: prefill = False @@ -824,6 +915,13 @@ def __create_inputs_prefill(): self.to_evict_seqs = invalid_seqs extra_inputs = None + if inputs is None and delta is None and deferred_long_context_chunk and self.long_context_chunker.enabled(): + running, inputs, delta, extra_inputs = __create_inputs_long_context_chunk() + + # reset decode count when non-decoding inputs are produced + if inputs is not None and not inputs.is_decoding: + self._decode_count = 0 + # skip if enable empty if inputs is None and delta is None: return None @@ -858,11 +956,14 @@ def do_prefill_pnode(self): def do_prefill_default(self): # decoding if no waiting scheduler = self.scheduler + pending_last_chunk = self._has_pending_last_long_context_chunk() # do decoding if not waiting - if not scheduler.has_waiting(): + if not scheduler.has_waiting() and not pending_last_chunk: self._decode_count = 0 return False + if pending_last_chunk: + return True # force prefill if too many consecutive decode rounds if self._decode_count >= self.config.prefill_interval: @@ -906,6 +1007,7 @@ async def _send_next_inputs_impl(self, prefill: bool = None, enable_empty: bool session_ids = [seq.session_id for seq in next_running] logger.debug(f'Forward session_ids: {session_ids}') await self.executor.forward_async(forward_inputs) + self._last_forward_kind = self._forward_kind(inputs, forward_inputs['delta']) self.scheduler.tick() self.forward_inputs = forward_inputs return forward_inputs, next_running diff --git a/lmdeploy/pytorch/messages.py b/lmdeploy/pytorch/messages.py index 00888bfdbb..49729668b2 100644 --- a/lmdeploy/pytorch/messages.py +++ b/lmdeploy/pytorch/messages.py @@ -698,6 +698,9 @@ class SchedulerSequence: meta: Any = None num_ignored_history: int = 0 model_meta: dict[str, Any] = None + # Exclusive absolute token limit for temporary KV ownership. Non-final + # long-context chunks use this to allocate only the computed prefix. + kv_token_limit: int | None = None # For Disaggregation migration_request: None | MigrationRequest = None diff --git a/lmdeploy/pytorch/paging/block_manager/default_block_manager.py b/lmdeploy/pytorch/paging/block_manager/default_block_manager.py index c8f935900f..5d33ba665b 100644 --- a/lmdeploy/pytorch/paging/block_manager/default_block_manager.py +++ b/lmdeploy/pytorch/paging/block_manager/default_block_manager.py @@ -25,7 +25,10 @@ class DefaultBlockManager(BaseBlockManager): @classmethod def num_required_blocks(cls, obj: SchedulerSequence, prealloc_size: int = 0): """Get num required blocks.""" - num_tokens = obj.num_all_ids + prealloc_size + num_tokens = obj.num_all_ids + if obj.kv_token_limit is not None: + num_tokens = min(num_tokens, obj.kv_token_limit) + num_tokens += prealloc_size num_all_blocks = _div_up(num_tokens, obj.block_size) return max(0, num_all_blocks - len(obj.logical_blocks)) diff --git a/lmdeploy/pytorch/paging/block_manager/window_block_manager.py b/lmdeploy/pytorch/paging/block_manager/window_block_manager.py index cde3f90ac5..28f89e6144 100644 --- a/lmdeploy/pytorch/paging/block_manager/window_block_manager.py +++ b/lmdeploy/pytorch/paging/block_manager/window_block_manager.py @@ -42,7 +42,13 @@ def num_required_blocks(self, obj: SchedulerSequence, prealloc_size: int = 0): if obj.num_history_ids <= self.window_size: return super().num_required_blocks(obj, prealloc_size) - return super().num_required_blocks(obj, prealloc_size) - obj.num_ignored_history // obj.block_size + # DefaultBlockManager applies kv_token_limit to the absolute token + # count. Sliding-window accounting then subtracts already-dropped + # history blocks so chunk-limited allocation grows only the retained + # window. + num_required_blocks = super().num_required_blocks(obj, prealloc_size) + num_required_blocks -= obj.num_ignored_history // obj.block_size + return max(0, num_required_blocks) def can_allocate(self, msg: SchedulerSequence, prealloc_size: int = 0): """Return if physical block can be allocated for given message.""" diff --git a/lmdeploy/pytorch/paging/block_trie.py b/lmdeploy/pytorch/paging/block_trie.py index c32770a7db..cb47e8a3fa 100644 --- a/lmdeploy/pytorch/paging/block_trie.py +++ b/lmdeploy/pytorch/paging/block_trie.py @@ -1085,6 +1085,8 @@ def allocate(self, seq: SchedulerSequence): num_matched = node.num_matched num_valid_ids = seq.num_valid_ids + if seq.kv_token_limit is not None: + num_valid_ids = min(num_valid_ids, seq.kv_token_limit) if num_matched + block_size > num_valid_ids: return diff --git a/lmdeploy/pytorch/paging/scheduler.py b/lmdeploy/pytorch/paging/scheduler.py index 690a547e90..8cb6c21e77 100644 --- a/lmdeploy/pytorch/paging/scheduler.py +++ b/lmdeploy/pytorch/paging/scheduler.py @@ -141,6 +141,7 @@ def _rollback_unscheduled_prefix_match(self, seq: SchedulerSequence, stats_snaps seq.state.free() elif seq.num_history_ids > 0: seq.set_step(0) + seq.kv_token_limit = None prefix_cache = seq.prefix_cache prefix_cache.last_shared_node = None prefix_cache.restore_state = -1 @@ -177,13 +178,83 @@ def _prefix_hit_starts_middle_long_context_chunk(self, seq: SchedulerSequence): if seq.num_history_ids <= 0: return False + max_prefill_num = self._long_context_chunk_limit(seq) + return seq.num_token_ids > max_prefill_num + + def _long_context_chunk_limit(self, seq: SchedulerSequence): + """Return the token budget for one long-context chunk.""" max_prefill_num = self.cache_config.max_prefill_token_num mm_for_chunk_limit = seq.get_chunk_limit_multimodals() for value in mm_for_chunk_limit.values(): max_mm_size = max([v.end - v.start for v in value], default=0) max_prefill_num = max(max_prefill_num, max_mm_size) - return seq.num_token_ids > max_prefill_num + return max_prefill_num + + def _next_long_context_chunk_end(self, seq: SchedulerSequence): + """Return the exclusive absolute token end for the next chunk.""" + max_prefill_num = self._long_context_chunk_limit(seq) + chunk_size = min(seq.num_token_ids, max_prefill_num) + start = seq.num_history_ids + end = start + chunk_size + + input_mm = seq.get_input_multimodals() + if len(input_mm) == 0: + return end + + multimodal_data = [] + for modal_type, modal_datas in input_mm.items(): + multimodal_data += [(modal_type, data) for data in modal_datas] + multimodal_data = sorted(multimodal_data, key=lambda x: x[1].start) + + for _, data in multimodal_data: + assert data.start >= start, 'multimodal data should be sorted by start' + if data.start >= end: + break + if data.end > end: + end = data.start + break + + return end + + def _prefill_kv_token_limit(self, seq: SchedulerSequence): + """Limit KV allocation for a non-final long-context prefill chunk.""" + max_prefill_num = self._long_context_chunk_limit(seq) + if seq.num_token_ids <= max_prefill_num: + return None + return self._next_long_context_chunk_end(seq) + + def _prepare_prefill_allocation(self, seq: SchedulerSequence, prealloc_size: int): + """Apply chunk KV limit and return the effective prealloc size.""" + kv_token_limit = self._prefill_kv_token_limit(seq) + if kv_token_limit is None: + seq.kv_token_limit = None + return prealloc_size + + seq.kv_token_limit = kv_token_limit + return 0 + + def reserve_long_context_chunk(self, + seq: SchedulerSequence, + chunk_size: int, + prealloc_size: int = 0, + is_last_chunk: bool = False): + """Reserve KV blocks for the next chunk of a running long prefill.""" + old_kv_token_limit = seq.kv_token_limit + if is_last_chunk: + seq.kv_token_limit = None + else: + seq.kv_token_limit = seq.num_history_ids + chunk_size + prealloc_size = 0 + + evictable = self.hanging + self.waiting + if not self.eviction_helper.evict_for_seq(seq, evictable, prealloc_size): + seq.kv_token_limit = old_kv_token_limit + return False + + self.block_manager.allocate(seq, prealloc_size) + self.block_trie.allocate(seq) + return True @staticmethod def create_status_list_property(status: MessageStatus): @@ -304,13 +375,21 @@ def _to_running(seq: SchedulerSequence): nonlocal token_count token_count += seq.num_token_ids - def __evict_for_seq(seq: SchedulerSequence, waiting): + def __evict_for_seq(seq: SchedulerSequence, waiting, evict_prealloc_size: int): """Evict until can append.""" from itertools import chain hanging = reversed(self.hanging) waiting = reversed(waiting) evictable = list(chain(hanging, waiting)) - return eviction_helper.evict_for_seq(seq, evictable, prealloc_size) + return eviction_helper.evict_for_seq(seq, evictable, evict_prealloc_size) + + def __prepare_and_evict(seq: SchedulerSequence, waiting): + """Apply chunk allocation limits and evict for this prefill.""" + alloc_prealloc_size = self._prepare_prefill_allocation(seq, prealloc_size) + if __evict_for_seq(seq, waiting, alloc_prealloc_size): + return True, alloc_prealloc_size + seq.kv_token_limit = None + return False, alloc_prealloc_size def _reorder_waiting(): """Reorder waiting.""" @@ -345,7 +424,8 @@ def __rollback_prefix_match(reason: str): if not self._acquire_ssm_restore_if_needed(seq): __rollback_prefix_match('failed to acquire SSM restore checkpoint') - if not __evict_for_seq(seq, waiting): + evicted, alloc_prealloc_size = __prepare_and_evict(seq, waiting) + if not evicted: if not had_ssm_restore: __rollback_prefix_match('eviction failed') break @@ -353,20 +433,24 @@ def __rollback_prefix_match(reason: str): # state that eviction would otherwise free. Roll it back once # and retry eviction before declaring the sequence unschedulable. __rollback_prefix_match('eviction failed with pinned SSM restore') - if not __evict_for_seq(seq, waiting): + evicted, alloc_prealloc_size = __prepare_and_evict(seq, waiting) + if not evicted: break # allocate session memory if self.is_ssm and not self._ensure_runtime_state_available(): __rollback_prefix_match('no runtime SSM state available') - if not __evict_for_seq(seq, waiting): + evicted, alloc_prealloc_size = __prepare_and_evict(seq, waiting) + if not evicted: break if not self._ensure_runtime_state_available(): + seq.kv_token_limit = None break else: - if not __evict_for_seq(seq, waiting): + evicted, alloc_prealloc_size = __prepare_and_evict(seq, waiting) + if not evicted: break - self.block_manager.allocate(seq, prealloc_size) + self.block_manager.allocate(seq, alloc_prealloc_size) if self.block_trie.enable: self.block_trie.allocate(seq) if self.is_ssm: diff --git a/lmdeploy/pytorch/paging/seq_states/states.py b/lmdeploy/pytorch/paging/seq_states/states.py index befa7e8ee4..834ca6cc62 100644 --- a/lmdeploy/pytorch/paging/seq_states/states.py +++ b/lmdeploy/pytorch/paging/seq_states/states.py @@ -15,6 +15,7 @@ def _free_seq(seq: SchedulerSequence, scheduler: 'Scheduler'): seq.prefix_cache.last_shared_node = None seq.prefix_cache.match_start_step = -1 seq.cached_tokens = 0 + seq.kv_token_limit = None if seq.num_blocks > 0: scheduler.block_manager.free(seq) if seq.logical_state >= 0: diff --git a/lmdeploy/pytorch/spec_decode/spec_agent.py b/lmdeploy/pytorch/spec_decode/spec_agent.py index 1894328169..13d8e05f7f 100644 --- a/lmdeploy/pytorch/spec_decode/spec_agent.py +++ b/lmdeploy/pytorch/spec_decode/spec_agent.py @@ -204,9 +204,11 @@ def _prepare_inputs_from_main(self, model_inputs: ModelInputs, extra_inputs: Ext history_lengths = model_inputs.history_lengths.clone() if not model_inputs.is_chunk: - # Dummy inputs are DP placeholders and should not disturb - # local long-context carry-over state. - if not model_inputs.is_dummy: + local_is_decoding = model_inputs.is_decoding + if not model_inputs.is_dummy and not local_is_decoding: + # Non-chunk prefill starts an independent stream. Dummy DP + # placeholders and interleaved decode must not clear a pending + # long-chunk carry. self._prev_chunk_last.clear() # Case A: non-chunked — shift left by 1, place next_token at end input_ids = model_inputs.input_ids.clone() diff --git a/tests/pytorch/engine/test_inputs_maker.py b/tests/pytorch/engine/test_inputs_maker.py index 667be525b3..9236c8e1ce 100644 --- a/tests/pytorch/engine/test_inputs_maker.py +++ b/tests/pytorch/engine/test_inputs_maker.py @@ -38,6 +38,10 @@ def __init__(self, self.return_logits = False self.return_routed_experts = False self.return_ce_loss = False + self.status = MessageStatus.RUNNING + + def set_step(self, step: int): + self.num_history_ids = step def get_input_multimodals(self): return self._input_multimodals @@ -60,12 +64,27 @@ def _state_seq(logical_state: int, restore_state: int = -1): class _FakeScheduler: - def __init__(self, running): + def __init__(self, running, waiting=None, num_ready=0, num_running=0): self.running = running + self.waiting = waiting or [] + self._num_ready = num_ready + self._num_running = num_running def schedule(self, is_prefill: bool, prealloc_size: int): return SimpleNamespace(running=self.running, swap_in_map={}, swap_out_map={}) + def reserve_long_context_chunk(self, seq, chunk_size: int, prealloc_size: int = 0, is_last_chunk: bool = False): + return True + + def has_waiting(self): + return len(self.waiting) > 0 + + def num_ready(self): + return self._num_ready + + def num_running(self): + return self._num_running + class _FakeEngineStrategy: @@ -168,6 +187,23 @@ async def get_output_async(self): assert not block_trie.pinned +def _make_policy_maker(long_seq, decode_seq=None): + maker = InputsMakerAsync.__new__(InputsMakerAsync) + maker.config = SimpleNamespace(role=EngineRole.Decode) + maker.spec_decoding = False + maker.scheduler = _FakeScheduler([]) + maker.engine_strategy = _FakeEngineStrategy() + maker.sampling_strategy = _FakeSamplingStrategy() + maker.model_agent_strategy = _FakeModelAgentStrategy() + maker.long_context_chunker = LongContextChunker(max_prefill_token_num=512) + maker.long_context_chunker.set_seq(long_seq) + maker.running_seqs = [] if decode_seq is None else [decode_seq] + maker.to_evict_seqs = [] + maker._decode_count = 0 + maker._last_forward_kind = None + return maker + + def test_long_context_chunker_uses_cached_multimodal_size_for_chunk_limit(): image = _DummyMultiModal(start=512, end=5888) seq = _DummySeq( @@ -287,6 +323,7 @@ def test_long_context_final_chunk_preserves_multimodal_flag_for_spec_decoding(): all_multimodals={'image': [image]}, input_multimodals={}, ) + model_inputs = SimpleNamespace(is_decoding=False, is_chunk=False, is_first_chunk=False, @@ -320,6 +357,139 @@ def test_long_context_final_chunk_preserves_multimodal_flag_for_spec_decoding(): assert not maker.long_context_chunker.enabled() +def test_long_context_chunk_defers_to_decode_after_chunk_forward(): + long_seq = _DummySeq(history_ids=0, token_ids=1024, all_multimodals={}, input_multimodals={}) + decode_seq = _DummySeq(history_ids=0, token_ids=1, all_multimodals={}, input_multimodals={}) + delta = SimpleNamespace(is_decoding=True) + maker = _make_policy_maker(long_seq, decode_seq) + maker._last_forward_kind = 'long_context_chunk' + maker.create_model_inputs_delta = lambda: (delta, [decode_seq], []) + maker.create_model_inputs_long_context = lambda *args, **kwargs: (_ for _ in ()).throw( + AssertionError('long chunk should wait behind decode')) + + forward_inputs = maker._make_forward_inputs(prefill=False) + + assert forward_inputs['inputs'] is None + assert forward_inputs['delta'] is delta + assert maker.to_evict_seqs == [] + + +def test_long_context_chunk_runs_after_decode_forward(): + long_seq = _DummySeq(history_ids=0, token_ids=1024, all_multimodals={}, input_multimodals={}) + decode_seq = _DummySeq(history_ids=0, token_ids=1, all_multimodals={}, input_multimodals={}) + model_inputs = SimpleNamespace(is_decoding=False, + is_chunk=True, + is_first_chunk=False, + is_last_chunk=False, + is_chunk_multimodal=False) + maker = _make_policy_maker(long_seq, decode_seq) + maker._last_forward_kind = 'decode' + maker.create_model_inputs_delta = lambda: (_ for _ in ()).throw(AssertionError('decode should not repeat')) + maker.create_model_inputs_long_context = lambda seq, chunk_size, multimodals: model_inputs + + forward_inputs = maker._make_forward_inputs(prefill=False) + + assert forward_inputs['inputs'] is model_inputs + assert forward_inputs['delta'] is None + assert not model_inputs.is_first_chunk + assert not model_inputs.is_last_chunk + + +def test_deferred_long_context_chunk_runs_when_decode_has_no_valid_seqs(): + long_seq = _DummySeq(history_ids=0, token_ids=1024, all_multimodals={}, input_multimodals={}) + decode_seq = _DummySeq(history_ids=0, token_ids=1, all_multimodals={}, input_multimodals={}) + model_inputs = SimpleNamespace(is_decoding=False, + is_chunk=True, + is_first_chunk=False, + is_last_chunk=False, + is_chunk_multimodal=False) + maker = _make_policy_maker(long_seq, decode_seq) + maker._decode_count = 3 + maker._last_forward_kind = 'long_context_chunk' + maker.create_model_inputs_delta = lambda: (None, [], [decode_seq]) + maker.create_model_inputs_long_context = lambda seq, chunk_size, multimodals: model_inputs + + forward_inputs = maker._make_forward_inputs(prefill=False) + + assert forward_inputs['inputs'] is model_inputs + assert forward_inputs['delta'] is None + assert maker.to_evict_seqs == [decode_seq] + assert maker._decode_count == 0 + + +def test_long_context_chunk_falls_back_to_decode_when_chunk_reservation_fails(): + long_seq = _DummySeq(history_ids=0, token_ids=1024, all_multimodals={}, input_multimodals={}) + decode_seq = _DummySeq(history_ids=0, token_ids=1, all_multimodals={}, input_multimodals={}) + delta = SimpleNamespace(is_decoding=True) + maker = _make_policy_maker(long_seq, decode_seq) + maker._last_forward_kind = 'decode' + maker.scheduler.reserve_long_context_chunk = lambda *args, **kwargs: False + maker.create_model_inputs_delta = lambda: (delta, [decode_seq], []) + maker.create_model_inputs_long_context = lambda *args, **kwargs: (_ for _ in ()).throw( + AssertionError('chunk inputs should not be created without KV reservation')) + + forward_inputs = maker._make_forward_inputs(prefill=False) + + assert forward_inputs['inputs'] is None + assert forward_inputs['delta'] is delta + + +def test_last_long_context_chunk_waits_for_prefill_turn_with_decode_ready(): + long_seq = _DummySeq(history_ids=512, token_ids=256, all_multimodals={}, input_multimodals={}) + decode_seq = _DummySeq(history_ids=0, token_ids=1, all_multimodals={}, input_multimodals={}) + delta = SimpleNamespace(is_decoding=True) + maker = _make_policy_maker(long_seq, decode_seq) + maker._last_forward_kind = 'long_context_chunk' + maker.create_model_inputs_delta = lambda: (delta, [decode_seq], []) + + forward_inputs = maker._make_forward_inputs(prefill=False) + + assert forward_inputs['inputs'] is None + assert forward_inputs['delta'] is delta + assert maker.long_context_chunker.enabled() + + +def test_last_long_context_chunk_runs_as_prefill_on_prefill_turn(): + image = _DummyMultiModal(start=600, end=700) + long_seq = _DummySeq(history_ids=512, + token_ids=256, + all_multimodals={'image': [image]}, + input_multimodals={'image': [image]}) + decode_seq = _DummySeq(history_ids=0, token_ids=1, all_multimodals={}, input_multimodals={}) + model_inputs = SimpleNamespace(is_decoding=False, + is_chunk=False, + is_first_chunk=False, + is_last_chunk=False, + is_chunk_multimodal=False) + maker = _make_policy_maker(long_seq, decode_seq) + maker._last_forward_kind = 'long_context_chunk' + maker.create_model_inputs = lambda seqs, is_prefill: model_inputs + maker.create_model_inputs_delta_valid_only = lambda: (None, [decode_seq], []) + + forward_inputs = maker._make_forward_inputs(prefill=True) + + assert forward_inputs['inputs'] is model_inputs + assert model_inputs.is_chunk + assert model_inputs.is_last_chunk + assert model_inputs.is_chunk_multimodal + assert not maker.long_context_chunker.enabled() + + +def test_do_prefill_default_forces_pending_last_chunk_prefill(): + long_seq = _DummySeq(history_ids=512, token_ids=256, all_multimodals={}, input_multimodals={}) + maker = InputsMakerAsync.__new__(InputsMakerAsync) + maker.config = SimpleNamespace(role=EngineRole.Decode, + max_prefill_token_num=512, + max_batches=1, + prefill_interval=100) + maker.scheduler = _FakeScheduler([], num_ready=1, num_running=1) + maker.long_context_chunker = LongContextChunker(max_prefill_token_num=512) + maker.long_context_chunker.set_seq(long_seq) + maker._decode_count = 0 + + assert maker.do_prefill_default() + + def test_state_prefix_cache_restore_offsets_are_compact(): messages = [_state_seq(4, 11), _state_seq(5, -1), _state_seq(6, 13)] diff --git a/tests/pytorch/paging/test_block_manager.py b/tests/pytorch/paging/test_block_manager.py index b08116d7f4..25611d565c 100644 --- a/tests/pytorch/paging/test_block_manager.py +++ b/tests/pytorch/paging/test_block_manager.py @@ -343,3 +343,29 @@ def test_win_alloc(self, scheduler, block_mgr, num_gpu_blocks, window_size): block_table = block_mgr.get_block_table(msg) assert block_table is None or len(block_table) == 2 block_mgr.free(msg) + + def test_win_alloc_respects_kv_token_limit(self, scheduler, block_mgr, num_gpu_blocks, window_size): + sess = scheduler.add_session(0) + block_size = sess.seq_meta.block_size + + token_ids = torch.tensor([1] * (window_size * 3)) + msg = sess.add_sequence(token_ids) + + msg.kv_token_limit = window_size + block_mgr.allocate(msg) + assert len(block_mgr.get_block_table(msg)) == 2 + assert block_mgr.get_num_free_gpu_blocks() == num_gpu_blocks - 2 + + msg.set_step(window_size) + msg.kv_token_limit = window_size + block_size + block_mgr.allocate(msg) + assert len(block_mgr.get_block_table(msg)) == 3 + assert block_mgr.get_num_free_gpu_blocks() == num_gpu_blocks - 3 + + msg.set_step(window_size + block_size) + msg.kv_token_limit = window_size + block_size * 2 + assert block_mgr.num_required_blocks(msg) == 1 + block_mgr.allocate(msg) + assert len(block_mgr.get_block_table(msg)) == 3 + assert block_mgr.get_num_free_gpu_blocks() == num_gpu_blocks - 3 + assert msg.num_ignored_history == block_size diff --git a/tests/pytorch/paging/test_scheduler.py b/tests/pytorch/paging/test_scheduler.py index d91b217be2..7bfa19df05 100644 --- a/tests/pytorch/paging/test_scheduler.py +++ b/tests/pytorch/paging/test_scheduler.py @@ -613,6 +613,175 @@ def test_scheduler_excludes_recompute_eviction_prefix_hits_from_stats(): assert scheduler.block_trie.stats.num_hit_tokens == 0 +def _make_scheduler_for_decode_growth(num_gpu_blocks: int = 2): + from lmdeploy.pytorch.strategies.ar.sequence import ARSequenceStrategy + block_size = 4 + seq_meta = SequenceMeta(block_size, strategy=ARSequenceStrategy()) + cache_config = CacheConfig(max_batches=2, + block_size=block_size, + num_cpu_blocks=0, + num_gpu_blocks=num_gpu_blocks, + max_prefill_token_num=block_size * 4) + scheduler_config = SchedulerConfig(max_batches=2, + max_session_len=64, + max_request_output_len=64, + eviction_type='recompute') + scheduler = Scheduler(scheduler_config=scheduler_config, cache_config=cache_config, seq_meta=seq_meta) + return scheduler, block_size + + +def _make_scheduler_for_long_context_chunks(num_gpu_blocks: int = 6): + from lmdeploy.pytorch.strategies.ar.sequence import ARSequenceStrategy + block_size = 4 + seq_meta = SequenceMeta(block_size, strategy=ARSequenceStrategy()) + cache_config = CacheConfig(max_batches=2, + block_size=block_size, + num_cpu_blocks=0, + num_gpu_blocks=num_gpu_blocks, + max_prefill_token_num=block_size * 2) + scheduler_config = SchedulerConfig(max_batches=2, + max_session_len=64, + max_request_output_len=64, + eviction_type='recompute') + scheduler = Scheduler(scheduler_config=scheduler_config, cache_config=cache_config, seq_meta=seq_meta) + return scheduler, block_size + + +def _make_ssm_scheduler_for_long_context_chunks(num_gpu_blocks: int = 2): + from lmdeploy.pytorch.strategies.ar.sequence import ARSequenceStrategy + block_size = 4 + seq_meta = SequenceMeta(block_size, strategy=ARSequenceStrategy()) + cache_config = CacheConfig(max_batches=1, + block_size=block_size, + num_cpu_blocks=0, + num_gpu_blocks=num_gpu_blocks, + max_prefill_token_num=block_size * 2, + num_state_caches=2, + states_shapes=[((1, ), torch.float32)]) + scheduler_config = SchedulerConfig(max_batches=1, + max_session_len=64, + max_request_output_len=64, + eviction_type='recompute') + scheduler = Scheduler(scheduler_config=scheduler_config, cache_config=cache_config, seq_meta=seq_meta) + return scheduler, block_size + + +def test_schedule_running_reclaims_waiting_blocks_for_decode_growth(): + scheduler, block_size = _make_scheduler_for_decode_growth(num_gpu_blocks=2) + decode = scheduler.add_session(100).add_sequence([1] * block_size) + waiting = scheduler.add_session(101).add_sequence([2] * block_size) + + output = scheduler.schedule(is_prefill=True) + assert output.running == [decode, waiting] + scheduler.activate_seqs([decode]) + waiting.state.evict() + assert decode.status == MessageStatus.RUNNING + assert waiting.status == MessageStatus.WAITING + assert scheduler.block_manager.get_num_free_gpu_blocks() == 0 + + valid_mask = scheduler.schedule_running([decode], num_required_tokens=1, prealloc_size=1) + + assert valid_mask == [True] + assert decode.status == MessageStatus.RUNNING + assert decode.num_blocks == 2 + assert waiting.status == MessageStatus.WAITING + assert waiting.num_blocks == 0 + assert scheduler.block_manager.get_num_free_gpu_blocks() == 0 + + +def test_schedule_running_keeps_other_running_sequence_when_decode_growth_fails(): + scheduler, block_size = _make_scheduler_for_decode_growth(num_gpu_blocks=2) + decode = scheduler.add_session(100).add_sequence([1] * block_size) + long_chunk = scheduler.add_session(101).add_sequence([2] * block_size) + + output = scheduler.schedule(is_prefill=True) + assert output.running == [decode, long_chunk] + scheduler.activate_seqs([decode, long_chunk]) + assert decode.status == MessageStatus.RUNNING + assert long_chunk.status == MessageStatus.RUNNING + assert scheduler.block_manager.get_num_free_gpu_blocks() == 0 + + valid_mask = scheduler.schedule_running([decode], num_required_tokens=1, prealloc_size=1) + + assert valid_mask == [False] + assert decode.status == MessageStatus.WAITING + assert long_chunk.status == MessageStatus.RUNNING + assert long_chunk.num_blocks == 1 + assert scheduler.block_manager.get_num_free_gpu_blocks() == 0 + + +def test_schedule_prefill_allocates_only_first_long_context_chunk(): + scheduler, block_size = _make_scheduler_for_long_context_chunks(num_gpu_blocks=2) + long_seq = scheduler.add_session(100).add_sequence([1] * (block_size * 4)) + + output = scheduler.schedule(is_prefill=True, prealloc_size=1) + + assert output.running == [long_seq] + assert long_seq.status == MessageStatus.READY + assert long_seq.kv_token_limit == block_size * 2 + assert long_seq.num_blocks == 2 + assert scheduler.block_manager.get_num_free_gpu_blocks() == 0 + + +def test_schedule_prefill_reapplies_chunk_limit_after_ssm_state_rollback(): + scheduler, block_size = _make_ssm_scheduler_for_long_context_chunks(num_gpu_blocks=2) + long_seq = scheduler.add_session(100).add_sequence([1] * (block_size * 4)) + + ensure_results = iter([False, True]) + + def _ensure_runtime_state_available_once_then_succeed(): + return next(ensure_results) + + scheduler._ensure_runtime_state_available = _ensure_runtime_state_available_once_then_succeed + + output = scheduler.schedule(is_prefill=True, prealloc_size=1) + + assert output.running == [long_seq] + assert long_seq.status == MessageStatus.READY + assert long_seq.kv_token_limit == block_size * 2 + assert long_seq.num_blocks == 2 + + +def test_reserve_long_context_chunk_grows_one_chunk_at_a_time(): + scheduler, block_size = _make_scheduler_for_long_context_chunks(num_gpu_blocks=6) + long_seq = scheduler.add_session(100).add_sequence([1] * (block_size * 5)) + + output = scheduler.schedule(is_prefill=True, prealloc_size=1) + assert output.running == [long_seq] + assert long_seq.kv_token_limit == block_size * 2 + assert long_seq.num_blocks == 2 + + scheduler.activate_seqs([long_seq]) + long_seq.set_step(block_size * 2) + + assert scheduler.reserve_long_context_chunk(long_seq, block_size * 2) + assert long_seq.status == MessageStatus.RUNNING + assert long_seq.kv_token_limit == block_size * 4 + assert long_seq.num_blocks == 4 + + long_seq.set_step(block_size * 4) + + assert scheduler.reserve_long_context_chunk(long_seq, block_size, prealloc_size=1, is_last_chunk=True) + assert long_seq.kv_token_limit is None + assert long_seq.num_blocks == 6 + assert scheduler.block_manager.get_num_free_gpu_blocks() == 0 + + +def test_reserve_long_context_chunk_failure_preserves_committed_prefix(): + scheduler, block_size = _make_scheduler_for_long_context_chunks(num_gpu_blocks=2) + long_seq = scheduler.add_session(100).add_sequence([1] * (block_size * 4)) + + output = scheduler.schedule(is_prefill=True) + assert output.running == [long_seq] + scheduler.activate_seqs([long_seq]) + long_seq.set_step(block_size * 2) + + assert not scheduler.reserve_long_context_chunk(long_seq, block_size * 2) + assert long_seq.status == MessageStatus.RUNNING + assert long_seq.kv_token_limit == block_size * 2 + assert long_seq.num_blocks == 2 + + def test_scheduler_rolls_back_prefix_hit_that_would_start_long_context_chunk_from_middle(): from lmdeploy.pytorch.strategies.ar.sequence import ARSequenceStrategy block_size = 16 diff --git a/tests/pytorch/spec_decode/test_spec_agent.py b/tests/pytorch/spec_decode/test_spec_agent.py index d121bed2a3..2713208a9c 100644 --- a/tests/pytorch/spec_decode/test_spec_agent.py +++ b/tests/pytorch/spec_decode/test_spec_agent.py @@ -2,7 +2,9 @@ import torch -from lmdeploy.pytorch.spec_decode.spec_agent import _expand_sampling_inputs +from lmdeploy.pytorch.model_inputs import DPMeta, ModelInputs +from lmdeploy.pytorch.spec_decode.spec_agent import SpecModelAgent, _expand_sampling_inputs +from lmdeploy.pytorch.strategies.ar_spec.model_agent import ARSpecExtraInputs device = 'cuda' if torch.cuda.is_available() else 'cpu' @@ -312,3 +314,83 @@ def test_slice_sampling_inputs_prefill(): sampling_inputs = SamplingInputs(max_top_k=1, batch_size=2) result = _slice_sampling_inputs(sampling_inputs, 1) assert result is sampling_inputs + + +def _model_inputs(input_ids, + *, + is_decoding=False, + is_chunk=False, + is_first_chunk=False, + is_last_chunk=False, + dp_meta=None): + input_ids = torch.tensor([input_ids]) + seq_length = torch.tensor([input_ids.size(1)]) + history_lengths = torch.tensor([0]) + max_q_seqlen = input_ids.size(1) + return ModelInputs( + input_ids=input_ids, + seq_length=seq_length, + history_lengths=history_lengths, + block_offsets=torch.zeros(1, 1, dtype=torch.int32), + is_decoding=is_decoding, + num_ignored_history=torch.zeros(1, dtype=torch.long), + max_q_seqlen=max_q_seqlen, + max_kv_seqlen=max_q_seqlen, + sum_kv_seqlen=max_q_seqlen, + is_chunk=is_chunk, + is_first_chunk=is_first_chunk, + is_last_chunk=is_last_chunk, + dp_meta=dp_meta, + ) + + +def _extra(hidden_values): + hidden_states = torch.tensor([hidden_values], dtype=torch.float32) + return ARSpecExtraInputs( + target_hidden_states=hidden_states, + next_token_ids=torch.tensor([99]), + last_token_indices=torch.tensor([hidden_states.size(1) - 1]), + ) + + +def test_prepare_inputs_from_main_keeps_chunk_carry_across_decode(): + agent = SpecModelAgent.__new__(SpecModelAgent) + agent._prev_chunk_last = {} + + first_chunk = _model_inputs([10, 11, 12], is_chunk=True, is_first_chunk=True) + agent._prepare_inputs_from_main(first_chunk, _extra([[1, 10], [2, 20], [3, 30]])) + saved_first_chunk_last = agent._prev_chunk_last['hidden_states'].clone() + + decode = _model_inputs([90, 91, 92], is_decoding=True) + agent._prepare_inputs_from_main(decode, _extra([[9, 90], [8, 80], [7, 70]])) + + assert torch.equal(agent._prev_chunk_last['hidden_states'], saved_first_chunk_last) + + middle_chunk = _model_inputs([20, 21, 22], is_chunk=True) + draft_inputs, _ = agent._prepare_inputs_from_main(middle_chunk, _extra([[4, 40], [5, 50], [6, 60]])) + + assert torch.equal(draft_inputs.target_hidden_states[:, :1], saved_first_chunk_last) + assert torch.equal(agent._prev_chunk_last['hidden_states'], torch.tensor([[[6., 60.]]])) + + +def test_prepare_inputs_from_main_clears_chunk_carry_on_non_chunk_prefill(): + agent = SpecModelAgent.__new__(SpecModelAgent) + agent._prev_chunk_last = {'hidden_states': torch.ones(1, 1, 2)} + + prefill = _model_inputs([10, 11, 12]) + agent._prepare_inputs_from_main(prefill, _extra([[1, 10], [2, 20], [3, 30]])) + + assert agent._prev_chunk_last == {} + + +def test_prepare_inputs_from_main_keeps_chunk_carry_for_dp_local_decode_global_prefill(): + agent = SpecModelAgent.__new__(SpecModelAgent) + saved = torch.ones(1, 1, 2) + agent._prev_chunk_last = {'hidden_states': saved.clone()} + agent.proposer = _DummyProposer() + + dp_meta = DPMeta(dp_batches=[1, 1], dp_is_decoding=False) + inputs = _model_inputs([90, 91, 92], is_decoding=True, dp_meta=dp_meta) + agent._prepare_inputs_from_main(inputs, _extra([[9, 90], [8, 80], [7, 70]])) + + assert torch.equal(agent._prev_chunk_last['hidden_states'], saved)