Skip to content

Commit 062bc37

Browse files
author
Kevin-Li-2025
committed
Bound decode slot constraint by token budget
Signed-off-by: Kevin-Li-2025 <2242139@qq.com>
1 parent 75e8beb commit 062bc37

2 files changed

Lines changed: 25 additions & 1 deletion

File tree

tensorrt_llm/_torch/pyexecutor/kv_cache_manager_v2.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -897,6 +897,7 @@ def _build_cache_config(
897897
constraints=[
898898
self._build_concurrent_decode_constraint(
899899
max_batch_size=self.max_batch_size,
900+
max_tokens=kv_cache_config.max_tokens,
900901
tokens_per_block=tokens_per_block,
901902
)
902903
],
@@ -925,10 +926,12 @@ def _extra_buffers_per_layer(
925926

926927
@staticmethod
927928
def _build_concurrent_decode_constraint(
928-
*, max_batch_size: int, tokens_per_block: int
929+
*, max_batch_size: int, max_tokens: Optional[int], tokens_per_block: int
929930
) -> BatchDesc:
930931
assert max_batch_size > 0
931932
assert tokens_per_block > 0
933+
if max_tokens is not None:
934+
max_batch_size = max(1, min(max_batch_size, max_tokens // tokens_per_block))
932935
return BatchDesc(
933936
[
934937
KVCacheDesc(capacity=tokens_per_block, history_length=tokens_per_block - 1)

tests/unittest/_torch/executor/test_per_layer_head_dim.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,27 @@ def test_build_cache_config_reserves_concurrent_decode_slots(self):
9898
self.assertEqual(config.typical_step.kv_caches[0].capacity, 2049)
9999
self.assertEqual(config.typical_step.kv_caches[0].history_length, 2048)
100100

101+
def test_build_cache_config_bounds_concurrent_decode_slots_by_max_tokens(self):
102+
mgr = KVCacheManagerV2.__new__(KVCacheManagerV2)
103+
mgr.kv_cache_type = CacheType.SELF
104+
mgr.dtype = DataType.HALF
105+
mgr.kv_factor = 2
106+
mgr.max_batch_size = 4
107+
mgr.max_attention_window_vec = [4, None]
108+
mgr.num_local_layers = 2
109+
mgr.pp_layers = [0, 1]
110+
mgr.num_kv_heads_per_layer = [1, 1]
111+
mgr.head_dim_per_layer = [8, 8]
112+
113+
config = mgr._build_cache_config(
114+
KvCacheConfigV2(max_tokens=16, enable_block_reuse=False),
115+
tokens_per_block=8,
116+
vocab_size=32000,
117+
cache_tiers=[GpuCacheTierConfig(quota=1 << 20)],
118+
)
119+
120+
self.assertEqual(len(config.constraints[0].kv_caches), 2)
121+
101122
def test_per_layer_head_dim_wrong_length(self):
102123
"""Test that mismatched list length raises assertion."""
103124
with self.assertRaises(AssertionError):

0 commit comments

Comments
 (0)