@@ -801,6 +801,10 @@ def append_to_kv_heads_per_layer(
801801 self .enable_block_reuse = kv_cache_config .enable_block_reuse
802802 self .enable_partial_reuse = kv_cache_config .enable_partial_reuse
803803 self .disk_prefetch_num_reqs = kv_cache_config .disk_prefetch_num_reqs
804+ self ._decode_capacity_only_requests : set [int ] = set ()
805+ self ._pending_compacted_capacities : dict [
806+ int , tuple [int , int , Optional [torch .cuda .Event ]]
807+ ] = {}
804808
805809 # With pipeline parallelism, multiple microbatches can be in-flight
806810 # simultaneously, so we need slots for all concurrent sequences.
@@ -1283,6 +1287,142 @@ def is_request_active(self, request_id: int) -> bool:
12831287 kv_cache = self .kv_cache_map .get (request_id )
12841288 return kv_cache is not None and kv_cache .is_active
12851289
1290+ def enable_decode_capacity_only (self , request_id : int ) -> None :
1291+ """Preserve a request's finalized history while decoding.
1292+
1293+ Capacity-only decode is intended for KV compression methods that compact
1294+ uncommitted decode KV to the front of an all-full-attention cache. Context
1295+ updates still advance history to the finalized prompt prefix; decode
1296+ updates only resize capacity around that prefix.
1297+
1298+ Args:
1299+ request_id: Request whose decode capacity will be managed explicitly.
1300+
1301+ Raises:
1302+ ValueError: If block reuse is enabled or any local layer has an SWA or
1303+ SSM lifecycle.
1304+ """
1305+ if self .enable_block_reuse :
1306+ raise ValueError ("Decode capacity-only mode requires block reuse to be disabled" )
1307+ if (
1308+ self .max_beam_width != 1
1309+ or self .num_extra_kv_tokens
1310+ or self .max_total_draft_tokens
1311+ or self ._kv_reserve_draft_tokens
1312+ ):
1313+ raise ValueError (
1314+ "Decode capacity-only mode currently supports single-token, beam-width-one "
1315+ "decoding only"
1316+ )
1317+ if any (window is not None for window in self .max_attention_window_vec ) or any (
1318+ not isinstance (layer , AttentionLayerConfig ) or layer .sliding_window_size is not None
1319+ for layer in self .kv_cache_manager_py_config .layers
1320+ ):
1321+ raise ValueError (
1322+ "Decode capacity-only mode supports full-attention layers only; "
1323+ "SWA, VSWA, and SSM layers are not supported"
1324+ )
1325+ self ._decode_capacity_only_requests .add (request_id )
1326+
1327+ def has_pending_compacted_capacity (self , request_id : int ) -> bool :
1328+ """Return whether a compacted capacity target is waiting to be consumed."""
1329+ return request_id in self ._pending_compacted_capacities
1330+
1331+ def get_pre_forward_kv_length (self , request_id : int ) -> int :
1332+ """Return written KV tokens after scheduling and before the next forward.
1333+
1334+ The generation scheduler has reserved one unwritten slot at this point.
1335+ A pending compaction target can coexist with a later overlap reservation,
1336+ so derive the effective capacity from both instead of from request logical
1337+ length.
1338+ """
1339+ if request_id not in self ._decode_capacity_only_requests :
1340+ raise ValueError (f"Request { request_id } is not enabled for decode capacity-only mode" )
1341+ kv_cache = self .kv_cache_map .get (request_id )
1342+ if kv_cache is None or not kv_cache .is_active :
1343+ raise ValueError (f"Request { request_id } has no active KV cache" )
1344+ allocated_draft_len = self ._allocated_draft_lens .get (request_id )
1345+ if allocated_draft_len is None :
1346+ raise ValueError (
1347+ f"Request { request_id } has no generation capacity reserved for this forward"
1348+ )
1349+ if allocated_draft_len :
1350+ raise ValueError (
1351+ "Decode capacity-only mode currently supports single-token, beam-width-one "
1352+ "decoding only"
1353+ )
1354+ effective_capacity = kv_cache .capacity
1355+ pending_target = self ._pending_compacted_capacities .get (request_id )
1356+ if pending_target is not None :
1357+ target_capacity , published_capacity , _ = pending_target
1358+ capacity_growth = kv_cache .capacity - published_capacity
1359+ if capacity_growth < 0 :
1360+ raise ValueError (
1361+ f"Request { request_id } capacity { kv_cache .capacity } fell below "
1362+ f"published capacity { published_capacity } "
1363+ )
1364+ effective_capacity = target_capacity + capacity_growth
1365+ if effective_capacity < 1 :
1366+ raise ValueError (
1367+ f"Request { request_id } has invalid pre-forward capacity { effective_capacity } "
1368+ )
1369+ written_length = effective_capacity - 1
1370+ if written_length < kv_cache .history_length :
1371+ raise ValueError (
1372+ f"Request { request_id } pre-forward KV length { written_length } is below "
1373+ f"finalized history { kv_cache .history_length } "
1374+ )
1375+ return written_length
1376+
1377+ def set_compacted_capacity (
1378+ self ,
1379+ request_id : int ,
1380+ target_capacity : int ,
1381+ event : Optional [torch .cuda .Event ] = None ,
1382+ ) -> None :
1383+ """Queue a one-shot physical capacity target for a compacted request.
1384+
1385+ The target is consumed by the next active generation update. Capacity
1386+ reserved after this call is added to the target, so an overlapped next
1387+ forward cannot lose its slot. If an event is supplied, the manager's
1388+ execution stream waits for it before releasing trailing KV pages.
1389+
1390+ Args:
1391+ request_id: Request previously enabled for capacity-only decode.
1392+ target_capacity: Physical capacity before the generation rewind is
1393+ applied.
1394+ event: Optional CUDA event recorded after compaction.
1395+
1396+ Raises:
1397+ ValueError: If the request is not enabled, has no KV cache, the target
1398+ does not leave a forward slot above finalized history, exceeds
1399+ current capacity, or another target is still pending.
1400+ """
1401+ if request_id not in self ._decode_capacity_only_requests :
1402+ raise ValueError (f"Request { request_id } is not enabled for decode capacity-only mode" )
1403+ if target_capacity < 0 :
1404+ raise ValueError (f"Compacted capacity must be non-negative, got { target_capacity } " )
1405+ kv_cache = self .kv_cache_map .get (request_id )
1406+ if kv_cache is None :
1407+ raise ValueError (f"Request { request_id } has no KV cache to compact" )
1408+ if target_capacity > kv_cache .capacity :
1409+ raise ValueError (
1410+ f"Compacted capacity { target_capacity } for request { request_id } "
1411+ f"cannot exceed current capacity { kv_cache .capacity } "
1412+ )
1413+ if target_capacity <= kv_cache .history_length :
1414+ raise ValueError (
1415+ f"Compacted capacity { target_capacity } for request { request_id } "
1416+ f"must leave a forward slot above finalized history { kv_cache .history_length } "
1417+ )
1418+ if request_id in self ._pending_compacted_capacities :
1419+ raise ValueError (f"Request { request_id } already has a pending compacted capacity" )
1420+ self ._pending_compacted_capacities [request_id ] = (
1421+ target_capacity ,
1422+ kv_cache .capacity ,
1423+ event ,
1424+ )
1425+
12861426 def _effective_draft_len (self , req : LlmRequest ) -> int :
12871427 """Draft token length to use for next-step KV capacity calculation.
12881428
@@ -1384,23 +1524,53 @@ def revert_allocate_generation(self, req: LlmRequest) -> None:
13841524 host page-index buffer.
13851525
13861526 Mirror the effective draft length used in _required_gen_capacity
1387- so disagg-gen-trans-complete revert stays symmetric.
1527+ so disagg-gen-trans-complete revert stays symmetric. The scheduler
1528+ overwrites ``_allocated_draft_lens`` for every revert-eligible
1529+ allocation; a successful revert consumes that marker.
13881530 """
13891531 kv_cache = self .kv_cache_map .get (req .py_request_id )
13901532 if kv_cache is None or not kv_cache .is_active :
13911533 return
1392- draft_len = self ._allocated_draft_lens .pop (
1534+ has_allocation_marker = req .py_request_id in self ._allocated_draft_lens
1535+ draft_len = self ._allocated_draft_lens .get (
13931536 req .py_request_id , self ._effective_draft_len (req )
13941537 )
1538+ pending_target = self ._pending_compacted_capacities .get (req .py_request_id )
1539+ published_this_allocation = (
1540+ has_allocation_marker
1541+ and pending_target is not None
1542+ and pending_target [1 ] == kv_cache .capacity
1543+ )
13951544 reverted_cap = kv_cache .capacity - 1 - draft_len
13961545 if reverted_cap < 0 :
1546+ self ._allocated_draft_lens .pop (req .py_request_id , None )
13971547 return
1548+ reverted_pending_target = None
1549+ if published_this_allocation :
1550+ target_capacity , published_capacity , event = pending_target
1551+ reverted_target = target_capacity - 1 - draft_len
1552+ if reverted_target < kv_cache .history_length :
1553+ raise RuntimeError (
1554+ f"Reverting request { req .py_request_id } would move compacted "
1555+ f"capacity { reverted_target } below finalized history "
1556+ f"{ kv_cache .history_length } "
1557+ )
1558+ reverted_pending_target = (
1559+ reverted_target ,
1560+ published_capacity - 1 - draft_len ,
1561+ event ,
1562+ )
1563+ if pending_target is not None and pending_target [2 ] is not None :
1564+ self ._stream .wait_event (pending_target [2 ])
13981565 if not kv_cache .resize (reverted_cap ):
13991566 raise RuntimeError (
14001567 f"Failed to revert KV cache capacity for request "
14011568 f"{ req .py_request_id } from { kv_cache .capacity } to "
14021569 f"{ reverted_cap } "
14031570 )
1571+ self ._allocated_draft_lens .pop (req .py_request_id , None )
1572+ if reverted_pending_target is not None :
1573+ self ._pending_compacted_capacities [req .py_request_id ] = reverted_pending_target
14041574
14051575 def revert_allocate_context (self , req : LlmRequest ) -> None :
14061576 """Undo the capacity growth from this iter's ``resize_context``.
@@ -2195,19 +2365,25 @@ def release_index_slot(self, request_id: int) -> None:
21952365 self ._early_freed_index_requests .add (request_id )
21962366
21972367 def free_resources (self , request : LlmRequest , pin_on_release : bool = False ):
2198- self ._allocated_draft_lens .pop (request .py_request_id , None )
2199- kv_cache = self .kv_cache_map .pop (request .py_request_id , None )
2368+ request_id = request .py_request_id
2369+ self ._allocated_draft_lens .pop (request_id , None )
2370+ pending_target = self ._pending_compacted_capacities .get (request_id )
2371+ if pending_target is not None and pending_target [2 ] is not None :
2372+ self ._stream .wait_event (pending_target [2 ])
2373+ self ._decode_capacity_only_requests .discard (request_id )
2374+ self ._pending_compacted_capacities .pop (request_id , None )
2375+ kv_cache = self .kv_cache_map .pop (request_id , None )
22002376 if kv_cache is None :
2201- self .impl .clear_stats_excluded (request . py_request_id )
2377+ self .impl .clear_stats_excluded (request_id )
22022378 return
22032379 kv_cache .discard_pending_stats ()
22042380 self .try_commit_blocks_for_reuse (request , kv_cache )
22052381 kv_cache .close ()
2206- self .impl .clear_stats_excluded (request . py_request_id )
2207- if request . py_request_id in self ._early_freed_index_requests :
2208- self ._early_freed_index_requests .discard (request . py_request_id )
2382+ self .impl .clear_stats_excluded (request_id )
2383+ if request_id in self ._early_freed_index_requests :
2384+ self ._early_freed_index_requests .discard (request_id )
22092385 else :
2210- self .index_mapper .remove_sequence (request . py_request_id )
2386+ self .index_mapper .remove_sequence (request_id )
22112387
22122388 def get_batch_cache_indices (
22132389 self , request_ids : List [int ], layer_idx : Optional [int ] = None
@@ -2480,11 +2656,39 @@ def update_resources(
24802656 # will be resumed by the scheduler on the next iteration.
24812657 if not kv_cache .is_active :
24822658 continue
2483- new_capacity = (
2484- None
2485- if req .state in (LlmRequestState .GENERATION_COMPLETE , LlmRequestState .CONTEXT_INIT )
2486- else kv_cache .capacity - req .py_rewind_len
2659+ completing = req .state in (
2660+ LlmRequestState .GENERATION_COMPLETE ,
2661+ LlmRequestState .CONTEXT_INIT ,
24872662 )
2663+ request_id = req .py_request_id
2664+ if request_id in self ._decode_capacity_only_requests :
2665+ pending_target = self ._pending_compacted_capacities .get (request_id )
2666+ if pending_target is not None :
2667+ target_capacity , published_capacity , event = pending_target
2668+ if event is not None :
2669+ self ._stream .wait_event (event )
2670+ if completing :
2671+ new_capacity = None
2672+ elif pending_target is None :
2673+ new_capacity = kv_cache .capacity - req .py_rewind_len
2674+ else :
2675+ capacity_growth = kv_cache .capacity - published_capacity
2676+ if capacity_growth < 0 :
2677+ raise ValueError (
2678+ f"Request { request_id } capacity { kv_cache .capacity } fell below "
2679+ f"published capacity { published_capacity } "
2680+ )
2681+ new_capacity = target_capacity + capacity_growth - req .py_rewind_len
2682+ success = kv_cache .resize (new_capacity , None )
2683+ if not success :
2684+ raise ValueError (
2685+ f"Failed to resize KV cache for request { request_id } "
2686+ f"to capacity { new_capacity } while preserving its finalized history"
2687+ )
2688+ if completing or pending_target is not None :
2689+ self ._pending_compacted_capacities .pop (request_id , None )
2690+ continue
2691+ new_capacity = None if completing else kv_cache .capacity - req .py_rewind_len
24882692 success = kv_cache .resize (new_capacity , req .max_beam_num_tokens - 1 )
24892693 if not success :
24902694 raise ValueError (
0 commit comments