Skip to content
Open
14 changes: 12 additions & 2 deletions tensorrt_llm/_torch/disaggregation/base/transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from tensorrt_llm import DisaggregatedParams
from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest


import torch
@dataclass
class TokenRange:
"""Range of tokens in the sequence dimension."""
Expand Down Expand Up @@ -65,6 +65,7 @@ class KVSlice:
) # Physical block IDs per layer group, each np.ndarray(dtype=np.int64)
is_last_slice: bool = False
mamba_state_index: Optional[int] = None
chunk_block_offset: int = 0


class SessionStatus(Enum):
Expand Down Expand Up @@ -158,7 +159,16 @@ def __init__(self, sender: SenderBase, args: SessionArgsBase):
self._sender = sender

@abstractmethod
def send(self, slice: KVSlice) -> None: ...
def send(self, slice: KVSlice) -> None:
"""Send a KV slice.

Args:
slice: The KV slice describing which source blocks to send.
The slice's ``chunk_block_offset`` field indicates the offset
into the receiver's destination block list for sender-side
chunking.
"""
...


class RxSessionBase(_SessionBase):
Expand Down
74 changes: 62 additions & 12 deletions tensorrt_llm/_torch/disaggregation/native/transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,14 +202,24 @@ def __init__(self, params: DisaggregatedParams, slot: Optional[int]):


class KVSendTask(SendTaskBase):
"""A per-slice send task within a TxSession.

Args:
kv_slice: The KV slice describing which blocks to transfer.
The slice's ``chunk_block_offset`` field indicates the
offset into the receiver's destination block list.
params: Disaggregated serving parameters for this request.
slice_id: Index of this slice within the session's task list.
"""

def __init__(
self,
kv_slice: KVSlice,
params: DisaggregatedParams,
slice_id: int,
prompt_len: Optional[int] = None,
beam_width: int = 1,
):
) -> None:
super().__init__(params)
self.slice_id = slice_id
self.transferred_count = 0
Expand Down Expand Up @@ -474,6 +484,11 @@ def _deliver_kv_to_agent(self, write_meta: WriteMeta):
assert write_meta.slice_id is not None
task = session.kv_tasks[write_meta.slice_id]
timer = task._perf_timer

# For pipelined prefill-transfer: wait for the GPU forward
# to finish writing KV data before starting RDMA. This
# blocks only this worker thread, not the GPU or main thread.

if timer:
timer.record_push_end(write_meta.peer_rank)
# Hold session.lock to serialize the INIT→TRANSFERRING transition with
Expand All @@ -493,7 +508,7 @@ def _deliver_kv_to_agent(self, write_meta: WriteMeta):
f"in {status.value} state; sending FAILED to receiver"
)
# Task may have been enqueued after cancel() already iterated kv_tasks,
# so its future was never set by cancel(). Set it here as a fallback.
# so its event was never set by cancel(). Set it here as a fallback.
task.fail(
RuntimeError(f"session {write_meta.unique_rid} {status.value}, transfer aborted")
)
Expand All @@ -503,7 +518,7 @@ def _deliver_kv_to_agent(self, write_meta: WriteMeta):
str(self._instance_rank).encode("ascii"),
str(write_meta.unique_rid).encode("ascii"),
str(write_meta.slice_id).encode("ascii"),
b"True", # is_last_slice — ensures receiver resolves its task future
b"True", # is_last_slice — ensures receiver resolves its task event
Comment on lines 520 to +521

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🩺 Stability & Availability | 🟠 Major | ⚡ Quick win

Send abort results to receiver slice 0 as well.

The success path makes sender-side chunking transparent by reporting receiver_slice_id = 0; this abort path still sends write_meta.slice_id, so a cancelled/failed later chunk can trip the receiver’s single-task slice assertion instead of unblocking it.

Suggested fix
-                    str(write_meta.slice_id).encode("ascii"),
+                    b"0",
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
str(write_meta.slice_id).encode("ascii"),
b"True", # is_last_slice — ensures receiver resolves its task future
b"True", # is_last_slice — ensures receiver resolves its task event
b"0",
b"True", # is_last_slice — ensures receiver resolves its task event
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tensorrt_llm/_torch/disaggregation/native/transfer.py` around lines 523 -
524, The abort/result notification path in transfer.py still uses
write_meta.slice_id, which can conflict with the receiver’s single-task slice
handling. Update the abort send logic in the relevant transfer routine to mirror
the success path by reporting receiver_slice_id as 0 for aborts too, so the
receiver does not see a later-chunk slice ID and hit its slice assertion. Keep
the existing task/event unblocking behavior intact while ensuring the
aborted/failure result is always sent to receiver slice 0.

AgentResult.FAILED.value.encode("ascii"),
]
)
Expand Down Expand Up @@ -536,13 +551,19 @@ def _deliver_kv_to_agent(self, write_meta: WriteMeta):
if timer:
timer.record_transfer_end(write_meta.peer_rank)

## TODO: just last slice need to send task state?
# The receiver always has a single monolithic task (slice_id=0).
# Sender-side chunking is transparent to the receiver: only the
# last chunk carries is_last_slice=True so the receiver knows
# when all data has arrived. Intermediate chunk results are
# sent (not suppressed) so that RDMA failures propagate to the
# receiver immediately rather than requiring a timeout.
receiver_slice_id = 0
self._get_or_connect_thread_dealer(write_meta.peer_endpoint).send(
[
MessageType.KV_AGENT_RESULT,
str(self._instance_rank).encode("ascii"),
str(write_meta.unique_rid).encode("ascii"),
str(write_meta.slice_id).encode("ascii"),
str(receiver_slice_id).encode("ascii"),
str(write_meta.is_last_slice).encode("ascii"),
agent_result.value.encode("ascii"),
]
Expand Down Expand Up @@ -701,10 +722,24 @@ def _build_kv_write_meta(self, task: KVSendTask, req_info: RecvReqInfo) -> Write
dst_block_ids_per_groups = req_info.block_ids_per_layer_groups
src_block_ids_per_groups = task._slice.block_ids_per_layer_groups

# Aggregate fragments from all matching pools using numpy concatenation
chunk_offset = task._slice.chunk_block_offset
for (self_lg, self_pi), (peer_lg, peer_pi) in pool_mapping.items():
src_block_ids = src_block_ids_per_groups[self_lg]
dst_block_ids = dst_block_ids_per_groups[peer_lg]
full_dst_block_ids = dst_block_ids_per_groups[peer_lg]

# When sender uses chunking, the receiver sends all dst
# blocks in a single RecvReqInfo. Slice dst to match
# this task's src chunk position.
if chunk_offset > 0 or not task._slice.is_last_slice:
chunk_end = chunk_offset + len(src_block_ids)
if chunk_end > full_dst_block_ids.size:
raise ValueError(
f"dst chunk range out of bounds: offset={chunk_offset}, "
f"len={len(src_block_ids)}, dst_blocks={full_dst_block_ids.size}"
)
dst_block_ids = full_dst_block_ids[chunk_offset:chunk_end]
Comment on lines +733 to +740

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎯 Functional Correctness | 🟠 Major | ⚡ Quick win

Allow exhausted layer groups to map to empty destination slices.

With asymmetric layer-group lengths, later chunks can have len(src_block_ids) == 0 while chunk_offset has advanced past that group’s destination length. That should be a no-op, not an out-of-bounds error.

Suggested fix
                 if chunk_offset > 0 or not task._slice.is_last_slice:
+                    if len(src_block_ids) == 0:
+                        dst_block_ids = full_dst_block_ids[:0]
+                        continue
                     chunk_end = chunk_offset + len(src_block_ids)
                     if chunk_end > full_dst_block_ids.size:
                         raise ValueError(
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if chunk_offset > 0 or not task._slice.is_last_slice:
chunk_end = chunk_offset + len(src_block_ids)
if chunk_end > full_dst_block_ids.size:
raise ValueError(
f"dst chunk range out of bounds: offset={chunk_offset}, "
f"len={len(src_block_ids)}, dst_blocks={full_dst_block_ids.size}"
)
dst_block_ids = full_dst_block_ids[chunk_offset:chunk_end]
if chunk_offset > 0 or not task._slice.is_last_slice:
if len(src_block_ids) == 0:
dst_block_ids = full_dst_block_ids[:0]
continue
chunk_end = chunk_offset + len(src_block_ids)
if chunk_end > full_dst_block_ids.size:
raise ValueError(
f"dst chunk range out of bounds: offset={chunk_offset}, "
f"len={len(src_block_ids)}, dst_blocks={full_dst_block_ids.size}"
)
dst_block_ids = full_dst_block_ids[chunk_offset:chunk_end]
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tensorrt_llm/_torch/disaggregation/native/transfer.py` around lines 736 -
743, The chunk-to-destination mapping in transfer.py is too strict for exhausted
layer groups: when len(src_block_ids) is 0, the current bounds check in the
chunk slicing logic still raises on advanced chunk_offset values. Update the
chunk handling around the dst_block_ids slice so empty source chunks become a
no-op and do not trigger the out-of-bounds error; keep the existing bounds
validation for non-empty chunks in the same chunk offset/full_dst_block_ids
path.

else:
dst_block_ids = full_dst_block_ids
Comment on lines +725 to +742

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🗄️ Data Integrity & Integration | 🟠 Major | 🏗️ Heavy lift

Keep chunk token ranges consistent with chunk_block_offset.

This path now slices destination blocks by chunk, but the alignment below still infers token starts as if each chunk’s block list were the suffix ending at token_range.end. If chunked slices carry the full request range, prefix-cache/SWA cases can align and write the wrong blocks. Either require per-chunk KVSlice.token_range from callers or derive the token starts from chunk_block_offset.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tensorrt_llm/_torch/disaggregation/native/transfer.py` around lines 728 -
745, The chunked destination slicing in transfer logic is now using chunk
offsets, but the token alignment still assumes each chunk maps to a suffix
ending at token_range.end. Update the code around the chunked path in
transfer.py and the downstream token-start calculation to derive starts from
chunk_block_offset, or require callers to provide per-chunk KVSlice.token_range
for each chunk. Make sure the block selection and token-range alignment stay
consistent for prefix-cache and SWA cases so the written blocks match the
intended chunk.


# Speculative decoding: generation may have one extra draft-token block.
block_diff = dst_block_ids.size - src_block_ids.size
Expand Down Expand Up @@ -965,7 +1000,7 @@ def _respond_with_kv(self, _send_id: bytes, message: list[bytes]):
self._save_peer_req_info(info)
tasks = list(session.kv_tasks)
# No tasks: no worker will send KV_AGENT_RESULT FAILED to the receiver.
# Send it directly to unblock the receiver's TRANSFERRING task future;
# Send it directly to unblock the receiver's TRANSFERRING task event;
# CANCEL_SESSION alone would leave it stuck indefinitely.
if not tasks and session.status in (SessionStatus.ERROR, SessionStatus.CANCELLED):
self._send_failed_result_to_receiver(info)
Expand Down Expand Up @@ -1143,6 +1178,10 @@ def disagg_request_id(self) -> int:
def status(self) -> SessionStatus:
if self._terminal_status is not None:
return self._terminal_status
if self._exception is not None or any(t.status == TaskStatus.ERROR for t in self.kv_tasks):
return SessionStatus.ERROR
if self.aux_task is not None and self.aux_task.status == TaskStatus.ERROR:
return SessionStatus.ERROR
kv_all_transferred = bool(self.kv_tasks) and all(
t.status == TaskStatus.TRANSFERRED for t in self.kv_tasks
)
Expand Down Expand Up @@ -1732,15 +1771,15 @@ def process_kv_agent_result(
)

def process_aux_agent_result(self, _peer_rank: int, status: AgentResult):
# Aux is session-level (not per-slice); expected_transfers is identical
# across all kv_tasks, so any task provides the right count.
# Aux is session-level (not per-slice); use the final KV task's
# expected transfer count so chunked sessions wait for all senders.
with self.lock:
if not self._kv_tasks:
logger.warning(
f"Aux result received before any KV tasks for request {self.request_id}"
)
return
task = self._kv_tasks[0]
task = self._kv_tasks[-1]
if status == AgentResult.SUCCESS:
self._aux_count += 1

Expand Down Expand Up @@ -1996,7 +2035,18 @@ def populate_instance_and_rank_info(self, endpoints: list[str], layer_num_per_pp
self._rank_info.sender_endpoints = endpoints
self._rank_info.layer_num_per_pp = layer_num_per_pp

def create_tx_session(self, request: LlmRequest) -> TxSession:
def create_tx_session(
self,
request: LlmRequest,
) -> TxSession:
"""Create a TxSession for the given request.

Args:
request: The LLM request to create a send session for.

Returns:
A new ``TxSession`` ready to accept ``send()`` calls.
"""
params = request.py_disaggregated_params
assert params is not None
return TxSession(
Expand Down
Loading