Skip to content

Commit d2b3cfc

Browse files
author
Kevin-Li-2025
committed
Reserve KV cache slots for concurrent decode
Signed-off-by: Kevin-Li-2025 <2242139@qq.com>
1 parent f75d795 commit d2b3cfc

2 files changed

Lines changed: 58 additions & 0 deletions

File tree

tensorrt_llm/_torch/pyexecutor/kv_cache_manager_v2.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,14 @@
3838
from tensorrt_llm.runtime.kv_cache_manager_v2 import (
3939
DEFAULT_BEAM_INDEX,
4040
AttentionLayerConfig,
41+
BatchDesc,
4142
BufferConfig,
4243
CacheTierConfig,
4344
DiskCacheTierConfig,
4445
GpuCacheTierConfig,
4546
HostCacheTierConfig,
4647
KVCacheIterationStatsDelta,
48+
KVCacheDesc,
4749
LayerId,
4850
ReuseScope,
4951
TokenIdExt,
@@ -1047,6 +1049,15 @@ def _build_cache_config(
10471049
cache_tiers=cache_tiers,
10481050
max_util_for_resume=kv_cache_config.max_util_for_resume,
10491051
enable_stats=self.enable_stats,
1052+
constraints=[
1053+
self._build_concurrent_decode_constraint(
1054+
max_batch_size=self.max_batch_size,
1055+
tokens_per_block=tokens_per_block,
1056+
)
1057+
],
1058+
# Preserve StorageManager's historical fallback ratio while using
1059+
# constraints only as a min-slot floor.
1060+
typical_step=BatchDesc([KVCacheDesc(capacity=2049, history_length=2048)]),
10501061
layers=layer_configs,
10511062
)
10521063

@@ -1067,6 +1078,19 @@ def _extra_buffers_per_layer(
10671078
"""
10681079
return None
10691080

1081+
@staticmethod
1082+
def _build_concurrent_decode_constraint(
1083+
*, max_batch_size: int, tokens_per_block: int
1084+
) -> BatchDesc:
1085+
assert max_batch_size > 0
1086+
assert tokens_per_block > 0
1087+
return BatchDesc(
1088+
[
1089+
KVCacheDesc(capacity=tokens_per_block, history_length=tokens_per_block - 1)
1090+
for _ in range(max_batch_size)
1091+
]
1092+
)
1093+
10701094
@property
10711095
def blocks_in_primary_pool(self) -> int:
10721096
"""

tests/unittest/_torch/executor/test_per_layer_head_dim.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from tensorrt_llm._torch.pyexecutor.kv_cache_manager_v2 import KVCacheManagerV2, Role
2323
from tensorrt_llm.llmapi.llm_args import KvCacheConfig as KvCacheConfigV2
2424
from tensorrt_llm.mapping import Mapping
25+
from tensorrt_llm.runtime.kv_cache_manager_v2 import GpuCacheTierConfig
2526

2627
DataType = tensorrt_llm.bindings.DataType
2728
CacheType = tensorrt_llm.bindings.internal.batch_manager.CacheType
@@ -64,6 +65,39 @@ def _create_kv_cache_manager_v2(
6465
class TestPerLayerHeadDimBasic(unittest.TestCase):
6566
"""Tests that don't allocate GPU memory or use uniform buffer sizes."""
6667

68+
def test_build_cache_config_reserves_concurrent_decode_slots(self):
69+
mgr = KVCacheManagerV2.__new__(KVCacheManagerV2)
70+
mgr.kv_cache_type = CacheType.SELF
71+
mgr.dtype = DataType.HALF
72+
mgr.kv_factor = 2
73+
mgr.max_batch_size = 4
74+
mgr.max_attention_window_vec = [4, None]
75+
mgr.num_local_layers = 2
76+
mgr.pp_layers = [0, 1]
77+
mgr.num_kv_heads_per_layer = [1, 1]
78+
mgr.head_dim_per_layer = [8, 8]
79+
80+
config = mgr._build_cache_config(
81+
KvCacheConfigV2(max_tokens=128, enable_block_reuse=False),
82+
tokens_per_block=8,
83+
vocab_size=32000,
84+
cache_tiers=[GpuCacheTierConfig(quota=1 << 20)],
85+
)
86+
87+
self.assertEqual(len(config.constraints), 1)
88+
decode_constraint = config.constraints[0]
89+
self.assertEqual(len(decode_constraint.kv_caches), mgr.max_batch_size)
90+
for kv_cache in decode_constraint.kv_caches:
91+
self.assertEqual(kv_cache.capacity, 8)
92+
self.assertEqual(kv_cache.history_length, 7)
93+
94+
# Keep the previous StorageManager fallback ratio basis so adding the
95+
# constraint only floors min slots and does not change ratio selection.
96+
self.assertIsNotNone(config.typical_step)
97+
self.assertEqual(len(config.typical_step.kv_caches), 1)
98+
self.assertEqual(config.typical_step.kv_caches[0].capacity, 2049)
99+
self.assertEqual(config.typical_step.kv_caches[0].history_length, 2048)
100+
67101
def test_per_layer_head_dim_wrong_length(self):
68102
"""Test that mismatched list length raises assertion."""
69103
with self.assertRaises(AssertionError):

0 commit comments

Comments
 (0)