diff --git a/lmdeploy/cli/cli.py b/lmdeploy/cli/cli.py index 1dceba3c5a..26c412bffb 100644 --- a/lmdeploy/cli/cli.py +++ b/lmdeploy/cli/cli.py @@ -54,6 +54,8 @@ def add_parser_chat(): ArgumentHelper.device(pt_group) ArgumentHelper.eager_mode(pt_group) ArgumentHelper.dllm_block_length(pt_group) + ArgumentHelper.prefix_cache_state_budget(pt_group) + ArgumentHelper.prefix_cache_decode_state_interval(pt_group) # common engine args dtype_act = ArgumentHelper.dtype(pt_group) tp_act = ArgumentHelper.tp(pt_group) diff --git a/lmdeploy/cli/serve.py b/lmdeploy/cli/serve.py index a1c4b583bf..01ac1d44f1 100644 --- a/lmdeploy/cli/serve.py +++ b/lmdeploy/cli/serve.py @@ -107,6 +107,8 @@ def add_parser_api_server(): ArgumentHelper.enable_return_routed_experts(pt_group) ArgumentHelper.distributed_executor_backend(pt_group) ArgumentHelper.kernel_block_size(pt_group) + ArgumentHelper.prefix_cache_state_budget(pt_group) + ArgumentHelper.prefix_cache_decode_state_interval(pt_group) # common engine args disable_vision_encoder = ArgumentHelper.disable_vision_encoder(pt_group) @@ -234,6 +236,8 @@ def api_server(args): session_len=args.session_len, adapters=adapters, enable_prefix_caching=args.enable_prefix_caching, + prefix_cache_state_budget=args.prefix_cache_state_budget, + prefix_cache_decode_state_interval=args.prefix_cache_decode_state_interval, device_type=args.device, quant_policy=args.quant_policy, eager_mode=args.eager_mode, diff --git a/lmdeploy/cli/utils.py b/lmdeploy/cli/utils.py index a5b0f9de33..1dda62a7e8 100644 --- a/lmdeploy/cli/utils.py +++ b/lmdeploy/cli/utils.py @@ -585,6 +585,30 @@ def enable_prefix_caching(parser): default=False, help='Enable cache and match prefix') + @staticmethod + def prefix_cache_state_budget(parser): + """Add argument prefix_cache_state_budget to parser.""" + + return parser.add_argument('--prefix-cache-state-budget', + type=int, + default=0, + help='Extra SSM state-cache slots budgeted for prefix-cache checkpoints. ' + '0 adds no extra slots, but checkpoints may borrow idle runtime state slots. ' + 'Only used by the PyTorch engine.') + + @staticmethod + def prefix_cache_decode_state_interval(parser): + """Add argument prefix_cache_decode_state_interval to parser.""" + + return parser.add_argument('--prefix-cache-decode-state-interval', + type=int, + default=0, + help='Token interval for SSM decode-state prefix-cache checkpoints. ' + '0 disables decode checkpoint saves while keeping prefill/chunk checkpoints. ' + 'Use a positive multiple of block size only for long SSM decoding where later ' + 'requests can reuse decode prefixes; smaller values improve hit granularity ' + 'but use more checkpoint memory and copy work. Only used by the PyTorch engine.') + @staticmethod def num_tokens_per_iter(parser): return parser.add_argument('--num-tokens-per-iter', diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index 39a41263f6..f13a3fa3ff 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -367,6 +367,17 @@ class PytorchEngineConfig: max_batch_size is always captured. thread_safe: thread safe engine instance. enable_prefix_caching: Enable token match and sharing caches. + prefix_cache_state_budget: Extra SSM state-cache slots budgeted for + prefix-cache checkpoints. 0 adds no extra slots, but SSM + checkpoints may still borrow idle runtime state slots. + prefix_cache_decode_state_interval: Token interval for SSM decode + state checkpoints. 0 disables decode-state checkpoint saves; prefill + and chunk checkpoints may still be saved. Keep 0 unless the workload + has long SSM decoding and repeated continuations that can reuse + decode checkpoints. Smaller positive values create more hit points + but use more checkpoint memory and copy work; larger values reduce + overhead but make decode-prefix hits less likely. Positive values + must be multiples of the cache block size. device_type: The inference device type, options ['cuda'] eager_mode: Enable "eager" mode or not custom_module_map: nn module map customized by users. Once @@ -428,6 +439,8 @@ class PytorchEngineConfig: cudagraph_capture_batch_sizes: list[int] | None = None thread_safe: bool = False enable_prefix_caching: bool = False + prefix_cache_state_budget: int = 0 + prefix_cache_decode_state_interval: int = 0 device_type: str = 'cuda' eager_mode: bool = False custom_module_map: dict[str, str] = None @@ -472,6 +485,8 @@ def __post_init__(self): assert self.max_prefill_token_num >= 0, \ 'invalid max_prefill_token_num' assert self.num_gpu_blocks >= 0, 'invalid num_gpu_blocks' + assert self.prefix_cache_state_budget >= 0, 'invalid prefix_cache_state_budget' + assert self.prefix_cache_decode_state_interval >= 0, 'invalid prefix_cache_decode_state_interval' try: self.quant_policy = QuantPolicy(self.quant_policy) except ValueError as e: @@ -485,6 +500,9 @@ def __post_init__(self): (f'block_size must be >= kernel_block_size and an integer multiple ' f'of kernel_block_size, but got block_size {self.block_size} ' f'and kernel_block_size {self.kernel_block_size}') + if self.prefix_cache_decode_state_interval > 0: + assert self.prefix_cache_decode_state_interval % self.block_size == 0, ( + 'prefix_cache_decode_state_interval must be a multiple of block_size') if self.quant_policy > 0 and self.device_type not in ['cuda', 'ascend']: assert False, \ 'kv cache quantization only works for CUDA and ASCEND.' diff --git a/lmdeploy/metrics/loggers.py b/lmdeploy/metrics/loggers.py index 69e24ede4f..70d3834e74 100644 --- a/lmdeploy/metrics/loggers.py +++ b/lmdeploy/metrics/loggers.py @@ -191,6 +191,11 @@ def __init__(self, model_name: str, max_model_len: int, dp_rank: int = 0): documentation='GPU KV-cache usage. 1 means 100 percent usage.', labelnames=labelnames).labels(*labelvalues) + self.gauge_prefix_cache_hit_rate = prometheus_client.Gauge( + name='lmdeploy:prefix_cache_hit_rate', + documentation='Prefix-cache hit rate. 1 means 100 percent of queried prefix tokens hit.', + labelnames=labelnames).labels(*labelvalues) + # # Counters # @@ -359,6 +364,7 @@ def record_schedule(self, stats: SchedulerStats) -> None: self.gauge_scheduler_running.set(stats.num_running_reqs) self.gauge_scheduler_waiting.set(stats.num_waiting_reqs) self.gauge_gpu_cache_usage.set(stats.gpu_cache_usage) + self.gauge_prefix_cache_hit_rate.set(stats.prefix_cache_hit_rate) def record_iteration(self, stats: IterationStats) -> None: """Report token-related metrics to prometheus.""" diff --git a/lmdeploy/pytorch/block.py b/lmdeploy/pytorch/block.py index cd4267890d..00088ac1e6 100644 --- a/lmdeploy/pytorch/block.py +++ b/lmdeploy/pytorch/block.py @@ -24,7 +24,6 @@ def __init__(self, blocks: np.ndarray = None): assert blocks.ndim == 1 self._blocks = blocks self._num_real = len(blocks) - self.last_shared_node = None def reserve(self, size: int): """Reserve cache size.""" @@ -67,7 +66,6 @@ def resize(self, num_blocks: int): def reset(self): """reset.""" self.resize(0) - self.last_shared_node = None def clone(self): """Clone logical blocks.""" diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py index 13dbe61ebd..cb280dbfa8 100644 --- a/lmdeploy/pytorch/config.py +++ b/lmdeploy/pytorch/config.py @@ -121,6 +121,8 @@ class CacheConfig: quant_policy: QuantPolicy = QuantPolicy.NONE device_type: str = 'cuda' num_state_caches: int = None + prefix_cache_state_budget: int = 0 + prefix_cache_decode_state_interval: int = 0 states_shapes: list[tuple] = field(default_factory=list) # reserved blocks for dummy inputs, init to 0 for unit test. @@ -132,11 +134,16 @@ class CacheConfig: def __post_init__(self): """Post init.""" + assert self.prefix_cache_state_budget >= 0, 'invalid prefix_cache_state_budget' + assert self.prefix_cache_decode_state_interval >= 0, 'invalid prefix_cache_decode_state_interval' if self.window_size > 1 and self.enable_prefix_caching: logger.warning('Prefix caching is not available for window attention.') self.enable_prefix_caching = False if self.kernel_block_size == -1: self.kernel_block_size = self.block_size + if self.prefix_cache_decode_state_interval > 0: + assert self.prefix_cache_decode_state_interval % self.block_size == 0, ( + 'prefix_cache_decode_state_interval must be a multiple of block_size') self.cudagraph_capture_batch_sizes = normalize_cudagraph_capture_batch_sizes( self.cudagraph_capture_batch_sizes, self.max_batches) diff --git a/lmdeploy/pytorch/engine/cache_engine.py b/lmdeploy/pytorch/engine/cache_engine.py index 9dc99c2bea..6dd6ffb6d3 100644 --- a/lmdeploy/pytorch/engine/cache_engine.py +++ b/lmdeploy/pytorch/engine/cache_engine.py @@ -2,7 +2,9 @@ # modify from: https://github.com/vllm-project/vllm import json import math +from collections.abc import Sequence from dataclasses import dataclass +from operator import index as as_index import torch @@ -565,3 +567,79 @@ def get_cache_state_size(state_shapes: list[tuple[tuple[int], torch.dtype]]) -> def state_caches(self): """State caches.""" return self._state_caches + + @staticmethod + def _index_list(idx: int | Sequence[int]): + """Normalize host-side cache indices.""" + if isinstance(idx, torch.Tensor): + raise TypeError('State cache copy indices must be host integers, not torch.Tensor.') + if isinstance(idx, (str, bytes)): + raise TypeError('State cache copy indices must be an int or a sequence of ints.') + try: + return [as_index(idx)] + except TypeError: + pass + if not isinstance(idx, Sequence): + raise TypeError('State cache copy indices must be an int or a sequence of ints.') + if any(isinstance(item, torch.Tensor) for item in idx): + raise TypeError('State cache copy indices must be host integers, not torch.Tensor.') + return [as_index(item) for item in idx] + + @staticmethod + def _validate_index_bounds(indices: Sequence[int], num_caches: int): + """Check normalized cache indices are valid state slots.""" + for idx in indices: + if idx < 0 or idx >= num_caches: + raise ValueError(f'State cache index {idx} is out of range [0, {num_caches}).') + + @staticmethod + def _copy_ranges(src_list: list[int], dst_list: list[int]): + """Yield contiguous copy ranges as (src_start, dst_start, length).""" + pairs = sorted(zip(src_list, dst_list)) + if len(pairs) == 0: + return + start_src = prev_src = pairs[0][0] + start_dst = prev_dst = pairs[0][1] + length = 1 + for src, dst in pairs[1:]: + if src == prev_src + 1 and dst == prev_dst + 1: + prev_src = src + prev_dst = dst + length += 1 + continue + yield start_src, start_dst, length + start_src = prev_src = src + start_dst = prev_dst = dst + length = 1 + yield start_src, start_dst, length + + def copy_caches(self, src_idx: int | Sequence[int], dst_idx: int | Sequence[int]): + """Copy state cache slots. + + This is the low-level primitive needed by SSM prefix caching: a frozen + state checkpoint can be copied into a newly allocated runtime slot + before the next forward. + """ + if len(self._state_caches) <= 0: + return + + src_list = self._index_list(src_idx) + dst_list = self._index_list(dst_idx) + if len(src_list) != len(dst_list): + raise ValueError('src_idx and dst_idx must have the same number of elements.') + if len(src_list) == 0: + return + num_caches = self.mem_pool.size(0) + self._validate_index_bounds(src_list, num_caches) + self._validate_index_bounds(dst_list, num_caches) + dst_set = set(dst_list) + if len(dst_set) != len(dst_list): + raise ValueError('dst_idx must not contain duplicate entries.') + if not set(src_list).isdisjoint(dst_set): + raise ValueError('src_idx and dst_idx must not overlap for stream-ordered state copies.') + + for src, dst, length in self._copy_ranges(src_list, dst_list): + if length == 1: + self.mem_pool[dst].copy_(self.mem_pool[src], non_blocking=True) + else: + self.mem_pool[dst:dst + length].copy_(self.mem_pool[src:src + length], non_blocking=True) diff --git a/lmdeploy/pytorch/engine/config_builder.py b/lmdeploy/pytorch/engine/config_builder.py index 6b4644d1e0..b35d8f51ed 100644 --- a/lmdeploy/pytorch/engine/config_builder.py +++ b/lmdeploy/pytorch/engine/config_builder.py @@ -75,6 +75,8 @@ def build_cache_config(engine_config: PytorchEngineConfig): max_prefill_token_num=engine_config.max_prefill_token_num, cudagraph_capture_batch_sizes=engine_config.cudagraph_capture_batch_sizes, enable_prefix_caching=engine_config.enable_prefix_caching, + prefix_cache_state_budget=engine_config.prefix_cache_state_budget, + prefix_cache_decode_state_interval=engine_config.prefix_cache_decode_state_interval, quant_policy=engine_config.quant_policy, device_type=engine_config.device_type, migration_backend=engine_config.migration_backend, diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index 6363592bfa..1e07a626bb 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -23,6 +23,7 @@ from ..adapter.adapter import AdapterManager from ..config import CacheConfig, ModelConfig from ..messages import MessageStatus, SchedulerSequence, UpdateTokenMode +from ..multimodal.data_type import ensure_multimodal_content_hashes from ..paging import Scheduler from ..strategies import build_strategy_factory from .base import EngineBase @@ -412,6 +413,8 @@ def _on_add_message(self, reqs: list[Request], **kwargs): input_ids = result.input_ids input_multimodals = result.input_multimodals + if self.cache_config.enable_prefix_caching: + input_multimodals = ensure_multimodal_content_hashes(input_multimodals) req_data['token_ids'] = input_ids req_data['input_multimodals'] = input_multimodals diff --git a/lmdeploy/pytorch/engine/engine_loop.py b/lmdeploy/pytorch/engine/engine_loop.py index 10cb1885c6..143f9632a3 100644 --- a/lmdeploy/pytorch/engine/engine_loop.py +++ b/lmdeploy/pytorch/engine/engine_loop.py @@ -305,6 +305,7 @@ def __get_logprobs(batched_outputs: 'BatchedOutputs'): seq = running[0] seq.append_routed_experts(all_routed_experts) seq.append_logits(logits) + self.scheduler.block_trie.cache_routed_experts_for_seq(seq) return dict() new_token_timestamp = batched_outputs.new_token_timestamp @@ -318,6 +319,7 @@ def __get_logprobs(batched_outputs: 'BatchedOutputs'): batched_outputs=batched_outputs, model_inputs=model_inputs, delta=delta) + self.scheduler.block_trie.cache_routed_experts(running) # generate output outputs: dict[int, InferOutput] = dict() @@ -378,6 +380,42 @@ async def _main_loop_try_send_next_inputs(self): scheduler.collect_migration_done() return await self.inputs_maker.send_next_inputs() + @staticmethod + def _has_state_checkpoint_save(model_inputs: 'ModelInputs | None', delta: 'ModelInputsDelta | None'): + """Check whether the current forward reserved SSM checkpoints.""" + return ((model_inputs is not None and model_inputs.state_prefix_cache_save_offsets is not None) + or (delta is not None and delta.state_prefix_cache_save_offsets is not None)) + + async def _prefetch_next_inputs(self): + """Collect migration completions before prefetching the next batch.""" + self.scheduler.collect_migration_done() + return await self.inputs_maker.prefetch_next_inputs() + + def _publish_forward_prefix_cache(self, running: 'SeqList', has_state_checkpoint_save: bool): + """Publish per-forward prefix-cache ownership before prefetching.""" + if not self.scheduler.block_trie.enable: + return + if has_state_checkpoint_save: + self.scheduler.block_trie.commit_state_checkpoints(running, acquire_save_ref=True) + self.scheduler.block_trie.release_state_checkpoint_restores(running) + + def _release_forward_prefix_cache_saves(self, running: 'SeqList'): + """Release producer refs after the forward output/event boundary.""" + if not self.scheduler.block_trie.enable: + return + self.scheduler.block_trie.release_state_checkpoint_saves(running) + + def _finish_forward_output(self, + out: 'BatchedOutputs | None', + running: 'SeqList', + model_inputs: 'ModelInputs | None', + delta: 'ModelInputsDelta | None'): + """Publish outputs.""" + if out is None: + return + step_outputs = self._make_infer_outputs(out, running=running, model_inputs=model_inputs, delta=delta) + self.resp_queue.put_nowait(step_outputs) + async def _main_loop_get_outputs( self, running: 'SeqList', @@ -387,18 +425,19 @@ async def _main_loop_get_outputs( model_inputs = forward_inputs['inputs'] delta = forward_inputs['delta'] self.inputs_maker.update_running_seqs(running, model_inputs) - - # try prefetch inputs - self.scheduler.collect_migration_done() - forward_inputs, next_running = await self.inputs_maker.prefetch_next_inputs() - - # send output + has_state_checkpoint_save = self._has_state_checkpoint_save(model_inputs, delta) + + # ModelAgent executes queued forwards in send order. Once the current + # input is queued, matched checkpoints can be published before waiting + # for GPU output; save checkpoints keep a producer ref until the output + # event boundary so prefetch cannot evict/reuse their destination slots. + self._publish_forward_prefix_cache(running, has_state_checkpoint_save) + forward_inputs, next_running = await self._prefetch_next_inputs() out = await self.executor.get_output_async() - if out is not None: - step_outputs = self._make_infer_outputs(out, running=running, model_inputs=model_inputs, delta=delta) - self.resp_queue.put_nowait(step_outputs) - # out might come from shared memory, need to explicitly delete to release memory in time - del out + self._release_forward_prefix_cache_saves(running) + self._finish_forward_output(out, running, model_inputs, delta) + # out might come from shared memory, need to explicitly delete to release memory in time + del out return forward_inputs, next_running diff --git a/lmdeploy/pytorch/engine/executor/base.py b/lmdeploy/pytorch/engine/executor/base.py index db4ad9008b..5b2de08acd 100644 --- a/lmdeploy/pytorch/engine/executor/base.py +++ b/lmdeploy/pytorch/engine/executor/base.py @@ -5,6 +5,7 @@ from typing import Any, NamedTuple from lmdeploy.pytorch.config import BackendConfig, CacheConfig, DistConfig, MiscConfig, ModelConfig, SpecDecodeConfig +from lmdeploy.pytorch.disagg.config import EngineRole from lmdeploy.pytorch.disagg.conn.protocol import DistServeInitRequest, DistServeKVTransferEndpointInfo from lmdeploy.pytorch.disagg.messages import MigrationExecutionBatch from lmdeploy.pytorch.engine.cache_engine import CacheEngine @@ -45,6 +46,12 @@ def __init__(self, # do not support sliding window prefix caching logger.warning('Sliding window prefix caching is not supported.') cache_config.enable_prefix_caching = False + if specdecode_config is not None and cache_config.enable_prefix_caching: + logger.warning('Speculative decoding prefix caching is not supported.') + cache_config.enable_prefix_caching = False + if cache_config.role != EngineRole.Hybrid and cache_config.enable_prefix_caching: + logger.warning('PD prefix caching is not supported.') + cache_config.enable_prefix_caching = False self.model_config = model_config self.cache_config = cache_config self.backend_config = backend_config @@ -238,18 +245,16 @@ def _get_state_cache_mem(self): num_state_caches = cache_config.num_state_caches if num_state_caches is None: - # add more caches for eviction + # One state slot is reserved for system use. Active sequences need + # max_batches runtime slots; prefix-cache checkpoints use an + # explicitly configured extra budget. # TODO: Share memory between state cache and pageable cache - num_state_caches = int(cache_config.max_batches + 1) + num_state_caches = int(cache_config.max_batches + 1 + cache_config.prefix_cache_state_budget) cache_config.num_state_caches = num_state_caches mems = StateCacheEngine.get_cache_state_size(cache_config.states_shapes) mems *= num_state_caches - if cache_config.enable_prefix_caching: - cache_config.enable_prefix_caching = False - logger.warning('Prefix caching has not been support for state space model.') - return mems def _sync_spec_cache_block_size(self) -> None: diff --git a/lmdeploy/pytorch/engine/inputs_maker.py b/lmdeploy/pytorch/engine/inputs_maker.py index 1717fcdb1a..3f6babc35f 100644 --- a/lmdeploy/pytorch/engine/inputs_maker.py +++ b/lmdeploy/pytorch/engine/inputs_maker.py @@ -41,6 +41,33 @@ def _tensorlize_block_offsets(block_offsets, dtype=torch.int32): return torch.as_tensor(out, dtype=dtype) +def _compact_state_prefix_cache_restore_offsets(messages: list['SchedulerSequence']): + """Build compact SSM restore src/dst index tensors.""" + src_offsets = [] + dst_offsets = [] + for msg in messages: + state_idx = msg.prefix_cache.restore_state + if state_idx >= 0: + src_offsets.append(state_idx) + dst_offsets.append(msg.logical_state) + if len(src_offsets) == 0: + return None, None + return tuple(src_offsets), tuple(dst_offsets) + + +def _compact_state_prefix_cache_save_offsets(messages: list['SchedulerSequence'], save_state_offsets: list[int]): + """Build compact SSM save src/dst index tensors.""" + src_offsets = [] + dst_offsets = [] + for msg, state_idx in zip(messages, save_state_offsets): + if state_idx >= 0: + src_offsets.append(msg.logical_state) + dst_offsets.append(state_idx) + if len(src_offsets) == 0: + return None, None + return tuple(src_offsets), tuple(dst_offsets) + + @dataclass class InputsMakerConfig: """Input maker config. @@ -82,7 +109,14 @@ def from_engine(engine: 'Engine'): class LongContextChunker: - """Long context chunker.""" + """Split a single long prefill into model-safe chunks. + + Multimodal spans are indivisible, so a span larger than + ``max_prefill_token_num`` temporarily raises the chunk limit. Prefix-cache + restore can skip over the span itself, but the enlarged limit still needs + to be derived from the whole request history so the remaining text tail is + chunked the same way as the no-cache path. + """ def __init__(self, max_prefill_token_num: int): self.max_prefill_token_num = max_prefill_token_num @@ -99,23 +133,24 @@ def is_long_context(self, seq: 'SchedulerSequence'): return seq.num_token_ids > self.max_prefill_token_num def set_seq(self, seq: 'SchedulerSequence'): - """Set seq.""" + """Set the sequence currently being chunked.""" self.seq = seq self.next_step = seq.num_history_ids - # fill multimodals - # if image size exceeds max_prefill_token_num, enlarge it max_prefill_num = self.max_prefill_token_num - mm = seq.get_input_multimodals() + input_mm = seq.get_input_multimodals() + mm_for_chunk_limit = seq.get_chunk_limit_multimodals() self.multimodals = defaultdict(list) + for value in mm_for_chunk_limit.values(): + max_mm_size = max([v.end - v.start for v in value], default=0) + max_prefill_num = max(max_prefill_num, max_mm_size) + has_multimodal = False - for key, value in mm.items(): - # sorted by start + for key, value in input_mm.items(): + # Only remaining multimodals are emitted by next_chunk_size(). value = sorted(value, key=lambda x: x.start) self.multimodals[key] = value - max_mm_size = max([v.end - v.start for v in value], default=0) - max_prefill_num = max(max_prefill_num, max_mm_size) has_multimodal = has_multimodal or len(value) > 0 @@ -135,7 +170,7 @@ def multimodal_iter(self): yield modal_type, data def next_chunk_size(self): - """Get chunk size.""" + """Get the next chunk size and its remaining multimodal payloads.""" seq = self.seq if seq is None: return 0, None @@ -158,7 +193,8 @@ def next_chunk_size(self): if mm.end > end: # | start ... mm.start ... end ... mm.end | - # assume multimodals not overlap + # Do not split a multimodal span; recompute from its start in + # the next chunk instead. end = mm.start break @@ -426,6 +462,28 @@ def create_model_inputs(self, messages: 'SeqList', is_prefill: bool): if self.config.is_ssm: state_offsets = torch.tensor([msg.logical_state for msg in messages]) model_inputs.state_offsets = state_offsets + if (self.cache_config.enable_prefix_caching + and any(msg.prefix_cache.restore_state >= 0 for msg in messages)): + # Pin restore checkpoints while the forward copies them into + # runtime state slots; otherwise checkpoint eviction could race + # with input prefetching for the next batch. + self.scheduler.block_trie.acquire_state_checkpoint_restores(messages) + if any(msg.prefix_cache.restore_state >= 0 and not msg.prefix_cache.restore_state_acquired + for msg in messages): + raise RuntimeError('Failed to acquire SSM prefix-cache restore checkpoint.') + restore_src_offsets, restore_dst_offsets = _compact_state_prefix_cache_restore_offsets(messages) + model_inputs.state_prefix_cache_offsets = restore_src_offsets + model_inputs.state_prefix_cache_dst_offsets = restore_dst_offsets + if self.cache_config.enable_prefix_caching and not is_decoding: + # Prefill saves publish only after model_forward has copied the + # runtime state to these reserved checkpoint offsets. + save_state_offsets = [ + self.scheduler.block_trie.reserve_state_checkpoint_for_seq(msg) for msg in messages + ] + save_src_offsets, save_dst_offsets = _compact_state_prefix_cache_save_offsets(messages, + save_state_offsets) + model_inputs.state_prefix_cache_save_src_offsets = save_src_offsets + model_inputs.state_prefix_cache_save_offsets = save_dst_offsets if self.config.use_mrope: mrope_pos_ids = [msg.mrope_pos_ids for msg in messages] @@ -489,6 +547,21 @@ def create_model_inputs_long_context(self, # ssm if self.config.is_ssm: model_inputs.state_offsets = torch.tensor([seq.logical_state]) + if self.cache_config.enable_prefix_caching and seq.prefix_cache.restore_state >= 0: + # Long-context chunks use the same restore pinning contract as + # normal prefill batches. + self.scheduler.block_trie.acquire_state_checkpoint_restore_for_seq(seq) + if not seq.prefix_cache.restore_state_acquired: + raise RuntimeError('Failed to acquire SSM prefix-cache restore checkpoint.') + model_inputs.state_prefix_cache_offsets = (seq.prefix_cache.restore_state, ) + model_inputs.state_prefix_cache_dst_offsets = (seq.logical_state, ) + if self.cache_config.enable_prefix_caching: + # Save at the exact state step produced by this chunk forward. + checkpoint_step = seq.num_history_ids + chunk_size + save_state = self.scheduler.block_trie.reserve_state_checkpoint_for_seq(seq, step=checkpoint_step) + if save_state >= 0: + model_inputs.state_prefix_cache_save_src_offsets = (seq.logical_state, ) + model_inputs.state_prefix_cache_save_offsets = (save_state, ) # mrope if self.config.use_mrope: @@ -543,6 +616,18 @@ def create_model_inputs_delta(self): sum_kv_seqlen=sum_kv_seqlen, num_ignored_history=num_ignored_history, ) + decode_state_interval = self.cache_config.prefix_cache_decode_state_interval + if (self.cache_config.enable_prefix_caching and self.config.is_ssm and decode_state_interval > 0 + and not self.spec_decoding and num_decode_tokens == 1): + save_state_offsets = [ + self.scheduler.block_trie.reserve_decode_state_checkpoint_for_seq(seq, decode_state_interval) + for seq in valid_seqs + ] + if any(state_idx >= 0 for state_idx in save_state_offsets): + save_src_offsets, save_dst_offsets = _compact_state_prefix_cache_save_offsets(valid_seqs, + save_state_offsets) + output.state_prefix_cache_save_src_offsets = save_src_offsets + output.state_prefix_cache_save_offsets = save_dst_offsets return output, valid_seqs, invalid_seqs @@ -646,6 +731,7 @@ def __create_inputs_chunk(running: 'SeqList'): def __create_inputs_long_context_chunk(): seq = self.long_context_chunker.seq running = [seq] + has_multimodal = self.long_context_chunker.has_multimodal if self.long_context_chunker.is_last_chunk(): inputs, delta, extra_inputs = __create_model_inputs(running) inputs.is_chunk = True @@ -655,7 +741,7 @@ def __create_inputs_long_context_chunk(): inputs, extra_inputs = __create_inputs_chunk(running) delta = None inputs.is_first_chunk = False - inputs.is_chunk_multimodal = self.long_context_chunker.has_multimodal + inputs.is_chunk_multimodal = has_multimodal return running, inputs, delta, extra_inputs def __create_inputs_prefill(): @@ -674,9 +760,19 @@ def __create_inputs_prefill(): if len(running) == 1 and self.long_context_chunker.is_long_context(running[0]): # set long context chunker self.long_context_chunker.set_seq(running[0]) - inputs, extra_inputs = __create_inputs_chunk(running) - inputs.is_first_chunk = True - inputs.is_chunk_multimodal = self.long_context_chunker.has_multimodal + if self.long_context_chunker.is_last_chunk(): + # A prefix-cache restore can skip past a large multimodal + # span, leaving a tail that fits the multimodal-expanded + # chunk limit. Treat it as normal prefill so the model sees + # the same single tail chunk as the no-cache path. Do not + # set chunk flags here: spec decoding uses them as a + # cross-chunk carry protocol. + self.long_context_chunker.clear() + inputs, delta, extra_inputs = __create_model_inputs(running) + else: + inputs, extra_inputs = __create_inputs_chunk(running) + inputs.is_first_chunk = True + inputs.is_chunk_multimodal = self.long_context_chunker.has_multimodal elif len(running) > 0: # create inputs inputs, delta, extra_inputs = __create_model_inputs(running) diff --git a/lmdeploy/pytorch/engine/model_agent/agent.py b/lmdeploy/pytorch/engine/model_agent/agent.py index ff38daafc0..6683426b35 100644 --- a/lmdeploy/pytorch/engine/model_agent/agent.py +++ b/lmdeploy/pytorch/engine/model_agent/agent.py @@ -168,6 +168,14 @@ def model_forward( ) with ctx_mgr.context(context): + if (not inputs.is_dummy and inputs.state_offsets is not None + and inputs.state_prefix_cache_offsets is not None): + # Restore frozen SSM prefix state into this request's runtime + # slot on the forward stream. The input maker already + # compacted valid src/dst pairs on CPU, so no CUDA boolean + # indexing/nonzero synchronization is needed here. + state_cache_engine.copy_caches(inputs.state_prefix_cache_offsets, + inputs.state_prefix_cache_dst_offsets) model_metas = model.update_model_metas( past_key_values=cache_engine.gpu_cache, @@ -183,6 +191,13 @@ def model_forward( # InternVL-3.5-Flash will change the seqlen, model_metas during forward if getattr(context, 'is_model_meta_updated', False): model_metas = context.model_metas + if (not inputs.is_dummy and inputs.state_offsets is not None + and inputs.state_prefix_cache_save_offsets is not None): + # Save the post-forward runtime state into reserved checkpoint + # slots. The scheduler publishes these slots only after the + # executor output boundary confirms the copy was enqueued. + state_cache_engine.copy_caches(inputs.state_prefix_cache_save_src_offsets, + inputs.state_prefix_cache_save_offsets) output['model_metas'] = model_metas output['seq_length'] = context.q_seqlens[:len(inputs.seq_length)] # for draft model reuse diff --git a/lmdeploy/pytorch/messages.py b/lmdeploy/pytorch/messages.py index 0b7195f2dd..c72aa40a5d 100644 --- a/lmdeploy/pytorch/messages.py +++ b/lmdeploy/pytorch/messages.py @@ -10,7 +10,7 @@ from lmdeploy.messages import EngineEvent, EventType, GenerationConfig, LogitsProcessor from lmdeploy.pytorch.disagg.conn.protocol import MigrationRequest -from lmdeploy.pytorch.multimodal.data_type import MultiModalInputs +from lmdeploy.pytorch.multimodal.data_type import MultiModalInputs, make_multimodal_content_hash from lmdeploy.utils import get_logger from lmdeploy.vl.constants import Modality @@ -43,6 +43,60 @@ def move_position(self, offset: int = 0): return self +@dataclass(frozen=True) +class PrefixCacheMeta: + """Multimodal span identity used by prefix-cache block keys. + + Placeholder token ids alone are not enough for VLM prefix caching: two + requests can contain the same image placeholder tokens backed by different + image/video content. The trie key therefore includes every overlapping + span's modality and stable content hash. + """ + + start: int + end: int + modality: str + content_hash: str + + +@dataclass +class PrefixCacheState: + """Per-sequence prefix-cache bookkeeping. + + ``metas`` and ``block_extra_hashes`` are persistent request metadata used + when constructing multimodal-aware trie keys. The restore/save fields are + transient scheduler state for SSM checkpoints: a matched frozen state is + pinned before forward, and pending save slots are published only after the + model has copied runtime state into them. ``last_shared_node`` is the + deepest trie node already shared by this sequence; ``BlockTrie.match()`` + writes it and ``BlockTrie.allocate()`` continues inserting new full blocks + from it. ``match_start_step`` remembers the sequence step before a + tentative prefix-cache match so long-context chunking can distinguish + current-turn cached multimodal spans from older session history. + ``suppress_match_stats`` is set while replaying work after recompute + eviction; cache reuse may still happen, but it should not affect the public + prefix-cache hit-rate metric. + """ + + metas: list[PrefixCacheMeta] = field(default_factory=list) + block_extra_hashes: dict[int, tuple] = field(default_factory=dict, repr=False) + num_indexed_metas: int = 0 + last_shared_node: Any = field(default=None, repr=False) + restore_state: int = -1 + restore_state_acquired: bool = False + restore_node: Any = field(default=None, repr=False) + save_state: int = -1 + save_step: int = 0 + save_is_decode: bool = False + save_node: Any = field(default=None, repr=False) + save_state_acquired: bool = False + save_acquired_state: int = -1 + save_acquired_node: Any = field(default=None, repr=False) + decode_state_node: Any = field(default=None, repr=False) + match_start_step: int = -1 + suppress_match_stats: bool = False + + @dataclass class SamplingParam: """Sampling parameter.""" @@ -629,12 +683,15 @@ class SchedulerSequence: history_cache: HistoryTokenIds = field(default_factory=HistoryTokenIds) history_embeddings: HistoryEmbeddings = field(default_factory=HistoryEmbeddings) history_multimodals: HistoryMultiModals = field(default_factory=HistoryMultiModals) + prefix_cache: PrefixCacheState = field(default_factory=PrefixCacheState) num_new_tokens: int = 0 sampling_param: SamplingParam = field(default_factory=SamplingParam) logical_blocks: LogicalTokenBlocks = field(default_factory=LogicalTokenBlocks) logical_state: int = -1 adapter_name: str = None arrive_time: float = 0.0 + input_start_pos: int = 0 + input_end_pos: int = 0 output_start_pos: int = 0 meta: Any = None num_ignored_history: int = 0 @@ -657,6 +714,9 @@ class SchedulerSequence: # mrope history_mrope_pos_ids: HistoryMropePosIds = field(default_factory=HistoryMropePosIds) + # Prefix-cache tokens accepted by the scheduler that are present in the current request prompt. + cached_tokens: int = 0 + def __post_init__(self): """Post init.""" self._seq_meta: SequenceMeta = self.session.seq_meta @@ -820,6 +880,71 @@ def get_input_multimodals(self): end = self.num_all_ids return self.history_multimodals.get_datas(start, end) + def get_chunk_limit_multimodals(self): + """Get multimodals that should affect long-context chunk size.""" + input_multimodals = self.get_input_multimodals() + match_start = self.prefix_cache.match_start_step + if match_start >= 0 and self.num_history_ids > match_start: + return self.history_multimodals.get_datas(match_start, self.num_all_ids) + return input_multimodals + + def get_prefix_cache_extra_hashes(self, start: int, end: int): + """Get canonical multimodal identity entries for a token range. + + The common caller asks for a full block, but partial ranges are used when verifying sparse SSM checkpoint + candidates. Returning only overlapping spans keeps text-only blocks unchanged while making blocks that touch + multimodal placeholders content-aware. + """ + prefix_cache = self.prefix_cache + if len(prefix_cache.metas) == 0: + return () + + if prefix_cache.num_indexed_metas != len(prefix_cache.metas): + self._index_prefix_cache_metas() + start_block = start // self.block_size + end_block = (max(start, end - 1)) // self.block_size + if start_block == end_block: + extras = prefix_cache.block_extra_hashes.get(start_block, ()) + if start % self.block_size == 0 and end - start == self.block_size: + # Full-block lookup is the hot path; the indexed tuple already + # contains exactly the spans that overlap this block. + return extras + return tuple(extra for extra in extras if extra[0] < end and start < extra[1]) + + extras = [] + for block_id in range(start_block, end_block + 1): + extras.extend(prefix_cache.block_extra_hashes.get(block_id, ())) + extras = [extra for extra in set(extras) if extra[0] < end and start < extra[1]] + return tuple(sorted(extras)) + + def clamp_prefix_cache_match_step(self, step: int): + """Clamp a prefix-cache match so forward never starts inside a span. + + Multimodal processors expect an image/video span to be consumed as a whole. If a candidate cache hit would stop + in the middle of such a span, rewind to the span start and then to a block boundary. Rounding a later span + start down can itself land inside an earlier span when multimodal spans are close together, so keep rewinding + until the final block boundary is outside every span. + """ + if step <= 0: + return step + + spans = [(meta.start, meta.end) for meta in self.prefix_cache.metas] + spans.extend((emb.start, emb.end) for emb in self.history_embeddings.embeddings) + if len(spans) == 0: + return (step // self.block_size) * self.block_size + + clamped = step + while clamped > 0: + next_step = clamped + for start, end in spans: + if start < next_step < end: + next_step = min(next_step, start) + next_step = (next_step // self.block_size) * self.block_size + if next_step == clamped: + break + clamped = next_step + return clamped + def record_event( self, event_type: EventType, @@ -842,8 +967,54 @@ def _update_multimodals(self, multimodals: MultiModalInputs): if multimodals is None: return multimodals = HistoryMultiModals.update_multimodals(multimodals, self.num_valid_ids) + if self.session.scheduler.cache_config.enable_prefix_caching: + self._update_prefix_cache_metas(multimodals) self.history_multimodals.add_inputs(multimodals) + def _update_prefix_cache_metas(self, multimodals: MultiModalInputs): + """Record multimodal span identities for future trie keying.""" + for modal_datas in multimodals.values(): + for modal_data in modal_datas: + modality = modal_data.modality + if isinstance(modality, enum.Enum): + modality = modality.value + content_hash = modal_data.content_hash + if content_hash is None: + # Most request paths precompute the hash after model + # preprocessing. Keep this fallback for unit tests and + # defensive correctness if a processor omits it. + content_hash = make_multimodal_content_hash(modal_data.data, modal_data.meta, + modal_data.mrope_pos_ids) + self.prefix_cache.metas.append( + PrefixCacheMeta(start=modal_data.start, + end=modal_data.end, + modality=str(modality), + content_hash=str(content_hash))) + + def _index_prefix_cache_metas(self): + """Build the lazy block -> multimodal identity index. + + The trie asks for block keys many times during match/allocation, so we pay the span-to-block indexing cost once + per newly appended metadata entry instead of scanning all multimodal spans for every block. + """ + prefix_cache = self.prefix_cache + block_size = self.block_size + new_metas = prefix_cache.metas[prefix_cache.num_indexed_metas:] + if len(new_metas) == 0: + return + + for meta in new_metas: + if meta.end <= meta.start: + continue + extra = (meta.start, meta.end, meta.modality, meta.content_hash) + start_block = meta.start // block_size + end_block = (meta.end - 1) // block_size + for block_id in range(start_block, end_block + 1): + extras = list(prefix_cache.block_extra_hashes.get(block_id, ())) + extras.append(extra) + prefix_cache.block_extra_hashes[block_id] = tuple(sorted(extras)) + prefix_cache.num_indexed_metas = len(prefix_cache.metas) + def _update_mrope_pos_ids(self): """Update mrope pos ids.""" if not self._seq_meta.use_mrope: diff --git a/lmdeploy/pytorch/model_inputs.py b/lmdeploy/pytorch/model_inputs.py index 590683431b..270bc34b0e 100644 --- a/lmdeploy/pytorch/model_inputs.py +++ b/lmdeploy/pytorch/model_inputs.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from collections.abc import Sequence from dataclasses import dataclass, field, fields from typing import TYPE_CHECKING, Any @@ -141,6 +142,9 @@ class ModelInputsDelta: is_decoding: bool = True # sliding window num_ignored_history: torch.Tensor | None = None + # Compact SSM prefix-cache checkpoint save pairs for decode forwards. + state_prefix_cache_save_src_offsets: Sequence[int] | None = None + state_prefix_cache_save_offsets: Sequence[int] | None = None @property def seq_length(self): @@ -193,7 +197,16 @@ class ModelInputs: dp_meta: DPMeta | None = None enable_microbatch: bool = False is_dummy: bool = False + # Runtime SSM state slot ids for each sequence in the batch. state_offsets: torch.Tensor | None = None + # Frozen checkpoint slot ids to restore from before forward. Compact, no sentinels. + state_prefix_cache_offsets: Sequence[int] | None = None + # Runtime state slot ids to restore into before forward. Compact, no sentinels. + state_prefix_cache_dst_offsets: Sequence[int] | None = None + # Runtime state slot ids to save from after forward. Compact, no sentinels. + state_prefix_cache_save_src_offsets: Sequence[int] | None = None + # Reserved checkpoint slot ids to save into after forward. Compact, no sentinels. + state_prefix_cache_save_offsets: Sequence[int] | None = None target_hidden_states: torch.Tensor | None = None target_position_ids: torch.Tensor | None = None target_inputs_embeds: torch.Tensor | None = None @@ -223,6 +236,10 @@ def step(self, input_ids: torch.Tensor, step_seqlens: torch.Tensor = None): history_lengths=self.history_lengths + step_seqlens, max_kv_seqlen=self.max_kv_seqlen + self.max_q_seqlen, sum_kv_seqlen=self.sum_kv_seqlen + self.max_q_seqlen * self.seq_length.numel(), + state_prefix_cache_offsets=None, + state_prefix_cache_dst_offsets=None, + state_prefix_cache_save_src_offsets=None, + state_prefix_cache_save_offsets=None, mrope_pos_ids=mrope_pos_ids, ) diff --git a/lmdeploy/pytorch/multimodal/data_type.py b/lmdeploy/pytorch/multimodal/data_type.py index e5292cc5c5..f778c2aeb5 100644 --- a/lmdeploy/pytorch/multimodal/data_type.py +++ b/lmdeploy/pytorch/multimodal/data_type.py @@ -1,8 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. +import enum +import hashlib from dataclasses import dataclass, fields from typing import Any import numpy as np +import torch from torch import Tensor from lmdeploy.vl.constants import Modality @@ -10,6 +13,46 @@ NestedTensor = Tensor | list[Tensor] +def _hash_multimodal_value(hasher: 'hashlib._Hash', value: Any): + """Update a hash with a deterministic multimodal value representation.""" + if isinstance(value, Tensor): + tensor = value.detach().cpu().contiguous() + hasher.update(f'tensor:{tensor.dtype}:{tuple(tensor.shape)}:'.encode()) + hasher.update(tensor.view(torch.uint8).numpy().tobytes()) + elif isinstance(value, np.ndarray): + array = np.ascontiguousarray(value) + hasher.update(f'ndarray:{array.dtype}:{array.shape}:'.encode()) + hasher.update(array.tobytes()) + elif isinstance(value, dict): + hasher.update(b'dict:{') + for key in sorted(value, key=lambda x: repr(x)): + _hash_multimodal_value(hasher, key) + hasher.update(b':') + _hash_multimodal_value(hasher, value[key]) + hasher.update(b',') + hasher.update(b'}') + elif isinstance(value, (list, tuple)): + hasher.update(f'{type(value).__name__}:['.encode()) + for item in value: + _hash_multimodal_value(hasher, item) + hasher.update(b',') + hasher.update(b']') + elif isinstance(value, enum.Enum): + _hash_multimodal_value(hasher, value.value) + else: + hasher.update(f'{type(value).__name__}:{repr(value)}'.encode()) + + +def make_multimodal_content_hash(data: Any, meta: dict[str, Any] | None, + mrope_pos_ids: np.ndarray | None = None) -> str: + """Create a stable content hash for prefix-cache multimodal matching.""" + hasher = hashlib.sha256() + _hash_multimodal_value(hasher, data) + _hash_multimodal_value(hasher, meta) + _hash_multimodal_value(hasher, mrope_pos_ids) + return hasher.hexdigest() + + @dataclass class MultiModalData: data: NestedTensor @@ -21,6 +64,8 @@ class MultiModalData: # for qwen-vl mrope_pos_ids: np.ndarray | None = None + content_hash: str | None = None + def __post_init__(self): if self.end is None: self.end = self.start @@ -56,3 +101,16 @@ def to_device(self, device: str, non_blocking: bool = False): MultiModalInputs = dict[str, list[MultiModalData]] + + +def ensure_multimodal_content_hashes(input_mms: MultiModalInputs | None): + """Populate missing multimodal content hashes in-place.""" + if input_mms is None: + return input_mms + + for modal_datas in input_mms.values(): + for modal_data in modal_datas: + if modal_data.content_hash is None: + modal_data.content_hash = make_multimodal_content_hash(modal_data.data, modal_data.meta, + modal_data.mrope_pos_ids) + return input_mms diff --git a/lmdeploy/pytorch/paging/block_trie.py b/lmdeploy/pytorch/paging/block_trie.py index d20aa665d2..c32770a7db 100644 --- a/lmdeploy/pytorch/paging/block_trie.py +++ b/lmdeploy/pytorch/paging/block_trie.py @@ -1,14 +1,86 @@ # Copyright (c) OpenMMLab. All rights reserved. +"""Prefix-cache trie ownership and lifecycle. + +``BlockTrie`` owns reusable prefix identity, trie-owned KV block references, +optional SSM state checkpoints, and optional routed-expert replay data. Read +this module together with ``Scheduler._schedule_prefill``, +``InputsMaker.create_model_inputs*``, ``model_forward``, and +``EngineLoop._publish_forward_prefix_cache``. + +Pipeline summary: + +1. The scheduler calls ``match()`` before eviction/allocation. A match mutates + sequence state tentatively: it may append shared KV blocks, advance + ``seq.num_history_ids``, set SSM restore metadata, and replay routed experts. +2. If scheduling later fails, the scheduler rolls that tentative match back. + If it succeeds, ``block_manager.allocate()`` gives blocks for the uncached + suffix, and ``allocate()`` attaches newly completed full blocks to the trie. +3. Text/VLM matching walks trie blocks by adapter root. Each block key is + token ids plus multimodal extra hashes; matches are clamped so forward never + starts inside a multimodal span. +4. SSM matching cannot reuse KV alone. It uses sparse ready checkpoint lookup, + verifies the full ancestor chain, then asks ``ModelAgent`` to copy the + frozen checkpoint state into the request runtime state on the forward stream. +5. SSM checkpoint saves are reserved here, copied by ``ModelAgent`` after + forward, and published by ``EngineLoop`` once the producer forward is queued. + Producer/restore refcounts pin checkpoint slots across async stream-ordering + windows. + +SSM checkpoint detail: + +* ``seq.prefix_cache.last_shared_node`` stores the deepest trie node already + shared by the sequence. ``match()`` writes it, rollback/free clears it, and + ``allocate()`` continues inserting newly computed full blocks from it. +* ``StateManager`` owns one state-cache pool split by role: active requests use + runtime slots stored on ``seq.logical_state``; prefix-cache checkpoints use + slots stored on trie ``Node.state_idx``. A trie node may own KV only, KV plus + an unready checkpoint reservation, or KV plus a ready checkpoint. +* Saving a checkpoint starts from an already-attached block-aligned trie node. + ``reserve_state_checkpoint_for_seq()`` records ``save_state``, ``save_step``, + ``save_node``, and ``save_is_decode`` on ``seq.prefix_cache``. Prefill and + long-context chunks save at the produced chunk end; decode saves are optional + and bounded by ``prefix_cache_decode_state_interval``. +* ``InputsMaker`` converts those pending saves into compact host integer + src/dst pairs. ``ModelAgent`` then copies ``runtime_state -> checkpoint`` on + the model forward stream after the model has produced the new SSM state. + ``EngineLoop`` calls ``commit_state_checkpoint_for_seq()`` after the forward + is queued; only then does ``state_ready`` become true and the sparse + checkpoint index become matchable. The producing forward holds a producer ref + until the output/event boundary, so this early visibility cannot make the + destination slot evictable before the save copy reaches the forward stream. + Abandoned reservations are discarded. +* Matching a SSM prefix never walks KV blocks as the source of truth. + ``_match_state_checkpoint()`` searches ready checkpoint steps, filters by + ``(adapter, step, last_block_hash)``, then verifies every ancestor block's + tokens and multimodal extra hashes before mutating the sequence. A hit + appends trie-owned KV blocks, advances ``seq.num_history_ids``, records + ``restore_state``/``restore_node``, and may replay routed experts. +* Restore is two-phase. The scheduler/input maker pins the ready checkpoint by + incrementing ``state_ref_count``. ``ModelAgent`` copies + ``checkpoint -> runtime_state`` before the suffix forward. ``EngineLoop`` + releases the pin once the copy has been queued, so LRU eviction cannot reuse + the checkpoint source slot too early. +* Checkpoint eviction is state-only LRU over ready, unpinned nodes. KV leaf + eviction also releases any checkpoint owned by that leaf. A KV match without + an exact ready SSM checkpoint is intentionally a miss. +""" + +import enum import heapq +import logging +import time from dataclasses import dataclass import numpy as np from lmdeploy.pytorch.messages import SchedulerSequence +from lmdeploy.utils import get_logger from ..config import CacheConfig from .block_manager import BaseBlockManager +logger = get_logger('lmdeploy') + @dataclass class PrefixCacheStats: @@ -20,18 +92,58 @@ def reset(self): self.num_query_tokens = 0 self.num_hit_tokens = 0 + def copy(self): + """Copy stats for tentative-match rollback.""" + return PrefixCacheStats(num_query_tokens=self.num_query_tokens, num_hit_tokens=self.num_hit_tokens) + def hit_rate(self): return 0.0 if self.num_query_tokens <= 0 else float(self.num_hit_tokens) / self.num_query_tokens +class StateCheckpointVerifyStatus(enum.Enum): + """Outcome of sparse SSM checkpoint verification.""" + HIT = enum.auto() + REQUEST_MISMATCH = enum.auto() + STALE_INDEX_ENTRY = enum.auto() + STALE_CHECKPOINT = enum.auto() + + class Node: - """Node of block trie.""" + """One full-token-block edge in the prefix-cache trie. - def __init__(self, hash_key: int, block: int, tokens: np.ndarray, num_matched: int = 0): + ``extra_hashes`` augments the token block key with VLM content identity. + ``state_idx`` / ``state_ready`` / ``state_ref_count`` are optional SSM + state-checkpoint ownership fields; they are meaningful only when the cache + config has state shapes. ``state_ready`` controls whether the checkpoint + has been published and may be matched. ``state_ref_count`` pins a ready + checkpoint while a restore copy may still read it or a producer save copy + may still write it, so LRU eviction or checkpoint reuse cannot overwrite + the slot too early. + """ + + def __init__(self, + hash_key: int, + block: int, + tokens: np.ndarray, + num_matched: int = 0, + extra_hashes: tuple = (), + state_idx: int = -1, + state_ready: bool = False, + state_ref_count: int = 0, + state_access_time: float = 0.0, + routed_experts: np.ndarray = None, + adapter_name: str = None): self.hash_key = hash_key self.block = block self.tokens = tokens self.num_matched = num_matched + self.extra_hashes = extra_hashes + self.state_idx = state_idx + self.state_ready = state_ready + self.state_ref_count = state_ref_count + self.state_access_time = state_access_time + self.routed_experts = routed_experts + self.adapter_name = adapter_name self.children: dict[int, Node] = dict() self._parent: Node = None @@ -55,43 +167,839 @@ def __le__(self, other): return True +@dataclass +class StateCheckpointVerifyResult: + """Verified checkpoint candidate details.""" + status: StateCheckpointVerifyStatus + reason: str = '' + matched_blocks: list[int] | None = None + matched_nodes: list[Node] | None = None + + class BlockTrie: """Block trie for prefix caching.""" - def __init__(self, cache_config: CacheConfig, block_manager: BaseBlockManager): + def __init__(self, cache_config: CacheConfig, block_manager: BaseBlockManager, state_manager=None): self.block_manager = block_manager self.cache_config = cache_config self.allocator = self.block_manager.allocator + self.state_manager = state_manager self.block_size = cache_config.block_size self.enable = self.cache_config.enable_prefix_caching + self.requires_state_checkpoint = state_manager is not None and len(cache_config.states_shapes) > 0 # caches with different adapter should not be shared. self._roots: dict[str, Node] = dict() self.leaves: set[Node] = set() + # SSM checkpoints are sparse. The trie still owns KV blocks, but ready + # recurrent-state snapshots are indexed only at selected exact steps. + self._state_checkpoint_index: dict[tuple, list[Node]] = dict() + self._state_checkpoint_steps: dict[str, set[int]] = dict() self.stats = PrefixCacheStats() def hit_rate(self): """Get hit rate.""" return self.stats.hit_rate() + def snapshot_stats(self): + """Snapshot prefix-cache stats before a tentative match.""" + if not self.enable: + return None + return self.stats.copy() + + def restore_stats(self, snapshot: PrefixCacheStats | None): + """Restore prefix-cache stats for an unused tentative match.""" + if snapshot is None: + return + self.stats.num_query_tokens = snapshot.num_query_tokens + self.stats.num_hit_tokens = snapshot.num_hit_tokens + + def _record_match_stats(self, seq: SchedulerSequence, query_tokens: int, hit_tokens: int = 0): + """Record a user-visible prefix-cache match attempt.""" + if seq.prefix_cache.suppress_match_stats: + return + self.stats.num_query_tokens += query_tokens + self.stats.num_hit_tokens += hit_tokens + def get_root(self, adapter_name: str): """Get root by adapter name.""" if adapter_name not in self._roots: - self._roots[adapter_name] = Node(-1, -1, None) + self._roots[adapter_name] = Node(-1, -1, None, adapter_name=adapter_name) return self._roots[adapter_name] + @staticmethod + def _get_block_extra_hashes(seq: SchedulerSequence, start: int, end: int): + """Get multimodal identity entries that belong in a block key.""" + return seq.get_prefix_cache_extra_hashes(start, end) + + @staticmethod + def _make_key(tokens: np.ndarray, extra_hashes: tuple): + """Make the trie lookup key from tokens plus multimodal identity.""" + return hash(('random', tuple(tokens), extra_hashes)) + + @staticmethod + def _match_node(node: Node, tokens: np.ndarray, extra_hashes: tuple): + """Check the exact key payload after the hash-table lookup.""" + return np.array_equal(tokens, node.tokens) and extra_hashes == node.extra_hashes + + @staticmethod + def _get_routed_experts_for_range(seq: SchedulerSequence, start: int, end: int): + """Get a copy of routed experts for a full token range, if present.""" + if not seq.return_routed_experts: + return None + all_routed_experts = seq.all_routed_experts + if all_routed_experts is None: + return None + if len(all_routed_experts) < seq.num_history_ids or len(all_routed_experts) < end: + return None + routed_experts = all_routed_experts.get_real() + if routed_experts is None or len(routed_experts) < end: + return None + return routed_experts[start:end].copy() + + def _try_cache_node_routed_experts(self, node: Node, seq: SchedulerSequence, start: int, end: int): + """Attach routed experts to a trie node when a sequence has them.""" + if node.routed_experts is not None: + return + routed_experts = self._get_routed_experts_for_range(seq, start, end) + if routed_experts is not None and len(routed_experts) == end - start: + node.routed_experts = routed_experts + + def _append_matched_routed_experts(self, seq: SchedulerSequence, nodes: list[Node], start: int): + """Replay cached routed experts for a matched trie range.""" + if not seq.return_routed_experts or len(nodes) == 0: + return + if len(seq.all_routed_experts) != start: + return + + expert_slices = [] + for node in nodes: + routed_experts = node.routed_experts + if routed_experts is None or len(routed_experts) != self.block_size: + return + expert_slices.append(routed_experts) + + seq.append_routed_experts(np.concatenate(expert_slices, axis=0).copy()) + + def cache_routed_experts_for_seq(self, seq: SchedulerSequence): + """Enrich attached trie nodes with routed experts from a sequence.""" + if not self.enable or not seq.return_routed_experts: + return + node = seq.prefix_cache.last_shared_node + while node is not None and node.parent is not None: + end = node.num_matched + start = end - self.block_size + self._try_cache_node_routed_experts(node, seq, start, end) + node = node.parent + + def cache_routed_experts(self, seqs: list[SchedulerSequence]): + """Enrich trie nodes with routed experts from multiple sequences.""" + if not self.enable: + return + for seq in seqs: + self.cache_routed_experts_for_seq(seq) + + def _make_state_checkpoint_lookup_key(self, seq: SchedulerSequence, step: int): + """Make the sparse SSM checkpoint lookup key for a sequence prefix. + + The last block key is only a filter into the sparse index. Candidate nodes are still verified by walking the + full ancestor chain so hash collisions or stale index entries cannot produce a false state hit. + """ + start = step - self.block_size + end = step + tokens = seq.history_cache[start:end] + extra_hashes = self._get_block_extra_hashes(seq, start, end) + return (seq.adapter_name, step, self._make_key(tokens, extra_hashes)) + + @staticmethod + def _make_state_checkpoint_node_key(node: Node): + """Make the sparse SSM checkpoint lookup key for a trie node.""" + return (node.adapter_name, node.num_matched, node.hash_key) + + def _index_state_checkpoint(self, node: Node): + """Add a ready state checkpoint to the sparse SSM index.""" + if node.state_idx < 0 or not node.state_ready: + raise RuntimeError('Cannot index an unready SSM prefix-cache checkpoint.') + if not self._is_attached_node(node): + raise RuntimeError('Cannot index a detached SSM prefix-cache checkpoint node.') + key = self._make_state_checkpoint_node_key(node) + nodes = self._state_checkpoint_index.setdefault(key, []) + if not any(indexed_node is node for indexed_node in nodes): + nodes.append(node) + steps = self._state_checkpoint_steps.setdefault(node.adapter_name, set()) + steps.add(node.num_matched) + + def _refresh_state_checkpoint_step(self, adapter_name: str, step: int): + """Drop an adapter step when no indexed checkpoint still owns it.""" + steps = self._state_checkpoint_steps.get(adapter_name) + if steps is None or step not in steps: + return + has_step = any(key[0] == adapter_name and key[1] == step for key in self._state_checkpoint_index) + if not has_step: + steps.remove(step) + if len(steps) == 0: + self._state_checkpoint_steps.pop(adapter_name) + + def _remove_state_checkpoint_index_entry(self, node: Node, key: tuple): + """Remove a node from one sparse-index bucket.""" + nodes = self._state_checkpoint_index.get(key) + if nodes is None: + return False + + old_len = len(nodes) + nodes[:] = [indexed_node for indexed_node in nodes if indexed_node is not node] + if len(nodes) == old_len: + return False + if len(nodes) == 0: + self._state_checkpoint_index.pop(key) + self._refresh_state_checkpoint_step(key[0], key[1]) + return True + + def _unindex_state_checkpoint(self, node: Node): + """Remove a state checkpoint from every sparse-index bucket.""" + removed = False + for key in list(self._state_checkpoint_index): + removed = self._remove_state_checkpoint_index_entry(node, key) or removed + return removed + + def reserve_state_checkpoint(self, node: Node): + """Reserve a state-cache slot owned by a trie node. + + Reusing a ready slot means replacing the checkpoint for the same node, which is allowed only while no restore + copy has it pinned. If the shared state pool is full, evict an old unpinned checkpoint without removing the + trie/KV node itself. + """ + if not self.requires_state_checkpoint or node.parent is None: + return -1 + if node.state_ready: + if node.state_ref_count > 0: + return -1 + if logger.isEnabledFor(logging.DEBUG): + logger.debug(f'Replace SSM prefix-cache checkpoint: adapter={node.adapter_name} ' + f'step={node.num_matched} state_idx={node.state_idx}') + self._unindex_state_checkpoint(node) + elif node.state_idx >= 0: + return -1 + if node.state_idx < 0: + if self.state_manager.get_num_free_checkpoint() == 0 and self.evict_state_checkpoints(1) == 0: + return -1 + node.state_idx = self.state_manager.allocate_checkpoint_state() + node.state_ready = False + return node.state_idx + + def _clear_pending_state_checkpoint(self, seq: SchedulerSequence): + """Clear pending checkpoint save metadata from a sequence.""" + prefix_cache = seq.prefix_cache + prefix_cache.save_state = -1 + prefix_cache.save_step = 0 + prefix_cache.save_is_decode = False + prefix_cache.save_node = None + + @staticmethod + def _clear_save_checkpoint_ref(seq: SchedulerSequence): + """Clear an in-flight producer checkpoint ref from a sequence.""" + prefix_cache = seq.prefix_cache + prefix_cache.save_state_acquired = False + prefix_cache.save_acquired_state = -1 + prefix_cache.save_acquired_node = None + + def discard_state_checkpoint_for_seq(self, seq: SchedulerSequence): + """Discard an unpublished state checkpoint reservation for a sequence. + + Reservations happen before forward. If the executor fails to produce output, or the sequence is rescheduled + before the copy is committed, the unready state slot must be released rather than becoming matchable. + """ + prefix_cache = seq.prefix_cache + state_idx = prefix_cache.save_state + node = prefix_cache.save_node + is_decode = prefix_cache.save_is_decode + self._clear_pending_state_checkpoint(seq) + if state_idx < 0: + return False + if self._is_unpublished_state_checkpoint_reservation(node, state_idx): + if is_decode and prefix_cache.decode_state_node is node: + prefix_cache.decode_state_node = None + if logger.isEnabledFor(logging.DEBUG): + logger.debug(f'Discard SSM prefix-cache checkpoint reservation: session_id={seq.session_id} ' + f'seq_id={seq.seq_id} step={node.num_matched} state_idx={state_idx} ' + f'is_decode={is_decode}') + self.release_state_checkpoint(node) + return True + return False + + def _get_state_checkpoint_node_for_seq(self, seq: SchedulerSequence, step: int): + """Get the trie node that exactly represents a sequence checkpoint + step.""" + node = seq.prefix_cache.last_shared_node + while node is not None and node.num_matched > step: + node = node.parent + if node is None or node.parent is None or node.num_matched != step: + return None + return node + + @staticmethod + def _is_attached_node(node: Node): + """Check whether a node is still attached to the trie.""" + parent = node.parent + return parent is not None and parent.children.get(node.hash_key) is node + + @staticmethod + def _is_attached_leaf(node: Node): + """Check whether a node is a current attached trie leaf.""" + return BlockTrie._is_attached_node(node) and len(node.children) == 0 + + @staticmethod + def _is_evict_candidate_leaf(node: Node): + """Check whether a leaf-set entry can be considered by KV eviction.""" + return (node.block >= 0 and len(node.children) == 0 + and (node.parent is None or BlockTrie._is_attached_node(node))) + + def reserve_state_checkpoint_for_seq(self, + seq: SchedulerSequence, + step: int = None, + is_decode: bool = False): + """Reserve a state checkpoint slot for an exact trie step. + + SSM prefix hits are valid only when KV blocks and recurrent state refer to the same prefix. Therefore saves are + limited to block-aligned, multimodal-safe steps that already have an attached trie node. + """ + self.discard_state_checkpoint_for_seq(seq) + + if not self.enable or not self.requires_state_checkpoint: + return -1 + + if step is None: + step = seq.num_valid_ids + if step <= 0 or step % self.block_size != 0: + return -1 + if step > seq.num_valid_ids: + return -1 + if seq.clamp_prefix_cache_match_step(step) != step: + return -1 + + node = self._get_state_checkpoint_node_for_seq(seq, step) + if node is None: + return -1 + if node.state_ready: + return -1 + + try: + state_idx = self.reserve_state_checkpoint(node) + except RuntimeError as e: + if 'No free states' not in str(e): + raise + return -1 + if state_idx < 0: + return -1 + + prefix_cache = seq.prefix_cache + prefix_cache.save_state = state_idx + prefix_cache.save_step = step + prefix_cache.save_is_decode = is_decode + prefix_cache.save_node = node + if logger.isEnabledFor(logging.DEBUG): + logger.debug(f'Reserve SSM prefix-cache checkpoint: session_id={seq.session_id} ' + f'seq_id={seq.seq_id} step={step} state_idx={state_idx} is_decode={is_decode}') + return state_idx + + def reserve_decode_state_checkpoint_for_seq(self, + seq: SchedulerSequence, + interval: int, + step: int = None): + """Reserve a bounded decode checkpoint for a sequence. + + Decode checkpoints are opt-in and replaceable: keep at most one ready + decode checkpoint per sequence so long generations do not consume the + whole checkpoint budget. The previous ready checkpoint is released + only after the new step is proven eligible. + """ + if step is None: + step = seq.num_valid_ids + if interval <= 0 or step % interval != 0: + return -1 + if not self.enable or not self.requires_state_checkpoint: + return -1 + if step <= 0 or step % self.block_size != 0: + return -1 + if step > seq.num_valid_ids: + return -1 + if seq.clamp_prefix_cache_match_step(step) != step: + return -1 + node = self._get_state_checkpoint_node_for_seq(seq, step) + if node is None or node.state_ready: + return -1 + if node.state_idx >= 0: + return -1 + + prefix_cache = seq.prefix_cache + old_node = prefix_cache.decode_state_node + if old_node is not None and old_node.state_idx < 0: + prefix_cache.decode_state_node = None + old_node = None + if old_node is not None: + if self._is_same_ready_decode_state_checkpoint(old_node, step): + return -1 + if old_node.state_ref_count > 0: + return -1 + if logger.isEnabledFor(logging.DEBUG): + logger.debug(f'Release previous decode SSM prefix-cache checkpoint: ' + f'session_id={seq.session_id} seq_id={seq.seq_id} ' + f'old_step={old_node.num_matched} old_state_idx={old_node.state_idx} ' + f'new_step={step}') + self.release_state_checkpoint(old_node) + prefix_cache.decode_state_node = None + + return self.reserve_state_checkpoint_for_seq(seq, step=step, is_decode=True) + + def mark_state_checkpoint_ready(self, node: Node): + """Mark a node-owned state checkpoint as ready for SSM matching.""" + if node.state_idx < 0: + raise RuntimeError('Cannot mark an unreserved state checkpoint as ready.') + if node.state_ref_count != 0: + raise RuntimeError('Cannot publish a pinned SSM prefix-cache checkpoint.') + if not self._is_attached_node(node): + raise RuntimeError('Cannot publish a detached SSM prefix-cache checkpoint node.') + if node.state_ready: + self._unindex_state_checkpoint(node) + node.state_ready = True + node.state_access_time = time.perf_counter() + self._index_state_checkpoint(node) + + @staticmethod + def _is_same_ready_decode_state_checkpoint(node: Node, step: int): + """Check whether a decode checkpoint for this exact step is ready.""" + return node.num_matched == step and node.state_ready + + def _state_checkpoint_commit_invalid_reason(self, node: Node | None, state_idx: int, save_step: int): + """Return why a pending checkpoint commit is invalid, or ``None``.""" + if node is None: + return 'missing node' + if not self._is_attached_node(node): + return 'detached node' + if node.state_idx != state_idx: + return f'state changed: current={node.state_idx}' + if node.num_matched != save_step: + return f'step changed: current={node.num_matched}' + return None + + @staticmethod + def _is_unpublished_state_checkpoint_reservation(node: Node | None, state_idx: int): + """Check whether an invalid commit still owns an unready + reservation.""" + return node is not None and node.state_idx == state_idx and not node.state_ready + + @staticmethod + def _is_ready_state_checkpoint(node: Node | None, state_idx: int): + """Check whether a node owns a ready checkpoint state slot.""" + return node is not None and node.state_idx == state_idx and node.state_ready + + @staticmethod + def _is_ready_state_checkpoint_node(node: Node): + """Check whether a node has any ready checkpoint state slot.""" + return node.state_idx >= 0 and node.state_ready + + @staticmethod + def _has_state_checkpoint_ref(node: Node | None, state_idx: int): + """Check whether a sequence still owns a checkpoint ref on this + node.""" + return node is not None and node.state_idx == state_idx and node.state_ref_count > 0 + + @staticmethod + def _is_evictable_state_checkpoint(node: Node): + """Check whether a ready checkpoint may be evicted by LRU.""" + return node.state_idx >= 0 and node.state_ready and node.state_ref_count == 0 + + @staticmethod + def _is_pinned_state_checkpoint(node: Node): + """Check whether a checkpoint may still be read by an async restore.""" + return node.state_ref_count > 0 + + def _release_invalid_state_checkpoint_reservation(self, + seq: SchedulerSequence, + node: Node | None, + state_idx: int, + is_decode: bool): + """Release an invalid pending save only if it still owns the slot.""" + if not self._is_unpublished_state_checkpoint_reservation(node, state_idx): + return + if is_decode and seq.prefix_cache.decode_state_node is node: + seq.prefix_cache.decode_state_node = None + self.release_state_checkpoint(node) + + def _acquire_state_checkpoint_save_for_seq(self, seq: SchedulerSequence, node: Node, state_idx: int): + """Pin a just-published checkpoint until its producer forward + completes.""" + prefix_cache = seq.prefix_cache + if prefix_cache.save_state_acquired: + raise RuntimeError('SSM prefix-cache save checkpoint already has an in-flight producer ref.') + if not self._is_ready_state_checkpoint(node, state_idx): + return False + node.state_ref_count += 1 + node.state_access_time = time.perf_counter() + prefix_cache.save_state_acquired = True + prefix_cache.save_acquired_state = state_idx + prefix_cache.save_acquired_node = node + if logger.isEnabledFor(logging.DEBUG): + logger.debug(f'Acquire SSM prefix-cache save checkpoint: session_id={seq.session_id} ' + f'seq_id={seq.seq_id} step={node.num_matched} state_idx={state_idx} ' + f'ref_count={node.state_ref_count}') + return True + + def commit_state_checkpoint_for_seq(self, seq: SchedulerSequence, acquire_save_ref: bool = False): + """Publish a sequence state checkpoint. + + When ``acquire_save_ref`` is true, the checkpoint becomes matchable as + soon as the producer forward is queued, but remains pinned until the + output/event boundary confirms the stream has passed the save copy. + + Commit validates the remembered node directly. This matters for decode saves because the sequence may have + advanced by one sampled token before the output boundary publishes the checkpoint. + """ + prefix_cache = seq.prefix_cache + state_idx = prefix_cache.save_state + save_step = prefix_cache.save_step + is_decode = prefix_cache.save_is_decode + node = prefix_cache.save_node + self._clear_pending_state_checkpoint(seq) + if state_idx < 0: + return False + + invalid_reason = self._state_checkpoint_commit_invalid_reason(node, state_idx, save_step) + if invalid_reason is not None: + if logger.isEnabledFor(logging.DEBUG): + logger.debug(f'Drop invalid SSM prefix-cache checkpoint commit: session_id={seq.session_id} ' + f'seq_id={seq.seq_id} step={save_step} state_idx={state_idx} ' + f'is_decode={is_decode} reason={invalid_reason}') + self._release_invalid_state_checkpoint_reservation(seq, node, state_idx, is_decode) + return False + + self.mark_state_checkpoint_ready(node) + if is_decode: + prefix_cache.decode_state_node = node + if acquire_save_ref: + self._acquire_state_checkpoint_save_for_seq(seq, node, state_idx) + if logger.isEnabledFor(logging.DEBUG): + logger.debug(f'Commit SSM prefix-cache checkpoint: session_id={seq.session_id} ' + f'seq_id={seq.seq_id} step={save_step} state_idx={state_idx} is_decode={is_decode}') + return True + + def commit_state_checkpoints(self, seqs: list[SchedulerSequence], acquire_save_ref: bool = False): + """Publish pending sequence state checkpoints.""" + if not self.enable: + return + for seq in seqs: + self.commit_state_checkpoint_for_seq(seq, acquire_save_ref=acquire_save_ref) + + def acquire_state_checkpoint_restore_for_seq(self, seq: SchedulerSequence): + """Pin a matched state checkpoint until its restore copy has + completed.""" + prefix_cache = seq.prefix_cache + if prefix_cache.restore_state < 0 or prefix_cache.restore_state_acquired: + return False + node = prefix_cache.restore_node + if not self._is_ready_state_checkpoint(node, prefix_cache.restore_state): + return False + node.state_ref_count += 1 + node.state_access_time = time.perf_counter() + prefix_cache.restore_state_acquired = True + if logger.isEnabledFor(logging.DEBUG): + logger.debug(f'Acquire SSM prefix-cache restore checkpoint: session_id={seq.session_id} ' + f'seq_id={seq.seq_id} step={node.num_matched} state_idx={node.state_idx} ' + f'ref_count={node.state_ref_count}') + return True + + def acquire_state_checkpoint_restores(self, seqs: list[SchedulerSequence]): + """Pin matched state checkpoints for a batch.""" + for seq in seqs: + self.acquire_state_checkpoint_restore_for_seq(seq) + + @staticmethod + def _release_state_checkpoint_ref(node: Node | None, state_idx: int, err_msg: str): + """Release one checkpoint ref held by a sequence.""" + if not BlockTrie._has_state_checkpoint_ref(node, state_idx): + raise RuntimeError(err_msg) + node.state_ref_count -= 1 + return node + + def release_state_checkpoint_restore_for_seq(self, seq: SchedulerSequence): + """Release a state checkpoint pinned for restore.""" + prefix_cache = seq.prefix_cache + if not prefix_cache.restore_state_acquired: + return False + node = self._release_state_checkpoint_ref( + prefix_cache.restore_node, + prefix_cache.restore_state, + 'Acquired SSM prefix-cache restore checkpoint lost its node reference.', + ) + if logger.isEnabledFor(logging.DEBUG): + logger.debug(f'Release SSM prefix-cache restore checkpoint: session_id={seq.session_id} ' + f'seq_id={seq.seq_id} step={node.num_matched} state_idx={node.state_idx} ' + f'ref_count={node.state_ref_count}') + prefix_cache.restore_state = -1 + prefix_cache.restore_node = None + prefix_cache.restore_state_acquired = False + return True + + def release_state_checkpoint_restores(self, seqs: list[SchedulerSequence]): + """Release state checkpoints pinned for a batch restore.""" + if not self.enable: + return + for seq in seqs: + self.release_state_checkpoint_restore_for_seq(seq) + + def release_state_checkpoint_save_for_seq(self, seq: SchedulerSequence): + """Release a checkpoint pinned for its producer save copy.""" + prefix_cache = seq.prefix_cache + if not prefix_cache.save_state_acquired: + return False + node = self._release_state_checkpoint_ref( + prefix_cache.save_acquired_node, prefix_cache.save_acquired_state, + 'Acquired SSM prefix-cache save checkpoint lost its node reference.') + if logger.isEnabledFor(logging.DEBUG): + logger.debug(f'Release SSM prefix-cache save checkpoint: session_id={seq.session_id} ' + f'seq_id={seq.seq_id} step={node.num_matched} state_idx={node.state_idx} ' + f'ref_count={node.state_ref_count}') + self._clear_save_checkpoint_ref(seq) + return True + + def release_state_checkpoint_saves(self, seqs: list[SchedulerSequence]): + """Release producer refs held by a batch of saved checkpoints.""" + if not self.enable: + return + for seq in seqs: + self.release_state_checkpoint_save_for_seq(seq) + + def release_state_checkpoint(self, node: Node): + """Release a node-owned state checkpoint while keeping KV ownership.""" + if node.state_ref_count > 0: + raise RuntimeError('Cannot release a pinned SSM prefix-cache checkpoint.') + if node.state_idx < 0: + if node.state_ready: + self._unindex_state_checkpoint(node) + node.state_ready = False + node.state_ref_count = 0 + node.state_access_time = 0.0 + return + if node.state_ready: + self._unindex_state_checkpoint(node) + self.state_manager.free_checkpoint_state(node.state_idx) + node.state_idx = -1 + node.state_ready = False + node.state_ref_count = 0 + node.state_access_time = 0.0 + + def evict_state_checkpoints(self, max_num_states: int): + """Evict ready SSM state checkpoints without removing KV trie nodes.""" + if not self.requires_state_checkpoint or max_num_states <= 0: + return 0 + + candidates = [] + seen_nodes = set() + for nodes in self._state_checkpoint_index.values(): + for node in nodes: + node_id = id(node) + if node_id in seen_nodes: + continue + seen_nodes.add(node_id) + if self._is_evictable_state_checkpoint(node): + candidates.append((node.state_access_time, node)) + heapq.heapify(candidates) + + evicted = 0 + while len(candidates) > 0 and evicted < max_num_states: + _, node = heapq.heappop(candidates) + if not self._is_evictable_state_checkpoint(node): + continue + if logger.isEnabledFor(logging.DEBUG): + logger.debug(f'Evict SSM prefix-cache checkpoint: adapter={node.adapter_name} ' + f'step={node.num_matched} state_idx={node.state_idx}') + self.release_state_checkpoint(node) + evicted += 1 + return evicted + + def _get_node_blocks(self, node: Node): + """Get trie nodes from root to a target node.""" + nodes = [] + while node is not None and node.parent is not None: + nodes.append(node) + node = node.parent + nodes.reverse() + return nodes + + def _drop_stale_state_checkpoint_index_entry(self, node: Node, key: tuple, reason: str): + """Remove a bad sparse-index entry without releasing a valid node.""" + removed = self._remove_state_checkpoint_index_entry(node, key) + if removed and logger.isEnabledFor(logging.DEBUG): + logger.debug(f'Drop stale SSM prefix-cache checkpoint index entry: adapter={key[0]} ' + f'step={key[1]} node_adapter={node.adapter_name} ' + f'node_step={node.num_matched} state_idx={node.state_idx} reason={reason}') + return removed + + def _release_stale_state_checkpoint_candidate(self, node: Node, reason: str): + """Release a globally stale checkpoint candidate if it is unpinned.""" + if self._is_pinned_state_checkpoint(node): + if logger.isEnabledFor(logging.DEBUG): + logger.debug(f'Skip pinned stale SSM prefix-cache checkpoint candidate: ' + f'adapter={node.adapter_name} step={node.num_matched} ' + f'state_idx={node.state_idx} ref_count={node.state_ref_count} ' + f'reason={reason}') + return False + + state_idx = node.state_idx + state_ready = node.state_ready + self._unindex_state_checkpoint(node) + if state_idx >= 0: + self.state_manager.free_checkpoint_state(state_idx) + node.state_idx = -1 + node.state_ready = False + node.state_ref_count = 0 + node.state_access_time = 0.0 + if logger.isEnabledFor(logging.DEBUG): + logger.debug(f'Release stale SSM prefix-cache checkpoint candidate: ' + f'adapter={node.adapter_name} step={node.num_matched} ' + f'state_idx={state_idx} was_ready={state_ready} reason={reason}') + return state_idx >= 0 or state_ready + + def _verify_state_checkpoint_node(self, seq: SchedulerSequence, node: Node, index_key: tuple): + """Verify a sparse SSM checkpoint candidate exactly. + + Matching only the sparse index key is not enough: we require every + ancestor block to match tokens and multimodal extra hashes before + restoring the frozen recurrent state. + """ + if not self._is_ready_state_checkpoint_node(node): + return StateCheckpointVerifyResult(StateCheckpointVerifyStatus.STALE_CHECKPOINT, + reason='checkpoint is not ready') + + step = node.num_matched + if step <= 0: + return StateCheckpointVerifyResult(StateCheckpointVerifyStatus.STALE_CHECKPOINT, + reason=f'invalid checkpoint step: {step}') + + nodes = self._get_node_blocks(node) + if len(nodes) * self.block_size != step: + return StateCheckpointVerifyResult(StateCheckpointVerifyStatus.STALE_CHECKPOINT, + reason='checkpoint ancestor chain is detached') + for block_node in nodes: + if not self._is_attached_node(block_node): + return StateCheckpointVerifyResult(StateCheckpointVerifyStatus.STALE_CHECKPOINT, + reason='checkpoint ancestor link is stale') + + node_key = self._make_state_checkpoint_node_key(node) + if index_key != node_key: + return StateCheckpointVerifyResult(StateCheckpointVerifyStatus.STALE_INDEX_ENTRY, + reason='checkpoint is indexed under a stale key') + + if node.adapter_name != seq.adapter_name: + return StateCheckpointVerifyResult(StateCheckpointVerifyStatus.STALE_INDEX_ENTRY, + reason='checkpoint adapter differs from lookup adapter') + + max_step = ((seq.num_valid_ids - 1) // self.block_size) * self.block_size + if step > max_step: + return StateCheckpointVerifyResult(StateCheckpointVerifyStatus.REQUEST_MISMATCH, + reason='checkpoint is longer than this request') + if seq.clamp_prefix_cache_match_step(step) != step: + return StateCheckpointVerifyResult(StateCheckpointVerifyStatus.REQUEST_MISMATCH, + reason='checkpoint would stop inside a multimodal span') + + matched_blocks = [] + for idx, block_node in enumerate(nodes): + start = idx * self.block_size + end = start + self.block_size + tokens = seq.history_cache[start:end] + extra_hashes = self._get_block_extra_hashes(seq, start, end) + if not self._match_node(block_node, tokens, extra_hashes): + return StateCheckpointVerifyResult(StateCheckpointVerifyStatus.REQUEST_MISMATCH, + reason=f'block payload mismatch at block {idx}') + matched_blocks.append(block_node.block) + + return StateCheckpointVerifyResult(StateCheckpointVerifyStatus.HIT, + matched_blocks=matched_blocks, + matched_nodes=nodes) + + def _match_state_checkpoint(self, seq: SchedulerSequence): + """Match SSM prefixes through sparse ready-checkpoint lookup. + + KV-only reuse is unsafe for SSM models, so this path reports a hit only if a ready recurrent-state checkpoint + exists at the exact matched step. + """ + seq.prefix_cache.restore_state = -1 + seq.prefix_cache.restore_node = None + + init_curr = seq.prefix_cache.last_shared_node + if init_curr is None: + init_curr = self.get_root(seq.adapter_name) + init_num_matched = init_curr.num_matched + + max_step = ((seq.num_valid_ids - 1) // self.block_size) * self.block_size + steps = self._state_checkpoint_steps.get(seq.adapter_name, ()) + for step in sorted((step for step in steps if init_num_matched < step <= max_step), reverse=True): + if seq.clamp_prefix_cache_match_step(step) != step: + continue + key = self._make_state_checkpoint_lookup_key(seq, step) + for node in tuple(self._state_checkpoint_index.get(key, ())): + match_result = self._verify_state_checkpoint_node(seq, node, key) + if match_result.status != StateCheckpointVerifyStatus.HIT: + if match_result.status == StateCheckpointVerifyStatus.STALE_INDEX_ENTRY: + self._drop_stale_state_checkpoint_index_entry(node, key, match_result.reason) + elif match_result.status == StateCheckpointVerifyStatus.STALE_CHECKPOINT: + self._release_stale_state_checkpoint_candidate(node, match_result.reason) + if logger.isEnabledFor(logging.DEBUG): + logger.debug(f'Reject SSM prefix-cache checkpoint candidate: ' + f'session_id={seq.session_id} seq_id={seq.seq_id} step={step} ' + f'state_idx={node.state_idx} status={match_result.status.name} ' + f'reason={match_result.reason}') + continue + + matched_blocks = match_result.matched_blocks + matched_nodes = match_result.matched_nodes + matched_nodes = matched_nodes[init_num_matched // self.block_size:] + matched_blocks = np.array(matched_blocks[init_num_matched // self.block_size:]) + self.allocator.update_access_time(matched_blocks) + self.allocator.add_ref_count(matched_blocks, 1) + seq.logical_blocks.append(matched_blocks) + seq.set_step(step) + self._append_matched_routed_experts(seq, matched_nodes, init_num_matched) + seq.prefix_cache.restore_state = node.state_idx + seq.prefix_cache.restore_node = node + seq.prefix_cache.last_shared_node = node + self._record_match_stats(seq, + query_tokens=seq.num_all_ids - init_num_matched, + hit_tokens=step - init_num_matched) + if logger.isEnabledFor(logging.DEBUG): + logger.debug(f'SSM prefix-cache hit: session_id={seq.session_id} seq_id={seq.seq_id} ' + f'init_step={init_num_matched} matched_step={step} ' + f'state_idx={node.state_idx}') + return + + seq.prefix_cache.last_shared_node = init_curr + self._record_match_stats(seq, query_tokens=seq.num_all_ids - init_num_matched) + if logger.isEnabledFor(logging.DEBUG): + logger.debug(f'SSM prefix-cache miss: session_id={seq.session_id} seq_id={seq.seq_id} ' + f'init_step={init_num_matched} max_step={max_step} ready_steps={len(steps)}') + def match(self, seq: SchedulerSequence): - """Match sequence and cache.""" + """Match reusable prefix blocks for a sequence. + + Text/VLM models walk the trie block by block. SSM models delegate to the sparse checkpoint matcher above + because a KV block match without an exact recurrent-state snapshot must be treated as a miss. + """ if not self.enable: return + seq.prefix_cache.match_start_step = seq.num_history_ids + seq.prefix_cache.restore_state = -1 + seq.prefix_cache.restore_node = None + if self.requires_state_checkpoint: + self._match_state_checkpoint(seq) + return block_size = self.block_size matched_blocks = [] - logical_blocks = seq.logical_blocks - curr: Node = getattr(logical_blocks, 'last_shared_node', None) + curr: Node = seq.prefix_cache.last_shared_node if curr is None: curr = self.get_root(seq.adapter_name) + init_curr = curr init_num_matched = curr.num_matched num_matched = curr.num_matched @@ -101,43 +1009,79 @@ def __match_success(node: Node): curr = node num_matched += block_size + matched_nodes: list[Node] = [] + while num_matched + block_size < seq.num_valid_ids: - curr_tokens = seq.history_cache[num_matched:num_matched + block_size] + start = num_matched + end = num_matched + block_size + curr_tokens = seq.history_cache[start:end] + extra_hashes = self._get_block_extra_hashes(seq, start, end) - key = hash(('random', tuple(curr_tokens))) + key = self._make_key(curr_tokens, extra_hashes) if key not in curr.children: break child = curr.children[key] - if not np.array_equal(curr_tokens, child.tokens): + if not self._match_node(child, curr_tokens, extra_hashes): break + matched_nodes.append(child) __match_success(child) + def __clamp_match_step(match_step: int): + nonlocal curr, num_matched, matched_blocks, matched_nodes + match_step = max(init_num_matched, match_step) + if match_step >= num_matched: + return + # If a candidate hit stopped inside a multimodal span, drop any + # blocks beyond the clamped safe boundary before acquiring refs. + keep = (match_step - init_num_matched) // block_size + matched_nodes = matched_nodes[:keep] + matched_blocks = matched_blocks[:keep] + if keep > 0: + curr = matched_nodes[-1] + num_matched = curr.num_matched + else: + curr = init_curr + num_matched = init_num_matched + + clamped_num_matched = seq.clamp_prefix_cache_match_step(num_matched) + unclamped_num_matched = num_matched + __clamp_match_step(clamped_num_matched) + if len(matched_blocks) > 0: matched_blocks = np.array(matched_blocks) self.allocator.update_access_time(matched_blocks) self.allocator.add_ref_count(matched_blocks, 1) seq.logical_blocks.append(matched_blocks) seq.set_step(num_matched) + self._append_matched_routed_experts(seq, matched_nodes, init_num_matched) + if self.requires_state_checkpoint: + seq.prefix_cache.restore_state = curr.state_idx # record prefix hit - self.stats.num_query_tokens += seq.num_all_ids - init_num_matched - self.stats.num_hit_tokens += num_matched - init_num_matched + self._record_match_stats(seq, + query_tokens=seq.num_all_ids - init_num_matched, + hit_tokens=num_matched - init_num_matched) - seq.logical_blocks.last_shared_node = curr + seq.prefix_cache.last_shared_node = curr + if logger.isEnabledFor(logging.DEBUG): + logger.debug(f'Prefix-cache match: session_id={seq.session_id} seq_id={seq.seq_id} ' + f'init_step={init_num_matched} matched_step={num_matched} ' + f'candidate_step={unclamped_num_matched} ' + f'clamped={clamped_num_matched != unclamped_num_matched}') def allocate(self, seq: SchedulerSequence): - """allocate.""" + """Attach newly allocated full blocks to the prefix-cache trie.""" if not self.enable: return block_size = self.block_size logical_blocks = seq.logical_blocks - node: Node = getattr(logical_blocks, 'last_shared_node', None) + node: Node = seq.prefix_cache.last_shared_node if node is None: node = self.get_root(seq.adapter_name) - logical_blocks.last_shared_node = node + seq.prefix_cache.last_shared_node = node num_matched = node.num_matched num_valid_ids = seq.num_valid_ids @@ -146,33 +1090,46 @@ def allocate(self, seq: SchedulerSequence): return if len(node.children) == 0 and node.parent is not None: - self.leaves.remove(node) + self.leaves.discard(node) block_id = num_matched // block_size blocks = [] free_blocks = [] while num_matched + block_size <= num_valid_ids: - curr_tokens = seq.history_cache[num_matched:num_matched + block_size] + start = num_matched + end = num_matched + block_size + curr_tokens = seq.history_cache[start:end] + extra_hashes = self._get_block_extra_hashes(seq, start, end) block = logical_blocks[block_id] - hash_key = hash(('random', tuple(curr_tokens))) + hash_key = self._make_key(curr_tokens, extra_hashes) parent = node if hash_key in parent.children: child = parent.children[hash_key] - if not np.array_equal(curr_tokens, child.tokens): + if not self._match_node(child, curr_tokens, extra_hashes): break + # Another sequence inserted the same key before us. Reuse the + # trie-owned block and release this sequence's duplicate block. node = child + self._try_cache_node_routed_experts(node, seq, start, end) free_blocks.append(block) logical_blocks[block_id] = node.block else: - node = Node(hash_key=hash_key, block=block, tokens=curr_tokens, num_matched=num_matched + block_size) + routed_experts = self._get_routed_experts_for_range(seq, start, end) + node = Node(hash_key=hash_key, + block=block, + tokens=curr_tokens, + num_matched=num_matched + block_size, + extra_hashes=extra_hashes, + routed_experts=routed_experts, + adapter_name=seq.adapter_name) node.parent = parent blocks.append(node.block) num_matched += block_size block_id += 1 - logical_blocks.last_shared_node = node + seq.prefix_cache.last_shared_node = node if node.parent is not None and len(node.children) == 0: # ignore root self.leaves.add(node) @@ -187,14 +1144,32 @@ def evict(self, max_num_blocks: int): return 0 def __remove_leaf(leaves, evicted_blocks): - _, leaf = heapq.heappop(leaves) + while len(leaves) > 0: + _, leaf = heapq.heappop(leaves) + if leaf not in self.leaves: + continue + if not self._is_evict_candidate_leaf(leaf): + self.leaves.discard(leaf) + continue + if int(self.allocator.get_ref_count(leaf.block)) != 1: + continue + break + else: + return False, None + evicted_blocks.append(leaf.block) + self.release_state_checkpoint(leaf) parent = leaf.parent - leaf.parent = None - self.leaves.remove(leaf) - return parent + if parent is not None: + leaf.parent = None + self.leaves.discard(leaf) + return True, parent def __add_leaf(leaves, parent): + if not self._is_attached_leaf(parent): + return + if parent in self.leaves: + return self.leaves.add(parent) if self.allocator.get_ref_count(parent.block) == 1: access_time = self.allocator.get_access_time(parent.block) @@ -204,7 +1179,11 @@ def __add_leaf(leaves, parent): return 0 evicted_blocks = [] - leaves = list(self.leaves) + leaves = list(leaf for leaf in self.leaves if self._is_evict_candidate_leaf(leaf)) + if len(leaves) != len(self.leaves): + self.leaves.intersection_update(leaves) + if len(leaves) == 0: + return 0 # filter ref-cnt == 1 (trie own one block ref) leave_blocks = np.array(list(leaf.block for leaf in leaves)) @@ -221,8 +1200,10 @@ def __add_leaf(leaves, parent): heapq.heapify(leaves) while len(leaves) > 0 and len(evicted_blocks) < max_num_blocks: - parent = __remove_leaf(leaves, evicted_blocks) - if parent.parent is None: + removed, parent = __remove_leaf(leaves, evicted_blocks) + if not removed: + break + if parent is None or parent.parent is None: # ignore root continue if len(parent.children) == 0: diff --git a/lmdeploy/pytorch/paging/eviction_helper/recompute_eviction_helper.py b/lmdeploy/pytorch/paging/eviction_helper/recompute_eviction_helper.py index 4976eddb20..8e125e56cb 100644 --- a/lmdeploy/pytorch/paging/eviction_helper/recompute_eviction_helper.py +++ b/lmdeploy/pytorch/paging/eviction_helper/recompute_eviction_helper.py @@ -34,6 +34,8 @@ def _evict_for_seq_default(self, seq: SchedulerSequence, evictable_seqs: list[Sc if evict_seq.num_blocks == 0: continue + if block_trie.enable: + evict_seq.prefix_cache.suppress_match_stats = True evict_seq.state.free() num_req = (num_required_blocks - block_manager.get_num_free_gpu_blocks()) if num_req <= 0: @@ -56,12 +58,19 @@ def _evict_for_seq_default(self, seq: SchedulerSequence, evictable_seqs: list[Sc return success def _evict_for_ssm(self, seq: SchedulerSequence, evictable_seqs: list[SchedulerSequence], prealloc_size: int): - """Evict seqs.""" + """Evict blocks and checkpoint states for an SSM sequence. + + SSM scheduling needs both KV blocks and a runtime state slot. Before evicting live sequences, try dropping old + unpinned checkpoints because they are cheaper to recompute than an active request. + """ block_manager = self.block_manager state_manager = self.state_manager block_trie = self.block_trie num_required_blocks = block_manager.num_required_blocks(seq, prealloc_size) - has_free_state = state_manager.get_num_free() > 0 + has_free_state = state_manager.get_num_free_runtime() > 0 + if block_trie.enable and not has_free_state: + block_trie.evict_state_checkpoints(1) + has_free_state = state_manager.get_num_free_runtime() > 0 if has_free_state and block_manager.get_num_free_gpu_blocks() >= num_required_blocks: return True @@ -75,8 +84,13 @@ def _evict_for_ssm(self, seq: SchedulerSequence, evictable_seqs: list[SchedulerS continue # free sequence + if block_trie.enable: + evict_seq.prefix_cache.suppress_match_stats = True evict_seq.state.free() - has_free_state = True + has_free_state = state_manager.get_num_free_runtime() > 0 + if block_trie.enable and not has_free_state: + block_trie.evict_state_checkpoints(1) + has_free_state = state_manager.get_num_free_runtime() > 0 num_req = (num_required_blocks - block_manager.get_num_free_gpu_blocks()) if num_req <= 0: success = True diff --git a/lmdeploy/pytorch/paging/scheduler.py b/lmdeploy/pytorch/paging/scheduler.py index 9063e32172..690a547e90 100644 --- a/lmdeploy/pytorch/paging/scheduler.py +++ b/lmdeploy/pytorch/paging/scheduler.py @@ -1,6 +1,43 @@ # Copyright (c) OpenMMLab. All rights reserved. # modify from: https://github.com/vllm-project/vllm - +"""Request scheduling and prefix-cache side-effect boundaries. + +The scheduler is the first owner of prefix-cache side effects. In prefill, +``BlockTrie.match()`` is intentionally called before eviction and allocation so +the scheduler can account for reused KV/state. That match is tentative: +rollback is required if long-context chunking, checkpoint pinning, KV eviction, +or runtime state allocation means the request cannot safely run now. + +Successful prefill scheduling keeps this order: + +1. ``block_trie.match(seq)`` mutates sequence state to skip a cached prefix. +2. eviction and SSM runtime-state availability are checked. +3. ``block_manager.allocate(seq)`` allocates missing KV blocks. +4. ``block_trie.allocate(seq)`` publishes newly allocated full blocks. +5. For SSM, downstream input/model/engine code restores and saves checkpoint + states; the scheduler only owns resource decisions and rollback. + +SSM scheduling detail: + +* ``block_trie.match(seq)`` may find a ready checkpoint and record + ``seq.prefix_cache.restore_state`` before the request owns a runtime state. + The scheduler must treat that as tentative until KV blocks and one runtime + state slot are guaranteed. +* A matched restore checkpoint can be pinned before eviction so checkpoint LRU + cannot free the source slot. If that pin prevents eviction from finding + enough resources, the scheduler rolls the match back, releases the pin, and + retries eviction once without the tentative hit. +* Runtime state availability is checked after KV eviction because old unpinned + checkpoints may be dropped to free state-cache slots. If no runtime slot can + be recovered, the tentative prefix hit is rolled back and the request waits. +* ``state_manager.allocate(seq)`` assigns the request runtime state only after + ``block_manager.allocate(seq)`` and ``block_trie.allocate(seq)`` succeed. + Later, ``InputsMaker`` may reserve checkpoint saves for the exact produced + step; scheduler code does not perform state-cache tensor copies or publish + checkpoint readiness. +""" + +import logging from collections import OrderedDict from contextlib import contextmanager from dataclasses import dataclass @@ -54,9 +91,9 @@ def __init__( # For Disaggregation self.locked_sessions: dict[int, SchedulerSession] = OrderedDict() - self.block_manager = build_block_manager(cache_config) - self.block_trie = BlockTrie(self.cache_config, self.block_manager) self.state_manager = build_state_manager(self.cache_config) + self.block_manager = build_block_manager(cache_config) + self.block_trie = BlockTrie(self.cache_config, self.block_manager, self.state_manager) self.is_ssm = len(self.cache_config.states_shapes) > 0 self.eviction_helper = build_eviction_helper(self, self.scheduler_config.eviction_type) @@ -70,6 +107,84 @@ def tick(self): """Mark one scheduler progress step (once per forward dispatch).""" self.scheduler_tick += 1 + def _ensure_runtime_state_available(self): + """Make one state-cache slot available for an SSM runtime state. + + Runtime states and frozen checkpoints share the same state-cache pool. Scheduling a request is more important + than keeping an old checkpoint, so unpinned checkpoints are evicted before we give up. + """ + if not self.is_ssm: + return True + if self.state_manager.get_num_free_runtime() > 0: + return True + self.block_trie.evict_state_checkpoints(1) + return self.state_manager.get_num_free_runtime() > 0 + + def _acquire_ssm_restore_if_needed(self, seq: SchedulerSequence): + """Pin a matched SSM checkpoint before scheduler-side eviction.""" + if not self.is_ssm or seq.prefix_cache.restore_state < 0: + return True + return self.block_trie.acquire_state_checkpoint_restore_for_seq(seq) + + def _rollback_unscheduled_prefix_match(self, seq: SchedulerSequence, stats_snapshot=None): + """Drop a tentative prefix match that will not be used now. + + ``block_trie.match()`` mutates sequence state immediately: it advances + the history step, appends shared blocks, and may pin a restore node. + If later eviction or state allocation fails, undo those side effects so + the waiting sequence can be scheduled cleanly in a later round. + """ + self.block_trie.restore_stats(stats_snapshot) + if self.is_ssm: + self.block_trie.release_state_checkpoint_restore_for_seq(seq) + if seq.num_blocks > 0 or seq.logical_state >= 0: + seq.state.free() + elif seq.num_history_ids > 0: + seq.set_step(0) + prefix_cache = seq.prefix_cache + prefix_cache.last_shared_node = None + prefix_cache.restore_state = -1 + prefix_cache.restore_node = None + prefix_cache.restore_state_acquired = False + prefix_cache.match_start_step = -1 + seq.cached_tokens = 0 + + @staticmethod + def _finalize_prefix_cache_match(seq: SchedulerSequence): + """Publish accepted cached-token count within the current prompt.""" + match_start = seq.prefix_cache.match_start_step + if match_start < 0: + seq.cached_tokens = 0 + return + cached_start = match_start + cached_end = seq.num_history_ids + prompt_start = seq.input_start_pos + prompt_end = seq.input_end_pos + seq.cached_tokens = max(0, min(cached_end, prompt_end) - max(cached_start, prompt_start)) + + @staticmethod + def _finish_prefix_cache_schedule(seq: SchedulerSequence): + """Publish match side effects after the sequence is accepted to run.""" + prefix_cache = seq.prefix_cache + if prefix_cache.suppress_match_stats: + seq.cached_tokens = 0 + prefix_cache.suppress_match_stats = False + return + Scheduler._finalize_prefix_cache_match(seq) + + def _prefix_hit_starts_middle_long_context_chunk(self, seq: SchedulerSequence): + """Check whether a prefix hit would start chunking from the middle.""" + if seq.num_history_ids <= 0: + return False + + max_prefill_num = self.cache_config.max_prefill_token_num + mm_for_chunk_limit = seq.get_chunk_limit_multimodals() + for value in mm_for_chunk_limit.values(): + max_mm_size = max([v.end - v.start for v in value], default=0) + max_prefill_num = max(max_prefill_num, max_mm_size) + + return seq.num_token_ids > max_prefill_num + @staticmethod def create_status_list_property(status: MessageStatus): """Create status list property.""" @@ -159,12 +274,13 @@ def _reorder_migrating(): max_batches = self.scheduler_config.max_batches - self.num_ready() - self.num_running() while len(migration_waiting) > 0 and len(migration_ready) < max_batches: seq = migration_waiting.pop(0) - self.block_trie.match(migration_waiting) + self.block_trie.match(seq) if not __evict_for_seq(seq, migration_waiting): break # allocate session memory self.block_manager.allocate(seq) + self._finish_prefix_cache_schedule(seq) _to_running(seq) return migration_ready @@ -211,16 +327,52 @@ def _reorder_waiting(): if (len(running) > 0 and token_count + seq.num_token_ids > self.cache_config.max_prefill_token_num): break - self.block_trie.match(seq) - - if not __evict_for_seq(seq, waiting): - break - - # allocate session memory + if self.block_trie.enable: + stats_snapshot = self.block_trie.snapshot_stats() + + def __rollback_prefix_match(reason: str): + if logger.isEnabledFor(logging.DEBUG): + logger.debug(f'Rollback tentative prefix-cache match: session_id={seq.session_id} ' + f'seq_id={seq.seq_id} reason={reason} num_history_ids={seq.num_history_ids} ' + f'restore_state={seq.prefix_cache.restore_state}') + self._rollback_unscheduled_prefix_match(seq, stats_snapshot) + + self.block_trie.match(seq) + if self._prefix_hit_starts_middle_long_context_chunk(seq): + __rollback_prefix_match('long-context chunk starts after prefix hit') + + had_ssm_restore = self.is_ssm and seq.prefix_cache.restore_state >= 0 + if not self._acquire_ssm_restore_if_needed(seq): + __rollback_prefix_match('failed to acquire SSM restore checkpoint') + + if not __evict_for_seq(seq, waiting): + if not had_ssm_restore: + __rollback_prefix_match('eviction failed') + break + # A matched SSM restore may be pinning the only checkpoint + # state that eviction would otherwise free. Roll it back once + # and retry eviction before declaring the sequence unschedulable. + __rollback_prefix_match('eviction failed with pinned SSM restore') + if not __evict_for_seq(seq, waiting): + break + + # allocate session memory + if self.is_ssm and not self._ensure_runtime_state_available(): + __rollback_prefix_match('no runtime SSM state available') + if not __evict_for_seq(seq, waiting): + break + if not self._ensure_runtime_state_available(): + break + else: + if not __evict_for_seq(seq, waiting): + break self.block_manager.allocate(seq, prealloc_size) - self.block_trie.allocate(seq) + if self.block_trie.enable: + self.block_trie.allocate(seq) if self.is_ssm: self.state_manager.allocate(seq) + if self.block_trie.enable: + self._finish_prefix_cache_schedule(seq) _to_running(seq) seq.record_event(EventType.SCHEDULED) diff --git a/lmdeploy/pytorch/paging/seq_states/states.py b/lmdeploy/pytorch/paging/seq_states/states.py index deac04f4b2..befa7e8ee4 100644 --- a/lmdeploy/pytorch/paging/seq_states/states.py +++ b/lmdeploy/pytorch/paging/seq_states/states.py @@ -9,6 +9,12 @@ def _free_seq(seq: SchedulerSequence, scheduler: 'Scheduler'): """Free the sequence.""" + if scheduler.block_trie.enable: + scheduler.block_trie.discard_state_checkpoint_for_seq(seq) + scheduler.block_trie.release_state_checkpoint_restore_for_seq(seq) + seq.prefix_cache.last_shared_node = None + seq.prefix_cache.match_start_step = -1 + seq.cached_tokens = 0 if seq.num_blocks > 0: scheduler.block_manager.free(seq) if seq.logical_state >= 0: diff --git a/lmdeploy/pytorch/paging/state_manager.py b/lmdeploy/pytorch/paging/state_manager.py index f6e84d0c07..5760748739 100644 --- a/lmdeploy/pytorch/paging/state_manager.py +++ b/lmdeploy/pytorch/paging/state_manager.py @@ -33,33 +33,101 @@ def get_num_free(self): class StateManager: - - def __init__(self, num_states: int, num_reserved: int = 0): + """Manage runtime and checkpoint ownership over one elastic state pool. + + Runtime sequence states have a configurable capacity cap so a large prefix + checkpoint budget cannot starve active requests. Checkpoint states borrow + from the same allocator and are evicted by ``BlockTrie`` when runtime slots + need to be recovered. + """ + + def __init__(self, + num_states: int, + num_reserved: int = 0, + num_runtime_states: int = None): if num_states is None: num_states = 1 - self.allocator = StateAllocator(num_states, offset=num_reserved) + self.num_states = num_states + self.num_reserved = num_reserved + num_available = max(0, num_states - num_reserved) + + if num_runtime_states is None: + num_runtime_states = num_available + num_runtime_states = max(0, min(num_runtime_states, num_available)) + + self.num_runtime_states = num_runtime_states + self.allocator = StateAllocator(num_available, offset=num_reserved) + self._runtime_states: set[int] = set() + self._checkpoint_states: set[int] = set() def is_allocated(self, seq: SchedulerSequence): """Check if a sequence is allocated.""" return seq.logical_state >= 0 + def allocate_state(self): + """Allocate one state-cache slot for an active sequence.""" + if self.get_num_free_runtime() <= 0: + raise RuntimeError('No free states.') + state_id = int(self.allocator.allocate()) + self._runtime_states.add(state_id) + return state_id + + def free_state(self, state_id: int): + """Free one state-cache slot.""" + state_id = int(state_id) + if state_id not in self._runtime_states: + raise RuntimeError(f'State {state_id} is not a runtime state.') + self._runtime_states.remove(state_id) + self.allocator.free(state_id) + + def allocate_checkpoint_state(self): + """Allocate one frozen prefix-cache checkpoint state slot.""" + state_id = int(self.allocator.allocate()) + self._checkpoint_states.add(state_id) + return state_id + + def free_checkpoint_state(self, state_id: int): + """Free one frozen prefix-cache checkpoint state slot.""" + state_id = int(state_id) + if state_id not in self._checkpoint_states: + raise RuntimeError(f'State {state_id} is not a checkpoint state.') + self._checkpoint_states.remove(state_id) + self.allocator.free(state_id) + def allocate(self, seq: SchedulerSequence): """Allocate states for a sequence.""" if self.is_allocated(seq): return None - seq.logical_state = self.allocator.allocate() + seq.logical_state = self.allocate_state() def free(self, seq: SchedulerSequence): """Free states for a sequence.""" if not self.is_allocated(seq): return None - self.allocator.free(seq.logical_state) + self.free_state(seq.logical_state) seq.logical_state = -1 def get_num_free(self): """Get num free.""" return self.allocator.get_num_free() + def get_num_free_runtime(self): + """Get slots still available under the runtime-state cap.""" + free_runtime_capacity = self.num_runtime_states - len(self._runtime_states) + return max(0, min(free_runtime_capacity, self.allocator.get_num_free())) + + def get_num_free_checkpoint(self): + """Get raw free slots that checkpoint saves may reserve.""" + return self.allocator.get_num_free() + + def get_num_runtime_states(self): + """Get num allocated runtime states.""" + return len(self._runtime_states) + + def get_num_allocated_checkpoint_states(self): + """Get num allocated checkpoint states.""" + return len(self._checkpoint_states) + def build_state_manager(cache_config: CacheConfig) -> StateManager: """Build state manager.""" @@ -70,7 +138,8 @@ def build_state_manager(cache_config: CacheConfig) -> StateManager: num_state_caches = num_reserved # `num_state_caches` is the number of allocated cache rows, including - # reserved rows. Allocatable state ids must therefore stop before - # `num_state_caches` to remain valid conv_state row indices. - num_states = max(num_state_caches - num_reserved, 0) - return StateManager(num_states, num_reserved) + # reserved rows. StateManager subtracts reserved rows internally, so pass + # the total row count to keep allocatable state ids below num_state_caches. + return StateManager(num_state_caches, + num_reserved, + num_runtime_states=cache_config.max_batches) diff --git a/lmdeploy/pytorch/strategies/ar/model_inputs.py b/lmdeploy/pytorch/strategies/ar/model_inputs.py index 9fa515f233..bc4754536a 100644 --- a/lmdeploy/pytorch/strategies/ar/model_inputs.py +++ b/lmdeploy/pytorch/strategies/ar/model_inputs.py @@ -1,5 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. +from collections.abc import Sequence + import numpy as np import torch @@ -129,7 +131,9 @@ def index_select_model_inputs(inputs: ModelInputs, max_q_seqlen: int | None = None, max_kv_seqlen: int | None = None, sum_kv_seqlen: int | None = None, - num_ignored_history: torch.Tensor | None = None): + num_ignored_history: torch.Tensor | None = None, + state_prefix_cache_save_src_offsets: Sequence[int] | None = None, + state_prefix_cache_save_offsets: Sequence[int] | None = None): """Index select model inputs by indices.""" assert inputs.is_decoding, 'Only support index_select in decoding.' @@ -192,6 +196,8 @@ def index_select_model_inputs(inputs: ModelInputs, local_adapter_ids=local_adapter_ids, model_metas=model_metas, state_offsets=state_offsets, + state_prefix_cache_save_src_offsets=state_prefix_cache_save_src_offsets, + state_prefix_cache_save_offsets=state_prefix_cache_save_offsets, target_hidden_states=target_hidden_states, target_position_ids=target_position_ids, mrope_pos_ids=mrope_pos_ids, diff --git a/lmdeploy/pytorch/strategies/ar/sequence.py b/lmdeploy/pytorch/strategies/ar/sequence.py index 61f80e9d27..0d39b7e0ae 100644 --- a/lmdeploy/pytorch/strategies/ar/sequence.py +++ b/lmdeploy/pytorch/strategies/ar/sequence.py @@ -50,8 +50,13 @@ def update_token_ids(self, self.append_routed_experts(routed_experts) if mode == UpdateTokenMode.INPUTS: + self.cached_tokens = 0 + self.prefix_cache.suppress_match_stats = False + self.prefix_cache.match_start_step = -1 + self.input_start_pos = self.num_all_ids + self.input_end_pos = self.input_start_pos + num_valid self.arrive_time = time.perf_counter() - self.output_start_pos = self.num_all_ids + len(token_ids) + self.output_start_pos = self.input_end_pos self._num_token_ids += num_valid self.num_new_tokens = 0 else: diff --git a/lmdeploy/pytorch/strategies/ar/step_inputs.py b/lmdeploy/pytorch/strategies/ar/step_inputs.py index 5fad97dcf2..32df5af4a8 100644 --- a/lmdeploy/pytorch/strategies/ar/step_inputs.py +++ b/lmdeploy/pytorch/strategies/ar/step_inputs.py @@ -105,6 +105,8 @@ def _reindex_model_inputs(inputs: ModelInputs, delta: ModelInputsDelta) -> Model max_kv_seqlen=delta.max_kv_seqlen, sum_kv_seqlen=delta.sum_kv_seqlen, num_ignored_history=delta.num_ignored_history, + state_prefix_cache_save_src_offsets=delta.state_prefix_cache_save_src_offsets, + state_prefix_cache_save_offsets=delta.state_prefix_cache_save_offsets, ) diff --git a/lmdeploy/pytorch/strategies/ar_spec/sequence.py b/lmdeploy/pytorch/strategies/ar_spec/sequence.py index b7fae0442f..d2359ca7a5 100644 --- a/lmdeploy/pytorch/strategies/ar_spec/sequence.py +++ b/lmdeploy/pytorch/strategies/ar_spec/sequence.py @@ -57,7 +57,9 @@ def generated_ids(self) -> np.ndarray: def _update_token_ids_inputs(self, token_ids: np.ndarray): """Append tokens.""" num_tokens = len(token_ids) - self.output_start_pos = self.num_valid_ids + num_tokens + self.input_start_pos = self.num_valid_ids + self.input_end_pos = self.input_start_pos + num_tokens + self.output_start_pos = self.input_end_pos self._num_valid_ids = self._num_valid_ids + num_tokens self._num_token_ids = num_tokens self.num_new_tokens = 0 @@ -133,6 +135,9 @@ def update_token_ids(self, if draft_token_ids is not None: draft_token_ids = _to_ndarray(draft_token_ids) if mode == UpdateTokenMode.INPUTS: + self.cached_tokens = 0 + self.prefix_cache.suppress_match_stats = False + self.prefix_cache.match_start_step = -1 self._update_token_ids_inputs(token_ids) elif mode == UpdateTokenMode.PREFILL: self._update_token_ids_prefill(token_ids, draft_token_ids, diff --git a/lmdeploy/pytorch/strategies/ar_spec/step_inputs.py b/lmdeploy/pytorch/strategies/ar_spec/step_inputs.py index 95c9eda0c1..5a6a2ca838 100644 --- a/lmdeploy/pytorch/strategies/ar_spec/step_inputs.py +++ b/lmdeploy/pytorch/strategies/ar_spec/step_inputs.py @@ -41,6 +41,8 @@ def _reindex_model_inputs_arspec( max_kv_seqlen = delta.max_kv_seqlen sum_kv_seqlen = delta.sum_kv_seqlen num_ignored_history = delta.num_ignored_history + state_prefix_cache_save_src_offsets = delta.state_prefix_cache_save_src_offsets + state_prefix_cache_save_offsets = delta.state_prefix_cache_save_offsets # required inputs — reshape by num_spec_tokens+1 for spec decoding inputs_ids = inputs.input_ids.reshape(1, -1, num_spec_tokens + 1) @@ -89,6 +91,8 @@ def _reindex_model_inputs_arspec( local_adapter_ids=local_adapter_ids, model_metas=model_metas, state_offsets=state_offsets, + state_prefix_cache_save_src_offsets=state_prefix_cache_save_src_offsets, + state_prefix_cache_save_offsets=state_prefix_cache_save_offsets, mrope_pos_ids=mrope_pos_ids, ) diff --git a/lmdeploy/pytorch/strategies/dllm/sequence.py b/lmdeploy/pytorch/strategies/dllm/sequence.py index 4b6ac470b4..71d5639f94 100644 --- a/lmdeploy/pytorch/strategies/dllm/sequence.py +++ b/lmdeploy/pytorch/strategies/dllm/sequence.py @@ -193,6 +193,11 @@ def update_token_ids(self, dllm_mask: np.ndarray = _to_ndarray(dllm_mask) if mode == UpdateTokenMode.INPUTS: + self.cached_tokens = 0 + self.prefix_cache.suppress_match_stats = False + self.prefix_cache.match_start_step = -1 + self.input_start_pos = self.num_valid_ids + self.input_end_pos = self.input_start_pos + len(token_ids) self._update_token_ids_inputs(token_ids, dllm_mask) elif mode == UpdateTokenMode.PREFILL: self._update_token_ids_prefill(token_ids, dllm_mask) diff --git a/lmdeploy/pytorch/strategies/dllm/step_inputs.py b/lmdeploy/pytorch/strategies/dllm/step_inputs.py index a7d92d54cb..653c64be23 100644 --- a/lmdeploy/pytorch/strategies/dllm/step_inputs.py +++ b/lmdeploy/pytorch/strategies/dllm/step_inputs.py @@ -134,6 +134,8 @@ def _reindex_model_inputs_dllm( max_kv_seqlen = delta.max_kv_seqlen sum_kv_seqlen = delta.sum_kv_seqlen num_ignored_history = delta.num_ignored_history + state_prefix_cache_save_src_offsets = delta.state_prefix_cache_save_src_offsets + state_prefix_cache_save_offsets = delta.state_prefix_cache_save_offsets # required inputs — reshape by block_size for DLLM inputs_ids = inputs.input_ids.reshape(1, -1, block_size) @@ -176,6 +178,8 @@ def _reindex_model_inputs_dllm( local_adapter_ids=local_adapter_ids, model_metas=model_metas, state_offsets=state_offsets, + state_prefix_cache_save_src_offsets=state_prefix_cache_save_src_offsets, + state_prefix_cache_save_offsets=state_prefix_cache_save_offsets, ) diff --git a/lmdeploy/serve/core/vl_async_engine.py b/lmdeploy/serve/core/vl_async_engine.py index 9e6c9ac25d..d246a20f75 100644 --- a/lmdeploy/serve/core/vl_async_engine.py +++ b/lmdeploy/serve/core/vl_async_engine.py @@ -25,14 +25,17 @@ def __init__(self, if backend == 'pytorch': try_import_deeplink(backend_config.device_type) - if backend_config and backend_config.enable_prefix_caching: - backend_config.enable_prefix_caching = False - logger.warning('Prefix caching is disabled since LMDeploy hasn\'t support in on VL models yet') self.vl_encoder = ImageEncoder(model_path, backend, vision_config, backend_config=backend_config, trust_remote_code=trust_remote_code) + if backend_config and backend_config.enable_prefix_caching: + supports_prefix_caching = backend == 'pytorch' and getattr(self.vl_encoder, '_uses_new_preprocess', False) + if not supports_prefix_caching: + backend_config.enable_prefix_caching = False + logger.warning('Prefix caching is disabled for this VL model path. ' + 'Only PyTorch new-preprocess multimodal inputs are supported.') super().__init__(model_path, backend=backend, backend_config=backend_config, diff --git a/tests/pytorch/engine/test_cache_engine.py b/tests/pytorch/engine/test_cache_engine.py index 85bddccccd..ca601f5588 100644 --- a/tests/pytorch/engine/test_cache_engine.py +++ b/tests/pytorch/engine/test_cache_engine.py @@ -1,12 +1,14 @@ # Copyright (c) OpenMMLab. All rights reserved. import asyncio +import numpy as np import pytest +import torch from lmdeploy.pytorch.config import CacheConfig from lmdeploy.pytorch.disagg.conn.protocol import MigrationProtocol from lmdeploy.pytorch.disagg.messages import MigrationExecutionBatch -from lmdeploy.pytorch.engine.cache_engine import CacheEngine +from lmdeploy.pytorch.engine.cache_engine import CacheEngine, StateCacheEngine def test_allocate_caches_requires_block_size_divisible_by_kernel_block_size(): @@ -35,3 +37,114 @@ def test_pd_migration_rejects_split_kernel_blocks(): with pytest.raises(RuntimeError, match='PD migration does not support block_size != kernel_block_size'): asyncio.run(cache_engine.migrate(migration_inputs)) + + +def _make_state_cache_engine(num_caches: int = 4): + cache_engine = object.__new__(StateCacheEngine) + cache_engine.cache_config = CacheConfig(max_batches=1, + block_size=64, + num_cpu_blocks=0, + num_gpu_blocks=0, + num_state_caches=num_caches, + states_shapes=[((2, 3), torch.float32), ((2, ), torch.float16)]) + cache_engine.mem_pool, cache_engine._state_caches = StateCacheEngine.allocate_caches( + num_caches=num_caches, + state_shapes=cache_engine.cache_config.states_shapes, + device='cpu', + ) + return cache_engine + + +def test_state_cache_engine_copy_caches_copies_all_state_views(): + cache_engine = _make_state_cache_engine() + conv_state, recurrent_state = cache_engine.state_caches + + conv_state[1].fill_(3.0) + recurrent_state[1].fill_(5.0) + + cache_engine.copy_caches(1, 2) + + assert torch.equal(conv_state[2], conv_state[1]) + assert torch.equal(recurrent_state[2], recurrent_state[1]) + + +def test_state_cache_engine_copy_caches_supports_batched_indices(): + cache_engine = _make_state_cache_engine() + conv_state, recurrent_state = cache_engine.state_caches + + conv_state[0].fill_(1.0) + recurrent_state[0].fill_(2.0) + conv_state[1].fill_(3.0) + recurrent_state[1].fill_(4.0) + + cache_engine.copy_caches((1, 0), (3, 2)) + + assert torch.equal(conv_state[2], conv_state[0]) + assert torch.equal(recurrent_state[2], recurrent_state[0]) + assert torch.equal(conv_state[3], conv_state[1]) + assert torch.equal(recurrent_state[3], recurrent_state[1]) + + +def test_state_cache_engine_copy_caches_accepts_host_integer_scalars(): + cache_engine = _make_state_cache_engine() + conv_state, recurrent_state = cache_engine.state_caches + + conv_state[1].fill_(7.0) + recurrent_state[1].fill_(9.0) + + cache_engine.copy_caches(np.int64(1), np.int64(2)) + + assert torch.equal(conv_state[2], conv_state[1]) + assert torch.equal(recurrent_state[2], recurrent_state[1]) + + +def test_state_cache_engine_copy_caches_coalesces_contiguous_ranges(): + ranges = list(StateCacheEngine._copy_ranges([4, 1, 5, 0, 6, 9], [20, 11, 21, 10, 22, 30])) + + assert ranges == [(0, 10, 2), (4, 20, 3), (9, 30, 1)] + assert list(StateCacheEngine._copy_ranges([], [])) == [] + + +def test_state_cache_engine_copy_caches_rejects_mismatched_indices(): + cache_engine = _make_state_cache_engine() + + with pytest.raises(ValueError, match='same number of elements'): + cache_engine.copy_caches([0, 1], [2]) + + +def test_state_cache_engine_copy_caches_rejects_tensor_indices(): + cache_engine = _make_state_cache_engine() + + with pytest.raises(TypeError, match='host integers'): + cache_engine.copy_caches(torch.tensor([0, 1]), torch.tensor([2, 3])) + + +def test_state_cache_engine_copy_caches_rejects_tensor_sequence_items(): + cache_engine = _make_state_cache_engine() + + with pytest.raises(TypeError, match='host integers'): + cache_engine.copy_caches([torch.tensor(0)], [2]) + + +def test_state_cache_engine_copy_caches_rejects_out_of_range_indices(): + cache_engine = _make_state_cache_engine() + + with pytest.raises(ValueError, match='out of range'): + cache_engine.copy_caches([-1], [2]) + + with pytest.raises(ValueError, match='out of range'): + cache_engine.copy_caches([0], [4]) + + +def test_state_cache_engine_copy_caches_rejects_overlapping_indices(): + cache_engine = _make_state_cache_engine() + + with pytest.raises(ValueError, match='must not overlap'): + cache_engine.copy_caches([0, 1], [1, 2]) + + +def test_state_cache_engine_copy_caches_rejects_duplicate_destinations(): + cache_engine = _make_state_cache_engine() + + with pytest.raises(ValueError, match='duplicate'): + cache_engine.copy_caches([0, 1], [2, 2]) diff --git a/tests/pytorch/engine/test_executor_base.py b/tests/pytorch/engine/test_executor_base.py index 6e20fddaba..72ae90029b 100644 --- a/tests/pytorch/engine/test_executor_base.py +++ b/tests/pytorch/engine/test_executor_base.py @@ -2,9 +2,13 @@ from types import SimpleNamespace import pytest +import torch +from lmdeploy.messages import PytorchEngineConfig from lmdeploy.pytorch.config import CacheConfig -from lmdeploy.pytorch.engine.cache_engine import CacheEngine +from lmdeploy.pytorch.disagg.config import EngineRole +from lmdeploy.pytorch.engine.cache_engine import CacheEngine, StateCacheEngine +from lmdeploy.pytorch.engine.config_builder import ConfigBuilder from lmdeploy.pytorch.engine.executor.base import ExecutorBase, _CacheBlockSize @@ -14,7 +18,7 @@ def __init__(self, empty_init: bool): super().__init__( model_path='', model_config=SimpleNamespace(sliding_window=None, states_shapes=None), - cache_config=SimpleNamespace(), + cache_config=SimpleNamespace(role=EngineRole.Hybrid), backend_config=SimpleNamespace(), dist_config=SimpleNamespace(dp=1, world_size=1), misc_config=SimpleNamespace(empty_init=empty_init), @@ -108,6 +112,44 @@ def test_sync_spec_cache_block_size_updates_kernel_block_size(): assert spec_cache_config.kernel_block_size == 16 +def test_executor_disables_prefix_cache_with_spec_decode(): + cache_config = CacheConfig(max_batches=1, + block_size=64, + num_cpu_blocks=0, + num_gpu_blocks=0, + enable_prefix_caching=True) + model_config = SimpleNamespace(sliding_window=None) + + ExecutorBase(model_path='', + model_config=model_config, + cache_config=cache_config, + backend_config=SimpleNamespace(), + dist_config=SimpleNamespace(dp=1, world_size=1), + misc_config=SimpleNamespace(), + specdecode_config=SimpleNamespace()) + + assert not cache_config.enable_prefix_caching + + +def test_executor_disables_prefix_cache_with_pd_role(): + cache_config = CacheConfig(max_batches=1, + block_size=64, + num_cpu_blocks=0, + num_gpu_blocks=0, + enable_prefix_caching=True, + role=EngineRole.Prefill) + model_config = SimpleNamespace(sliding_window=None) + + ExecutorBase(model_path='', + model_config=model_config, + cache_config=cache_config, + backend_config=SimpleNamespace(), + dist_config=SimpleNamespace(dp=1, world_size=1), + misc_config=SimpleNamespace()) + + assert not cache_config.enable_prefix_caching + + def test_get_rank_cache_block_sizes_only_charges_spec_rank(): executor = object.__new__(ExecutorBase) executor.dist_config = SimpleNamespace(attn_tp=2) @@ -163,3 +205,86 @@ def test_update_num_gpu_blocks_can_be_limited_by_non_spec_rank(): assert executor.cache_config.num_gpu_blocks == 3 assert spec_cache_config.num_gpu_blocks == 3 + + +def test_get_state_cache_mem_uses_prefix_cache_state_budget(): + executor = object.__new__(ExecutorBase) + state_shapes = [((2, ), torch.float32)] + executor.cache_config = CacheConfig(max_batches=4, + block_size=64, + num_cpu_blocks=0, + num_gpu_blocks=0, + states_shapes=state_shapes, + prefix_cache_state_budget=3) + + mem = executor._get_state_cache_mem() + + expected_num_state_caches = 4 + 1 + 3 + expected_mem = StateCacheEngine.get_cache_state_size(state_shapes) * expected_num_state_caches + assert executor.cache_config.num_state_caches == expected_num_state_caches + assert mem == expected_mem + + +def test_get_state_cache_mem_keeps_ssm_prefix_cache_enabled_without_extra_budget(): + executor = object.__new__(ExecutorBase) + state_shapes = [((2, ), torch.float32)] + executor.cache_config = CacheConfig(max_batches=4, + block_size=64, + num_cpu_blocks=0, + num_gpu_blocks=0, + states_shapes=state_shapes, + enable_prefix_caching=True, + prefix_cache_state_budget=0) + + executor._get_state_cache_mem() + + assert executor.cache_config.num_state_caches == 4 + 1 + assert executor.cache_config.enable_prefix_caching + + +def test_get_state_cache_mem_keeps_budgeted_ssm_prefix_cache_enabled(): + executor = object.__new__(ExecutorBase) + state_shapes = [((2, ), torch.float32)] + executor.cache_config = CacheConfig(max_batches=4, + block_size=64, + num_cpu_blocks=0, + num_gpu_blocks=0, + states_shapes=state_shapes, + enable_prefix_caching=True, + prefix_cache_state_budget=2) + + executor._get_state_cache_mem() + + assert executor.cache_config.num_state_caches == 4 + 1 + 2 + assert executor.cache_config.enable_prefix_caching + + +def test_get_state_cache_mem_leaves_non_ssm_prefix_cache_enabled(): + executor = object.__new__(ExecutorBase) + executor.cache_config = CacheConfig(max_batches=4, + block_size=64, + num_cpu_blocks=0, + num_gpu_blocks=0, + enable_prefix_caching=True, + prefix_cache_state_budget=0) + + mem = executor._get_state_cache_mem() + + assert mem == 0 + assert executor.cache_config.enable_prefix_caching + + +def test_build_cache_config_carries_prefix_cache_state_budget(): + engine_config = PytorchEngineConfig(max_batch_size=4, + prefix_cache_state_budget=3, + prefix_cache_decode_state_interval=128) + + cache_config = ConfigBuilder.build_cache_config(engine_config) + + assert cache_config.prefix_cache_state_budget == 3 + assert cache_config.prefix_cache_decode_state_interval == 128 + + +def test_engine_config_rejects_unaligned_prefix_cache_decode_state_interval(): + with pytest.raises(AssertionError): + PytorchEngineConfig(max_batch_size=4, prefix_cache_decode_state_interval=96) diff --git a/tests/pytorch/engine/test_inputs_maker.py b/tests/pytorch/engine/test_inputs_maker.py new file mode 100644 index 0000000000..5923722877 --- /dev/null +++ b/tests/pytorch/engine/test_inputs_maker.py @@ -0,0 +1,335 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import asyncio +from dataclasses import dataclass +from types import SimpleNamespace + +from lmdeploy.pytorch.disagg.config import EngineRole +from lmdeploy.pytorch.engine.engine_loop import EngineLoop +from lmdeploy.pytorch.engine.inputs_maker import ( + InputsMakerAsync, + LongContextChunker, + _compact_state_prefix_cache_restore_offsets, + _compact_state_prefix_cache_save_offsets, +) +from lmdeploy.pytorch.messages import MessageStatus + + +@dataclass +class _DummyMultiModal: + start: int + end: int + + +class _DummySeq: + + def __init__(self, + history_ids: int, + token_ids: int, + all_multimodals: dict, + input_multimodals: dict, + match_start_step: int = -1): + self.num_history_ids = history_ids + self.num_token_ids = token_ids + self.history_multimodals = SimpleNamespace(multimodals=all_multimodals) + self._input_multimodals = input_multimodals + self.prefix_cache = SimpleNamespace(match_start_step=match_start_step) + self.return_logits = False + self.return_routed_experts = False + + def get_input_multimodals(self): + return self._input_multimodals + + def get_chunk_limit_multimodals(self): + match_start = self.prefix_cache.match_start_step + if match_start >= 0 and self.num_history_ids > match_start: + end = self.num_history_ids + self.num_token_ids + return { + key: [mm for mm in value if match_start <= mm.start and mm.end <= end] + for key, value in self.history_multimodals.multimodals.items() + } + return self.get_input_multimodals() + + +def _state_seq(logical_state: int, restore_state: int = -1): + return SimpleNamespace(logical_state=logical_state, + prefix_cache=SimpleNamespace(restore_state=restore_state)) + + +class _FakeScheduler: + + def __init__(self, running): + self.running = running + + def schedule(self, is_prefill: bool, prealloc_size: int): + return SimpleNamespace(running=self.running, swap_in_map={}, swap_out_map={}) + + +class _FakeEngineStrategy: + + def get_prealloc_size(self, is_prefill: bool): + return 0 + + +class _FakeSamplingStrategy: + + def make_sampling_inputs(self, running): + return None + + +class _FakeModelAgentStrategy: + + def make_extra_inputs(self, running, inputs): + return None + + def make_stopping_criteria(self, running): + return None + + +def test_engine_loop_skips_prefix_cache_publish_when_disabled(): + + class _DisabledBlockTrie: + enable = False + + def commit_state_checkpoints(self, seqs): + raise AssertionError('disabled prefix cache must not commit state checkpoints') + + def release_state_checkpoint_restores(self, seqs): + raise AssertionError('disabled prefix cache must not release state checkpoint restores') + + loop = EngineLoop.__new__(EngineLoop) + loop.scheduler = SimpleNamespace(block_trie=_DisabledBlockTrie()) + + loop._publish_forward_prefix_cache([object()], has_state_checkpoint_save=True) + + +def test_engine_loop_keeps_state_save_pinned_until_output_boundary(): + events = [] + + class _BlockTrie: + enable = True + pinned = False + + def commit_state_checkpoints(self, seqs, acquire_save_ref=False): + events.append(('commit', acquire_save_ref)) + assert acquire_save_ref + self.pinned = True + + def release_state_checkpoint_restores(self, seqs): + events.append(('release_restore', self.pinned)) + + def release_state_checkpoint_saves(self, seqs): + events.append(('release_save', self.pinned)) + self.pinned = False + + class _InputsMaker: + + def __init__(self, block_trie): + self.block_trie = block_trie + + def update_running_seqs(self, running, model_inputs): + events.append('update_running') + + async def prefetch_next_inputs(self): + events.append(('prefetch', self.block_trie.pinned)) + return None, None + + class _Executor: + + def __init__(self, block_trie): + self.block_trie = block_trie + + async def get_output_async(self): + events.append(('get_output', self.block_trie.pinned)) + return None + + block_trie = _BlockTrie() + loop = EngineLoop.__new__(EngineLoop) + loop.scheduler = SimpleNamespace(block_trie=block_trie, collect_migration_done=lambda: None) + loop.inputs_maker = _InputsMaker(block_trie) + loop.executor = _Executor(block_trie) + model_inputs = SimpleNamespace(state_prefix_cache_save_offsets=[1]) + forward_inputs = dict(inputs=model_inputs, delta=None) + + forward_inputs, next_running = asyncio.run(loop._main_loop_get_outputs([object()], forward_inputs)) + + assert forward_inputs is None + assert next_running is None + assert events == [ + 'update_running', + ('commit', True), + ('release_restore', True), + ('prefetch', True), + ('get_output', True), + ('release_save', True), + ] + assert not block_trie.pinned + + +def test_long_context_chunker_uses_cached_multimodal_size_for_chunk_limit(): + image = _DummyMultiModal(start=512, end=5888) + seq = _DummySeq( + history_ids=5888, + token_ids=1056, + all_multimodals={'image': [image]}, + input_multimodals={}, + match_start_step=0, + ) + + chunker = LongContextChunker(max_prefill_token_num=512) + assert chunker.is_long_context(seq) + + chunker.set_seq(seq) + + assert chunker.max_prefill_num == 5376 + assert chunker.is_last_chunk() + chunk_size, multimodals = chunker.next_chunk_size() + assert chunk_size == 1056 + assert multimodals is None + + +def test_long_context_chunker_only_tracks_remaining_multimodals(): + cached_image = _DummyMultiModal(start=512, end=5888) + remaining_image = _DummyMultiModal(start=6400, end=7424) + seq = _DummySeq( + history_ids=5888, + token_ids=2000, + all_multimodals={'image': [cached_image, remaining_image]}, + input_multimodals={'image': [remaining_image]}, + match_start_step=0, + ) + + chunker = LongContextChunker(max_prefill_token_num=512) + chunker.set_seq(seq) + chunk_size, multimodals = chunker.next_chunk_size() + + assert chunker.max_prefill_num == 5376 + assert chunk_size == 2000 + assert multimodals == {'image': [remaining_image]} + + +def test_single_forward_multimodal_long_context_stays_normal_prefill_for_spec_decoding(): + image = _DummyMultiModal(start=0, end=1024) + seq = _DummySeq( + history_ids=0, + token_ids=1024, + all_multimodals={'image': [image]}, + input_multimodals={'image': [image]}, + ) + model_inputs = SimpleNamespace(is_decoding=False, + is_chunk=False, + is_first_chunk=False, + is_last_chunk=False, + is_chunk_multimodal=False) + maker = InputsMakerAsync.__new__(InputsMakerAsync) + maker.config = SimpleNamespace(role=EngineRole.Decode) + maker.spec_decoding = True + maker.scheduler = _FakeScheduler([seq]) + maker.engine_strategy = _FakeEngineStrategy() + maker.sampling_strategy = _FakeSamplingStrategy() + maker.model_agent_strategy = _FakeModelAgentStrategy() + maker.long_context_chunker = LongContextChunker(max_prefill_token_num=512) + maker.running_seqs = [] + maker.to_evict_seqs = [] + maker._decode_count = 0 + maker.create_model_inputs = lambda seqs, is_prefill: model_inputs + maker.create_model_inputs_delta_valid_only = lambda: (None, [], []) + + forward_inputs = maker._make_forward_inputs(prefill=True) + + assert forward_inputs['inputs'] is model_inputs + assert not model_inputs.is_chunk + assert not model_inputs.is_first_chunk + assert not model_inputs.is_last_chunk + assert not model_inputs.is_chunk_multimodal + + +def test_spec_decoding_text_turn_ignores_previous_multimodal_chunk_limit(): + previous_image = _DummyMultiModal(start=512, end=5888) + seq = _DummySeq( + history_ids=5888, + token_ids=1056, + all_multimodals={'image': [previous_image]}, + input_multimodals={}, + ) + model_inputs = SimpleNamespace(is_decoding=False, + is_chunk=False, + is_first_chunk=False, + is_last_chunk=False, + is_chunk_multimodal=False) + maker = InputsMakerAsync.__new__(InputsMakerAsync) + maker.config = SimpleNamespace(role=EngineRole.Decode) + maker.spec_decoding = True + maker.scheduler = _FakeScheduler([seq]) + maker.engine_strategy = _FakeEngineStrategy() + maker.sampling_strategy = _FakeSamplingStrategy() + maker.model_agent_strategy = _FakeModelAgentStrategy() + maker.long_context_chunker = LongContextChunker(max_prefill_token_num=512) + maker.running_seqs = [] + maker.to_evict_seqs = [] + maker._decode_count = 0 + maker.create_model_inputs_long_context = lambda seq, chunk_size, multimodals: model_inputs + + forward_inputs = maker._make_forward_inputs(prefill=True) + + assert forward_inputs['inputs'] is model_inputs + assert model_inputs.is_first_chunk + assert not model_inputs.is_chunk_multimodal + + +def test_long_context_final_chunk_preserves_multimodal_flag_for_spec_decoding(): + image = _DummyMultiModal(start=0, end=1024) + seq = _DummySeq( + history_ids=512, + token_ids=512, + all_multimodals={'image': [image]}, + input_multimodals={}, + ) + model_inputs = SimpleNamespace(is_decoding=False, + is_chunk=False, + is_first_chunk=False, + is_last_chunk=False, + is_chunk_multimodal=False) + maker = InputsMakerAsync.__new__(InputsMakerAsync) + maker.config = SimpleNamespace(role=EngineRole.Decode) + maker.spec_decoding = True + maker.scheduler = _FakeScheduler([]) + maker.engine_strategy = _FakeEngineStrategy() + maker.sampling_strategy = _FakeSamplingStrategy() + maker.model_agent_strategy = _FakeModelAgentStrategy() + maker.long_context_chunker = LongContextChunker(max_prefill_token_num=512) + seq.status = MessageStatus.RUNNING + maker.long_context_chunker.seq = seq + maker.long_context_chunker.has_multimodal = True + maker.long_context_chunker.max_prefill_num = 512 + maker.running_seqs = [] + maker.to_evict_seqs = [] + maker._decode_count = 0 + maker.create_model_inputs = lambda seqs, is_prefill: model_inputs + maker.create_model_inputs_delta_valid_only = lambda: (None, [seq], []) + + forward_inputs = maker._make_forward_inputs(prefill=True) + + assert forward_inputs['inputs'] is model_inputs + assert model_inputs.is_chunk + assert not model_inputs.is_first_chunk + assert model_inputs.is_last_chunk + assert model_inputs.is_chunk_multimodal + assert not maker.long_context_chunker.enabled() + + +def test_state_prefix_cache_restore_offsets_are_compact(): + messages = [_state_seq(4, 11), _state_seq(5, -1), _state_seq(6, 13)] + + src_offsets, dst_offsets = _compact_state_prefix_cache_restore_offsets(messages) + + assert src_offsets == (11, 13) + assert dst_offsets == (4, 6) + + +def test_state_prefix_cache_save_offsets_are_compact(): + messages = [_state_seq(4), _state_seq(5), _state_seq(6)] + + src_offsets, dst_offsets = _compact_state_prefix_cache_save_offsets(messages, [-1, 21, 22]) + + assert src_offsets == (5, 6) + assert dst_offsets == (21, 22) diff --git a/tests/pytorch/paging/test_block_trie.py b/tests/pytorch/paging/test_block_trie.py index 5736e4d006..8e5cf21033 100644 --- a/tests/pytorch/paging/test_block_trie.py +++ b/tests/pytorch/paging/test_block_trie.py @@ -1,9 +1,13 @@ import numpy as np import pytest +import torch +from lmdeploy.pytorch import messages as messages_module from lmdeploy.pytorch.config import CacheConfig, SchedulerConfig -from lmdeploy.pytorch.messages import SequenceMeta +from lmdeploy.pytorch.messages import SamplingParam, SequenceMeta, UpdateTokenMode +from lmdeploy.pytorch.multimodal.data_type import MultiModalData from lmdeploy.pytorch.paging import Scheduler +from lmdeploy.vl.constants import Modality class TestBlockTire: @@ -49,6 +53,21 @@ def seq_meta(self, block_size): def scheduler(self, cache_config, scheduler_config, seq_meta): yield Scheduler(scheduler_config=scheduler_config, cache_config=cache_config, seq_meta=seq_meta) + @pytest.fixture + def ssm_cache_config(self, block_size, num_cpu_blocks, num_gpu_blocks, max_batch_size): + yield CacheConfig(max_batches=max_batch_size, + block_size=block_size, + num_cpu_blocks=num_cpu_blocks, + num_gpu_blocks=num_gpu_blocks, + enable_prefix_caching=True, + num_state_caches=max_batch_size + 1 + 8, + prefix_cache_state_budget=8, + states_shapes=[((1, ), torch.float32)]) + + @pytest.fixture + def ssm_scheduler(self, ssm_cache_config, scheduler_config, seq_meta): + yield Scheduler(scheduler_config=scheduler_config, cache_config=ssm_cache_config, seq_meta=seq_meta) + @pytest.fixture def block_mgr(self, scheduler): yield scheduler.block_manager @@ -57,6 +76,49 @@ def block_mgr(self, scheduler): def block_trie(self, scheduler): yield scheduler.block_trie + def _image_multimodals(self, + start: int, + end: int, + value: float, + image_token_id: int = 99, + content_hash: str | None = None): + data = torch.full((2, 2), value, dtype=torch.float32) + return dict(image=[MultiModalData(data=data, + start=start, + end=end, + meta=dict(image_token_id=image_token_id), + content_hash=content_hash)]) + + def _modal_data(self, start: int, end: int, value: float, modality: Modality): + data = torch.full((2, 2), value, dtype=torch.float32) + return MultiModalData(data=data, + start=start, + end=end, + modality=modality, + meta=dict(token_id=int(value))) + + def _multi_image_multimodals(self, spans: list[tuple[int, int, float]]): + return dict(image=[ + MultiModalData(data=torch.full((2, 2), value, dtype=torch.float32), + start=start, + end=end, + modality=Modality.IMAGE, + meta=dict(image_token_id=99)) for start, end, value in spans + ]) + + def _routed_experts(self, num_tokens: int, offset: int = 0): + values = np.arange(offset, offset + num_tokens * 2, dtype=np.uint16) + return values.reshape(num_tokens, 2, 1) + + def _add_ready_ssm_checkpoint(self, scheduler, token_ids): + seq = scheduler.add_session(len(scheduler.sessions)).add_sequence(token_ids) + scheduler.block_manager.allocate(seq) + scheduler.block_trie.allocate(seq) + state_idx = scheduler.block_trie.reserve_state_checkpoint_for_seq(seq) + assert state_idx >= 0 + assert scheduler.block_trie.commit_state_checkpoint_for_seq(seq) + return seq, seq.prefix_cache.last_shared_node, state_idx + def test_allocate(self, block_trie, block_mgr, scheduler): allocator = block_trie.allocator sess = scheduler.add_session(0) @@ -72,7 +134,7 @@ def test_allocate(self, block_trie, block_mgr, scheduler): assert len(logical_blocks) == 3 ref_cnt = allocator.get_ref_count(logical_blocks.get_real_blocks()) assert np.array_equal(ref_cnt, [2, 2, 1]) - node = getattr(seq.logical_blocks, 'last_shared_node', None) + node = seq.prefix_cache.last_shared_node assert node is not None assert node.num_matched == block_size * 2 assert np.array_equal(node.tokens, [2] * block_size) @@ -88,7 +150,7 @@ def test_allocate(self, block_trie, block_mgr, scheduler): assert len(logical_blocks) == 4 ref_cnt = allocator.get_ref_count(logical_blocks.get_real_blocks()) assert np.array_equal(ref_cnt, [2, 2, 2, 1]) - node = getattr(seq.logical_blocks, 'last_shared_node', None) + node = seq.prefix_cache.last_shared_node assert node is not None assert node.num_matched == block_size * 3 expect_tokens = [3] * (block_size // 2) + [4] * (block_size // 2) @@ -116,7 +178,7 @@ def test_match(self, block_trie, block_mgr, scheduler): assert len(logical_blocks) == 1 ref_cnt = allocator.get_ref_count(logical_blocks.get_real_blocks()) assert np.array_equal(ref_cnt, [3]) - node = getattr(seq.logical_blocks, 'last_shared_node', None) + node = seq.prefix_cache.last_shared_node assert node is not None assert node.num_matched == block_size assert np.array_equal(node.tokens, [1] * block_size) @@ -134,6 +196,351 @@ def test_match(self, block_trie, block_mgr, scheduler): ref_cnt = allocator.get_ref_count(logical_blocks.get_real_blocks()) assert np.array_equal(ref_cnt, [4, 3]) + def test_match_after_sequence_blocks_are_freed(self, block_trie, block_mgr, scheduler): + sess = scheduler.add_session(0) + block_size = sess.seq_meta.block_size + token_ids = [1] * block_size + [2] * block_size + [3] * (block_size // 2) + seq = sess.add_sequence(token_ids) + + block_mgr.allocate(seq) + block_trie.allocate(seq) + seq.state.free() + + assert seq.num_history_ids == 0 + assert len(seq.logical_blocks) == 0 + assert seq.prefix_cache.last_shared_node is None + + block_trie.match(seq) + + assert seq.num_history_ids == block_size * 2 + assert len(seq.logical_blocks) == 2 + node = seq.prefix_cache.last_shared_node + assert node is not None + assert node.num_matched == block_size * 2 + + def test_match_replays_cached_routed_experts(self, block_trie, block_mgr, scheduler): + sess = scheduler.add_session(0) + block_size = sess.seq_meta.block_size + token_ids = [1] * block_size + [2] * block_size + [3] + sampling_param = SamplingParam(return_routed_experts=True) + seq = sess.add_sequence(token_ids, sampling_param=sampling_param) + experts = self._routed_experts(block_size * 2) + + block_mgr.allocate(seq) + block_trie.allocate(seq) + seq.append_routed_experts(experts) + block_trie.cache_routed_experts_for_seq(seq) + + matched = sess.add_sequence(token_ids, sampling_param=sampling_param) + block_trie.match(matched) + + assert matched.num_history_ids == block_size * 2 + assert np.array_equal(matched.all_routed_experts.get_real(), experts) + + def test_match_skips_routed_expert_replay_when_not_requested(self, block_trie, block_mgr, scheduler): + sess = scheduler.add_session(0) + block_size = sess.seq_meta.block_size + token_ids = [1] * block_size + [2] * block_size + [3] + seq = sess.add_sequence(token_ids, sampling_param=SamplingParam(return_routed_experts=True)) + + block_mgr.allocate(seq) + block_trie.allocate(seq) + seq.append_routed_experts(self._routed_experts(block_size * 2)) + block_trie.cache_routed_experts_for_seq(seq) + + matched = sess.add_sequence(token_ids) + block_trie.match(matched) + + assert matched.num_history_ids == block_size * 2 + assert len(matched.all_routed_experts) == 0 + + def test_existing_node_can_be_enriched_with_routed_experts(self, block_trie, block_mgr, scheduler): + sess = scheduler.add_session(0) + block_size = sess.seq_meta.block_size + token_ids = [1] * block_size + [2] * block_size + [3] + + seq = sess.add_sequence(token_ids) + block_mgr.allocate(seq) + block_trie.allocate(seq) + node = seq.prefix_cache.last_shared_node + assert node is not None + assert node.routed_experts is None + + expert_seq = sess.add_sequence(token_ids, sampling_param=SamplingParam(return_routed_experts=True)) + experts = self._routed_experts(block_size * 2) + expert_seq.append_routed_experts(experts) + block_mgr.allocate(expert_seq) + block_trie.allocate(expert_seq) + + assert node.routed_experts is not None + matched = sess.add_sequence(token_ids, sampling_param=SamplingParam(return_routed_experts=True)) + block_trie.match(matched) + assert np.array_equal(matched.all_routed_experts.get_real(), experts) + + def test_match_does_not_partially_replay_routed_experts(self, block_trie, block_mgr, scheduler): + sess = scheduler.add_session(0) + block_size = sess.seq_meta.block_size + token_ids = [1] * block_size + [2] * block_size + [3] + seq = sess.add_sequence(token_ids, sampling_param=SamplingParam(return_routed_experts=True)) + + block_mgr.allocate(seq) + block_trie.allocate(seq) + seq.append_routed_experts(self._routed_experts(block_size)) + block_trie.cache_routed_experts_for_seq(seq) + + matched = sess.add_sequence(token_ids, sampling_param=SamplingParam(return_routed_experts=True)) + block_trie.match(matched) + + assert matched.num_history_ids == block_size * 2 + assert len(matched.all_routed_experts) == 0 + + def test_missing_replay_does_not_enrich_from_misaligned_tail(self, block_trie, block_mgr, scheduler): + sess = scheduler.add_session(0) + block_size = sess.seq_meta.block_size + token_ids = [1] * block_size + [2] * block_size + [3] + seq = sess.add_sequence(token_ids) + + block_mgr.allocate(seq) + block_trie.allocate(seq) + last_node = seq.prefix_cache.last_shared_node + assert last_node is not None + assert last_node.routed_experts is None + + matched = sess.add_sequence(token_ids, sampling_param=SamplingParam(return_routed_experts=True)) + block_trie.match(matched) + + assert matched.num_history_ids == block_size * 2 + assert len(matched.all_routed_experts) == 0 + + matched.append_routed_experts(self._routed_experts(1, offset=1000)) + block_trie.cache_routed_experts_for_seq(matched) + + assert last_node.routed_experts is None + + def test_match_multimodal_same_hash(self, block_trie, block_mgr, scheduler): + sess = scheduler.add_session(0) + block_size = sess.seq_meta.block_size + token_ids = [1] * block_size + [99] * block_size + [2] * block_size + [3] + + seq = sess.add_sequence(token_ids, multimodals=self._image_multimodals(block_size, block_size * 2, 1.0)) + block_mgr.allocate(seq) + block_trie.allocate(seq) + + seq = sess.add_sequence(token_ids, multimodals=self._image_multimodals(block_size, block_size * 2, 1.0)) + block_trie.match(seq) + + assert len(seq.logical_blocks) == 3 + assert seq.num_history_ids == block_size * 3 + node = seq.prefix_cache.last_shared_node + assert node is not None + assert node.num_matched == block_size * 3 + + def test_match_multimodal_different_hash(self, block_trie, block_mgr, scheduler): + sess = scheduler.add_session(0) + block_size = sess.seq_meta.block_size + token_ids = [1] * block_size + [99] * block_size + [2] * block_size + [3] + + seq = sess.add_sequence(token_ids, multimodals=self._image_multimodals(block_size, block_size * 2, 1.0)) + block_mgr.allocate(seq) + block_trie.allocate(seq) + + seq = sess.add_sequence(token_ids, multimodals=self._image_multimodals(block_size, block_size * 2, 2.0)) + block_trie.match(seq) + + assert len(seq.logical_blocks) == 1 + assert seq.num_history_ids == block_size + node = seq.prefix_cache.last_shared_node + assert node is not None + assert node.num_matched == block_size + + def test_match_multimodal_uses_precomputed_content_hash(self, block_trie, block_mgr, scheduler): + sess = scheduler.add_session(0) + block_size = sess.seq_meta.block_size + token_ids = [1] * block_size + [99] * block_size + [2] * block_size + [3] + + seq = sess.add_sequence( + token_ids, + multimodals=self._image_multimodals(block_size, block_size * 2, 1.0, content_hash='image-a'), + ) + block_mgr.allocate(seq) + block_trie.allocate(seq) + + seq = sess.add_sequence( + token_ids, + multimodals=self._image_multimodals(block_size, block_size * 2, 2.0, content_hash='image-a'), + ) + block_trie.match(seq) + + assert len(seq.logical_blocks) == 3 + assert seq.num_history_ids == block_size * 3 + assert seq.prefix_cache.metas[0].content_hash == 'image-a' + + def test_match_multimodal_different_precomputed_content_hash(self, block_trie, block_mgr, scheduler): + sess = scheduler.add_session(0) + block_size = sess.seq_meta.block_size + token_ids = [1] * block_size + [99] * block_size + [2] * block_size + [3] + + seq = sess.add_sequence( + token_ids, + multimodals=self._image_multimodals(block_size, block_size * 2, 1.0, content_hash='image-a'), + ) + block_mgr.allocate(seq) + block_trie.allocate(seq) + + seq = sess.add_sequence( + token_ids, + multimodals=self._image_multimodals(block_size, block_size * 2, 1.0, content_hash='image-b'), + ) + block_trie.match(seq) + + assert len(seq.logical_blocks) == 1 + assert seq.num_history_ids == block_size + assert seq.prefix_cache.metas[0].content_hash == 'image-b' + + def test_multimodal_prefix_cache_meta_skips_hash_when_prefix_cache_disabled(self, cache_config, scheduler_config, + seq_meta, monkeypatch): + cache_config.enable_prefix_caching = False + scheduler = Scheduler(scheduler_config=scheduler_config, cache_config=cache_config, seq_meta=seq_meta) + + def _fail_hash(*args, **kwargs): + raise AssertionError('disabled prefix cache should not hash multimodal payloads') + + monkeypatch.setattr(messages_module, 'make_multimodal_content_hash', _fail_hash) + + sess = scheduler.add_session(0) + seq = sess.add_sequence([99] * sess.seq_meta.block_size, + multimodals=self._image_multimodals(0, sess.seq_meta.block_size, 1.0)) + + assert seq.prefix_cache.metas == [] + assert not seq.history_multimodals.empty() + + def test_match_multimodal_clamps_before_split_span(self, block_trie, block_mgr, scheduler): + allocator = block_trie.allocator + sess = scheduler.add_session(0) + block_size = sess.seq_meta.block_size + start = block_size // 2 + end = block_size + block_size // 2 + token_ids = [99] * block_size + [99] * block_size + [3] + + seq = sess.add_sequence(token_ids, multimodals=self._image_multimodals(start, end, 1.0)) + block_mgr.allocate(seq) + block_trie.allocate(seq) + cached_blocks = seq.logical_blocks.get_real_blocks()[:1] + + token_ids = [99] * block_size + [98] * block_size + [3] + seq = sess.add_sequence(token_ids, multimodals=self._image_multimodals(start, end, 1.0)) + block_trie.match(seq) + + assert len(seq.logical_blocks) == 0 + assert seq.num_history_ids == 0 + assert np.array_equal(allocator.get_ref_count(cached_blocks), [2]) + node = seq.prefix_cache.last_shared_node + assert node is not None + assert node.num_matched == 0 + + def test_match_multimodal_clamp_keeps_previous_images(self, block_trie, block_mgr, scheduler): + sess = scheduler.add_session(0) + block_size = sess.seq_meta.block_size + token_ids = [7] * (block_size * 7 + block_size // 2) + image1 = (block_size, block_size * 2, 1.0) + image2 = (block_size * 3, block_size * 4, 2.0) + image3 = (block_size * 6, block_size * 7 + block_size // 4, 3.0) + + seq = sess.add_sequence(token_ids, multimodals=self._multi_image_multimodals([image1, image2, image3])) + block_mgr.allocate(seq) + block_trie.allocate(seq) + + seq = sess.add_sequence(token_ids, multimodals=self._multi_image_multimodals([image1, image2, image3])) + block_trie.match(seq) + assert len(seq.logical_blocks) == 6 + assert seq.num_history_ids == block_size * 6 + node = seq.prefix_cache.last_shared_node + assert node is not None + assert node.num_matched == block_size * 6 + + different_last_image = (image3[0], image3[1], 4.0) + seq = sess.add_sequence( + token_ids, + multimodals=self._multi_image_multimodals([image1, image2, different_last_image]), + ) + block_trie.match(seq) + assert len(seq.logical_blocks) == 6 + assert seq.num_history_ids == block_size * 6 + + different_middle_image = (image2[0], image2[1], 5.0) + seq = sess.add_sequence( + token_ids, + multimodals=self._multi_image_multimodals([image1, different_middle_image, image3]), + ) + block_trie.match(seq) + assert len(seq.logical_blocks) == 3 + assert seq.num_history_ids == block_size * 3 + + def test_match_multimodal_clamp_rechecks_after_block_rounding(self, block_trie, block_mgr, scheduler): + sess = scheduler.add_session(0) + block_size = sess.seq_meta.block_size + token_ids = [99] * (block_size * 7 + block_size // 2) + image1 = (block_size // 2, block_size * 5 + block_size // 2, 1.0) + image2 = (block_size * 5 + block_size // 2 + 2, block_size * 7 + block_size // 2, 2.0) + + seq = sess.add_sequence(token_ids, multimodals=self._multi_image_multimodals([image1, image2])) + block_mgr.allocate(seq) + block_trie.allocate(seq) + + seq = sess.add_sequence(token_ids, multimodals=self._multi_image_multimodals([image1, image2])) + block_trie.match(seq) + + assert len(seq.logical_blocks) == 0 + assert seq.num_history_ids == 0 + node = seq.prefix_cache.last_shared_node + assert node is not None + assert node.num_matched == 0 + + def test_match_multimodal_extra_hash_order_is_canonical(self, block_trie, block_mgr, scheduler): + sess = scheduler.add_session(0) + block_size = sess.seq_meta.block_size + token_ids = [99] * block_size + [3] + image = self._modal_data(2, 6, 1.0, Modality.IMAGE) + video = self._modal_data(8, 12, 2.0, Modality.VIDEO) + + seq = sess.add_sequence(token_ids, multimodals=dict(image=[image], video=[video])) + block_mgr.allocate(seq) + block_trie.allocate(seq) + + image = self._modal_data(2, 6, 1.0, Modality.IMAGE) + video = self._modal_data(8, 12, 2.0, Modality.VIDEO) + seq = sess.add_sequence(token_ids, multimodals=dict(video=[video], image=[image])) + block_trie.match(seq) + + assert len(seq.logical_blocks) == 1 + assert seq.num_history_ids == block_size + node = seq.prefix_cache.last_shared_node + assert node is not None + assert node.num_matched == block_size + + def test_prefix_cache_extra_hash_lookup_is_block_indexed(self, scheduler): + sess = scheduler.add_session(0) + block_size = sess.seq_meta.block_size + token_ids = [99] * block_size * 4 + [3] + multimodals = dict(image=[ + self._modal_data(1, block_size + 1, 1.0, Modality.IMAGE), + self._modal_data(block_size * 2 + 1, block_size * 2 + 4, 2.0, Modality.IMAGE), + self._modal_data(block_size * 3 + 2, block_size * 3 + 6, 3.0, Modality.IMAGE), + ]) + seq = sess.add_sequence(token_ids, multimodals=multimodals) + + block0_hashes = seq.get_prefix_cache_extra_hashes(0, block_size) + block1_hashes = seq.get_prefix_cache_extra_hashes(block_size, block_size * 2) + block2_hashes = seq.get_prefix_cache_extra_hashes(block_size * 2, block_size * 3) + block3_hashes = seq.get_prefix_cache_extra_hashes(block_size * 3, block_size * 4) + + assert len(block0_hashes) == 1 + assert block0_hashes == block1_hashes + assert len(block2_hashes) == 1 + assert len(block3_hashes) == 1 + assert len(seq.prefix_cache.block_extra_hashes) == 4 + assert seq.prefix_cache.num_indexed_metas == 3 + def test_evict(self, block_trie, scheduler, num_gpu_blocks): block_mgr = block_trie.block_manager sess = scheduler.add_session(0) @@ -156,3 +563,854 @@ def test_evict(self, block_trie, scheduler, num_gpu_blocks): new_leaf = next(iter(block_trie.leaves)) assert leaf != new_leaf assert block_mgr.get_num_free_gpu_blocks() == 5 + + def test_evict_prunes_stale_non_leaf_entry(self, block_trie, scheduler, num_gpu_blocks): + block_mgr = block_trie.block_manager + allocator = block_trie.allocator + sess = scheduler.add_session(0) + block_size = sess.seq_meta.block_size + token_ids = [1] * block_size + [2] * block_size + seq = sess.add_sequence(token_ids) + + block_mgr.allocate(seq) + block_trie.allocate(seq) + leaf = seq.prefix_cache.last_shared_node + parent = leaf.parent + assert parent.parent is not None + assert leaf in block_trie.leaves + assert parent not in block_trie.leaves + + block_trie.leaves.add(parent) + allocator._log_mem.access_time[leaf.block] = 0 + allocator._log_mem.access_time[parent.block] = 1 + block_mgr.free(seq) + seq.set_step(0) + + assert block_trie.evict(2) == 2 + assert len(block_trie.leaves) == 0 + assert block_mgr.get_num_free_gpu_blocks() == num_gpu_blocks + + def test_match_ssm_requires_ready_state_checkpoint(self, ssm_scheduler): + block_mgr = ssm_scheduler.block_manager + block_trie = ssm_scheduler.block_trie + sess = ssm_scheduler.add_session(0) + block_size = sess.seq_meta.block_size + token_ids = [1] * block_size * 2 + [2] + + seq = sess.add_sequence(token_ids) + block_mgr.allocate(seq) + block_trie.allocate(seq) + node = seq.prefix_cache.last_shared_node + + seq = sess.add_sequence(token_ids) + block_trie.match(seq) + assert len(seq.logical_blocks) == 0 + assert seq.num_history_ids == 0 + assert seq.prefix_cache.restore_state == -1 + + state_idx = block_trie.reserve_state_checkpoint(node) + block_trie.mark_state_checkpoint_ready(node) + + seq = sess.add_sequence(token_ids) + block_trie.match(seq) + assert len(seq.logical_blocks) == 2 + assert seq.num_history_ids == block_size * 2 + assert seq.prefix_cache.restore_state == state_idx + + def test_match_ssm_clamps_to_deepest_ready_state_checkpoint(self, ssm_scheduler): + block_mgr = ssm_scheduler.block_manager + block_trie = ssm_scheduler.block_trie + sess = ssm_scheduler.add_session(0) + block_size = sess.seq_meta.block_size + token_ids = [1] * block_size * 3 + [2] + + seq = sess.add_sequence(token_ids) + block_mgr.allocate(seq) + block_trie.allocate(seq) + leaf = seq.prefix_cache.last_shared_node + checkpoint_node = leaf.parent + state_idx = block_trie.reserve_state_checkpoint(checkpoint_node) + block_trie.mark_state_checkpoint_ready(checkpoint_node) + + seq = sess.add_sequence(token_ids) + block_trie.match(seq) + + assert len(seq.logical_blocks) == 2 + assert seq.num_history_ids == block_size * 2 + assert seq.prefix_cache.restore_state == state_idx + + def test_match_ssm_replays_cached_routed_experts(self, ssm_scheduler): + block_mgr = ssm_scheduler.block_manager + block_trie = ssm_scheduler.block_trie + sess = ssm_scheduler.add_session(0) + block_size = sess.seq_meta.block_size + token_ids = [1] * block_size + [2] * block_size + sampling_param = SamplingParam(return_routed_experts=True) + seq = sess.add_sequence(token_ids, sampling_param=sampling_param) + experts = self._routed_experts(block_size * 2) + + block_mgr.allocate(seq) + block_trie.allocate(seq) + seq.append_routed_experts(experts) + block_trie.cache_routed_experts_for_seq(seq) + state_idx = block_trie.reserve_state_checkpoint_for_seq(seq) + assert state_idx >= 0 + assert block_trie.commit_state_checkpoint_for_seq(seq) + + matched = sess.add_sequence(token_ids + [3], sampling_param=sampling_param) + block_trie.match(matched) + + assert matched.prefix_cache.restore_state == state_idx + assert matched.num_history_ids == block_size * 2 + assert np.array_equal(matched.all_routed_experts.get_real(), experts) + + def test_match_ssm_sparse_index_misses_without_block_walk(self, ssm_scheduler): + block_mgr = ssm_scheduler.block_manager + block_trie = ssm_scheduler.block_trie + sess = ssm_scheduler.add_session(0) + block_size = sess.seq_meta.block_size + num_blocks = 12 + token_ids = [] + for block_id in range(num_blocks): + token_ids.extend([block_id + 1] * block_size) + token_ids.append(99) + + seq = sess.add_sequence(token_ids) + block_mgr.allocate(seq) + block_trie.allocate(seq) + node = seq.prefix_cache.last_shared_node + block_trie.reserve_state_checkpoint(node) + block_trie.mark_state_checkpoint_ready(node) + + miss_token_ids = token_ids.copy() + miss_token_ids[(num_blocks - 1) * block_size:num_blocks * block_size] = [777] * block_size + seq = sess.add_sequence(miss_token_ids) + calls = 0 + get_hashes = seq.get_prefix_cache_extra_hashes + + def count_hashes(start, end): + nonlocal calls + calls += 1 + return get_hashes(start, end) + + seq.get_prefix_cache_extra_hashes = count_hashes + block_trie.match(seq) + + assert len(seq.logical_blocks) == 0 + assert seq.prefix_cache.restore_state == -1 + assert calls == 1 + + def test_match_ssm_sparse_index_verifies_hash_collision_exactly(self, ssm_scheduler): + block_mgr = ssm_scheduler.block_manager + block_trie = ssm_scheduler.block_trie + sess = ssm_scheduler.add_session(0) + block_size = sess.seq_meta.block_size + token_ids = [1] * block_size + [2] * block_size + [3] + + seq = sess.add_sequence(token_ids) + block_mgr.allocate(seq) + block_trie.allocate(seq) + node = seq.prefix_cache.last_shared_node + block_trie.reserve_state_checkpoint(node) + block_trie.mark_state_checkpoint_ready(node) + + miss_token_ids = [1] * block_size + [4] * block_size + [3] + seq = sess.add_sequence(miss_token_ids) + collision_key = block_trie._make_state_checkpoint_lookup_key(seq, block_size * 2) + block_trie._state_checkpoint_index.setdefault(collision_key, []).append(node) + block_trie._state_checkpoint_steps.setdefault(seq.adapter_name, set()).add(block_size * 2) + + block_trie.match(seq) + + assert len(seq.logical_blocks) == 0 + assert seq.prefix_cache.restore_state == -1 + + def test_match_ssm_keeps_request_mismatch_checkpoint_candidate(self, ssm_scheduler): + block_size = ssm_scheduler.seq_meta.block_size + token_ids = [1] * block_size + [2] * block_size + _, node, state_idx = self._add_ready_ssm_checkpoint(ssm_scheduler, token_ids) + key = ssm_scheduler.block_trie._make_state_checkpoint_node_key(node) + free_states = ssm_scheduler.state_manager.get_num_free_checkpoint() + + # Same indexed step and last block, but a different earlier block. This + # is a miss for this request, not proof that the cached checkpoint is + # stale globally. + miss_token_ids = [4] * block_size + [2] * block_size + [3] + seq = ssm_scheduler.add_session(100).add_sequence(miss_token_ids) + ssm_scheduler.block_trie.match(seq) + + assert len(seq.logical_blocks) == 0 + assert seq.prefix_cache.restore_state == -1 + assert node.state_idx == state_idx + assert node.state_ready + assert node in ssm_scheduler.block_trie._state_checkpoint_index[key] + assert ssm_scheduler.state_manager.get_num_free_checkpoint() == free_states + + def test_match_ssm_drops_stale_sparse_index_entry_only(self, ssm_scheduler): + block_size = ssm_scheduler.seq_meta.block_size + token_ids = [1] * block_size + [2] * block_size + _, node, state_idx = self._add_ready_ssm_checkpoint(ssm_scheduler, token_ids) + block_trie = ssm_scheduler.block_trie + canonical_key = block_trie._make_state_checkpoint_node_key(node) + + miss_token_ids = [1] * block_size + [4] * block_size + [3] + seq = ssm_scheduler.add_session(100).add_sequence(miss_token_ids) + stale_key = block_trie._make_state_checkpoint_lookup_key(seq, block_size * 2) + assert stale_key != canonical_key + block_trie._state_checkpoint_index.setdefault(stale_key, []).append(node) + block_trie._state_checkpoint_steps.setdefault(seq.adapter_name, set()).add(block_size * 2) + free_states = ssm_scheduler.state_manager.get_num_free_checkpoint() + + block_trie.match(seq) + + assert len(seq.logical_blocks) == 0 + assert seq.prefix_cache.restore_state == -1 + assert stale_key not in block_trie._state_checkpoint_index + assert node in block_trie._state_checkpoint_index[canonical_key] + assert block_trie._state_checkpoint_steps[node.adapter_name] == {node.num_matched} + assert node.state_idx == state_idx + assert node.state_ready + assert ssm_scheduler.state_manager.get_num_free_checkpoint() == free_states + + def test_match_ssm_releases_detached_stale_checkpoint_candidate(self, ssm_scheduler): + block_size = ssm_scheduler.seq_meta.block_size + token_ids = [1] * block_size + [2] * block_size + _, node, _ = self._add_ready_ssm_checkpoint(ssm_scheduler, token_ids) + block_trie = ssm_scheduler.block_trie + key = block_trie._make_state_checkpoint_node_key(node) + free_states = ssm_scheduler.state_manager.get_num_free_checkpoint() + + node.parent = None + seq = ssm_scheduler.add_session(100).add_sequence(token_ids + [3]) + block_trie.match(seq) + + assert len(seq.logical_blocks) == 0 + assert seq.prefix_cache.restore_state == -1 + assert key not in block_trie._state_checkpoint_index + assert node.adapter_name not in block_trie._state_checkpoint_steps + assert node.state_idx == -1 + assert not node.state_ready + assert ssm_scheduler.state_manager.get_num_free_checkpoint() == free_states + 1 + + def test_match_ssm_keeps_pinned_stale_checkpoint_candidate(self, ssm_scheduler): + block_size = ssm_scheduler.seq_meta.block_size + token_ids = [1] * block_size + [2] * block_size + _, node, state_idx = self._add_ready_ssm_checkpoint(ssm_scheduler, token_ids) + block_trie = ssm_scheduler.block_trie + key = block_trie._make_state_checkpoint_node_key(node) + free_states = ssm_scheduler.state_manager.get_num_free_checkpoint() + + node.state_ref_count = 1 + node.parent = None + seq = ssm_scheduler.add_session(100).add_sequence(token_ids + [3]) + block_trie.match(seq) + + assert len(seq.logical_blocks) == 0 + assert seq.prefix_cache.restore_state == -1 + assert node.state_idx == state_idx + assert node.state_ready + assert node.state_ref_count == 1 + assert node in block_trie._state_checkpoint_index[key] + assert ssm_scheduler.state_manager.get_num_free_checkpoint() == free_states + + def test_match_ssm_releases_unready_indexed_checkpoint_candidate(self, ssm_scheduler): + block_mgr = ssm_scheduler.block_manager + block_trie = ssm_scheduler.block_trie + sess = ssm_scheduler.add_session(0) + block_size = sess.seq_meta.block_size + token_ids = [1] * block_size + [2] * block_size + + seq = sess.add_sequence(token_ids) + block_mgr.allocate(seq) + block_trie.allocate(seq) + node = seq.prefix_cache.last_shared_node + state_idx = block_trie.reserve_state_checkpoint(node) + assert state_idx >= 0 + assert not node.state_ready + key = block_trie._make_state_checkpoint_node_key(node) + block_trie._state_checkpoint_index.setdefault(key, []).append(node) + block_trie._state_checkpoint_steps.setdefault(node.adapter_name, set()).add(node.num_matched) + free_states = ssm_scheduler.state_manager.get_num_free_checkpoint() + + seq = sess.add_sequence(token_ids + [3]) + block_trie.match(seq) + + assert len(seq.logical_blocks) == 0 + assert seq.prefix_cache.restore_state == -1 + assert key not in block_trie._state_checkpoint_index + assert node.adapter_name not in block_trie._state_checkpoint_steps + assert node.state_idx == -1 + assert not node.state_ready + assert ssm_scheduler.state_manager.get_num_free_checkpoint() == free_states + 1 + + def test_ssm_checkpoint_index_rejects_unready_node(self, ssm_scheduler): + block_mgr = ssm_scheduler.block_manager + block_trie = ssm_scheduler.block_trie + sess = ssm_scheduler.add_session(0) + block_size = sess.seq_meta.block_size + token_ids = [1] * block_size * 2 + + seq = sess.add_sequence(token_ids) + block_mgr.allocate(seq) + block_trie.allocate(seq) + node = seq.prefix_cache.last_shared_node + assert block_trie.reserve_state_checkpoint(node) >= 0 + + with pytest.raises(RuntimeError, match='unready SSM prefix-cache checkpoint'): + block_trie._index_state_checkpoint(node) + + def test_ssm_checkpoint_save_publishes_to_sparse_index(self, ssm_scheduler): + block_mgr = ssm_scheduler.block_manager + block_trie = ssm_scheduler.block_trie + sess = ssm_scheduler.add_session(0) + block_size = sess.seq_meta.block_size + token_ids = [1] * block_size * 2 + + seq = sess.add_sequence(token_ids) + block_mgr.allocate(seq) + block_trie.allocate(seq) + state_idx = block_trie.reserve_state_checkpoint_for_seq(seq) + node = seq.prefix_cache.last_shared_node + + assert state_idx >= 0 + assert seq.prefix_cache.save_state == state_idx + assert not node.state_ready + + assert block_trie.commit_state_checkpoint_for_seq(seq) + assert node.state_ready + assert seq.prefix_cache.save_state == -1 + + seq = sess.add_sequence(token_ids + [2]) + block_trie.match(seq) + + assert seq.num_history_ids == block_size * 2 + assert seq.prefix_cache.restore_state == state_idx + + def test_ssm_restore_acquire_survives_tail_allocation(self, ssm_scheduler): + block_mgr = ssm_scheduler.block_manager + block_trie = ssm_scheduler.block_trie + block_size = ssm_scheduler.seq_meta.block_size + checkpoint_tokens = [1] * block_size * 2 + suffix_tokens = [2] * block_size * 3 + [3] + + _, checkpoint_node, state_idx = self._add_ready_ssm_checkpoint(ssm_scheduler, checkpoint_tokens) + + seq = ssm_scheduler.add_session(100).add_sequence(checkpoint_tokens + suffix_tokens) + block_trie.match(seq) + assert seq.num_history_ids == block_size * 2 + assert seq.prefix_cache.restore_state == state_idx + assert seq.prefix_cache.restore_node is checkpoint_node + + block_mgr.allocate(seq) + block_trie.allocate(seq) + assert seq.prefix_cache.last_shared_node is not checkpoint_node + assert seq.prefix_cache.restore_node is checkpoint_node + + assert block_trie.acquire_state_checkpoint_restore_for_seq(seq) + assert checkpoint_node.state_ref_count == 1 + assert block_trie.release_state_checkpoint_restore_for_seq(seq) + assert checkpoint_node.state_ref_count == 0 + assert seq.prefix_cache.restore_node is None + + def test_ssm_checkpoint_release_rejects_pinned_state(self, ssm_scheduler): + block_trie = ssm_scheduler.block_trie + block_size = ssm_scheduler.seq_meta.block_size + token_ids = [1] * block_size * 2 + + _, checkpoint_node, state_idx = self._add_ready_ssm_checkpoint(ssm_scheduler, token_ids) + + seq = ssm_scheduler.add_session(100).add_sequence(token_ids + [2]) + block_trie.match(seq) + assert seq.prefix_cache.restore_state == state_idx + assert block_trie.acquire_state_checkpoint_restore_for_seq(seq) + + with pytest.raises(RuntimeError, match='Cannot release a pinned'): + block_trie.release_state_checkpoint(checkpoint_node) + + assert block_trie.release_state_checkpoint_restore_for_seq(seq) + block_trie.release_state_checkpoint(checkpoint_node) + + def test_ssm_checkpoint_restore_release_detects_lost_ref(self, ssm_scheduler): + block_trie = ssm_scheduler.block_trie + block_size = ssm_scheduler.seq_meta.block_size + token_ids = [1] * block_size * 2 + + _, checkpoint_node, state_idx = self._add_ready_ssm_checkpoint(ssm_scheduler, token_ids) + + seq = ssm_scheduler.add_session(100).add_sequence(token_ids + [2]) + block_trie.match(seq) + assert seq.prefix_cache.restore_state == state_idx + assert block_trie.acquire_state_checkpoint_restore_for_seq(seq) + checkpoint_node.state_ref_count = 0 + + with pytest.raises(RuntimeError, match='lost its node reference'): + block_trie.release_state_checkpoint_restore_for_seq(seq) + + checkpoint_node.state_ref_count = 1 + assert block_trie.release_state_checkpoint_restore_for_seq(seq) + + def test_ssm_checkpoint_ready_index_is_idempotent(self, ssm_scheduler): + block_mgr = ssm_scheduler.block_manager + block_trie = ssm_scheduler.block_trie + sess = ssm_scheduler.add_session(0) + block_size = sess.seq_meta.block_size + token_ids = [1] * block_size * 2 + + seq = sess.add_sequence(token_ids) + block_mgr.allocate(seq) + block_trie.allocate(seq) + node = seq.prefix_cache.last_shared_node + block_trie.reserve_state_checkpoint(node) + + block_trie.mark_state_checkpoint_ready(node) + block_trie.mark_state_checkpoint_ready(node) + + key = block_trie._make_state_checkpoint_node_key(node) + assert block_trie._state_checkpoint_index[key] == [node] + assert block_trie._state_checkpoint_steps[node.adapter_name] == {node.num_matched} + + def test_ssm_checkpoint_unindex_removes_duplicate_entries(self, ssm_scheduler): + block_mgr = ssm_scheduler.block_manager + block_trie = ssm_scheduler.block_trie + sess = ssm_scheduler.add_session(0) + block_size = sess.seq_meta.block_size + token_ids = [1] * block_size * 2 + + seq = sess.add_sequence(token_ids) + block_mgr.allocate(seq) + block_trie.allocate(seq) + node = seq.prefix_cache.last_shared_node + block_trie.reserve_state_checkpoint(node) + block_trie.mark_state_checkpoint_ready(node) + key = block_trie._make_state_checkpoint_node_key(node) + block_trie._state_checkpoint_index[key].extend([node, node]) + + block_trie.release_state_checkpoint(node) + + assert key not in block_trie._state_checkpoint_index + assert node.adapter_name not in block_trie._state_checkpoint_steps + assert node.state_idx == -1 + assert not node.state_ready + + def test_ssm_checkpoint_pending_save_discard_releases_slot(self, ssm_scheduler): + block_mgr = ssm_scheduler.block_manager + block_trie = ssm_scheduler.block_trie + sess = ssm_scheduler.add_session(0) + block_size = sess.seq_meta.block_size + token_ids = [1] * block_size * 2 + + seq = sess.add_sequence(token_ids) + block_mgr.allocate(seq) + block_trie.allocate(seq) + free_states = ssm_scheduler.state_manager.get_num_free_checkpoint() + state_idx = block_trie.reserve_state_checkpoint_for_seq(seq) + node = seq.prefix_cache.last_shared_node + + assert state_idx >= 0 + assert node.state_idx == state_idx + assert not node.state_ready + assert ssm_scheduler.state_manager.get_num_free_checkpoint() == free_states - 1 + + assert block_trie.discard_state_checkpoint_for_seq(seq) + assert seq.prefix_cache.save_state == -1 + assert seq.prefix_cache.save_step == 0 + assert seq.prefix_cache.save_node is None + assert node.state_idx == -1 + assert not node.state_ready + assert ssm_scheduler.state_manager.get_num_free_checkpoint() == free_states + + def test_ssm_checkpoint_commit_allows_sequence_to_advance_past_save_step(self, ssm_scheduler): + block_mgr = ssm_scheduler.block_manager + block_trie = ssm_scheduler.block_trie + sess = ssm_scheduler.add_session(0) + block_size = sess.seq_meta.block_size + save_step = block_size * 2 + token_ids = [1] * save_step + + seq = sess.add_sequence(token_ids) + seq.set_step(save_step - 1) + block_mgr.allocate(seq) + block_trie.allocate(seq) + state_idx = block_trie.reserve_decode_state_checkpoint_for_seq(seq, interval=block_size) + node = seq.prefix_cache.save_node + + assert state_idx >= 0 + seq.update_token_ids([2], mode=UpdateTokenMode.DECODE) + + assert block_trie.commit_state_checkpoint_for_seq(seq) + assert seq.prefix_cache.decode_state_node is node + assert node.state_ready + + def test_ssm_checkpoint_commit_failure_discards_detached_pending_slot(self, ssm_scheduler): + block_mgr = ssm_scheduler.block_manager + block_trie = ssm_scheduler.block_trie + sess = ssm_scheduler.add_session(0) + block_size = sess.seq_meta.block_size + token_ids = [1] * block_size * 2 + + seq = sess.add_sequence(token_ids) + block_mgr.allocate(seq) + block_trie.allocate(seq) + free_states = ssm_scheduler.state_manager.get_num_free_checkpoint() + state_idx = block_trie.reserve_state_checkpoint_for_seq(seq) + node = seq.prefix_cache.last_shared_node + + assert state_idx >= 0 + node.parent = None + + assert not block_trie.commit_state_checkpoint_for_seq(seq) + assert seq.prefix_cache.save_state == -1 + assert seq.prefix_cache.save_step == 0 + assert seq.prefix_cache.save_node is None + assert node.state_idx == -1 + assert ssm_scheduler.state_manager.get_num_free_checkpoint() == free_states + + def test_ssm_decode_checkpoint_replaces_previous_unpinned_state(self, ssm_scheduler): + block_mgr = ssm_scheduler.block_manager + block_trie = ssm_scheduler.block_trie + sess = ssm_scheduler.add_session(0) + block_size = sess.seq_meta.block_size + step = block_size * 2 + + seq = sess.add_sequence([1] * step) + seq.set_step(step - 1) + block_mgr.allocate(seq) + block_trie.allocate(seq) + state_idx_a = block_trie.reserve_decode_state_checkpoint_for_seq(seq, interval=block_size) + node_a = seq.prefix_cache.save_node + + assert state_idx_a >= 0 + assert seq.prefix_cache.save_is_decode + assert block_trie.commit_state_checkpoint_for_seq(seq) + assert seq.prefix_cache.decode_state_node is node_a + assert node_a.state_ready + + seq.update_token_ids([2] * block_size, mode=UpdateTokenMode.DECODE) + block_mgr.allocate(seq) + block_trie.allocate(seq) + state_idx_b = block_trie.reserve_decode_state_checkpoint_for_seq(seq, interval=block_size) + node_b = seq.prefix_cache.save_node + + assert state_idx_b >= 0 + assert node_a.state_idx == -1 + assert not node_a.state_ready + assert seq.prefix_cache.decode_state_node is None + assert block_trie.commit_state_checkpoint_for_seq(seq) + assert seq.prefix_cache.decode_state_node is node_b + assert node_b.state_ready + + def test_ssm_decode_checkpoint_skip_replacement_when_previous_is_pinned(self, ssm_scheduler): + block_mgr = ssm_scheduler.block_manager + block_trie = ssm_scheduler.block_trie + sess = ssm_scheduler.add_session(0) + block_size = sess.seq_meta.block_size + step = block_size * 2 + + seq = sess.add_sequence([1] * step) + seq.set_step(step - 1) + block_mgr.allocate(seq) + block_trie.allocate(seq) + state_idx = block_trie.reserve_decode_state_checkpoint_for_seq(seq, interval=block_size) + node = seq.prefix_cache.save_node + assert state_idx >= 0 + assert block_trie.commit_state_checkpoint_for_seq(seq) + node.state_ref_count = 1 + + seq.update_token_ids([2] * block_size, mode=UpdateTokenMode.DECODE) + block_mgr.allocate(seq) + block_trie.allocate(seq) + + assert block_trie.reserve_decode_state_checkpoint_for_seq(seq, interval=block_size) == -1 + assert seq.prefix_cache.decode_state_node is node + assert node.state_idx == state_idx + assert node.state_ready + assert seq.prefix_cache.save_state == -1 + + def test_ssm_decode_checkpoint_keeps_old_state_when_new_node_is_pending(self, ssm_scheduler): + block_mgr = ssm_scheduler.block_manager + block_trie = ssm_scheduler.block_trie + sess = ssm_scheduler.add_session(0) + block_size = sess.seq_meta.block_size + step = block_size * 2 + old_tokens = [1] * step + new_tokens = [2] * block_size + + seq = sess.add_sequence(old_tokens) + seq.set_step(step - 1) + block_mgr.allocate(seq) + block_trie.allocate(seq) + old_state_idx = block_trie.reserve_decode_state_checkpoint_for_seq(seq, interval=block_size) + old_node = seq.prefix_cache.save_node + + assert old_state_idx >= 0 + assert block_trie.commit_state_checkpoint_for_seq(seq) + assert seq.prefix_cache.decode_state_node is old_node + assert old_node.state_ready + + seq.update_token_ids(new_tokens, mode=UpdateTokenMode.DECODE) + block_mgr.allocate(seq) + block_trie.allocate(seq) + new_node = seq.prefix_cache.last_shared_node + + pending_seq = sess.add_sequence(old_tokens + new_tokens) + block_mgr.allocate(pending_seq) + block_trie.allocate(pending_seq) + assert pending_seq.prefix_cache.last_shared_node is new_node + pending_state_idx = block_trie.reserve_state_checkpoint_for_seq(pending_seq) + assert pending_state_idx >= 0 + assert new_node.state_idx == pending_state_idx + assert not new_node.state_ready + + assert block_trie.reserve_decode_state_checkpoint_for_seq(seq, interval=block_size) == -1 + assert seq.prefix_cache.decode_state_node is old_node + assert old_node.state_idx == old_state_idx + assert old_node.state_ready + assert seq.prefix_cache.save_state == -1 + + def test_ssm_decode_checkpoint_keeps_old_state_when_new_step_is_not_allocated(self, ssm_scheduler): + block_mgr = ssm_scheduler.block_manager + block_trie = ssm_scheduler.block_trie + sess = ssm_scheduler.add_session(0) + block_size = sess.seq_meta.block_size + step = block_size * 2 + + seq = sess.add_sequence([1] * step) + seq.set_step(step - 1) + block_mgr.allocate(seq) + block_trie.allocate(seq) + state_idx = block_trie.reserve_decode_state_checkpoint_for_seq(seq, interval=block_size) + node = seq.prefix_cache.save_node + assert state_idx >= 0 + assert block_trie.commit_state_checkpoint_for_seq(seq) + + seq.update_token_ids([2] * block_size, mode=UpdateTokenMode.DECODE) + block_mgr.allocate(seq) + + assert block_trie.reserve_decode_state_checkpoint_for_seq(seq, interval=block_size) == -1 + assert seq.prefix_cache.decode_state_node is node + assert node.state_idx == state_idx + assert node.state_ready + + def test_ssm_checkpoint_save_uses_explicit_chunk_step(self, ssm_scheduler): + block_mgr = ssm_scheduler.block_manager + block_trie = ssm_scheduler.block_trie + sess = ssm_scheduler.add_session(0) + block_size = sess.seq_meta.block_size + checkpoint_step = block_size * 2 + token_ids = [1] * block_size * 4 + [2] + + seq = sess.add_sequence(token_ids) + block_mgr.allocate(seq) + block_trie.allocate(seq) + state_idx = block_trie.reserve_state_checkpoint_for_seq(seq, step=checkpoint_step) + + assert state_idx >= 0 + assert seq.prefix_cache.save_state == state_idx + assert seq.prefix_cache.save_step == checkpoint_step + + # Long-context chunking advances the sequence step before the executor + # output is committed. The checkpoint should still attach to the + # ancestor node for the chunk boundary. + seq.set_step(checkpoint_step) + assert block_trie.commit_state_checkpoint_for_seq(seq) + + seq = sess.add_sequence(token_ids[:checkpoint_step] + [3]) + block_trie.match(seq) + + assert seq.num_history_ids == checkpoint_step + assert seq.prefix_cache.restore_state == state_idx + + def test_ssm_checkpoint_save_skips_partial_tail(self, ssm_scheduler): + block_mgr = ssm_scheduler.block_manager + block_trie = ssm_scheduler.block_trie + sess = ssm_scheduler.add_session(0) + block_size = sess.seq_meta.block_size + token_ids = [1] * block_size * 2 + [2] + + seq = sess.add_sequence(token_ids) + block_mgr.allocate(seq) + block_trie.allocate(seq) + + assert block_trie.reserve_state_checkpoint_for_seq(seq) == -1 + assert seq.prefix_cache.save_state == -1 + + def test_ssm_checkpoint_save_skips_when_no_state_slot(self, ssm_cache_config, scheduler_config, seq_meta): + cache_config = ssm_cache_config + cache_config.num_state_caches = 1 + scheduler = Scheduler(scheduler_config=scheduler_config, cache_config=cache_config, seq_meta=seq_meta) + block_mgr = scheduler.block_manager + block_trie = scheduler.block_trie + sess = scheduler.add_session(0) + block_size = sess.seq_meta.block_size + token_ids = [1] * block_size * 2 + + seq = sess.add_sequence(token_ids) + block_mgr.allocate(seq) + block_trie.allocate(seq) + + assert block_trie.reserve_state_checkpoint_for_seq(seq) == -1 + assert seq.prefix_cache.save_state == -1 + + def test_ssm_checkpoint_save_skips_duplicate_unready_node(self, ssm_scheduler): + block_mgr = ssm_scheduler.block_manager + block_trie = ssm_scheduler.block_trie + block_size = ssm_scheduler.seq_meta.block_size + token_ids = [1] * block_size * 2 + + seq_a = ssm_scheduler.add_session(0).add_sequence(token_ids) + block_mgr.allocate(seq_a) + block_trie.allocate(seq_a) + + seq_b = ssm_scheduler.add_session(1).add_sequence(token_ids) + block_mgr.allocate(seq_b) + block_trie.allocate(seq_b) + + node = seq_a.prefix_cache.last_shared_node + assert node is seq_b.prefix_cache.last_shared_node + + state_idx_a = block_trie.reserve_state_checkpoint_for_seq(seq_a) + state_idx_b = block_trie.reserve_state_checkpoint_for_seq(seq_b) + + assert state_idx_a >= 0 + assert state_idx_b == -1 + assert node.state_idx == state_idx_a + assert not node.state_ready + assert seq_a.prefix_cache.save_state == state_idx_a + assert seq_a.prefix_cache.save_node is node + assert seq_b.prefix_cache.save_state == -1 + assert seq_b.prefix_cache.save_node is None + + assert block_trie.commit_state_checkpoint_for_seq(seq_a) + matched = ssm_scheduler.add_session(2).add_sequence(token_ids + [2]) + block_trie.match(matched) + assert matched.prefix_cache.restore_state == state_idx_a + assert matched.num_history_ids == block_size * 2 + + def test_ssm_checkpoint_save_evicts_unpinned_state_only(self, ssm_cache_config, scheduler_config, seq_meta): + cache_config = ssm_cache_config + cache_config.prefix_cache_state_budget = 0 + cache_config.num_state_caches = 2 + scheduler = Scheduler(scheduler_config=scheduler_config, cache_config=cache_config, seq_meta=seq_meta) + block_size = scheduler.seq_meta.block_size + token_ids_a = [1] * block_size * 2 + token_ids_b = [2] * block_size * 2 + + _, node_a, state_idx_a = self._add_ready_ssm_checkpoint(scheduler, token_ids_a) + seq_b = scheduler.add_session(99).add_sequence(token_ids_b) + scheduler.block_manager.allocate(seq_b) + scheduler.block_trie.allocate(seq_b) + state_idx_b = scheduler.block_trie.reserve_state_checkpoint_for_seq(seq_b) + + assert state_idx_b >= 0 + assert state_idx_b == state_idx_a + assert node_a.state_idx == -1 + assert not node_a.state_ready + assert scheduler.state_manager.get_num_free_checkpoint() == 0 + + assert scheduler.block_trie.commit_state_checkpoint_for_seq(seq_b) + + seq_a = scheduler.add_session(100).add_sequence(token_ids_a + [3]) + scheduler.block_trie.match(seq_a) + assert seq_a.prefix_cache.restore_state == -1 + + seq_b = scheduler.add_session(101).add_sequence(token_ids_b + [3]) + scheduler.block_trie.match(seq_b) + assert seq_b.prefix_cache.restore_state == state_idx_b + + def test_ssm_checkpoint_save_producer_pin_blocks_eviction_until_release(self, ssm_cache_config, scheduler_config, + seq_meta): + cache_config = ssm_cache_config + cache_config.prefix_cache_state_budget = 0 + cache_config.num_state_caches = 2 + scheduler = Scheduler(scheduler_config=scheduler_config, cache_config=cache_config, seq_meta=seq_meta) + block_size = scheduler.seq_meta.block_size + token_ids_a = [1] * block_size * 2 + token_ids_b = [2] * block_size * 2 + + seq_a = scheduler.add_session(0).add_sequence(token_ids_a) + scheduler.block_manager.allocate(seq_a) + scheduler.block_trie.allocate(seq_a) + state_idx_a = scheduler.block_trie.reserve_state_checkpoint_for_seq(seq_a) + node_a = seq_a.prefix_cache.save_node + + assert state_idx_a >= 0 + assert scheduler.block_trie.commit_state_checkpoint_for_seq(seq_a, acquire_save_ref=True) + assert node_a.state_ready + assert node_a.state_ref_count == 1 + + matched = scheduler.add_session(1).add_sequence(token_ids_a + [3]) + scheduler.block_trie.match(matched) + assert matched.prefix_cache.restore_state == state_idx_a + + seq_b = scheduler.add_session(2).add_sequence(token_ids_b) + scheduler.block_manager.allocate(seq_b) + scheduler.block_trie.allocate(seq_b) + + assert scheduler.block_trie.reserve_state_checkpoint_for_seq(seq_b) == -1 + assert node_a.state_idx == state_idx_a + assert node_a.state_ready + assert node_a.state_ref_count == 1 + + assert scheduler.block_trie.release_state_checkpoint_save_for_seq(seq_a) + assert node_a.state_ref_count == 0 + + state_idx_b = scheduler.block_trie.reserve_state_checkpoint_for_seq(seq_b) + assert state_idx_b == state_idx_a + assert node_a.state_idx == -1 + assert not node_a.state_ready + + def test_ssm_checkpoint_state_eviction_skips_pinned_restore(self, ssm_cache_config, scheduler_config, seq_meta): + cache_config = ssm_cache_config + cache_config.prefix_cache_state_budget = 0 + cache_config.num_state_caches = 3 + scheduler = Scheduler(scheduler_config=scheduler_config, cache_config=cache_config, seq_meta=seq_meta) + block_size = scheduler.seq_meta.block_size + token_ids_a = [1] * block_size * 2 + token_ids_b = [2] * block_size * 2 + token_ids_c = [3] * block_size * 2 + + _, node_a, state_idx_a = self._add_ready_ssm_checkpoint(scheduler, token_ids_a) + _, node_b, state_idx_b = self._add_ready_ssm_checkpoint(scheduler, token_ids_b) + + seq_a = scheduler.add_session(100).add_sequence(token_ids_a + [4]) + scheduler.block_trie.match(seq_a) + assert seq_a.prefix_cache.restore_state == state_idx_a + assert scheduler.block_trie.acquire_state_checkpoint_restore_for_seq(seq_a) + assert node_a.state_ref_count == 1 + + seq_c = scheduler.add_session(101).add_sequence(token_ids_c) + scheduler.block_manager.allocate(seq_c) + scheduler.block_trie.allocate(seq_c) + state_idx_c = scheduler.block_trie.reserve_state_checkpoint_for_seq(seq_c) + + assert state_idx_c >= 0 + assert node_a.state_idx == state_idx_a + assert node_a.state_ready + assert node_b.state_idx == -1 + assert not node_b.state_ready + assert state_idx_c == state_idx_b + + assert scheduler.block_trie.release_state_checkpoint_restore_for_seq(seq_a) + assert node_a.state_ref_count == 0 + assert seq_a.prefix_cache.restore_state == -1 + + def test_evict_ssm_releases_state_checkpoint(self, ssm_scheduler): + block_mgr = ssm_scheduler.block_manager + block_trie = ssm_scheduler.block_trie + sess = ssm_scheduler.add_session(0) + block_size = sess.seq_meta.block_size + token_ids = [1] * block_size * 2 + [2] + + seq = sess.add_sequence(token_ids) + block_mgr.allocate(seq) + block_trie.allocate(seq) + node = seq.prefix_cache.last_shared_node + block_trie.reserve_state_checkpoint(node) + block_trie.mark_state_checkpoint_ready(node) + free_states = ssm_scheduler.state_manager.get_num_free_checkpoint() + + block_mgr.free(seq) + seq.set_step(0) + block_trie.evict(1) + + assert ssm_scheduler.state_manager.get_num_free_checkpoint() == free_states + 1 diff --git a/tests/pytorch/paging/test_scheduler.py b/tests/pytorch/paging/test_scheduler.py index de8e37cca0..d91b217be2 100644 --- a/tests/pytorch/paging/test_scheduler.py +++ b/tests/pytorch/paging/test_scheduler.py @@ -2,8 +2,11 @@ import torch from lmdeploy.pytorch.config import CacheConfig, SchedulerConfig -from lmdeploy.pytorch.messages import MessageStatus, SequenceMeta +from lmdeploy.pytorch.disagg.conn.protocol import MigrationProtocol, MigrationRequest +from lmdeploy.pytorch.engine.inputs_maker import _compact_state_prefix_cache_save_offsets +from lmdeploy.pytorch.messages import MessageStatus, SequenceMeta, UpdateTokenMode from lmdeploy.pytorch.paging.scheduler import Scheduler +from lmdeploy.pytorch.paging.state_manager import StateManager class TestScheduler: @@ -179,3 +182,467 @@ def test_evict(self, scheduler, block_size, num_gpu_blocks, num_cpu_blocks): assert seq1.status == MessageStatus.READY assert seq2.status == MessageStatus.WAITING assert block_manager.get_num_free_gpu_blocks() == 2 + + +def test_state_manager_reserves_system_state_slot(): + manager = StateManager(num_states=3, num_reserved=1) + + assert manager.allocate_state() == 1 + assert manager.allocate_state() == 2 + with pytest.raises(RuntimeError, match='No free states'): + manager.allocate_state() + + +def test_state_manager_checkpoint_can_borrow_idle_runtime_slots(): + manager = StateManager(num_states=5, num_reserved=1, num_runtime_states=2) + + checkpoints = [manager.allocate_checkpoint_state() for _ in range(4)] + assert checkpoints == [1, 2, 3, 4] + with pytest.raises(RuntimeError, match='No free states'): + manager.allocate_checkpoint_state() + + manager.free_checkpoint_state(checkpoints[0]) + manager.free_checkpoint_state(checkpoints[1]) + assert manager.allocate_state() == checkpoints[1] + assert manager.allocate_state() == checkpoints[0] + with pytest.raises(RuntimeError, match='No free states'): + manager.allocate_state() + + +def test_state_manager_caps_runtime_count_even_with_extra_free_slots(): + manager = StateManager(num_states=6, num_reserved=1, num_runtime_states=2) + + assert manager.num_runtime_states == 2 + assert manager.allocate_state() == 1 + assert manager.allocate_state() == 2 + assert manager.get_num_free() == 3 + assert manager.get_num_free_runtime() == 0 + with pytest.raises(RuntimeError, match='No free states'): + manager.allocate_state() + + +def _make_ssm_scheduler(max_batch_size: int = 1, prefix_cache_state_budget: int = 0): + from lmdeploy.pytorch.strategies.ar.sequence import ARSequenceStrategy + block_size = 16 + cache_config = CacheConfig(max_batches=max_batch_size, + block_size=block_size, + num_cpu_blocks=4, + num_gpu_blocks=16, + enable_prefix_caching=True, + num_state_caches=max_batch_size + 1 + prefix_cache_state_budget, + prefix_cache_state_budget=prefix_cache_state_budget, + states_shapes=[((1, ), torch.float32)]) + scheduler_config = SchedulerConfig(max_batches=max_batch_size, + max_session_len=128, + max_request_output_len=64, + eviction_type='recompute') + seq_meta = SequenceMeta(block_size, strategy=ARSequenceStrategy()) + return Scheduler(scheduler_config=scheduler_config, cache_config=cache_config, seq_meta=seq_meta) + + +def _add_ready_ssm_checkpoint(scheduler: Scheduler, token_ids: list[int]): + session = scheduler.add_session(len(scheduler.sessions)) + seq = session.add_sequence(token_ids) + scheduler.block_manager.allocate(seq) + scheduler.block_trie.allocate(seq) + state_idx = scheduler.block_trie.reserve_state_checkpoint_for_seq(seq) + assert state_idx >= 0 + assert scheduler.block_trie.commit_state_checkpoint_for_seq(seq) + node = seq.prefix_cache.last_shared_node + session.remove_sequence(seq) + return node, state_idx + + +def test_ssm_runtime_state_reclaims_borrowed_checkpoint_slot(): + scheduler = _make_ssm_scheduler(max_batch_size=1, prefix_cache_state_budget=0) + block_size = scheduler.seq_meta.block_size + node, state_idx = _add_ready_ssm_checkpoint(scheduler, [1] * block_size * 2) + seq = scheduler.add_session(100).add_sequence([2] * block_size * 2) + + output = scheduler.schedule(is_prefill=True) + + assert output.running == [seq] + assert seq.logical_state == state_idx + assert node.state_idx == -1 + assert not node.state_ready + assert scheduler.state_manager.get_num_runtime_states() == 1 + assert scheduler.state_manager.get_num_allocated_checkpoint_states() == 0 + + +def test_ssm_long_chunked_request_schedules_with_only_runtime_state_slot(): + scheduler = _make_ssm_scheduler(max_batch_size=1, prefix_cache_state_budget=0) + scheduler.cache_config.max_prefill_token_num = scheduler.seq_meta.block_size * 2 + block_size = scheduler.seq_meta.block_size + token_ids = [1] * block_size + [2] * block_size + [3] * block_size + seq = scheduler.add_session(100).add_sequence(token_ids) + + output = scheduler.schedule(is_prefill=True) + + assert output.running == [seq] + assert seq.logical_state >= 0 + assert scheduler.state_manager.get_num_runtime_states() == 1 + assert scheduler.state_manager.get_num_allocated_checkpoint_states() == 0 + assert scheduler.block_trie.reserve_state_checkpoint_for_seq(seq, step=block_size * 2) == -1 + + +def test_ssm_runtime_state_waits_when_only_checkpoint_slot_is_pinned(): + scheduler = _make_ssm_scheduler(max_batch_size=1, prefix_cache_state_budget=0) + block_size = scheduler.seq_meta.block_size + node, state_idx = _add_ready_ssm_checkpoint(scheduler, [1] * block_size * 2) + node.state_ref_count = 1 + seq = scheduler.add_session(100).add_sequence([2] * block_size * 2) + + output = scheduler.schedule(is_prefill=True) + + assert output.running == [] + assert seq.status == MessageStatus.WAITING + assert seq.logical_state == -1 + assert node.state_idx == state_idx + assert node.state_ready + + +def test_ssm_same_batch_duplicate_checkpoint_save_has_unique_dst_offsets(): + scheduler = _make_ssm_scheduler(max_batch_size=2, prefix_cache_state_budget=2) + block_size = scheduler.seq_meta.block_size + token_ids = [1] * block_size * 2 + + seq_a = scheduler.add_session(100).add_sequence(token_ids) + seq_b = scheduler.add_session(101).add_sequence(token_ids) + + output = scheduler.schedule(is_prefill=True) + assert output.running == [seq_a, seq_b] + assert seq_a.logical_state >= 0 + assert seq_b.logical_state >= 0 + assert seq_a.logical_state != seq_b.logical_state + assert seq_a.prefix_cache.last_shared_node is seq_b.prefix_cache.last_shared_node + + save_state_offsets = [ + scheduler.block_trie.reserve_state_checkpoint_for_seq(seq) for seq in output.running + ] + save_src_offsets, save_dst_offsets = _compact_state_prefix_cache_save_offsets(output.running, + save_state_offsets) + + assert save_src_offsets == (seq_a.logical_state, ) + assert save_dst_offsets == (save_state_offsets[0], ) + assert save_state_offsets[0] >= 0 + assert save_state_offsets[1] == -1 + assert len(save_dst_offsets) == len(set(save_dst_offsets)) + + +def test_ssm_end_session_discards_pending_checkpoint_reservation(): + scheduler = _make_ssm_scheduler(max_batch_size=1, prefix_cache_state_budget=1) + block_size = scheduler.seq_meta.block_size + session = scheduler.add_session(100) + seq = session.add_sequence([1] * block_size * 2) + scheduler.block_manager.allocate(seq) + scheduler.block_trie.allocate(seq) + scheduler.state_manager.allocate(seq) + + state_idx = scheduler.block_trie.reserve_state_checkpoint_for_seq(seq) + node = seq.prefix_cache.save_node + assert state_idx >= 0 + assert node is not None + assert scheduler.state_manager.get_num_allocated_checkpoint_states() == 1 + + scheduler.end_session(100) + + assert 100 not in scheduler.sessions + assert node.state_idx == -1 + assert not node.state_ready + assert scheduler.state_manager.get_num_runtime_states() == 0 + assert scheduler.state_manager.get_num_allocated_checkpoint_states() == 0 + + +def test_ssm_end_session_releases_acquired_restore_checkpoint(): + scheduler = _make_ssm_scheduler(max_batch_size=1, prefix_cache_state_budget=1) + block_size = scheduler.seq_meta.block_size + node, state_idx = _add_ready_ssm_checkpoint(scheduler, [1] * block_size * 2) + seq = scheduler.add_session(100).add_sequence([1] * block_size * 2 + [2]) + + scheduler.block_trie.match(seq) + assert seq.prefix_cache.restore_state == state_idx + assert scheduler.block_trie.acquire_state_checkpoint_restore_for_seq(seq) + assert node.state_ref_count == 1 + + scheduler.end_session(100) + + assert 100 not in scheduler.sessions + assert node.state_idx == state_idx + assert node.state_ready + assert node.state_ref_count == 0 + + +def test_ssm_failed_restore_schedule_rolls_back_match(): + scheduler = _make_ssm_scheduler(max_batch_size=1, prefix_cache_state_budget=0) + block_size = scheduler.seq_meta.block_size + node, state_idx = _add_ready_ssm_checkpoint(scheduler, [1] * block_size * 2) + node.state_ref_count = 1 + seq = scheduler.add_session(100).add_sequence([1] * block_size * 2 + [2]) + + output = scheduler.schedule(is_prefill=True) + + assert output.running == [] + assert seq.status == MessageStatus.WAITING + assert seq.num_history_ids == 0 + assert len(seq.logical_blocks) == 0 + assert seq.cached_tokens == 0 + assert seq.prefix_cache.last_shared_node is None + assert seq.prefix_cache.restore_state == -1 + assert seq.prefix_cache.restore_node is None + assert node.state_idx == state_idx + assert node.state_ready + assert scheduler.block_trie.stats.num_query_tokens == 0 + assert scheduler.block_trie.stats.num_hit_tokens == 0 + + node.state_ref_count = 0 + output = scheduler.schedule(is_prefill=True) + + assert output.running == [seq] + assert seq.status == MessageStatus.READY + assert seq.num_history_ids == 0 + assert seq.prefix_cache.restore_state == -1 + assert seq.logical_state == state_idx + assert node.state_idx == -1 + assert not node.state_ready + assert scheduler.block_trie.stats.num_query_tokens == 0 + assert scheduler.block_trie.stats.num_hit_tokens == 0 + + +def test_ssm_scheduler_preserves_matched_checkpoint_when_evicting_for_runtime_state(): + scheduler = _make_ssm_scheduler(max_batch_size=1, prefix_cache_state_budget=1) + block_size = scheduler.seq_meta.block_size + node_a, state_idx_a = _add_ready_ssm_checkpoint(scheduler, [1] * block_size * 2) + node_b, state_idx_b = _add_ready_ssm_checkpoint(scheduler, [2] * block_size * 2) + seq = scheduler.add_session(100).add_sequence([1] * block_size * 2 + [3]) + + output = scheduler.schedule(is_prefill=True) + + assert output.running == [seq] + assert seq.num_history_ids == block_size * 2 + assert seq.cached_tokens == block_size * 2 + assert seq.prefix_cache.restore_state == state_idx_a + assert seq.prefix_cache.restore_node is node_a + assert seq.prefix_cache.restore_state_acquired + assert seq.logical_state == state_idx_b + assert node_a.state_idx == state_idx_a + assert node_a.state_ready + assert node_a.state_ref_count == 1 + assert node_b.state_idx == -1 + assert not node_b.state_ready + assert scheduler.block_trie.stats.num_hit_tokens == block_size * 2 + + assert scheduler.block_trie.release_state_checkpoint_restore_for_seq(seq) + + +def test_ssm_scheduler_evicts_stopped_runtime_state_with_free_checkpoint_slot(): + scheduler = _make_ssm_scheduler(max_batch_size=1, prefix_cache_state_budget=1) + block_size = scheduler.seq_meta.block_size + seq_a = scheduler.add_session(100).add_sequence([1] * block_size) + + output = scheduler.schedule(is_prefill=True) + assert output.running == [seq_a] + assert seq_a.logical_state >= 0 + assert scheduler.state_manager.get_num_free() == 1 + assert scheduler.state_manager.get_num_free_runtime() == 0 + + seq_a.state.stop() + seq_b = scheduler.add_session(101).add_sequence([2] * block_size) + + output = scheduler.schedule(is_prefill=True) + + assert output.running == [seq_b] + assert seq_b.logical_state >= 0 + assert seq_a.logical_state == -1 + assert seq_a.status == MessageStatus.STOPPED + + +def test_schedule_migration_matches_current_sequence(): + from lmdeploy.pytorch.strategies.ar.sequence import ARSequenceStrategy + block_size = 16 + seq_meta = SequenceMeta(block_size, strategy=ARSequenceStrategy()) + cache_config = CacheConfig(max_batches=1, + block_size=block_size, + num_cpu_blocks=4, + num_gpu_blocks=4, + enable_prefix_caching=True) + scheduler_config = SchedulerConfig(max_batches=1, + max_session_len=128, + max_request_output_len=64, + eviction_type='recompute') + scheduler = Scheduler(scheduler_config=scheduler_config, cache_config=cache_config, seq_meta=seq_meta) + migration_request = MigrationRequest(protocol=MigrationProtocol.RDMA, + remote_engine_id='prefill-0', + remote_session_id=7, + remote_token_id=8, + remote_block_ids=[1]) + seq = scheduler.add_session(100).add_sequence([1] * block_size, migration_request=migration_request) + + output = scheduler._schedule_migration() + + assert output == [seq] + assert seq.status == MessageStatus.MIGRATION_READY + + +def test_scheduler_publishes_cached_tokens_for_accepted_prefix_hit(): + from lmdeploy.pytorch.strategies.ar.sequence import ARSequenceStrategy + block_size = 16 + seq_meta = SequenceMeta(block_size, strategy=ARSequenceStrategy()) + cache_config = CacheConfig(max_batches=1, + block_size=block_size, + num_cpu_blocks=0, + num_gpu_blocks=8, + enable_prefix_caching=True) + scheduler_config = SchedulerConfig(max_batches=1, + max_session_len=128, + max_request_output_len=64, + eviction_type='recompute') + scheduler = Scheduler(scheduler_config=scheduler_config, cache_config=cache_config, seq_meta=seq_meta) + + cached = scheduler.add_session(0).add_sequence([1] * block_size + [2] * block_size + [3]) + scheduler.schedule(is_prefill=True) + cached.state.stop() + + seq = scheduler.add_session(1).add_sequence([1] * block_size + [2] * block_size + [4]) + output = scheduler.schedule(is_prefill=True) + + assert output.running == [seq] + assert seq.num_history_ids == block_size * 2 + assert seq.cached_tokens == block_size * 2 + + seq.update_token_ids(torch.tensor([5])) + + assert seq.cached_tokens == 0 + assert seq.prefix_cache.match_start_step == -1 + + +def test_scheduler_reports_zero_cached_tokens_for_prefix_miss(): + from lmdeploy.pytorch.strategies.ar.sequence import ARSequenceStrategy + block_size = 16 + seq_meta = SequenceMeta(block_size, strategy=ARSequenceStrategy()) + cache_config = CacheConfig(max_batches=1, + block_size=block_size, + num_cpu_blocks=0, + num_gpu_blocks=8, + enable_prefix_caching=True) + scheduler_config = SchedulerConfig(max_batches=1, + max_session_len=128, + max_request_output_len=64, + eviction_type='recompute') + scheduler = Scheduler(scheduler_config=scheduler_config, cache_config=cache_config, seq_meta=seq_meta) + + cached = scheduler.add_session(0).add_sequence([1] * block_size + [2]) + scheduler.schedule(is_prefill=True) + cached.state.stop() + + seq = scheduler.add_session(1).add_sequence([3] * block_size + [4]) + output = scheduler.schedule(is_prefill=True) + + assert output.running == [seq] + assert seq.num_history_ids == 0 + assert seq.cached_tokens == 0 + + +def test_scheduler_cached_tokens_only_count_current_prompt_after_session_eviction(): + from lmdeploy.pytorch.strategies.ar.sequence import ARSequenceStrategy + block_size = 16 + seq_meta = SequenceMeta(block_size, strategy=ARSequenceStrategy()) + cache_config = CacheConfig(max_batches=1, + block_size=block_size, + num_cpu_blocks=0, + num_gpu_blocks=8, + enable_prefix_caching=True) + scheduler_config = SchedulerConfig(max_batches=1, + max_session_len=128, + max_request_output_len=64, + eviction_type='recompute') + scheduler = Scheduler(scheduler_config=scheduler_config, cache_config=cache_config, seq_meta=seq_meta) + + session = scheduler.add_session(0) + seq = session.add_sequence([1] * block_size + [2] * block_size + [3]) + scheduler.schedule(is_prefill=True) + seq.update_token_ids(torch.tensor([9]), mode=UpdateTokenMode.PREFILL) + seq.state.stop() + seq.state.free() + + seq.update_token_ids(torch.tensor([4] * 4)) + assert seq.input_start_pos == block_size * 2 + 2 + assert seq.input_end_pos == block_size * 2 + 6 + seq.state.activate() + + output = scheduler.schedule(is_prefill=True) + + assert output.running == [seq] + assert seq.num_history_ids == block_size * 2 + assert seq.cached_tokens == 0 + + +def test_scheduler_excludes_recompute_eviction_prefix_hits_from_stats(): + from lmdeploy.pytorch.strategies.ar.sequence import ARSequenceStrategy + block_size = 16 + seq_meta = SequenceMeta(block_size, strategy=ARSequenceStrategy()) + cache_config = CacheConfig(max_batches=1, + block_size=block_size, + num_cpu_blocks=0, + num_gpu_blocks=4, + enable_prefix_caching=True) + scheduler_config = SchedulerConfig(max_batches=1, + max_session_len=128, + max_request_output_len=64, + eviction_type='recompute') + scheduler = Scheduler(scheduler_config=scheduler_config, cache_config=cache_config, seq_meta=seq_meta) + + seq = scheduler.add_session(0).add_sequence([1] * block_size + [2] * block_size + [3]) + output = scheduler.schedule(is_prefill=True) + assert output.running == [seq] + + seq.state.evict() + pressure = scheduler.add_session(1).add_sequence([9] * block_size * 3) + scheduler.block_trie.stats.reset() + + assert scheduler.eviction_helper.evict_for_seq(pressure, [seq], 0) + assert seq.prefix_cache.suppress_match_stats + pressure.session.remove_sequence(pressure) + + output = scheduler.schedule(is_prefill=True) + + assert output.running == [seq] + assert seq.num_history_ids >= block_size + assert seq.cached_tokens == 0 + assert not seq.prefix_cache.suppress_match_stats + assert scheduler.block_trie.stats.num_query_tokens == 0 + assert scheduler.block_trie.stats.num_hit_tokens == 0 + + +def test_scheduler_rolls_back_prefix_hit_that_would_start_long_context_chunk_from_middle(): + from lmdeploy.pytorch.strategies.ar.sequence import ARSequenceStrategy + block_size = 16 + seq_meta = SequenceMeta(block_size, strategy=ARSequenceStrategy()) + cache_config = CacheConfig(max_batches=1, + block_size=block_size, + num_cpu_blocks=0, + num_gpu_blocks=8, + max_prefill_token_num=block_size * 2, + enable_prefix_caching=True) + scheduler_config = SchedulerConfig(max_batches=1, + max_session_len=128, + max_request_output_len=64, + eviction_type='recompute') + scheduler = Scheduler(scheduler_config=scheduler_config, cache_config=cache_config, seq_meta=seq_meta) + + cached = scheduler.add_session(0).add_sequence([1] * block_size + [2] * block_size) + scheduler.block_manager.allocate(cached) + scheduler.block_trie.allocate(cached) + cached.state.stop() + + token_ids = [1] * block_size + [2] * block_size + [3] * block_size + token_ids += [4] * block_size + [5] * block_size + seq = scheduler.add_session(1).add_sequence(token_ids) + + output = scheduler.schedule(is_prefill=True) + + assert output.running == [seq] + assert seq.num_history_ids == 0 + assert seq.num_token_ids == len(token_ids) + assert seq.cached_tokens == 0 + assert scheduler.block_trie.stats.num_query_tokens == 0 + assert scheduler.block_trie.stats.num_hit_tokens == 0