Skip to content

Commit aa846cf

Browse files
committed
[None][feat] Add capacity-only decode path to KV cache manager V2
Allow opt-in full-attention requests to preserve finalized history while resizing only physical capacity. Queue event-ordered compaction targets, preserve overlap reservations and cancellation semantics, and expose the authoritative pre-forward KV length. Signed-off-by: Hudayday <32944717+Hudayday@users.noreply.github.com>
1 parent 3f33620 commit aa846cf

3 files changed

Lines changed: 656 additions & 13 deletions

File tree

tensorrt_llm/_torch/pyexecutor/kv_cache_manager_v2.py

Lines changed: 217 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -801,6 +801,10 @@ def append_to_kv_heads_per_layer(
801801
self.enable_block_reuse = kv_cache_config.enable_block_reuse
802802
self.enable_partial_reuse = kv_cache_config.enable_partial_reuse
803803
self.disk_prefetch_num_reqs = kv_cache_config.disk_prefetch_num_reqs
804+
self._decode_capacity_only_requests: set[int] = set()
805+
self._pending_compacted_capacities: dict[
806+
int, tuple[int, int, Optional[torch.cuda.Event]]
807+
] = {}
804808

805809
# With pipeline parallelism, multiple microbatches can be in-flight
806810
# simultaneously, so we need slots for all concurrent sequences.
@@ -1283,6 +1287,142 @@ def is_request_active(self, request_id: int) -> bool:
12831287
kv_cache = self.kv_cache_map.get(request_id)
12841288
return kv_cache is not None and kv_cache.is_active
12851289

1290+
def enable_decode_capacity_only(self, request_id: int) -> None:
1291+
"""Preserve a request's finalized history while decoding.
1292+
1293+
Capacity-only decode is intended for KV compression methods that compact
1294+
uncommitted decode KV to the front of an all-full-attention cache. Context
1295+
updates still advance history to the finalized prompt prefix; decode
1296+
updates only resize capacity around that prefix.
1297+
1298+
Args:
1299+
request_id: Request whose decode capacity will be managed explicitly.
1300+
1301+
Raises:
1302+
ValueError: If block reuse is enabled or any local layer has an SWA or
1303+
SSM lifecycle.
1304+
"""
1305+
if self.enable_block_reuse:
1306+
raise ValueError("Decode capacity-only mode requires block reuse to be disabled")
1307+
if (
1308+
self.max_beam_width != 1
1309+
or self.num_extra_kv_tokens
1310+
or self.max_total_draft_tokens
1311+
or self._kv_reserve_draft_tokens
1312+
):
1313+
raise ValueError(
1314+
"Decode capacity-only mode currently supports single-token, beam-width-one "
1315+
"decoding only"
1316+
)
1317+
if any(window is not None for window in self.max_attention_window_vec) or any(
1318+
not isinstance(layer, AttentionLayerConfig) or layer.sliding_window_size is not None
1319+
for layer in self.kv_cache_manager_py_config.layers
1320+
):
1321+
raise ValueError(
1322+
"Decode capacity-only mode supports full-attention layers only; "
1323+
"SWA, VSWA, and SSM layers are not supported"
1324+
)
1325+
self._decode_capacity_only_requests.add(request_id)
1326+
1327+
def has_pending_compacted_capacity(self, request_id: int) -> bool:
1328+
"""Return whether a compacted capacity target is waiting to be consumed."""
1329+
return request_id in self._pending_compacted_capacities
1330+
1331+
def get_pre_forward_kv_length(self, request_id: int) -> int:
1332+
"""Return written KV tokens after scheduling and before the next forward.
1333+
1334+
The generation scheduler has reserved one unwritten slot at this point.
1335+
A pending compaction target can coexist with a later overlap reservation,
1336+
so derive the effective capacity from both instead of from request logical
1337+
length.
1338+
"""
1339+
if request_id not in self._decode_capacity_only_requests:
1340+
raise ValueError(f"Request {request_id} is not enabled for decode capacity-only mode")
1341+
kv_cache = self.kv_cache_map.get(request_id)
1342+
if kv_cache is None or not kv_cache.is_active:
1343+
raise ValueError(f"Request {request_id} has no active KV cache")
1344+
allocated_draft_len = self._allocated_draft_lens.get(request_id)
1345+
if allocated_draft_len is None:
1346+
raise ValueError(
1347+
f"Request {request_id} has no generation capacity reserved for this forward"
1348+
)
1349+
if allocated_draft_len:
1350+
raise ValueError(
1351+
"Decode capacity-only mode currently supports single-token, beam-width-one "
1352+
"decoding only"
1353+
)
1354+
effective_capacity = kv_cache.capacity
1355+
pending_target = self._pending_compacted_capacities.get(request_id)
1356+
if pending_target is not None:
1357+
target_capacity, published_capacity, _ = pending_target
1358+
capacity_growth = kv_cache.capacity - published_capacity
1359+
if capacity_growth < 0:
1360+
raise ValueError(
1361+
f"Request {request_id} capacity {kv_cache.capacity} fell below "
1362+
f"published capacity {published_capacity}"
1363+
)
1364+
effective_capacity = target_capacity + capacity_growth
1365+
if effective_capacity < 1:
1366+
raise ValueError(
1367+
f"Request {request_id} has invalid pre-forward capacity {effective_capacity}"
1368+
)
1369+
written_length = effective_capacity - 1
1370+
if written_length < kv_cache.history_length:
1371+
raise ValueError(
1372+
f"Request {request_id} pre-forward KV length {written_length} is below "
1373+
f"finalized history {kv_cache.history_length}"
1374+
)
1375+
return written_length
1376+
1377+
def set_compacted_capacity(
1378+
self,
1379+
request_id: int,
1380+
target_capacity: int,
1381+
event: Optional[torch.cuda.Event] = None,
1382+
) -> None:
1383+
"""Queue a one-shot physical capacity target for a compacted request.
1384+
1385+
The target is consumed by the next active generation update. Capacity
1386+
reserved after this call is added to the target, so an overlapped next
1387+
forward cannot lose its slot. If an event is supplied, the manager's
1388+
execution stream waits for it before releasing trailing KV pages.
1389+
1390+
Args:
1391+
request_id: Request previously enabled for capacity-only decode.
1392+
target_capacity: Physical capacity before the generation rewind is
1393+
applied.
1394+
event: Optional CUDA event recorded after compaction.
1395+
1396+
Raises:
1397+
ValueError: If the request is not enabled, has no KV cache, the target
1398+
does not leave a forward slot above finalized history, exceeds
1399+
current capacity, or another target is still pending.
1400+
"""
1401+
if request_id not in self._decode_capacity_only_requests:
1402+
raise ValueError(f"Request {request_id} is not enabled for decode capacity-only mode")
1403+
if target_capacity < 0:
1404+
raise ValueError(f"Compacted capacity must be non-negative, got {target_capacity}")
1405+
kv_cache = self.kv_cache_map.get(request_id)
1406+
if kv_cache is None:
1407+
raise ValueError(f"Request {request_id} has no KV cache to compact")
1408+
if target_capacity > kv_cache.capacity:
1409+
raise ValueError(
1410+
f"Compacted capacity {target_capacity} for request {request_id} "
1411+
f"cannot exceed current capacity {kv_cache.capacity}"
1412+
)
1413+
if target_capacity <= kv_cache.history_length:
1414+
raise ValueError(
1415+
f"Compacted capacity {target_capacity} for request {request_id} "
1416+
f"must leave a forward slot above finalized history {kv_cache.history_length}"
1417+
)
1418+
if request_id in self._pending_compacted_capacities:
1419+
raise ValueError(f"Request {request_id} already has a pending compacted capacity")
1420+
self._pending_compacted_capacities[request_id] = (
1421+
target_capacity,
1422+
kv_cache.capacity,
1423+
event,
1424+
)
1425+
12861426
def _effective_draft_len(self, req: LlmRequest) -> int:
12871427
"""Draft token length to use for next-step KV capacity calculation.
12881428
@@ -1384,23 +1524,53 @@ def revert_allocate_generation(self, req: LlmRequest) -> None:
13841524
host page-index buffer.
13851525
13861526
Mirror the effective draft length used in _required_gen_capacity
1387-
so disagg-gen-trans-complete revert stays symmetric.
1527+
so disagg-gen-trans-complete revert stays symmetric. The scheduler
1528+
overwrites ``_allocated_draft_lens`` for every revert-eligible
1529+
allocation; a successful revert consumes that marker.
13881530
"""
13891531
kv_cache = self.kv_cache_map.get(req.py_request_id)
13901532
if kv_cache is None or not kv_cache.is_active:
13911533
return
1392-
draft_len = self._allocated_draft_lens.pop(
1534+
has_allocation_marker = req.py_request_id in self._allocated_draft_lens
1535+
draft_len = self._allocated_draft_lens.get(
13931536
req.py_request_id, self._effective_draft_len(req)
13941537
)
1538+
pending_target = self._pending_compacted_capacities.get(req.py_request_id)
1539+
published_this_allocation = (
1540+
has_allocation_marker
1541+
and pending_target is not None
1542+
and pending_target[1] == kv_cache.capacity
1543+
)
13951544
reverted_cap = kv_cache.capacity - 1 - draft_len
13961545
if reverted_cap < 0:
1546+
self._allocated_draft_lens.pop(req.py_request_id, None)
13971547
return
1548+
reverted_pending_target = None
1549+
if published_this_allocation:
1550+
target_capacity, published_capacity, event = pending_target
1551+
reverted_target = target_capacity - 1 - draft_len
1552+
if reverted_target < kv_cache.history_length:
1553+
raise RuntimeError(
1554+
f"Reverting request {req.py_request_id} would move compacted "
1555+
f"capacity {reverted_target} below finalized history "
1556+
f"{kv_cache.history_length}"
1557+
)
1558+
reverted_pending_target = (
1559+
reverted_target,
1560+
published_capacity - 1 - draft_len,
1561+
event,
1562+
)
1563+
if pending_target is not None and pending_target[2] is not None:
1564+
self._stream.wait_event(pending_target[2])
13981565
if not kv_cache.resize(reverted_cap):
13991566
raise RuntimeError(
14001567
f"Failed to revert KV cache capacity for request "
14011568
f"{req.py_request_id} from {kv_cache.capacity} to "
14021569
f"{reverted_cap}"
14031570
)
1571+
self._allocated_draft_lens.pop(req.py_request_id, None)
1572+
if reverted_pending_target is not None:
1573+
self._pending_compacted_capacities[req.py_request_id] = reverted_pending_target
14041574

14051575
def revert_allocate_context(self, req: LlmRequest) -> None:
14061576
"""Undo the capacity growth from this iter's ``resize_context``.
@@ -2195,19 +2365,25 @@ def release_index_slot(self, request_id: int) -> None:
21952365
self._early_freed_index_requests.add(request_id)
21962366

21972367
def free_resources(self, request: LlmRequest, pin_on_release: bool = False):
2198-
self._allocated_draft_lens.pop(request.py_request_id, None)
2199-
kv_cache = self.kv_cache_map.pop(request.py_request_id, None)
2368+
request_id = request.py_request_id
2369+
self._allocated_draft_lens.pop(request_id, None)
2370+
pending_target = self._pending_compacted_capacities.get(request_id)
2371+
if pending_target is not None and pending_target[2] is not None:
2372+
self._stream.wait_event(pending_target[2])
2373+
self._decode_capacity_only_requests.discard(request_id)
2374+
self._pending_compacted_capacities.pop(request_id, None)
2375+
kv_cache = self.kv_cache_map.pop(request_id, None)
22002376
if kv_cache is None:
2201-
self.impl.clear_stats_excluded(request.py_request_id)
2377+
self.impl.clear_stats_excluded(request_id)
22022378
return
22032379
kv_cache.discard_pending_stats()
22042380
self.try_commit_blocks_for_reuse(request, kv_cache)
22052381
kv_cache.close()
2206-
self.impl.clear_stats_excluded(request.py_request_id)
2207-
if request.py_request_id in self._early_freed_index_requests:
2208-
self._early_freed_index_requests.discard(request.py_request_id)
2382+
self.impl.clear_stats_excluded(request_id)
2383+
if request_id in self._early_freed_index_requests:
2384+
self._early_freed_index_requests.discard(request_id)
22092385
else:
2210-
self.index_mapper.remove_sequence(request.py_request_id)
2386+
self.index_mapper.remove_sequence(request_id)
22112387

22122388
def get_batch_cache_indices(
22132389
self, request_ids: List[int], layer_idx: Optional[int] = None
@@ -2480,11 +2656,39 @@ def update_resources(
24802656
# will be resumed by the scheduler on the next iteration.
24812657
if not kv_cache.is_active:
24822658
continue
2483-
new_capacity = (
2484-
None
2485-
if req.state in (LlmRequestState.GENERATION_COMPLETE, LlmRequestState.CONTEXT_INIT)
2486-
else kv_cache.capacity - req.py_rewind_len
2659+
completing = req.state in (
2660+
LlmRequestState.GENERATION_COMPLETE,
2661+
LlmRequestState.CONTEXT_INIT,
24872662
)
2663+
request_id = req.py_request_id
2664+
if request_id in self._decode_capacity_only_requests:
2665+
pending_target = self._pending_compacted_capacities.get(request_id)
2666+
if pending_target is not None:
2667+
target_capacity, published_capacity, event = pending_target
2668+
if event is not None:
2669+
self._stream.wait_event(event)
2670+
if completing:
2671+
new_capacity = None
2672+
elif pending_target is None:
2673+
new_capacity = kv_cache.capacity - req.py_rewind_len
2674+
else:
2675+
capacity_growth = kv_cache.capacity - published_capacity
2676+
if capacity_growth < 0:
2677+
raise ValueError(
2678+
f"Request {request_id} capacity {kv_cache.capacity} fell below "
2679+
f"published capacity {published_capacity}"
2680+
)
2681+
new_capacity = target_capacity + capacity_growth - req.py_rewind_len
2682+
success = kv_cache.resize(new_capacity, None)
2683+
if not success:
2684+
raise ValueError(
2685+
f"Failed to resize KV cache for request {request_id} "
2686+
f"to capacity {new_capacity} while preserving its finalized history"
2687+
)
2688+
if completing or pending_target is not None:
2689+
self._pending_compacted_capacities.pop(request_id, None)
2690+
continue
2691+
new_capacity = None if completing else kv_cache.capacity - req.py_rewind_len
24882692
success = kv_cache.resize(new_capacity, req.max_beam_num_tokens - 1)
24892693
if not success:
24902694
raise ValueError(

tests/integration/test_lists/test-db/l0_a10.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ l0_a10:
3535
- unittest/_torch/executor/test_kv_pool_rebalance.py
3636
- unittest/_torch/executor/test_disagg_index_mapper_early_release.py
3737
- unittest/_torch/pyexecutor/test_kv_cache_compression_manager.py
38+
- unittest/_torch/pyexecutor/test_kv_cache_v2_capacity_only.py
3839
- unittest/_torch/modules/dwdp/test_dwdp_fixup_moe_backends.py
3940
- unittest/_torch/modules/dwdp/test_dwdp_manager.py
4041
- unittest/_torch/modules/dwdp/test_dwdp_mapping.py

0 commit comments

Comments
 (0)