Skip to content

Commit 1629cd8

Browse files
chuangz0Shixiaowei02
authored andcommitted
[None][fix] fix token_range_end add extra_kv_num_tokens (#14258)
Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com>
1 parent edcce32 commit 1629cd8

2 files changed

Lines changed: 94 additions & 1 deletion

File tree

tensorrt_llm/_torch/disaggregation/transceiver.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,12 @@ def _create_kv_slice(
181181
)
182182

183183
if token_range is None and req.prompt_len > 0:
184-
token_range = TokenRange(start=0, end=req.prompt_len)
184+
# Align with KV cache allocation (resize_context /
185+
# _get_context_bytes), which reserves prompt_len +
186+
# num_extra_kv_tokens slots for speculative decoding methods
187+
# (e.g. EAGLE3) that consume extra KV positions per request.
188+
num_extra_kv_tokens = getattr(self._kv_cache_manager, "num_extra_kv_tokens", 0) or 0
189+
token_range = TokenRange(start=0, end=req.prompt_len + num_extra_kv_tokens)
185190

186191
groups = []
187192
for idx, lg in enumerate(layer_groups):

tests/unittest/disaggregated/test_cache_reuse_adapter.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,16 @@
1414
# limitations under the License.
1515
"""Tests for CacheReuseAdapter, _create_kv_slice SWA trim, and Sender token-start derivation."""
1616

17+
from types import SimpleNamespace
18+
1719
import numpy as np
1820
import pytest
1921

2022
from tensorrt_llm._torch.disaggregation.base.transfer import TokenRange
2123
from tensorrt_llm._torch.disaggregation.native.transfer import Sender
2224
from tensorrt_llm._torch.disaggregation.resource.cache_reuse import CacheReuseAdapter
2325
from tensorrt_llm._torch.disaggregation.resource.page import AttentionLayerGroup
26+
from tensorrt_llm._torch.disaggregation.transceiver import KvCacheTransceiverV2
2427

2528
# ---------------------------------------------------------------------------
2629
# _align_kv_blocks: contract unchanged.
@@ -114,6 +117,91 @@ def test_start_eq_end_rejected(self):
114117
TokenRange(start=256, end=256)
115118

116119

120+
# ---------------------------------------------------------------------------
121+
# _create_kv_slice: default TokenRange spans prompt_len + num_extra_kv_tokens
122+
# so transferred KV matches what resize_context / _get_context_bytes allocate.
123+
# ---------------------------------------------------------------------------
124+
125+
126+
def _build_transceiver_for_kv_slice(num_extra_kv_tokens: int, prompt_len: int):
127+
"""Stub a KvCacheTransceiverV2 so _create_kv_slice runs without dist setup.
128+
129+
Wires only the attributes the method touches:
130+
- reuse adapter: tokens_per_block, per-layer-group cached count, block ids
131+
- page table: layer groups
132+
- cache manager: num_extra_kv_tokens (read in this code path)
133+
"""
134+
tokens_per_block = 8
135+
layer_group = AttentionLayerGroup(pool_group_idx=0, kv_head_num_per_rank=1)
136+
total_blocks = (prompt_len + num_extra_kv_tokens + tokens_per_block - 1) // tokens_per_block
137+
block_ids = np.arange(total_blocks, dtype=np.int64)
138+
139+
reuse_adapter = SimpleNamespace(
140+
tokens_per_block=tokens_per_block,
141+
get_cached_token_count_per_layer_group=lambda req, layer_groups: [0] * len(layer_groups),
142+
get_block_ids=lambda req, idx, lg: block_ids,
143+
)
144+
page_table = SimpleNamespace(layer_groups=[layer_group])
145+
cache_manager = SimpleNamespace(num_extra_kv_tokens=num_extra_kv_tokens)
146+
147+
transceiver = object.__new__(KvCacheTransceiverV2)
148+
transceiver._reuse_adapter = reuse_adapter
149+
transceiver._page_table = page_table
150+
transceiver._kv_cache_manager = cache_manager
151+
152+
req = SimpleNamespace(
153+
prompt_len=prompt_len,
154+
py_request_id=0,
155+
is_generation_only_request=lambda: False,
156+
)
157+
return transceiver, req
158+
159+
160+
class TestCreateKvSliceTokenRange:
161+
"""Default TokenRange built by _create_kv_slice must align with KV-cache allocation.
162+
163+
KV cache allocation in resize_context (V2) and prepare_resources (V1) reserves
164+
prompt_len + num_extra_kv_tokens slots whenever speculative decoding (e.g.
165+
EAGLE3, MTP) consumes extra KV positions per request. The transferred token
166+
range must cover the same span, otherwise the receiver under-receives KV.
167+
"""
168+
169+
def test_includes_num_extra_kv_tokens(self):
170+
prompt_len = 17
171+
num_extra_kv_tokens = 7
172+
transceiver, req = _build_transceiver_for_kv_slice(num_extra_kv_tokens, prompt_len)
173+
174+
kv_slice = transceiver._create_kv_slice(req)
175+
176+
assert kv_slice.token_range is not None
177+
assert (kv_slice.token_range.start, kv_slice.token_range.end) == (
178+
0,
179+
prompt_len + num_extra_kv_tokens,
180+
)
181+
182+
def test_defaults_to_prompt_len_when_no_extra(self):
183+
prompt_len = 17
184+
transceiver, req = _build_transceiver_for_kv_slice(
185+
num_extra_kv_tokens=0, prompt_len=prompt_len
186+
)
187+
188+
kv_slice = transceiver._create_kv_slice(req)
189+
190+
assert kv_slice.token_range is not None
191+
assert (kv_slice.token_range.start, kv_slice.token_range.end) == (0, prompt_len)
192+
193+
def test_respects_explicit_token_range(self):
194+
prompt_len = 17
195+
transceiver, req = _build_transceiver_for_kv_slice(
196+
num_extra_kv_tokens=7, prompt_len=prompt_len
197+
)
198+
explicit = TokenRange(start=0, end=8)
199+
200+
kv_slice = transceiver._create_kv_slice(req, token_range=explicit)
201+
202+
assert kv_slice.token_range is explicit
203+
204+
117205
# ---------------------------------------------------------------------------
118206
# CacheReuseAdapter.get_cached_token_count_per_layer_group: SWA clamp.
119207
# ---------------------------------------------------------------------------

0 commit comments

Comments
 (0)