From 63438c29a2e3f46c8d6d049941c2f19e37caa884 Mon Sep 17 00:00:00 2001 From: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com> Date: Mon, 30 Mar 2026 15:36:07 -0700 Subject: [PATCH 1/8] [TRTLLM-11608][feat] Chunked KV cache transfer with early block release Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com> Made-with: Cursor Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com> Made-with: Cursor Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com> Made-with: Cursor --- .../batch_manager/kvCacheManager.h | 19 + .../batch_manager/kvCacheManager.cpp | 53 +++ .../nanobind/batch_manager/kvCacheManager.cpp | 4 +- .../_torch/disaggregation/base/transfer.py | 12 +- .../_torch/disaggregation/native/transfer.py | 94 ++++- .../_torch/disaggregation/transceiver.py | 155 +++++++- .../_torch/pyexecutor/resource_manager.py | 18 + tensorrt_llm/llmapi/llm_args.py | 15 + .../disaggregated/test_chunked_transfer.py | 373 ++++++++++++++++++ .../disaggregated/test_kv_transfer.py | 291 +++++++++++++- tests/unittest/llmapi/test_llm_args.py | 17 + 11 files changed, 1028 insertions(+), 23 deletions(-) create mode 100644 tests/unittest/disaggregated/test_chunked_transfer.py diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h index c665f7a8df95..7182cae16350 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h @@ -973,6 +973,13 @@ class WindowBlockManager std::optional releaseBlocks( GenerationRequest& sequence, OptionalRef llmRequest); + //! \brief Release the first numBlocks prefix blocks of a sequence. + //! \details Used by disaggregated serving to free sender-side KV memory + //! for blocks whose data has already been transferred. Reuses the + //! detachFrontBlock mechanism (decRefCount + eviction policy release). + //! Cumulative: calling with 3 then 5 releases blocks 0-4 total. + void releasePrefixBlocks(GenerationRequest& sequence, SizeType32 numBlocks); + //! \brief Simulate freeing all blocks for that sequence to check impact on number of free blocks void schedulingReleaseBlocks(LlmRequest::RequestIdType requestId); @@ -1514,6 +1521,13 @@ class BlockManager std::optional releaseBlocks( GenerationRequest& sequence, OptionalRef llmRequest = std::nullopt, bool pinBlocks = false); + //! \brief Release the first numBlocks prefix blocks of a sequence. + //! \details Mirrors detachFrontBlock logic: decRefCount + eviction policy + //! release for each prefix block. The mNumFrontBlocksRemoved counter on + //! GenerationRequest ensures releaseBlocks (called during removeSequence) + //! skips already-freed prefix blocks. + void releasePrefixBlocks(GenerationRequest& sequence, SizeType32 numBlocks); + [[nodiscard]] std::vector storeBlocksForReuse( GenerationRequest& sequence, OptionalRef llmRequest = std::nullopt, bool pinBlocks = false); @@ -2431,6 +2445,11 @@ class KVCacheManager : public BaseKVCacheManager [[nodiscard]] std::optional removeSequence(LlmRequest::RequestIdType requestId, OptionalRef llmRequest = std::nullopt, bool pinOnRelease = false) override; + //! \brief Release prefix blocks for a sequence without removing it. + //! \details Used by disaggregated serving for early block release during + //! chunked KV cache transfer. No-op if the sequence does not exist. + void releasePrefixBlocks(LlmRequest::RequestIdType requestId, SizeType32 numBlocks); + void schedulingRemoveSequence(LlmRequest::RequestIdType requestId) override; [[nodiscard]] runtime::ITensor::SharedPtr getBlockPoolPointers() const override diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index 0fb8af1527ae..028444cec48e 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -2897,6 +2897,14 @@ std::optional BlockManager::releaseBlocks( return lastStoredId; } +void BlockManager::releasePrefixBlocks(GenerationRequest& sequence, SizeType32 numBlocks) +{ + for (auto& [_, manager] : mWindowBlockManagers) + { + manager.releasePrefixBlocks(sequence, numBlocks); + } +} + void BlockManager::pinBlocks(GenerationRequest& sequence) { for (auto& [_, manager] : mWindowBlockManagers) @@ -3709,6 +3717,36 @@ void WindowBlockManager::detachFrontBlock(GenerationRequest& sequence) sequence.getNumFrontBlocksRemoved(mWindowSize)); } +void WindowBlockManager::releasePrefixBlocks(GenerationRequest& sequence, SizeType32 numBlocks) +{ + TLLM_CHECK_WITH_INFO( + sequence.getBeamWidth() == 1, "[kv cache manager] releasePrefixBlocks does not support beamWidth > 1"); + + auto const requestId = sequence.getRequestId(); + auto& allocatedBlocks = mAllocatedBlocksPerSeq.at(requestId); + SizeType32 const target = std::min(numBlocks, static_cast(allocatedBlocks.size())); + + while (sequence.getNumFrontBlocksRemoved() < target) + { + SizeType32 const blockIdx = sequence.getNumFrontBlocksRemoved(); + auto& block = allocatedBlocks.at(blockIdx); + + TLLM_LOG_DEBUG("%s::releasePrefixBlocks - Releasing block %d from sequence %lu", mLogPrefix.c_str(), + block->getBlockId(), requestId); + + if (block->hasRefs()) + { + block->decRefCount(); + } + if (!block->hasRefs()) + { + mEvictionPolicy->releaseBlock(block); + } + + sequence.removeFrontBlock(mWindowSize); + } +} + PrefixReuseSummary KVCacheManager::analyzePrefixReuse( VecUniqueTokens const& uniqueTokens, LlmRequest const& llmRequest) const { @@ -3885,6 +3923,21 @@ std::optional KVCacheManager::removeSequence( return lastStoredId; } +void KVCacheManager::releasePrefixBlocks(RequestIdType requestId, SizeType32 numBlocks) +{ + if (numBlocks <= 0) + { + return; + } + std::scoped_lock lock(mSequencesMtx); + auto it = mSequences.find(requestId); + if (it == mSequences.end()) + { + return; + } + mBlockManager.releasePrefixBlocks(it->second, numBlocks); +} + std::vector KVCacheManager::storeBlocksForReuse( RequestIdType requestId, OptionalRef llmRequest, bool pinBlocks) { diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp index 12b29d4981e2..4e43e74d4a96 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp @@ -683,7 +683,9 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m) .def("copy_linear_attention_block", &tbk::KVCacheManager::copyLinearAttentionBlock, nb::arg("llm_request"), nb::call_guard()) .def("copy_linear_attention_block_batch", &tbk::KVCacheManager::copyLinearAttentionBlockBatch, - nb::arg("llm_requests"), nb::call_guard()); + nb::arg("llm_requests"), nb::call_guard()) + .def("release_prefix_blocks", &tbk::KVCacheManager::releasePrefixBlocks, nb::arg("request_id"), + nb::arg("num_blocks"), nb::call_guard()); } void tb::BasePeftCacheManagerBindings::initBindings(nb::module_& m) diff --git a/tensorrt_llm/_torch/disaggregation/base/transfer.py b/tensorrt_llm/_torch/disaggregation/base/transfer.py index a80b03153e42..222301e5b144 100644 --- a/tensorrt_llm/_torch/disaggregation/base/transfer.py +++ b/tensorrt_llm/_torch/disaggregation/base/transfer.py @@ -158,7 +158,17 @@ def __init__(self, sender: SenderBase, args: SessionArgsBase): self._sender = sender @abstractmethod - def send(self, slice: KVSlice) -> None: ... + def send(self, slice: KVSlice, chunk_block_offset: int = 0) -> None: + """Send a KV slice. + + Args: + slice: The KV slice describing which source blocks to send. + chunk_block_offset: Block offset into the receiver's full + destination block list for this chunk. Used by sender-side + chunking to slice the receiver's destination blocks correctly. + Defaults to 0 for monolithic transfer. + """ + ... class RxSessionBase(_SessionBase): diff --git a/tensorrt_llm/_torch/disaggregation/native/transfer.py b/tensorrt_llm/_torch/disaggregation/native/transfer.py index 1068bebfd7e0..1deb5d4f2d9d 100644 --- a/tensorrt_llm/_torch/disaggregation/native/transfer.py +++ b/tensorrt_llm/_torch/disaggregation/native/transfer.py @@ -7,7 +7,7 @@ import weakref from dataclasses import dataclass from enum import Enum -from typing import List, Optional, Union +from typing import Callable, List, Optional, Union import msgpack import numpy as np @@ -202,6 +202,17 @@ def __init__(self, params: DisaggregatedParams, slot: Optional[int]): class KVSendTask(SendTaskBase): + """A per-slice send task within a TxSession. + + Args: + kv_slice: The KV slice describing which blocks to transfer. + params: Disaggregated serving parameters for this request. + slice_id: Index of this slice within the session's task list. + chunk_block_offset: Block offset into the receiver's full + destination block list. Used by sender-side chunking to + slice the receiver's destination blocks correctly. + """ + def __init__( self, kv_slice: KVSlice, @@ -209,13 +220,15 @@ def __init__( slice_id: int, prompt_len: Optional[int] = None, beam_width: int = 1, - ): + chunk_block_offset: int = 0, + ) -> None: super().__init__(params) self.slice_id = slice_id self.transferred_count = 0 self._slice = kv_slice self._prompt_len = prompt_len self._beam_width = beam_width + self.chunk_block_offset = chunk_block_offset class Sender(SenderBase): @@ -491,7 +504,7 @@ def _deliver_kv_to_agent(self, write_meta: WriteMeta): f"in {status.value} state; sending FAILED to receiver" ) # Task may have been enqueued after cancel() already iterated kv_tasks, - # so its future was never set by cancel(). Set it here as a fallback. + # so its event was never set by cancel(). Set it here as a fallback. task.fail( RuntimeError(f"session {write_meta.unique_rid} {status.value}, transfer aborted") ) @@ -501,7 +514,7 @@ def _deliver_kv_to_agent(self, write_meta: WriteMeta): str(self._instance_rank).encode("ascii"), str(write_meta.unique_rid).encode("ascii"), str(write_meta.slice_id).encode("ascii"), - b"True", # is_last_slice — ensures receiver resolves its task future + b"True", # is_last_slice — ensures receiver resolves its task event AgentResult.FAILED.value.encode("ascii"), ] ) @@ -518,13 +531,19 @@ def _deliver_kv_to_agent(self, write_meta: WriteMeta): if timer: timer.record_transfer_end(write_meta.peer_rank) - ## TODO: just last slice need to send task state? + # The receiver always has a single monolithic task (slice_id=0). + # Sender-side chunking is transparent to the receiver: only the + # last chunk carries is_last_slice=True so the receiver knows + # when all data has arrived. Intermediate chunk results are + # sent (not suppressed) so that RDMA failures propagate to the + # receiver immediately rather than requiring a timeout. + receiver_slice_id = 0 self._get_or_connect_thread_dealer(write_meta.peer_endpoint).send( [ MessageType.KV_AGENT_RESULT, str(self._instance_rank).encode("ascii"), str(write_meta.unique_rid).encode("ascii"), - str(write_meta.slice_id).encode("ascii"), + str(receiver_slice_id).encode("ascii"), str(write_meta.is_last_slice).encode("ascii"), agent_result.value.encode("ascii"), ] @@ -551,6 +570,29 @@ def _deliver_kv_to_agent(self, write_meta: WriteMeta): ) else: task.complete() + if session._on_chunk_transferred is not None: + try: + # Use the max across layer groups as the + # cumulative release count. For asymmetric + # layer groups (e.g., sliding window), shorter + # groups may have fewer blocks per chunk, but + # each WindowBlockManager independently clamps + # to its own allocated block count via + # min(numBlocks, allocatedBlocks.size()). + num_blocks = max( + (len(ids) for ids in task._slice.block_ids_per_layer_groups), + default=0, + ) + session._on_chunk_transferred( + request_id=session.request_id, + chunk_block_offset=task.chunk_block_offset, + num_blocks=num_blocks, + ) + except Exception as e: + logger.warning( + f"on_chunk_transferred callback failed for " + f"request {session.request_id} slice {write_meta.slice_id}: {e}" + ) logger.debug( f"deliver_kv_to_agent completed: unique_rid={write_meta.unique_rid}, " @@ -683,10 +725,20 @@ def _build_kv_write_meta(self, task: KVSendTask, req_info: RecvReqInfo) -> Write dst_block_ids_per_groups = req_info.block_ids_per_layer_groups src_block_ids_per_groups = task._slice.block_ids_per_layer_groups - # Aggregate fragments from all matching pools using numpy concatenation + chunk_offset = task.chunk_block_offset for (self_lg, self_pi), (peer_lg, peer_pi) in pool_mapping.items(): src_block_ids = src_block_ids_per_groups[self_lg] - dst_block_ids = dst_block_ids_per_groups[peer_lg] + full_dst_block_ids = dst_block_ids_per_groups[peer_lg] + + # When sender uses chunking, the receiver sends all dst + # blocks in a single RecvReqInfo. Slice dst to match + # this task's src chunk position. + if chunk_offset > 0 or len(src_block_ids) < len(full_dst_block_ids): + dst_block_ids = full_dst_block_ids[ + chunk_offset : chunk_offset + len(src_block_ids) + ] + else: + dst_block_ids = full_dst_block_ids # Speculative decoding: generation may have one extra draft-token block. block_diff = dst_block_ids.size - src_block_ids.size @@ -943,7 +995,7 @@ def _respond_with_kv(self, _send_id: bytes, message: list[bytes]): self._save_peer_req_info(info) tasks = list(session.kv_tasks) # No tasks: no worker will send KV_AGENT_RESULT FAILED to the receiver. - # Send it directly to unblock the receiver's TRANSFERRING task future; + # Send it directly to unblock the receiver's TRANSFERRING task event; # CANCEL_SESSION alone would leave it stuck indefinitely. if not tasks and session.status in (SessionStatus.ERROR, SessionStatus.CANCELLED): self._send_failed_result_to_receiver(info) @@ -1076,6 +1128,7 @@ def __init__( timeout_s: Optional[float] = None, prompt_len: Optional[int] = None, beam_width: int = 1, + on_chunk_transferred: Optional[Callable] = None, ): super().__init__( sender, @@ -1091,6 +1144,7 @@ def __init__( self.kv_tasks = [] self.aux_task = None self.lock = threading.Lock() + self._on_chunk_transferred = on_chunk_transferred self._exception: Optional[Exception] = None self._closed = False @@ -1126,7 +1180,7 @@ def status(self) -> SessionStatus: return SessionStatus.TRANSFERRING return SessionStatus.READY if self.receiver_ready else SessionStatus.INIT - def send(self, slice: KVSlice) -> None: + def send(self, slice: KVSlice, chunk_block_offset: int = 0) -> None: with self.lock: params = self._base_args.params slice_id = len(self.kv_tasks) @@ -1136,6 +1190,7 @@ def send(self, slice: KVSlice) -> None: slice_id, prompt_len=self._base_args.prompt_len, beam_width=self._base_args.beam_width, + chunk_block_offset=chunk_block_offset, ) task._unique_rid = self.disagg_request_id self.kv_tasks.append(task) @@ -1982,7 +2037,23 @@ def populate_instance_and_rank_info(self, endpoints: list[str], layer_num_per_pp self._rank_info.sender_endpoints = endpoints self._rank_info.layer_num_per_pp = layer_num_per_pp - def create_tx_session(self, request: LlmRequest) -> TxSession: + def create_tx_session( + self, + request: LlmRequest, + on_chunk_transferred: Optional[Callable] = None, + ) -> TxSession: + """Create a TxSession for the given request. + + Args: + request: The LLM request to create a send session for. + on_chunk_transferred: Optional callback invoked on the + sender worker thread after each chunk's RDMA completes. + Signature: ``(request_id: int, chunk_block_offset: int, + num_blocks: int) -> None``. + + Returns: + A new ``TxSession`` ready to accept ``send()`` calls. + """ params = request.py_disaggregated_params assert params is not None return TxSession( @@ -1993,6 +2064,7 @@ def create_tx_session(self, request: LlmRequest) -> TxSession: timeout_s=self._config.tx_timeout_s, prompt_len=request.prompt_len, beam_width=request.py_beam_width, + on_chunk_transferred=on_chunk_transferred, ) def create_rx_session(self, request: LlmRequest) -> RxSession: diff --git a/tensorrt_llm/_torch/disaggregation/transceiver.py b/tensorrt_llm/_torch/disaggregation/transceiver.py index 6db88215b816..d75e64df9078 100644 --- a/tensorrt_llm/_torch/disaggregation/transceiver.py +++ b/tensorrt_llm/_torch/disaggregation/transceiver.py @@ -1,7 +1,9 @@ +import math +import queue import uuid from collections import defaultdict from itertools import chain -from typing import Any, Callable, Dict, List, Optional, cast +from typing import Any, Callable, Dict, List, Optional, Tuple, cast import numpy as np import torch @@ -96,6 +98,9 @@ def __init__( self._recv_reqs = {} self._wait_reqs = {} self._page_table = self._transfer_worker.page_table + self._chunk_size_blocks = cache_transceiver_config.chunk_size_blocks + self._pending_prefix_releases: queue.Queue[Tuple[int, int]] = queue.Queue() + self._chunk_callback: Optional[Callable] = self._make_chunk_callback() def _broadcast_instance_name(self) -> str: if self._dist.rank == 0: @@ -221,6 +226,121 @@ def _create_kv_slice( token_range=token_range, ) + def _collect_block_ids(self, req: LlmRequest) -> List[List[int]]: + """Collect all valid block IDs per layer group for a request. + + Args: + req: The LLM request whose KV cache block IDs to collect. + + Returns: + A list of block ID lists, one per layer group. Each inner + list contains the physical block IDs for that layer group, + filtered for sliding-window relevance. + """ + return self._create_kv_slice(req).block_ids_per_layer_groups + + def _create_kv_slices(self, req: LlmRequest) -> List[KVSlice]: + """Create one or more KVSlice objects for a request. + + When ``chunk_size_blocks`` is ``None``, returns a single slice + covering all blocks. Otherwise, each layer group's block ID + list is partitioned into slices of at most ``chunk_size_blocks`` + blocks. + + Args: + req: The LLM request to create slices for. + + Returns: + A list of ``KVSlice`` objects. Only the last slice has + ``is_last_slice=True``. + + Raises: + ValueError: If the reassembled block IDs from all slices do not + match the original block IDs. + """ + base_slice = self._create_kv_slice(req) + all_block_ids = base_slice.block_ids_per_layer_groups + + if self._chunk_size_blocks is None: + return [base_slice] + + max_blocks = max((len(ids) for ids in all_block_ids), default=0) + if max_blocks == 0: + return [base_slice] + + num_chunks = math.ceil(max_blocks / self._chunk_size_blocks) + slices: List[KVSlice] = [] + for chunk_idx in range(num_chunks): + start = chunk_idx * self._chunk_size_blocks + end = start + self._chunk_size_blocks + is_last = chunk_idx == num_chunks - 1 + + chunk_block_ids = [ids[start:end] for ids in all_block_ids] + slices.append( + KVSlice( + is_last_slice=is_last, + block_ids_per_layer_groups=chunk_block_ids, + mamba_state_index=base_slice.mamba_state_index, + token_range=base_slice.token_range, + ) + ) + + for lg_idx, original_ids in enumerate(all_block_ids): + reassembled = np.concatenate([s.block_ids_per_layer_groups[lg_idx] for s in slices]) + if not np.array_equal(reassembled, original_ids): + raise ValueError( + f"Chunking integrity check failed for layer group {lg_idx}: " + f"expected {len(original_ids)} blocks, got {len(reassembled)}" + ) + + return slices + + def _make_chunk_callback(self) -> Optional[Callable]: + """Return a callback for early prefix block release. + + The callback is invoked on the sender worker thread after each + chunk's RDMA finishes. It enqueues a release request that the + main thread drains via ``_drain_pending_releases``. + + Early release is disabled when: + - ``chunk_size_blocks`` is not set (no chunking) + - The KV cache manager does not support ``release_prefix_blocks`` + + The callback is created once at init time and shared across all + sessions (all sessions use the same release queue). + + Returns: + A callback ``(request_id, chunk_block_offset, num_blocks) -> None`` + if chunking is enabled and the KV cache manager supports + ``release_prefix_blocks``, otherwise ``None``. + """ + if self._chunk_size_blocks is None: + return None + if not hasattr(self._kv_cache_manager, "release_prefix_blocks"): + return None + + release_queue = self._pending_prefix_releases + + def _on_chunk_transferred(request_id: int, chunk_block_offset: int, num_blocks: int): + cumulative_blocks = chunk_block_offset + num_blocks + release_queue.put((request_id, cumulative_blocks)) + + return _on_chunk_transferred + + def _drain_pending_releases(self) -> None: + """Process all queued prefix block releases on the main thread. + + Drains the ``_pending_prefix_releases`` queue and calls + ``release_prefix_blocks`` on the KV cache manager for each + entry. Must be called from the main executor thread only. + """ + while True: + try: + request_id, num_blocks = self._pending_prefix_releases.get_nowait() + except queue.Empty: + break + self._kv_cache_manager.release_prefix_blocks(request_id, num_blocks) + @staticmethod def _split_packed_beam_block_ids( block_ids: np.ndarray, @@ -422,7 +542,13 @@ def _get_or_create_send_session(self, req: LlmRequest) -> TxSessionBase: rid = get_unique_rid(req) assert rid is not None if rid not in self._send_sessions: - self._send_sessions[rid] = self._transfer_worker.create_tx_session(req) + # Skip early release for beam_width > 1: C++ releasePrefixBlocks + # asserts beamWidth == 1. Chunking still works, but blocks are freed + # at session teardown instead. + callback = self._chunk_callback if req.sampling_config.beam_width <= 1 else None + self._send_sessions[rid] = self._transfer_worker.create_tx_session( + req, on_chunk_transferred=callback + ) return self._send_sessions[rid] def _finalize_send(self, req: LlmRequest, session: TxSessionBase): @@ -446,7 +572,12 @@ def _finalize_send(self, req: LlmRequest, session: TxSessionBase): def respond_and_send_async(self, req: LlmRequest): session = self._get_or_create_send_session(req) req.state = LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS - session.send(self._create_kv_slice(req)) + chunk_block_offset = 0 + for kv_slice in self._create_kv_slices(req): + session.send(kv_slice, chunk_block_offset=chunk_block_offset) + chunk_block_offset += max( + (len(ids) for ids in kv_slice.block_ids_per_layer_groups), default=0 + ) self._finalize_send(req, session) @nvtx_range("KvCacheTransceiverV2.request_and_receive_sync") @@ -482,7 +613,17 @@ def request_and_receive_sync(self, req: LlmRequest): self._recv_reqs.pop(rid, None) @nvtx_range("KvCacheTransceiverV2.request_and_receive_async") - def request_and_receive_async(self, req: LlmRequest): + def request_and_receive_async(self, req: LlmRequest) -> None: + """Start background KV cache receive from the context server. + + The receiver always uses a single monolithic slice. Chunking is + sender-only: the sender splits its source blocks into chunks and + slices the receiver's destination blocks to match each chunk. + + Args: + req: The generation request whose KV cache blocks to receive + into. + """ rid = get_unique_rid(req) if rid in self._recv_sessions: logger.warning( @@ -492,12 +633,16 @@ def request_and_receive_async(self, req: LlmRequest): req.state = LlmRequestState.DISAGG_GENERATION_TRANS_IN_PROGRESS session = self._transfer_worker.create_rx_session(req) self._recv_sessions[rid] = session - session.receive(self._create_kv_slice(req)) + all_block_ids = self._collect_block_ids(req) + full_slice = KVSlice(is_last_slice=True, block_ids_per_layer_groups=all_block_ids) + session.receive(full_slice) self._recv_reqs[rid] = req def check_context_transfer_status( self, at_least_request_num: Optional[int], mark_complete: bool = False ): + self._drain_pending_releases() + block_all = at_least_request_num is None wait_num = at_least_request_num if not block_all else 0 diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index 1ef1c2843d8d..8237e6740323 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -1010,6 +1010,24 @@ def free_resources(self, request: LlmRequest, pin_on_release: bool = False): return self.impl.remove_sequence(request.py_request_id, request, pin_on_release) + def release_prefix_blocks(self, request_id: int, num_blocks: int) -> None: + """Release leading blocks from a request's V1 KV cache. + + Used by disaggregated serving to free sender-side KV memory + for blocks whose data has already been transferred. The + underlying C++ ``KVCacheManager::releasePrefixBlocks`` frees + blocks via the eviction policy so they can be reused. + + Args: + request_id: The request whose KV cache to partially free. + num_blocks: Number of leading blocks to release + (cumulative from the start of the sequence). + + Note: + No-op if the sequence does not exist (already removed). + """ + self.impl.release_prefix_blocks(request_id, num_blocks) + def store_blocks_for_reuse(self, request: LlmRequest, pin_blocks: bool = False): diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 97a2ec22ccb7..21da01b90186 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -3365,7 +3365,22 @@ class CacheTransceiverConfig(StrictBaseModel, PybindMirror): "Timeout in milliseconds to wait for the sender future to be ready when scheduled batch size is 0. This allows the request to be eventually cancelled by the user or because of kv_transfer_timeout_ms" ) + chunk_size_blocks: Optional[PositiveInt] = Field( + default=None, + description= + "Maximum number of KV cache blocks per layer group per chunk for " + "chunked KV cache transfer. When set, each layer group's block list " + "is partitioned into slices of at most this many blocks, and each " + "slice is transferred independently. The total data per chunk is " + "approximately chunk_size_blocks * num_layer_groups * slot_bytes. " + "This reduces per-transfer NIXL descriptor pressure for long " + "sequences. When None (default), the entire KV cache is transferred " + "in a single slice. Only effective with the Python transceiver " + "(transceiver_runtime='PYTHON').") + def _to_pybind(self): + # chunk_size_blocks is consumed by the Python transceiver only + # and has no C++ counterpart, so it is intentionally omitted. return _CacheTransceiverConfig( backend=_CacheTransceiverBackendType.from_string(self.backend), max_tokens_in_buffer=self.max_tokens_in_buffer, diff --git a/tests/unittest/disaggregated/test_chunked_transfer.py b/tests/unittest/disaggregated/test_chunked_transfer.py new file mode 100644 index 000000000000..da6d21194cf5 --- /dev/null +++ b/tests/unittest/disaggregated/test_chunked_transfer.py @@ -0,0 +1,373 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for chunked KV cache transfer (sender-only chunking). + +These tests validate the session state machine, callback plumbing, and +release queue mechanics using the real TxSession/RxSession classes with +lightweight stub sender/receiver objects. +""" + +import queue +from unittest.mock import MagicMock + +import pytest + +from tensorrt_llm import DisaggregatedParams +from tensorrt_llm._torch.disaggregation.base.transfer import KVSlice, SessionStatus, WaitResult +from tensorrt_llm._torch.disaggregation.native.transfer import ( + AgentResult, + KVSendTask, + RxSession, + TaskStatus, + TxSession, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_params(rid: int = 42) -> DisaggregatedParams: + return DisaggregatedParams(disagg_request_id=rid) + + +def _stub_sender(): + """Create a stub sender with no-op methods needed by TxSession.""" + sender = MagicMock() + sender.setup_session = MagicMock() + sender._get_req_info = MagicMock(return_value=None) + sender.dispatch_task = MagicMock() + return sender + + +def _stub_receiver(): + """Create a stub receiver with no-op methods needed by RxSession.""" + receiver = MagicMock() + receiver.setup_session = MagicMock() + receiver.dispatch_task = MagicMock() + return receiver + + +def _make_tx_session(num_slices: int, rid: int = 42, **kwargs) -> TxSession: + """Create a real TxSession and send num_slices slices into it.""" + params = _make_params(rid) + session = TxSession( + request_id=rid, + params=params, + sender=_stub_sender(), + aux_slot=None, + **kwargs, + ) + for i in range(num_slices): + s = KVSlice( + is_last_slice=(i == num_slices - 1), + block_ids_per_layer_groups=[[i]], + ) + session.send(s, chunk_block_offset=i) + return session + + +def _make_rx_session(num_slices: int, rid: int = 42) -> RxSession: + """Create a real RxSession and receive num_slices slices into it.""" + params = _make_params(rid) + session = RxSession( + request_id=rid, + params=params, + receiver=_stub_receiver(), + aux_slot=None, + ) + for i in range(num_slices): + s = KVSlice( + is_last_slice=(i == num_slices - 1), + block_ids_per_layer_groups=[[i]], + ) + session.receive(s) + return session + + +# --------------------------------------------------------------------------- +# KVSendTask tests +# --------------------------------------------------------------------------- + + +def test_kv_send_task_chunk_block_offset(): + """KVSendTask stores chunk_block_offset correctly.""" + s = KVSlice(is_last_slice=False, block_ids_per_layer_groups=[[0, 1]]) + task = KVSendTask(s, _make_params(), slice_id=1, chunk_block_offset=512) + assert task.chunk_block_offset == 512 + assert task.slice_id == 1 + assert task._slice is s + + +def test_kv_send_task_default_offset(): + """Default chunk_block_offset is 0.""" + s = KVSlice(is_last_slice=True, block_ids_per_layer_groups=[[0]]) + task = KVSendTask(s, _make_params(), slice_id=0) + assert task.chunk_block_offset == 0 + + +# --------------------------------------------------------------------------- +# TxSession multi-slice status tests (real class) +# --------------------------------------------------------------------------- + + +def test_tx_session_status_init_until_all_transferred(): + """TxSession status is not KV_TRANSFERRED until ALL tasks complete.""" + session = _make_tx_session(3) + session.receiver_ready = True + assert session.status == SessionStatus.TRANSFERRING or session.status == SessionStatus.READY + + session.kv_tasks[0].status = TaskStatus.TRANSFERRED + assert session.status != SessionStatus.KV_TRANSFERRED + + session.kv_tasks[1].status = TaskStatus.TRANSFERRED + assert session.status != SessionStatus.KV_TRANSFERRED + + session.kv_tasks[2].status = TaskStatus.TRANSFERRED + assert session.status == SessionStatus.KV_TRANSFERRED + + +def test_tx_session_status_error_on_any_failure(): + """TxSession status is ERROR if any task fails.""" + session = _make_tx_session(3) + session.kv_tasks[0].status = TaskStatus.TRANSFERRED + session.kv_tasks[1].status = TaskStatus.ERROR + assert session.status == SessionStatus.ERROR + + +def test_tx_session_wait_complete_all_tasks(): + """TxSession.wait_complete blocks on all task futures.""" + session = _make_tx_session(3) + for task in session.kv_tasks: + task.future.set_result(AgentResult.SUCCESS) + task.status = TaskStatus.TRANSFERRED + + result = session.wait_complete(need_aux=False, timeout=1.0) + assert result == WaitResult.COMPLETED + + +def test_tx_session_wait_complete_fails_on_partial_failure(): + """TxSession.wait_complete returns FAILED if any task fails.""" + session = _make_tx_session(3) + session.kv_tasks[0].future.set_result(AgentResult.SUCCESS) + session.kv_tasks[0].status = TaskStatus.TRANSFERRED + session.kv_tasks[1].future.set_result(AgentResult.FAILED) + session.kv_tasks[1].status = TaskStatus.ERROR + session.kv_tasks[2].future.set_result(AgentResult.SUCCESS) + session.kv_tasks[2].status = TaskStatus.TRANSFERRED + + result = session.wait_complete(need_aux=False, timeout=1.0) + assert result == WaitResult.FAILED + + +# --------------------------------------------------------------------------- +# RxSession multi-slice status tests (real class) +# --------------------------------------------------------------------------- + + +def test_rx_session_status_checks_all_tasks(): + """RxSession status is KV_TRANSFERRED only when ALL tasks complete.""" + session = _make_rx_session(3) + assert session.status == SessionStatus.INIT + + session._kv_tasks[0].status = TaskStatus.TRANSFERRED + session._kv_tasks[1].status = TaskStatus.TRANSFERRING + assert session.status == SessionStatus.TRANSFERRING + + session._kv_tasks[1].status = TaskStatus.TRANSFERRED + session._kv_tasks[2].status = TaskStatus.TRANSFERRED + assert session.status == SessionStatus.KV_TRANSFERRED + + +def test_rx_session_status_error_on_any_failure(): + """RxSession status is ERROR if any task fails.""" + session = _make_rx_session(2) + session._kv_tasks[0].status = TaskStatus.TRANSFERRED + session._kv_tasks[1].status = TaskStatus.ERROR + assert session.status == SessionStatus.ERROR + + +def test_rx_session_process_aux_uses_last_task(): + """process_aux_agent_result uses the last task's expected_transfers.""" + session = _make_rx_session(3) + session._kv_tasks[0].expected_transfers = 99 + session._kv_tasks[1].expected_transfers = 99 + session._kv_tasks[2].expected_transfers = 1 + + session.process_aux_agent_result(0, AgentResult.SUCCESS) + assert session._aux_status == TaskStatus.TRANSFERRED + + +def test_rx_session_wait_complete_all_tasks(): + """RxSession.wait_complete blocks on all task futures.""" + session = _make_rx_session(3) + for task in session._kv_tasks: + task.future.set_result(AgentResult.SUCCESS) + task.status = TaskStatus.TRANSFERRED + + result = session.wait_complete(need_aux=False) + assert result == WaitResult.COMPLETED + + +def test_rx_session_wait_complete_fails_on_partial_failure(): + """RxSession.wait_complete returns FAILED if any task fails.""" + session = _make_rx_session(2) + session._kv_tasks[0].future.set_result(AgentResult.SUCCESS) + session._kv_tasks[0].status = TaskStatus.TRANSFERRED + session._kv_tasks[1].future.set_exception(RuntimeError("transfer failed")) + session._kv_tasks[1].status = TaskStatus.ERROR + + result = session.wait_complete(need_aux=False) + assert result == WaitResult.FAILED + + +# --------------------------------------------------------------------------- +# Chunk completion callback tests +# --------------------------------------------------------------------------- + + +def test_chunk_callback_enqueues_release(): + """Callback from _make_chunk_callback enqueues the correct release entries.""" + from tensorrt_llm._torch.disaggregation.transceiver import KvCacheTransceiverV2 + + transceiver = MagicMock() + transceiver._chunk_size_blocks = 64 + transceiver._pending_prefix_releases = queue.Queue() + transceiver._kv_cache_manager.release_prefix_blocks = MagicMock() + + callback = KvCacheTransceiverV2._make_chunk_callback(transceiver) + assert callback is not None + + callback(request_id=7, chunk_block_offset=0, num_blocks=64) + callback(request_id=7, chunk_block_offset=64, num_blocks=64) + callback(request_id=7, chunk_block_offset=128, num_blocks=64) + + results = [] + while not transceiver._pending_prefix_releases.empty(): + results.append(transceiver._pending_prefix_releases.get_nowait()) + + assert results == [(7, 64), (7, 128), (7, 192)] + + +def test_drain_pending_releases(): + """_drain_pending_releases calls release_prefix_blocks for each entry.""" + transceiver = MagicMock() + transceiver._pending_prefix_releases = queue.Queue() + transceiver._kv_cache_manager = MagicMock() + transceiver._pending_prefix_releases.put((10, 64)) + transceiver._pending_prefix_releases.put((10, 128)) + transceiver._pending_prefix_releases.put((20, 32)) + + from tensorrt_llm._torch.disaggregation.transceiver import KvCacheTransceiverV2 + + KvCacheTransceiverV2._drain_pending_releases(transceiver) + + calls = transceiver._kv_cache_manager.release_prefix_blocks.call_args_list + assert len(calls) == 3 + assert calls[0].args == (10, 64) + assert calls[1].args == (10, 128) + assert calls[2].args == (20, 32) + + +@pytest.mark.parametrize( + "has_release,chunk_size,expected_none", + [ + (False, 64, True), + (True, None, True), + (False, None, True), + (True, 64, False), + ], + ids=["no_release_api", "no_chunking", "neither", "with_release_and_chunking"], +) +def test_make_chunk_callback_conditions(has_release, chunk_size, expected_none): + """_make_chunk_callback returns None unless both release API and chunking enabled.""" + from tensorrt_llm._torch.disaggregation.transceiver import KvCacheTransceiverV2 + + transceiver = MagicMock() + transceiver._chunk_size_blocks = chunk_size + transceiver._pending_prefix_releases = queue.Queue() + if has_release: + transceiver._kv_cache_manager.release_prefix_blocks = MagicMock() + else: + del transceiver._kv_cache_manager.release_prefix_blocks + + result = KvCacheTransceiverV2._make_chunk_callback(transceiver) + assert (result is None) == expected_none + + +def test_chunk_callback_then_drain(): + """End-to-end: callback enqueues, drain calls release_prefix_blocks.""" + from tensorrt_llm._torch.disaggregation.transceiver import KvCacheTransceiverV2 + + transceiver = MagicMock() + transceiver._chunk_size_blocks = 4 + transceiver._pending_prefix_releases = queue.Queue() + transceiver._kv_cache_manager.release_prefix_blocks = MagicMock() + + callback = KvCacheTransceiverV2._make_chunk_callback(transceiver) + assert callback is not None + + callback(request_id=1, chunk_block_offset=0, num_blocks=4) + callback(request_id=1, chunk_block_offset=4, num_blocks=4) + callback(request_id=1, chunk_block_offset=8, num_blocks=2) + + KvCacheTransceiverV2._drain_pending_releases(transceiver) + + calls = transceiver._kv_cache_manager.release_prefix_blocks.call_args_list + assert len(calls) == 3 + assert calls[0].args == (1, 4) + assert calls[1].args == (1, 8) + assert calls[2].args == (1, 10) + + +# --------------------------------------------------------------------------- +# Mid-transfer chunk failure tests +# --------------------------------------------------------------------------- + + +def test_tx_session_mid_chunk_failure(): + """If one chunk fails mid-transfer, the session reports ERROR.""" + session = _make_tx_session(4) + + session.kv_tasks[0].future.set_result(AgentResult.SUCCESS) + session.kv_tasks[0].status = TaskStatus.TRANSFERRED + session.kv_tasks[1].future.set_result(AgentResult.SUCCESS) + session.kv_tasks[1].status = TaskStatus.TRANSFERRED + session.kv_tasks[2].future.set_exception(RuntimeError("RDMA failed")) + session.kv_tasks[2].status = TaskStatus.ERROR + session.kv_tasks[3].future.set_result(AgentResult.SUCCESS) + session.kv_tasks[3].status = TaskStatus.TRANSFERRED + + assert session.status == SessionStatus.ERROR + result = session.wait_complete(need_aux=False, timeout=1.0) + assert result == WaitResult.FAILED + + +def test_rx_session_mid_chunk_failure(): + """If one chunk fails mid-transfer on receiver, the session reports ERROR.""" + session = _make_rx_session(4) + + session._kv_tasks[0].future.set_result(AgentResult.SUCCESS) + session._kv_tasks[0].status = TaskStatus.TRANSFERRED + session._kv_tasks[1].future.set_exception(RuntimeError("RDMA failed")) + session._kv_tasks[1].status = TaskStatus.ERROR + session._kv_tasks[2].future.set_result(AgentResult.SUCCESS) + session._kv_tasks[2].status = TaskStatus.TRANSFERRED + session._kv_tasks[3].future.set_result(AgentResult.SUCCESS) + session._kv_tasks[3].status = TaskStatus.TRANSFERRED + + assert session.status == SessionStatus.ERROR + result = session.wait_complete(need_aux=False) + assert result == WaitResult.FAILED diff --git a/tests/unittest/disaggregated/test_kv_transfer.py b/tests/unittest/disaggregated/test_kv_transfer.py index caf8dcb3f6d6..a9110371d084 100644 --- a/tests/unittest/disaggregated/test_kv_transfer.py +++ b/tests/unittest/disaggregated/test_kv_transfer.py @@ -148,6 +148,69 @@ def test_session_status_enum(): assert len(SessionStatus) == 7 +# --------------------------------------------------------------------------- +# Chunked KV slice creation tests +# --------------------------------------------------------------------------- + + +def _chunk_block_ids(all_block_ids, chunk_size_blocks): + """Call the real _create_kv_slices via a mock transceiver.""" + from unittest.mock import MagicMock + + from tensorrt_llm._torch.disaggregation.transceiver import KvCacheTransceiverV2 + + transceiver = MagicMock() + transceiver._chunk_size_blocks = chunk_size_blocks + transceiver._collect_block_ids = MagicMock(return_value=all_block_ids) + transceiver._create_kv_slices = KvCacheTransceiverV2._create_kv_slices.__get__(transceiver) + + req = MagicMock() + return transceiver._create_kv_slices(req) + + +@pytest.mark.parametrize( + "all_block_ids,chunk_size,expected_num_slices", + [ + ([[0, 1, 2, 3, 4, 5, 6, 7]], None, 1), + ([[0, 1, 2, 3, 4, 5, 6, 7]], 4, 2), + ([list(range(10))], 4, 3), + ([[], []], 4, 1), + ([[0, 1, 2]], 64, 1), + ], + ids=["no_chunking", "even_split", "uneven_split", "empty_blocks", "chunk_larger_than_total"], +) +def test_create_kv_slices_basic(all_block_ids, chunk_size, expected_num_slices): + """Chunking produces the expected number of slices.""" + slices = _chunk_block_ids(all_block_ids, chunk_size_blocks=chunk_size) + assert len(slices) == expected_num_slices + assert slices[-1].is_last_slice is True + if expected_num_slices > 1: + for s in slices[:-1]: + assert s.is_last_slice is False + + +def test_create_kv_slices_integrity_check(): + """Reassembled block IDs from all slices must match the original.""" + all_block_ids = [list(range(17)), list(range(5))] + slices = _chunk_block_ids(all_block_ids, chunk_size_blocks=4) + for lg_idx, original in enumerate(all_block_ids): + reassembled = [] + for s in slices: + reassembled.extend(s.block_ids_per_layer_groups[lg_idx]) + assert reassembled == original + + +def test_create_kv_slices_multiple_layer_groups(): + """Different layer groups with different block counts produce correct chunking.""" + all_block_ids = [list(range(8)), list(range(3))] + slices = _chunk_block_ids(all_block_ids, chunk_size_blocks=4) + assert len(slices) == 2 + assert slices[0].block_ids_per_layer_groups[0] == [0, 1, 2, 3] + assert slices[1].block_ids_per_layer_groups[0] == [4, 5, 6, 7] + assert slices[0].block_ids_per_layer_groups[1] == [0, 1, 2] + assert slices[1].block_ids_per_layer_groups[1] == [] + + def create_transfer_worker_setup( ctx_tp: int, ctx_pp: int, @@ -1277,7 +1340,7 @@ def test_session_cancel_before_send(): @pytest.mark.timeout(60) def test_session_cancel_after_send(): - """TxSession cancelled after send() queues INIT tasks; future raises.""" + """TxSession cancelled after send() queues INIT tasks fails the event wait.""" tensorrt_llm.logger.set_level("debug") setup = create_transfer_worker_setup( ctx_tp=1, @@ -1312,19 +1375,237 @@ def test_session_cancel_after_send(): page_table = ctx_transfer_worker._rank_info.page_table block_ids_per_groups = [np.array([], dtype=np.int64) for _ in page_table.layer_groups] kv_slice = KVSlice(is_last_slice=True, block_ids_per_layer_groups=block_ids_per_groups) - future = tx_session.send(kv_slice) + tx_session.send(kv_slice) # No receiver registered yet; task is INIT. tx_session.cancel() assert tx_session.status == SessionStatus.CANCELLED assert tx_session.has_failed() - # Future for the cancelled INIT task must raise. - with pytest.raises(Exception): - future.result(timeout=5.0) + assert tx_session.wait_complete() == WaitResult.FAILED tx_session.close() finally: ctx_transfer_worker.shutdown() + + +def _setup_chunked_request(setup, ctx_request_id, gen_request_id, request_len): + """Create requests, allocate KV, and collect block IDs for chunked transfer tests.""" + ctx_transfer_workers = setup["ctx_transfer_workers"] + ctx_kv_cache_managers = setup["ctx_kv_cache_managers"] + gen_transfer_workers = setup["gen_transfer_workers"] + gen_kv_cache_managers = setup["gen_kv_cache_managers"] + ctx_info_endpoint = setup["ctx_info_endpoint"] + use_v2 = setup["use_v2"] + tokens_per_block = setup["tokens_per_block"] + + sampling_params = SamplingParams() + unique_rid = uuid.uuid4().int & 0x7FFFFFFFFFFFFFFF + + ctx_request = LlmRequest( + request_id=ctx_request_id, + max_new_tokens=1, + input_tokens=list(range(request_len)), + sampling_config=tensorrt_llm.bindings.SamplingConfig( + sampling_params._get_sampling_config() + ), + is_streaming=False, + llm_request_type=LlmRequestType.LLMREQUEST_TYPE_CONTEXT_ONLY, + ) + ctx_request.py_disaggregated_params = DisaggregatedParams(disagg_request_id=unique_rid) + + gen_request = LlmRequest( + request_id=gen_request_id, + max_new_tokens=1, + input_tokens=list(range(request_len)), + sampling_config=tensorrt_llm.bindings.SamplingConfig( + sampling_params._get_sampling_config() + ), + is_streaming=False, + llm_request_type=LlmRequestType.LLMREQUEST_TYPE_GENERATION_ONLY, + ) + gen_request.py_disaggregated_params = DisaggregatedParams( + ctx_request_id=ctx_request.py_request_id, + ctx_dp_rank=0, + ctx_info_endpoint=ctx_info_endpoint, + disagg_request_id=unique_rid, + ) + + ctx_kv_caches, gen_kv_caches = [], [] + for mgr in ctx_kv_cache_managers: + if use_v2: + kv = mgr._create_kv_cache(ctx_request.py_request_id, None, None) + assert kv.resume(torch.cuda.current_stream().cuda_stream) + assert kv.resize(request_len) + ctx_kv_caches.append(kv) + else: + mgr.impl.add_sequence_batch( + [(ctx_request.py_request_id, request_len, 1)], [ctx_request] + ) + + for mgr in gen_kv_cache_managers: + if use_v2: + kv = mgr._create_kv_cache(gen_request.py_request_id, None, None) + assert kv.resume(torch.cuda.current_stream().cuda_stream) + assert kv.resize(request_len) + gen_kv_caches.append(kv) + else: + mgr.impl.add_sequence_batch( + [(gen_request.py_request_id, request_len, 1)], [gen_request] + ) + + ctx_block_ids = [ + get_block_ids_per_layer_groups(mgr, tw, ctx_request.py_request_id, use_v2, tokens_per_block) + for mgr, tw in zip(ctx_kv_cache_managers, ctx_transfer_workers) + ] + gen_block_ids = [ + get_block_ids_per_layer_groups(mgr, tw, gen_request.py_request_id, use_v2, tokens_per_block) + for mgr, tw in zip(gen_kv_cache_managers, gen_transfer_workers) + ] + + return { + "ctx_request": ctx_request, + "gen_request": gen_request, + "ctx_kv_caches": ctx_kv_caches, + "gen_kv_caches": gen_kv_caches, + "ctx_block_ids": ctx_block_ids, + "gen_block_ids": gen_block_ids, + } + + +def _verify_and_cleanup_chunked(setup, ctx_info, sender_sessions, receiver_sessions): + """Shared verification and cleanup for chunked transfer tests.""" + ctx_kv_cache_managers = setup["ctx_kv_cache_managers"] + gen_kv_cache_managers = setup["gen_kv_cache_managers"] + use_v2 = setup["use_v2"] + + ctx_block_ids = ctx_info["ctx_block_ids"] + gen_block_ids = ctx_info["gen_block_ids"] + + for session in sender_sessions: + assert session.status == SessionStatus.KV_TRANSFERRED + for session in receiver_sessions: + assert session.status == SessionStatus.KV_TRANSFERRED + + num_layer_groups = len(ctx_block_ids[0]) + for lg_id in range(num_layer_groups): + ctx_data = [ + get_block_data(mgr, bids[lg_id], lg_id, use_v2, ctx_info["ctx_request"].py_request_id) + for mgr, bids in zip(ctx_kv_cache_managers, ctx_block_ids) + ] + gen_data = [ + get_block_data(mgr, bids[lg_id], lg_id, use_v2, ctx_info["gen_request"].py_request_id) + for mgr, bids in zip(gen_kv_cache_managers, gen_block_ids) + ] + for c, g in zip(ctx_data, gen_data): + assert c.equal(g), f"Layer group {lg_id}: data mismatch with chunked transfer" + + for s in receiver_sessions: + s.close() + for s in sender_sessions: + s.close() + if use_v2: + torch.cuda.current_stream().synchronize() + for kv in ctx_info["ctx_kv_caches"]: + kv.close() + for kv in ctx_info["gen_kv_caches"]: + kv.close() + + +def add_and_verify_chunked_request( + setup, + ctx_request_id, + gen_request_id, + request_len, + chunk_size_blocks, +): + """Chunked transfer variant: sender sends N slices, receiver sends 1.""" + import math + + ctx_transfer_workers = setup["ctx_transfer_workers"] + gen_transfer_workers = setup["gen_transfer_workers"] + + ctx_info = _setup_chunked_request(setup, ctx_request_id, gen_request_id, request_len) + ctx_block_ids = ctx_info["ctx_block_ids"] + gen_block_ids = ctx_info["gen_block_ids"] + + sender_sessions = [tw.create_tx_session(ctx_info["ctx_request"]) for tw in ctx_transfer_workers] + for sender_session, block_ids_per_groups in zip(sender_sessions, ctx_block_ids): + max_blocks = max(len(ids) for ids in block_ids_per_groups) + num_chunks = math.ceil(max_blocks / chunk_size_blocks) + chunk_offset = 0 + for chunk_idx in range(num_chunks): + start = chunk_idx * chunk_size_blocks + end = start + chunk_size_blocks + is_last = chunk_idx == num_chunks - 1 + chunk_block_ids = [ids[start:end] for ids in block_ids_per_groups] + kv_slice = KVSlice( + is_last_slice=is_last, + block_ids_per_layer_groups=chunk_block_ids, + ) + sender_session.send(kv_slice, chunk_block_offset=chunk_offset) + chunk_offset += max(len(ids) for ids in chunk_block_ids) + + receiver_sessions = [ + tw.create_rx_session(ctx_info["gen_request"]) for tw in gen_transfer_workers + ] + for recv_session, block_ids_per_groups in zip(receiver_sessions, gen_block_ids): + full_slice = KVSlice( + is_last_slice=True, + block_ids_per_layer_groups=block_ids_per_groups, + ) + recv_session.receive(full_slice) + + for session in sender_sessions: + result = session.wait_complete() + assert result == WaitResult.COMPLETED, f"tx wait_complete returned {result}" + for session in receiver_sessions: + result = session.wait_complete(blocking=True) + assert result == WaitResult.COMPLETED, f"rx wait_complete returned {result}" + + _verify_and_cleanup_chunked(setup, ctx_info, sender_sessions, receiver_sessions) + + +CHUNKED_TEST_CONFIGS = [ + (1, 1, False, 1, 1, False, False, True, "v2_tp1_pp1_chunked"), + (1, 1, False, 1, 1, False, False, False, "v1_tp1_pp1_chunked"), +] + + +@pytest.mark.timeout(120) +@pytest.mark.parametrize( + "ctx_tp,ctx_pp,ctx_enable_dp,gen_tp,gen_pp,gen_enable_dp,is_mla,use_v2", + [(c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]) for c in CHUNKED_TEST_CONFIGS], + ids=[c[8] for c in CHUNKED_TEST_CONFIGS], +) +def test_transfer_worker_chunked( + ctx_tp, ctx_pp, ctx_enable_dp, gen_tp, gen_pp, gen_enable_dp, is_mla, use_v2 +): + """Test transfer worker with sender-side chunking for V1 and V2.""" + tensorrt_llm.logger.set_level("info") + logger.info(f"Test transfer worker {'V2' if use_v2 else 'V1'} with chunked transfer") + + setup = create_transfer_worker_setup( + ctx_tp=ctx_tp, + ctx_pp=ctx_pp, + ctx_enable_dp=ctx_enable_dp, + gen_tp=gen_tp, + gen_pp=gen_pp, + gen_enable_dp=gen_enable_dp, + is_mla=is_mla, + use_v2=use_v2, + ) + + request_len = setup["request_len"] + tokens_per_block = setup["tokens_per_block"] + total_blocks = (request_len + tokens_per_block - 1) // tokens_per_block + chunk_size = max(1, total_blocks // 2) + + try: + add_and_verify_chunked_request(setup, 0, 1, request_len, chunk_size_blocks=chunk_size) + add_and_verify_chunked_request(setup, 2, 3, request_len * 2, chunk_size_blocks=chunk_size) + finally: + for worker in setup["ctx_transfer_workers"]: + worker.shutdown() for worker in setup["gen_transfer_workers"]: worker.shutdown() diff --git a/tests/unittest/llmapi/test_llm_args.py b/tests/unittest/llmapi/test_llm_args.py index d0c37dc581a3..0e4c7185c5dd 100644 --- a/tests/unittest/llmapi/test_llm_args.py +++ b/tests/unittest/llmapi/test_llm_args.py @@ -1761,6 +1761,23 @@ def test_cache_transceiver_config_arbitrary_args(self): CacheTransceiverConfig(backend="UCX", invalid_config="should_fail") assert "invalid_config" in str(exc_info.value) + def test_cache_transceiver_config_chunk_size_blocks(self): + """Test chunk_size_blocks field validation.""" + config = CacheTransceiverConfig(chunk_size_blocks=64) + assert config.chunk_size_blocks == 64 + + config_none = CacheTransceiverConfig(chunk_size_blocks=None) + assert config_none.chunk_size_blocks is None + + config_default = CacheTransceiverConfig() + assert config_default.chunk_size_blocks is None + + with pytest.raises(pydantic_core._pydantic_core.ValidationError): + CacheTransceiverConfig(chunk_size_blocks=0) + + with pytest.raises(pydantic_core._pydantic_core.ValidationError): + CacheTransceiverConfig(chunk_size_blocks=-1) + def test_torch_compile_config_arbitrary_args(self): """Test that TorchCompileConfig rejects arbitrary arguments.""" # Valid arguments should work From 6f4fefd469cdb7ecb00e465cdf5acc661f9ffaf3 Mon Sep 17 00:00:00 2001 From: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com> Date: Tue, 7 Apr 2026 10:13:31 -0700 Subject: [PATCH 2/8] fix: preserve mamba_state_index in chunked slices and fix VSWA shared counter - Replace _collect_block_ids with _collect_base_slice to preserve the full KVSlice metadata (including mamba_state_index) through all new code paths: _create_kv_slices (sender) and request_and_receive_async (receiver). Without this, Mamba/hybrid-state model transfers would lose required state metadata. - Fix VSWA shared counter bug in WindowBlockManager::releasePrefixBlocks: snapshot mNumFrontBlocksRemoved before iterating window managers so each manager releases blocks from the same range. Previously the first manager advanced the shared counter, causing subsequent managers to skip their own blocks entirely. - Guard chunking integrity assertion with __debug__ to avoid O(N) CPU overhead on the hot path in optimized builds. - Add tests for mamba_state_index propagation through chunked slices. Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com> --- .../batch_manager/kvCacheManager.h | 7 +++-- .../batch_manager/kvCacheManager.cpp | 22 +++++++++++---- .../_torch/disaggregation/transceiver.py | 24 ++++++++++------ .../disaggregated/test_kv_transfer.py | 28 +++++++++++++++++-- 4 files changed, 61 insertions(+), 20 deletions(-) diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h index 7182cae16350..31b6300604c9 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h @@ -973,12 +973,13 @@ class WindowBlockManager std::optional releaseBlocks( GenerationRequest& sequence, OptionalRef llmRequest); - //! \brief Release the first numBlocks prefix blocks of a sequence. + //! \brief Release prefix blocks in range [startIdx, numBlocks) for a sequence. //! \details Used by disaggregated serving to free sender-side KV memory //! for blocks whose data has already been transferred. Reuses the //! detachFrontBlock mechanism (decRefCount + eviction policy release). - //! Cumulative: calling with 3 then 5 releases blocks 0-4 total. - void releasePrefixBlocks(GenerationRequest& sequence, SizeType32 numBlocks); + //! Called by BlockManager::releasePrefixBlocks which coordinates the + //! shared mNumFrontBlocksRemoved counter across all window managers. + void releasePrefixBlocks(GenerationRequest& sequence, SizeType32 startIdx, SizeType32 numBlocks); //! \brief Simulate freeing all blocks for that sequence to check impact on number of free blocks void schedulingReleaseBlocks(LlmRequest::RequestIdType requestId); diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index 028444cec48e..819bf004b4de 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -2899,9 +2899,19 @@ std::optional BlockManager::releaseBlocks( void BlockManager::releasePrefixBlocks(GenerationRequest& sequence, SizeType32 numBlocks) { + // Snapshot the counter before iterating so that every WindowBlockManager + // releases the same range. Without this, the first manager would advance + // the shared mNumFrontBlocksRemoved counter and subsequent managers would + // see the counter already at the target, skipping their own blocks. + SizeType32 const startIdx = sequence.getNumFrontBlocksRemoved(); for (auto& [_, manager] : mWindowBlockManagers) { - manager.releasePrefixBlocks(sequence, numBlocks); + manager.releasePrefixBlocks(sequence, startIdx, numBlocks); + } + // Advance the shared counter once, after all managers have released. + while (sequence.getNumFrontBlocksRemoved() < numBlocks) + { + sequence.removeFrontBlock(0); } } @@ -3717,7 +3727,7 @@ void WindowBlockManager::detachFrontBlock(GenerationRequest& sequence) sequence.getNumFrontBlocksRemoved(mWindowSize)); } -void WindowBlockManager::releasePrefixBlocks(GenerationRequest& sequence, SizeType32 numBlocks) +void WindowBlockManager::releasePrefixBlocks(GenerationRequest& sequence, SizeType32 startIdx, SizeType32 numBlocks) { TLLM_CHECK_WITH_INFO( sequence.getBeamWidth() == 1, "[kv cache manager] releasePrefixBlocks does not support beamWidth > 1"); @@ -3726,9 +3736,11 @@ void WindowBlockManager::releasePrefixBlocks(GenerationRequest& sequence, SizeTy auto& allocatedBlocks = mAllocatedBlocksPerSeq.at(requestId); SizeType32 const target = std::min(numBlocks, static_cast(allocatedBlocks.size())); - while (sequence.getNumFrontBlocksRemoved() < target) + // Release blocks in range [startIdx, target). The shared + // mNumFrontBlocksRemoved counter is advanced by BlockManager after + // all WindowBlockManagers have processed the same range. + for (SizeType32 blockIdx = startIdx; blockIdx < target; ++blockIdx) { - SizeType32 const blockIdx = sequence.getNumFrontBlocksRemoved(); auto& block = allocatedBlocks.at(blockIdx); TLLM_LOG_DEBUG("%s::releasePrefixBlocks - Releasing block %d from sequence %lu", mLogPrefix.c_str(), @@ -3742,8 +3754,6 @@ void WindowBlockManager::releasePrefixBlocks(GenerationRequest& sequence, SizeTy { mEvictionPolicy->releaseBlock(block); } - - sequence.removeFrontBlock(mWindowSize); } } diff --git a/tensorrt_llm/_torch/disaggregation/transceiver.py b/tensorrt_llm/_torch/disaggregation/transceiver.py index d75e64df9078..1476ae8a95fb 100644 --- a/tensorrt_llm/_torch/disaggregation/transceiver.py +++ b/tensorrt_llm/_torch/disaggregation/transceiver.py @@ -226,18 +226,20 @@ def _create_kv_slice( token_range=token_range, ) - def _collect_block_ids(self, req: LlmRequest) -> List[List[int]]: - """Collect all valid block IDs per layer group for a request. + def _collect_base_slice(self, req: LlmRequest) -> KVSlice: + """Collect a full KVSlice (including metadata) for a request. + + This returns the complete slice produced by ``_create_kv_slice``, + preserving fields like ``mamba_state_index`` that are required + for hybrid-state model transfers. Args: req: The LLM request whose KV cache block IDs to collect. Returns: - A list of block ID lists, one per layer group. Each inner - list contains the physical block IDs for that layer group, - filtered for sliding-window relevance. + A ``KVSlice`` with all metadata populated. """ - return self._create_kv_slice(req).block_ids_per_layer_groups + return self._create_kv_slice(req) def _create_kv_slices(self, req: LlmRequest) -> List[KVSlice]: """Create one or more KVSlice objects for a request. @@ -258,7 +260,7 @@ def _create_kv_slices(self, req: LlmRequest) -> List[KVSlice]: ValueError: If the reassembled block IDs from all slices do not match the original block IDs. """ - base_slice = self._create_kv_slice(req) + base_slice = self._collect_base_slice(req) all_block_ids = base_slice.block_ids_per_layer_groups if self._chunk_size_blocks is None: @@ -633,8 +635,12 @@ def request_and_receive_async(self, req: LlmRequest) -> None: req.state = LlmRequestState.DISAGG_GENERATION_TRANS_IN_PROGRESS session = self._transfer_worker.create_rx_session(req) self._recv_sessions[rid] = session - all_block_ids = self._collect_block_ids(req) - full_slice = KVSlice(is_last_slice=True, block_ids_per_layer_groups=all_block_ids) + base_slice = self._collect_base_slice(req) + full_slice = KVSlice( + is_last_slice=True, + block_ids_per_layer_groups=base_slice.block_ids_per_layer_groups, + mamba_state_index=base_slice.mamba_state_index, + ) session.receive(full_slice) self._recv_reqs[rid] = req diff --git a/tests/unittest/disaggregated/test_kv_transfer.py b/tests/unittest/disaggregated/test_kv_transfer.py index a9110371d084..999414e20a9f 100644 --- a/tests/unittest/disaggregated/test_kv_transfer.py +++ b/tests/unittest/disaggregated/test_kv_transfer.py @@ -153,15 +153,22 @@ def test_session_status_enum(): # --------------------------------------------------------------------------- -def _chunk_block_ids(all_block_ids, chunk_size_blocks): +def _chunk_block_ids(all_block_ids, chunk_size_blocks, mamba_state_index=None): """Call the real _create_kv_slices via a mock transceiver.""" from unittest.mock import MagicMock + from tensorrt_llm._torch.disaggregation.base.transfer import KVSlice from tensorrt_llm._torch.disaggregation.transceiver import KvCacheTransceiverV2 + base_slice = KVSlice( + is_last_slice=True, + block_ids_per_layer_groups=all_block_ids, + mamba_state_index=mamba_state_index, + ) + transceiver = MagicMock() transceiver._chunk_size_blocks = chunk_size_blocks - transceiver._collect_block_ids = MagicMock(return_value=all_block_ids) + transceiver._collect_base_slice = MagicMock(return_value=base_slice) transceiver._create_kv_slices = KvCacheTransceiverV2._create_kv_slices.__get__(transceiver) req = MagicMock() @@ -211,6 +218,23 @@ def test_create_kv_slices_multiple_layer_groups(): assert slices[1].block_ids_per_layer_groups[1] == [] +def test_create_kv_slices_preserves_mamba_state_index(): + """mamba_state_index is propagated to every chunk slice.""" + all_block_ids = [list(range(8))] + slices = _chunk_block_ids(all_block_ids, chunk_size_blocks=4, mamba_state_index=42) + assert len(slices) == 2 + for s in slices: + assert s.mamba_state_index == 42 + + +def test_create_kv_slices_none_mamba_state_index(): + """mamba_state_index=None is preserved when not set.""" + all_block_ids = [list(range(4))] + slices = _chunk_block_ids(all_block_ids, chunk_size_blocks=4) + assert len(slices) == 1 + assert slices[0].mamba_state_index is None + + def create_transfer_worker_setup( ctx_tp: int, ctx_pp: int, From 396127c424c5159eea65a4c52c139c9bddd2616d Mon Sep 17 00:00:00 2001 From: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com> Date: Tue, 7 Apr 2026 11:02:49 -0700 Subject: [PATCH 3/8] Address CodeRabbit review nitpicks - Update copyright year to 2026 in nanobind kvCacheManager.cpp - Add OnChunkTransferredCallback type alias for precise callback typing - Add strict=True to zip() calls in chunked transfer tests Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com> Made-with: Cursor --- .../_torch/disaggregation/native/transfer.py | 6 ++++-- tests/unittest/disaggregated/test_kv_transfer.py | 14 +++++++------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/tensorrt_llm/_torch/disaggregation/native/transfer.py b/tensorrt_llm/_torch/disaggregation/native/transfer.py index 1deb5d4f2d9d..d07578957f2c 100644 --- a/tensorrt_llm/_torch/disaggregation/native/transfer.py +++ b/tensorrt_llm/_torch/disaggregation/native/transfer.py @@ -57,6 +57,8 @@ AttentionTypeCpp = tensorrt_llm.bindings.internal.batch_manager.AttentionType LlmRequestType = tensorrt_llm.bindings.internal.batch_manager.LlmRequestType +OnChunkTransferredCallback = Callable[[int, int, int], None] + # Number of worker threads for KV transfer queues (default: 1) KV_TRANSFER_NUM_THREADS = int(os.environ.get("TRTLLM_KV_TRANSFER_NUM_THREADS", "1")) @@ -1128,7 +1130,7 @@ def __init__( timeout_s: Optional[float] = None, prompt_len: Optional[int] = None, beam_width: int = 1, - on_chunk_transferred: Optional[Callable] = None, + on_chunk_transferred: Optional[OnChunkTransferredCallback] = None, ): super().__init__( sender, @@ -2040,7 +2042,7 @@ def populate_instance_and_rank_info(self, endpoints: list[str], layer_num_per_pp def create_tx_session( self, request: LlmRequest, - on_chunk_transferred: Optional[Callable] = None, + on_chunk_transferred: Optional[OnChunkTransferredCallback] = None, ) -> TxSession: """Create a TxSession for the given request. diff --git a/tests/unittest/disaggregated/test_kv_transfer.py b/tests/unittest/disaggregated/test_kv_transfer.py index 999414e20a9f..b57f9fe66755 100644 --- a/tests/unittest/disaggregated/test_kv_transfer.py +++ b/tests/unittest/disaggregated/test_kv_transfer.py @@ -1479,11 +1479,11 @@ def _setup_chunked_request(setup, ctx_request_id, gen_request_id, request_len): ctx_block_ids = [ get_block_ids_per_layer_groups(mgr, tw, ctx_request.py_request_id, use_v2, tokens_per_block) - for mgr, tw in zip(ctx_kv_cache_managers, ctx_transfer_workers) + for mgr, tw in zip(ctx_kv_cache_managers, ctx_transfer_workers, strict=True) ] gen_block_ids = [ get_block_ids_per_layer_groups(mgr, tw, gen_request.py_request_id, use_v2, tokens_per_block) - for mgr, tw in zip(gen_kv_cache_managers, gen_transfer_workers) + for mgr, tw in zip(gen_kv_cache_managers, gen_transfer_workers, strict=True) ] return { @@ -1514,13 +1514,13 @@ def _verify_and_cleanup_chunked(setup, ctx_info, sender_sessions, receiver_sessi for lg_id in range(num_layer_groups): ctx_data = [ get_block_data(mgr, bids[lg_id], lg_id, use_v2, ctx_info["ctx_request"].py_request_id) - for mgr, bids in zip(ctx_kv_cache_managers, ctx_block_ids) + for mgr, bids in zip(ctx_kv_cache_managers, ctx_block_ids, strict=True) ] gen_data = [ get_block_data(mgr, bids[lg_id], lg_id, use_v2, ctx_info["gen_request"].py_request_id) - for mgr, bids in zip(gen_kv_cache_managers, gen_block_ids) + for mgr, bids in zip(gen_kv_cache_managers, gen_block_ids, strict=True) ] - for c, g in zip(ctx_data, gen_data): + for c, g in zip(ctx_data, gen_data, strict=True): assert c.equal(g), f"Layer group {lg_id}: data mismatch with chunked transfer" for s in receiver_sessions: @@ -1553,7 +1553,7 @@ def add_and_verify_chunked_request( gen_block_ids = ctx_info["gen_block_ids"] sender_sessions = [tw.create_tx_session(ctx_info["ctx_request"]) for tw in ctx_transfer_workers] - for sender_session, block_ids_per_groups in zip(sender_sessions, ctx_block_ids): + for sender_session, block_ids_per_groups in zip(sender_sessions, ctx_block_ids, strict=True): max_blocks = max(len(ids) for ids in block_ids_per_groups) num_chunks = math.ceil(max_blocks / chunk_size_blocks) chunk_offset = 0 @@ -1572,7 +1572,7 @@ def add_and_verify_chunked_request( receiver_sessions = [ tw.create_rx_session(ctx_info["gen_request"]) for tw in gen_transfer_workers ] - for recv_session, block_ids_per_groups in zip(receiver_sessions, gen_block_ids): + for recv_session, block_ids_per_groups in zip(receiver_sessions, gen_block_ids, strict=True): full_slice = KVSlice( is_last_slice=True, block_ids_per_layer_groups=block_ids_per_groups, From f09f29f1a9bbbca2bb590eeb428dc521f76f1c69 Mon Sep 17 00:00:00 2001 From: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com> Date: Wed, 8 Apr 2026 13:19:49 -0700 Subject: [PATCH 4/8] Address reviewer comments from eopXD and pcastonguay - Fix chunking integrity check: use np.array_equal() instead of == for numpy array comparison, raise ValueError instead of assert (eopXD comment on transceiver.py) - Add explicit VSWA limitation comment in BlockManager::releasePrefixBlocks documenting the single-window-size assumption (eopXD comment on kvCacheManager.cpp) - Auto-select Python transceiver when chunk_size_blocks is set and backend is NIXL/DEFAULT. The C++ transceiver does not support chunked transfer; this makes chunking work without requiring users to manually set transceiver_runtime="PYTHON" (pcastonguay comment on transceiver.py) Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com> Made-with: Cursor Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com> Made-with: Cursor --- .../batch_manager/kvCacheManager.cpp | 7 ++++++ .../_torch/pyexecutor/kv_cache_transceiver.py | 24 +++++++++++++++++-- tensorrt_llm/llmapi/llm_args.py | 8 ++++--- 3 files changed, 34 insertions(+), 5 deletions(-) diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index 819bf004b4de..9efbb8c025fa 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -2899,6 +2899,13 @@ std::optional BlockManager::releaseBlocks( void BlockManager::releasePrefixBlocks(GenerationRequest& sequence, SizeType32 numBlocks) { + // NOTE: This assumes a single window size (no VSWA). With different window + // sizes, each WindowBlockManager may have a different number of allocated + // blocks, so releasing the same numBlocks from all managers would need + // per-window-size handling. Disaggregated serving does not support VSWA + // today (gated by should_store_blocks: not is_vswa in the executor and + // beamWidth == 1 assertion in WindowBlockManager::releasePrefixBlocks). + // // Snapshot the counter before iterating so that every WindowBlockManager // releases the same range. Without this, the first manager would advance // the shared mNumFrontBlocksRemoved counter and subsequent managers would diff --git a/tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py b/tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py index 6367829d55ce..2210a4165d2a 100644 --- a/tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py +++ b/tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py @@ -70,12 +70,32 @@ def create_kv_cache_transceiver( f"UCX_CUDA_IPC_ENABLE_MNNVL=n, UCX_RNDV_SCHEME=put_zcopy and/or unset UCX_NET_DEVICES upon server " f"hangs or lower-than-expected performance.") + # Auto-select Python transceiver when chunk_size_blocks is set, + # since the C++ transceiver does not support chunked transfer. + # Only applies to NIXL/DEFAULT backends (the Python transceiver + # does not support UCX, MPI, or MOONCAKE). + use_python = cache_transceiver_config.transceiver_runtime == "PYTHON" + if (not use_python + and cache_transceiver_config.chunk_size_blocks is not None): + if cache_transceiver_config.backend in (None, "DEFAULT", "NIXL"): + logger.info( + "chunk_size_blocks is set; auto-selecting Python transceiver " + "for chunked KV cache transfer support") + use_python = True + else: + logger.warning( + f"chunk_size_blocks is set but backend " + f"'{cache_transceiver_config.backend}' requires the C++ " + f"transceiver, which does not support chunked transfer. " + f"chunk_size_blocks will be ignored. Use NIXL backend to " + f"enable chunked transfer.") + # Select transceiver implementation based on transceiver_runtime # transceiver_runtime == None or "CPP" -> use C++ transceiver (default) # transceiver_runtime == "PYTHON" -> use Python transceiver - if cache_transceiver_config.transceiver_runtime == "PYTHON": + if use_python: # Python transceiver currently only supports NIXL and DEFAULT backend - if cache_transceiver_config.backend not in ("DEFAULT", "NIXL"): + if cache_transceiver_config.backend not in (None, "DEFAULT", "NIXL"): raise ValueError( f"Python transceiver currently only supports NIXL or DEFAULT backend, " f"got {cache_transceiver_config.backend}. " diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 21da01b90186..1a285cfad10a 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -3374,9 +3374,11 @@ class CacheTransceiverConfig(StrictBaseModel, PybindMirror): "slice is transferred independently. The total data per chunk is " "approximately chunk_size_blocks * num_layer_groups * slot_bytes. " "This reduces per-transfer NIXL descriptor pressure for long " - "sequences. When None (default), the entire KV cache is transferred " - "in a single slice. Only effective with the Python transceiver " - "(transceiver_runtime='PYTHON').") + "sequences and enables early block release to free GPU memory " + "incrementally during transfer. When None (default), the entire " + "KV cache is transferred in a single slice. When set with NIXL " + "backend (default), the Python transceiver is auto-selected. " + "Not supported with UCX, MPI, or MOONCAKE backends.") def _to_pybind(self): # chunk_size_blocks is consumed by the Python transceiver only From d70941e41d93dec35c3c5bbd72f603942ea2ef31 Mon Sep 17 00:00:00 2001 From: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com> Date: Thu, 9 Apr 2026 15:38:57 -0700 Subject: [PATCH 5/8] Move chunk_block_offset into KVSlice dataclass Per reviewer feedback (chuangz0, Shixiaowei02): chunk_block_offset belongs as a member of KVSlice rather than a function parameter on send(). The KVSlice dataclass was designed to carry all slice metadata. - Add chunk_block_offset: int = 0 to KVSlice dataclass - Remove chunk_block_offset from TxSessionBase.send() signature - Remove chunk_block_offset from TxSession.send() signature - Remove chunk_block_offset from KVSendTask.__init__ - Read chunk_block_offset from task._slice in _build_kv_write_meta and _deliver_kv_to_agent callback - Set chunk_block_offset on each KVSlice in _create_kv_slices - Update all tests accordingly Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com> Made-with: Cursor Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com> --- .../batch_manager/kvCacheManager.h | 11 ++++ .../batch_manager/kvCacheManager.cpp | 15 ++++- .../_torch/disaggregation/base/transfer.py | 10 ++-- .../_torch/disaggregation/native/transfer.py | 14 ++--- .../_torch/disaggregation/transceiver.py | 40 +++++++++++-- .../_torch/pyexecutor/kv_cache_transceiver.py | 26 ++++++++- .../disaggregated/test_chunked_transfer.py | 56 ++++++++++++++++--- .../disaggregated/test_kv_transfer.py | 3 +- 8 files changed, 144 insertions(+), 31 deletions(-) diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h index 31b6300604c9..94e8cc1a4fa0 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h @@ -681,6 +681,17 @@ class GenerationRequest ++mNumFrontBlocksRemovedPerWindow.at(windowSize); } + //! \brief Advance ``mNumFrontBlocksRemoved`` without touching cache blocks. + //! \details Used by ``BlockManager::releasePrefixBlocks`` to advance the + //! shared front-block counter once after every ``WindowBlockManager`` has + //! processed the same prefix range. Has clearer intent than calling + //! ``removeFrontBlock`` with a sentinel ``windowSize`` value, and is robust + //! to future changes that consume the ``windowSize`` argument. + void incrementNumFrontBlocksRemoved() + { + ++mNumFrontBlocksRemoved; + } + void removeLastBlock(SizeType32 windowSize) { for (auto& beamBlockIds : mCacheBlockIds.at(windowSize)) diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index 9efbb8c025fa..c9f9686a7c50 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -2916,9 +2916,12 @@ void BlockManager::releasePrefixBlocks(GenerationRequest& sequence, SizeType32 n manager.releasePrefixBlocks(sequence, startIdx, numBlocks); } // Advance the shared counter once, after all managers have released. + // Uses incrementNumFrontBlocksRemoved (counter-only) instead of + // removeFrontBlock so the intent is explicit and we do not depend on + // removeFrontBlock ignoring its windowSize argument. while (sequence.getNumFrontBlocksRemoved() < numBlocks) { - sequence.removeFrontBlock(0); + sequence.incrementNumFrontBlocksRemoved(); } } @@ -3942,6 +3945,16 @@ std::optional KVCacheManager::removeSequence( void KVCacheManager::releasePrefixBlocks(RequestIdType requestId, SizeType32 numBlocks) { + // Hard precondition: BlockManager::releasePrefixBlocks advances the shared + // mNumFrontBlocksRemoved counter to numBlocks for every WindowBlockManager, + // even when a window has fewer than numBlocks allocated. Under variable + // sliding window attention (VSWA), that would cause WindowBlockManager:: + // releaseBlocks (called during removeSequence) to underrun rbegin() and + // skip tail blocks for the smaller window. Disagg serving already gates + // VSWA out, but we enforce the assumption here so the C++ API contract is + // self-defending instead of relying on caller discipline. + TLLM_CHECK_WITH_INFO( + !mBlockManager.isVariableWindow(), "releasePrefixBlocks does not support variable sliding window attention"); if (numBlocks <= 0) { return; diff --git a/tensorrt_llm/_torch/disaggregation/base/transfer.py b/tensorrt_llm/_torch/disaggregation/base/transfer.py index 222301e5b144..3ff67cc8d5c2 100644 --- a/tensorrt_llm/_torch/disaggregation/base/transfer.py +++ b/tensorrt_llm/_torch/disaggregation/base/transfer.py @@ -65,6 +65,7 @@ class KVSlice: ) # Physical block IDs per layer group, each np.ndarray(dtype=np.int64) is_last_slice: bool = False mamba_state_index: Optional[int] = None + chunk_block_offset: int = 0 class SessionStatus(Enum): @@ -158,15 +159,14 @@ def __init__(self, sender: SenderBase, args: SessionArgsBase): self._sender = sender @abstractmethod - def send(self, slice: KVSlice, chunk_block_offset: int = 0) -> None: + def send(self, slice: KVSlice) -> None: """Send a KV slice. Args: slice: The KV slice describing which source blocks to send. - chunk_block_offset: Block offset into the receiver's full - destination block list for this chunk. Used by sender-side - chunking to slice the receiver's destination blocks correctly. - Defaults to 0 for monolithic transfer. + The slice's ``chunk_block_offset`` field indicates the offset + into the receiver's destination block list for sender-side + chunking. """ ... diff --git a/tensorrt_llm/_torch/disaggregation/native/transfer.py b/tensorrt_llm/_torch/disaggregation/native/transfer.py index d07578957f2c..da490c953e8e 100644 --- a/tensorrt_llm/_torch/disaggregation/native/transfer.py +++ b/tensorrt_llm/_torch/disaggregation/native/transfer.py @@ -208,11 +208,10 @@ class KVSendTask(SendTaskBase): Args: kv_slice: The KV slice describing which blocks to transfer. + The slice's ``chunk_block_offset`` field indicates the + offset into the receiver's destination block list. params: Disaggregated serving parameters for this request. slice_id: Index of this slice within the session's task list. - chunk_block_offset: Block offset into the receiver's full - destination block list. Used by sender-side chunking to - slice the receiver's destination blocks correctly. """ def __init__( @@ -222,7 +221,6 @@ def __init__( slice_id: int, prompt_len: Optional[int] = None, beam_width: int = 1, - chunk_block_offset: int = 0, ) -> None: super().__init__(params) self.slice_id = slice_id @@ -230,7 +228,6 @@ def __init__( self._slice = kv_slice self._prompt_len = prompt_len self._beam_width = beam_width - self.chunk_block_offset = chunk_block_offset class Sender(SenderBase): @@ -587,7 +584,7 @@ def _deliver_kv_to_agent(self, write_meta: WriteMeta): ) session._on_chunk_transferred( request_id=session.request_id, - chunk_block_offset=task.chunk_block_offset, + chunk_block_offset=task._slice.chunk_block_offset, num_blocks=num_blocks, ) except Exception as e: @@ -727,7 +724,7 @@ def _build_kv_write_meta(self, task: KVSendTask, req_info: RecvReqInfo) -> Write dst_block_ids_per_groups = req_info.block_ids_per_layer_groups src_block_ids_per_groups = task._slice.block_ids_per_layer_groups - chunk_offset = task.chunk_block_offset + chunk_offset = task._slice.chunk_block_offset for (self_lg, self_pi), (peer_lg, peer_pi) in pool_mapping.items(): src_block_ids = src_block_ids_per_groups[self_lg] full_dst_block_ids = dst_block_ids_per_groups[peer_lg] @@ -1182,7 +1179,7 @@ def status(self) -> SessionStatus: return SessionStatus.TRANSFERRING return SessionStatus.READY if self.receiver_ready else SessionStatus.INIT - def send(self, slice: KVSlice, chunk_block_offset: int = 0) -> None: + def send(self, slice: KVSlice) -> None: with self.lock: params = self._base_args.params slice_id = len(self.kv_tasks) @@ -1192,7 +1189,6 @@ def send(self, slice: KVSlice, chunk_block_offset: int = 0) -> None: slice_id, prompt_len=self._base_args.prompt_len, beam_width=self._base_args.beam_width, - chunk_block_offset=chunk_block_offset, ) task._unique_rid = self.disagg_request_id self.kv_tasks.append(task) diff --git a/tensorrt_llm/_torch/disaggregation/transceiver.py b/tensorrt_llm/_torch/disaggregation/transceiver.py index 1476ae8a95fb..91371f21391e 100644 --- a/tensorrt_llm/_torch/disaggregation/transceiver.py +++ b/tensorrt_llm/_torch/disaggregation/transceiver.py @@ -146,6 +146,10 @@ def shutdown(self): if getattr(self, "_shutdown", False): return self._shutdown = True + # Drain any pending prefix-release entries before tearing down sessions + # so memory frees in the same shutdown step instead of leaking until + # removeSequence cleans up at session close. + self._drain_pending_releases() for session in list(self._send_sessions.values()): session.close() for session in list(self._recv_sessions.values()): @@ -272,6 +276,7 @@ def _create_kv_slices(self, req: LlmRequest) -> List[KVSlice]: num_chunks = math.ceil(max_blocks / self._chunk_size_blocks) slices: List[KVSlice] = [] + block_offset = 0 for chunk_idx in range(num_chunks): start = chunk_idx * self._chunk_size_blocks end = start + self._chunk_size_blocks @@ -284,8 +289,16 @@ def _create_kv_slices(self, req: LlmRequest) -> List[KVSlice]: block_ids_per_layer_groups=chunk_block_ids, mamba_state_index=base_slice.mamba_state_index, token_range=base_slice.token_range, + chunk_block_offset=block_offset, ) ) + # Use the max length across layer groups to advance the receiver + # offset. This is the contract that lets receiver-side slicing in + # native/transfer.py (`_build_kv_write_meta`) trim the per-LG dst + # range with `len(src_block_ids)`, so asymmetric layer groups still + # land at the right destination position even though the offset is + # shared across groups. + block_offset += max((len(ids) for ids in chunk_block_ids), default=0) for lg_idx, original_ids in enumerate(all_block_ids): reassembled = np.concatenate([s.block_ids_per_layer_groups[lg_idx] for s in slices]) @@ -318,8 +331,24 @@ def _make_chunk_callback(self) -> Optional[Callable]: """ if self._chunk_size_blocks is None: return None + manager_name = type(self._kv_cache_manager).__name__ if not hasattr(self._kv_cache_manager, "release_prefix_blocks"): + # Surface the gate decision in logs so a typo or missing wrapper on + # the manager side is observable at startup, not silent. + logger.warning( + "Chunked KV transfer is enabled (chunk_size_blocks=%s) but %s " + "does not implement release_prefix_blocks; early prefix block " + "release is disabled. Blocks will be freed at session teardown.", + self._chunk_size_blocks, + manager_name, + ) return None + logger.info( + "Chunked KV transfer with early prefix block release enabled " + "(chunk_size_blocks=%s, manager=%s).", + self._chunk_size_blocks, + manager_name, + ) release_queue = self._pending_prefix_releases @@ -515,6 +544,11 @@ def _build_to_process( return to_process def _close_failed_sessions(self, sessions: dict, reqs: dict, failed: list): + # Drain pending prefix releases before closing failed sessions so that + # already-completed chunks of healthy sister sessions free memory now + # rather than waiting for the next check_context_transfer_status pass. + # No-op when the queue is empty, including on the gen-side path. + self._drain_pending_releases() for rid in failed: reqs[rid].state = LlmRequestState.DISAGG_TRANS_ERROR sessions[rid].close() @@ -574,12 +608,8 @@ def _finalize_send(self, req: LlmRequest, session: TxSessionBase): def respond_and_send_async(self, req: LlmRequest): session = self._get_or_create_send_session(req) req.state = LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS - chunk_block_offset = 0 for kv_slice in self._create_kv_slices(req): - session.send(kv_slice, chunk_block_offset=chunk_block_offset) - chunk_block_offset += max( - (len(ids) for ids in kv_slice.block_ids_per_layer_groups), default=0 - ) + session.send(kv_slice) self._finalize_send(req, session) @nvtx_range("KvCacheTransceiverV2.request_and_receive_sync") diff --git a/tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py b/tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py index 2210a4165d2a..2c88d58b3720 100644 --- a/tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py +++ b/tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py @@ -78,9 +78,15 @@ def create_kv_cache_transceiver( if (not use_python and cache_transceiver_config.chunk_size_blocks is not None): if cache_transceiver_config.backend in (None, "DEFAULT", "NIXL"): - logger.info( - "chunk_size_blocks is set; auto-selecting Python transceiver " - "for chunked KV cache transfer support") + # Use warning (not info) so users notice the transceiver swap and + # the implied perf / staging-buffer characteristics change. Set + # transceiver_runtime='CPP' explicitly to opt out (and lose + # chunked transfer + early block release). + logger.warning( + "chunk_size_blocks is set; auto-selecting the Python " + "transceiver instead of the C++ transceiver to enable " + "chunked KV cache transfer + early block release. " + "Set transceiver_runtime='CPP' to disable this auto-selection.") use_python = True else: logger.warning( @@ -90,6 +96,20 @@ def create_kv_cache_transceiver( f"chunk_size_blocks will be ignored. Use NIXL backend to " f"enable chunked transfer.") + # Warn when chunk_size_blocks is below the recommended floor. The Pydantic + # field is PositiveInt (>=1), but values below ~16 push the per-chunk RDMA + # overhead into the regime where it dominates transfer throughput. + _MIN_RECOMMENDED_CHUNK_SIZE_BLOCKS = 16 + if (cache_transceiver_config.chunk_size_blocks is not None + and cache_transceiver_config.chunk_size_blocks + < _MIN_RECOMMENDED_CHUNK_SIZE_BLOCKS): + logger.warning( + f"chunk_size_blocks={cache_transceiver_config.chunk_size_blocks} " + f"is below the recommended floor of " + f"{_MIN_RECOMMENDED_CHUNK_SIZE_BLOCKS}; per-chunk RDMA overhead " + f"may dominate transfer throughput. Consider 64-128 for " + f"long-context workloads (ISL >= 32K).") + # Select transceiver implementation based on transceiver_runtime # transceiver_runtime == None or "CPP" -> use C++ transceiver (default) # transceiver_runtime == "PYTHON" -> use Python transceiver diff --git a/tests/unittest/disaggregated/test_chunked_transfer.py b/tests/unittest/disaggregated/test_chunked_transfer.py index da6d21194cf5..c140be329a34 100644 --- a/tests/unittest/disaggregated/test_chunked_transfer.py +++ b/tests/unittest/disaggregated/test_chunked_transfer.py @@ -74,8 +74,9 @@ def _make_tx_session(num_slices: int, rid: int = 42, **kwargs) -> TxSession: s = KVSlice( is_last_slice=(i == num_slices - 1), block_ids_per_layer_groups=[[i]], + chunk_block_offset=i, ) - session.send(s, chunk_block_offset=i) + session.send(s) return session @@ -103,19 +104,19 @@ def _make_rx_session(num_slices: int, rid: int = 42) -> RxSession: def test_kv_send_task_chunk_block_offset(): - """KVSendTask stores chunk_block_offset correctly.""" - s = KVSlice(is_last_slice=False, block_ids_per_layer_groups=[[0, 1]]) - task = KVSendTask(s, _make_params(), slice_id=1, chunk_block_offset=512) - assert task.chunk_block_offset == 512 + """KVSendTask reads chunk_block_offset from the slice.""" + s = KVSlice(is_last_slice=False, block_ids_per_layer_groups=[[0, 1]], chunk_block_offset=512) + task = KVSendTask(s, _make_params(), slice_id=1) + assert task._slice.chunk_block_offset == 512 assert task.slice_id == 1 assert task._slice is s def test_kv_send_task_default_offset(): - """Default chunk_block_offset is 0.""" + """Default chunk_block_offset on KVSlice is 0.""" s = KVSlice(is_last_slice=True, block_ids_per_layer_groups=[[0]]) task = KVSendTask(s, _make_params(), slice_id=0) - assert task.chunk_block_offset == 0 + assert task._slice.chunk_block_offset == 0 # --------------------------------------------------------------------------- @@ -281,6 +282,47 @@ def test_drain_pending_releases(): assert calls[2].args == (20, 32) +def test_drain_pending_releases_tolerates_stale_rid(): + """A pending release for a request that was already removed must be a no-op. + + Models the production race where the sender worker enqueues a release + after the main thread has already torn the sequence down via + ``removeSequence``. ``KVCacheManager.release_prefix_blocks`` returns + early in that case, so ``_drain_pending_releases`` must not raise. + """ + from tensorrt_llm._torch.disaggregation.transceiver import KvCacheTransceiverV2 + + transceiver = MagicMock() + transceiver._pending_prefix_releases = queue.Queue() + transceiver._kv_cache_manager = MagicMock() + # Manager wrapper is a no-op for unknown rids; drain must propagate that + # no-op semantics rather than crashing. + transceiver._kv_cache_manager.release_prefix_blocks = MagicMock(return_value=None) + + transceiver._pending_prefix_releases.put((9999, 64)) # unknown rid + transceiver._pending_prefix_releases.put((9999, 128)) + + KvCacheTransceiverV2._drain_pending_releases(transceiver) + + calls = transceiver._kv_cache_manager.release_prefix_blocks.call_args_list + assert len(calls) == 2 + assert calls[0].args == (9999, 64) + assert calls[1].args == (9999, 128) + + +def test_drain_pending_releases_empty_queue_is_noop(): + """Drain on an empty queue is a no-op and never calls the manager.""" + from tensorrt_llm._torch.disaggregation.transceiver import KvCacheTransceiverV2 + + transceiver = MagicMock() + transceiver._pending_prefix_releases = queue.Queue() + transceiver._kv_cache_manager = MagicMock() + + KvCacheTransceiverV2._drain_pending_releases(transceiver) + + transceiver._kv_cache_manager.release_prefix_blocks.assert_not_called() + + @pytest.mark.parametrize( "has_release,chunk_size,expected_none", [ diff --git a/tests/unittest/disaggregated/test_kv_transfer.py b/tests/unittest/disaggregated/test_kv_transfer.py index b57f9fe66755..1728feac00ec 100644 --- a/tests/unittest/disaggregated/test_kv_transfer.py +++ b/tests/unittest/disaggregated/test_kv_transfer.py @@ -1565,8 +1565,9 @@ def add_and_verify_chunked_request( kv_slice = KVSlice( is_last_slice=is_last, block_ids_per_layer_groups=chunk_block_ids, + chunk_block_offset=chunk_offset, ) - sender_session.send(kv_slice, chunk_block_offset=chunk_offset) + sender_session.send(kv_slice) chunk_offset += max(len(ids) for ids in chunk_block_ids) receiver_sessions = [ From 97eb94a25da520d74e70ee260b092f7ba444f770 Mon Sep 17 00:00:00 2001 From: Athena Cai Date: Mon, 22 Jun 2026 21:47:45 +0000 Subject: [PATCH 6/8] fix ups Signed-off-by: Athena Cai --- .../batch_manager/kvCacheManager.h | 16 +++-- .../batch_manager/kvCacheManager.cpp | 41 +++++++------ .../batch_manager/kvCacheManagerTest.cpp | 53 +++++++++++++++++ .../_torch/disaggregation/native/transfer.py | 10 +++- .../_torch/disaggregation/transceiver.py | 4 ++ .../accuracy/test_disaggregated_serving.py | 43 ++++++++++++++ .../test_lists/test-db/l0_dgx_b200.yml | 1 + .../disaggregated/test_chunked_transfer.py | 59 +++++++------------ .../disaggregated/test_kv_transfer.py | 3 + 9 files changed, 163 insertions(+), 67 deletions(-) diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h index 94e8cc1a4fa0..2c122fbeb8e5 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h @@ -681,15 +681,13 @@ class GenerationRequest ++mNumFrontBlocksRemovedPerWindow.at(windowSize); } - //! \brief Advance ``mNumFrontBlocksRemoved`` without touching cache blocks. + //! \brief Advance the per-window front-block counter without touching cache blocks. //! \details Used by ``BlockManager::releasePrefixBlocks`` to advance the - //! shared front-block counter once after every ``WindowBlockManager`` has - //! processed the same prefix range. Has clearer intent than calling - //! ``removeFrontBlock`` with a sentinel ``windowSize`` value, and is robust - //! to future changes that consume the ``windowSize`` argument. - void incrementNumFrontBlocksRemoved() + //! single-window front-block counter once after every ``WindowBlockManager`` has + //! processed the same prefix range. + void incrementNumFrontBlocksRemoved(SizeType32 windowSize) { - ++mNumFrontBlocksRemoved; + ++mNumFrontBlocksRemovedPerWindow.at(windowSize); } void removeLastBlock(SizeType32 windowSize) @@ -989,7 +987,7 @@ class WindowBlockManager //! for blocks whose data has already been transferred. Reuses the //! detachFrontBlock mechanism (decRefCount + eviction policy release). //! Called by BlockManager::releasePrefixBlocks which coordinates the - //! shared mNumFrontBlocksRemoved counter across all window managers. + //! single-window front-block counter across all window managers. void releasePrefixBlocks(GenerationRequest& sequence, SizeType32 startIdx, SizeType32 numBlocks); //! \brief Simulate freeing all blocks for that sequence to check impact on number of free blocks @@ -1535,7 +1533,7 @@ class BlockManager //! \brief Release the first numBlocks prefix blocks of a sequence. //! \details Mirrors detachFrontBlock logic: decRefCount + eviction policy - //! release for each prefix block. The mNumFrontBlocksRemoved counter on + //! release for each prefix block. The front-block counter on //! GenerationRequest ensures releaseBlocks (called during removeSequence) //! skips already-freed prefix blocks. void releasePrefixBlocks(GenerationRequest& sequence, SizeType32 numBlocks); diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index c9f9686a7c50..a0ead678967e 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -2906,22 +2906,22 @@ void BlockManager::releasePrefixBlocks(GenerationRequest& sequence, SizeType32 n // today (gated by should_store_blocks: not is_vswa in the executor and // beamWidth == 1 assertion in WindowBlockManager::releasePrefixBlocks). // + auto const windowSize = mWindowBlockManagers.cbegin()->first; // Snapshot the counter before iterating so that every WindowBlockManager // releases the same range. Without this, the first manager would advance - // the shared mNumFrontBlocksRemoved counter and subsequent managers would - // see the counter already at the target, skipping their own blocks. - SizeType32 const startIdx = sequence.getNumFrontBlocksRemoved(); + // the single-window front-block counter and subsequent managers would see + // the counter already at the target, skipping their own blocks. + SizeType32 const startIdx = sequence.getNumFrontBlocksRemoved(windowSize); for (auto& [_, manager] : mWindowBlockManagers) { manager.releasePrefixBlocks(sequence, startIdx, numBlocks); } - // Advance the shared counter once, after all managers have released. + // Advance the single-window counter once, after all managers have released. // Uses incrementNumFrontBlocksRemoved (counter-only) instead of - // removeFrontBlock so the intent is explicit and we do not depend on - // removeFrontBlock ignoring its windowSize argument. - while (sequence.getNumFrontBlocksRemoved() < numBlocks) + // removeFrontBlock so the intent is explicit. + while (sequence.getNumFrontBlocksRemoved(windowSize) < numBlocks) { - sequence.incrementNumFrontBlocksRemoved(); + sequence.incrementNumFrontBlocksRemoved(windowSize); } } @@ -3746,23 +3746,30 @@ void WindowBlockManager::releasePrefixBlocks(GenerationRequest& sequence, SizeTy auto& allocatedBlocks = mAllocatedBlocksPerSeq.at(requestId); SizeType32 const target = std::min(numBlocks, static_cast(allocatedBlocks.size())); - // Release blocks in range [startIdx, target). The shared - // mNumFrontBlocksRemoved counter is advanced by BlockManager after + // Release blocks in range [startIdx, target). The single-window + // front-block counter is advanced by BlockManager after // all WindowBlockManagers have processed the same range. for (SizeType32 blockIdx = startIdx; blockIdx < target; ++blockIdx) { auto& block = allocatedBlocks.at(blockIdx); + auto releasedBlock = block; TLLM_LOG_DEBUG("%s::releasePrefixBlocks - Releasing block %d from sequence %lu", mLogPrefix.c_str(), - block->getBlockId(), requestId); + releasedBlock->getBlockId(), requestId); - if (block->hasRefs()) + // Replace the sequence slot with a placeholder, matching detachFrontBlock(). + // removeSequence later walks allocatedBlocks in releaseBlocks(); leaving the + // real block here would release it a second time and corrupt the eviction + // policy's free-block count. + block = KVCacheBlock::createPlaceholder(); + + if (releasedBlock->hasRefs()) { - block->decRefCount(); + releasedBlock->decRefCount(); } - if (!block->hasRefs()) + if (!releasedBlock->hasRefs()) { - mEvictionPolicy->releaseBlock(block); + mEvictionPolicy->releaseBlock(releasedBlock); } } } @@ -3945,8 +3952,8 @@ std::optional KVCacheManager::removeSequence( void KVCacheManager::releasePrefixBlocks(RequestIdType requestId, SizeType32 numBlocks) { - // Hard precondition: BlockManager::releasePrefixBlocks advances the shared - // mNumFrontBlocksRemoved counter to numBlocks for every WindowBlockManager, + // Hard precondition: BlockManager::releasePrefixBlocks advances the + // single-window front-block counter to numBlocks for every WindowBlockManager, // even when a window has fewer than numBlocks allocated. Under variable // sliding window attention (VSWA), that would cause WindowBlockManager:: // releaseBlocks (called during removeSequence) to underrun rbegin() and diff --git a/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp b/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp index b1c91ae09d4c..44e114a936ed 100644 --- a/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp +++ b/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp @@ -270,6 +270,59 @@ TEST_F(KVCacheManagerTest, BlockManagerTest) std::runtime_error); } +TEST_F(KVCacheManagerTest, BlockManagerReleasePrefixBlocksDoesNotDoubleFreeOnTeardown) +{ + auto constexpr numLayers = 12; + auto constexpr numKvHeads = 6; + auto constexpr sizePerHead = 128; + auto constexpr tokensPerBlock = 4; + auto constexpr blocksInPrimaryPool = 8; + auto constexpr blocksInSecondaryPool = 0; + auto constexpr maxNumSequences = 8; + auto const stream = std::make_shared(); + + auto constexpr beamWidth = 1; + auto constexpr numBlocksPerBeam = 4; + auto constexpr numTokens = tokensPerBlock * numBlocksPerBeam; + auto constexpr maxAttentionWindow = numTokens; + + auto const blocksPerWindow = BlocksPerWindow{{maxAttentionWindow, {blocksInPrimaryPool, blocksInSecondaryPool}}}; + + BlockManager blockManager(std::vector(numLayers, numKvHeads), sizePerHead, tokensPerBlock, blocksPerWindow, + maxNumSequences, stream, maxAttentionWindow, beamWidth, + std::vector{maxAttentionWindow}, nvinfer1::DataType::kHALF, 0, maxAttentionWindow); + blockManager.allocatePools(false); + + SizeType32 constexpr maxNewTokens{0}; + tr::SamplingConfig const samplingConfig{beamWidth}; + bool constexpr isStreaming{false}; + + auto tokens = std::make_shared(); + for (SizeType32 i = 0; i < numTokens; ++i) + { + tokens->push_back(i); + } + + LlmRequest::RequestIdType constexpr requestId{42}; + auto llmReq = std::make_shared(requestId, maxNewTokens, tokens, samplingConfig, isStreaming); + GenerationRequest seq{requestId, numTokens, beamWidth, blockManager.getWindowSizesMetadata()}; + + (void) blockManager.addSequenceBatch( + {&seq}, {numTokens}, {numBlocksPerBeam}, {std::ref(*llmReq)}, maxAttentionWindow, /*isEnableBlockReuse=*/false); + EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool - numBlocksPerBeam); + + blockManager.releasePrefixBlocks(seq, 2); + EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool - 2); + + // releasePrefixBlocks has cumulative semantics. This should release only + // one additional block rather than releasing the first two again. + blockManager.releasePrefixBlocks(seq, 3); + EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool - 1); + + blockManager.releaseBlocks(seq); + EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool); +} + template void writePatternToOffloadedBlocksDRAM(T* rawBlockPtr, int blockSize, int mask) { diff --git a/tensorrt_llm/_torch/disaggregation/native/transfer.py b/tensorrt_llm/_torch/disaggregation/native/transfer.py index da490c953e8e..9a5d7c11d250 100644 --- a/tensorrt_llm/_torch/disaggregation/native/transfer.py +++ b/tensorrt_llm/_torch/disaggregation/native/transfer.py @@ -1168,6 +1168,10 @@ def disagg_request_id(self) -> int: def status(self) -> SessionStatus: if self._terminal_status is not None: return self._terminal_status + if self._exception is not None or any(t.status == TaskStatus.ERROR for t in self.kv_tasks): + return SessionStatus.ERROR + if self.aux_task is not None and self.aux_task.status == TaskStatus.ERROR: + return SessionStatus.ERROR kv_all_transferred = bool(self.kv_tasks) and all( t.status == TaskStatus.TRANSFERRED for t in self.kv_tasks ) @@ -1755,15 +1759,15 @@ def process_kv_agent_result( ) def process_aux_agent_result(self, _peer_rank: int, status: AgentResult): - # Aux is session-level (not per-slice); expected_transfers is identical - # across all kv_tasks, so any task provides the right count. + # Aux is session-level (not per-slice); use the final KV task's + # expected transfer count so chunked sessions wait for all senders. with self.lock: if not self._kv_tasks: logger.warning( f"Aux result received before any KV tasks for request {self.request_id}" ) return - task = self._kv_tasks[0] + task = self._kv_tasks[-1] if status == AgentResult.SUCCESS: self._aux_count += 1 diff --git a/tensorrt_llm/_torch/disaggregation/transceiver.py b/tensorrt_llm/_torch/disaggregation/transceiver.py index 91371f21391e..bfa18ac2924d 100644 --- a/tensorrt_llm/_torch/disaggregation/transceiver.py +++ b/tensorrt_llm/_torch/disaggregation/transceiver.py @@ -353,6 +353,10 @@ def _make_chunk_callback(self) -> Optional[Callable]: release_queue = self._pending_prefix_releases def _on_chunk_transferred(request_id: int, chunk_block_offset: int, num_blocks: int): + logger.debug( + f"Early release _on_chunk_transferred: request_id: {request_id}, " + f"chunk_block_offset: {chunk_block_offset}, num_blocks: {num_blocks}" + ) cumulative_blocks = chunk_block_offset + num_blocks release_queue.put((request_id, cumulative_blocks)) diff --git a/tests/integration/defs/accuracy/test_disaggregated_serving.py b/tests/integration/defs/accuracy/test_disaggregated_serving.py index 5c0039e8ac32..ff7db8fbf1e9 100644 --- a/tests/integration/defs/accuracy/test_disaggregated_serving.py +++ b/tests/integration/defs/accuracy/test_disaggregated_serving.py @@ -733,6 +733,49 @@ def test_kv_cache_v2_nixl_python(self): self.MODEL_PATH) as llm: run_accuracy_test(llm, self.MODEL_NAME, ["GSM8K"]) + @skip_pre_hopper + @pytest.mark.skip_less_device(2) + @parametrize_with_ids("chunk_size_blocks", [64]) + @parametrize_with_ids("enable_block_reuse", [False, True]) + def test_chunked_kv_transfer_nixl_python_accuracy(self, + chunk_size_blocks: int, + enable_block_reuse: bool): + """Test chunked KV transfer accuracy using Python transceiver and C++ KVCacheManager.""" + kv_cache_config = { + "use_kv_cache_manager_v2": False, + "enable_block_reuse": enable_block_reuse, + } + cache_transceiver_config = { + "backend": "NIXL", + "transceiver_runtime": "PYTHON", + "max_tokens_in_buffer": 4096, + "chunk_size_blocks": chunk_size_blocks, + } + ctx_server_config = { + "disable_overlap_scheduler": True, + "kv_cache_config": dict(kv_cache_config), + "cache_transceiver_config": dict(cache_transceiver_config), + } + gen_server_config = { + "disable_overlap_scheduler": False, + "kv_cache_config": dict(kv_cache_config), + "cache_transceiver_config": dict(cache_transceiver_config), + } + disaggregated_server_config = { + "hostname": "localhost", + "backend": "pytorch", + "context_servers": { + "num_instances": 1, + }, + "generation_servers": { + "num_instances": 1, + }, + } + with launch_disaggregated_llm(disaggregated_server_config, + ctx_server_config, gen_server_config, + self.MODEL_PATH) as llm: + run_accuracy_test(llm, self.MODEL_NAME, ["GSM8K"]) + @pytest.mark.skip_less_device(2) def test_ngram(self): speculative_decoding_config = { diff --git a/tests/integration/test_lists/test-db/l0_dgx_b200.yml b/tests/integration/test_lists/test-db/l0_dgx_b200.yml index c1ea19000626..685eaa868cf4 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_b200.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_b200.yml @@ -17,6 +17,7 @@ l0_dgx_b200: tests: - unittest/_torch/misc/test_autotuner.py::test_autotuner_distributed_strategy - accuracy/test_llm_api_pytorch.py::TestQwen3_5_35B_A3B::test_bf16[tp2-TRTLLM] + - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_chunked_kv_transfer_nixl_python_accuracy # ------------- KV Cache V2 Scheduler IT (multi-GPU) --------------- - kv_cache/test_kv_cache_v2_scheduler.py::TestKVCacheV2DSv3Lite::test_mtp_draft_tokens - kv_cache/test_kv_cache_v2_scheduler.py::TestKVCacheV2DSv3Lite::test_mtp_chunked_draft_tokens diff --git a/tests/unittest/disaggregated/test_chunked_transfer.py b/tests/unittest/disaggregated/test_chunked_transfer.py index c140be329a34..7ffc066b415b 100644 --- a/tests/unittest/disaggregated/test_chunked_transfer.py +++ b/tests/unittest/disaggregated/test_chunked_transfer.py @@ -67,7 +67,6 @@ def _make_tx_session(num_slices: int, rid: int = 42, **kwargs) -> TxSession: request_id=rid, params=params, sender=_stub_sender(), - aux_slot=None, **kwargs, ) for i in range(num_slices): @@ -87,7 +86,6 @@ def _make_rx_session(num_slices: int, rid: int = 42) -> RxSession: request_id=rid, params=params, receiver=_stub_receiver(), - aux_slot=None, ) for i in range(num_slices): s = KVSlice( @@ -152,24 +150,20 @@ def test_tx_session_wait_complete_all_tasks(): """TxSession.wait_complete blocks on all task futures.""" session = _make_tx_session(3) for task in session.kv_tasks: - task.future.set_result(AgentResult.SUCCESS) - task.status = TaskStatus.TRANSFERRED + task.complete() - result = session.wait_complete(need_aux=False, timeout=1.0) + result = session.wait_complete() assert result == WaitResult.COMPLETED def test_tx_session_wait_complete_fails_on_partial_failure(): """TxSession.wait_complete returns FAILED if any task fails.""" session = _make_tx_session(3) - session.kv_tasks[0].future.set_result(AgentResult.SUCCESS) - session.kv_tasks[0].status = TaskStatus.TRANSFERRED - session.kv_tasks[1].future.set_result(AgentResult.FAILED) - session.kv_tasks[1].status = TaskStatus.ERROR - session.kv_tasks[2].future.set_result(AgentResult.SUCCESS) - session.kv_tasks[2].status = TaskStatus.TRANSFERRED + session.kv_tasks[0].complete() + session.kv_tasks[1].fail(RuntimeError("transfer failed")) + session.kv_tasks[2].complete() - result = session.wait_complete(need_aux=False, timeout=1.0) + result = session.wait_complete() assert result == WaitResult.FAILED @@ -215,22 +209,19 @@ def test_rx_session_wait_complete_all_tasks(): """RxSession.wait_complete blocks on all task futures.""" session = _make_rx_session(3) for task in session._kv_tasks: - task.future.set_result(AgentResult.SUCCESS) - task.status = TaskStatus.TRANSFERRED + task.complete() - result = session.wait_complete(need_aux=False) + result = session.wait_complete() assert result == WaitResult.COMPLETED def test_rx_session_wait_complete_fails_on_partial_failure(): """RxSession.wait_complete returns FAILED if any task fails.""" session = _make_rx_session(2) - session._kv_tasks[0].future.set_result(AgentResult.SUCCESS) - session._kv_tasks[0].status = TaskStatus.TRANSFERRED - session._kv_tasks[1].future.set_exception(RuntimeError("transfer failed")) - session._kv_tasks[1].status = TaskStatus.ERROR + session._kv_tasks[0].complete() + session._kv_tasks[1].fail(RuntimeError("transfer failed")) - result = session.wait_complete(need_aux=False) + result = session.wait_complete() assert result == WaitResult.FAILED @@ -383,17 +374,13 @@ def test_tx_session_mid_chunk_failure(): """If one chunk fails mid-transfer, the session reports ERROR.""" session = _make_tx_session(4) - session.kv_tasks[0].future.set_result(AgentResult.SUCCESS) - session.kv_tasks[0].status = TaskStatus.TRANSFERRED - session.kv_tasks[1].future.set_result(AgentResult.SUCCESS) - session.kv_tasks[1].status = TaskStatus.TRANSFERRED - session.kv_tasks[2].future.set_exception(RuntimeError("RDMA failed")) - session.kv_tasks[2].status = TaskStatus.ERROR - session.kv_tasks[3].future.set_result(AgentResult.SUCCESS) - session.kv_tasks[3].status = TaskStatus.TRANSFERRED + session.kv_tasks[0].complete() + session.kv_tasks[1].complete() + session.kv_tasks[2].fail(RuntimeError("RDMA failed")) + session.kv_tasks[3].complete() assert session.status == SessionStatus.ERROR - result = session.wait_complete(need_aux=False, timeout=1.0) + result = session.wait_complete() assert result == WaitResult.FAILED @@ -401,15 +388,11 @@ def test_rx_session_mid_chunk_failure(): """If one chunk fails mid-transfer on receiver, the session reports ERROR.""" session = _make_rx_session(4) - session._kv_tasks[0].future.set_result(AgentResult.SUCCESS) - session._kv_tasks[0].status = TaskStatus.TRANSFERRED - session._kv_tasks[1].future.set_exception(RuntimeError("RDMA failed")) - session._kv_tasks[1].status = TaskStatus.ERROR - session._kv_tasks[2].future.set_result(AgentResult.SUCCESS) - session._kv_tasks[2].status = TaskStatus.TRANSFERRED - session._kv_tasks[3].future.set_result(AgentResult.SUCCESS) - session._kv_tasks[3].status = TaskStatus.TRANSFERRED + session._kv_tasks[0].complete() + session._kv_tasks[1].fail(RuntimeError("RDMA failed")) + session._kv_tasks[2].complete() + session._kv_tasks[3].complete() assert session.status == SessionStatus.ERROR - result = session.wait_complete(need_aux=False) + result = session.wait_complete() assert result == WaitResult.FAILED diff --git a/tests/unittest/disaggregated/test_kv_transfer.py b/tests/unittest/disaggregated/test_kv_transfer.py index 1728feac00ec..6282455e2579 100644 --- a/tests/unittest/disaggregated/test_kv_transfer.py +++ b/tests/unittest/disaggregated/test_kv_transfer.py @@ -1551,6 +1551,7 @@ def add_and_verify_chunked_request( ctx_info = _setup_chunked_request(setup, ctx_request_id, gen_request_id, request_len) ctx_block_ids = ctx_info["ctx_block_ids"] gen_block_ids = ctx_info["gen_block_ids"] + token_range = TokenRange(start=0, end=request_len) sender_sessions = [tw.create_tx_session(ctx_info["ctx_request"]) for tw in ctx_transfer_workers] for sender_session, block_ids_per_groups in zip(sender_sessions, ctx_block_ids, strict=True): @@ -1565,6 +1566,7 @@ def add_and_verify_chunked_request( kv_slice = KVSlice( is_last_slice=is_last, block_ids_per_layer_groups=chunk_block_ids, + token_range=token_range, chunk_block_offset=chunk_offset, ) sender_session.send(kv_slice) @@ -1577,6 +1579,7 @@ def add_and_verify_chunked_request( full_slice = KVSlice( is_last_slice=True, block_ids_per_layer_groups=block_ids_per_groups, + token_range=token_range, ) recv_session.receive(full_slice) From 2702a1037821c373f5c0a1abbd3879872b97e8ea Mon Sep 17 00:00:00 2001 From: Athena Cai Date: Tue, 23 Jun 2026 22:11:00 +0000 Subject: [PATCH 7/8] address coderabbit comments Signed-off-by: Athena Cai --- .../_torch/disaggregation/native/transfer.py | 16 ++++++++++------ .../_torch/pyexecutor/kv_cache_transceiver.py | 11 +++++++++-- 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/tensorrt_llm/_torch/disaggregation/native/transfer.py b/tensorrt_llm/_torch/disaggregation/native/transfer.py index 9a5d7c11d250..47c32c60572a 100644 --- a/tensorrt_llm/_torch/disaggregation/native/transfer.py +++ b/tensorrt_llm/_torch/disaggregation/native/transfer.py @@ -732,10 +732,14 @@ def _build_kv_write_meta(self, task: KVSendTask, req_info: RecvReqInfo) -> Write # When sender uses chunking, the receiver sends all dst # blocks in a single RecvReqInfo. Slice dst to match # this task's src chunk position. - if chunk_offset > 0 or len(src_block_ids) < len(full_dst_block_ids): - dst_block_ids = full_dst_block_ids[ - chunk_offset : chunk_offset + len(src_block_ids) - ] + if chunk_offset > 0 or not task._slice.is_last_slice: + chunk_end = chunk_offset + len(src_block_ids) + if chunk_end > full_dst_block_ids.size: + raise ValueError( + f"dst chunk range out of bounds: offset={chunk_offset}, " + f"len={len(src_block_ids)}, dst_blocks={full_dst_block_ids.size}" + ) + dst_block_ids = full_dst_block_ids[chunk_offset:chunk_end] else: dst_block_ids = full_dst_block_ids @@ -747,10 +751,10 @@ def _build_kv_write_meta(self, task: KVSendTask, req_info: RecvReqInfo) -> Write f"src={src_block_ids.size}, dst={dst_block_ids.size}" ) dst_block_ids = dst_block_ids[:-1] - elif block_diff > 1: + elif block_diff != 0: raise ValueError( f"src/dst block count mismatch: {src_block_ids.size} vs " - f"{dst_block_ids.size} (expected diff <= 1)" + f"{dst_block_ids.size} (expected 0 <= diff <= 1)" ) tpb = extractor.page_table.tokens_per_block token_range = task._slice.token_range diff --git a/tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py b/tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py index 2c88d58b3720..294cd42fa559 100644 --- a/tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py +++ b/tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py @@ -74,8 +74,9 @@ def create_kv_cache_transceiver( # since the C++ transceiver does not support chunked transfer. # Only applies to NIXL/DEFAULT backends (the Python transceiver # does not support UCX, MPI, or MOONCAKE). - use_python = cache_transceiver_config.transceiver_runtime == "PYTHON" - if (not use_python + runtime = cache_transceiver_config.transceiver_runtime + use_python = runtime == "PYTHON" + if (runtime is None and cache_transceiver_config.chunk_size_blocks is not None): if cache_transceiver_config.backend in (None, "DEFAULT", "NIXL"): # Use warning (not info) so users notice the transceiver swap and @@ -95,6 +96,12 @@ def create_kv_cache_transceiver( f"transceiver, which does not support chunked transfer. " f"chunk_size_blocks will be ignored. Use NIXL backend to " f"enable chunked transfer.") + elif (runtime == "CPP" + and cache_transceiver_config.chunk_size_blocks is not None): + logger.warning( + "chunk_size_blocks is set but transceiver_runtime='CPP' " + "explicitly disables Python auto-selection; " + "chunk_size_blocks will be ignored.") # Warn when chunk_size_blocks is below the recommended floor. The Pydantic # field is PositiveInt (>=1), but values below ~16 push the per-chunk RDMA From 423f69891d6a6408e773c092ddb5371a07408396 Mon Sep 17 00:00:00 2001 From: Athena Cai Date: Tue, 23 Jun 2026 22:48:08 +0000 Subject: [PATCH 8/8] Revert early block release implementation Signed-off-by: Athena Cai --- .../batch_manager/kvCacheManager.h | 29 ---- .../batch_manager/kvCacheManager.cpp | 90 ----------- .../nanobind/batch_manager/kvCacheManager.cpp | 4 +- .../batch_manager/kvCacheManagerTest.cpp | 53 ------- .../_torch/disaggregation/native/transfer.py | 39 +---- .../_torch/disaggregation/transceiver.py | 90 +---------- .../_torch/pyexecutor/kv_cache_transceiver.py | 17 +- .../_torch/pyexecutor/resource_manager.py | 18 --- tensorrt_llm/llmapi/llm_args.py | 3 +- .../disaggregated/test_chunked_transfer.py | 148 +----------------- 10 files changed, 17 insertions(+), 474 deletions(-) diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h index 2c122fbeb8e5..c665f7a8df95 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h @@ -681,15 +681,6 @@ class GenerationRequest ++mNumFrontBlocksRemovedPerWindow.at(windowSize); } - //! \brief Advance the per-window front-block counter without touching cache blocks. - //! \details Used by ``BlockManager::releasePrefixBlocks`` to advance the - //! single-window front-block counter once after every ``WindowBlockManager`` has - //! processed the same prefix range. - void incrementNumFrontBlocksRemoved(SizeType32 windowSize) - { - ++mNumFrontBlocksRemovedPerWindow.at(windowSize); - } - void removeLastBlock(SizeType32 windowSize) { for (auto& beamBlockIds : mCacheBlockIds.at(windowSize)) @@ -982,14 +973,6 @@ class WindowBlockManager std::optional releaseBlocks( GenerationRequest& sequence, OptionalRef llmRequest); - //! \brief Release prefix blocks in range [startIdx, numBlocks) for a sequence. - //! \details Used by disaggregated serving to free sender-side KV memory - //! for blocks whose data has already been transferred. Reuses the - //! detachFrontBlock mechanism (decRefCount + eviction policy release). - //! Called by BlockManager::releasePrefixBlocks which coordinates the - //! single-window front-block counter across all window managers. - void releasePrefixBlocks(GenerationRequest& sequence, SizeType32 startIdx, SizeType32 numBlocks); - //! \brief Simulate freeing all blocks for that sequence to check impact on number of free blocks void schedulingReleaseBlocks(LlmRequest::RequestIdType requestId); @@ -1531,13 +1514,6 @@ class BlockManager std::optional releaseBlocks( GenerationRequest& sequence, OptionalRef llmRequest = std::nullopt, bool pinBlocks = false); - //! \brief Release the first numBlocks prefix blocks of a sequence. - //! \details Mirrors detachFrontBlock logic: decRefCount + eviction policy - //! release for each prefix block. The front-block counter on - //! GenerationRequest ensures releaseBlocks (called during removeSequence) - //! skips already-freed prefix blocks. - void releasePrefixBlocks(GenerationRequest& sequence, SizeType32 numBlocks); - [[nodiscard]] std::vector storeBlocksForReuse( GenerationRequest& sequence, OptionalRef llmRequest = std::nullopt, bool pinBlocks = false); @@ -2455,11 +2431,6 @@ class KVCacheManager : public BaseKVCacheManager [[nodiscard]] std::optional removeSequence(LlmRequest::RequestIdType requestId, OptionalRef llmRequest = std::nullopt, bool pinOnRelease = false) override; - //! \brief Release prefix blocks for a sequence without removing it. - //! \details Used by disaggregated serving for early block release during - //! chunked KV cache transfer. No-op if the sequence does not exist. - void releasePrefixBlocks(LlmRequest::RequestIdType requestId, SizeType32 numBlocks); - void schedulingRemoveSequence(LlmRequest::RequestIdType requestId) override; [[nodiscard]] runtime::ITensor::SharedPtr getBlockPoolPointers() const override diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index a0ead678967e..0fb8af1527ae 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -2897,34 +2897,6 @@ std::optional BlockManager::releaseBlocks( return lastStoredId; } -void BlockManager::releasePrefixBlocks(GenerationRequest& sequence, SizeType32 numBlocks) -{ - // NOTE: This assumes a single window size (no VSWA). With different window - // sizes, each WindowBlockManager may have a different number of allocated - // blocks, so releasing the same numBlocks from all managers would need - // per-window-size handling. Disaggregated serving does not support VSWA - // today (gated by should_store_blocks: not is_vswa in the executor and - // beamWidth == 1 assertion in WindowBlockManager::releasePrefixBlocks). - // - auto const windowSize = mWindowBlockManagers.cbegin()->first; - // Snapshot the counter before iterating so that every WindowBlockManager - // releases the same range. Without this, the first manager would advance - // the single-window front-block counter and subsequent managers would see - // the counter already at the target, skipping their own blocks. - SizeType32 const startIdx = sequence.getNumFrontBlocksRemoved(windowSize); - for (auto& [_, manager] : mWindowBlockManagers) - { - manager.releasePrefixBlocks(sequence, startIdx, numBlocks); - } - // Advance the single-window counter once, after all managers have released. - // Uses incrementNumFrontBlocksRemoved (counter-only) instead of - // removeFrontBlock so the intent is explicit. - while (sequence.getNumFrontBlocksRemoved(windowSize) < numBlocks) - { - sequence.incrementNumFrontBlocksRemoved(windowSize); - } -} - void BlockManager::pinBlocks(GenerationRequest& sequence) { for (auto& [_, manager] : mWindowBlockManagers) @@ -3737,43 +3709,6 @@ void WindowBlockManager::detachFrontBlock(GenerationRequest& sequence) sequence.getNumFrontBlocksRemoved(mWindowSize)); } -void WindowBlockManager::releasePrefixBlocks(GenerationRequest& sequence, SizeType32 startIdx, SizeType32 numBlocks) -{ - TLLM_CHECK_WITH_INFO( - sequence.getBeamWidth() == 1, "[kv cache manager] releasePrefixBlocks does not support beamWidth > 1"); - - auto const requestId = sequence.getRequestId(); - auto& allocatedBlocks = mAllocatedBlocksPerSeq.at(requestId); - SizeType32 const target = std::min(numBlocks, static_cast(allocatedBlocks.size())); - - // Release blocks in range [startIdx, target). The single-window - // front-block counter is advanced by BlockManager after - // all WindowBlockManagers have processed the same range. - for (SizeType32 blockIdx = startIdx; blockIdx < target; ++blockIdx) - { - auto& block = allocatedBlocks.at(blockIdx); - auto releasedBlock = block; - - TLLM_LOG_DEBUG("%s::releasePrefixBlocks - Releasing block %d from sequence %lu", mLogPrefix.c_str(), - releasedBlock->getBlockId(), requestId); - - // Replace the sequence slot with a placeholder, matching detachFrontBlock(). - // removeSequence later walks allocatedBlocks in releaseBlocks(); leaving the - // real block here would release it a second time and corrupt the eviction - // policy's free-block count. - block = KVCacheBlock::createPlaceholder(); - - if (releasedBlock->hasRefs()) - { - releasedBlock->decRefCount(); - } - if (!releasedBlock->hasRefs()) - { - mEvictionPolicy->releaseBlock(releasedBlock); - } - } -} - PrefixReuseSummary KVCacheManager::analyzePrefixReuse( VecUniqueTokens const& uniqueTokens, LlmRequest const& llmRequest) const { @@ -3950,31 +3885,6 @@ std::optional KVCacheManager::removeSequence( return lastStoredId; } -void KVCacheManager::releasePrefixBlocks(RequestIdType requestId, SizeType32 numBlocks) -{ - // Hard precondition: BlockManager::releasePrefixBlocks advances the - // single-window front-block counter to numBlocks for every WindowBlockManager, - // even when a window has fewer than numBlocks allocated. Under variable - // sliding window attention (VSWA), that would cause WindowBlockManager:: - // releaseBlocks (called during removeSequence) to underrun rbegin() and - // skip tail blocks for the smaller window. Disagg serving already gates - // VSWA out, but we enforce the assumption here so the C++ API contract is - // self-defending instead of relying on caller discipline. - TLLM_CHECK_WITH_INFO( - !mBlockManager.isVariableWindow(), "releasePrefixBlocks does not support variable sliding window attention"); - if (numBlocks <= 0) - { - return; - } - std::scoped_lock lock(mSequencesMtx); - auto it = mSequences.find(requestId); - if (it == mSequences.end()) - { - return; - } - mBlockManager.releasePrefixBlocks(it->second, numBlocks); -} - std::vector KVCacheManager::storeBlocksForReuse( RequestIdType requestId, OptionalRef llmRequest, bool pinBlocks) { diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp index 4e43e74d4a96..12b29d4981e2 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp @@ -683,9 +683,7 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m) .def("copy_linear_attention_block", &tbk::KVCacheManager::copyLinearAttentionBlock, nb::arg("llm_request"), nb::call_guard()) .def("copy_linear_attention_block_batch", &tbk::KVCacheManager::copyLinearAttentionBlockBatch, - nb::arg("llm_requests"), nb::call_guard()) - .def("release_prefix_blocks", &tbk::KVCacheManager::releasePrefixBlocks, nb::arg("request_id"), - nb::arg("num_blocks"), nb::call_guard()); + nb::arg("llm_requests"), nb::call_guard()); } void tb::BasePeftCacheManagerBindings::initBindings(nb::module_& m) diff --git a/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp b/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp index 44e114a936ed..b1c91ae09d4c 100644 --- a/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp +++ b/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp @@ -270,59 +270,6 @@ TEST_F(KVCacheManagerTest, BlockManagerTest) std::runtime_error); } -TEST_F(KVCacheManagerTest, BlockManagerReleasePrefixBlocksDoesNotDoubleFreeOnTeardown) -{ - auto constexpr numLayers = 12; - auto constexpr numKvHeads = 6; - auto constexpr sizePerHead = 128; - auto constexpr tokensPerBlock = 4; - auto constexpr blocksInPrimaryPool = 8; - auto constexpr blocksInSecondaryPool = 0; - auto constexpr maxNumSequences = 8; - auto const stream = std::make_shared(); - - auto constexpr beamWidth = 1; - auto constexpr numBlocksPerBeam = 4; - auto constexpr numTokens = tokensPerBlock * numBlocksPerBeam; - auto constexpr maxAttentionWindow = numTokens; - - auto const blocksPerWindow = BlocksPerWindow{{maxAttentionWindow, {blocksInPrimaryPool, blocksInSecondaryPool}}}; - - BlockManager blockManager(std::vector(numLayers, numKvHeads), sizePerHead, tokensPerBlock, blocksPerWindow, - maxNumSequences, stream, maxAttentionWindow, beamWidth, - std::vector{maxAttentionWindow}, nvinfer1::DataType::kHALF, 0, maxAttentionWindow); - blockManager.allocatePools(false); - - SizeType32 constexpr maxNewTokens{0}; - tr::SamplingConfig const samplingConfig{beamWidth}; - bool constexpr isStreaming{false}; - - auto tokens = std::make_shared(); - for (SizeType32 i = 0; i < numTokens; ++i) - { - tokens->push_back(i); - } - - LlmRequest::RequestIdType constexpr requestId{42}; - auto llmReq = std::make_shared(requestId, maxNewTokens, tokens, samplingConfig, isStreaming); - GenerationRequest seq{requestId, numTokens, beamWidth, blockManager.getWindowSizesMetadata()}; - - (void) blockManager.addSequenceBatch( - {&seq}, {numTokens}, {numBlocksPerBeam}, {std::ref(*llmReq)}, maxAttentionWindow, /*isEnableBlockReuse=*/false); - EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool - numBlocksPerBeam); - - blockManager.releasePrefixBlocks(seq, 2); - EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool - 2); - - // releasePrefixBlocks has cumulative semantics. This should release only - // one additional block rather than releasing the first two again. - blockManager.releasePrefixBlocks(seq, 3); - EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool - 1); - - blockManager.releaseBlocks(seq); - EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool); -} - template void writePatternToOffloadedBlocksDRAM(T* rawBlockPtr, int blockSize, int mask) { diff --git a/tensorrt_llm/_torch/disaggregation/native/transfer.py b/tensorrt_llm/_torch/disaggregation/native/transfer.py index 47c32c60572a..315fe31edf3d 100644 --- a/tensorrt_llm/_torch/disaggregation/native/transfer.py +++ b/tensorrt_llm/_torch/disaggregation/native/transfer.py @@ -7,7 +7,7 @@ import weakref from dataclasses import dataclass from enum import Enum -from typing import Callable, List, Optional, Union +from typing import List, Optional, Union import msgpack import numpy as np @@ -57,8 +57,6 @@ AttentionTypeCpp = tensorrt_llm.bindings.internal.batch_manager.AttentionType LlmRequestType = tensorrt_llm.bindings.internal.batch_manager.LlmRequestType -OnChunkTransferredCallback = Callable[[int, int, int], None] - # Number of worker threads for KV transfer queues (default: 1) KV_TRANSFER_NUM_THREADS = int(os.environ.get("TRTLLM_KV_TRANSFER_NUM_THREADS", "1")) @@ -569,29 +567,6 @@ def _deliver_kv_to_agent(self, write_meta: WriteMeta): ) else: task.complete() - if session._on_chunk_transferred is not None: - try: - # Use the max across layer groups as the - # cumulative release count. For asymmetric - # layer groups (e.g., sliding window), shorter - # groups may have fewer blocks per chunk, but - # each WindowBlockManager independently clamps - # to its own allocated block count via - # min(numBlocks, allocatedBlocks.size()). - num_blocks = max( - (len(ids) for ids in task._slice.block_ids_per_layer_groups), - default=0, - ) - session._on_chunk_transferred( - request_id=session.request_id, - chunk_block_offset=task._slice.chunk_block_offset, - num_blocks=num_blocks, - ) - except Exception as e: - logger.warning( - f"on_chunk_transferred callback failed for " - f"request {session.request_id} slice {write_meta.slice_id}: {e}" - ) logger.debug( f"deliver_kv_to_agent completed: unique_rid={write_meta.unique_rid}, " @@ -751,10 +726,10 @@ def _build_kv_write_meta(self, task: KVSendTask, req_info: RecvReqInfo) -> Write f"src={src_block_ids.size}, dst={dst_block_ids.size}" ) dst_block_ids = dst_block_ids[:-1] - elif block_diff != 0: + elif block_diff > 1: raise ValueError( f"src/dst block count mismatch: {src_block_ids.size} vs " - f"{dst_block_ids.size} (expected 0 <= diff <= 1)" + f"{dst_block_ids.size} (expected diff <= 1)" ) tpb = extractor.page_table.tokens_per_block token_range = task._slice.token_range @@ -1131,7 +1106,6 @@ def __init__( timeout_s: Optional[float] = None, prompt_len: Optional[int] = None, beam_width: int = 1, - on_chunk_transferred: Optional[OnChunkTransferredCallback] = None, ): super().__init__( sender, @@ -1147,7 +1121,6 @@ def __init__( self.kv_tasks = [] self.aux_task = None self.lock = threading.Lock() - self._on_chunk_transferred = on_chunk_transferred self._exception: Optional[Exception] = None self._closed = False @@ -2046,16 +2019,11 @@ def populate_instance_and_rank_info(self, endpoints: list[str], layer_num_per_pp def create_tx_session( self, request: LlmRequest, - on_chunk_transferred: Optional[OnChunkTransferredCallback] = None, ) -> TxSession: """Create a TxSession for the given request. Args: request: The LLM request to create a send session for. - on_chunk_transferred: Optional callback invoked on the - sender worker thread after each chunk's RDMA completes. - Signature: ``(request_id: int, chunk_block_offset: int, - num_blocks: int) -> None``. Returns: A new ``TxSession`` ready to accept ``send()`` calls. @@ -2070,7 +2038,6 @@ def create_tx_session( timeout_s=self._config.tx_timeout_s, prompt_len=request.prompt_len, beam_width=request.py_beam_width, - on_chunk_transferred=on_chunk_transferred, ) def create_rx_session(self, request: LlmRequest) -> RxSession: diff --git a/tensorrt_llm/_torch/disaggregation/transceiver.py b/tensorrt_llm/_torch/disaggregation/transceiver.py index bfa18ac2924d..b2f3ec3ca514 100644 --- a/tensorrt_llm/_torch/disaggregation/transceiver.py +++ b/tensorrt_llm/_torch/disaggregation/transceiver.py @@ -1,9 +1,8 @@ import math -import queue import uuid from collections import defaultdict from itertools import chain -from typing import Any, Callable, Dict, List, Optional, Tuple, cast +from typing import Any, Callable, Dict, List, Optional, cast import numpy as np import torch @@ -99,8 +98,6 @@ def __init__( self._wait_reqs = {} self._page_table = self._transfer_worker.page_table self._chunk_size_blocks = cache_transceiver_config.chunk_size_blocks - self._pending_prefix_releases: queue.Queue[Tuple[int, int]] = queue.Queue() - self._chunk_callback: Optional[Callable] = self._make_chunk_callback() def _broadcast_instance_name(self) -> str: if self._dist.rank == 0: @@ -146,10 +143,6 @@ def shutdown(self): if getattr(self, "_shutdown", False): return self._shutdown = True - # Drain any pending prefix-release entries before tearing down sessions - # so memory frees in the same shutdown step instead of leaking until - # removeSequence cleans up at session close. - self._drain_pending_releases() for session in list(self._send_sessions.values()): session.close() for session in list(self._recv_sessions.values()): @@ -310,72 +303,6 @@ def _create_kv_slices(self, req: LlmRequest) -> List[KVSlice]: return slices - def _make_chunk_callback(self) -> Optional[Callable]: - """Return a callback for early prefix block release. - - The callback is invoked on the sender worker thread after each - chunk's RDMA finishes. It enqueues a release request that the - main thread drains via ``_drain_pending_releases``. - - Early release is disabled when: - - ``chunk_size_blocks`` is not set (no chunking) - - The KV cache manager does not support ``release_prefix_blocks`` - - The callback is created once at init time and shared across all - sessions (all sessions use the same release queue). - - Returns: - A callback ``(request_id, chunk_block_offset, num_blocks) -> None`` - if chunking is enabled and the KV cache manager supports - ``release_prefix_blocks``, otherwise ``None``. - """ - if self._chunk_size_blocks is None: - return None - manager_name = type(self._kv_cache_manager).__name__ - if not hasattr(self._kv_cache_manager, "release_prefix_blocks"): - # Surface the gate decision in logs so a typo or missing wrapper on - # the manager side is observable at startup, not silent. - logger.warning( - "Chunked KV transfer is enabled (chunk_size_blocks=%s) but %s " - "does not implement release_prefix_blocks; early prefix block " - "release is disabled. Blocks will be freed at session teardown.", - self._chunk_size_blocks, - manager_name, - ) - return None - logger.info( - "Chunked KV transfer with early prefix block release enabled " - "(chunk_size_blocks=%s, manager=%s).", - self._chunk_size_blocks, - manager_name, - ) - - release_queue = self._pending_prefix_releases - - def _on_chunk_transferred(request_id: int, chunk_block_offset: int, num_blocks: int): - logger.debug( - f"Early release _on_chunk_transferred: request_id: {request_id}, " - f"chunk_block_offset: {chunk_block_offset}, num_blocks: {num_blocks}" - ) - cumulative_blocks = chunk_block_offset + num_blocks - release_queue.put((request_id, cumulative_blocks)) - - return _on_chunk_transferred - - def _drain_pending_releases(self) -> None: - """Process all queued prefix block releases on the main thread. - - Drains the ``_pending_prefix_releases`` queue and calls - ``release_prefix_blocks`` on the KV cache manager for each - entry. Must be called from the main executor thread only. - """ - while True: - try: - request_id, num_blocks = self._pending_prefix_releases.get_nowait() - except queue.Empty: - break - self._kv_cache_manager.release_prefix_blocks(request_id, num_blocks) - @staticmethod def _split_packed_beam_block_ids( block_ids: np.ndarray, @@ -548,11 +475,6 @@ def _build_to_process( return to_process def _close_failed_sessions(self, sessions: dict, reqs: dict, failed: list): - # Drain pending prefix releases before closing failed sessions so that - # already-completed chunks of healthy sister sessions free memory now - # rather than waiting for the next check_context_transfer_status pass. - # No-op when the queue is empty, including on the gen-side path. - self._drain_pending_releases() for rid in failed: reqs[rid].state = LlmRequestState.DISAGG_TRANS_ERROR sessions[rid].close() @@ -582,13 +504,7 @@ def _get_or_create_send_session(self, req: LlmRequest) -> TxSessionBase: rid = get_unique_rid(req) assert rid is not None if rid not in self._send_sessions: - # Skip early release for beam_width > 1: C++ releasePrefixBlocks - # asserts beamWidth == 1. Chunking still works, but blocks are freed - # at session teardown instead. - callback = self._chunk_callback if req.sampling_config.beam_width <= 1 else None - self._send_sessions[rid] = self._transfer_worker.create_tx_session( - req, on_chunk_transferred=callback - ) + self._send_sessions[rid] = self._transfer_worker.create_tx_session(req) return self._send_sessions[rid] def _finalize_send(self, req: LlmRequest, session: TxSessionBase): @@ -681,8 +597,6 @@ def request_and_receive_async(self, req: LlmRequest) -> None: def check_context_transfer_status( self, at_least_request_num: Optional[int], mark_complete: bool = False ): - self._drain_pending_releases() - block_all = at_least_request_num is None wait_num = at_least_request_num if not block_all else 0 diff --git a/tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py b/tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py index 294cd42fa559..75d69bf2a453 100644 --- a/tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py +++ b/tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py @@ -66,9 +66,9 @@ def create_kv_cache_transceiver( "MPI CacheTransceiver is deprecated, UCX or NIXL is recommended") elif cache_transceiver_config.backend == "UCX": logger.info( - f"Using UCX kv-cache transceiver. If your devices are not in the same domain, please consider setting " - f"UCX_CUDA_IPC_ENABLE_MNNVL=n, UCX_RNDV_SCHEME=put_zcopy and/or unset UCX_NET_DEVICES upon server " - f"hangs or lower-than-expected performance.") + "Using UCX kv-cache transceiver. If your devices are not in the same domain, please consider setting " + "UCX_CUDA_IPC_ENABLE_MNNVL=n, UCX_RNDV_SCHEME=put_zcopy and/or unset UCX_NET_DEVICES upon server " + "hangs or lower-than-expected performance.") # Auto-select Python transceiver when chunk_size_blocks is set, # since the C++ transceiver does not support chunked transfer. @@ -82,11 +82,11 @@ def create_kv_cache_transceiver( # Use warning (not info) so users notice the transceiver swap and # the implied perf / staging-buffer characteristics change. Set # transceiver_runtime='CPP' explicitly to opt out (and lose - # chunked transfer + early block release). + # chunked transfer). logger.warning( "chunk_size_blocks is set; auto-selecting the Python " "transceiver instead of the C++ transceiver to enable " - "chunked KV cache transfer + early block release. " + "chunked KV cache transfer. " "Set transceiver_runtime='CPP' to disable this auto-selection.") use_python = True else: @@ -98,10 +98,9 @@ def create_kv_cache_transceiver( f"enable chunked transfer.") elif (runtime == "CPP" and cache_transceiver_config.chunk_size_blocks is not None): - logger.warning( - "chunk_size_blocks is set but transceiver_runtime='CPP' " - "explicitly disables Python auto-selection; " - "chunk_size_blocks will be ignored.") + logger.warning("chunk_size_blocks is set but transceiver_runtime='CPP' " + "explicitly disables Python auto-selection; " + "chunk_size_blocks will be ignored.") # Warn when chunk_size_blocks is below the recommended floor. The Pydantic # field is PositiveInt (>=1), but values below ~16 push the per-chunk RDMA diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index 8237e6740323..1ef1c2843d8d 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -1010,24 +1010,6 @@ def free_resources(self, request: LlmRequest, pin_on_release: bool = False): return self.impl.remove_sequence(request.py_request_id, request, pin_on_release) - def release_prefix_blocks(self, request_id: int, num_blocks: int) -> None: - """Release leading blocks from a request's V1 KV cache. - - Used by disaggregated serving to free sender-side KV memory - for blocks whose data has already been transferred. The - underlying C++ ``KVCacheManager::releasePrefixBlocks`` frees - blocks via the eviction policy so they can be reused. - - Args: - request_id: The request whose KV cache to partially free. - num_blocks: Number of leading blocks to release - (cumulative from the start of the sequence). - - Note: - No-op if the sequence does not exist (already removed). - """ - self.impl.release_prefix_blocks(request_id, num_blocks) - def store_blocks_for_reuse(self, request: LlmRequest, pin_blocks: bool = False): diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 1a285cfad10a..71cebd3e1c61 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -3374,8 +3374,7 @@ class CacheTransceiverConfig(StrictBaseModel, PybindMirror): "slice is transferred independently. The total data per chunk is " "approximately chunk_size_blocks * num_layer_groups * slot_bytes. " "This reduces per-transfer NIXL descriptor pressure for long " - "sequences and enables early block release to free GPU memory " - "incrementally during transfer. When None (default), the entire " + "sequences. When None (default), the entire " "KV cache is transferred in a single slice. When set with NIXL " "backend (default), the Python transceiver is auto-selected. " "Not supported with UCX, MPI, or MOONCAKE backends.") diff --git a/tests/unittest/disaggregated/test_chunked_transfer.py b/tests/unittest/disaggregated/test_chunked_transfer.py index 7ffc066b415b..50304ae009bf 100644 --- a/tests/unittest/disaggregated/test_chunked_transfer.py +++ b/tests/unittest/disaggregated/test_chunked_transfer.py @@ -14,16 +14,12 @@ # limitations under the License. """Unit tests for chunked KV cache transfer (sender-only chunking). -These tests validate the session state machine, callback plumbing, and -release queue mechanics using the real TxSession/RxSession classes with -lightweight stub sender/receiver objects. +These tests validate the session state machine using the real +TxSession/RxSession classes with lightweight stub sender/receiver objects. """ -import queue from unittest.mock import MagicMock -import pytest - from tensorrt_llm import DisaggregatedParams from tensorrt_llm._torch.disaggregation.base.transfer import KVSlice, SessionStatus, WaitResult from tensorrt_llm._torch.disaggregation.native.transfer import ( @@ -225,146 +221,6 @@ def test_rx_session_wait_complete_fails_on_partial_failure(): assert result == WaitResult.FAILED -# --------------------------------------------------------------------------- -# Chunk completion callback tests -# --------------------------------------------------------------------------- - - -def test_chunk_callback_enqueues_release(): - """Callback from _make_chunk_callback enqueues the correct release entries.""" - from tensorrt_llm._torch.disaggregation.transceiver import KvCacheTransceiverV2 - - transceiver = MagicMock() - transceiver._chunk_size_blocks = 64 - transceiver._pending_prefix_releases = queue.Queue() - transceiver._kv_cache_manager.release_prefix_blocks = MagicMock() - - callback = KvCacheTransceiverV2._make_chunk_callback(transceiver) - assert callback is not None - - callback(request_id=7, chunk_block_offset=0, num_blocks=64) - callback(request_id=7, chunk_block_offset=64, num_blocks=64) - callback(request_id=7, chunk_block_offset=128, num_blocks=64) - - results = [] - while not transceiver._pending_prefix_releases.empty(): - results.append(transceiver._pending_prefix_releases.get_nowait()) - - assert results == [(7, 64), (7, 128), (7, 192)] - - -def test_drain_pending_releases(): - """_drain_pending_releases calls release_prefix_blocks for each entry.""" - transceiver = MagicMock() - transceiver._pending_prefix_releases = queue.Queue() - transceiver._kv_cache_manager = MagicMock() - transceiver._pending_prefix_releases.put((10, 64)) - transceiver._pending_prefix_releases.put((10, 128)) - transceiver._pending_prefix_releases.put((20, 32)) - - from tensorrt_llm._torch.disaggregation.transceiver import KvCacheTransceiverV2 - - KvCacheTransceiverV2._drain_pending_releases(transceiver) - - calls = transceiver._kv_cache_manager.release_prefix_blocks.call_args_list - assert len(calls) == 3 - assert calls[0].args == (10, 64) - assert calls[1].args == (10, 128) - assert calls[2].args == (20, 32) - - -def test_drain_pending_releases_tolerates_stale_rid(): - """A pending release for a request that was already removed must be a no-op. - - Models the production race where the sender worker enqueues a release - after the main thread has already torn the sequence down via - ``removeSequence``. ``KVCacheManager.release_prefix_blocks`` returns - early in that case, so ``_drain_pending_releases`` must not raise. - """ - from tensorrt_llm._torch.disaggregation.transceiver import KvCacheTransceiverV2 - - transceiver = MagicMock() - transceiver._pending_prefix_releases = queue.Queue() - transceiver._kv_cache_manager = MagicMock() - # Manager wrapper is a no-op for unknown rids; drain must propagate that - # no-op semantics rather than crashing. - transceiver._kv_cache_manager.release_prefix_blocks = MagicMock(return_value=None) - - transceiver._pending_prefix_releases.put((9999, 64)) # unknown rid - transceiver._pending_prefix_releases.put((9999, 128)) - - KvCacheTransceiverV2._drain_pending_releases(transceiver) - - calls = transceiver._kv_cache_manager.release_prefix_blocks.call_args_list - assert len(calls) == 2 - assert calls[0].args == (9999, 64) - assert calls[1].args == (9999, 128) - - -def test_drain_pending_releases_empty_queue_is_noop(): - """Drain on an empty queue is a no-op and never calls the manager.""" - from tensorrt_llm._torch.disaggregation.transceiver import KvCacheTransceiverV2 - - transceiver = MagicMock() - transceiver._pending_prefix_releases = queue.Queue() - transceiver._kv_cache_manager = MagicMock() - - KvCacheTransceiverV2._drain_pending_releases(transceiver) - - transceiver._kv_cache_manager.release_prefix_blocks.assert_not_called() - - -@pytest.mark.parametrize( - "has_release,chunk_size,expected_none", - [ - (False, 64, True), - (True, None, True), - (False, None, True), - (True, 64, False), - ], - ids=["no_release_api", "no_chunking", "neither", "with_release_and_chunking"], -) -def test_make_chunk_callback_conditions(has_release, chunk_size, expected_none): - """_make_chunk_callback returns None unless both release API and chunking enabled.""" - from tensorrt_llm._torch.disaggregation.transceiver import KvCacheTransceiverV2 - - transceiver = MagicMock() - transceiver._chunk_size_blocks = chunk_size - transceiver._pending_prefix_releases = queue.Queue() - if has_release: - transceiver._kv_cache_manager.release_prefix_blocks = MagicMock() - else: - del transceiver._kv_cache_manager.release_prefix_blocks - - result = KvCacheTransceiverV2._make_chunk_callback(transceiver) - assert (result is None) == expected_none - - -def test_chunk_callback_then_drain(): - """End-to-end: callback enqueues, drain calls release_prefix_blocks.""" - from tensorrt_llm._torch.disaggregation.transceiver import KvCacheTransceiverV2 - - transceiver = MagicMock() - transceiver._chunk_size_blocks = 4 - transceiver._pending_prefix_releases = queue.Queue() - transceiver._kv_cache_manager.release_prefix_blocks = MagicMock() - - callback = KvCacheTransceiverV2._make_chunk_callback(transceiver) - assert callback is not None - - callback(request_id=1, chunk_block_offset=0, num_blocks=4) - callback(request_id=1, chunk_block_offset=4, num_blocks=4) - callback(request_id=1, chunk_block_offset=8, num_blocks=2) - - KvCacheTransceiverV2._drain_pending_releases(transceiver) - - calls = transceiver._kv_cache_manager.release_prefix_blocks.call_args_list - assert len(calls) == 3 - assert calls[0].args == (1, 4) - assert calls[1].args == (1, 8) - assert calls[2].args == (1, 10) - - # --------------------------------------------------------------------------- # Mid-transfer chunk failure tests # ---------------------------------------------------------------------------