Skip to content

Commit 36634a6

Browse files
author
Kevin-Li-2025
committed
Reserve KV cache slots for concurrent decode
Signed-off-by: Kevin-Li-2025 <2242139@qq.com>
1 parent 2a18bd4 commit 36634a6

2 files changed

Lines changed: 58 additions & 0 deletions

File tree

tensorrt_llm/_torch/pyexecutor/kv_cache_manager_v2.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,13 @@
3636
from tensorrt_llm.runtime.kv_cache_manager_v2 import (
3737
DEFAULT_BEAM_INDEX,
3838
AttentionLayerConfig,
39+
BatchDesc,
3940
BufferConfig,
4041
CacheTierConfig,
4142
DiskCacheTierConfig,
4243
GpuCacheTierConfig,
4344
HostCacheTierConfig,
45+
KVCacheDesc,
4446
LayerId,
4547
ReuseScope,
4648
TokenIdExt,
@@ -892,6 +894,15 @@ def _build_cache_config(
892894
vocab_size=vocab_size,
893895
cache_tiers=cache_tiers,
894896
max_util_for_resume=kv_cache_config.max_util_for_resume,
897+
constraints=[
898+
self._build_concurrent_decode_constraint(
899+
max_batch_size=self.max_batch_size,
900+
tokens_per_block=tokens_per_block,
901+
)
902+
],
903+
# Preserve StorageManager's historical fallback ratio while using
904+
# constraints only as a min-slot floor.
905+
typical_step=BatchDesc([KVCacheDesc(capacity=2049, history_length=2048)]),
895906
layers=layer_configs,
896907
)
897908

@@ -912,6 +923,19 @@ def _extra_buffers_per_layer(
912923
"""
913924
return None
914925

926+
@staticmethod
927+
def _build_concurrent_decode_constraint(
928+
*, max_batch_size: int, tokens_per_block: int
929+
) -> BatchDesc:
930+
assert max_batch_size > 0
931+
assert tokens_per_block > 0
932+
return BatchDesc(
933+
[
934+
KVCacheDesc(capacity=tokens_per_block, history_length=tokens_per_block - 1)
935+
for _ in range(max_batch_size)
936+
]
937+
)
938+
915939
@property
916940
def blocks_in_primary_pool(self) -> int:
917941
"""

tests/unittest/_torch/executor/test_per_layer_head_dim.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from tensorrt_llm._torch.pyexecutor.kv_cache_manager_v2 import KVCacheManagerV2, Role
2323
from tensorrt_llm.llmapi.llm_args import KvCacheConfig as KvCacheConfigV2
2424
from tensorrt_llm.mapping import Mapping
25+
from tensorrt_llm.runtime.kv_cache_manager_v2 import GpuCacheTierConfig
2526

2627
DataType = tensorrt_llm.bindings.DataType
2728
CacheType = tensorrt_llm.bindings.internal.batch_manager.CacheType
@@ -64,6 +65,39 @@ def _create_kv_cache_manager_v2(
6465
class TestPerLayerHeadDimBasic(unittest.TestCase):
6566
"""Tests that don't allocate GPU memory or use uniform buffer sizes."""
6667

68+
def test_build_cache_config_reserves_concurrent_decode_slots(self):
69+
mgr = KVCacheManagerV2.__new__(KVCacheManagerV2)
70+
mgr.kv_cache_type = CacheType.SELF
71+
mgr.dtype = DataType.HALF
72+
mgr.kv_factor = 2
73+
mgr.max_batch_size = 4
74+
mgr.max_attention_window_vec = [4, None]
75+
mgr.num_local_layers = 2
76+
mgr.pp_layers = [0, 1]
77+
mgr.num_kv_heads_per_layer = [1, 1]
78+
mgr.head_dim_per_layer = [8, 8]
79+
80+
config = mgr._build_cache_config(
81+
KvCacheConfigV2(max_tokens=128, enable_block_reuse=False),
82+
tokens_per_block=8,
83+
vocab_size=32000,
84+
cache_tiers=[GpuCacheTierConfig(quota=1 << 20)],
85+
)
86+
87+
self.assertEqual(len(config.constraints), 1)
88+
decode_constraint = config.constraints[0]
89+
self.assertEqual(len(decode_constraint.kv_caches), mgr.max_batch_size)
90+
for kv_cache in decode_constraint.kv_caches:
91+
self.assertEqual(kv_cache.capacity, 8)
92+
self.assertEqual(kv_cache.history_length, 7)
93+
94+
# Keep the previous StorageManager fallback ratio basis so adding the
95+
# constraint only floors min slots and does not change ratio selection.
96+
self.assertIsNotNone(config.typical_step)
97+
self.assertEqual(len(config.typical_step.kv_caches), 1)
98+
self.assertEqual(config.typical_step.kv_caches[0].capacity, 2049)
99+
self.assertEqual(config.typical_step.kv_caches[0].history_length, 2048)
100+
67101
def test_per_layer_head_dim_wrong_length(self):
68102
"""Test that mismatched list length raises assertion."""
69103
with self.assertRaises(AssertionError):

0 commit comments

Comments
 (0)