Skip to content

Commit f535bbe

Browse files
committed
feat(scheduler): MTP verify backend + accept-len transport
Wire the verify path through the inference backends: a single draft-model factory keyed on (model_type, mtp_mode); build the (mtp_step+1)-expanded verify decode batch; run the eagle + vanilla draft decode; verify accepted tokens; and thread per-request accept-lengths (b_num_accepted_tokens) from the chunked-prefill and dp backends into the model verify forward.
1 parent 7057e0e commit f535bbe

5 files changed

Lines changed: 340 additions & 112 deletions

File tree

lightllm/server/router/model_infer/infer_batch.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,11 @@ def copy_linear_att_state_to_cache_buffer(self, b_req_idx: torch.Tensor, reqs: L
361361
if not self.is_linear_att_mixed_model:
362362
return
363363

364+
# 当 dynamic prompt cache 被禁用时 radix_cache 为 None,没有大页/小页缓冲可写,
365+
# 线性层状态仅存于 req_manager 的 GPU buffer 即可,直接跳过跨请求缓存拷贝。
366+
if self.radix_cache is None:
367+
return
368+
364369
# 大页对应的 linear att 的拷贝
365370
big_page_token_num = self.args.linear_att_hash_page_size * self.args.linear_att_page_block_num
366371
big_page_buffer_ids = []
@@ -384,6 +389,10 @@ def copy_linear_att_state_to_cache_buffer(self, b_req_idx: torch.Tensor, reqs: L
384389

385390
from lightllm.common.basemodel.triton_kernel.linear_att_copy import copy_linear_att_state_to_kv_buffer
386391

392+
b_num_accepted_tokens = torch.tensor(
393+
[req.mtp_accept_len for req in reqs], dtype=torch.int32, requires_grad=False, device="cpu"
394+
).cuda(non_blocking=True)
395+
387396
copy_linear_att_state_to_kv_buffer(
388397
b_req_idx=b_req_idx,
389398
big_page_buffer_ids=big_page_buffer_ids,
@@ -392,6 +401,7 @@ def copy_linear_att_state_to_cache_buffer(self, b_req_idx: torch.Tensor, reqs: L
392401
cpu_kv_conv_state=self.radix_cache.linear_att_big_page_buffers.conv_state_cache.buffer,
393402
cpu_kv_ssm_state=self.radix_cache.linear_att_big_page_buffers.ssm_state_cache.buffer,
394403
mtp_step=self.args.mtp_step,
404+
b_num_accepted_tokens=b_num_accepted_tokens,
395405
)
396406

397407
assert not self.args.disable_chunked_prefill, "chunked prefill mode must be enabled for linear att mixed model"
@@ -407,9 +417,18 @@ def copy_linear_att_state_to_cache_buffer(self, b_req_idx: torch.Tensor, reqs: L
407417
self.radix_cache.linear_att_small_page_buffers.alloc_one_state_cache()
408418
)
409419
if req.tail_linear_att_small_page_buffer_id is not None:
410-
src_buffer_idx = req.req_idx * (self.args.mtp_step + 1)
411-
gpu_conv_state = self.req_manager.req_to_conv_state.buffer[:, src_buffer_idx, ...]
412-
gpu_ssm_state = self.req_manager.req_to_ssm_state.buffer[:, src_buffer_idx, ...]
420+
assert 1 <= req.mtp_accept_len <= self.args.mtp_step + 1, (
421+
f"mtp_accept_len={req.mtp_accept_len} out of range "
422+
f"[1, {self.args.mtp_step + 1}]; would slice past the widened conv slot"
423+
)
424+
canonical_off = req.mtp_accept_len - 1
425+
conv_src_idx = req.req_idx
426+
ssm_src_idx = req.req_idx * (self.args.mtp_step + 1) + canonical_off
427+
narrow_w = self.req_manager.linear_config.get_persisted_conv_state_shape()[-1]
428+
gpu_conv_state = self.req_manager.req_to_conv_state.buffer[
429+
:, conv_src_idx, ..., canonical_off : canonical_off + narrow_w
430+
]
431+
gpu_ssm_state = self.req_manager.req_to_ssm_state.buffer[:, ssm_src_idx, ...]
413432
dst_buffer_idx = req.tail_linear_att_small_page_buffer_id
414433

415434
dst_conv_state, dst_ssm_state = self.radix_cache.linear_att_small_page_buffers.get_state_cache(
@@ -559,6 +578,8 @@ def __init__(
559578
else:
560579
self.decode_need_token_num = self._normal_decode_need_token_num
561580

581+
self.mtp_accept_len: int = 1
582+
562583
if g_infer_context.is_linear_att_mixed_model:
563584
self.get_chuncked_input_token_len = self.get_chuncked_input_token_len_for_linear_att
564585
self.get_chuncked_input_token_ids = self.get_chuncked_input_token_ids_for_linear_att

lightllm/server/router/model_infer/mode_backend/base_backend.py

Lines changed: 66 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import copy
23
import numpy as np
34
import torch
45
import time
@@ -41,10 +42,6 @@
4142
)
4243
from lightllm.server.core.objs.shm_objs_io_buffer import ShmObjsIOBuffer
4344
from lightllm.server.router.model_infer.mode_backend.overlap_events import OverlapEventManager, OverlapEventPack
44-
from lightllm.models.deepseek_mtp.model import Deepseek3MTPModel
45-
from lightllm.models.qwen3_moe_mtp.model import Qwen3MOEMTPModel
46-
from lightllm.models.mistral_mtp.model import MistralMTPModel
47-
from lightllm.models.glm4_moe_lite_mtp.model import Glm4MoeLiteMTPModel
4845
from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample
4946
from lightllm.common.basemodel.triton_kernel.gather_token_id import scatter_token
5047
from lightllm.server.pd_io_struct import PDChunckedTransTaskRet
@@ -328,22 +325,11 @@ def init_mtp_draft_model(self, main_kvargs: dict):
328325
"mtp_previous_draft_models": self.draft_models.copy(),
329326
}
330327

331-
# Select MTP model class based on model type
328+
# Select MTP model class based on model type (single source of truth: #10).
329+
from lightllm.server.router.model_infer.mode_backend.mtp_model_factory import create_mtp_draft_model
330+
332331
model_type = mtp_model_cfg.get("model_type", "")
333-
if model_type == "deepseek_v3":
334-
assert self.args.mtp_mode in ["vanilla_with_att", "eagle_with_att"]
335-
self.draft_models.append(Deepseek3MTPModel(mtp_model_kvargs))
336-
elif model_type == "qwen3_moe":
337-
assert self.args.mtp_mode in ["vanilla_no_att", "eagle_no_att"]
338-
self.draft_models.append(Qwen3MOEMTPModel(mtp_model_kvargs))
339-
elif model_type == "mistral":
340-
assert self.args.mtp_mode in ["vanilla_no_att", "eagle_no_att"]
341-
self.draft_models.append(MistralMTPModel(mtp_model_kvargs))
342-
elif mtp_model_cfg["model_type"] == "glm4_moe_lite":
343-
assert self.args.mtp_mode in ["vanilla_with_att", "eagle_with_att"]
344-
self.draft_models.append(Glm4MoeLiteMTPModel(mtp_model_kvargs))
345-
else:
346-
raise ValueError(f"Unsupported MTP model type: {model_type}")
332+
self.draft_models.append(create_mtp_draft_model(model_type, self.args.mtp_mode, mtp_model_kvargs))
347333

348334
self.logger.info(f"loaded mtp model class {self.draft_models[i].__class__}")
349335
return
@@ -584,7 +570,6 @@ def _get_classed_reqs(
584570
can_alloc_token_num = g_infer_context.get_can_alloc_token_num()
585571

586572
for req_obj in ready_reqs:
587-
588573
if req_obj.filter_mark:
589574
finished_reqs.append(req_obj)
590575
continue
@@ -761,20 +746,79 @@ def _verify_mtp_v2(
761746
)
762747
return mtp_accept_len, accepted_index
763748

749+
def _build_eagle_accepted_draft_input(
750+
self,
751+
main_model_input: ModelInput,
752+
main_model_output: ModelOutput,
753+
next_token_ids: torch.Tensor,
754+
mtp_accept_len: torch.Tensor,
755+
b_req_mtp_start_loc: torch.Tensor,
756+
):
757+
accepted_row_idx = b_req_mtp_start_loc + mtp_accept_len - 1
758+
accepted_row_idx_long = accepted_row_idx.long()
759+
760+
draft_model_input = copy.copy(main_model_input)
761+
draft_model_input.batch_size = accepted_row_idx.shape[0]
762+
draft_model_input.total_token_num = draft_model_input.batch_size * main_model_input.max_kv_seq_len
763+
draft_model_input.input_ids = next_token_ids.index_select(0, accepted_row_idx_long)
764+
draft_model_input.mtp_draft_input_hiddens = main_model_output.mtp_main_output_hiddens.index_select(
765+
0, accepted_row_idx_long
766+
)
767+
draft_model_input.b_req_idx = main_model_input.b_req_idx.index_select(0, accepted_row_idx_long)
768+
draft_model_input.b_mtp_index = main_model_input.b_mtp_index.index_select(0, accepted_row_idx_long)
769+
draft_model_input.b_seq_len = main_model_input.b_seq_len.index_select(0, accepted_row_idx_long)
770+
draft_model_input.b_num_accepted_tokens = None
771+
if main_model_input.mem_indexes is not None:
772+
draft_model_input.mem_indexes = main_model_input.mem_indexes.index_select(0, accepted_row_idx_long)
773+
draft_model_input.mem_indexes_cpu = None
774+
if main_model_input.b_shared_seq_len is not None:
775+
draft_model_input.b_shared_seq_len = main_model_input.b_shared_seq_len.index_select(
776+
0, accepted_row_idx_long
777+
)
778+
if main_model_input.b_mark_shared_group is not None:
779+
draft_model_input.b_mark_shared_group = main_model_input.b_mark_shared_group.index_select(
780+
0, accepted_row_idx_long
781+
)
782+
783+
if accepted_row_idx.device.type == "cpu":
784+
selected_rows = accepted_row_idx.tolist()
785+
draft_model_input.multimodal_params = [main_model_input.multimodal_params[i] for i in selected_rows]
786+
else:
787+
draft_model_input.multimodal_params = [
788+
{"images": [], "audios": []} for _ in range(draft_model_input.batch_size)
789+
]
790+
791+
accepted_next_token_ids = draft_model_input.input_ids
792+
accepted_req_idx = draft_model_input.b_req_idx
793+
return draft_model_input, accepted_next_token_ids, accepted_req_idx
794+
795+
def _scatter_accepted_next_token_ids(self, accepted_req_idx: torch.Tensor, all_next_token_ids: torch.Tensor):
796+
req_to_next_token_ids = self.model.req_manager.req_sampling_params_manager.req_to_next_token_ids
797+
width = all_next_token_ids.shape[1]
798+
req_to_next_token_ids[:, :width].index_copy_(
799+
0,
800+
accepted_req_idx.long(),
801+
all_next_token_ids.to(dtype=req_to_next_token_ids.dtype),
802+
)
803+
return
804+
764805
def _update_mtp_accept_ratio(
765806
self,
766807
decode_reqs: List[InferReq],
767808
mtp_accept_len_cpu: torch.Tensor,
768809
):
810+
# Master-only accept-ratio statistics. Unlike the phase-2 mtp_accept_len commit
811+
# (inlined in decode_mtp) this only feeds metrics, so it may stay in phase 3.
769812
if self.is_master_in_dp:
770813
for req, accept_len in zip(decode_reqs, mtp_accept_len_cpu):
771814
req.update_mtp_accepted_token_num(accept_token_num=accept_len - 1)
772815
return
773816

774817
def _gen_argmax_token_ids(self, model_output: ModelOutput):
775818
logits = model_output.logits
776-
probs = torch.softmax(logits, dim=-1)
777-
draft_next_token_ids_gpu = torch.argmax(probs, dim=-1)
819+
# softmax is strictly monotonic, so argmax(softmax(logits)) == argmax(logits);
820+
# skip the softmax to shorten the per-step MTP draft critical chain (need-to-fix #16).
821+
draft_next_token_ids_gpu = torch.argmax(logits, dim=-1)
778822
return draft_next_token_ids_gpu
779823

780824
def _sample_and_scatter_token(
@@ -787,7 +831,6 @@ def _sample_and_scatter_token(
787831
b_prefill_has_output_cpu: torch.Tensor = None,
788832
mask_func: Optional[Callable] = None,
789833
):
790-
791834
if mask_func is not None:
792835
assert len(run_reqs) == logits.shape[0]
793836
mask_func(run_reqs, logits)

lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py

Lines changed: 46 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import torch
22
import time
3+
import copy
34
from typing import List, Optional, Callable, Dict, Any
45
from queue import Queue
56
from lightllm.server.router.model_infer.mode_backend.base_backend import ModeBackend
@@ -240,17 +241,23 @@ def decode_mtp(
240241
"""
241242
model_input, run_reqs = prepare_decode_inputs(decode_reqs)
242243

244+
if self.mtp_step > 0:
245+
accept_lens = [req.mtp_accept_len for req in decode_reqs]
246+
model_input.b_num_accepted_tokens = g_pin_mem_manager.gen_from_list(
247+
key="b_num_accepted_tokens",
248+
data=accept_lens,
249+
dtype=torch.int32,
250+
)
251+
243252
with torch.cuda.stream(g_infer_context.get_overlap_stream()):
244-
b_mtp_index_cpu = model_input.b_mtp_index
245253
model_output = self.model.forward(model_input)
246254
next_token_ids, next_token_logprobs = sample(model_output.logits, run_reqs, self.eos_id)
247-
# verify the next_token_ids
248-
b_req_mtp_start_loc = [index for index, mtp_index in enumerate(b_mtp_index_cpu) if mtp_index == 0]
249-
b_req_mtp_start_loc = g_pin_mem_manager.gen_from_list(
250-
key="b_req_mtp_start_loc",
251-
data=b_req_mtp_start_loc,
252-
dtype=torch.int32,
253-
).cuda(non_blocking=True)
255+
# verify the next_token_ids. The chunked decode batch is the contiguous
256+
# (mtp_step+1)-expanded layout, so request starts are structurally
257+
# arange(n_real)*(mtp_step+1). Compute on device instead of a per-step Python
258+
# list-comp + pinned pack + H2D (#22).
259+
n_real = model_input.batch_size // (self.mtp_step + 1)
260+
b_req_mtp_start_loc = torch.arange(n_real, dtype=torch.int32, device="cuda") * (self.mtp_step + 1)
254261

255262
mtp_accept_len, accepted_index = self._verify_mtp_v2(
256263
new_next_token_ids=next_token_ids,
@@ -292,6 +299,8 @@ def decode_mtp(
292299
# 第二阶段
293300
event_pack.notify_post_handle_and_wait_pre_post_handle()
294301
verify_event.synchronize()
302+
for req, accept_len in zip(decode_reqs, mtp_accept_len_cpu):
303+
req.mtp_accept_len = int(accept_len)
295304
verify_ok_reqs = [run_reqs[i] for i in range(len(run_reqs)) if accepted_index_cpu[i] == 1]
296305
update_packs = self._pre_post_handle(verify_ok_reqs, is_chuncked_mode=False)
297306

@@ -344,15 +353,19 @@ def _draft_decode_vanilla(
344353
mtp_accept_len: torch.Tensor,
345354
b_req_mtp_start_loc: torch.Tensor,
346355
):
347-
# share some inference info with the main model
348-
draft_model_input = main_model_input
356+
# share some inference info with the main model. copy.copy 后清空 b_num_accepted_tokens,
357+
# 使 draft (MTP) forward 走普通 decode 布局 (bs, False);否则会沿用主模型 decode_mtp 设置的
358+
# verify 布局,命中 MTP draft 模型从未捕获的 cudagraph key (bs, True) -> KeyError
359+
# (cudagraph 关闭时则会在扁平的 draft batch 上误用 S+1 分组的 verify attention)。
360+
# 镜像 eagle 路径 _build_eagle_accepted_draft_input 中清空 b_num_accepted_tokens 的处理。
361+
draft_model_input = copy.copy(main_model_input)
362+
draft_model_input.b_num_accepted_tokens = None
349363
draft_model_output = main_model_output
350364
draft_next_token_ids = next_token_ids
351365
all_next_token_ids = []
352366
all_next_token_ids.append(next_token_ids)
353367
# process the draft model output
354368
for draft_model_idx in range(self.mtp_step):
355-
356369
draft_model_input.input_ids = draft_next_token_ids
357370
draft_model_input.mtp_draft_input_hiddens = draft_model_output.mtp_main_output_hiddens
358371
# spec decode: MTP
@@ -379,44 +392,47 @@ def _draft_decode_eagle(
379392
mtp_accept_len: torch.Tensor,
380393
b_req_mtp_start_loc: torch.Tensor,
381394
):
382-
batch_size = main_model_input.batch_size
383-
num_reqs = batch_size // (self.mtp_step + 1)
395+
num_reqs = b_req_mtp_start_loc.shape[0]
384396
if g_infer_context.radix_cache is not None:
385397
g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(num_reqs * self.mtp_step)
386398
eagle_mem_indexes_cpu = g_infer_context.req_manager.mem_manager.alloc(num_reqs * self.mtp_step)
387399
eagle_mem_indexes = eagle_mem_indexes_cpu.cuda(non_blocking=True)
388400

389-
# share some inference info with the main model
390-
draft_model_input = main_model_input
401+
(draft_model_input, draft_next_token_ids, accepted_req_idx,) = self._build_eagle_accepted_draft_input(
402+
main_model_input=main_model_input,
403+
main_model_output=main_model_output,
404+
next_token_ids=next_token_ids,
405+
mtp_accept_len=mtp_accept_len,
406+
b_req_mtp_start_loc=b_req_mtp_start_loc,
407+
)
391408
draft_model_output = main_model_output
392-
draft_next_token_ids = next_token_ids
393409
all_next_token_ids = []
394-
all_next_token_ids.append(next_token_ids)
395-
# process the draft model output
396-
for _step in range(self.mtp_step):
410+
all_next_token_ids.append(draft_next_token_ids)
411+
412+
mtp_size = self.mtp_step + 1
413+
main_mem_indexes = main_model_input.mem_indexes.view(num_reqs, mtp_size)
414+
eagle_mem_indexes_by_req = eagle_mem_indexes.view(self.mtp_step, num_reqs).transpose(0, 1).contiguous()
415+
mem_index_plan = torch.cat([main_mem_indexes, eagle_mem_indexes_by_req], dim=1)
416+
accepted_offsets = mtp_accept_len.long() - 1
417+
req_offsets = torch.arange(num_reqs, dtype=torch.long, device=mtp_accept_len.device)
397418

419+
for _step in range(self.mtp_step):
398420
draft_model_input.input_ids = draft_next_token_ids
399-
draft_model_input.mtp_draft_input_hiddens = draft_model_output.mtp_main_output_hiddens
421+
if _step > 0:
422+
draft_model_input.mtp_draft_input_hiddens = draft_model_output.mtp_main_output_hiddens
423+
draft_model_input.mem_indexes = mem_index_plan[req_offsets, accepted_offsets + _step]
400424
# spec decode: MTP
401425
draft_model_idx = _step % self.num_mtp_models
402426
draft_model_output: ModelOutput = self.draft_models[draft_model_idx].forward(draft_model_input)
403427
draft_next_token_ids = self._gen_argmax_token_ids(draft_model_output)
404428
draft_model_input.b_seq_len += 1
405429
draft_model_input.max_kv_seq_len += 1
406-
eagle_mem_indexes_i = eagle_mem_indexes[_step * num_reqs : (_step + 1) * num_reqs]
407-
draft_model_input.mem_indexes = torch.cat(
408-
[draft_model_input.mem_indexes.view(-1, self.mtp_step + 1)[:, 1:], eagle_mem_indexes_i.view(-1, 1)],
409-
dim=1,
410-
).view(-1)
411430
all_next_token_ids.append(draft_next_token_ids)
412431

413432
all_next_token_ids = torch.stack(all_next_token_ids, dim=1) # [batch_size, mtp_step + 1]
414433

415-
mtp_scatter_next_token_ids(
416-
req_to_next_token_ids=self.model.req_manager.req_sampling_params_manager.req_to_next_token_ids,
417-
b_req_mtp_start_loc=b_req_mtp_start_loc,
434+
self._scatter_accepted_next_token_ids(
435+
accepted_req_idx=accepted_req_idx,
418436
all_next_token_ids=all_next_token_ids,
419-
b_req_idx=main_model_input.b_req_idx,
420-
mtp_accept_len=mtp_accept_len,
421437
)
422438
return eagle_mem_indexes_cpu

0 commit comments

Comments
 (0)