Skip to content

Commit 5cdbbf3

Browse files
committed
feat(kvcache): add KV-cache-compression reclaim hook to KVCacheManagerV2.update_resources
KVCacheManagerV2.update_resources grows every generation request's KV history by one token per decode step. KV-cache compression methods compact a request's kept tokens in place, after which the history should shrink and the freed blocks return to the pool. This establishes the integration point: a compression-evicted request reports its evicted-token count in py_kv_evicted_tokens, and update_resources has the hook where the non-growing reclaim (shrink to the compacted length + free the blocks) belongs. The reclaim itself needs a _KVCache non-growing shrink/free primitive (the planned page-sharing fork) and is left as a TODO follow-up; until then the hook is a no-op and every request grows exactly as before (byte-identical behavior). Adds unit tests pinning the grow / completing / suspended paths so the hook stays non-breaking, and registers them in l0_a10.yml. Signed-off-by: Tianrui Hu <tianruih@nvidia.com>
1 parent 0425801 commit 5cdbbf3

3 files changed

Lines changed: 105 additions & 4 deletions

File tree

tensorrt_llm/_torch/pyexecutor/kv_cache_manager_v2.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2480,11 +2480,23 @@ def update_resources(
24802480
# will be resumed by the scheduler on the next iteration.
24812481
if not kv_cache.is_active:
24822482
continue
2483-
new_capacity = (
2484-
None
2485-
if req.state in (LlmRequestState.GENERATION_COMPLETE, LlmRequestState.CONTEXT_INIT)
2486-
else kv_cache.capacity - req.py_rewind_len
2483+
completing = req.state in (
2484+
LlmRequestState.GENERATION_COMPLETE,
2485+
LlmRequestState.CONTEXT_INIT,
24872486
)
2487+
# KV-cache-compression reclaim hook. A request whose kept tokens were
2488+
# compacted to the front reports its evicted-token count in
2489+
# py_kv_evicted_tokens; when that is > 0 the non-growing reclaim — shrink
2490+
# history to (max_beam_num_tokens - evicted) and return the freed blocks to
2491+
# the pool — belongs here, in place of the grow below.
2492+
# TODO(kvcache): implement once _KVCache exposes the non-growing shrink/free
2493+
# primitive (the planned page-sharing fork); out of scope for this change.
2494+
# Until then every request grows as usual (still correct — the compression
2495+
# manager reconciles the read length; the freed blocks are not yet reclaimed).
2496+
evicted = int(getattr(req, "py_kv_evicted_tokens", 0) or 0)
2497+
if evicted > 0 and not completing:
2498+
pass # TODO(kvcache): reclaim to (req.max_beam_num_tokens - evicted)
2499+
new_capacity = None if completing else kv_cache.capacity - req.py_rewind_len
24882500
success = kv_cache.resize(new_capacity, req.max_beam_num_tokens - 1)
24892501
if not success:
24902502
raise ValueError(

tests/integration/test_lists/test-db/l0_a10.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ l0_a10:
3535
- unittest/_torch/executor/test_kv_pool_rebalance.py
3636
- unittest/_torch/executor/test_disagg_index_mapper_early_release.py
3737
- unittest/_torch/pyexecutor/test_kv_cache_compression_manager.py
38+
- unittest/_torch/pyexecutor/test_kv_cache_v2_compression_reclaim.py
3839
- unittest/_torch/modules/dwdp/test_dwdp_fixup_moe_backends.py
3940
- unittest/_torch/modules/dwdp/test_dwdp_manager.py
4041
- unittest/_torch/modules/dwdp/test_dwdp_mapping.py
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
"""Unit tests for KVCacheManagerV2's KV-cache-compression reclaim hook in
2+
``update_resources``.
3+
4+
A request whose cache was compacted in place reports its evicted-token count in
5+
``py_kv_evicted_tokens``. The actual non-growing reclaim (shrink history + return the
6+
freed blocks to the pool) needs a ``_KVCache`` shrink/free primitive and is a TODO
7+
follow-up, so for now the hook is a no-op and every request grows as before. These
8+
tests pin that the hook does not change the existing grow / completing / suspended
9+
paths (i.e. it is byte-identical, non-breaking).
10+
"""
11+
12+
from unittest.mock import MagicMock
13+
14+
from tensorrt_llm._torch.pyexecutor.kv_cache_manager_v2 import KVCacheManagerV2
15+
from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequestState
16+
17+
18+
def _fake_manager():
19+
"""A KVCacheManagerV2 with only the attributes update_resources touches."""
20+
m = KVCacheManagerV2.__new__(KVCacheManagerV2)
21+
m.is_draft = True # skip the draft-token-location helper (orthogonal)
22+
m.kv_cache_map = {}
23+
return m
24+
25+
26+
def _req(rid, max_beam, evicted=0, rewind=0, state=LlmRequestState.GENERATION_IN_PROGRESS):
27+
r = MagicMock()
28+
r.py_request_id = rid
29+
r.max_beam_num_tokens = max_beam
30+
r.py_kv_evicted_tokens = evicted
31+
r.py_rewind_len = rewind
32+
r.state = state
33+
return r
34+
35+
36+
def _kvcache(capacity=4096):
37+
k = MagicMock()
38+
k.is_active = True
39+
k.capacity = capacity
40+
k.resize.return_value = True
41+
return k
42+
43+
44+
def _batch(reqs):
45+
b = MagicMock()
46+
b.generation_requests = reqs
47+
return b
48+
49+
50+
def _run(manager, reqs):
51+
for r in reqs:
52+
manager.kv_cache_map[r.py_request_id] = _kvcache()
53+
manager.update_resources(_batch(reqs))
54+
return {r.py_request_id: manager.kv_cache_map[r.py_request_id] for r in reqs}
55+
56+
57+
def test_evicted_request_grows_until_reclaim_implemented():
58+
"""A compression-evicted request reports py_kv_evicted_tokens, but until the
59+
reclaim primitive lands the hook is a no-op and the request still grows like any
60+
other (resize(capacity, max_beam-1)) -- the hook must not break the grow path."""
61+
m = _fake_manager()
62+
k = _run(m, [_req(1, max_beam=200, evicted=50, rewind=0)])[1]
63+
k.resize.assert_called_once_with(4096, 199)
64+
65+
66+
def test_unevicted_request_grows_exactly_as_before():
67+
"""Backward-compat: no eviction -> resize(capacity - rewind, max_beam - 1)."""
68+
m = _fake_manager()
69+
k = _run(m, [_req(2, max_beam=200, evicted=0, rewind=3)])[2]
70+
k.resize.assert_called_once_with(4096 - 3, 199)
71+
72+
73+
def test_completing_request_resizes_with_none_capacity():
74+
"""A completing/generation-done request keeps the resize(None, max_beam-1) path."""
75+
m = _fake_manager()
76+
k = _run(m, [_req(3, max_beam=200, state=LlmRequestState.GENERATION_COMPLETE)])[3]
77+
k.resize.assert_called_once_with(None, 199)
78+
79+
80+
def test_inactive_cache_is_skipped():
81+
"""A suspended (overlap-scheduler) cache is skipped entirely -- no resize."""
82+
m = _fake_manager()
83+
r = _req(6, max_beam=200, evicted=50)
84+
k = _kvcache()
85+
k.is_active = False
86+
m.kv_cache_map[6] = k
87+
m.update_resources(_batch([r]))
88+
k.resize.assert_not_called()

0 commit comments

Comments
 (0)