|
22 | 22 | from tensorrt_llm._torch.pyexecutor.kv_cache_manager_v2 import KVCacheManagerV2, Role |
23 | 23 | from tensorrt_llm.llmapi.llm_args import KvCacheConfig as KvCacheConfigV2 |
24 | 24 | from tensorrt_llm.mapping import Mapping |
| 25 | +from tensorrt_llm.runtime.kv_cache_manager_v2 import GpuCacheTierConfig |
25 | 26 |
|
26 | 27 | DataType = tensorrt_llm.bindings.DataType |
27 | 28 | CacheType = tensorrt_llm.bindings.internal.batch_manager.CacheType |
@@ -64,6 +65,39 @@ def _create_kv_cache_manager_v2( |
64 | 65 | class TestPerLayerHeadDimBasic(unittest.TestCase): |
65 | 66 | """Tests that don't allocate GPU memory or use uniform buffer sizes.""" |
66 | 67 |
|
| 68 | + def test_build_cache_config_reserves_concurrent_decode_slots(self): |
| 69 | + mgr = KVCacheManagerV2.__new__(KVCacheManagerV2) |
| 70 | + mgr.kv_cache_type = CacheType.SELF |
| 71 | + mgr.dtype = DataType.HALF |
| 72 | + mgr.kv_factor = 2 |
| 73 | + mgr.max_batch_size = 4 |
| 74 | + mgr.max_attention_window_vec = [4, None] |
| 75 | + mgr.num_local_layers = 2 |
| 76 | + mgr.pp_layers = [0, 1] |
| 77 | + mgr.num_kv_heads_per_layer = [1, 1] |
| 78 | + mgr.head_dim_per_layer = [8, 8] |
| 79 | + |
| 80 | + config = mgr._build_cache_config( |
| 81 | + KvCacheConfigV2(max_tokens=128, enable_block_reuse=False), |
| 82 | + tokens_per_block=8, |
| 83 | + vocab_size=32000, |
| 84 | + cache_tiers=[GpuCacheTierConfig(quota=1 << 20)], |
| 85 | + ) |
| 86 | + |
| 87 | + self.assertEqual(len(config.constraints), 1) |
| 88 | + decode_constraint = config.constraints[0] |
| 89 | + self.assertEqual(len(decode_constraint.kv_caches), mgr.max_batch_size) |
| 90 | + for kv_cache in decode_constraint.kv_caches: |
| 91 | + self.assertEqual(kv_cache.capacity, 8) |
| 92 | + self.assertEqual(kv_cache.history_length, 7) |
| 93 | + |
| 94 | + # Keep the previous StorageManager fallback ratio basis so adding the |
| 95 | + # constraint only floors min slots and does not change ratio selection. |
| 96 | + self.assertIsNotNone(config.typical_step) |
| 97 | + self.assertEqual(len(config.typical_step.kv_caches), 1) |
| 98 | + self.assertEqual(config.typical_step.kv_caches[0].capacity, 2049) |
| 99 | + self.assertEqual(config.typical_step.kv_caches[0].history_length, 2048) |
| 100 | + |
67 | 101 | def test_per_layer_head_dim_wrong_length(self): |
68 | 102 | """Test that mismatched list length raises assertion.""" |
69 | 103 | with self.assertRaises(AssertionError): |
|
0 commit comments