Skip to content

Commit b053331

Browse files
committed
[None][feat] Support capacity-only KV cache compaction
Signed-off-by: Hudayday <32944717+Hudayday@users.noreply.github.com>
1 parent e5a05b2 commit b053331

3 files changed

Lines changed: 151 additions & 2 deletions

File tree

tensorrt_llm/_torch/pyexecutor/kv_cache_manager_v2.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2488,13 +2488,34 @@ def update_resources(
24882488
if req.state in (LlmRequestState.GENERATION_COMPLETE, LlmRequestState.CONTEXT_INIT)
24892489
else kv_cache.capacity - req.py_rewind_len
24902490
)
2491-
success = kv_cache.resize(new_capacity, req.max_beam_num_tokens - 1)
2491+
capacity_only = getattr(req, "py_kv_cache_decode_capacity_only", False) is True
2492+
history_length = None if capacity_only else req.max_beam_num_tokens - 1
2493+
compaction = getattr(req, "py_kv_cache_compaction", None)
2494+
consume_compaction = capacity_only and compaction is not None
2495+
if consume_compaction:
2496+
target_capacity, published_capacity, event = compaction
2497+
capacity_growth = kv_cache.capacity - published_capacity
2498+
if capacity_growth < 0:
2499+
raise ValueError(
2500+
f"Request {req.py_request_id} capacity {kv_cache.capacity} "
2501+
f"fell below published capacity {published_capacity}"
2502+
)
2503+
# K+1 retains every block addressable by this forward. Resizing
2504+
# may race the full-table offset copy, but only rewrites the
2505+
# unreachable tail; the stream event protects page reuse.
2506+
if event is not None:
2507+
self._stream.wait_event(event)
2508+
if new_capacity is not None:
2509+
new_capacity = target_capacity + capacity_growth - req.py_rewind_len
2510+
success = kv_cache.resize(new_capacity, history_length)
24922511
if not success:
24932512
raise ValueError(
24942513
f"Failed to resize KV cache for request {req.py_request_id} "
24952514
f"to capacity {new_capacity} and history length "
2496-
f"{req.max_beam_num_tokens - 1} tokens at generation update"
2515+
f"{history_length} tokens at generation update"
24972516
)
2517+
if consume_compaction:
2518+
req.py_kv_cache_compaction = None
24982519

24992520
def copy_batch_block_offsets(
25002521
self,

tests/integration/test_lists/test-db/l0_a10.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ l0_a10:
3636
- unittest/_torch/executor/test_kv_pool_rebalance.py
3737
- unittest/_torch/executor/test_disagg_index_mapper_early_release.py
3838
- unittest/_torch/pyexecutor/test_kv_cache_compression_manager.py
39+
- unittest/_torch/pyexecutor/test_kv_cache_v2_capacity_only.py
3940
- unittest/_torch/modules/dwdp/test_dwdp_fixup_moe_backends.py
4041
- unittest/_torch/modules/dwdp/test_dwdp_manager.py
4142
- unittest/_torch/modules/dwdp/test_dwdp_mapping.py
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from types import SimpleNamespace
17+
from unittest.mock import MagicMock
18+
19+
import pytest
20+
21+
from tensorrt_llm._torch.pyexecutor.kv_cache_manager_v2 import KVCacheManagerV2
22+
from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequestState
23+
24+
25+
def _manager() -> KVCacheManagerV2:
26+
manager = KVCacheManagerV2.__new__(KVCacheManagerV2)
27+
manager.is_draft = True
28+
manager.enable_block_reuse = False
29+
manager.kv_cache_map = {}
30+
manager._stream = MagicMock()
31+
return manager
32+
33+
34+
def _request(request_id: int, *, rewind: int = 0, complete: bool = False) -> MagicMock:
35+
request = MagicMock()
36+
request.py_request_id = request_id
37+
request.py_rewind_len = rewind
38+
request.max_beam_num_tokens = 201
39+
request.py_kv_cache_decode_capacity_only = False
40+
request.py_kv_cache_compaction = None
41+
request.state = (
42+
LlmRequestState.GENERATION_COMPLETE if complete else LlmRequestState.GENERATION_IN_PROGRESS
43+
)
44+
return request
45+
46+
47+
def _cache(capacity: int = 256) -> MagicMock:
48+
cache = MagicMock()
49+
cache.is_active = True
50+
cache.capacity = capacity
51+
cache.resize.return_value = True
52+
return cache
53+
54+
55+
def test_capacity_only_is_request_scoped() -> None:
56+
manager = _manager()
57+
regular = _request(1, rewind=3)
58+
compacted = _request(2, rewind=5)
59+
regular_cache = _cache()
60+
compacted_cache = _cache()
61+
manager.kv_cache_map = {1: regular_cache, 2: compacted_cache}
62+
compacted.py_kv_cache_decode_capacity_only = True
63+
64+
manager.update_resources(SimpleNamespace(generation_requests=[regular, compacted]))
65+
66+
regular_cache.resize.assert_called_once_with(253, 200)
67+
compacted_cache.resize.assert_called_once_with(251, None)
68+
69+
70+
def test_missing_capacity_only_flag_is_fail_closed() -> None:
71+
manager = _manager()
72+
request = MagicMock()
73+
request.py_request_id = 1
74+
request.py_rewind_len = 3
75+
request.max_beam_num_tokens = 201
76+
request.py_kv_cache_compaction = None
77+
request.state = LlmRequestState.GENERATION_IN_PROGRESS
78+
cache = _cache()
79+
manager.kv_cache_map[1] = cache
80+
81+
manager.update_resources(SimpleNamespace(generation_requests=[request]))
82+
83+
cache.resize.assert_called_once_with(253, 200)
84+
85+
86+
def test_capacity_only_completion_preserves_history() -> None:
87+
manager = _manager()
88+
request = _request(1, complete=True)
89+
cache = _cache()
90+
manager.kv_cache_map[1] = cache
91+
request.py_kv_cache_decode_capacity_only = True
92+
93+
manager.update_resources(SimpleNamespace(generation_requests=[request]))
94+
95+
cache.resize.assert_called_once_with(None, None)
96+
97+
98+
def test_compaction_target_preserves_overlap_growth() -> None:
99+
manager = _manager()
100+
request = _request(1)
101+
event = MagicMock()
102+
request.py_kv_cache_compaction = (129, 256, event)
103+
cache = _cache(capacity=257)
104+
manager.kv_cache_map[1] = cache
105+
request.py_kv_cache_decode_capacity_only = True
106+
107+
manager.update_resources(SimpleNamespace(generation_requests=[request]))
108+
109+
manager._stream.wait_event.assert_called_once_with(event)
110+
cache.resize.assert_called_once_with(130, None)
111+
assert request.py_kv_cache_compaction is None
112+
113+
114+
def test_failed_compaction_resize_keeps_target() -> None:
115+
manager = _manager()
116+
request = _request(1)
117+
target = (129, 256, MagicMock())
118+
request.py_kv_cache_compaction = target
119+
cache = _cache(capacity=256)
120+
cache.resize.return_value = False
121+
manager.kv_cache_map[1] = cache
122+
request.py_kv_cache_decode_capacity_only = True
123+
124+
with pytest.raises(ValueError, match="Failed to resize KV cache"):
125+
manager.update_resources(SimpleNamespace(generation_requests=[request]))
126+
127+
assert request.py_kv_cache_compaction is target

0 commit comments

Comments
 (0)