Skip to content

Commit ef05e64

Browse files
committed
[None][refactor] Minimize V2 capacity-only decode integration
Signed-off-by: Hudayday <32944717+Hudayday@users.noreply.github.com>
1 parent aa846cf commit ef05e64

2 files changed

Lines changed: 90 additions & 584 deletions

File tree

tensorrt_llm/_torch/pyexecutor/kv_cache_manager_v2.py

Lines changed: 34 additions & 217 deletions
Original file line numberDiff line numberDiff line change
@@ -801,10 +801,6 @@ def append_to_kv_heads_per_layer(
801801
self.enable_block_reuse = kv_cache_config.enable_block_reuse
802802
self.enable_partial_reuse = kv_cache_config.enable_partial_reuse
803803
self.disk_prefetch_num_reqs = kv_cache_config.disk_prefetch_num_reqs
804-
self._decode_capacity_only_requests: set[int] = set()
805-
self._pending_compacted_capacities: dict[
806-
int, tuple[int, int, Optional[torch.cuda.Event]]
807-
] = {}
808804

809805
# With pipeline parallelism, multiple microbatches can be in-flight
810806
# simultaneously, so we need slots for all concurrent sequences.
@@ -1287,142 +1283,6 @@ def is_request_active(self, request_id: int) -> bool:
12871283
kv_cache = self.kv_cache_map.get(request_id)
12881284
return kv_cache is not None and kv_cache.is_active
12891285

1290-
def enable_decode_capacity_only(self, request_id: int) -> None:
1291-
"""Preserve a request's finalized history while decoding.
1292-
1293-
Capacity-only decode is intended for KV compression methods that compact
1294-
uncommitted decode KV to the front of an all-full-attention cache. Context
1295-
updates still advance history to the finalized prompt prefix; decode
1296-
updates only resize capacity around that prefix.
1297-
1298-
Args:
1299-
request_id: Request whose decode capacity will be managed explicitly.
1300-
1301-
Raises:
1302-
ValueError: If block reuse is enabled or any local layer has an SWA or
1303-
SSM lifecycle.
1304-
"""
1305-
if self.enable_block_reuse:
1306-
raise ValueError("Decode capacity-only mode requires block reuse to be disabled")
1307-
if (
1308-
self.max_beam_width != 1
1309-
or self.num_extra_kv_tokens
1310-
or self.max_total_draft_tokens
1311-
or self._kv_reserve_draft_tokens
1312-
):
1313-
raise ValueError(
1314-
"Decode capacity-only mode currently supports single-token, beam-width-one "
1315-
"decoding only"
1316-
)
1317-
if any(window is not None for window in self.max_attention_window_vec) or any(
1318-
not isinstance(layer, AttentionLayerConfig) or layer.sliding_window_size is not None
1319-
for layer in self.kv_cache_manager_py_config.layers
1320-
):
1321-
raise ValueError(
1322-
"Decode capacity-only mode supports full-attention layers only; "
1323-
"SWA, VSWA, and SSM layers are not supported"
1324-
)
1325-
self._decode_capacity_only_requests.add(request_id)
1326-
1327-
def has_pending_compacted_capacity(self, request_id: int) -> bool:
1328-
"""Return whether a compacted capacity target is waiting to be consumed."""
1329-
return request_id in self._pending_compacted_capacities
1330-
1331-
def get_pre_forward_kv_length(self, request_id: int) -> int:
1332-
"""Return written KV tokens after scheduling and before the next forward.
1333-
1334-
The generation scheduler has reserved one unwritten slot at this point.
1335-
A pending compaction target can coexist with a later overlap reservation,
1336-
so derive the effective capacity from both instead of from request logical
1337-
length.
1338-
"""
1339-
if request_id not in self._decode_capacity_only_requests:
1340-
raise ValueError(f"Request {request_id} is not enabled for decode capacity-only mode")
1341-
kv_cache = self.kv_cache_map.get(request_id)
1342-
if kv_cache is None or not kv_cache.is_active:
1343-
raise ValueError(f"Request {request_id} has no active KV cache")
1344-
allocated_draft_len = self._allocated_draft_lens.get(request_id)
1345-
if allocated_draft_len is None:
1346-
raise ValueError(
1347-
f"Request {request_id} has no generation capacity reserved for this forward"
1348-
)
1349-
if allocated_draft_len:
1350-
raise ValueError(
1351-
"Decode capacity-only mode currently supports single-token, beam-width-one "
1352-
"decoding only"
1353-
)
1354-
effective_capacity = kv_cache.capacity
1355-
pending_target = self._pending_compacted_capacities.get(request_id)
1356-
if pending_target is not None:
1357-
target_capacity, published_capacity, _ = pending_target
1358-
capacity_growth = kv_cache.capacity - published_capacity
1359-
if capacity_growth < 0:
1360-
raise ValueError(
1361-
f"Request {request_id} capacity {kv_cache.capacity} fell below "
1362-
f"published capacity {published_capacity}"
1363-
)
1364-
effective_capacity = target_capacity + capacity_growth
1365-
if effective_capacity < 1:
1366-
raise ValueError(
1367-
f"Request {request_id} has invalid pre-forward capacity {effective_capacity}"
1368-
)
1369-
written_length = effective_capacity - 1
1370-
if written_length < kv_cache.history_length:
1371-
raise ValueError(
1372-
f"Request {request_id} pre-forward KV length {written_length} is below "
1373-
f"finalized history {kv_cache.history_length}"
1374-
)
1375-
return written_length
1376-
1377-
def set_compacted_capacity(
1378-
self,
1379-
request_id: int,
1380-
target_capacity: int,
1381-
event: Optional[torch.cuda.Event] = None,
1382-
) -> None:
1383-
"""Queue a one-shot physical capacity target for a compacted request.
1384-
1385-
The target is consumed by the next active generation update. Capacity
1386-
reserved after this call is added to the target, so an overlapped next
1387-
forward cannot lose its slot. If an event is supplied, the manager's
1388-
execution stream waits for it before releasing trailing KV pages.
1389-
1390-
Args:
1391-
request_id: Request previously enabled for capacity-only decode.
1392-
target_capacity: Physical capacity before the generation rewind is
1393-
applied.
1394-
event: Optional CUDA event recorded after compaction.
1395-
1396-
Raises:
1397-
ValueError: If the request is not enabled, has no KV cache, the target
1398-
does not leave a forward slot above finalized history, exceeds
1399-
current capacity, or another target is still pending.
1400-
"""
1401-
if request_id not in self._decode_capacity_only_requests:
1402-
raise ValueError(f"Request {request_id} is not enabled for decode capacity-only mode")
1403-
if target_capacity < 0:
1404-
raise ValueError(f"Compacted capacity must be non-negative, got {target_capacity}")
1405-
kv_cache = self.kv_cache_map.get(request_id)
1406-
if kv_cache is None:
1407-
raise ValueError(f"Request {request_id} has no KV cache to compact")
1408-
if target_capacity > kv_cache.capacity:
1409-
raise ValueError(
1410-
f"Compacted capacity {target_capacity} for request {request_id} "
1411-
f"cannot exceed current capacity {kv_cache.capacity}"
1412-
)
1413-
if target_capacity <= kv_cache.history_length:
1414-
raise ValueError(
1415-
f"Compacted capacity {target_capacity} for request {request_id} "
1416-
f"must leave a forward slot above finalized history {kv_cache.history_length}"
1417-
)
1418-
if request_id in self._pending_compacted_capacities:
1419-
raise ValueError(f"Request {request_id} already has a pending compacted capacity")
1420-
self._pending_compacted_capacities[request_id] = (
1421-
target_capacity,
1422-
kv_cache.capacity,
1423-
event,
1424-
)
1425-
14261286
def _effective_draft_len(self, req: LlmRequest) -> int:
14271287
"""Draft token length to use for next-step KV capacity calculation.
14281288
@@ -1524,53 +1384,23 @@ def revert_allocate_generation(self, req: LlmRequest) -> None:
15241384
host page-index buffer.
15251385
15261386
Mirror the effective draft length used in _required_gen_capacity
1527-
so disagg-gen-trans-complete revert stays symmetric. The scheduler
1528-
overwrites ``_allocated_draft_lens`` for every revert-eligible
1529-
allocation; a successful revert consumes that marker.
1387+
so disagg-gen-trans-complete revert stays symmetric.
15301388
"""
15311389
kv_cache = self.kv_cache_map.get(req.py_request_id)
15321390
if kv_cache is None or not kv_cache.is_active:
15331391
return
1534-
has_allocation_marker = req.py_request_id in self._allocated_draft_lens
1535-
draft_len = self._allocated_draft_lens.get(
1392+
draft_len = self._allocated_draft_lens.pop(
15361393
req.py_request_id, self._effective_draft_len(req)
15371394
)
1538-
pending_target = self._pending_compacted_capacities.get(req.py_request_id)
1539-
published_this_allocation = (
1540-
has_allocation_marker
1541-
and pending_target is not None
1542-
and pending_target[1] == kv_cache.capacity
1543-
)
15441395
reverted_cap = kv_cache.capacity - 1 - draft_len
15451396
if reverted_cap < 0:
1546-
self._allocated_draft_lens.pop(req.py_request_id, None)
15471397
return
1548-
reverted_pending_target = None
1549-
if published_this_allocation:
1550-
target_capacity, published_capacity, event = pending_target
1551-
reverted_target = target_capacity - 1 - draft_len
1552-
if reverted_target < kv_cache.history_length:
1553-
raise RuntimeError(
1554-
f"Reverting request {req.py_request_id} would move compacted "
1555-
f"capacity {reverted_target} below finalized history "
1556-
f"{kv_cache.history_length}"
1557-
)
1558-
reverted_pending_target = (
1559-
reverted_target,
1560-
published_capacity - 1 - draft_len,
1561-
event,
1562-
)
1563-
if pending_target is not None and pending_target[2] is not None:
1564-
self._stream.wait_event(pending_target[2])
15651398
if not kv_cache.resize(reverted_cap):
15661399
raise RuntimeError(
15671400
f"Failed to revert KV cache capacity for request "
15681401
f"{req.py_request_id} from {kv_cache.capacity} to "
15691402
f"{reverted_cap}"
15701403
)
1571-
self._allocated_draft_lens.pop(req.py_request_id, None)
1572-
if reverted_pending_target is not None:
1573-
self._pending_compacted_capacities[req.py_request_id] = reverted_pending_target
15741404

15751405
def revert_allocate_context(self, req: LlmRequest) -> None:
15761406
"""Undo the capacity growth from this iter's ``resize_context``.
@@ -2365,25 +2195,19 @@ def release_index_slot(self, request_id: int) -> None:
23652195
self._early_freed_index_requests.add(request_id)
23662196

23672197
def free_resources(self, request: LlmRequest, pin_on_release: bool = False):
2368-
request_id = request.py_request_id
2369-
self._allocated_draft_lens.pop(request_id, None)
2370-
pending_target = self._pending_compacted_capacities.get(request_id)
2371-
if pending_target is not None and pending_target[2] is not None:
2372-
self._stream.wait_event(pending_target[2])
2373-
self._decode_capacity_only_requests.discard(request_id)
2374-
self._pending_compacted_capacities.pop(request_id, None)
2375-
kv_cache = self.kv_cache_map.pop(request_id, None)
2198+
self._allocated_draft_lens.pop(request.py_request_id, None)
2199+
kv_cache = self.kv_cache_map.pop(request.py_request_id, None)
23762200
if kv_cache is None:
2377-
self.impl.clear_stats_excluded(request_id)
2201+
self.impl.clear_stats_excluded(request.py_request_id)
23782202
return
23792203
kv_cache.discard_pending_stats()
23802204
self.try_commit_blocks_for_reuse(request, kv_cache)
23812205
kv_cache.close()
2382-
self.impl.clear_stats_excluded(request_id)
2383-
if request_id in self._early_freed_index_requests:
2384-
self._early_freed_index_requests.discard(request_id)
2206+
self.impl.clear_stats_excluded(request.py_request_id)
2207+
if request.py_request_id in self._early_freed_index_requests:
2208+
self._early_freed_index_requests.discard(request.py_request_id)
23852209
else:
2386-
self.index_mapper.remove_sequence(request_id)
2210+
self.index_mapper.remove_sequence(request.py_request_id)
23872211

23882212
def get_batch_cache_indices(
23892213
self, request_ids: List[int], layer_idx: Optional[int] = None
@@ -2656,46 +2480,39 @@ def update_resources(
26562480
# will be resumed by the scheduler on the next iteration.
26572481
if not kv_cache.is_active:
26582482
continue
2659-
completing = req.state in (
2660-
LlmRequestState.GENERATION_COMPLETE,
2661-
LlmRequestState.CONTEXT_INIT,
2483+
new_capacity = (
2484+
None
2485+
if req.state in (LlmRequestState.GENERATION_COMPLETE, LlmRequestState.CONTEXT_INIT)
2486+
else kv_cache.capacity - req.py_rewind_len
26622487
)
2663-
request_id = req.py_request_id
2664-
if request_id in self._decode_capacity_only_requests:
2665-
pending_target = self._pending_compacted_capacities.get(request_id)
2666-
if pending_target is not None:
2667-
target_capacity, published_capacity, event = pending_target
2668-
if event is not None:
2669-
self._stream.wait_event(event)
2670-
if completing:
2671-
new_capacity = None
2672-
elif pending_target is None:
2673-
new_capacity = kv_cache.capacity - req.py_rewind_len
2674-
else:
2675-
capacity_growth = kv_cache.capacity - published_capacity
2676-
if capacity_growth < 0:
2677-
raise ValueError(
2678-
f"Request {request_id} capacity {kv_cache.capacity} fell below "
2679-
f"published capacity {published_capacity}"
2680-
)
2681-
new_capacity = target_capacity + capacity_growth - req.py_rewind_len
2682-
success = kv_cache.resize(new_capacity, None)
2683-
if not success:
2488+
capacity_only = getattr(req, "py_kv_cache_decode_capacity_only", False) is True
2489+
history_length = None if capacity_only else req.max_beam_num_tokens - 1
2490+
compaction = getattr(req, "py_kv_cache_compaction", None)
2491+
consume_compaction = capacity_only and compaction is not None
2492+
if consume_compaction:
2493+
target_capacity, published_capacity, event = compaction
2494+
capacity_growth = kv_cache.capacity - published_capacity
2495+
if capacity_growth < 0:
26842496
raise ValueError(
2685-
f"Failed to resize KV cache for request {request_id} "
2686-
f"to capacity {new_capacity} while preserving its finalized history"
2497+
f"Request {req.py_request_id} capacity {kv_cache.capacity} "
2498+
f"fell below published capacity {published_capacity}"
26872499
)
2688-
if completing or pending_target is not None:
2689-
self._pending_compacted_capacities.pop(request_id, None)
2690-
continue
2691-
new_capacity = None if completing else kv_cache.capacity - req.py_rewind_len
2692-
success = kv_cache.resize(new_capacity, req.max_beam_num_tokens - 1)
2500+
# K+1 retains every block addressable by this forward. Resizing
2501+
# may race the full-table offset copy, but only rewrites the
2502+
# unreachable tail; the stream event protects page reuse.
2503+
if event is not None:
2504+
self._stream.wait_event(event)
2505+
if new_capacity is not None:
2506+
new_capacity = target_capacity + capacity_growth - req.py_rewind_len
2507+
success = kv_cache.resize(new_capacity, history_length)
26932508
if not success:
26942509
raise ValueError(
26952510
f"Failed to resize KV cache for request {req.py_request_id} "
26962511
f"to capacity {new_capacity} and history length "
2697-
f"{req.max_beam_num_tokens - 1} tokens at generation update"
2512+
f"{history_length} tokens at generation update"
26982513
)
2514+
if consume_compaction:
2515+
req.py_kv_cache_compaction = None
26992516

27002517
def copy_batch_block_offsets(
27012518
self,

0 commit comments

Comments
 (0)