-
Notifications
You must be signed in to change notification settings - Fork 2.5k
[TRTLLM-12499][feat] (WIP) Add support for pipelined KVCache transfer for disaggregated serving in Python Cache Transceiver #15727
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
80ed3a8
5912964
6ff2b02
e582cb0
4bd5e6e
b622099
ea62778
8ccc142
9aed65c
7705b98
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -202,14 +202,24 @@ def __init__(self, params: DisaggregatedParams, slot: Optional[int]): | |||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| class KVSendTask(SendTaskBase): | ||||||||||||||||||||||||||||||||||||||||
| """A per-slice send task within a TxSession. | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||||||||||
| kv_slice: The KV slice describing which blocks to transfer. | ||||||||||||||||||||||||||||||||||||||||
| The slice's ``chunk_block_offset`` field indicates the | ||||||||||||||||||||||||||||||||||||||||
| offset into the receiver's destination block list. | ||||||||||||||||||||||||||||||||||||||||
| params: Disaggregated serving parameters for this request. | ||||||||||||||||||||||||||||||||||||||||
| slice_id: Index of this slice within the session's task list. | ||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| def __init__( | ||||||||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||||||||
| kv_slice: KVSlice, | ||||||||||||||||||||||||||||||||||||||||
| params: DisaggregatedParams, | ||||||||||||||||||||||||||||||||||||||||
| slice_id: int, | ||||||||||||||||||||||||||||||||||||||||
| prompt_len: Optional[int] = None, | ||||||||||||||||||||||||||||||||||||||||
| beam_width: int = 1, | ||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||
| ) -> None: | ||||||||||||||||||||||||||||||||||||||||
| super().__init__(params) | ||||||||||||||||||||||||||||||||||||||||
| self.slice_id = slice_id | ||||||||||||||||||||||||||||||||||||||||
| self.transferred_count = 0 | ||||||||||||||||||||||||||||||||||||||||
|
|
@@ -474,6 +484,11 @@ def _deliver_kv_to_agent(self, write_meta: WriteMeta): | |||||||||||||||||||||||||||||||||||||||
| assert write_meta.slice_id is not None | ||||||||||||||||||||||||||||||||||||||||
| task = session.kv_tasks[write_meta.slice_id] | ||||||||||||||||||||||||||||||||||||||||
| timer = task._perf_timer | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| # For pipelined prefill-transfer: wait for the GPU forward | ||||||||||||||||||||||||||||||||||||||||
| # to finish writing KV data before starting RDMA. This | ||||||||||||||||||||||||||||||||||||||||
| # blocks only this worker thread, not the GPU or main thread. | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| if timer: | ||||||||||||||||||||||||||||||||||||||||
| timer.record_push_end(write_meta.peer_rank) | ||||||||||||||||||||||||||||||||||||||||
| # Hold session.lock to serialize the INIT→TRANSFERRING transition with | ||||||||||||||||||||||||||||||||||||||||
|
|
@@ -493,7 +508,7 @@ def _deliver_kv_to_agent(self, write_meta: WriteMeta): | |||||||||||||||||||||||||||||||||||||||
| f"in {status.value} state; sending FAILED to receiver" | ||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||
| # Task may have been enqueued after cancel() already iterated kv_tasks, | ||||||||||||||||||||||||||||||||||||||||
| # so its future was never set by cancel(). Set it here as a fallback. | ||||||||||||||||||||||||||||||||||||||||
| # so its event was never set by cancel(). Set it here as a fallback. | ||||||||||||||||||||||||||||||||||||||||
| task.fail( | ||||||||||||||||||||||||||||||||||||||||
| RuntimeError(f"session {write_meta.unique_rid} {status.value}, transfer aborted") | ||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||
|
|
@@ -503,7 +518,7 @@ def _deliver_kv_to_agent(self, write_meta: WriteMeta): | |||||||||||||||||||||||||||||||||||||||
| str(self._instance_rank).encode("ascii"), | ||||||||||||||||||||||||||||||||||||||||
| str(write_meta.unique_rid).encode("ascii"), | ||||||||||||||||||||||||||||||||||||||||
| str(write_meta.slice_id).encode("ascii"), | ||||||||||||||||||||||||||||||||||||||||
| b"True", # is_last_slice — ensures receiver resolves its task future | ||||||||||||||||||||||||||||||||||||||||
| b"True", # is_last_slice — ensures receiver resolves its task event | ||||||||||||||||||||||||||||||||||||||||
| AgentResult.FAILED.value.encode("ascii"), | ||||||||||||||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||
|
|
@@ -536,13 +551,19 @@ def _deliver_kv_to_agent(self, write_meta: WriteMeta): | |||||||||||||||||||||||||||||||||||||||
| if timer: | ||||||||||||||||||||||||||||||||||||||||
| timer.record_transfer_end(write_meta.peer_rank) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| ## TODO: just last slice need to send task state? | ||||||||||||||||||||||||||||||||||||||||
| # The receiver always has a single monolithic task (slice_id=0). | ||||||||||||||||||||||||||||||||||||||||
| # Sender-side chunking is transparent to the receiver: only the | ||||||||||||||||||||||||||||||||||||||||
| # last chunk carries is_last_slice=True so the receiver knows | ||||||||||||||||||||||||||||||||||||||||
| # when all data has arrived. Intermediate chunk results are | ||||||||||||||||||||||||||||||||||||||||
| # sent (not suppressed) so that RDMA failures propagate to the | ||||||||||||||||||||||||||||||||||||||||
| # receiver immediately rather than requiring a timeout. | ||||||||||||||||||||||||||||||||||||||||
| receiver_slice_id = 0 | ||||||||||||||||||||||||||||||||||||||||
| self._get_or_connect_thread_dealer(write_meta.peer_endpoint).send( | ||||||||||||||||||||||||||||||||||||||||
| [ | ||||||||||||||||||||||||||||||||||||||||
| MessageType.KV_AGENT_RESULT, | ||||||||||||||||||||||||||||||||||||||||
| str(self._instance_rank).encode("ascii"), | ||||||||||||||||||||||||||||||||||||||||
| str(write_meta.unique_rid).encode("ascii"), | ||||||||||||||||||||||||||||||||||||||||
| str(write_meta.slice_id).encode("ascii"), | ||||||||||||||||||||||||||||||||||||||||
| str(receiver_slice_id).encode("ascii"), | ||||||||||||||||||||||||||||||||||||||||
| str(write_meta.is_last_slice).encode("ascii"), | ||||||||||||||||||||||||||||||||||||||||
| agent_result.value.encode("ascii"), | ||||||||||||||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||||||||||||||
|
|
@@ -701,10 +722,24 @@ def _build_kv_write_meta(self, task: KVSendTask, req_info: RecvReqInfo) -> Write | |||||||||||||||||||||||||||||||||||||||
| dst_block_ids_per_groups = req_info.block_ids_per_layer_groups | ||||||||||||||||||||||||||||||||||||||||
| src_block_ids_per_groups = task._slice.block_ids_per_layer_groups | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| # Aggregate fragments from all matching pools using numpy concatenation | ||||||||||||||||||||||||||||||||||||||||
| chunk_offset = task._slice.chunk_block_offset | ||||||||||||||||||||||||||||||||||||||||
| for (self_lg, self_pi), (peer_lg, peer_pi) in pool_mapping.items(): | ||||||||||||||||||||||||||||||||||||||||
| src_block_ids = src_block_ids_per_groups[self_lg] | ||||||||||||||||||||||||||||||||||||||||
| dst_block_ids = dst_block_ids_per_groups[peer_lg] | ||||||||||||||||||||||||||||||||||||||||
| full_dst_block_ids = dst_block_ids_per_groups[peer_lg] | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| # When sender uses chunking, the receiver sends all dst | ||||||||||||||||||||||||||||||||||||||||
| # blocks in a single RecvReqInfo. Slice dst to match | ||||||||||||||||||||||||||||||||||||||||
| # this task's src chunk position. | ||||||||||||||||||||||||||||||||||||||||
| if chunk_offset > 0 or not task._slice.is_last_slice: | ||||||||||||||||||||||||||||||||||||||||
| chunk_end = chunk_offset + len(src_block_ids) | ||||||||||||||||||||||||||||||||||||||||
| if chunk_end > full_dst_block_ids.size: | ||||||||||||||||||||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||||||||||||||||||||
| f"dst chunk range out of bounds: offset={chunk_offset}, " | ||||||||||||||||||||||||||||||||||||||||
| f"len={len(src_block_ids)}, dst_blocks={full_dst_block_ids.size}" | ||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||
| dst_block_ids = full_dst_block_ids[chunk_offset:chunk_end] | ||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+733
to
+740
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🎯 Functional Correctness | 🟠 Major | ⚡ Quick win Allow exhausted layer groups to map to empty destination slices. With asymmetric layer-group lengths, later chunks can have Suggested fix if chunk_offset > 0 or not task._slice.is_last_slice:
+ if len(src_block_ids) == 0:
+ dst_block_ids = full_dst_block_ids[:0]
+ continue
chunk_end = chunk_offset + len(src_block_ids)
if chunk_end > full_dst_block_ids.size:
raise ValueError(📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||
| dst_block_ids = full_dst_block_ids | ||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+725
to
+742
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🗄️ Data Integrity & Integration | 🟠 Major | 🏗️ Heavy lift Keep chunk token ranges consistent with This path now slices destination blocks by chunk, but the alignment below still infers token starts as if each chunk’s block list were the suffix ending at 🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| # Speculative decoding: generation may have one extra draft-token block. | ||||||||||||||||||||||||||||||||||||||||
| block_diff = dst_block_ids.size - src_block_ids.size | ||||||||||||||||||||||||||||||||||||||||
|
|
@@ -965,7 +1000,7 @@ def _respond_with_kv(self, _send_id: bytes, message: list[bytes]): | |||||||||||||||||||||||||||||||||||||||
| self._save_peer_req_info(info) | ||||||||||||||||||||||||||||||||||||||||
| tasks = list(session.kv_tasks) | ||||||||||||||||||||||||||||||||||||||||
| # No tasks: no worker will send KV_AGENT_RESULT FAILED to the receiver. | ||||||||||||||||||||||||||||||||||||||||
| # Send it directly to unblock the receiver's TRANSFERRING task future; | ||||||||||||||||||||||||||||||||||||||||
| # Send it directly to unblock the receiver's TRANSFERRING task event; | ||||||||||||||||||||||||||||||||||||||||
| # CANCEL_SESSION alone would leave it stuck indefinitely. | ||||||||||||||||||||||||||||||||||||||||
| if not tasks and session.status in (SessionStatus.ERROR, SessionStatus.CANCELLED): | ||||||||||||||||||||||||||||||||||||||||
| self._send_failed_result_to_receiver(info) | ||||||||||||||||||||||||||||||||||||||||
|
|
@@ -1143,6 +1178,10 @@ def disagg_request_id(self) -> int: | |||||||||||||||||||||||||||||||||||||||
| def status(self) -> SessionStatus: | ||||||||||||||||||||||||||||||||||||||||
| if self._terminal_status is not None: | ||||||||||||||||||||||||||||||||||||||||
| return self._terminal_status | ||||||||||||||||||||||||||||||||||||||||
| if self._exception is not None or any(t.status == TaskStatus.ERROR for t in self.kv_tasks): | ||||||||||||||||||||||||||||||||||||||||
| return SessionStatus.ERROR | ||||||||||||||||||||||||||||||||||||||||
| if self.aux_task is not None and self.aux_task.status == TaskStatus.ERROR: | ||||||||||||||||||||||||||||||||||||||||
| return SessionStatus.ERROR | ||||||||||||||||||||||||||||||||||||||||
| kv_all_transferred = bool(self.kv_tasks) and all( | ||||||||||||||||||||||||||||||||||||||||
| t.status == TaskStatus.TRANSFERRED for t in self.kv_tasks | ||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||
|
|
@@ -1732,15 +1771,15 @@ def process_kv_agent_result( | |||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| def process_aux_agent_result(self, _peer_rank: int, status: AgentResult): | ||||||||||||||||||||||||||||||||||||||||
| # Aux is session-level (not per-slice); expected_transfers is identical | ||||||||||||||||||||||||||||||||||||||||
| # across all kv_tasks, so any task provides the right count. | ||||||||||||||||||||||||||||||||||||||||
| # Aux is session-level (not per-slice); use the final KV task's | ||||||||||||||||||||||||||||||||||||||||
| # expected transfer count so chunked sessions wait for all senders. | ||||||||||||||||||||||||||||||||||||||||
| with self.lock: | ||||||||||||||||||||||||||||||||||||||||
| if not self._kv_tasks: | ||||||||||||||||||||||||||||||||||||||||
| logger.warning( | ||||||||||||||||||||||||||||||||||||||||
| f"Aux result received before any KV tasks for request {self.request_id}" | ||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||||||||
| task = self._kv_tasks[0] | ||||||||||||||||||||||||||||||||||||||||
| task = self._kv_tasks[-1] | ||||||||||||||||||||||||||||||||||||||||
| if status == AgentResult.SUCCESS: | ||||||||||||||||||||||||||||||||||||||||
| self._aux_count += 1 | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
|
|
@@ -1996,7 +2035,18 @@ def populate_instance_and_rank_info(self, endpoints: list[str], layer_num_per_pp | |||||||||||||||||||||||||||||||||||||||
| self._rank_info.sender_endpoints = endpoints | ||||||||||||||||||||||||||||||||||||||||
| self._rank_info.layer_num_per_pp = layer_num_per_pp | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| def create_tx_session(self, request: LlmRequest) -> TxSession: | ||||||||||||||||||||||||||||||||||||||||
| def create_tx_session( | ||||||||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||||||||
| request: LlmRequest, | ||||||||||||||||||||||||||||||||||||||||
| ) -> TxSession: | ||||||||||||||||||||||||||||||||||||||||
| """Create a TxSession for the given request. | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||||||||||
| request: The LLM request to create a send session for. | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||||||||||||||||||||
| A new ``TxSession`` ready to accept ``send()`` calls. | ||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||
| params = request.py_disaggregated_params | ||||||||||||||||||||||||||||||||||||||||
| assert params is not None | ||||||||||||||||||||||||||||||||||||||||
| return TxSession( | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🩺 Stability & Availability | 🟠 Major | ⚡ Quick win
Send abort results to receiver slice
0as well.The success path makes sender-side chunking transparent by reporting
receiver_slice_id = 0; this abort path still sendswrite_meta.slice_id, so a cancelled/failed later chunk can trip the receiver’s single-task slice assertion instead of unblocking it.Suggested fix
📝 Committable suggestion
🤖 Prompt for AI Agents