|
14 | 14 | # limitations under the License. |
15 | 15 | """Tests for CacheReuseAdapter, _create_kv_slice SWA trim, and Sender token-start derivation.""" |
16 | 16 |
|
| 17 | +from types import SimpleNamespace |
| 18 | + |
17 | 19 | import numpy as np |
18 | 20 | import pytest |
19 | 21 |
|
20 | 22 | from tensorrt_llm._torch.disaggregation.base.transfer import TokenRange |
21 | 23 | from tensorrt_llm._torch.disaggregation.native.transfer import Sender |
22 | 24 | from tensorrt_llm._torch.disaggregation.resource.cache_reuse import CacheReuseAdapter |
23 | 25 | from tensorrt_llm._torch.disaggregation.resource.page import AttentionLayerGroup |
| 26 | +from tensorrt_llm._torch.disaggregation.transceiver import KvCacheTransceiverV2 |
24 | 27 |
|
25 | 28 | # --------------------------------------------------------------------------- |
26 | 29 | # _align_kv_blocks: contract unchanged. |
@@ -114,6 +117,91 @@ def test_start_eq_end_rejected(self): |
114 | 117 | TokenRange(start=256, end=256) |
115 | 118 |
|
116 | 119 |
|
| 120 | +# --------------------------------------------------------------------------- |
| 121 | +# _create_kv_slice: default TokenRange spans prompt_len + num_extra_kv_tokens |
| 122 | +# so transferred KV matches what resize_context / _get_context_bytes allocate. |
| 123 | +# --------------------------------------------------------------------------- |
| 124 | + |
| 125 | + |
| 126 | +def _build_transceiver_for_kv_slice(num_extra_kv_tokens: int, prompt_len: int): |
| 127 | + """Stub a KvCacheTransceiverV2 so _create_kv_slice runs without dist setup. |
| 128 | +
|
| 129 | + Wires only the attributes the method touches: |
| 130 | + - reuse adapter: tokens_per_block, per-layer-group cached count, block ids |
| 131 | + - page table: layer groups |
| 132 | + - cache manager: num_extra_kv_tokens (read in this code path) |
| 133 | + """ |
| 134 | + tokens_per_block = 8 |
| 135 | + layer_group = AttentionLayerGroup(pool_group_idx=0, kv_head_num_per_rank=1) |
| 136 | + total_blocks = (prompt_len + num_extra_kv_tokens + tokens_per_block - 1) // tokens_per_block |
| 137 | + block_ids = np.arange(total_blocks, dtype=np.int64) |
| 138 | + |
| 139 | + reuse_adapter = SimpleNamespace( |
| 140 | + tokens_per_block=tokens_per_block, |
| 141 | + get_cached_token_count_per_layer_group=lambda req, layer_groups: [0] * len(layer_groups), |
| 142 | + get_block_ids=lambda req, idx, lg: block_ids, |
| 143 | + ) |
| 144 | + page_table = SimpleNamespace(layer_groups=[layer_group]) |
| 145 | + cache_manager = SimpleNamespace(num_extra_kv_tokens=num_extra_kv_tokens) |
| 146 | + |
| 147 | + transceiver = object.__new__(KvCacheTransceiverV2) |
| 148 | + transceiver._reuse_adapter = reuse_adapter |
| 149 | + transceiver._page_table = page_table |
| 150 | + transceiver._kv_cache_manager = cache_manager |
| 151 | + |
| 152 | + req = SimpleNamespace( |
| 153 | + prompt_len=prompt_len, |
| 154 | + py_request_id=0, |
| 155 | + is_generation_only_request=lambda: False, |
| 156 | + ) |
| 157 | + return transceiver, req |
| 158 | + |
| 159 | + |
| 160 | +class TestCreateKvSliceTokenRange: |
| 161 | + """Default TokenRange built by _create_kv_slice must align with KV-cache allocation. |
| 162 | +
|
| 163 | + KV cache allocation in resize_context (V2) and prepare_resources (V1) reserves |
| 164 | + prompt_len + num_extra_kv_tokens slots whenever speculative decoding (e.g. |
| 165 | + EAGLE3, MTP) consumes extra KV positions per request. The transferred token |
| 166 | + range must cover the same span, otherwise the receiver under-receives KV. |
| 167 | + """ |
| 168 | + |
| 169 | + def test_includes_num_extra_kv_tokens(self): |
| 170 | + prompt_len = 17 |
| 171 | + num_extra_kv_tokens = 7 |
| 172 | + transceiver, req = _build_transceiver_for_kv_slice(num_extra_kv_tokens, prompt_len) |
| 173 | + |
| 174 | + kv_slice = transceiver._create_kv_slice(req) |
| 175 | + |
| 176 | + assert kv_slice.token_range is not None |
| 177 | + assert (kv_slice.token_range.start, kv_slice.token_range.end) == ( |
| 178 | + 0, |
| 179 | + prompt_len + num_extra_kv_tokens, |
| 180 | + ) |
| 181 | + |
| 182 | + def test_defaults_to_prompt_len_when_no_extra(self): |
| 183 | + prompt_len = 17 |
| 184 | + transceiver, req = _build_transceiver_for_kv_slice( |
| 185 | + num_extra_kv_tokens=0, prompt_len=prompt_len |
| 186 | + ) |
| 187 | + |
| 188 | + kv_slice = transceiver._create_kv_slice(req) |
| 189 | + |
| 190 | + assert kv_slice.token_range is not None |
| 191 | + assert (kv_slice.token_range.start, kv_slice.token_range.end) == (0, prompt_len) |
| 192 | + |
| 193 | + def test_respects_explicit_token_range(self): |
| 194 | + prompt_len = 17 |
| 195 | + transceiver, req = _build_transceiver_for_kv_slice( |
| 196 | + num_extra_kv_tokens=7, prompt_len=prompt_len |
| 197 | + ) |
| 198 | + explicit = TokenRange(start=0, end=8) |
| 199 | + |
| 200 | + kv_slice = transceiver._create_kv_slice(req, token_range=explicit) |
| 201 | + |
| 202 | + assert kv_slice.token_range is explicit |
| 203 | + |
| 204 | + |
117 | 205 | # --------------------------------------------------------------------------- |
118 | 206 | # CacheReuseAdapter.get_cached_token_count_per_layer_group: SWA clamp. |
119 | 207 | # --------------------------------------------------------------------------- |
|
0 commit comments