Skip to content

Commit 326f280

Browse files
chuangz0Shixiaowei02
authored andcommitted
[None][fix] fix token_range_end add extra_kv_num_tokens (#14258)
Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com>
1 parent f60357a commit 326f280

2 files changed

Lines changed: 93 additions & 1 deletion

File tree

tensorrt_llm/_torch/disaggregation/transceiver.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,12 @@ def _create_kv_slice(
181181
)
182182

183183
if token_range is None and req.prompt_len > 0:
184-
token_range = TokenRange(start=0, end=req.prompt_len)
184+
# Align with KV cache allocation (resize_context /
185+
# _get_context_bytes), which reserves prompt_len +
186+
# num_extra_kv_tokens slots for speculative decoding methods
187+
# (e.g. EAGLE3) that consume extra KV positions per request.
188+
num_extra_kv_tokens = getattr(self._kv_cache_manager, "num_extra_kv_tokens", 0) or 0
189+
token_range = TokenRange(start=0, end=req.prompt_len + num_extra_kv_tokens)
185190

186191
groups = []
187192
for idx, lg in enumerate(layer_groups):

tests/unittest/disaggregated/test_cache_reuse_adapter.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
# limitations under the License.
1515
"""Tests for CacheReuseAdapter, _create_kv_slice SWA trim, and Sender token-start derivation."""
1616

17+
from types import SimpleNamespace
18+
1719
import numpy as np
1820
import pytest
1921

@@ -265,6 +267,91 @@ def test_start_eq_end_rejected(self):
265267
TokenRange(start=256, end=256)
266268

267269

270+
# ---------------------------------------------------------------------------
271+
# _create_kv_slice: default TokenRange spans prompt_len + num_extra_kv_tokens
272+
# so transferred KV matches what resize_context / _get_context_bytes allocate.
273+
# ---------------------------------------------------------------------------
274+
275+
276+
def _build_transceiver_for_kv_slice(num_extra_kv_tokens: int, prompt_len: int):
277+
"""Stub a KvCacheTransceiverV2 so _create_kv_slice runs without dist setup.
278+
279+
Wires only the attributes the method touches:
280+
- reuse adapter: tokens_per_block, per-layer-group cached count, block ids
281+
- page table: layer groups
282+
- cache manager: num_extra_kv_tokens (read in this code path)
283+
"""
284+
tokens_per_block = 8
285+
layer_group = AttentionLayerGroup(pool_group_idx=0, kv_head_num_per_rank=1)
286+
total_blocks = (prompt_len + num_extra_kv_tokens + tokens_per_block - 1) // tokens_per_block
287+
block_ids = np.arange(total_blocks, dtype=np.int64)
288+
289+
reuse_adapter = SimpleNamespace(
290+
tokens_per_block=tokens_per_block,
291+
get_cached_token_count_per_layer_group=lambda req, layer_groups: [0] * len(layer_groups),
292+
get_block_ids=lambda req, idx, lg: block_ids,
293+
)
294+
page_table = SimpleNamespace(layer_groups=[layer_group])
295+
cache_manager = SimpleNamespace(num_extra_kv_tokens=num_extra_kv_tokens)
296+
297+
transceiver = object.__new__(KvCacheTransceiverV2)
298+
transceiver._reuse_adapter = reuse_adapter
299+
transceiver._page_table = page_table
300+
transceiver._kv_cache_manager = cache_manager
301+
302+
req = SimpleNamespace(
303+
prompt_len=prompt_len,
304+
py_request_id=0,
305+
is_generation_only_request=lambda: False,
306+
)
307+
return transceiver, req
308+
309+
310+
class TestCreateKvSliceTokenRange:
311+
"""Default TokenRange built by _create_kv_slice must align with KV-cache allocation.
312+
313+
KV cache allocation in resize_context (V2) and prepare_resources (V1) reserves
314+
prompt_len + num_extra_kv_tokens slots whenever speculative decoding (e.g.
315+
EAGLE3, MTP) consumes extra KV positions per request. The transferred token
316+
range must cover the same span, otherwise the receiver under-receives KV.
317+
"""
318+
319+
def test_includes_num_extra_kv_tokens(self):
320+
prompt_len = 17
321+
num_extra_kv_tokens = 7
322+
transceiver, req = _build_transceiver_for_kv_slice(num_extra_kv_tokens, prompt_len)
323+
324+
kv_slice = transceiver._create_kv_slice(req)
325+
326+
assert kv_slice.token_range is not None
327+
assert (kv_slice.token_range.start, kv_slice.token_range.end) == (
328+
0,
329+
prompt_len + num_extra_kv_tokens,
330+
)
331+
332+
def test_defaults_to_prompt_len_when_no_extra(self):
333+
prompt_len = 17
334+
transceiver, req = _build_transceiver_for_kv_slice(
335+
num_extra_kv_tokens=0, prompt_len=prompt_len
336+
)
337+
338+
kv_slice = transceiver._create_kv_slice(req)
339+
340+
assert kv_slice.token_range is not None
341+
assert (kv_slice.token_range.start, kv_slice.token_range.end) == (0, prompt_len)
342+
343+
def test_respects_explicit_token_range(self):
344+
prompt_len = 17
345+
transceiver, req = _build_transceiver_for_kv_slice(
346+
num_extra_kv_tokens=7, prompt_len=prompt_len
347+
)
348+
explicit = TokenRange(start=0, end=8)
349+
350+
kv_slice = transceiver._create_kv_slice(req, token_range=explicit)
351+
352+
assert kv_slice.token_range is explicit
353+
354+
268355
# ---------------------------------------------------------------------------
269356
# CacheReuseAdapter.get_cached_token_count_per_layer_group: SWA clamp.
270357
# ---------------------------------------------------------------------------

0 commit comments

Comments
 (0)