Skip to content

Commit 75f5ddc

Browse files
authored
Refactor prefix caching for pytorch engine (#4618)
* finish vlm * add context hash * early hash * ssm prefix caching prefill * decoding ssm * refactor sequence * add comment * enable when prefix_cache_decode_state_interval=0 * optimize copy state * better copy cache * easy engine loop func * more fix * fix end states * update block trie * refactor block trie * add hit rate metrics * add longbenchv2 * fix * add check and raise * Fix Qwen3 Omni fake model dtype fixture * fix gdr kernel for tilelang>=0.1.9 * add flag * fix bugs * update comment * remove init state * remove unrelated * fix duplicate node * fix reserve_decode_state_checkpoint_for_seq * fix 27b * fix * update * fix long chunk * fix lint * add acache_tokens * fix metric
1 parent 4f25485 commit 75f5ddc

35 files changed

Lines changed: 4198 additions & 101 deletions

lmdeploy/cli/cli.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ def add_parser_chat():
5454
ArgumentHelper.device(pt_group)
5555
ArgumentHelper.eager_mode(pt_group)
5656
ArgumentHelper.dllm_block_length(pt_group)
57+
ArgumentHelper.prefix_cache_state_budget(pt_group)
58+
ArgumentHelper.prefix_cache_decode_state_interval(pt_group)
5759
# common engine args
5860
dtype_act = ArgumentHelper.dtype(pt_group)
5961
tp_act = ArgumentHelper.tp(pt_group)

lmdeploy/cli/serve.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,8 @@ def add_parser_api_server():
107107
ArgumentHelper.enable_return_routed_experts(pt_group)
108108
ArgumentHelper.distributed_executor_backend(pt_group)
109109
ArgumentHelper.kernel_block_size(pt_group)
110+
ArgumentHelper.prefix_cache_state_budget(pt_group)
111+
ArgumentHelper.prefix_cache_decode_state_interval(pt_group)
110112

111113
# common engine args
112114
disable_vision_encoder = ArgumentHelper.disable_vision_encoder(pt_group)
@@ -234,6 +236,8 @@ def api_server(args):
234236
session_len=args.session_len,
235237
adapters=adapters,
236238
enable_prefix_caching=args.enable_prefix_caching,
239+
prefix_cache_state_budget=args.prefix_cache_state_budget,
240+
prefix_cache_decode_state_interval=args.prefix_cache_decode_state_interval,
237241
device_type=args.device,
238242
quant_policy=args.quant_policy,
239243
eager_mode=args.eager_mode,

lmdeploy/cli/utils.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -585,6 +585,30 @@ def enable_prefix_caching(parser):
585585
default=False,
586586
help='Enable cache and match prefix')
587587

588+
@staticmethod
589+
def prefix_cache_state_budget(parser):
590+
"""Add argument prefix_cache_state_budget to parser."""
591+
592+
return parser.add_argument('--prefix-cache-state-budget',
593+
type=int,
594+
default=0,
595+
help='Extra SSM state-cache slots budgeted for prefix-cache checkpoints. '
596+
'0 adds no extra slots, but checkpoints may borrow idle runtime state slots. '
597+
'Only used by the PyTorch engine.')
598+
599+
@staticmethod
600+
def prefix_cache_decode_state_interval(parser):
601+
"""Add argument prefix_cache_decode_state_interval to parser."""
602+
603+
return parser.add_argument('--prefix-cache-decode-state-interval',
604+
type=int,
605+
default=0,
606+
help='Token interval for SSM decode-state prefix-cache checkpoints. '
607+
'0 disables decode checkpoint saves while keeping prefill/chunk checkpoints. '
608+
'Use a positive multiple of block size only for long SSM decoding where later '
609+
'requests can reuse decode prefixes; smaller values improve hit granularity '
610+
'but use more checkpoint memory and copy work. Only used by the PyTorch engine.')
611+
588612
@staticmethod
589613
def num_tokens_per_iter(parser):
590614
return parser.add_argument('--num-tokens-per-iter',

lmdeploy/messages.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,17 @@ class PytorchEngineConfig:
367367
max_batch_size is always captured.
368368
thread_safe: thread safe engine instance.
369369
enable_prefix_caching: Enable token match and sharing caches.
370+
prefix_cache_state_budget: Extra SSM state-cache slots budgeted for
371+
prefix-cache checkpoints. 0 adds no extra slots, but SSM
372+
checkpoints may still borrow idle runtime state slots.
373+
prefix_cache_decode_state_interval: Token interval for SSM decode
374+
state checkpoints. 0 disables decode-state checkpoint saves; prefill
375+
and chunk checkpoints may still be saved. Keep 0 unless the workload
376+
has long SSM decoding and repeated continuations that can reuse
377+
decode checkpoints. Smaller positive values create more hit points
378+
but use more checkpoint memory and copy work; larger values reduce
379+
overhead but make decode-prefix hits less likely. Positive values
380+
must be multiples of the cache block size.
370381
device_type: The inference device type, options ['cuda']
371382
eager_mode: Enable "eager" mode or not
372383
custom_module_map: nn module map customized by users. Once
@@ -428,6 +439,8 @@ class PytorchEngineConfig:
428439
cudagraph_capture_batch_sizes: list[int] | None = None
429440
thread_safe: bool = False
430441
enable_prefix_caching: bool = False
442+
prefix_cache_state_budget: int = 0
443+
prefix_cache_decode_state_interval: int = 0
431444
device_type: str = 'cuda'
432445
eager_mode: bool = False
433446
custom_module_map: dict[str, str] = None
@@ -472,6 +485,8 @@ def __post_init__(self):
472485
assert self.max_prefill_token_num >= 0, \
473486
'invalid max_prefill_token_num'
474487
assert self.num_gpu_blocks >= 0, 'invalid num_gpu_blocks'
488+
assert self.prefix_cache_state_budget >= 0, 'invalid prefix_cache_state_budget'
489+
assert self.prefix_cache_decode_state_interval >= 0, 'invalid prefix_cache_decode_state_interval'
475490
try:
476491
self.quant_policy = QuantPolicy(self.quant_policy)
477492
except ValueError as e:
@@ -485,6 +500,9 @@ def __post_init__(self):
485500
(f'block_size must be >= kernel_block_size and an integer multiple '
486501
f'of kernel_block_size, but got block_size {self.block_size} '
487502
f'and kernel_block_size {self.kernel_block_size}')
503+
if self.prefix_cache_decode_state_interval > 0:
504+
assert self.prefix_cache_decode_state_interval % self.block_size == 0, (
505+
'prefix_cache_decode_state_interval must be a multiple of block_size')
488506
if self.quant_policy > 0 and self.device_type not in ['cuda', 'ascend']:
489507
assert False, \
490508
'kv cache quantization only works for CUDA and ASCEND.'

lmdeploy/metrics/loggers.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,11 @@ def __init__(self, model_name: str, max_model_len: int, dp_rank: int = 0):
191191
documentation='GPU KV-cache usage. 1 means 100 percent usage.',
192192
labelnames=labelnames).labels(*labelvalues)
193193

194+
self.gauge_prefix_cache_hit_rate = prometheus_client.Gauge(
195+
name='lmdeploy:prefix_cache_hit_rate',
196+
documentation='Prefix-cache hit rate. 1 means 100 percent of queried prefix tokens hit.',
197+
labelnames=labelnames).labels(*labelvalues)
198+
194199
#
195200
# Counters
196201
#
@@ -359,6 +364,7 @@ def record_schedule(self, stats: SchedulerStats) -> None:
359364
self.gauge_scheduler_running.set(stats.num_running_reqs)
360365
self.gauge_scheduler_waiting.set(stats.num_waiting_reqs)
361366
self.gauge_gpu_cache_usage.set(stats.gpu_cache_usage)
367+
self.gauge_prefix_cache_hit_rate.set(stats.prefix_cache_hit_rate)
362368

363369
def record_iteration(self, stats: IterationStats) -> None:
364370
"""Report token-related metrics to prometheus."""

lmdeploy/pytorch/block.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ def __init__(self, blocks: np.ndarray = None):
2424
assert blocks.ndim == 1
2525
self._blocks = blocks
2626
self._num_real = len(blocks)
27-
self.last_shared_node = None
2827

2928
def reserve(self, size: int):
3029
"""Reserve cache size."""
@@ -67,7 +66,6 @@ def resize(self, num_blocks: int):
6766
def reset(self):
6867
"""reset."""
6968
self.resize(0)
70-
self.last_shared_node = None
7169

7270
def clone(self):
7371
"""Clone logical blocks."""

lmdeploy/pytorch/config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,8 @@ class CacheConfig:
121121
quant_policy: QuantPolicy = QuantPolicy.NONE
122122
device_type: str = 'cuda'
123123
num_state_caches: int = None
124+
prefix_cache_state_budget: int = 0
125+
prefix_cache_decode_state_interval: int = 0
124126
states_shapes: list[tuple] = field(default_factory=list)
125127

126128
# reserved blocks for dummy inputs, init to 0 for unit test.
@@ -132,11 +134,16 @@ class CacheConfig:
132134

133135
def __post_init__(self):
134136
"""Post init."""
137+
assert self.prefix_cache_state_budget >= 0, 'invalid prefix_cache_state_budget'
138+
assert self.prefix_cache_decode_state_interval >= 0, 'invalid prefix_cache_decode_state_interval'
135139
if self.window_size > 1 and self.enable_prefix_caching:
136140
logger.warning('Prefix caching is not available for window attention.')
137141
self.enable_prefix_caching = False
138142
if self.kernel_block_size == -1:
139143
self.kernel_block_size = self.block_size
144+
if self.prefix_cache_decode_state_interval > 0:
145+
assert self.prefix_cache_decode_state_interval % self.block_size == 0, (
146+
'prefix_cache_decode_state_interval must be a multiple of block_size')
140147
self.cudagraph_capture_batch_sizes = normalize_cudagraph_capture_batch_sizes(
141148
self.cudagraph_capture_batch_sizes, self.max_batches)
142149

lmdeploy/pytorch/engine/cache_engine.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
# modify from: https://github.com/vllm-project/vllm
33
import json
44
import math
5+
from collections.abc import Sequence
56
from dataclasses import dataclass
7+
from operator import index as as_index
68

79
import torch
810

@@ -565,3 +567,79 @@ def get_cache_state_size(state_shapes: list[tuple[tuple[int], torch.dtype]]) ->
565567
def state_caches(self):
566568
"""State caches."""
567569
return self._state_caches
570+
571+
@staticmethod
572+
def _index_list(idx: int | Sequence[int]):
573+
"""Normalize host-side cache indices."""
574+
if isinstance(idx, torch.Tensor):
575+
raise TypeError('State cache copy indices must be host integers, not torch.Tensor.')
576+
if isinstance(idx, (str, bytes)):
577+
raise TypeError('State cache copy indices must be an int or a sequence of ints.')
578+
try:
579+
return [as_index(idx)]
580+
except TypeError:
581+
pass
582+
if not isinstance(idx, Sequence):
583+
raise TypeError('State cache copy indices must be an int or a sequence of ints.')
584+
if any(isinstance(item, torch.Tensor) for item in idx):
585+
raise TypeError('State cache copy indices must be host integers, not torch.Tensor.')
586+
return [as_index(item) for item in idx]
587+
588+
@staticmethod
589+
def _validate_index_bounds(indices: Sequence[int], num_caches: int):
590+
"""Check normalized cache indices are valid state slots."""
591+
for idx in indices:
592+
if idx < 0 or idx >= num_caches:
593+
raise ValueError(f'State cache index {idx} is out of range [0, {num_caches}).')
594+
595+
@staticmethod
596+
def _copy_ranges(src_list: list[int], dst_list: list[int]):
597+
"""Yield contiguous copy ranges as (src_start, dst_start, length)."""
598+
pairs = sorted(zip(src_list, dst_list))
599+
if len(pairs) == 0:
600+
return
601+
start_src = prev_src = pairs[0][0]
602+
start_dst = prev_dst = pairs[0][1]
603+
length = 1
604+
for src, dst in pairs[1:]:
605+
if src == prev_src + 1 and dst == prev_dst + 1:
606+
prev_src = src
607+
prev_dst = dst
608+
length += 1
609+
continue
610+
yield start_src, start_dst, length
611+
start_src = prev_src = src
612+
start_dst = prev_dst = dst
613+
length = 1
614+
yield start_src, start_dst, length
615+
616+
def copy_caches(self, src_idx: int | Sequence[int], dst_idx: int | Sequence[int]):
617+
"""Copy state cache slots.
618+
619+
This is the low-level primitive needed by SSM prefix caching: a frozen
620+
state checkpoint can be copied into a newly allocated runtime slot
621+
before the next forward.
622+
"""
623+
if len(self._state_caches) <= 0:
624+
return
625+
626+
src_list = self._index_list(src_idx)
627+
dst_list = self._index_list(dst_idx)
628+
if len(src_list) != len(dst_list):
629+
raise ValueError('src_idx and dst_idx must have the same number of elements.')
630+
if len(src_list) == 0:
631+
return
632+
num_caches = self.mem_pool.size(0)
633+
self._validate_index_bounds(src_list, num_caches)
634+
self._validate_index_bounds(dst_list, num_caches)
635+
dst_set = set(dst_list)
636+
if len(dst_set) != len(dst_list):
637+
raise ValueError('dst_idx must not contain duplicate entries.')
638+
if not set(src_list).isdisjoint(dst_set):
639+
raise ValueError('src_idx and dst_idx must not overlap for stream-ordered state copies.')
640+
641+
for src, dst, length in self._copy_ranges(src_list, dst_list):
642+
if length == 1:
643+
self.mem_pool[dst].copy_(self.mem_pool[src], non_blocking=True)
644+
else:
645+
self.mem_pool[dst:dst + length].copy_(self.mem_pool[src:src + length], non_blocking=True)

lmdeploy/pytorch/engine/config_builder.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ def build_cache_config(engine_config: PytorchEngineConfig):
7575
max_prefill_token_num=engine_config.max_prefill_token_num,
7676
cudagraph_capture_batch_sizes=engine_config.cudagraph_capture_batch_sizes,
7777
enable_prefix_caching=engine_config.enable_prefix_caching,
78+
prefix_cache_state_budget=engine_config.prefix_cache_state_budget,
79+
prefix_cache_decode_state_interval=engine_config.prefix_cache_decode_state_interval,
7880
quant_policy=engine_config.quant_policy,
7981
device_type=engine_config.device_type,
8082
migration_backend=engine_config.migration_backend,

lmdeploy/pytorch/engine/engine.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from ..adapter.adapter import AdapterManager
2424
from ..config import CacheConfig, ModelConfig
2525
from ..messages import MessageStatus, SchedulerSequence, UpdateTokenMode
26+
from ..multimodal.data_type import ensure_multimodal_content_hashes
2627
from ..paging import Scheduler
2728
from ..strategies import build_strategy_factory
2829
from .base import EngineBase
@@ -412,6 +413,8 @@ def _on_add_message(self, reqs: list[Request], **kwargs):
412413

413414
input_ids = result.input_ids
414415
input_multimodals = result.input_multimodals
416+
if self.cache_config.enable_prefix_caching:
417+
input_multimodals = ensure_multimodal_content_hashes(input_multimodals)
415418

416419
req_data['token_ids'] = input_ids
417420
req_data['input_multimodals'] = input_multimodals

0 commit comments

Comments
 (0)