Skip to content

Commit 596d453

Browse files
chienchunhungathena-nv
authored andcommitted
Move chunk_block_offset into KVSlice dataclass
Per reviewer feedback (chuangz0, Shixiaowei02): chunk_block_offset belongs as a member of KVSlice rather than a function parameter on send(). The KVSlice dataclass was designed to carry all slice metadata. - Add chunk_block_offset: int = 0 to KVSlice dataclass - Remove chunk_block_offset from TxSessionBase.send() signature - Remove chunk_block_offset from TxSession.send() signature - Remove chunk_block_offset from KVSendTask.__init__ - Read chunk_block_offset from task._slice in _build_kv_write_meta and _deliver_kv_to_agent callback - Set chunk_block_offset on each KVSlice in _create_kv_slices - Update all tests accordingly Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com> Made-with: Cursor Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com>
1 parent 2b36bca commit 596d453

8 files changed

Lines changed: 144 additions & 31 deletions

File tree

cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -681,6 +681,17 @@ class GenerationRequest
681681
++mNumFrontBlocksRemovedPerWindow.at(windowSize);
682682
}
683683

684+
//! \brief Advance ``mNumFrontBlocksRemoved`` without touching cache blocks.
685+
//! \details Used by ``BlockManager::releasePrefixBlocks`` to advance the
686+
//! shared front-block counter once after every ``WindowBlockManager`` has
687+
//! processed the same prefix range. Has clearer intent than calling
688+
//! ``removeFrontBlock`` with a sentinel ``windowSize`` value, and is robust
689+
//! to future changes that consume the ``windowSize`` argument.
690+
void incrementNumFrontBlocksRemoved()
691+
{
692+
++mNumFrontBlocksRemoved;
693+
}
694+
684695
void removeLastBlock(SizeType32 windowSize)
685696
{
686697
for (auto& beamBlockIds : mCacheBlockIds.at(windowSize))

cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2916,9 +2916,12 @@ void BlockManager::releasePrefixBlocks(GenerationRequest& sequence, SizeType32 n
29162916
manager.releasePrefixBlocks(sequence, startIdx, numBlocks);
29172917
}
29182918
// Advance the shared counter once, after all managers have released.
2919+
// Uses incrementNumFrontBlocksRemoved (counter-only) instead of
2920+
// removeFrontBlock so the intent is explicit and we do not depend on
2921+
// removeFrontBlock ignoring its windowSize argument.
29192922
while (sequence.getNumFrontBlocksRemoved() < numBlocks)
29202923
{
2921-
sequence.removeFrontBlock(0);
2924+
sequence.incrementNumFrontBlocksRemoved();
29222925
}
29232926
}
29242927

@@ -3942,6 +3945,16 @@ std::optional<KVCacheBlock::IdType> KVCacheManager::removeSequence(
39423945

39433946
void KVCacheManager::releasePrefixBlocks(RequestIdType requestId, SizeType32 numBlocks)
39443947
{
3948+
// Hard precondition: BlockManager::releasePrefixBlocks advances the shared
3949+
// mNumFrontBlocksRemoved counter to numBlocks for every WindowBlockManager,
3950+
// even when a window has fewer than numBlocks allocated. Under variable
3951+
// sliding window attention (VSWA), that would cause WindowBlockManager::
3952+
// releaseBlocks (called during removeSequence) to underrun rbegin() and
3953+
// skip tail blocks for the smaller window. Disagg serving already gates
3954+
// VSWA out, but we enforce the assumption here so the C++ API contract is
3955+
// self-defending instead of relying on caller discipline.
3956+
TLLM_CHECK_WITH_INFO(
3957+
!mBlockManager.isVariableWindow(), "releasePrefixBlocks does not support variable sliding window attention");
39453958
if (numBlocks <= 0)
39463959
{
39473960
return;

tensorrt_llm/_torch/disaggregation/base/transfer.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ class KVSlice:
6565
) # Physical block IDs per layer group, each np.ndarray(dtype=np.int64)
6666
is_last_slice: bool = False
6767
mamba_state_index: Optional[int] = None
68+
chunk_block_offset: int = 0
6869

6970

7071
class SessionStatus(Enum):
@@ -158,15 +159,14 @@ def __init__(self, sender: SenderBase, args: SessionArgsBase):
158159
self._sender = sender
159160

160161
@abstractmethod
161-
def send(self, slice: KVSlice, chunk_block_offset: int = 0) -> None:
162+
def send(self, slice: KVSlice) -> None:
162163
"""Send a KV slice.
163164
164165
Args:
165166
slice: The KV slice describing which source blocks to send.
166-
chunk_block_offset: Block offset into the receiver's full
167-
destination block list for this chunk. Used by sender-side
168-
chunking to slice the receiver's destination blocks correctly.
169-
Defaults to 0 for monolithic transfer.
167+
The slice's ``chunk_block_offset`` field indicates the offset
168+
into the receiver's destination block list for sender-side
169+
chunking.
170170
"""
171171
...
172172

tensorrt_llm/_torch/disaggregation/native/transfer.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -208,11 +208,10 @@ class KVSendTask(SendTaskBase):
208208
209209
Args:
210210
kv_slice: The KV slice describing which blocks to transfer.
211+
The slice's ``chunk_block_offset`` field indicates the
212+
offset into the receiver's destination block list.
211213
params: Disaggregated serving parameters for this request.
212214
slice_id: Index of this slice within the session's task list.
213-
chunk_block_offset: Block offset into the receiver's full
214-
destination block list. Used by sender-side chunking to
215-
slice the receiver's destination blocks correctly.
216215
"""
217216

218217
def __init__(
@@ -222,15 +221,13 @@ def __init__(
222221
slice_id: int,
223222
prompt_len: Optional[int] = None,
224223
beam_width: int = 1,
225-
chunk_block_offset: int = 0,
226224
) -> None:
227225
super().__init__(params)
228226
self.slice_id = slice_id
229227
self.transferred_count = 0
230228
self._slice = kv_slice
231229
self._prompt_len = prompt_len
232230
self._beam_width = beam_width
233-
self.chunk_block_offset = chunk_block_offset
234231

235232

236233
class Sender(SenderBase):
@@ -587,7 +584,7 @@ def _deliver_kv_to_agent(self, write_meta: WriteMeta):
587584
)
588585
session._on_chunk_transferred(
589586
request_id=session.request_id,
590-
chunk_block_offset=task.chunk_block_offset,
587+
chunk_block_offset=task._slice.chunk_block_offset,
591588
num_blocks=num_blocks,
592589
)
593590
except Exception as e:
@@ -727,7 +724,7 @@ def _build_kv_write_meta(self, task: KVSendTask, req_info: RecvReqInfo) -> Write
727724
dst_block_ids_per_groups = req_info.block_ids_per_layer_groups
728725
src_block_ids_per_groups = task._slice.block_ids_per_layer_groups
729726

730-
chunk_offset = task.chunk_block_offset
727+
chunk_offset = task._slice.chunk_block_offset
731728
for (self_lg, self_pi), (peer_lg, peer_pi) in pool_mapping.items():
732729
src_block_ids = src_block_ids_per_groups[self_lg]
733730
full_dst_block_ids = dst_block_ids_per_groups[peer_lg]
@@ -1182,7 +1179,7 @@ def status(self) -> SessionStatus:
11821179
return SessionStatus.TRANSFERRING
11831180
return SessionStatus.READY if self.receiver_ready else SessionStatus.INIT
11841181

1185-
def send(self, slice: KVSlice, chunk_block_offset: int = 0) -> None:
1182+
def send(self, slice: KVSlice) -> None:
11861183
with self.lock:
11871184
params = self._base_args.params
11881185
slice_id = len(self.kv_tasks)
@@ -1192,7 +1189,6 @@ def send(self, slice: KVSlice, chunk_block_offset: int = 0) -> None:
11921189
slice_id,
11931190
prompt_len=self._base_args.prompt_len,
11941191
beam_width=self._base_args.beam_width,
1195-
chunk_block_offset=chunk_block_offset,
11961192
)
11971193
task._unique_rid = self.disagg_request_id
11981194
self.kv_tasks.append(task)

tensorrt_llm/_torch/disaggregation/transceiver.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,10 @@ def shutdown(self):
146146
if getattr(self, "_shutdown", False):
147147
return
148148
self._shutdown = True
149+
# Drain any pending prefix-release entries before tearing down sessions
150+
# so memory frees in the same shutdown step instead of leaking until
151+
# removeSequence cleans up at session close.
152+
self._drain_pending_releases()
149153
for session in list(self._send_sessions.values()):
150154
session.close()
151155
for session in list(self._recv_sessions.values()):
@@ -272,6 +276,7 @@ def _create_kv_slices(self, req: LlmRequest) -> List[KVSlice]:
272276

273277
num_chunks = math.ceil(max_blocks / self._chunk_size_blocks)
274278
slices: List[KVSlice] = []
279+
block_offset = 0
275280
for chunk_idx in range(num_chunks):
276281
start = chunk_idx * self._chunk_size_blocks
277282
end = start + self._chunk_size_blocks
@@ -284,8 +289,16 @@ def _create_kv_slices(self, req: LlmRequest) -> List[KVSlice]:
284289
block_ids_per_layer_groups=chunk_block_ids,
285290
mamba_state_index=base_slice.mamba_state_index,
286291
token_range=base_slice.token_range,
292+
chunk_block_offset=block_offset,
287293
)
288294
)
295+
# Use the max length across layer groups to advance the receiver
296+
# offset. This is the contract that lets receiver-side slicing in
297+
# native/transfer.py (`_build_kv_write_meta`) trim the per-LG dst
298+
# range with `len(src_block_ids)`, so asymmetric layer groups still
299+
# land at the right destination position even though the offset is
300+
# shared across groups.
301+
block_offset += max((len(ids) for ids in chunk_block_ids), default=0)
289302

290303
for lg_idx, original_ids in enumerate(all_block_ids):
291304
reassembled = np.concatenate([s.block_ids_per_layer_groups[lg_idx] for s in slices])
@@ -318,8 +331,24 @@ def _make_chunk_callback(self) -> Optional[Callable]:
318331
"""
319332
if self._chunk_size_blocks is None:
320333
return None
334+
manager_name = type(self._kv_cache_manager).__name__
321335
if not hasattr(self._kv_cache_manager, "release_prefix_blocks"):
336+
# Surface the gate decision in logs so a typo or missing wrapper on
337+
# the manager side is observable at startup, not silent.
338+
logger.warning(
339+
"Chunked KV transfer is enabled (chunk_size_blocks=%s) but %s "
340+
"does not implement release_prefix_blocks; early prefix block "
341+
"release is disabled. Blocks will be freed at session teardown.",
342+
self._chunk_size_blocks,
343+
manager_name,
344+
)
322345
return None
346+
logger.info(
347+
"Chunked KV transfer with early prefix block release enabled "
348+
"(chunk_size_blocks=%s, manager=%s).",
349+
self._chunk_size_blocks,
350+
manager_name,
351+
)
323352

324353
release_queue = self._pending_prefix_releases
325354

@@ -515,6 +544,11 @@ def _build_to_process(
515544
return to_process
516545

517546
def _close_failed_sessions(self, sessions: dict, reqs: dict, failed: list):
547+
# Drain pending prefix releases before closing failed sessions so that
548+
# already-completed chunks of healthy sister sessions free memory now
549+
# rather than waiting for the next check_context_transfer_status pass.
550+
# No-op when the queue is empty, including on the gen-side path.
551+
self._drain_pending_releases()
518552
for rid in failed:
519553
reqs[rid].state = LlmRequestState.DISAGG_TRANS_ERROR
520554
sessions[rid].close()
@@ -574,12 +608,8 @@ def _finalize_send(self, req: LlmRequest, session: TxSessionBase):
574608
def respond_and_send_async(self, req: LlmRequest):
575609
session = self._get_or_create_send_session(req)
576610
req.state = LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS
577-
chunk_block_offset = 0
578611
for kv_slice in self._create_kv_slices(req):
579-
session.send(kv_slice, chunk_block_offset=chunk_block_offset)
580-
chunk_block_offset += max(
581-
(len(ids) for ids in kv_slice.block_ids_per_layer_groups), default=0
582-
)
612+
session.send(kv_slice)
583613
self._finalize_send(req, session)
584614

585615
@nvtx_range("KvCacheTransceiverV2.request_and_receive_sync")

tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,15 @@ def create_kv_cache_transceiver(
7878
if (not use_python
7979
and cache_transceiver_config.chunk_size_blocks is not None):
8080
if cache_transceiver_config.backend in (None, "DEFAULT", "NIXL"):
81-
logger.info(
82-
"chunk_size_blocks is set; auto-selecting Python transceiver "
83-
"for chunked KV cache transfer support")
81+
# Use warning (not info) so users notice the transceiver swap and
82+
# the implied perf / staging-buffer characteristics change. Set
83+
# transceiver_runtime='CPP' explicitly to opt out (and lose
84+
# chunked transfer + early block release).
85+
logger.warning(
86+
"chunk_size_blocks is set; auto-selecting the Python "
87+
"transceiver instead of the C++ transceiver to enable "
88+
"chunked KV cache transfer + early block release. "
89+
"Set transceiver_runtime='CPP' to disable this auto-selection.")
8490
use_python = True
8591
else:
8692
logger.warning(
@@ -90,6 +96,20 @@ def create_kv_cache_transceiver(
9096
f"chunk_size_blocks will be ignored. Use NIXL backend to "
9197
f"enable chunked transfer.")
9298

99+
# Warn when chunk_size_blocks is below the recommended floor. The Pydantic
100+
# field is PositiveInt (>=1), but values below ~16 push the per-chunk RDMA
101+
# overhead into the regime where it dominates transfer throughput.
102+
_MIN_RECOMMENDED_CHUNK_SIZE_BLOCKS = 16
103+
if (cache_transceiver_config.chunk_size_blocks is not None
104+
and cache_transceiver_config.chunk_size_blocks
105+
< _MIN_RECOMMENDED_CHUNK_SIZE_BLOCKS):
106+
logger.warning(
107+
f"chunk_size_blocks={cache_transceiver_config.chunk_size_blocks} "
108+
f"is below the recommended floor of "
109+
f"{_MIN_RECOMMENDED_CHUNK_SIZE_BLOCKS}; per-chunk RDMA overhead "
110+
f"may dominate transfer throughput. Consider 64-128 for "
111+
f"long-context workloads (ISL >= 32K).")
112+
93113
# Select transceiver implementation based on transceiver_runtime
94114
# transceiver_runtime == None or "CPP" -> use C++ transceiver (default)
95115
# transceiver_runtime == "PYTHON" -> use Python transceiver

tests/unittest/disaggregated/test_chunked_transfer.py

Lines changed: 49 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,9 @@ def _make_tx_session(num_slices: int, rid: int = 42, **kwargs) -> TxSession:
7474
s = KVSlice(
7575
is_last_slice=(i == num_slices - 1),
7676
block_ids_per_layer_groups=[[i]],
77+
chunk_block_offset=i,
7778
)
78-
session.send(s, chunk_block_offset=i)
79+
session.send(s)
7980
return session
8081

8182

@@ -103,19 +104,19 @@ def _make_rx_session(num_slices: int, rid: int = 42) -> RxSession:
103104

104105

105106
def test_kv_send_task_chunk_block_offset():
106-
"""KVSendTask stores chunk_block_offset correctly."""
107-
s = KVSlice(is_last_slice=False, block_ids_per_layer_groups=[[0, 1]])
108-
task = KVSendTask(s, _make_params(), slice_id=1, chunk_block_offset=512)
109-
assert task.chunk_block_offset == 512
107+
"""KVSendTask reads chunk_block_offset from the slice."""
108+
s = KVSlice(is_last_slice=False, block_ids_per_layer_groups=[[0, 1]], chunk_block_offset=512)
109+
task = KVSendTask(s, _make_params(), slice_id=1)
110+
assert task._slice.chunk_block_offset == 512
110111
assert task.slice_id == 1
111112
assert task._slice is s
112113

113114

114115
def test_kv_send_task_default_offset():
115-
"""Default chunk_block_offset is 0."""
116+
"""Default chunk_block_offset on KVSlice is 0."""
116117
s = KVSlice(is_last_slice=True, block_ids_per_layer_groups=[[0]])
117118
task = KVSendTask(s, _make_params(), slice_id=0)
118-
assert task.chunk_block_offset == 0
119+
assert task._slice.chunk_block_offset == 0
119120

120121

121122
# ---------------------------------------------------------------------------
@@ -281,6 +282,47 @@ def test_drain_pending_releases():
281282
assert calls[2].args == (20, 32)
282283

283284

285+
def test_drain_pending_releases_tolerates_stale_rid():
286+
"""A pending release for a request that was already removed must be a no-op.
287+
288+
Models the production race where the sender worker enqueues a release
289+
after the main thread has already torn the sequence down via
290+
``removeSequence``. ``KVCacheManager.release_prefix_blocks`` returns
291+
early in that case, so ``_drain_pending_releases`` must not raise.
292+
"""
293+
from tensorrt_llm._torch.disaggregation.transceiver import KvCacheTransceiverV2
294+
295+
transceiver = MagicMock()
296+
transceiver._pending_prefix_releases = queue.Queue()
297+
transceiver._kv_cache_manager = MagicMock()
298+
# Manager wrapper is a no-op for unknown rids; drain must propagate that
299+
# no-op semantics rather than crashing.
300+
transceiver._kv_cache_manager.release_prefix_blocks = MagicMock(return_value=None)
301+
302+
transceiver._pending_prefix_releases.put((9999, 64)) # unknown rid
303+
transceiver._pending_prefix_releases.put((9999, 128))
304+
305+
KvCacheTransceiverV2._drain_pending_releases(transceiver)
306+
307+
calls = transceiver._kv_cache_manager.release_prefix_blocks.call_args_list
308+
assert len(calls) == 2
309+
assert calls[0].args == (9999, 64)
310+
assert calls[1].args == (9999, 128)
311+
312+
313+
def test_drain_pending_releases_empty_queue_is_noop():
314+
"""Drain on an empty queue is a no-op and never calls the manager."""
315+
from tensorrt_llm._torch.disaggregation.transceiver import KvCacheTransceiverV2
316+
317+
transceiver = MagicMock()
318+
transceiver._pending_prefix_releases = queue.Queue()
319+
transceiver._kv_cache_manager = MagicMock()
320+
321+
KvCacheTransceiverV2._drain_pending_releases(transceiver)
322+
323+
transceiver._kv_cache_manager.release_prefix_blocks.assert_not_called()
324+
325+
284326
@pytest.mark.parametrize(
285327
"has_release,chunk_size,expected_none",
286328
[

tests/unittest/disaggregated/test_kv_transfer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1564,8 +1564,9 @@ def add_and_verify_chunked_request(
15641564
kv_slice = KVSlice(
15651565
is_last_slice=is_last,
15661566
block_ids_per_layer_groups=chunk_block_ids,
1567+
chunk_block_offset=chunk_offset,
15671568
)
1568-
sender_session.send(kv_slice, chunk_block_offset=chunk_offset)
1569+
sender_session.send(kv_slice)
15691570
chunk_offset += max(len(ids) for ids in chunk_block_ids)
15701571

15711572
receiver_sessions = [

0 commit comments

Comments
 (0)