Skip to content

Commit 347e52e

Browse files
committed
fix(qwen3next): persist mtp full-attn cpu cache slots
1 parent 2560680 commit 347e52e

7 files changed

Lines changed: 355 additions & 24 deletions

File tree

lightllm/common/basemodel/triton_kernel/linear_att_cpu_cache_copy.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -193,11 +193,7 @@ def copy_kv_buffer_to_cpu_cache(
193193
cpu_kv_ssm_tail_dim = cpu_kv_ssm_state.shape[-1]
194194
full_att_layer_num = gpu_kv_full_att_state.shape[-2]
195195

196-
assert (
197-
full_att_layer_num
198-
== (linear_config.all_layer_num // linear_config.full_attention_interval)
199-
== (linear_config.all_layer_num - linear_config.linear_layer_num)
200-
)
196+
assert full_att_layer_num == linear_config.get_persisted_full_att_layer_num()
201197
assert gpu_full_att_tail_dim == cpu_cache_full_att.shape[-1]
202198
assert cpu_cache_conv.shape[-1] == cpu_kv_conv_state.shape[-1]
203199
assert cpu_cache_ssm.shape[-1] == cpu_kv_ssm_state.shape[-1]
@@ -428,6 +424,7 @@ def copy_cpu_cache_to_kv_buffer(
428424
cpu_kv_ssm_tail_dim = cpu_kv_ssm_state.shape[-1]
429425
full_att_layer_num = gpu_full_att_kv_state.shape[-2]
430426

427+
assert full_att_layer_num == linear_config.get_persisted_full_att_layer_num()
431428
assert gpu_full_att_tail_dim == cpu_cache_full_att.shape[-1]
432429
assert cpu_cache_conv.shape[-1] == cpu_kv_conv_state.shape[-1]
433430
assert cpu_cache_ssm.shape[-1] == cpu_kv_ssm_state.shape[-1]

lightllm/common/kv_cache_mem_manager/operator/linear_att.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,16 @@ def __init__(self, mem_manager):
2424
super().__init__(mem_manager)
2525
self.linear_config = LinearAttCacheConfig.load_from_args()
2626

27+
@staticmethod
28+
def _get_persisted_full_att_layer_num(mem_manager) -> int:
29+
persisted_full_att = getattr(mem_manager, "persisted_full_att_layer_num", None)
30+
if persisted_full_att is None:
31+
main_full_att = getattr(mem_manager, "main_full_att_layer_num", mem_manager.kv_buffer.shape[0])
32+
draft_full_att = getattr(mem_manager, "draft_full_att_layers", 0)
33+
persisted_full_att = main_full_att + draft_full_att
34+
assert 0 < persisted_full_att <= mem_manager.kv_buffer.shape[0]
35+
return int(persisted_full_att)
36+
2737
def load_cpu_cache_to_gpu(
2838
self,
2939
mem_indexes: torch.Tensor,
@@ -76,16 +86,14 @@ def load_cpu_cache_to_gpu(
7686
copy_cpu_cache_to_kv_buffer,
7787
)
7888

79-
# Persist/restore ONLY the main model's full-attn slice. The kv buffer is widened by
80-
# dedicated MTP draft slots [main_full_att, main_full_att + draft) (speculative KV that
81-
# must never touch the CPU/disk cache), so slice them off here.
82-
main_full_att = getattr(mem_manager, "main_full_att_layer_num", mem_manager.kv_buffer.shape[0])
89+
# Restore the persisted full-attn slice: main slots followed by MTP draft slots.
90+
persisted_full_att = self._get_persisted_full_att_layer_num(mem_manager)
8391

8492
copy_cpu_cache_to_kv_buffer(
8593
mem_indexes=mem_indexes,
8694
big_page_buffer_ids=big_page_buffer_ids_gpu,
8795
page_indexes=page_indexes,
88-
gpu_full_att_kv_state=mem_manager.kv_buffer[:main_full_att],
96+
gpu_full_att_kv_state=mem_manager.kv_buffer[:persisted_full_att],
8997
cpu_kv_conv_state=mem_manager.linear_att_big_page_buffers.conv_state_cache.buffer,
9098
cpu_kv_ssm_state=mem_manager.linear_att_big_page_buffers.ssm_state_cache.buffer,
9199
cpu_cache_tensor=cpu_cache_client.cpu_kv_cache_tensor,
@@ -174,17 +182,15 @@ def offload_gpu_kv_to_cpu_cache(
174182
copy_kv_buffer_to_cpu_cache,
175183
)
176184

177-
# Persist ONLY the main model's full-attn slice. The kv buffer is widened by dedicated
178-
# MTP draft slots [main_full_att, main_full_att + draft) (speculative KV that must never
179-
# be persisted to the CPU/disk cache), so slice them off here.
180-
main_full_att = getattr(mem_manager, "main_full_att_layer_num", mem_manager.kv_buffer.shape[0])
185+
# Persist the full-attn slice used for prefix reuse: main slots followed by MTP draft slots.
186+
persisted_full_att = self._get_persisted_full_att_layer_num(mem_manager)
181187

182188
copy_kv_buffer_to_cpu_cache(
183189
mem_indexes=mem_indexes,
184190
page_indexes=page_indexes,
185191
page_readies=page_readies,
186192
big_page_buffer_ids=big_page_buffer_ids_gpu,
187-
gpu_kv_full_att_state=mem_manager.kv_buffer[:main_full_att],
193+
gpu_kv_full_att_state=mem_manager.kv_buffer[:persisted_full_att],
188194
cpu_kv_conv_state=mem_manager.linear_att_big_page_buffers.conv_state_cache.buffer,
189195
cpu_kv_ssm_state=mem_manager.linear_att_big_page_buffers.ssm_state_cache.buffer,
190196
cpu_cache_tensor=cpu_cache_client.cpu_kv_cache_tensor,

lightllm/common/linear_att_cache_manager/config_objs.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,15 @@
88
logger = init_logger(__name__)
99

1010

11+
def get_mtp_draft_full_att_layer_num(args) -> int:
12+
mtp_mode = getattr(args, "mtp_mode", None)
13+
if mtp_mode == "eagle_with_att":
14+
return 1
15+
if mtp_mode == "vanilla_with_att":
16+
return getattr(args, "mtp_step", 0)
17+
return 0
18+
19+
1120
@dataclasses.dataclass
1221
class LinearAttCacheConfig:
1322
tp_world_size: int
@@ -28,10 +37,19 @@ class LinearAttCacheConfig:
2837
ssm_state_dtype: torch.dtype
2938
full_attention_interval: int
3039
all_layer_num: int # 包括 linear att 和 full att 的层加起来的层数
40+
draft_full_att_layer_num: int = 0
3141

3242
def get_conv_dim(self):
3343
return self.head_linear_k_dim * self.num_linear_k_heads * 2 + self.head_linear_v_dim * self.num_linear_v_heads
3444

45+
def get_main_full_att_layer_num(self):
46+
main_full_att_layer_num = self.all_layer_num - self.linear_layer_num
47+
assert main_full_att_layer_num == self.all_layer_num // self.full_attention_interval
48+
return main_full_att_layer_num
49+
50+
def get_persisted_full_att_layer_num(self):
51+
return self.get_main_full_att_layer_num() + self.draft_full_att_layer_num
52+
3553
def get_persisted_conv_state_shape(self):
3654
# NARROW shape used for the CPU/disk persisted page and ALL byte math.
3755
# Persisted state is always the committed (narrow) sliding window.
@@ -71,7 +89,7 @@ def get_cpu_cache_full_att_bytes(self):
7189
)
7290
assert big_page_token_num == get_env_start_args().cpu_cache_token_page_size
7391
full_att_bytes = 2 * self.full_att_all_num_kv_heads * self.full_att_head_dim * self.full_att_dtype.itemsize
74-
a = full_att_bytes * (self.all_layer_num - self.linear_layer_num) * big_page_token_num
92+
a = full_att_bytes * self.get_persisted_full_att_layer_num() * big_page_token_num
7593
return a
7694

7795
def get_cpu_cache_conv_bytes(self):
@@ -116,4 +134,5 @@ def load_from_args() -> "LinearAttCacheConfig":
116134
ssm_state_dtype=get_torch_dtype(args.linear_att_ssm_data_type),
117135
full_attention_interval=llm_config["full_attention_interval"],
118136
all_layer_num=n_layer,
137+
draft_full_att_layer_num=get_mtp_draft_full_att_layer_num(args),
119138
)

lightllm/models/qwen3next/model.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@
1616
from lightllm.common.kv_cache_mem_manager.qwen3next_mem_manager import Qwen3NextMemManager
1717
from lightllm.server.core.objs.start_args_type import StartArgs
1818
from lightllm.common.req_manager import ReqManagerForMamba
19-
from lightllm.common.linear_att_cache_manager.config_objs import LinearAttCacheConfig
19+
from lightllm.common.linear_att_cache_manager.config_objs import (
20+
LinearAttCacheConfig,
21+
get_mtp_draft_full_att_layer_num,
22+
)
2023

2124
logger = init_logger(__name__)
2225

@@ -58,6 +61,7 @@ def _init_mem_manager(self):
5861
assert self.config["num_attention_heads"] % self.tp_world_size_ == 0
5962
start_args: StartArgs = get_env_start_args()
6063
ssm_dtype_dict = {"bfloat16": torch.bfloat16, "float32": torch.float32}
64+
draft_full_att_layers = get_mtp_draft_full_att_layer_num(start_args)
6165
self.linear_config = LinearAttCacheConfig(
6266
tp_world_size=self.tp_world_size_,
6367
full_att_all_num_kv_heads=self.config["num_key_value_heads"],
@@ -75,14 +79,11 @@ def _init_mem_manager(self):
7579
ssm_state_dtype=ssm_dtype_dict[start_args.linear_att_ssm_data_type],
7680
full_attention_interval=self.config["full_attention_interval"],
7781
all_layer_num=self.config["n_layer"],
82+
draft_full_att_layer_num=draft_full_att_layers,
7883
)
7984

80-
main_full_att = self.linear_config.all_layer_num - self.linear_config.linear_layer_num
81-
draft_full_att_layers = 0
82-
if start_args.mtp_mode == "eagle_with_att":
83-
draft_full_att_layers = 1
84-
elif start_args.mtp_mode == "vanilla_with_att":
85-
draft_full_att_layers = start_args.mtp_step
85+
main_full_att = self.linear_config.get_main_full_att_layer_num()
86+
persisted_full_att = self.linear_config.get_persisted_full_att_layer_num()
8687
self._main_full_att_layer_num = main_full_att
8788
self._draft_full_att_layers = draft_full_att_layers
8889

@@ -91,12 +92,13 @@ def _init_mem_manager(self):
9192
dtype=self.data_type,
9293
num_kv_heads=self.num_kv_heads,
9394
head_dim=self.config["head_dim"],
94-
full_att_layer_num=main_full_att + draft_full_att_layers,
95+
full_att_layer_num=persisted_full_att,
9596
linear_config=self.linear_config,
9697
mem_fraction=self.mem_fraction,
9798
)
9899
self.mem_manager.main_full_att_layer_num = main_full_att
99100
self.mem_manager.draft_full_att_layers = draft_full_att_layers
101+
self.mem_manager.persisted_full_att_layer_num = persisted_full_att
100102

101103
def _init_req_manager(self):
102104
create_max_seq_len = 0

lightllm/utils/kv_cache_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,8 @@ def calcu_cpu_cache_meta() -> "CpuKVCacheMeta":
121121
if args.mtp_mode is not None:
122122
# TODO 可能会存在不同mtp模式的精度问题
123123
if is_linear_att_mixed_model(args.model_dir):
124+
# Linear mixed models use one packed byte page; MTP draft full-attn
125+
# slots are accounted in LinearAttCacheConfig.get_cpu_cache_big_page_bytes().
124126
pass
125127
else:
126128
cpu_cache_meta.layer_num += get_added_mtp_kv_layer_num()
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
"""Force the CPU KV-cache offload->restore path and check correctness.
2+
3+
GSM8K can't exercise the CPU cache (one shared hot prefix, sub-page tails).
4+
This driver builds N distinct, page-aligned, long prompts that overflow the
5+
GPU KV budget so their KV is offloaded to CPU, then re-requests them so they
6+
are restored from CPU. With greedy decoding the round-2 (CPU-restored) output
7+
MUST be token-identical to round-1 (freshly computed). For the MTP build it
8+
also tracks accept-rate (mtp_avg_token_per_step) which would degrade if the
9+
draft full-attn slots were not persisted/restored correctly.
10+
"""
11+
import argparse
12+
import sys
13+
import requests
14+
from concurrent.futures import ThreadPoolExecutor
15+
16+
17+
def make_prompts(n, words_per_prompt):
18+
prompts = []
19+
for i in range(n):
20+
# Distinct, deterministic filler so each prompt is its own radix branch
21+
# and long enough to span several 256-token pages.
22+
filler = " ".join(f"item{i}-{j}" for j in range(words_per_prompt))
23+
prompts.append(
24+
f"You are given list number {i}. The list is: {filler}. "
25+
f"Question: briefly summarize what list number {i} contains. Answer:"
26+
)
27+
return prompts
28+
29+
30+
def gen(url, prompt, max_tokens):
31+
data = {
32+
"inputs": prompt,
33+
"parameters": {
34+
"temperature": 0.0,
35+
"max_new_tokens": max_tokens,
36+
"stop_sequences": None,
37+
"repetition_penalty": 1.0,
38+
"top_p": 1.0,
39+
"top_k": 1,
40+
},
41+
}
42+
r = requests.post(url, json=data, timeout=120)
43+
assert r.status_code == 200, f"{r.status_code}: {r.text}"
44+
return r.json()["generated_text"][0]
45+
46+
47+
def run_round(url, prompts, max_tokens, parallel):
48+
out = [None] * len(prompts)
49+
with ThreadPoolExecutor(max_workers=parallel) as ex:
50+
futs = {ex.submit(gen, url, p, max_tokens): k for k, p in enumerate(prompts)}
51+
for f in futs:
52+
k = futs[f]
53+
out[k] = f.result()
54+
return out
55+
56+
57+
def main():
58+
ap = argparse.ArgumentParser()
59+
ap.add_argument("--host", default="http://127.0.0.1")
60+
ap.add_argument("--port", type=int, default=8088)
61+
ap.add_argument("--num-prompts", type=int, default=24)
62+
ap.add_argument("--words-per-prompt", type=int, default=400)
63+
ap.add_argument("--max-tokens", type=int, default=32)
64+
ap.add_argument("--parallel", type=int, default=8)
65+
args = ap.parse_args()
66+
67+
url = f"{args.host}:{args.port}/generate"
68+
prompts = make_prompts(args.num_prompts, args.words_per_prompt)
69+
70+
print(f"Round 1 (cold compute): {len(prompts)} distinct prompts", flush=True)
71+
r1 = run_round(url, prompts, args.max_tokens, args.parallel)
72+
print("Round 2 (CPU-restored):", flush=True)
73+
r2 = run_round(url, prompts, args.max_tokens, args.parallel)
74+
75+
mismatches = [i for i in range(len(prompts)) if r1[i] != r2[i]]
76+
print(f"\n=== RESULT ===")
77+
print(f"prompts: {len(prompts)} identical: {len(prompts) - len(mismatches)} mismatches: {len(mismatches)}")
78+
if mismatches:
79+
for i in mismatches[:5]:
80+
print(f" [#{i}] R1={r1[i]!r}\n R2={r2[i]!r}")
81+
sys.exit(1)
82+
print("PASS: round-2 (CPU-restored) output is token-identical to round-1.")
83+
84+
85+
if __name__ == "__main__":
86+
main()

0 commit comments

Comments
 (0)