Skip to content

Commit 59705de

Browse files
committed
Fix Qwen3Next MTP linear-att page moves
1 parent b0fab9d commit 59705de

3 files changed

Lines changed: 130 additions & 11 deletions

File tree

lightllm/common/kv_cache_mem_manager/qwen3next_mem_manager.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -208,9 +208,9 @@ def write_req_to_page(
208208
dp_mems: List["Qwen3NextMemManager"],
209209
):
210210
conv_page, ssm_page = self.view_page_to_linear_att_state(page_index)
211-
req_buffer_idx = req_idx * (get_env_start_args().mtp_step + 1)
211+
conv_req_idx, ssm_req_idx = self._get_req_state_indexes(req_idx)
212212
for tp_index, mem in enumerate(dp_mems):
213-
self._write_one_rank(mem, tp_index, req_buffer_idx, conv_page, ssm_page)
213+
self._write_one_rank(mem, tp_index, conv_req_idx, ssm_req_idx, conv_page, ssm_page)
214214
return
215215

216216
def read_page_to_req(
@@ -220,21 +220,27 @@ def read_page_to_req(
220220
dp_mems: List["Qwen3NextMemManager"],
221221
):
222222
conv_page, ssm_page = self.view_page_to_linear_att_state(page_index)
223-
req_buffer_idx = req_idx * (get_env_start_args().mtp_step + 1)
223+
conv_req_idx, ssm_req_idx = self._get_req_state_indexes(req_idx)
224224
for tp_index, mem in enumerate(dp_mems):
225-
self._read_one_rank(mem, tp_index, req_buffer_idx, conv_page, ssm_page)
225+
self._read_one_rank(mem, tp_index, conv_req_idx, ssm_req_idx, conv_page, ssm_page)
226226
return
227227

228+
def _get_req_state_indexes(self, req_idx: int):
229+
mtp_size = get_env_start_args().mtp_step + 1
230+
# Conv is one widened slot per request; SSM keeps the historical S+1 block layout.
231+
return req_idx, req_idx * mtp_size
232+
228233
def _write_one_rank(
229234
self,
230235
mem: "Qwen3NextMemManager",
231236
tp_index: int,
232-
req_buffer_idx: int,
237+
conv_req_idx: int,
238+
ssm_req_idx: int,
233239
conv_page: torch.Tensor,
234240
ssm_page: torch.Tensor,
235241
):
236-
conv_state = mem.req_to_conv_state.buffer[:, req_buffer_idx, ...]
237-
ssm_state = mem.req_to_ssm_state.buffer[:, req_buffer_idx, ...]
242+
conv_state = mem.req_to_conv_state.buffer[:, conv_req_idx, ..., : self.conv_shape[-1]]
243+
ssm_state = mem.req_to_ssm_state.buffer[:, ssm_req_idx, ...]
238244
self._copy_conv_state_to_page(conv_state, conv_page, mem, tp_index)
239245
self._copy_ssm_state_to_page(ssm_state, ssm_page, mem, tp_index)
240246
return
@@ -408,12 +414,13 @@ def _read_one_rank(
408414
self,
409415
mem: "Qwen3NextMemManager",
410416
tp_index: int,
411-
req_buffer_idx: int,
417+
conv_req_idx: int,
418+
ssm_req_idx: int,
412419
conv_page: torch.Tensor,
413420
ssm_page: torch.Tensor,
414421
):
415-
conv_state = mem.req_to_conv_state.buffer[:, req_buffer_idx, ...]
416-
ssm_state = mem.req_to_ssm_state.buffer[:, req_buffer_idx, ...]
422+
conv_state = mem.req_to_conv_state.buffer[:, conv_req_idx, ..., : self.conv_shape[-1]]
423+
ssm_state = mem.req_to_ssm_state.buffer[:, ssm_req_idx, ...]
417424
self._copy_page_to_conv_state(conv_page, conv_state, mem, tp_index)
418425
self._copy_page_to_ssm_state(ssm_page, ssm_state, mem, tp_index)
419426
return

unit_tests/common/basemodel/test_mtp_decode_cuda_graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ class Qwen3_5MOETpPartModel:
103103
pass
104104

105105
class Qwen3_5MoeMTPModel:
106-
pass
106+
is_mtp_draft_model = True
107107

108108
graph = CudaGraph.__new__(CudaGraph)
109109
graph.mtp_step = 2
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
from types import SimpleNamespace
2+
3+
import torch
4+
5+
6+
class _Buf:
7+
def __init__(self, tensor):
8+
self.buffer = tensor
9+
10+
11+
def _make_config():
12+
return SimpleNamespace(
13+
tp_world_size=1,
14+
linear_layer_num=1,
15+
conv_kernel_size=4,
16+
global_linear_k_heads=1,
17+
global_linear_v_heads=1,
18+
num_linear_k_heads=1,
19+
num_linear_v_heads=1,
20+
head_linear_k_dim=2,
21+
head_linear_v_dim=3,
22+
)
23+
24+
25+
def _make_mem(mtp_step=2, req_slots=4):
26+
config = _make_config()
27+
conv_dim = (
28+
2 * config.num_linear_k_heads * config.head_linear_k_dim
29+
+ config.num_linear_v_heads * config.head_linear_v_dim
30+
)
31+
narrow_w = config.conv_kernel_size - 1
32+
conv = torch.full(
33+
(config.linear_layer_num, req_slots, conv_dim, narrow_w + mtp_step),
34+
-9.0,
35+
dtype=torch.float32,
36+
)
37+
ssm = torch.full(
38+
(
39+
config.linear_layer_num,
40+
req_slots * (mtp_step + 1),
41+
config.num_linear_v_heads,
42+
config.head_linear_k_dim,
43+
config.head_linear_v_dim,
44+
),
45+
-11.0,
46+
dtype=torch.float32,
47+
)
48+
return SimpleNamespace(
49+
linear_config=config,
50+
req_to_conv_state=_Buf(conv),
51+
req_to_ssm_state=_Buf(ssm),
52+
kv_move_buffer=torch.zeros((1, 4096), dtype=torch.uint8),
53+
)
54+
55+
56+
def test_page_helper_writes_req_conv_slot_and_narrow_width(monkeypatch):
57+
import lightllm.common.kv_cache_mem_manager.qwen3next_mem_manager as qwen3next_mem_manager
58+
from lightllm.common.kv_cache_mem_manager.qwen3next_mem_manager import Qwen3NextLinearAttPageHelper
59+
60+
mtp_step = 2
61+
req_idx = 2
62+
monkeypatch.setattr(qwen3next_mem_manager, "get_env_start_args", lambda: SimpleNamespace(mtp_step=mtp_step))
63+
64+
mem = _make_mem(mtp_step=mtp_step)
65+
helper = Qwen3NextLinearAttPageHelper(mem)
66+
mem.kv_move_buffer = torch.zeros((1, helper.state_nbytes), dtype=torch.uint8)
67+
68+
narrow_w = helper.conv_shape[-1]
69+
marker_conv = torch.arange(
70+
helper.conv_shape[0] * helper.conv_shape[1] * narrow_w,
71+
dtype=torch.float32,
72+
).view(helper.conv_shape)
73+
marker_ssm = torch.arange(
74+
helper.ssm_shape[0] * helper.ssm_shape[1] * helper.ssm_shape[2] * helper.ssm_shape[3],
75+
dtype=torch.float32,
76+
).view(helper.ssm_shape)
77+
78+
mem.req_to_conv_state.buffer[:, req_idx, :, :narrow_w] = marker_conv
79+
mem.req_to_conv_state.buffer[:, req_idx, :, narrow_w:] = 999.0
80+
mem.req_to_ssm_state.buffer[:, req_idx * (mtp_step + 1), ...] = marker_ssm
81+
82+
helper.write_req_to_page(page_index=0, req_idx=req_idx, dp_mems=[mem])
83+
84+
conv_page, ssm_page = helper.view_page_to_linear_att_state(page_index=0)
85+
torch.testing.assert_close(conv_page, marker_conv)
86+
torch.testing.assert_close(ssm_page, marker_ssm)
87+
88+
89+
def test_page_helper_restores_narrow_conv_to_req_slot(monkeypatch):
90+
import lightllm.common.kv_cache_mem_manager.qwen3next_mem_manager as qwen3next_mem_manager
91+
from lightllm.common.kv_cache_mem_manager.qwen3next_mem_manager import Qwen3NextLinearAttPageHelper
92+
93+
mtp_step = 2
94+
req_idx = 2
95+
monkeypatch.setattr(qwen3next_mem_manager, "get_env_start_args", lambda: SimpleNamespace(mtp_step=mtp_step))
96+
97+
mem = _make_mem(mtp_step=mtp_step)
98+
helper = Qwen3NextLinearAttPageHelper(mem)
99+
mem.kv_move_buffer = torch.zeros((1, helper.state_nbytes), dtype=torch.uint8)
100+
conv_page, ssm_page = helper.view_page_to_linear_att_state(page_index=0)
101+
102+
marker_conv = torch.arange(conv_page.numel(), dtype=torch.float32).view_as(conv_page)
103+
marker_ssm = torch.arange(ssm_page.numel(), dtype=torch.float32).view_as(ssm_page)
104+
conv_page.copy_(marker_conv)
105+
ssm_page.copy_(marker_ssm)
106+
107+
helper.read_page_to_req(page_index=0, req_idx=req_idx, dp_mems=[mem])
108+
109+
narrow_w = helper.conv_shape[-1]
110+
torch.testing.assert_close(mem.req_to_conv_state.buffer[:, req_idx, :, :narrow_w], marker_conv)
111+
assert torch.all(mem.req_to_conv_state.buffer[:, req_idx, :, narrow_w:] == -9.0)
112+
torch.testing.assert_close(mem.req_to_ssm_state.buffer[:, req_idx * (mtp_step + 1), ...], marker_ssm)

0 commit comments

Comments
 (0)