Skip to content

Commit 0d15236

Browse files
committed
refactor(mtp): GPU-resident req_to_accept_len + simplify verify-decode plumbing
- is_mtp_verify: drop the redundant `b_num_accepted_tokens is not None` clause (post grouped-revert it's implied by mtp_step>0 ∧ ¬prefill). - Replace the per-step host round-trip for b_num_accepted_tokens with a GPU-resident ReqManager.req_to_accept_len: a triton scatter_mtp_accept_len after verify + a GDN-only gather in init_mtp_verify_extra_state. Removes the gen_from_list H2D rebuild, the phase-2 req.mtp_accept_len writeback, and the host attr (linear-att offload + resets now read/write the buffer). - Drop the redundant `if mtp_step>0` guard inside decode_mtp/decode_overlap_mtp. - config_objs: inline the mtp draft-layer count, dropping the _mtp_added_layer_num helper (kept get_added_mtp_kv_layer_num inlined in envs_utils). - cpu_cache_meta: don't bump layer_num for linear-att models (the draft full-att slots are already in LinearAttCacheConfig.get_cpu_cache_big_page_bytes()). Static checks pass (ast, flake8). The req_to_accept_len refactor is not yet runtime-verified; pending a hybrid GSM8K + cudagraph-ON parity run.
1 parent f71bcc9 commit 0d15236

15 files changed

Lines changed: 152 additions & 242 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@ requirements-musa.txt
1010
logs/
1111

1212
/benchmark/
13+
artifacts/

lightllm/common/basemodel/basemodel.py

Lines changed: 22 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,7 @@
2626
from lightllm.common.basemodel.triton_kernel.gather_token_id import gather_token, gather_token_prefill_decode_mixed
2727
from lightllm.utils.log_utils import init_logger
2828
from lightllm.utils.dist_utils import get_dp_world_size
29-
from lightllm.utils.envs_utils import (
30-
get_env_start_args,
31-
get_llm_data_type,
32-
get_added_mtp_kv_layer_num,
33-
)
29+
from lightllm.utils.envs_utils import get_env_start_args, get_llm_data_type, get_added_mtp_kv_layer_num
3430
from lightllm.distributed.communication_op import dist_group_manager
3531
from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput
3632
from lightllm.common.triton_utils.autotuner import AutotuneLevel
@@ -381,105 +377,36 @@ def _create_padded_decode_model_input(self, model_input: ModelInput, new_batch_s
381377
is_mtp_grouped_decode = (not model_input.is_prefill) and self.args.mtp_step > 0
382378
if is_mtp_grouped_decode:
383379
mtp_size = self.args.mtp_step + 1
384-
assert model_input.batch_size % mtp_size == 0
385-
assert new_batch_size % mtp_size == 0
386380
assert padded_batch_size % mtp_size == 0
387381
padded_req_num = padded_batch_size // mtp_size
388-
389-
pad_mtp_index = torch.arange(
390-
mtp_size,
391-
dtype=new_model_input.b_mtp_index.dtype,
392-
device=new_model_input.b_mtp_index.device,
393-
).repeat(padded_req_num)
394-
pad_seq_len = torch.arange(
395-
2,
396-
mtp_size + 2,
397-
dtype=new_model_input.b_seq_len.dtype,
398-
device=new_model_input.b_seq_len.device,
399-
).repeat(padded_req_num)
400382
new_model_input.total_token_num += padded_req_num * (mtp_size * (mtp_size + 3) // 2)
401383
new_model_input.max_kv_seq_len = max(mtp_size + 1, model_input.max_kv_seq_len)
402-
new_model_input.input_ids = torch.cat(
403-
(
404-
new_model_input.input_ids,
405-
torch.ones(
406-
padded_batch_size,
407-
dtype=new_model_input.input_ids.dtype,
408-
device=new_model_input.input_ids.device,
409-
),
410-
),
411-
dim=0,
412-
)
413-
new_model_input.b_req_idx = torch.cat(
414-
(
415-
new_model_input.b_req_idx,
416-
torch.full(
417-
(padded_batch_size,),
418-
self.req_manager.HOLD_REQUEST_ID,
419-
dtype=new_model_input.b_req_idx.dtype,
420-
device=new_model_input.b_req_idx.device,
421-
),
422-
),
423-
dim=0,
424-
)
425-
new_model_input.b_mtp_index = torch.cat((new_model_input.b_mtp_index, pad_mtp_index), dim=0)
384+
pad_seq_len = torch.arange(
385+
2, mtp_size + 2, dtype=new_model_input.b_seq_len.dtype, device=new_model_input.b_seq_len.device
386+
).repeat(padded_req_num)
426387
new_model_input.b_seq_len = torch.cat((new_model_input.b_seq_len, pad_seq_len), dim=0)
427-
new_model_input.mem_indexes = torch.cat(
428-
(
429-
new_model_input.mem_indexes,
430-
torch.full(
431-
(padded_batch_size,),
432-
self.mem_manager.HOLD_TOKEN_MEMINDEX,
433-
dtype=new_model_input.mem_indexes.dtype,
434-
device=new_model_input.mem_indexes.device,
435-
),
436-
),
437-
dim=0,
438-
)
439-
new_model_input.b_num_accepted_tokens = torch.cat(
440-
(
441-
new_model_input.b_num_accepted_tokens,
442-
torch.ones(
443-
padded_req_num,
444-
dtype=new_model_input.b_num_accepted_tokens.dtype,
445-
device=new_model_input.b_num_accepted_tokens.device,
446-
),
447-
),
448-
dim=0,
449-
)
388+
# b_num_accepted_tokens 不再随 model_input 流转/补齐:它在 GDN 的 init_mtp_verify_extra_state
389+
# 里按 req_first 从 req_to_accept_len gather,padding 组 req_first=HOLD(槽恒为 1)自然得 1。
450390
else:
451391
new_model_input.total_token_num += padded_batch_size * 2
452392
new_model_input.max_kv_seq_len = max(2, model_input.max_kv_seq_len)
453-
new_model_input.input_ids = F.pad(
454-
new_model_input.input_ids,
455-
(0, padded_batch_size),
456-
mode="constant",
457-
value=1,
458-
)
459-
new_model_input.b_req_idx = F.pad(
460-
new_model_input.b_req_idx,
461-
(0, padded_batch_size),
462-
mode="constant",
463-
value=self.req_manager.HOLD_REQUEST_ID,
464-
)
465-
new_model_input.b_mtp_index = F.pad(
466-
new_model_input.b_mtp_index,
467-
(0, padded_batch_size),
468-
mode="constant",
469-
value=0,
470-
)
471393
new_model_input.b_seq_len = F.pad(
472-
new_model_input.b_seq_len,
473-
(0, padded_batch_size),
474-
mode="constant",
475-
value=2,
476-
)
477-
new_model_input.mem_indexes = F.pad(
478-
new_model_input.mem_indexes,
479-
(0, padded_batch_size),
480-
mode="constant",
481-
value=self.mem_manager.HOLD_TOKEN_MEMINDEX,
394+
new_model_input.b_seq_len, (0, padded_batch_size), mode="constant", value=2
482395
)
396+
397+
new_model_input.input_ids = F.pad(new_model_input.input_ids, (0, padded_batch_size), mode="constant", value=1)
398+
new_model_input.b_req_idx = F.pad(
399+
new_model_input.b_req_idx, (0, padded_batch_size), mode="constant", value=self.req_manager.HOLD_REQUEST_ID
400+
)
401+
new_model_input.b_mtp_index = F.pad(
402+
new_model_input.b_mtp_index, (0, padded_batch_size), mode="constant", value=0
403+
)
404+
new_model_input.mem_indexes = F.pad(
405+
new_model_input.mem_indexes,
406+
(0, padded_batch_size),
407+
mode="constant",
408+
value=self.mem_manager.HOLD_TOKEN_MEMINDEX,
409+
)
483410
new_model_input.multimodal_params = new_model_input.multimodal_params + [
484411
{"images": [], "audios": []} for _ in range(padded_batch_size)
485412
]
@@ -698,6 +625,7 @@ def _decode(
698625

699626
@final
700627
def _context_forward(self, infer_state: InferStateInfo):
628+
701629
input_embs = self.pre_infer.context_forward(infer_state.input_ids, infer_state, self.pre_post_weight)
702630
if self.args.enable_dp_prefill_balance:
703631
assert not self.args.enable_prefill_cudagraph, "not support now"

lightllm/common/basemodel/cuda_graph.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,8 @@ def _build_warmup_decode_model_input(
102102
real_batch_size = batch_size // mtp_size
103103
b_mtp_index = torch.arange(mtp_size, dtype=torch.int32, device=device).repeat(real_batch_size)
104104
b_seq_len = torch.arange(2, mtp_size + 2, dtype=torch.int32, device=device).repeat(real_batch_size)
105-
b_num_accepted_tokens = torch.ones(real_batch_size, dtype=torch.int32, device=device)
105+
# b_num_accepted_tokens 不再随 model_input 传入:GDN 的 init_mtp_verify_extra_state 会按
106+
# req_first(全 HOLD,槽恒为 1) gather,warmup/capture 自然得到全 1,等价旧的 torch.ones。
106107
total_token_num = real_batch_size * (mtp_size * (mtp_size + 3) // 2)
107108
else:
108109
seq_len = 2

lightllm/common/basemodel/mtp_verify_extra_state.py

Lines changed: 3 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,29 +3,14 @@
33
from lightllm.utils.envs_utils import get_env_start_args
44

55

6-
def init_mtp_verify_extra_state(self):
7-
"""Shared MTP-verify decode metadata, used by qwen3_5 and qwen3next infer-struct classes (#12).
8-
Call AFTER super().init_some_extra_state(model). `self` is the InferStateInfo instance."""
6+
def init_mtp_verify_extra_state(self, model):
97
self.b_att_seq_len = self.b_seq_len
108
mtp_step = get_env_start_args().mtp_step
119
self.b_buffer_idx = self.b_req_idx * (mtp_step + 1) + self.b_mtp_index
12-
# conv buffer is now ONE widened slot per request (indexed by req_idx),
13-
# dropping the *(S+1) + mtp_index addressing used by the SSM block.
1410
self.b_conv_buffer_idx = self.b_req_idx
15-
# MTP verify batch: decode-mode, S+1 expanded, and gated on the
16-
# per-real-request accept tensor that decode_mtp threads in. Gating on
17-
# b_num_accepted_tokens (vs only b_mtp_index, which is set for any decode)
18-
# distinguishes the main-model verify forward from draft/plain decode.
19-
self.is_mtp_verify = (
20-
(mtp_step > 0)
21-
and (not self.is_prefill)
22-
and (self.b_mtp_index is not None)
23-
and (self.b_num_accepted_tokens is not None)
24-
)
11+
self.is_mtp_verify = (mtp_step > 0) and (not self.is_prefill) and (self.b_mtp_index is not None)
2512
self.b_gdn_verify_cu_seqlens = None
2613
self.b_ssm_index_rows = None
27-
# b_num_accepted_tokens is threaded onto the infer_state from ModelInput by
28-
# _create_inferstate (mirrors b_mtp_index) BEFORE this runs; nothing to do here.
2914
if self.is_mtp_verify:
3015
step = mtp_step + 1
3116
n_real = self.b_req_idx.shape[0] // step
@@ -36,12 +21,6 @@ def init_mtp_verify_extra_state(self):
3621
base = (req_first * step).view(n_real, 1)
3722
self.b_ssm_index_rows = base + torch.arange(step, device=base.device, dtype=base.dtype).view(1, step)
3823
assert self.b_ssm_index_rows.shape == (n_real, step)
39-
# The spec conv kernel is per-SEQUENCE (one program per real request),
40-
# indexed by conv_state_indices[idx_seq] with idx_seq in [0, n_real),
41-
# aligned 1:1 with b_gdn_verify_cu_seqlens / b_num_accepted_tokens. The
42-
# default b_conv_buffer_idx = b_req_idx has the expanded length n_real*step,
43-
# which launches n_real*step conv programs and reads num_accepted/
44-
# query_start_loc out of bounds for idx_seq >= n_real, corrupting the
45-
# committed conv slot. Narrow it to one widened conv slot per request.
4624
self.b_conv_buffer_idx = req_first
25+
self.b_num_accepted_tokens = model.req_manager.req_to_accept_len[req_first]
4726
return

lightllm/common/basemodel/triton_kernel/mtp_utils.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,51 @@ def mtp_scatter_next_token_ids(
148148
)
149149

150150

151+
@triton.jit
152+
def _fwd_kernel_scatter_accept_len(
153+
req_to_accept_len,
154+
b_req_mtp_start_loc,
155+
b_req_idx,
156+
mtp_accept_len,
157+
):
158+
cur_index = tl.program_id(0)
159+
req_start_loc = tl.load(b_req_mtp_start_loc + cur_index)
160+
cur_req_idx = tl.load(b_req_idx + req_start_loc)
161+
accept_len = tl.load(mtp_accept_len + cur_index)
162+
tl.store(req_to_accept_len + cur_req_idx, accept_len)
163+
return
164+
165+
166+
def scatter_mtp_accept_len(
167+
req_to_accept_len: torch.Tensor,
168+
b_req_mtp_start_loc: torch.Tensor,
169+
b_req_idx: torch.Tensor,
170+
mtp_accept_len: torch.Tensor,
171+
):
172+
"""
173+
将本步每个真实请求(组首)的 accept 数量写入 GPU 常驻的 req_to_accept_len[req_idx]。
174+
融合 `req_to_accept_len[b_req_idx[b_req_mtp_start_loc]] = mtp_accept_len` 的 gather+scatter
175+
为单次 launch、无中间张量。每个 program 处理一个真实请求。
176+
Args:
177+
req_to_accept_len: (max_req_num + 1,)
178+
b_req_mtp_start_loc: (num_reqs,) 每组首行在 batch 中的偏移
179+
b_req_idx: (batch_size,) grouped 布局的 req_idx(组首即该请求的 req_idx)
180+
mtp_accept_len: (num_reqs,)
181+
"""
182+
num_reqs = mtp_accept_len.shape[0]
183+
if num_reqs == 0:
184+
return
185+
grid = (num_reqs,)
186+
_fwd_kernel_scatter_accept_len[grid](
187+
req_to_accept_len=req_to_accept_len,
188+
b_req_mtp_start_loc=b_req_mtp_start_loc,
189+
b_req_idx=b_req_idx,
190+
mtp_accept_len=mtp_accept_len,
191+
num_warps=1,
192+
num_stages=1,
193+
)
194+
195+
151196
def test_mtp_verify():
152197
req_to_next_token_ids = torch.tensor(
153198
[[1, 2, -2, -1, -1], [1, 2, 0, -1, -1], [1, 3, 4, 4, 5]], dtype=torch.int32, device="cuda"

lightllm/common/linear_att_cache_manager/config_objs.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,21 @@
11
import torch
22
import dataclasses
33
import triton
4-
from lightllm.utils.envs_utils import get_env_start_args, _mtp_added_layer_num
4+
from lightllm.utils.envs_utils import get_env_start_args
55
from lightllm.utils.log_utils import init_logger
66
from lightllm.utils.torch_dtype_utils import get_torch_dtype
77

88
logger = init_logger(__name__)
99

1010

1111
def get_mtp_draft_full_att_layer_num(args) -> int:
12-
# Delegates to the single source of truth in envs_utils (#9).
13-
return _mtp_added_layer_num(getattr(args, "mtp_mode", None), getattr(args, "mtp_step", 0))
12+
# mtp_mode -> draft model 增加的 full-att KV 层数(与 envs_utils.get_added_mtp_kv_layer_num 同口径)。
13+
mtp_mode = getattr(args, "mtp_mode", None)
14+
if mtp_mode == "eagle_with_att":
15+
return 1
16+
if mtp_mode == "vanilla_with_att":
17+
return getattr(args, "mtp_step", 0)
18+
return 0
1419

1520

1621
@dataclasses.dataclass

lightllm/common/req_manager.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,15 @@ def __init__(self, max_request_num, max_sequence_length, mem_manager: MemoryMana
8686
self.req_sampling_params_manager = ReqSamplingParamsManager(max_request_num)
8787
self.max_request_num = max_request_num
8888
self.HOLD_REQUEST_ID = max_request_num
89+
# MTP verify decode 的 per-req accept 数量:GPU 常驻、按 req_idx 索引(含 HOLD 槽)。
90+
# 取代旧的 req.mtp_accept_len host 属性 —— verify 后在 GPU 上 scatter,下一步在 GDN 的
91+
# init_mtp_verify_extra_state 里按 req_first gather 成 b_num_accepted_tokens,省掉每步的
92+
# host 回写 + H2D 重建。HOLD 槽恒为 1,使 padding 组 gather 到 1。仅 mtp_step>0 时分配。
93+
self.req_to_accept_len = (
94+
torch.ones((max_request_num + 1,), dtype=torch.int32, device="cuda")
95+
if get_env_start_args().mtp_step > 0
96+
else None
97+
)
8998

9099
def alloc(self):
91100
return self.req_list.alloc()
@@ -274,7 +283,8 @@ def init_linear_att_state(self, req: "InferReq"):
274283
# #17: zero the FULL (mtp_step + 1)-row SSM block, not just canonical row +0, so a future
275284
# first-step verify reading offset>0 after fresh init never hits a never-written row (NaN).
276285
self.req_to_ssm_state.buffer[:, ssm_start : ssm_start + (self.mtp_step + 1), ...].fill_(0)
277-
req.mtp_accept_len = 1
286+
if self.req_to_accept_len is not None:
287+
self.req_to_accept_len[req.req_idx] = 1
278288
return
279289

280290
def get_mamba_cache(self, layer_idx_in_all: int):
@@ -298,7 +308,8 @@ def copy_big_page_buffer_to_linear_att_state(self, big_page_buffer_idx: int, req
298308
narrow_w = conv_state.shape[-1] # persisted (narrow) width
299309
self.req_to_conv_state.buffer[:, conv_dest, ..., :narrow_w] = conv_state
300310
self.req_to_ssm_state.buffer[:, ssm_dest, ...] = ssm_state
301-
req.mtp_accept_len = 1
311+
if self.req_to_accept_len is not None:
312+
self.req_to_accept_len[req.req_idx] = 1
302313
return
303314

304315
def copy_small_page_buffer_to_linear_att_state(
@@ -314,5 +325,6 @@ def copy_small_page_buffer_to_linear_att_state(
314325
# 同时,非连续对象的拷贝,可能存在效率问题。
315326
self.req_to_conv_state.buffer[:, conv_dest, ..., :narrow_w] = conv_state
316327
self.req_to_ssm_state.buffer[:, ssm_dest, ...] = ssm_state
317-
req.mtp_accept_len = 1
328+
if self.req_to_accept_len is not None:
329+
self.req_to_accept_len[req.req_idx] = 1
318330
return

lightllm/models/qwen3_5/infer_struct.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,5 @@ def init_some_extra_state(self, model):
1010
super().init_some_extra_state(model)
1111
from lightllm.common.basemodel.mtp_verify_extra_state import init_mtp_verify_extra_state
1212

13-
init_mtp_verify_extra_state(self)
13+
init_mtp_verify_extra_state(self, model)
1414
return

lightllm/models/qwen3next/infer_struct.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,5 @@ def init_some_extra_state(self, model):
1010
super().init_some_extra_state(model)
1111
from lightllm.common.basemodel.mtp_verify_extra_state import init_mtp_verify_extra_state
1212

13-
init_mtp_verify_extra_state(self)
13+
init_mtp_verify_extra_state(self, model)
1414
return

lightllm/server/router/model_infer/infer_batch.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -389,9 +389,11 @@ def copy_linear_att_state_to_cache_buffer(self, b_req_idx: torch.Tensor, reqs: L
389389

390390
from lightllm.common.basemodel.triton_kernel.linear_att_copy import copy_linear_att_state_to_kv_buffer
391391

392-
b_num_accepted_tokens = torch.tensor(
393-
[req.mtp_accept_len for req in reqs], dtype=torch.int32, requires_grad=False, device="cpu"
392+
# accept 数量改由 GPU 常驻的 req_to_accept_len 按 req_idx gather(不再读 req.mtp_accept_len)。
393+
req_idxs = torch.tensor(
394+
[req.req_idx for req in reqs], dtype=torch.int32, requires_grad=False, device="cpu"
394395
).cuda(non_blocking=True)
396+
b_num_accepted_tokens = self.req_manager.req_to_accept_len[req_idxs]
395397

396398
copy_linear_att_state_to_kv_buffer(
397399
b_req_idx=b_req_idx,
@@ -417,11 +419,13 @@ def copy_linear_att_state_to_cache_buffer(self, b_req_idx: torch.Tensor, reqs: L
417419
self.radix_cache.linear_att_small_page_buffers.alloc_one_state_cache()
418420
)
419421
if req.tail_linear_att_small_page_buffer_id is not None:
420-
assert 1 <= req.mtp_accept_len <= self.args.mtp_step + 1, (
421-
f"mtp_accept_len={req.mtp_accept_len} out of range "
422+
# 冷路径(prefill 跨小页边界):单标量从 GPU buffer 读回做 Python 切片下标。
423+
accept_len = int(self.req_manager.req_to_accept_len[req.req_idx].item())
424+
assert 1 <= accept_len <= self.args.mtp_step + 1, (
425+
f"mtp_accept_len={accept_len} out of range "
422426
f"[1, {self.args.mtp_step + 1}]; would slice past the widened conv slot"
423427
)
424-
canonical_off = req.mtp_accept_len - 1
428+
canonical_off = accept_len - 1
425429
conv_src_idx = req.req_idx
426430
ssm_src_idx = req.req_idx * (self.args.mtp_step + 1) + canonical_off
427431
narrow_w = self.req_manager.linear_config.get_persisted_conv_state_shape()[-1]
@@ -578,8 +582,6 @@ def __init__(
578582
else:
579583
self.decode_need_token_num = self._normal_decode_need_token_num
580584

581-
self.mtp_accept_len: int = 1
582-
583585
if g_infer_context.is_linear_att_mixed_model:
584586
self.get_chuncked_input_token_len = self.get_chuncked_input_token_len_for_linear_att
585587
self.get_chuncked_input_token_ids = self.get_chuncked_input_token_ids_for_linear_att

0 commit comments

Comments
 (0)