diff --git a/tensorrt_llm/_torch/pyexecutor/kv_cache_manager_v2.py b/tensorrt_llm/_torch/pyexecutor/kv_cache_manager_v2.py index caf80ecba7b5..347dbd1e8e78 100644 --- a/tensorrt_llm/_torch/pyexecutor/kv_cache_manager_v2.py +++ b/tensorrt_llm/_torch/pyexecutor/kv_cache_manager_v2.py @@ -38,12 +38,14 @@ from tensorrt_llm.runtime.kv_cache_manager_v2 import ( DEFAULT_BEAM_INDEX, AttentionLayerConfig, + BatchDesc, BufferConfig, CacheTierConfig, DiskCacheTierConfig, GpuCacheTierConfig, HostCacheTierConfig, KVCacheIterationStatsDelta, + KVCacheDesc, LayerId, ReuseScope, TokenIdExt, @@ -1047,6 +1049,16 @@ def _build_cache_config( cache_tiers=cache_tiers, max_util_for_resume=kv_cache_config.max_util_for_resume, enable_stats=self.enable_stats, + constraints=[ + self._build_concurrent_decode_constraint( + max_batch_size=self.max_batch_size, + max_tokens=kv_cache_config.max_tokens, + tokens_per_block=tokens_per_block, + ) + ], + # Preserve StorageManager's historical fallback ratio while using + # constraints only as a min-slot floor. + typical_step=BatchDesc([KVCacheDesc(capacity=2049, history_length=2048)]), layers=layer_configs, ) @@ -1067,6 +1079,21 @@ def _extra_buffers_per_layer( """ return None + @staticmethod + def _build_concurrent_decode_constraint( + *, max_batch_size: int, max_tokens: Optional[int], tokens_per_block: int + ) -> BatchDesc: + assert max_batch_size > 0 + assert tokens_per_block > 0 + if max_tokens is not None: + max_batch_size = max(1, min(max_batch_size, max_tokens // tokens_per_block)) + return BatchDesc( + [ + KVCacheDesc(capacity=tokens_per_block, history_length=tokens_per_block - 1) + for _ in range(max_batch_size) + ] + ) + @property def blocks_in_primary_pool(self) -> int: """ diff --git a/tests/unittest/_torch/executor/test_per_layer_head_dim.py b/tests/unittest/_torch/executor/test_per_layer_head_dim.py index bb4a87fe8766..c87e30727450 100644 --- a/tests/unittest/_torch/executor/test_per_layer_head_dim.py +++ b/tests/unittest/_torch/executor/test_per_layer_head_dim.py @@ -22,6 +22,7 @@ from tensorrt_llm._torch.pyexecutor.kv_cache_manager_v2 import KVCacheManagerV2, Role from tensorrt_llm.llmapi.llm_args import KvCacheConfig as KvCacheConfigV2 from tensorrt_llm.mapping import Mapping +from tensorrt_llm.runtime.kv_cache_manager_v2 import GpuCacheTierConfig DataType = tensorrt_llm.bindings.DataType CacheType = tensorrt_llm.bindings.internal.batch_manager.CacheType @@ -64,6 +65,60 @@ def _create_kv_cache_manager_v2( class TestPerLayerHeadDimBasic(unittest.TestCase): """Tests that don't allocate GPU memory or use uniform buffer sizes.""" + def test_build_cache_config_reserves_concurrent_decode_slots(self): + mgr = KVCacheManagerV2.__new__(KVCacheManagerV2) + mgr.kv_cache_type = CacheType.SELF + mgr.dtype = DataType.HALF + mgr.kv_factor = 2 + mgr.max_batch_size = 4 + mgr.max_attention_window_vec = [4, None] + mgr.num_local_layers = 2 + mgr.pp_layers = [0, 1] + mgr.num_kv_heads_per_layer = [1, 1] + mgr.head_dim_per_layer = [8, 8] + + config = mgr._build_cache_config( + KvCacheConfigV2(max_tokens=128, enable_block_reuse=False), + tokens_per_block=8, + vocab_size=32000, + cache_tiers=[GpuCacheTierConfig(quota=1 << 20)], + ) + + self.assertEqual(len(config.constraints), 1) + decode_constraint = config.constraints[0] + self.assertEqual(len(decode_constraint.kv_caches), mgr.max_batch_size) + for kv_cache in decode_constraint.kv_caches: + self.assertEqual(kv_cache.capacity, 8) + self.assertEqual(kv_cache.history_length, 7) + + # Keep the previous StorageManager fallback ratio basis so adding the + # constraint only floors min slots and does not change ratio selection. + self.assertIsNotNone(config.typical_step) + self.assertEqual(len(config.typical_step.kv_caches), 1) + self.assertEqual(config.typical_step.kv_caches[0].capacity, 2049) + self.assertEqual(config.typical_step.kv_caches[0].history_length, 2048) + + def test_build_cache_config_bounds_concurrent_decode_slots_by_max_tokens(self): + mgr = KVCacheManagerV2.__new__(KVCacheManagerV2) + mgr.kv_cache_type = CacheType.SELF + mgr.dtype = DataType.HALF + mgr.kv_factor = 2 + mgr.max_batch_size = 4 + mgr.max_attention_window_vec = [4, None] + mgr.num_local_layers = 2 + mgr.pp_layers = [0, 1] + mgr.num_kv_heads_per_layer = [1, 1] + mgr.head_dim_per_layer = [8, 8] + + config = mgr._build_cache_config( + KvCacheConfigV2(max_tokens=16, enable_block_reuse=False), + tokens_per_block=8, + vocab_size=32000, + cache_tiers=[GpuCacheTierConfig(quota=1 << 20)], + ) + + self.assertEqual(len(config.constraints[0].kv_caches), 2) + def test_per_layer_head_dim_wrong_length(self): """Test that mismatched list length raises assertion.""" with self.assertRaises(AssertionError):