diff --git a/tensorrt_llm/_torch/disaggregation/base/transfer.py b/tensorrt_llm/_torch/disaggregation/base/transfer.py index a80b03153e42..3ff67cc8d5c2 100644 --- a/tensorrt_llm/_torch/disaggregation/base/transfer.py +++ b/tensorrt_llm/_torch/disaggregation/base/transfer.py @@ -65,6 +65,7 @@ class KVSlice: ) # Physical block IDs per layer group, each np.ndarray(dtype=np.int64) is_last_slice: bool = False mamba_state_index: Optional[int] = None + chunk_block_offset: int = 0 class SessionStatus(Enum): @@ -158,7 +159,16 @@ def __init__(self, sender: SenderBase, args: SessionArgsBase): self._sender = sender @abstractmethod - def send(self, slice: KVSlice) -> None: ... + def send(self, slice: KVSlice) -> None: + """Send a KV slice. + + Args: + slice: The KV slice describing which source blocks to send. + The slice's ``chunk_block_offset`` field indicates the offset + into the receiver's destination block list for sender-side + chunking. + """ + ... class RxSessionBase(_SessionBase): diff --git a/tensorrt_llm/_torch/disaggregation/native/transfer.py b/tensorrt_llm/_torch/disaggregation/native/transfer.py index 1068bebfd7e0..315fe31edf3d 100644 --- a/tensorrt_llm/_torch/disaggregation/native/transfer.py +++ b/tensorrt_llm/_torch/disaggregation/native/transfer.py @@ -202,6 +202,16 @@ def __init__(self, params: DisaggregatedParams, slot: Optional[int]): class KVSendTask(SendTaskBase): + """A per-slice send task within a TxSession. + + Args: + kv_slice: The KV slice describing which blocks to transfer. + The slice's ``chunk_block_offset`` field indicates the + offset into the receiver's destination block list. + params: Disaggregated serving parameters for this request. + slice_id: Index of this slice within the session's task list. + """ + def __init__( self, kv_slice: KVSlice, @@ -209,7 +219,7 @@ def __init__( slice_id: int, prompt_len: Optional[int] = None, beam_width: int = 1, - ): + ) -> None: super().__init__(params) self.slice_id = slice_id self.transferred_count = 0 @@ -491,7 +501,7 @@ def _deliver_kv_to_agent(self, write_meta: WriteMeta): f"in {status.value} state; sending FAILED to receiver" ) # Task may have been enqueued after cancel() already iterated kv_tasks, - # so its future was never set by cancel(). Set it here as a fallback. + # so its event was never set by cancel(). Set it here as a fallback. task.fail( RuntimeError(f"session {write_meta.unique_rid} {status.value}, transfer aborted") ) @@ -501,7 +511,7 @@ def _deliver_kv_to_agent(self, write_meta: WriteMeta): str(self._instance_rank).encode("ascii"), str(write_meta.unique_rid).encode("ascii"), str(write_meta.slice_id).encode("ascii"), - b"True", # is_last_slice — ensures receiver resolves its task future + b"True", # is_last_slice — ensures receiver resolves its task event AgentResult.FAILED.value.encode("ascii"), ] ) @@ -518,13 +528,19 @@ def _deliver_kv_to_agent(self, write_meta: WriteMeta): if timer: timer.record_transfer_end(write_meta.peer_rank) - ## TODO: just last slice need to send task state? + # The receiver always has a single monolithic task (slice_id=0). + # Sender-side chunking is transparent to the receiver: only the + # last chunk carries is_last_slice=True so the receiver knows + # when all data has arrived. Intermediate chunk results are + # sent (not suppressed) so that RDMA failures propagate to the + # receiver immediately rather than requiring a timeout. + receiver_slice_id = 0 self._get_or_connect_thread_dealer(write_meta.peer_endpoint).send( [ MessageType.KV_AGENT_RESULT, str(self._instance_rank).encode("ascii"), str(write_meta.unique_rid).encode("ascii"), - str(write_meta.slice_id).encode("ascii"), + str(receiver_slice_id).encode("ascii"), str(write_meta.is_last_slice).encode("ascii"), agent_result.value.encode("ascii"), ] @@ -683,10 +699,24 @@ def _build_kv_write_meta(self, task: KVSendTask, req_info: RecvReqInfo) -> Write dst_block_ids_per_groups = req_info.block_ids_per_layer_groups src_block_ids_per_groups = task._slice.block_ids_per_layer_groups - # Aggregate fragments from all matching pools using numpy concatenation + chunk_offset = task._slice.chunk_block_offset for (self_lg, self_pi), (peer_lg, peer_pi) in pool_mapping.items(): src_block_ids = src_block_ids_per_groups[self_lg] - dst_block_ids = dst_block_ids_per_groups[peer_lg] + full_dst_block_ids = dst_block_ids_per_groups[peer_lg] + + # When sender uses chunking, the receiver sends all dst + # blocks in a single RecvReqInfo. Slice dst to match + # this task's src chunk position. + if chunk_offset > 0 or not task._slice.is_last_slice: + chunk_end = chunk_offset + len(src_block_ids) + if chunk_end > full_dst_block_ids.size: + raise ValueError( + f"dst chunk range out of bounds: offset={chunk_offset}, " + f"len={len(src_block_ids)}, dst_blocks={full_dst_block_ids.size}" + ) + dst_block_ids = full_dst_block_ids[chunk_offset:chunk_end] + else: + dst_block_ids = full_dst_block_ids # Speculative decoding: generation may have one extra draft-token block. block_diff = dst_block_ids.size - src_block_ids.size @@ -943,7 +973,7 @@ def _respond_with_kv(self, _send_id: bytes, message: list[bytes]): self._save_peer_req_info(info) tasks = list(session.kv_tasks) # No tasks: no worker will send KV_AGENT_RESULT FAILED to the receiver. - # Send it directly to unblock the receiver's TRANSFERRING task future; + # Send it directly to unblock the receiver's TRANSFERRING task event; # CANCEL_SESSION alone would leave it stuck indefinitely. if not tasks and session.status in (SessionStatus.ERROR, SessionStatus.CANCELLED): self._send_failed_result_to_receiver(info) @@ -1115,6 +1145,10 @@ def disagg_request_id(self) -> int: def status(self) -> SessionStatus: if self._terminal_status is not None: return self._terminal_status + if self._exception is not None or any(t.status == TaskStatus.ERROR for t in self.kv_tasks): + return SessionStatus.ERROR + if self.aux_task is not None and self.aux_task.status == TaskStatus.ERROR: + return SessionStatus.ERROR kv_all_transferred = bool(self.kv_tasks) and all( t.status == TaskStatus.TRANSFERRED for t in self.kv_tasks ) @@ -1702,15 +1736,15 @@ def process_kv_agent_result( ) def process_aux_agent_result(self, _peer_rank: int, status: AgentResult): - # Aux is session-level (not per-slice); expected_transfers is identical - # across all kv_tasks, so any task provides the right count. + # Aux is session-level (not per-slice); use the final KV task's + # expected transfer count so chunked sessions wait for all senders. with self.lock: if not self._kv_tasks: logger.warning( f"Aux result received before any KV tasks for request {self.request_id}" ) return - task = self._kv_tasks[0] + task = self._kv_tasks[-1] if status == AgentResult.SUCCESS: self._aux_count += 1 @@ -1982,7 +2016,18 @@ def populate_instance_and_rank_info(self, endpoints: list[str], layer_num_per_pp self._rank_info.sender_endpoints = endpoints self._rank_info.layer_num_per_pp = layer_num_per_pp - def create_tx_session(self, request: LlmRequest) -> TxSession: + def create_tx_session( + self, + request: LlmRequest, + ) -> TxSession: + """Create a TxSession for the given request. + + Args: + request: The LLM request to create a send session for. + + Returns: + A new ``TxSession`` ready to accept ``send()`` calls. + """ params = request.py_disaggregated_params assert params is not None return TxSession( diff --git a/tensorrt_llm/_torch/disaggregation/transceiver.py b/tensorrt_llm/_torch/disaggregation/transceiver.py index 6db88215b816..b2f3ec3ca514 100644 --- a/tensorrt_llm/_torch/disaggregation/transceiver.py +++ b/tensorrt_llm/_torch/disaggregation/transceiver.py @@ -1,3 +1,4 @@ +import math import uuid from collections import defaultdict from itertools import chain @@ -96,6 +97,7 @@ def __init__( self._recv_reqs = {} self._wait_reqs = {} self._page_table = self._transfer_worker.page_table + self._chunk_size_blocks = cache_transceiver_config.chunk_size_blocks def _broadcast_instance_name(self) -> str: if self._dist.rank == 0: @@ -221,6 +223,86 @@ def _create_kv_slice( token_range=token_range, ) + def _collect_base_slice(self, req: LlmRequest) -> KVSlice: + """Collect a full KVSlice (including metadata) for a request. + + This returns the complete slice produced by ``_create_kv_slice``, + preserving fields like ``mamba_state_index`` that are required + for hybrid-state model transfers. + + Args: + req: The LLM request whose KV cache block IDs to collect. + + Returns: + A ``KVSlice`` with all metadata populated. + """ + return self._create_kv_slice(req) + + def _create_kv_slices(self, req: LlmRequest) -> List[KVSlice]: + """Create one or more KVSlice objects for a request. + + When ``chunk_size_blocks`` is ``None``, returns a single slice + covering all blocks. Otherwise, each layer group's block ID + list is partitioned into slices of at most ``chunk_size_blocks`` + blocks. + + Args: + req: The LLM request to create slices for. + + Returns: + A list of ``KVSlice`` objects. Only the last slice has + ``is_last_slice=True``. + + Raises: + ValueError: If the reassembled block IDs from all slices do not + match the original block IDs. + """ + base_slice = self._collect_base_slice(req) + all_block_ids = base_slice.block_ids_per_layer_groups + + if self._chunk_size_blocks is None: + return [base_slice] + + max_blocks = max((len(ids) for ids in all_block_ids), default=0) + if max_blocks == 0: + return [base_slice] + + num_chunks = math.ceil(max_blocks / self._chunk_size_blocks) + slices: List[KVSlice] = [] + block_offset = 0 + for chunk_idx in range(num_chunks): + start = chunk_idx * self._chunk_size_blocks + end = start + self._chunk_size_blocks + is_last = chunk_idx == num_chunks - 1 + + chunk_block_ids = [ids[start:end] for ids in all_block_ids] + slices.append( + KVSlice( + is_last_slice=is_last, + block_ids_per_layer_groups=chunk_block_ids, + mamba_state_index=base_slice.mamba_state_index, + token_range=base_slice.token_range, + chunk_block_offset=block_offset, + ) + ) + # Use the max length across layer groups to advance the receiver + # offset. This is the contract that lets receiver-side slicing in + # native/transfer.py (`_build_kv_write_meta`) trim the per-LG dst + # range with `len(src_block_ids)`, so asymmetric layer groups still + # land at the right destination position even though the offset is + # shared across groups. + block_offset += max((len(ids) for ids in chunk_block_ids), default=0) + + for lg_idx, original_ids in enumerate(all_block_ids): + reassembled = np.concatenate([s.block_ids_per_layer_groups[lg_idx] for s in slices]) + if not np.array_equal(reassembled, original_ids): + raise ValueError( + f"Chunking integrity check failed for layer group {lg_idx}: " + f"expected {len(original_ids)} blocks, got {len(reassembled)}" + ) + + return slices + @staticmethod def _split_packed_beam_block_ids( block_ids: np.ndarray, @@ -446,7 +528,8 @@ def _finalize_send(self, req: LlmRequest, session: TxSessionBase): def respond_and_send_async(self, req: LlmRequest): session = self._get_or_create_send_session(req) req.state = LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS - session.send(self._create_kv_slice(req)) + for kv_slice in self._create_kv_slices(req): + session.send(kv_slice) self._finalize_send(req, session) @nvtx_range("KvCacheTransceiverV2.request_and_receive_sync") @@ -482,7 +565,17 @@ def request_and_receive_sync(self, req: LlmRequest): self._recv_reqs.pop(rid, None) @nvtx_range("KvCacheTransceiverV2.request_and_receive_async") - def request_and_receive_async(self, req: LlmRequest): + def request_and_receive_async(self, req: LlmRequest) -> None: + """Start background KV cache receive from the context server. + + The receiver always uses a single monolithic slice. Chunking is + sender-only: the sender splits its source blocks into chunks and + slices the receiver's destination blocks to match each chunk. + + Args: + req: The generation request whose KV cache blocks to receive + into. + """ rid = get_unique_rid(req) if rid in self._recv_sessions: logger.warning( @@ -492,7 +585,13 @@ def request_and_receive_async(self, req: LlmRequest): req.state = LlmRequestState.DISAGG_GENERATION_TRANS_IN_PROGRESS session = self._transfer_worker.create_rx_session(req) self._recv_sessions[rid] = session - session.receive(self._create_kv_slice(req)) + base_slice = self._collect_base_slice(req) + full_slice = KVSlice( + is_last_slice=True, + block_ids_per_layer_groups=base_slice.block_ids_per_layer_groups, + mamba_state_index=base_slice.mamba_state_index, + ) + session.receive(full_slice) self._recv_reqs[rid] = req def check_context_transfer_status( diff --git a/tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py b/tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py index 6367829d55ce..75d69bf2a453 100644 --- a/tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py +++ b/tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py @@ -66,16 +66,62 @@ def create_kv_cache_transceiver( "MPI CacheTransceiver is deprecated, UCX or NIXL is recommended") elif cache_transceiver_config.backend == "UCX": logger.info( - f"Using UCX kv-cache transceiver. If your devices are not in the same domain, please consider setting " - f"UCX_CUDA_IPC_ENABLE_MNNVL=n, UCX_RNDV_SCHEME=put_zcopy and/or unset UCX_NET_DEVICES upon server " - f"hangs or lower-than-expected performance.") + "Using UCX kv-cache transceiver. If your devices are not in the same domain, please consider setting " + "UCX_CUDA_IPC_ENABLE_MNNVL=n, UCX_RNDV_SCHEME=put_zcopy and/or unset UCX_NET_DEVICES upon server " + "hangs or lower-than-expected performance.") + + # Auto-select Python transceiver when chunk_size_blocks is set, + # since the C++ transceiver does not support chunked transfer. + # Only applies to NIXL/DEFAULT backends (the Python transceiver + # does not support UCX, MPI, or MOONCAKE). + runtime = cache_transceiver_config.transceiver_runtime + use_python = runtime == "PYTHON" + if (runtime is None + and cache_transceiver_config.chunk_size_blocks is not None): + if cache_transceiver_config.backend in (None, "DEFAULT", "NIXL"): + # Use warning (not info) so users notice the transceiver swap and + # the implied perf / staging-buffer characteristics change. Set + # transceiver_runtime='CPP' explicitly to opt out (and lose + # chunked transfer). + logger.warning( + "chunk_size_blocks is set; auto-selecting the Python " + "transceiver instead of the C++ transceiver to enable " + "chunked KV cache transfer. " + "Set transceiver_runtime='CPP' to disable this auto-selection.") + use_python = True + else: + logger.warning( + f"chunk_size_blocks is set but backend " + f"'{cache_transceiver_config.backend}' requires the C++ " + f"transceiver, which does not support chunked transfer. " + f"chunk_size_blocks will be ignored. Use NIXL backend to " + f"enable chunked transfer.") + elif (runtime == "CPP" + and cache_transceiver_config.chunk_size_blocks is not None): + logger.warning("chunk_size_blocks is set but transceiver_runtime='CPP' " + "explicitly disables Python auto-selection; " + "chunk_size_blocks will be ignored.") + + # Warn when chunk_size_blocks is below the recommended floor. The Pydantic + # field is PositiveInt (>=1), but values below ~16 push the per-chunk RDMA + # overhead into the regime where it dominates transfer throughput. + _MIN_RECOMMENDED_CHUNK_SIZE_BLOCKS = 16 + if (cache_transceiver_config.chunk_size_blocks is not None + and cache_transceiver_config.chunk_size_blocks + < _MIN_RECOMMENDED_CHUNK_SIZE_BLOCKS): + logger.warning( + f"chunk_size_blocks={cache_transceiver_config.chunk_size_blocks} " + f"is below the recommended floor of " + f"{_MIN_RECOMMENDED_CHUNK_SIZE_BLOCKS}; per-chunk RDMA overhead " + f"may dominate transfer throughput. Consider 64-128 for " + f"long-context workloads (ISL >= 32K).") # Select transceiver implementation based on transceiver_runtime # transceiver_runtime == None or "CPP" -> use C++ transceiver (default) # transceiver_runtime == "PYTHON" -> use Python transceiver - if cache_transceiver_config.transceiver_runtime == "PYTHON": + if use_python: # Python transceiver currently only supports NIXL and DEFAULT backend - if cache_transceiver_config.backend not in ("DEFAULT", "NIXL"): + if cache_transceiver_config.backend not in (None, "DEFAULT", "NIXL"): raise ValueError( f"Python transceiver currently only supports NIXL or DEFAULT backend, " f"got {cache_transceiver_config.backend}. " diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 97a2ec22ccb7..71cebd3e1c61 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -3365,7 +3365,23 @@ class CacheTransceiverConfig(StrictBaseModel, PybindMirror): "Timeout in milliseconds to wait for the sender future to be ready when scheduled batch size is 0. This allows the request to be eventually cancelled by the user or because of kv_transfer_timeout_ms" ) + chunk_size_blocks: Optional[PositiveInt] = Field( + default=None, + description= + "Maximum number of KV cache blocks per layer group per chunk for " + "chunked KV cache transfer. When set, each layer group's block list " + "is partitioned into slices of at most this many blocks, and each " + "slice is transferred independently. The total data per chunk is " + "approximately chunk_size_blocks * num_layer_groups * slot_bytes. " + "This reduces per-transfer NIXL descriptor pressure for long " + "sequences. When None (default), the entire " + "KV cache is transferred in a single slice. When set with NIXL " + "backend (default), the Python transceiver is auto-selected. " + "Not supported with UCX, MPI, or MOONCAKE backends.") + def _to_pybind(self): + # chunk_size_blocks is consumed by the Python transceiver only + # and has no C++ counterpart, so it is intentionally omitted. return _CacheTransceiverConfig( backend=_CacheTransceiverBackendType.from_string(self.backend), max_tokens_in_buffer=self.max_tokens_in_buffer, diff --git a/tests/integration/defs/accuracy/test_disaggregated_serving.py b/tests/integration/defs/accuracy/test_disaggregated_serving.py index 5c0039e8ac32..ff7db8fbf1e9 100644 --- a/tests/integration/defs/accuracy/test_disaggregated_serving.py +++ b/tests/integration/defs/accuracy/test_disaggregated_serving.py @@ -733,6 +733,49 @@ def test_kv_cache_v2_nixl_python(self): self.MODEL_PATH) as llm: run_accuracy_test(llm, self.MODEL_NAME, ["GSM8K"]) + @skip_pre_hopper + @pytest.mark.skip_less_device(2) + @parametrize_with_ids("chunk_size_blocks", [64]) + @parametrize_with_ids("enable_block_reuse", [False, True]) + def test_chunked_kv_transfer_nixl_python_accuracy(self, + chunk_size_blocks: int, + enable_block_reuse: bool): + """Test chunked KV transfer accuracy using Python transceiver and C++ KVCacheManager.""" + kv_cache_config = { + "use_kv_cache_manager_v2": False, + "enable_block_reuse": enable_block_reuse, + } + cache_transceiver_config = { + "backend": "NIXL", + "transceiver_runtime": "PYTHON", + "max_tokens_in_buffer": 4096, + "chunk_size_blocks": chunk_size_blocks, + } + ctx_server_config = { + "disable_overlap_scheduler": True, + "kv_cache_config": dict(kv_cache_config), + "cache_transceiver_config": dict(cache_transceiver_config), + } + gen_server_config = { + "disable_overlap_scheduler": False, + "kv_cache_config": dict(kv_cache_config), + "cache_transceiver_config": dict(cache_transceiver_config), + } + disaggregated_server_config = { + "hostname": "localhost", + "backend": "pytorch", + "context_servers": { + "num_instances": 1, + }, + "generation_servers": { + "num_instances": 1, + }, + } + with launch_disaggregated_llm(disaggregated_server_config, + ctx_server_config, gen_server_config, + self.MODEL_PATH) as llm: + run_accuracy_test(llm, self.MODEL_NAME, ["GSM8K"]) + @pytest.mark.skip_less_device(2) def test_ngram(self): speculative_decoding_config = { diff --git a/tests/integration/test_lists/test-db/l0_dgx_b200.yml b/tests/integration/test_lists/test-db/l0_dgx_b200.yml index c1ea19000626..685eaa868cf4 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_b200.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_b200.yml @@ -17,6 +17,7 @@ l0_dgx_b200: tests: - unittest/_torch/misc/test_autotuner.py::test_autotuner_distributed_strategy - accuracy/test_llm_api_pytorch.py::TestQwen3_5_35B_A3B::test_bf16[tp2-TRTLLM] + - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_chunked_kv_transfer_nixl_python_accuracy # ------------- KV Cache V2 Scheduler IT (multi-GPU) --------------- - kv_cache/test_kv_cache_v2_scheduler.py::TestKVCacheV2DSv3Lite::test_mtp_draft_tokens - kv_cache/test_kv_cache_v2_scheduler.py::TestKVCacheV2DSv3Lite::test_mtp_chunked_draft_tokens diff --git a/tests/unittest/disaggregated/test_chunked_transfer.py b/tests/unittest/disaggregated/test_chunked_transfer.py new file mode 100644 index 000000000000..50304ae009bf --- /dev/null +++ b/tests/unittest/disaggregated/test_chunked_transfer.py @@ -0,0 +1,254 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for chunked KV cache transfer (sender-only chunking). + +These tests validate the session state machine using the real +TxSession/RxSession classes with lightweight stub sender/receiver objects. +""" + +from unittest.mock import MagicMock + +from tensorrt_llm import DisaggregatedParams +from tensorrt_llm._torch.disaggregation.base.transfer import KVSlice, SessionStatus, WaitResult +from tensorrt_llm._torch.disaggregation.native.transfer import ( + AgentResult, + KVSendTask, + RxSession, + TaskStatus, + TxSession, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_params(rid: int = 42) -> DisaggregatedParams: + return DisaggregatedParams(disagg_request_id=rid) + + +def _stub_sender(): + """Create a stub sender with no-op methods needed by TxSession.""" + sender = MagicMock() + sender.setup_session = MagicMock() + sender._get_req_info = MagicMock(return_value=None) + sender.dispatch_task = MagicMock() + return sender + + +def _stub_receiver(): + """Create a stub receiver with no-op methods needed by RxSession.""" + receiver = MagicMock() + receiver.setup_session = MagicMock() + receiver.dispatch_task = MagicMock() + return receiver + + +def _make_tx_session(num_slices: int, rid: int = 42, **kwargs) -> TxSession: + """Create a real TxSession and send num_slices slices into it.""" + params = _make_params(rid) + session = TxSession( + request_id=rid, + params=params, + sender=_stub_sender(), + **kwargs, + ) + for i in range(num_slices): + s = KVSlice( + is_last_slice=(i == num_slices - 1), + block_ids_per_layer_groups=[[i]], + chunk_block_offset=i, + ) + session.send(s) + return session + + +def _make_rx_session(num_slices: int, rid: int = 42) -> RxSession: + """Create a real RxSession and receive num_slices slices into it.""" + params = _make_params(rid) + session = RxSession( + request_id=rid, + params=params, + receiver=_stub_receiver(), + ) + for i in range(num_slices): + s = KVSlice( + is_last_slice=(i == num_slices - 1), + block_ids_per_layer_groups=[[i]], + ) + session.receive(s) + return session + + +# --------------------------------------------------------------------------- +# KVSendTask tests +# --------------------------------------------------------------------------- + + +def test_kv_send_task_chunk_block_offset(): + """KVSendTask reads chunk_block_offset from the slice.""" + s = KVSlice(is_last_slice=False, block_ids_per_layer_groups=[[0, 1]], chunk_block_offset=512) + task = KVSendTask(s, _make_params(), slice_id=1) + assert task._slice.chunk_block_offset == 512 + assert task.slice_id == 1 + assert task._slice is s + + +def test_kv_send_task_default_offset(): + """Default chunk_block_offset on KVSlice is 0.""" + s = KVSlice(is_last_slice=True, block_ids_per_layer_groups=[[0]]) + task = KVSendTask(s, _make_params(), slice_id=0) + assert task._slice.chunk_block_offset == 0 + + +# --------------------------------------------------------------------------- +# TxSession multi-slice status tests (real class) +# --------------------------------------------------------------------------- + + +def test_tx_session_status_init_until_all_transferred(): + """TxSession status is not KV_TRANSFERRED until ALL tasks complete.""" + session = _make_tx_session(3) + session.receiver_ready = True + assert session.status == SessionStatus.TRANSFERRING or session.status == SessionStatus.READY + + session.kv_tasks[0].status = TaskStatus.TRANSFERRED + assert session.status != SessionStatus.KV_TRANSFERRED + + session.kv_tasks[1].status = TaskStatus.TRANSFERRED + assert session.status != SessionStatus.KV_TRANSFERRED + + session.kv_tasks[2].status = TaskStatus.TRANSFERRED + assert session.status == SessionStatus.KV_TRANSFERRED + + +def test_tx_session_status_error_on_any_failure(): + """TxSession status is ERROR if any task fails.""" + session = _make_tx_session(3) + session.kv_tasks[0].status = TaskStatus.TRANSFERRED + session.kv_tasks[1].status = TaskStatus.ERROR + assert session.status == SessionStatus.ERROR + + +def test_tx_session_wait_complete_all_tasks(): + """TxSession.wait_complete blocks on all task futures.""" + session = _make_tx_session(3) + for task in session.kv_tasks: + task.complete() + + result = session.wait_complete() + assert result == WaitResult.COMPLETED + + +def test_tx_session_wait_complete_fails_on_partial_failure(): + """TxSession.wait_complete returns FAILED if any task fails.""" + session = _make_tx_session(3) + session.kv_tasks[0].complete() + session.kv_tasks[1].fail(RuntimeError("transfer failed")) + session.kv_tasks[2].complete() + + result = session.wait_complete() + assert result == WaitResult.FAILED + + +# --------------------------------------------------------------------------- +# RxSession multi-slice status tests (real class) +# --------------------------------------------------------------------------- + + +def test_rx_session_status_checks_all_tasks(): + """RxSession status is KV_TRANSFERRED only when ALL tasks complete.""" + session = _make_rx_session(3) + assert session.status == SessionStatus.INIT + + session._kv_tasks[0].status = TaskStatus.TRANSFERRED + session._kv_tasks[1].status = TaskStatus.TRANSFERRING + assert session.status == SessionStatus.TRANSFERRING + + session._kv_tasks[1].status = TaskStatus.TRANSFERRED + session._kv_tasks[2].status = TaskStatus.TRANSFERRED + assert session.status == SessionStatus.KV_TRANSFERRED + + +def test_rx_session_status_error_on_any_failure(): + """RxSession status is ERROR if any task fails.""" + session = _make_rx_session(2) + session._kv_tasks[0].status = TaskStatus.TRANSFERRED + session._kv_tasks[1].status = TaskStatus.ERROR + assert session.status == SessionStatus.ERROR + + +def test_rx_session_process_aux_uses_last_task(): + """process_aux_agent_result uses the last task's expected_transfers.""" + session = _make_rx_session(3) + session._kv_tasks[0].expected_transfers = 99 + session._kv_tasks[1].expected_transfers = 99 + session._kv_tasks[2].expected_transfers = 1 + + session.process_aux_agent_result(0, AgentResult.SUCCESS) + assert session._aux_status == TaskStatus.TRANSFERRED + + +def test_rx_session_wait_complete_all_tasks(): + """RxSession.wait_complete blocks on all task futures.""" + session = _make_rx_session(3) + for task in session._kv_tasks: + task.complete() + + result = session.wait_complete() + assert result == WaitResult.COMPLETED + + +def test_rx_session_wait_complete_fails_on_partial_failure(): + """RxSession.wait_complete returns FAILED if any task fails.""" + session = _make_rx_session(2) + session._kv_tasks[0].complete() + session._kv_tasks[1].fail(RuntimeError("transfer failed")) + + result = session.wait_complete() + assert result == WaitResult.FAILED + + +# --------------------------------------------------------------------------- +# Mid-transfer chunk failure tests +# --------------------------------------------------------------------------- + + +def test_tx_session_mid_chunk_failure(): + """If one chunk fails mid-transfer, the session reports ERROR.""" + session = _make_tx_session(4) + + session.kv_tasks[0].complete() + session.kv_tasks[1].complete() + session.kv_tasks[2].fail(RuntimeError("RDMA failed")) + session.kv_tasks[3].complete() + + assert session.status == SessionStatus.ERROR + result = session.wait_complete() + assert result == WaitResult.FAILED + + +def test_rx_session_mid_chunk_failure(): + """If one chunk fails mid-transfer on receiver, the session reports ERROR.""" + session = _make_rx_session(4) + + session._kv_tasks[0].complete() + session._kv_tasks[1].fail(RuntimeError("RDMA failed")) + session._kv_tasks[2].complete() + session._kv_tasks[3].complete() + + assert session.status == SessionStatus.ERROR + result = session.wait_complete() + assert result == WaitResult.FAILED diff --git a/tests/unittest/disaggregated/test_kv_transfer.py b/tests/unittest/disaggregated/test_kv_transfer.py index caf8dcb3f6d6..6282455e2579 100644 --- a/tests/unittest/disaggregated/test_kv_transfer.py +++ b/tests/unittest/disaggregated/test_kv_transfer.py @@ -148,6 +148,93 @@ def test_session_status_enum(): assert len(SessionStatus) == 7 +# --------------------------------------------------------------------------- +# Chunked KV slice creation tests +# --------------------------------------------------------------------------- + + +def _chunk_block_ids(all_block_ids, chunk_size_blocks, mamba_state_index=None): + """Call the real _create_kv_slices via a mock transceiver.""" + from unittest.mock import MagicMock + + from tensorrt_llm._torch.disaggregation.base.transfer import KVSlice + from tensorrt_llm._torch.disaggregation.transceiver import KvCacheTransceiverV2 + + base_slice = KVSlice( + is_last_slice=True, + block_ids_per_layer_groups=all_block_ids, + mamba_state_index=mamba_state_index, + ) + + transceiver = MagicMock() + transceiver._chunk_size_blocks = chunk_size_blocks + transceiver._collect_base_slice = MagicMock(return_value=base_slice) + transceiver._create_kv_slices = KvCacheTransceiverV2._create_kv_slices.__get__(transceiver) + + req = MagicMock() + return transceiver._create_kv_slices(req) + + +@pytest.mark.parametrize( + "all_block_ids,chunk_size,expected_num_slices", + [ + ([[0, 1, 2, 3, 4, 5, 6, 7]], None, 1), + ([[0, 1, 2, 3, 4, 5, 6, 7]], 4, 2), + ([list(range(10))], 4, 3), + ([[], []], 4, 1), + ([[0, 1, 2]], 64, 1), + ], + ids=["no_chunking", "even_split", "uneven_split", "empty_blocks", "chunk_larger_than_total"], +) +def test_create_kv_slices_basic(all_block_ids, chunk_size, expected_num_slices): + """Chunking produces the expected number of slices.""" + slices = _chunk_block_ids(all_block_ids, chunk_size_blocks=chunk_size) + assert len(slices) == expected_num_slices + assert slices[-1].is_last_slice is True + if expected_num_slices > 1: + for s in slices[:-1]: + assert s.is_last_slice is False + + +def test_create_kv_slices_integrity_check(): + """Reassembled block IDs from all slices must match the original.""" + all_block_ids = [list(range(17)), list(range(5))] + slices = _chunk_block_ids(all_block_ids, chunk_size_blocks=4) + for lg_idx, original in enumerate(all_block_ids): + reassembled = [] + for s in slices: + reassembled.extend(s.block_ids_per_layer_groups[lg_idx]) + assert reassembled == original + + +def test_create_kv_slices_multiple_layer_groups(): + """Different layer groups with different block counts produce correct chunking.""" + all_block_ids = [list(range(8)), list(range(3))] + slices = _chunk_block_ids(all_block_ids, chunk_size_blocks=4) + assert len(slices) == 2 + assert slices[0].block_ids_per_layer_groups[0] == [0, 1, 2, 3] + assert slices[1].block_ids_per_layer_groups[0] == [4, 5, 6, 7] + assert slices[0].block_ids_per_layer_groups[1] == [0, 1, 2] + assert slices[1].block_ids_per_layer_groups[1] == [] + + +def test_create_kv_slices_preserves_mamba_state_index(): + """mamba_state_index is propagated to every chunk slice.""" + all_block_ids = [list(range(8))] + slices = _chunk_block_ids(all_block_ids, chunk_size_blocks=4, mamba_state_index=42) + assert len(slices) == 2 + for s in slices: + assert s.mamba_state_index == 42 + + +def test_create_kv_slices_none_mamba_state_index(): + """mamba_state_index=None is preserved when not set.""" + all_block_ids = [list(range(4))] + slices = _chunk_block_ids(all_block_ids, chunk_size_blocks=4) + assert len(slices) == 1 + assert slices[0].mamba_state_index is None + + def create_transfer_worker_setup( ctx_tp: int, ctx_pp: int, @@ -1277,7 +1364,7 @@ def test_session_cancel_before_send(): @pytest.mark.timeout(60) def test_session_cancel_after_send(): - """TxSession cancelled after send() queues INIT tasks; future raises.""" + """TxSession cancelled after send() queues INIT tasks fails the event wait.""" tensorrt_llm.logger.set_level("debug") setup = create_transfer_worker_setup( ctx_tp=1, @@ -1312,19 +1399,241 @@ def test_session_cancel_after_send(): page_table = ctx_transfer_worker._rank_info.page_table block_ids_per_groups = [np.array([], dtype=np.int64) for _ in page_table.layer_groups] kv_slice = KVSlice(is_last_slice=True, block_ids_per_layer_groups=block_ids_per_groups) - future = tx_session.send(kv_slice) + tx_session.send(kv_slice) # No receiver registered yet; task is INIT. tx_session.cancel() assert tx_session.status == SessionStatus.CANCELLED assert tx_session.has_failed() - # Future for the cancelled INIT task must raise. - with pytest.raises(Exception): - future.result(timeout=5.0) + assert tx_session.wait_complete() == WaitResult.FAILED tx_session.close() finally: ctx_transfer_worker.shutdown() + + +def _setup_chunked_request(setup, ctx_request_id, gen_request_id, request_len): + """Create requests, allocate KV, and collect block IDs for chunked transfer tests.""" + ctx_transfer_workers = setup["ctx_transfer_workers"] + ctx_kv_cache_managers = setup["ctx_kv_cache_managers"] + gen_transfer_workers = setup["gen_transfer_workers"] + gen_kv_cache_managers = setup["gen_kv_cache_managers"] + ctx_info_endpoint = setup["ctx_info_endpoint"] + use_v2 = setup["use_v2"] + tokens_per_block = setup["tokens_per_block"] + + sampling_params = SamplingParams() + unique_rid = uuid.uuid4().int & 0x7FFFFFFFFFFFFFFF + + ctx_request = LlmRequest( + request_id=ctx_request_id, + max_new_tokens=1, + input_tokens=list(range(request_len)), + sampling_config=tensorrt_llm.bindings.SamplingConfig( + sampling_params._get_sampling_config() + ), + is_streaming=False, + llm_request_type=LlmRequestType.LLMREQUEST_TYPE_CONTEXT_ONLY, + ) + ctx_request.py_disaggregated_params = DisaggregatedParams(disagg_request_id=unique_rid) + + gen_request = LlmRequest( + request_id=gen_request_id, + max_new_tokens=1, + input_tokens=list(range(request_len)), + sampling_config=tensorrt_llm.bindings.SamplingConfig( + sampling_params._get_sampling_config() + ), + is_streaming=False, + llm_request_type=LlmRequestType.LLMREQUEST_TYPE_GENERATION_ONLY, + ) + gen_request.py_disaggregated_params = DisaggregatedParams( + ctx_request_id=ctx_request.py_request_id, + ctx_dp_rank=0, + ctx_info_endpoint=ctx_info_endpoint, + disagg_request_id=unique_rid, + ) + + ctx_kv_caches, gen_kv_caches = [], [] + for mgr in ctx_kv_cache_managers: + if use_v2: + kv = mgr._create_kv_cache(ctx_request.py_request_id, None, None) + assert kv.resume(torch.cuda.current_stream().cuda_stream) + assert kv.resize(request_len) + ctx_kv_caches.append(kv) + else: + mgr.impl.add_sequence_batch( + [(ctx_request.py_request_id, request_len, 1)], [ctx_request] + ) + + for mgr in gen_kv_cache_managers: + if use_v2: + kv = mgr._create_kv_cache(gen_request.py_request_id, None, None) + assert kv.resume(torch.cuda.current_stream().cuda_stream) + assert kv.resize(request_len) + gen_kv_caches.append(kv) + else: + mgr.impl.add_sequence_batch( + [(gen_request.py_request_id, request_len, 1)], [gen_request] + ) + + ctx_block_ids = [ + get_block_ids_per_layer_groups(mgr, tw, ctx_request.py_request_id, use_v2, tokens_per_block) + for mgr, tw in zip(ctx_kv_cache_managers, ctx_transfer_workers, strict=True) + ] + gen_block_ids = [ + get_block_ids_per_layer_groups(mgr, tw, gen_request.py_request_id, use_v2, tokens_per_block) + for mgr, tw in zip(gen_kv_cache_managers, gen_transfer_workers, strict=True) + ] + + return { + "ctx_request": ctx_request, + "gen_request": gen_request, + "ctx_kv_caches": ctx_kv_caches, + "gen_kv_caches": gen_kv_caches, + "ctx_block_ids": ctx_block_ids, + "gen_block_ids": gen_block_ids, + } + + +def _verify_and_cleanup_chunked(setup, ctx_info, sender_sessions, receiver_sessions): + """Shared verification and cleanup for chunked transfer tests.""" + ctx_kv_cache_managers = setup["ctx_kv_cache_managers"] + gen_kv_cache_managers = setup["gen_kv_cache_managers"] + use_v2 = setup["use_v2"] + + ctx_block_ids = ctx_info["ctx_block_ids"] + gen_block_ids = ctx_info["gen_block_ids"] + + for session in sender_sessions: + assert session.status == SessionStatus.KV_TRANSFERRED + for session in receiver_sessions: + assert session.status == SessionStatus.KV_TRANSFERRED + + num_layer_groups = len(ctx_block_ids[0]) + for lg_id in range(num_layer_groups): + ctx_data = [ + get_block_data(mgr, bids[lg_id], lg_id, use_v2, ctx_info["ctx_request"].py_request_id) + for mgr, bids in zip(ctx_kv_cache_managers, ctx_block_ids, strict=True) + ] + gen_data = [ + get_block_data(mgr, bids[lg_id], lg_id, use_v2, ctx_info["gen_request"].py_request_id) + for mgr, bids in zip(gen_kv_cache_managers, gen_block_ids, strict=True) + ] + for c, g in zip(ctx_data, gen_data, strict=True): + assert c.equal(g), f"Layer group {lg_id}: data mismatch with chunked transfer" + + for s in receiver_sessions: + s.close() + for s in sender_sessions: + s.close() + if use_v2: + torch.cuda.current_stream().synchronize() + for kv in ctx_info["ctx_kv_caches"]: + kv.close() + for kv in ctx_info["gen_kv_caches"]: + kv.close() + + +def add_and_verify_chunked_request( + setup, + ctx_request_id, + gen_request_id, + request_len, + chunk_size_blocks, +): + """Chunked transfer variant: sender sends N slices, receiver sends 1.""" + import math + + ctx_transfer_workers = setup["ctx_transfer_workers"] + gen_transfer_workers = setup["gen_transfer_workers"] + + ctx_info = _setup_chunked_request(setup, ctx_request_id, gen_request_id, request_len) + ctx_block_ids = ctx_info["ctx_block_ids"] + gen_block_ids = ctx_info["gen_block_ids"] + token_range = TokenRange(start=0, end=request_len) + + sender_sessions = [tw.create_tx_session(ctx_info["ctx_request"]) for tw in ctx_transfer_workers] + for sender_session, block_ids_per_groups in zip(sender_sessions, ctx_block_ids, strict=True): + max_blocks = max(len(ids) for ids in block_ids_per_groups) + num_chunks = math.ceil(max_blocks / chunk_size_blocks) + chunk_offset = 0 + for chunk_idx in range(num_chunks): + start = chunk_idx * chunk_size_blocks + end = start + chunk_size_blocks + is_last = chunk_idx == num_chunks - 1 + chunk_block_ids = [ids[start:end] for ids in block_ids_per_groups] + kv_slice = KVSlice( + is_last_slice=is_last, + block_ids_per_layer_groups=chunk_block_ids, + token_range=token_range, + chunk_block_offset=chunk_offset, + ) + sender_session.send(kv_slice) + chunk_offset += max(len(ids) for ids in chunk_block_ids) + + receiver_sessions = [ + tw.create_rx_session(ctx_info["gen_request"]) for tw in gen_transfer_workers + ] + for recv_session, block_ids_per_groups in zip(receiver_sessions, gen_block_ids, strict=True): + full_slice = KVSlice( + is_last_slice=True, + block_ids_per_layer_groups=block_ids_per_groups, + token_range=token_range, + ) + recv_session.receive(full_slice) + + for session in sender_sessions: + result = session.wait_complete() + assert result == WaitResult.COMPLETED, f"tx wait_complete returned {result}" + for session in receiver_sessions: + result = session.wait_complete(blocking=True) + assert result == WaitResult.COMPLETED, f"rx wait_complete returned {result}" + + _verify_and_cleanup_chunked(setup, ctx_info, sender_sessions, receiver_sessions) + + +CHUNKED_TEST_CONFIGS = [ + (1, 1, False, 1, 1, False, False, True, "v2_tp1_pp1_chunked"), + (1, 1, False, 1, 1, False, False, False, "v1_tp1_pp1_chunked"), +] + + +@pytest.mark.timeout(120) +@pytest.mark.parametrize( + "ctx_tp,ctx_pp,ctx_enable_dp,gen_tp,gen_pp,gen_enable_dp,is_mla,use_v2", + [(c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]) for c in CHUNKED_TEST_CONFIGS], + ids=[c[8] for c in CHUNKED_TEST_CONFIGS], +) +def test_transfer_worker_chunked( + ctx_tp, ctx_pp, ctx_enable_dp, gen_tp, gen_pp, gen_enable_dp, is_mla, use_v2 +): + """Test transfer worker with sender-side chunking for V1 and V2.""" + tensorrt_llm.logger.set_level("info") + logger.info(f"Test transfer worker {'V2' if use_v2 else 'V1'} with chunked transfer") + + setup = create_transfer_worker_setup( + ctx_tp=ctx_tp, + ctx_pp=ctx_pp, + ctx_enable_dp=ctx_enable_dp, + gen_tp=gen_tp, + gen_pp=gen_pp, + gen_enable_dp=gen_enable_dp, + is_mla=is_mla, + use_v2=use_v2, + ) + + request_len = setup["request_len"] + tokens_per_block = setup["tokens_per_block"] + total_blocks = (request_len + tokens_per_block - 1) // tokens_per_block + chunk_size = max(1, total_blocks // 2) + + try: + add_and_verify_chunked_request(setup, 0, 1, request_len, chunk_size_blocks=chunk_size) + add_and_verify_chunked_request(setup, 2, 3, request_len * 2, chunk_size_blocks=chunk_size) + finally: + for worker in setup["ctx_transfer_workers"]: + worker.shutdown() for worker in setup["gen_transfer_workers"]: worker.shutdown() diff --git a/tests/unittest/llmapi/test_llm_args.py b/tests/unittest/llmapi/test_llm_args.py index d0c37dc581a3..0e4c7185c5dd 100644 --- a/tests/unittest/llmapi/test_llm_args.py +++ b/tests/unittest/llmapi/test_llm_args.py @@ -1761,6 +1761,23 @@ def test_cache_transceiver_config_arbitrary_args(self): CacheTransceiverConfig(backend="UCX", invalid_config="should_fail") assert "invalid_config" in str(exc_info.value) + def test_cache_transceiver_config_chunk_size_blocks(self): + """Test chunk_size_blocks field validation.""" + config = CacheTransceiverConfig(chunk_size_blocks=64) + assert config.chunk_size_blocks == 64 + + config_none = CacheTransceiverConfig(chunk_size_blocks=None) + assert config_none.chunk_size_blocks is None + + config_default = CacheTransceiverConfig() + assert config_default.chunk_size_blocks is None + + with pytest.raises(pydantic_core._pydantic_core.ValidationError): + CacheTransceiverConfig(chunk_size_blocks=0) + + with pytest.raises(pydantic_core._pydantic_core.ValidationError): + CacheTransceiverConfig(chunk_size_blocks=-1) + def test_torch_compile_config_arbitrary_args(self): """Test that TorchCompileConfig rejects arbitrary arguments.""" # Valid arguments should work