Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
bfd25d5
finish vlm
grimoire May 23, 2026
ec8e772
add context hash
grimoire May 23, 2026
a86ae37
early hash
grimoire May 23, 2026
6d13d65
ssm prefix caching prefill
grimoire May 23, 2026
9ca3a3b
decoding ssm
grimoire May 24, 2026
96f6db4
refactor sequence
grimoire May 24, 2026
1aa8773
add comment
grimoire May 24, 2026
3866b42
enable when prefix_cache_decode_state_interval=0
grimoire May 25, 2026
9da79a7
optimize copy state
grimoire May 25, 2026
64991ff
better copy cache
grimoire May 25, 2026
a6c2dde
easy engine loop func
grimoire May 25, 2026
baef92a
more fix
grimoire May 25, 2026
90ddc95
fix end states
grimoire May 25, 2026
79838a4
update block trie
grimoire May 25, 2026
22167ae
refactor block trie
grimoire May 25, 2026
e93959b
add hit rate metrics
grimoire May 26, 2026
235b43d
add longbenchv2
grimoire May 26, 2026
efd2ffc
fix
grimoire May 26, 2026
8c8a0ba
Merge branch 'main' into refactor-prefix-caching
grimoire May 27, 2026
ec728f7
add check and raise
grimoire May 27, 2026
d6abb5f
Merge branch 'main' of github.com:InternLM/lmdeploy into refactor-pre…
grimoire May 29, 2026
145d5d5
Merge remote-tracking branch 'upstream/main' into refactor-prefix-cac…
grimoire Jun 1, 2026
56e3b7d
Merge remote-tracking branch 'upstream/main' into refactor-prefix-cac…
grimoire Jun 2, 2026
49f42e1
Fix Qwen3 Omni fake model dtype fixture
grimoire Jun 2, 2026
4bbcf03
Merge remote-tracking branch 'upstream/main' into refactor-prefix-cac…
grimoire Jun 8, 2026
e1fadea
Merge remote-tracking branch 'upstream/main' into refactor-prefix-cac…
grimoire Jun 8, 2026
240aae0
fix gdr kernel for tilelang>=0.1.9
grimoire Jun 9, 2026
79e5706
add flag
grimoire Jun 9, 2026
48ea0ea
Merge remote-tracking branch 'upstream/main' into refactor-prefix-cac…
grimoire Jun 9, 2026
8f550e6
Merge branch 'fix-gdr-tilelang019' into refactor-prefix-caching
grimoire Jun 9, 2026
9938afc
fix bugs
grimoire Jun 9, 2026
ff1f74a
Merge remote-tracking branch 'upstream/main' into refactor-prefix-cac…
grimoire Jun 10, 2026
636b629
update comment
grimoire Jun 10, 2026
292e9c9
Merge remote-tracking branch 'upstream/main' into refactor-prefix-cac…
grimoire Jun 10, 2026
9d54645
remove init state
grimoire Jun 10, 2026
6a96306
remove unrelated
grimoire Jun 10, 2026
3546b7c
fix duplicate node
grimoire Jun 10, 2026
db1590e
fix reserve_decode_state_checkpoint_for_seq
grimoire Jun 10, 2026
126c3d1
fix 27b
grimoire Jun 10, 2026
a0426f2
fix
grimoire Jun 11, 2026
b2979b0
Merge remote-tracking branch 'upstream/main' into refactor-prefix-cac…
grimoire Jun 11, 2026
5ca1229
update
grimoire Jun 11, 2026
8c0ff32
fix long chunk
grimoire Jun 12, 2026
3979c8b
fix lint
grimoire Jun 12, 2026
665fcc2
add acache_tokens
grimoire Jun 15, 2026
b13846f
fix metric
grimoire Jun 15, 2026
3185fa1
Merge remote-tracking branch 'upstream/main' into refactor-prefix-cac…
grimoire Jun 16, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lmdeploy/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ def add_parser_chat():
ArgumentHelper.device(pt_group)
ArgumentHelper.eager_mode(pt_group)
ArgumentHelper.dllm_block_length(pt_group)
ArgumentHelper.prefix_cache_state_budget(pt_group)
ArgumentHelper.prefix_cache_decode_state_interval(pt_group)
# common engine args
dtype_act = ArgumentHelper.dtype(pt_group)
tp_act = ArgumentHelper.tp(pt_group)
Expand Down
4 changes: 4 additions & 0 deletions lmdeploy/cli/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ def add_parser_api_server():
ArgumentHelper.enable_return_routed_experts(pt_group)
ArgumentHelper.distributed_executor_backend(pt_group)
ArgumentHelper.kernel_block_size(pt_group)
ArgumentHelper.prefix_cache_state_budget(pt_group)
ArgumentHelper.prefix_cache_decode_state_interval(pt_group)

# common engine args
disable_vision_encoder = ArgumentHelper.disable_vision_encoder(pt_group)
Expand Down Expand Up @@ -234,6 +236,8 @@ def api_server(args):
session_len=args.session_len,
adapters=adapters,
enable_prefix_caching=args.enable_prefix_caching,
prefix_cache_state_budget=args.prefix_cache_state_budget,
prefix_cache_decode_state_interval=args.prefix_cache_decode_state_interval,
device_type=args.device,
quant_policy=args.quant_policy,
eager_mode=args.eager_mode,
Expand Down
24 changes: 24 additions & 0 deletions lmdeploy/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,30 @@ def enable_prefix_caching(parser):
default=False,
help='Enable cache and match prefix')

@staticmethod
def prefix_cache_state_budget(parser):
"""Add argument prefix_cache_state_budget to parser."""

return parser.add_argument('--prefix-cache-state-budget',
type=int,
default=0,
help='Extra SSM state-cache slots budgeted for prefix-cache checkpoints. '
'0 adds no extra slots, but checkpoints may borrow idle runtime state slots. '
'Only used by the PyTorch engine.')

@staticmethod
def prefix_cache_decode_state_interval(parser):
Comment on lines +589 to +600

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please recommend their values for qwen3.5-35b/qwen3.5-35b-fp8/qwen3.5-397b/qwen3.5-397b-fp8

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prefix_cache_state_budget is the extra states for caching only, this value could be the same as max-batch-size if gpu memory is enough.
prefix_cache_decode_state_interval is the interval of checkpoint saving. This value should be determined by the inputs distribution. If most session lengh are larger than 1024, you can use 1024 as the prefix_cache_decode_state_interval

"""Add argument prefix_cache_decode_state_interval to parser."""

return parser.add_argument('--prefix-cache-decode-state-interval',
type=int,
default=0,
help='Token interval for SSM decode-state prefix-cache checkpoints. '
'0 disables decode checkpoint saves while keeping prefill/chunk checkpoints. '
'Use a positive multiple of block size only for long SSM decoding where later '
'requests can reuse decode prefixes; smaller values improve hit granularity '
'but use more checkpoint memory and copy work. Only used by the PyTorch engine.')

@staticmethod
def num_tokens_per_iter(parser):
return parser.add_argument('--num-tokens-per-iter',
Expand Down
18 changes: 18 additions & 0 deletions lmdeploy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,17 @@ class PytorchEngineConfig:
max_batch_size is always captured.
thread_safe: thread safe engine instance.
enable_prefix_caching: Enable token match and sharing caches.
prefix_cache_state_budget: Extra SSM state-cache slots budgeted for
prefix-cache checkpoints. 0 adds no extra slots, but SSM
checkpoints may still borrow idle runtime state slots.
prefix_cache_decode_state_interval: Token interval for SSM decode
state checkpoints. 0 disables decode-state checkpoint saves; prefill
and chunk checkpoints may still be saved. Keep 0 unless the workload
has long SSM decoding and repeated continuations that can reuse
decode checkpoints. Smaller positive values create more hit points
but use more checkpoint memory and copy work; larger values reduce
overhead but make decode-prefix hits less likely. Positive values
must be multiples of the cache block size.
device_type: The inference device type, options ['cuda']
eager_mode: Enable "eager" mode or not
custom_module_map: nn module map customized by users. Once
Expand Down Expand Up @@ -428,6 +439,8 @@ class PytorchEngineConfig:
cudagraph_capture_batch_sizes: list[int] | None = None
thread_safe: bool = False
enable_prefix_caching: bool = False
prefix_cache_state_budget: int = 0
prefix_cache_decode_state_interval: int = 0
device_type: str = 'cuda'
eager_mode: bool = False
custom_module_map: dict[str, str] = None
Expand Down Expand Up @@ -472,6 +485,8 @@ def __post_init__(self):
assert self.max_prefill_token_num >= 0, \
'invalid max_prefill_token_num'
assert self.num_gpu_blocks >= 0, 'invalid num_gpu_blocks'
assert self.prefix_cache_state_budget >= 0, 'invalid prefix_cache_state_budget'
assert self.prefix_cache_decode_state_interval >= 0, 'invalid prefix_cache_decode_state_interval'
try:
self.quant_policy = QuantPolicy(self.quant_policy)
except ValueError as e:
Expand All @@ -485,6 +500,9 @@ def __post_init__(self):
(f'block_size must be >= kernel_block_size and an integer multiple '
f'of kernel_block_size, but got block_size {self.block_size} '
f'and kernel_block_size {self.kernel_block_size}')
if self.prefix_cache_decode_state_interval > 0:
assert self.prefix_cache_decode_state_interval % self.block_size == 0, (
'prefix_cache_decode_state_interval must be a multiple of block_size')
Comment on lines +503 to +505

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to update the help information of prefix_cache_decode_state_interval

if self.quant_policy > 0 and self.device_type not in ['cuda', 'ascend']:
assert False, \
'kv cache quantization only works for CUDA and ASCEND.'
Expand Down
6 changes: 6 additions & 0 deletions lmdeploy/metrics/loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,11 @@ def __init__(self, model_name: str, max_model_len: int, dp_rank: int = 0):
documentation='GPU KV-cache usage. 1 means 100 percent usage.',
labelnames=labelnames).labels(*labelvalues)

self.gauge_prefix_cache_hit_rate = prometheus_client.Gauge(
name='lmdeploy:prefix_cache_hit_rate',
documentation='Prefix-cache hit rate. 1 means 100 percent of queried prefix tokens hit.',
labelnames=labelnames).labels(*labelvalues)

#
# Counters
#
Expand Down Expand Up @@ -359,6 +364,7 @@ def record_schedule(self, stats: SchedulerStats) -> None:
self.gauge_scheduler_running.set(stats.num_running_reqs)
self.gauge_scheduler_waiting.set(stats.num_waiting_reqs)
self.gauge_gpu_cache_usage.set(stats.gpu_cache_usage)
self.gauge_prefix_cache_hit_rate.set(stats.prefix_cache_hit_rate)

def record_iteration(self, stats: IterationStats) -> None:
"""Report token-related metrics to prometheus."""
Expand Down
2 changes: 0 additions & 2 deletions lmdeploy/pytorch/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ def __init__(self, blocks: np.ndarray = None):
assert blocks.ndim == 1
self._blocks = blocks
self._num_real = len(blocks)
self.last_shared_node = None

def reserve(self, size: int):
"""Reserve cache size."""
Expand Down Expand Up @@ -67,7 +66,6 @@ def resize(self, num_blocks: int):
def reset(self):
"""reset."""
self.resize(0)
self.last_shared_node = None

def clone(self):
"""Clone logical blocks."""
Expand Down
7 changes: 7 additions & 0 deletions lmdeploy/pytorch/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ class CacheConfig:
quant_policy: QuantPolicy = QuantPolicy.NONE
device_type: str = 'cuda'
num_state_caches: int = None
prefix_cache_state_budget: int = 0
prefix_cache_decode_state_interval: int = 0
states_shapes: list[tuple] = field(default_factory=list)

# reserved blocks for dummy inputs, init to 0 for unit test.
Expand All @@ -132,11 +134,16 @@ class CacheConfig:

def __post_init__(self):
"""Post init."""
assert self.prefix_cache_state_budget >= 0, 'invalid prefix_cache_state_budget'
assert self.prefix_cache_decode_state_interval >= 0, 'invalid prefix_cache_decode_state_interval'
if self.window_size > 1 and self.enable_prefix_caching:
logger.warning('Prefix caching is not available for window attention.')
self.enable_prefix_caching = False
if self.kernel_block_size == -1:
self.kernel_block_size = self.block_size
if self.prefix_cache_decode_state_interval > 0:
assert self.prefix_cache_decode_state_interval % self.block_size == 0, (
'prefix_cache_decode_state_interval must be a multiple of block_size')
self.cudagraph_capture_batch_sizes = normalize_cudagraph_capture_batch_sizes(
self.cudagraph_capture_batch_sizes, self.max_batches)

Expand Down
78 changes: 78 additions & 0 deletions lmdeploy/pytorch/engine/cache_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
# modify from: https://github.com/vllm-project/vllm
import json
import math
from collections.abc import Sequence
from dataclasses import dataclass
from operator import index as as_index

import torch

Expand Down Expand Up @@ -565,3 +567,79 @@ def get_cache_state_size(state_shapes: list[tuple[tuple[int], torch.dtype]]) ->
def state_caches(self):
"""State caches."""
return self._state_caches

@staticmethod
def _index_list(idx: int | Sequence[int]):
"""Normalize host-side cache indices."""
if isinstance(idx, torch.Tensor):
raise TypeError('State cache copy indices must be host integers, not torch.Tensor.')
if isinstance(idx, (str, bytes)):
raise TypeError('State cache copy indices must be an int or a sequence of ints.')
try:
return [as_index(idx)]
except TypeError:
pass
if not isinstance(idx, Sequence):
raise TypeError('State cache copy indices must be an int or a sequence of ints.')
if any(isinstance(item, torch.Tensor) for item in idx):
raise TypeError('State cache copy indices must be host integers, not torch.Tensor.')
return [as_index(item) for item in idx]

@staticmethod
def _validate_index_bounds(indices: Sequence[int], num_caches: int):
"""Check normalized cache indices are valid state slots."""
for idx in indices:
if idx < 0 or idx >= num_caches:
raise ValueError(f'State cache index {idx} is out of range [0, {num_caches}).')

@staticmethod
def _copy_ranges(src_list: list[int], dst_list: list[int]):
"""Yield contiguous copy ranges as (src_start, dst_start, length)."""
pairs = sorted(zip(src_list, dst_list))
if len(pairs) == 0:
return
start_src = prev_src = pairs[0][0]
start_dst = prev_dst = pairs[0][1]
length = 1
for src, dst in pairs[1:]:
if src == prev_src + 1 and dst == prev_dst + 1:
prev_src = src
prev_dst = dst
length += 1
continue
yield start_src, start_dst, length
start_src = prev_src = src
start_dst = prev_dst = dst
length = 1
yield start_src, start_dst, length

def copy_caches(self, src_idx: int | Sequence[int], dst_idx: int | Sequence[int]):
"""Copy state cache slots.

This is the low-level primitive needed by SSM prefix caching: a frozen
state checkpoint can be copied into a newly allocated runtime slot
before the next forward.
"""
if len(self._state_caches) <= 0:
return

src_list = self._index_list(src_idx)
dst_list = self._index_list(dst_idx)
if len(src_list) != len(dst_list):
raise ValueError('src_idx and dst_idx must have the same number of elements.')
if len(src_list) == 0:
return
num_caches = self.mem_pool.size(0)
self._validate_index_bounds(src_list, num_caches)
self._validate_index_bounds(dst_list, num_caches)
dst_set = set(dst_list)
if len(dst_set) != len(dst_list):
raise ValueError('dst_idx must not contain duplicate entries.')
if not set(src_list).isdisjoint(dst_set):
raise ValueError('src_idx and dst_idx must not overlap for stream-ordered state copies.')

for src, dst, length in self._copy_ranges(src_list, dst_list):
if length == 1:
self.mem_pool[dst].copy_(self.mem_pool[src], non_blocking=True)
else:
self.mem_pool[dst:dst + length].copy_(self.mem_pool[src:src + length], non_blocking=True)
2 changes: 2 additions & 0 deletions lmdeploy/pytorch/engine/config_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ def build_cache_config(engine_config: PytorchEngineConfig):
max_prefill_token_num=engine_config.max_prefill_token_num,
cudagraph_capture_batch_sizes=engine_config.cudagraph_capture_batch_sizes,
enable_prefix_caching=engine_config.enable_prefix_caching,
prefix_cache_state_budget=engine_config.prefix_cache_state_budget,
prefix_cache_decode_state_interval=engine_config.prefix_cache_decode_state_interval,
quant_policy=engine_config.quant_policy,
device_type=engine_config.device_type,
migration_backend=engine_config.migration_backend,
Expand Down
3 changes: 3 additions & 0 deletions lmdeploy/pytorch/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from ..adapter.adapter import AdapterManager
from ..config import CacheConfig, ModelConfig
from ..messages import MessageStatus, SchedulerSequence, UpdateTokenMode
from ..multimodal.data_type import ensure_multimodal_content_hashes
from ..paging import Scheduler
from ..strategies import build_strategy_factory
from .base import EngineBase
Expand Down Expand Up @@ -412,6 +413,8 @@ def _on_add_message(self, reqs: list[Request], **kwargs):

input_ids = result.input_ids
input_multimodals = result.input_multimodals
if self.cache_config.enable_prefix_caching:
input_multimodals = ensure_multimodal_content_hashes(input_multimodals)

req_data['token_ids'] = input_ids
req_data['input_multimodals'] = input_multimodals
Expand Down
61 changes: 50 additions & 11 deletions lmdeploy/pytorch/engine/engine_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ def __get_logprobs(batched_outputs: 'BatchedOutputs'):
seq = running[0]
seq.append_routed_experts(all_routed_experts)
seq.append_logits(logits)
self.scheduler.block_trie.cache_routed_experts_for_seq(seq)
return dict()

new_token_timestamp = batched_outputs.new_token_timestamp
Expand All @@ -318,6 +319,7 @@ def __get_logprobs(batched_outputs: 'BatchedOutputs'):
batched_outputs=batched_outputs,
model_inputs=model_inputs,
delta=delta)
self.scheduler.block_trie.cache_routed_experts(running)

# generate output
outputs: dict[int, InferOutput] = dict()
Expand Down Expand Up @@ -378,6 +380,42 @@ async def _main_loop_try_send_next_inputs(self):
scheduler.collect_migration_done()
return await self.inputs_maker.send_next_inputs()

@staticmethod
def _has_state_checkpoint_save(model_inputs: 'ModelInputs | None', delta: 'ModelInputsDelta | None'):
"""Check whether the current forward reserved SSM checkpoints."""
return ((model_inputs is not None and model_inputs.state_prefix_cache_save_offsets is not None)
or (delta is not None and delta.state_prefix_cache_save_offsets is not None))

async def _prefetch_next_inputs(self):
"""Collect migration completions before prefetching the next batch."""
self.scheduler.collect_migration_done()
return await self.inputs_maker.prefetch_next_inputs()

def _publish_forward_prefix_cache(self, running: 'SeqList', has_state_checkpoint_save: bool):
"""Publish per-forward prefix-cache ownership before prefetching."""
if not self.scheduler.block_trie.enable:
return
if has_state_checkpoint_save:
self.scheduler.block_trie.commit_state_checkpoints(running, acquire_save_ref=True)
self.scheduler.block_trie.release_state_checkpoint_restores(running)

def _release_forward_prefix_cache_saves(self, running: 'SeqList'):
"""Release producer refs after the forward output/event boundary."""
if not self.scheduler.block_trie.enable:
return
self.scheduler.block_trie.release_state_checkpoint_saves(running)

def _finish_forward_output(self,
out: 'BatchedOutputs | None',
running: 'SeqList',
model_inputs: 'ModelInputs | None',
delta: 'ModelInputsDelta | None'):
"""Publish outputs."""
if out is None:
return
step_outputs = self._make_infer_outputs(out, running=running, model_inputs=model_inputs, delta=delta)
self.resp_queue.put_nowait(step_outputs)

async def _main_loop_get_outputs(
self,
running: 'SeqList',
Expand All @@ -387,18 +425,19 @@ async def _main_loop_get_outputs(
model_inputs = forward_inputs['inputs']
delta = forward_inputs['delta']
self.inputs_maker.update_running_seqs(running, model_inputs)

# try prefetch inputs
self.scheduler.collect_migration_done()
forward_inputs, next_running = await self.inputs_maker.prefetch_next_inputs()

# send output
has_state_checkpoint_save = self._has_state_checkpoint_save(model_inputs, delta)

# ModelAgent executes queued forwards in send order. Once the current
# input is queued, matched checkpoints can be published before waiting
# for GPU output; save checkpoints keep a producer ref until the output
# event boundary so prefetch cannot evict/reuse their destination slots.
self._publish_forward_prefix_cache(running, has_state_checkpoint_save)
Comment thread
grimoire marked this conversation as resolved.
forward_inputs, next_running = await self._prefetch_next_inputs()
out = await self.executor.get_output_async()
if out is not None:
step_outputs = self._make_infer_outputs(out, running=running, model_inputs=model_inputs, delta=delta)
self.resp_queue.put_nowait(step_outputs)
# out might come from shared memory, need to explicitly delete to release memory in time
del out
self._release_forward_prefix_cache_saves(running)
self._finish_forward_output(out, running, model_inputs, delta)
# out might come from shared memory, need to explicitly delete to release memory in time
del out

return forward_inputs, next_running

Expand Down
Loading
Loading