Skip to content

Commit c778385

Browse files
committed
feat(qwen3_5_mtp): scheduler MTP verify backend + accept-len transport
Drive the draft/verify loop from the scheduler: - carry a canonical InferReq.mtp_accept_len pointer and persist the per-request accept_len across steps; build per-req b_num_accepted_tokens in decode_mtp and commit it in phase 2 so the next step reads a fresh count. - extend the chunked_prefill backend / base_backend with the MTP verify dispatch and the partial-accept read offset.
1 parent bb5ffd9 commit c778385

3 files changed

Lines changed: 191 additions & 33 deletions

File tree

lightllm/server/router/model_infer/infer_batch.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,11 @@ def copy_linear_att_state_to_cache_buffer(self, b_req_idx: torch.Tensor, reqs: L
357357
if not self.is_linear_att_mixed_model:
358358
return
359359

360+
# 当 dynamic prompt cache 被禁用时 radix_cache 为 None,没有大页/小页缓冲可写,
361+
# 线性层状态仅存于 req_manager 的 GPU buffer 即可,直接跳过跨请求缓存拷贝。
362+
if self.radix_cache is None:
363+
return
364+
360365
# 大页对应的 linear att 的拷贝
361366
big_page_token_num = self.args.linear_att_hash_page_size * self.args.linear_att_page_block_num
362367
big_page_buffer_ids = []
@@ -377,6 +382,10 @@ def copy_linear_att_state_to_cache_buffer(self, b_req_idx: torch.Tensor, reqs: L
377382

378383
from lightllm.common.basemodel.triton_kernel.linear_att_copy import copy_linear_att_state_to_kv_buffer
379384

385+
b_num_accepted_tokens = torch.tensor(
386+
[req.mtp_accept_len for req in reqs], dtype=torch.int32, requires_grad=False, device="cpu"
387+
).cuda(non_blocking=True)
388+
380389
copy_linear_att_state_to_kv_buffer(
381390
b_req_idx=b_req_idx,
382391
big_page_buffer_ids=big_page_buffer_ids,
@@ -385,6 +394,7 @@ def copy_linear_att_state_to_cache_buffer(self, b_req_idx: torch.Tensor, reqs: L
385394
cpu_kv_conv_state=self.radix_cache.linear_att_big_page_buffers.conv_state_cache.buffer,
386395
cpu_kv_ssm_state=self.radix_cache.linear_att_big_page_buffers.ssm_state_cache.buffer,
387396
mtp_step=self.args.mtp_step,
397+
b_num_accepted_tokens=b_num_accepted_tokens,
388398
)
389399

390400
assert not self.args.disable_chunked_prefill, "chunked prefill mode must be enabled for linear att mixed model"
@@ -400,9 +410,14 @@ def copy_linear_att_state_to_cache_buffer(self, b_req_idx: torch.Tensor, reqs: L
400410
self.radix_cache.linear_att_small_page_buffers.alloc_one_state_cache()
401411
)
402412
if req.tail_linear_att_small_page_buffer_id is not None:
403-
src_buffer_idx = req.req_idx * (self.args.mtp_step + 1)
404-
gpu_conv_state = self.req_manager.req_to_conv_state.buffer[:, src_buffer_idx, ...]
405-
gpu_ssm_state = self.req_manager.req_to_ssm_state.buffer[:, src_buffer_idx, ...]
413+
canonical_off = req.mtp_accept_len - 1
414+
conv_src_idx = req.req_idx
415+
ssm_src_idx = req.req_idx * (self.args.mtp_step + 1) + canonical_off
416+
narrow_w = self.req_manager.linear_config.get_persisted_conv_state_shape()[-1]
417+
gpu_conv_state = self.req_manager.req_to_conv_state.buffer[
418+
:, conv_src_idx, ..., canonical_off : canonical_off + narrow_w
419+
]
420+
gpu_ssm_state = self.req_manager.req_to_ssm_state.buffer[:, ssm_src_idx, ...]
406421
dst_buffer_idx = req.tail_linear_att_small_page_buffer_id
407422

408423
dst_conv_state, dst_ssm_state = self.radix_cache.linear_att_small_page_buffers.get_state_cache(
@@ -558,6 +573,8 @@ def __init__(
558573
else:
559574
self.decode_need_token_num = self._normal_decode_need_token_num
560575

576+
self.mtp_accept_len: int = 1
577+
561578
if g_infer_context.is_linear_att_mixed_model:
562579
self.get_chuncked_input_token_len = self.get_chuncked_input_token_len_for_linear_att
563580
self.get_chuncked_input_token_ids = self.get_chuncked_input_token_ids_for_linear_att

lightllm/server/router/model_infer/mode_backend/base_backend.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,16 @@ def init_mtp_draft_model(self, main_kvargs: dict):
357357
elif mtp_model_cfg["model_type"] == "glm4_moe_lite":
358358
assert self.args.mtp_mode in ["vanilla_with_att", "eagle_with_att"]
359359
self.draft_models.append(Glm4MoeLiteMTPModel(mtp_model_kvargs))
360+
elif model_type in ("qwen3_5", "qwen3_5_text"):
361+
assert self.args.mtp_mode in ["vanilla_with_att", "eagle_with_att"]
362+
from lightllm.models.qwen3_5_mtp.model import Qwen3_5MTPModel
363+
364+
self.draft_models.append(Qwen3_5MTPModel(mtp_model_kvargs))
365+
elif model_type in ("qwen3_5_moe", "qwen3_5_moe_text"):
366+
assert self.args.mtp_mode in ["vanilla_with_att", "eagle_with_att"]
367+
from lightllm.models.qwen3_5_moe_mtp.model import Qwen3_5MoeMTPModel
368+
369+
self.draft_models.append(Qwen3_5MoeMTPModel(mtp_model_kvargs))
360370
else:
361371
raise ValueError(f"Unsupported MTP model type: {model_type}")
362372

@@ -602,7 +612,6 @@ def _get_classed_reqs(
602612
can_alloc_token_num = g_infer_context.get_can_alloc_token_num()
603613

604614
for req_obj in ready_reqs:
605-
606615
if req_obj.filter_mark:
607616
finished_reqs.append(req_obj)
608617
continue
@@ -783,11 +792,35 @@ def _verify_mtp_v2(
783792
)
784793
return mtp_accept_len, accepted_index
785794

795+
def _commit_mtp_accept_len(
796+
self,
797+
decode_reqs: List[InferReq],
798+
mtp_accept_len_cpu: torch.Tensor,
799+
):
800+
# Carry the per-req accept count into the NEXT step as the canonical
801+
# pointer (design §3.1). This must run on every rank (not only master):
802+
# the kernels on this rank read req.mtp_accept_len.
803+
#
804+
# CRITICAL ordering (overlap scheduler): the next step's decode_mtp reads
805+
# req.mtp_accept_len (to build b_num_accepted_tokens) the moment its
806+
# wait_to_forward() is released, which happens at THIS step's
807+
# notify_forward_and_wait_post_handle() (start of phase 3). So this carry
808+
# MUST be committed in phase 2 (pre_post_handle), before that release —
809+
# otherwise the next step reads a one-step-stale accept count. The error
810+
# is invisible while accept_len is constant (==1) and corrupts the GDN
811+
# conv/ssm committed-state read-offset the instant a multi-token accept
812+
# (accept_len>=2) occurs.
813+
for req, accept_len in zip(decode_reqs, mtp_accept_len_cpu):
814+
req.mtp_accept_len = int(accept_len)
815+
return
816+
786817
def _update_mtp_accept_ratio(
787818
self,
788819
decode_reqs: List[InferReq],
789820
mtp_accept_len_cpu: torch.Tensor,
790821
):
822+
# Master-only accept-ratio statistics. Unlike _commit_mtp_accept_len this
823+
# only feeds metrics, so it may stay in the phase-3 post_handle region.
791824
if self.is_master_in_dp:
792825
for req, accept_len in zip(decode_reqs, mtp_accept_len_cpu):
793826
req.update_mtp_accepted_token_num(accept_token_num=accept_len - 1)
@@ -809,7 +842,6 @@ def _sample_and_scatter_token(
809842
b_prefill_has_output_cpu: torch.Tensor = None,
810843
mask_func: Optional[Callable] = None,
811844
):
812-
813845
if mask_func is not None:
814846
assert len(run_reqs) == logits.shape[0]
815847
mask_func(run_reqs, logits)

0 commit comments

Comments
 (0)