Skip to content
12 changes: 11 additions & 1 deletion tensorrt_llm/_torch/disaggregation/base/transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class KVSlice:
) # Physical block IDs per layer group, each np.ndarray(dtype=np.int64)
is_last_slice: bool = False
mamba_state_index: Optional[int] = None
chunk_block_offset: int = 0


class SessionStatus(Enum):
Expand Down Expand Up @@ -158,7 +159,16 @@ def __init__(self, sender: SenderBase, args: SessionArgsBase):
self._sender = sender

@abstractmethod
def send(self, slice: KVSlice) -> None: ...
def send(self, slice: KVSlice) -> None:
"""Send a KV slice.

Args:
slice: The KV slice describing which source blocks to send.
The slice's ``chunk_block_offset`` field indicates the offset
into the receiver's destination block list for sender-side
chunking.
"""
...


class RxSessionBase(_SessionBase):
Expand Down
69 changes: 57 additions & 12 deletions tensorrt_llm/_torch/disaggregation/native/transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,14 +202,24 @@ def __init__(self, params: DisaggregatedParams, slot: Optional[int]):


class KVSendTask(SendTaskBase):
"""A per-slice send task within a TxSession.

Args:
kv_slice: The KV slice describing which blocks to transfer.
The slice's ``chunk_block_offset`` field indicates the
offset into the receiver's destination block list.
params: Disaggregated serving parameters for this request.
slice_id: Index of this slice within the session's task list.
"""

def __init__(
self,
kv_slice: KVSlice,
params: DisaggregatedParams,
slice_id: int,
prompt_len: Optional[int] = None,
beam_width: int = 1,
):
) -> None:
super().__init__(params)
self.slice_id = slice_id
self.transferred_count = 0
Expand Down Expand Up @@ -491,7 +501,7 @@ def _deliver_kv_to_agent(self, write_meta: WriteMeta):
f"in {status.value} state; sending FAILED to receiver"
)
# Task may have been enqueued after cancel() already iterated kv_tasks,
# so its future was never set by cancel(). Set it here as a fallback.
# so its event was never set by cancel(). Set it here as a fallback.
task.fail(
RuntimeError(f"session {write_meta.unique_rid} {status.value}, transfer aborted")
)
Expand All @@ -501,7 +511,7 @@ def _deliver_kv_to_agent(self, write_meta: WriteMeta):
str(self._instance_rank).encode("ascii"),
str(write_meta.unique_rid).encode("ascii"),
str(write_meta.slice_id).encode("ascii"),
b"True", # is_last_slice — ensures receiver resolves its task future
b"True", # is_last_slice — ensures receiver resolves its task event
AgentResult.FAILED.value.encode("ascii"),
]
)
Expand All @@ -518,13 +528,19 @@ def _deliver_kv_to_agent(self, write_meta: WriteMeta):
if timer:
timer.record_transfer_end(write_meta.peer_rank)

## TODO: just last slice need to send task state?
# The receiver always has a single monolithic task (slice_id=0).
# Sender-side chunking is transparent to the receiver: only the
# last chunk carries is_last_slice=True so the receiver knows
# when all data has arrived. Intermediate chunk results are
# sent (not suppressed) so that RDMA failures propagate to the
# receiver immediately rather than requiring a timeout.
receiver_slice_id = 0
self._get_or_connect_thread_dealer(write_meta.peer_endpoint).send(
[
MessageType.KV_AGENT_RESULT,
str(self._instance_rank).encode("ascii"),
str(write_meta.unique_rid).encode("ascii"),
str(write_meta.slice_id).encode("ascii"),
str(receiver_slice_id).encode("ascii"),
str(write_meta.is_last_slice).encode("ascii"),
agent_result.value.encode("ascii"),
]
Expand Down Expand Up @@ -683,10 +699,24 @@ def _build_kv_write_meta(self, task: KVSendTask, req_info: RecvReqInfo) -> Write
dst_block_ids_per_groups = req_info.block_ids_per_layer_groups
src_block_ids_per_groups = task._slice.block_ids_per_layer_groups

# Aggregate fragments from all matching pools using numpy concatenation
chunk_offset = task._slice.chunk_block_offset
for (self_lg, self_pi), (peer_lg, peer_pi) in pool_mapping.items():
src_block_ids = src_block_ids_per_groups[self_lg]
dst_block_ids = dst_block_ids_per_groups[peer_lg]
full_dst_block_ids = dst_block_ids_per_groups[peer_lg]

# When sender uses chunking, the receiver sends all dst
# blocks in a single RecvReqInfo. Slice dst to match
# this task's src chunk position.
if chunk_offset > 0 or not task._slice.is_last_slice:
chunk_end = chunk_offset + len(src_block_ids)
if chunk_end > full_dst_block_ids.size:
raise ValueError(
f"dst chunk range out of bounds: offset={chunk_offset}, "
f"len={len(src_block_ids)}, dst_blocks={full_dst_block_ids.size}"
)
dst_block_ids = full_dst_block_ids[chunk_offset:chunk_end]
else:
dst_block_ids = full_dst_block_ids
Comment thread
athena-nv marked this conversation as resolved.

# Speculative decoding: generation may have one extra draft-token block.
block_diff = dst_block_ids.size - src_block_ids.size
Expand Down Expand Up @@ -943,7 +973,7 @@ def _respond_with_kv(self, _send_id: bytes, message: list[bytes]):
self._save_peer_req_info(info)
tasks = list(session.kv_tasks)
# No tasks: no worker will send KV_AGENT_RESULT FAILED to the receiver.
# Send it directly to unblock the receiver's TRANSFERRING task future;
# Send it directly to unblock the receiver's TRANSFERRING task event;
# CANCEL_SESSION alone would leave it stuck indefinitely.
if not tasks and session.status in (SessionStatus.ERROR, SessionStatus.CANCELLED):
self._send_failed_result_to_receiver(info)
Expand Down Expand Up @@ -1115,6 +1145,10 @@ def disagg_request_id(self) -> int:
def status(self) -> SessionStatus:
if self._terminal_status is not None:
return self._terminal_status
if self._exception is not None or any(t.status == TaskStatus.ERROR for t in self.kv_tasks):
return SessionStatus.ERROR
if self.aux_task is not None and self.aux_task.status == TaskStatus.ERROR:
return SessionStatus.ERROR
kv_all_transferred = bool(self.kv_tasks) and all(
t.status == TaskStatus.TRANSFERRED for t in self.kv_tasks
)
Expand Down Expand Up @@ -1702,15 +1736,15 @@ def process_kv_agent_result(
)

def process_aux_agent_result(self, _peer_rank: int, status: AgentResult):
# Aux is session-level (not per-slice); expected_transfers is identical
# across all kv_tasks, so any task provides the right count.
# Aux is session-level (not per-slice); use the final KV task's
# expected transfer count so chunked sessions wait for all senders.
with self.lock:
if not self._kv_tasks:
logger.warning(
f"Aux result received before any KV tasks for request {self.request_id}"
)
return
task = self._kv_tasks[0]
task = self._kv_tasks[-1]
if status == AgentResult.SUCCESS:
self._aux_count += 1

Expand Down Expand Up @@ -1982,7 +2016,18 @@ def populate_instance_and_rank_info(self, endpoints: list[str], layer_num_per_pp
self._rank_info.sender_endpoints = endpoints
self._rank_info.layer_num_per_pp = layer_num_per_pp

def create_tx_session(self, request: LlmRequest) -> TxSession:
def create_tx_session(
self,
request: LlmRequest,
) -> TxSession:
"""Create a TxSession for the given request.

Args:
request: The LLM request to create a send session for.

Returns:
A new ``TxSession`` ready to accept ``send()`` calls.
"""
params = request.py_disaggregated_params
assert params is not None
return TxSession(
Expand Down
105 changes: 102 additions & 3 deletions tensorrt_llm/_torch/disaggregation/transceiver.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
import uuid
from collections import defaultdict
from itertools import chain
Expand Down Expand Up @@ -96,6 +97,7 @@ def __init__(
self._recv_reqs = {}
self._wait_reqs = {}
self._page_table = self._transfer_worker.page_table
self._chunk_size_blocks = cache_transceiver_config.chunk_size_blocks

def _broadcast_instance_name(self) -> str:
if self._dist.rank == 0:
Expand Down Expand Up @@ -221,6 +223,86 @@ def _create_kv_slice(
token_range=token_range,
)

def _collect_base_slice(self, req: LlmRequest) -> KVSlice:
"""Collect a full KVSlice (including metadata) for a request.

This returns the complete slice produced by ``_create_kv_slice``,
preserving fields like ``mamba_state_index`` that are required
for hybrid-state model transfers.

Args:
req: The LLM request whose KV cache block IDs to collect.

Returns:
A ``KVSlice`` with all metadata populated.
"""
return self._create_kv_slice(req)

def _create_kv_slices(self, req: LlmRequest) -> List[KVSlice]:
"""Create one or more KVSlice objects for a request.

When ``chunk_size_blocks`` is ``None``, returns a single slice
covering all blocks. Otherwise, each layer group's block ID
list is partitioned into slices of at most ``chunk_size_blocks``
blocks.

Args:
req: The LLM request to create slices for.

Returns:
A list of ``KVSlice`` objects. Only the last slice has
``is_last_slice=True``.

Raises:
ValueError: If the reassembled block IDs from all slices do not
match the original block IDs.
"""
base_slice = self._collect_base_slice(req)
all_block_ids = base_slice.block_ids_per_layer_groups

if self._chunk_size_blocks is None:
return [base_slice]

max_blocks = max((len(ids) for ids in all_block_ids), default=0)
if max_blocks == 0:
return [base_slice]

num_chunks = math.ceil(max_blocks / self._chunk_size_blocks)
slices: List[KVSlice] = []
block_offset = 0
for chunk_idx in range(num_chunks):
start = chunk_idx * self._chunk_size_blocks
end = start + self._chunk_size_blocks
is_last = chunk_idx == num_chunks - 1

chunk_block_ids = [ids[start:end] for ids in all_block_ids]
slices.append(
KVSlice(
is_last_slice=is_last,
block_ids_per_layer_groups=chunk_block_ids,
mamba_state_index=base_slice.mamba_state_index,
token_range=base_slice.token_range,
chunk_block_offset=block_offset,
)
)
# Use the max length across layer groups to advance the receiver
# offset. This is the contract that lets receiver-side slicing in
# native/transfer.py (`_build_kv_write_meta`) trim the per-LG dst
# range with `len(src_block_ids)`, so asymmetric layer groups still
# land at the right destination position even though the offset is
# shared across groups.
block_offset += max((len(ids) for ids in chunk_block_ids), default=0)

for lg_idx, original_ids in enumerate(all_block_ids):
reassembled = np.concatenate([s.block_ids_per_layer_groups[lg_idx] for s in slices])
if not np.array_equal(reassembled, original_ids):
raise ValueError(
f"Chunking integrity check failed for layer group {lg_idx}: "
f"expected {len(original_ids)} blocks, got {len(reassembled)}"
)

return slices

@staticmethod
def _split_packed_beam_block_ids(
block_ids: np.ndarray,
Expand Down Expand Up @@ -446,7 +528,8 @@ def _finalize_send(self, req: LlmRequest, session: TxSessionBase):
def respond_and_send_async(self, req: LlmRequest):
session = self._get_or_create_send_session(req)
req.state = LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS
session.send(self._create_kv_slice(req))
for kv_slice in self._create_kv_slices(req):
session.send(kv_slice)
self._finalize_send(req, session)

@nvtx_range("KvCacheTransceiverV2.request_and_receive_sync")
Expand Down Expand Up @@ -482,7 +565,17 @@ def request_and_receive_sync(self, req: LlmRequest):
self._recv_reqs.pop(rid, None)

@nvtx_range("KvCacheTransceiverV2.request_and_receive_async")
def request_and_receive_async(self, req: LlmRequest):
def request_and_receive_async(self, req: LlmRequest) -> None:
"""Start background KV cache receive from the context server.

The receiver always uses a single monolithic slice. Chunking is
sender-only: the sender splits its source blocks into chunks and
slices the receiver's destination blocks to match each chunk.

Args:
req: The generation request whose KV cache blocks to receive
into.
"""
rid = get_unique_rid(req)
if rid in self._recv_sessions:
logger.warning(
Expand All @@ -492,7 +585,13 @@ def request_and_receive_async(self, req: LlmRequest):
req.state = LlmRequestState.DISAGG_GENERATION_TRANS_IN_PROGRESS
session = self._transfer_worker.create_rx_session(req)
self._recv_sessions[rid] = session
session.receive(self._create_kv_slice(req))
base_slice = self._collect_base_slice(req)
full_slice = KVSlice(
is_last_slice=True,
block_ids_per_layer_groups=base_slice.block_ids_per_layer_groups,
mamba_state_index=base_slice.mamba_state_index,
)
session.receive(full_slice)
self._recv_reqs[rid] = req

def check_context_transfer_status(
Expand Down
56 changes: 51 additions & 5 deletions tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,16 +66,62 @@ def create_kv_cache_transceiver(
"MPI CacheTransceiver is deprecated, UCX or NIXL is recommended")
elif cache_transceiver_config.backend == "UCX":
logger.info(
f"Using UCX kv-cache transceiver. If your devices are not in the same domain, please consider setting "
f"UCX_CUDA_IPC_ENABLE_MNNVL=n, UCX_RNDV_SCHEME=put_zcopy and/or unset UCX_NET_DEVICES upon server "
f"hangs or lower-than-expected performance.")
"Using UCX kv-cache transceiver. If your devices are not in the same domain, please consider setting "
"UCX_CUDA_IPC_ENABLE_MNNVL=n, UCX_RNDV_SCHEME=put_zcopy and/or unset UCX_NET_DEVICES upon server "
"hangs or lower-than-expected performance.")

# Auto-select Python transceiver when chunk_size_blocks is set,
# since the C++ transceiver does not support chunked transfer.
# Only applies to NIXL/DEFAULT backends (the Python transceiver
# does not support UCX, MPI, or MOONCAKE).
runtime = cache_transceiver_config.transceiver_runtime
use_python = runtime == "PYTHON"
if (runtime is None
and cache_transceiver_config.chunk_size_blocks is not None):
if cache_transceiver_config.backend in (None, "DEFAULT", "NIXL"):
# Use warning (not info) so users notice the transceiver swap and
# the implied perf / staging-buffer characteristics change. Set
# transceiver_runtime='CPP' explicitly to opt out (and lose
# chunked transfer).
logger.warning(
"chunk_size_blocks is set; auto-selecting the Python "
"transceiver instead of the C++ transceiver to enable "
"chunked KV cache transfer. "
"Set transceiver_runtime='CPP' to disable this auto-selection.")
use_python = True
Comment thread
coderabbitai[bot] marked this conversation as resolved.
else:
logger.warning(
f"chunk_size_blocks is set but backend "
f"'{cache_transceiver_config.backend}' requires the C++ "
f"transceiver, which does not support chunked transfer. "
f"chunk_size_blocks will be ignored. Use NIXL backend to "
f"enable chunked transfer.")
elif (runtime == "CPP"
and cache_transceiver_config.chunk_size_blocks is not None):
logger.warning("chunk_size_blocks is set but transceiver_runtime='CPP' "
"explicitly disables Python auto-selection; "
"chunk_size_blocks will be ignored.")

# Warn when chunk_size_blocks is below the recommended floor. The Pydantic
# field is PositiveInt (>=1), but values below ~16 push the per-chunk RDMA
# overhead into the regime where it dominates transfer throughput.
_MIN_RECOMMENDED_CHUNK_SIZE_BLOCKS = 16
if (cache_transceiver_config.chunk_size_blocks is not None
and cache_transceiver_config.chunk_size_blocks
< _MIN_RECOMMENDED_CHUNK_SIZE_BLOCKS):
logger.warning(
f"chunk_size_blocks={cache_transceiver_config.chunk_size_blocks} "
f"is below the recommended floor of "
f"{_MIN_RECOMMENDED_CHUNK_SIZE_BLOCKS}; per-chunk RDMA overhead "
f"may dominate transfer throughput. Consider 64-128 for "
f"long-context workloads (ISL >= 32K).")

# Select transceiver implementation based on transceiver_runtime
# transceiver_runtime == None or "CPP" -> use C++ transceiver (default)
# transceiver_runtime == "PYTHON" -> use Python transceiver
if cache_transceiver_config.transceiver_runtime == "PYTHON":
if use_python:
# Python transceiver currently only supports NIXL and DEFAULT backend
if cache_transceiver_config.backend not in ("DEFAULT", "NIXL"):
if cache_transceiver_config.backend not in (None, "DEFAULT", "NIXL"):
raise ValueError(
f"Python transceiver currently only supports NIXL or DEFAULT backend, "
f"got {cache_transceiver_config.backend}. "
Expand Down
Loading
Loading