Skip to content

Commit 207015d

Browse files
author
Kevin-Li-2025
committed
Bound decode slot constraint by token budget
Signed-off-by: Kevin-Li-2025 <2242139@qq.com>
1 parent 4a9783b commit 207015d

2 files changed

Lines changed: 25 additions & 1 deletion

File tree

tensorrt_llm/_torch/pyexecutor/kv_cache_manager_v2.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1052,6 +1052,7 @@ def _build_cache_config(
10521052
constraints=[
10531053
self._build_concurrent_decode_constraint(
10541054
max_batch_size=self.max_batch_size,
1055+
max_tokens=kv_cache_config.max_tokens,
10551056
tokens_per_block=tokens_per_block,
10561057
)
10571058
],
@@ -1080,10 +1081,12 @@ def _extra_buffers_per_layer(
10801081

10811082
@staticmethod
10821083
def _build_concurrent_decode_constraint(
1083-
*, max_batch_size: int, tokens_per_block: int
1084+
*, max_batch_size: int, max_tokens: Optional[int], tokens_per_block: int
10841085
) -> BatchDesc:
10851086
assert max_batch_size > 0
10861087
assert tokens_per_block > 0
1088+
if max_tokens is not None:
1089+
max_batch_size = max(1, min(max_batch_size, max_tokens // tokens_per_block))
10871090
return BatchDesc(
10881091
[
10891092
KVCacheDesc(capacity=tokens_per_block, history_length=tokens_per_block - 1)

tests/unittest/_torch/executor/test_per_layer_head_dim.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,27 @@ def test_build_cache_config_reserves_concurrent_decode_slots(self):
9898
self.assertEqual(config.typical_step.kv_caches[0].capacity, 2049)
9999
self.assertEqual(config.typical_step.kv_caches[0].history_length, 2048)
100100

101+
def test_build_cache_config_bounds_concurrent_decode_slots_by_max_tokens(self):
102+
mgr = KVCacheManagerV2.__new__(KVCacheManagerV2)
103+
mgr.kv_cache_type = CacheType.SELF
104+
mgr.dtype = DataType.HALF
105+
mgr.kv_factor = 2
106+
mgr.max_batch_size = 4
107+
mgr.max_attention_window_vec = [4, None]
108+
mgr.num_local_layers = 2
109+
mgr.pp_layers = [0, 1]
110+
mgr.num_kv_heads_per_layer = [1, 1]
111+
mgr.head_dim_per_layer = [8, 8]
112+
113+
config = mgr._build_cache_config(
114+
KvCacheConfigV2(max_tokens=16, enable_block_reuse=False),
115+
tokens_per_block=8,
116+
vocab_size=32000,
117+
cache_tiers=[GpuCacheTierConfig(quota=1 << 20)],
118+
)
119+
120+
self.assertEqual(len(config.constraints[0].kv_caches), 2)
121+
101122
def test_per_layer_head_dim_wrong_length(self):
102123
"""Test that mismatched list length raises assertion."""
103124
with self.assertRaises(AssertionError):

0 commit comments

Comments
 (0)