Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
fdc4de1
feat(qwen3_5_mtp): split linear-attn cache state for spec-decode verify
sufubao Jun 4, 2026
24dd3f1
feat(qwen3_5_mtp): qwen3next GDN spec-decode verify path
sufubao Jun 4, 2026
e12cf70
feat(qwen3_5_mtp): basemodel MTP decode CUDA graphs + verify dispatch
sufubao Jun 4, 2026
79a5257
feat(qwen3_5_mtp): scheduler MTP verify backend + accept-len transport
sufubao Jun 4, 2026
75514ec
feat(qwen3_5_mtp): Qwen3.5 / Qwen3.5-MoE MTP draft models
sufubao Jun 4, 2026
e49c092
fix(qwen3next): persist mtp full-attn cpu cache slots
sufubao Jun 4, 2026
b5d5476
refactor(qwen3_5_mtp): drop unused _draft_kv_slot attribute
sufubao Jun 5, 2026
48c15de
style: fix black formatting and drop unused var for pre-commit
sufubao Jun 5, 2026
16170f3
style: align formatting with upstream/main and inline mtp accept-len …
sufubao Jun 7, 2026
0d42047
test(static_inference): generalize MTP static benchmark
sufubao Jun 7, 2026
c071ffd
fix(fa3/fp8): narrow decode verify to b_att_seq_len + causal (#4)
sufubao Jun 8, 2026
3c9fba3
test(static_mtp): lock bench exercises verify path via b_num_accepted…
sufubao Jun 8, 2026
f28b94d
fix(linear_att): guard conv width contiguity + accept-len bound in sn…
sufubao Jun 8, 2026
53663a8
fix(infer_batch): bound mtp_accept_len at tail small-page conv offloa…
sufubao Jun 8, 2026
63c5d07
fix(req_manager): zero full SSM block on linear-att init (#17)
sufubao Jun 8, 2026
b0dfbb8
test(qwen3next): pin prefill->first-decode conv column round-trip (#19)
sufubao Jun 8, 2026
3e67d68
perf(qwen3next): drop per-layer D2H sync in causal_conv1d_update seql…
sufubao Jun 8, 2026
89700ab
perf(qwen3next): drop per-step accept-len .all() D2H sync in GDN veri…
sufubao Jun 8, 2026
f989b64
perf(chunked): compute b_req_mtp_start_loc on device via arange (#22)
sufubao Jun 8, 2026
05551c3
fix(mtp): apply review fixes — dp verify-layout + accept-len writebac…
sufubao Jun 8, 2026
4f17bd0
test(static_inference): extend MTP static bench; gitignore benchmark/…
sufubao Jun 8, 2026
530aca7
cleanup(mtp_utils): remove dead gen_b_req_mtp_start_loc kernel (#22)
sufubao Jun 8, 2026
d38a6e0
perf(dp): shrink eagle draft to accepted rows, share builder with chu…
sufubao Jun 8, 2026
6b1466b
refactor(mtp): single source for added-layer count (#9)
sufubao Jun 8, 2026
640694c
refactor(infer_struct): share MTP-verify extra-state block (#12)
sufubao Jun 8, 2026
32cf1b3
cleanup(cuda_graph): remove dead cuda_graph_batch_sizes alias (#13)
sufubao Jun 8, 2026
499251d
refactor(req_manager): derive mtp_step cap from buffer width + real g…
sufubao Jun 8, 2026
6b801b1
refactor(linear_att): drop default-named conv-shape alias (#24)
sufubao Jun 8, 2026
4165298
refactor(mtp): single factory for draft-model selection (#10)
sufubao Jun 8, 2026
0a024b2
refactor(mtp): centralize is_mtp_verify_decode predicate (#21)
sufubao Jun 8, 2026
53b7b94
refactor(mtp): extract BaseMTPModel mixin for shared draft wiring (#25)
sufubao Jun 8, 2026
9a8697c
refactor(mtp): detect draft models via is_mtp_draft_model attr (#23)
sufubao Jun 8, 2026
a27ccde
refactor(qwen3_5_mtp): share weight-retarget mixin across dense+moe (…
sufubao Jun 8, 2026
5d46b47
Merge remote-tracking branch 'upstream/main' into qw35_mtp
sufubao Jun 8, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,6 @@ dist
.vscode
tmp/
requirements-musa.txt
logs/
logs/

benchmark/
11 changes: 7 additions & 4 deletions lightllm/common/basemodel/attention/fa3/fp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.common.basemodel.triton_kernel.fa3_utils import page_table_copy
from lightllm.common.basemodel.triton_kernel.gen_prefill_params import gen_cumsum_pad0_tensor
from lightllm.common.basemodel.batch_objs import is_mtp_verify_decode as is_mtp_verify_decode_fn


class Fa3AttBackend(BaseAttBackend):
Expand Down Expand Up @@ -125,8 +126,9 @@ class Fa3DecodeAttState(BaseDecodeAttState):
def init_state(self):
self.backend: Fa3AttBackend = self.backend
args_mtp_step = get_env_start_args().mtp_step
is_mtp_verify_decode = is_mtp_verify_decode_fn(args_mtp_step, self.infer_state.b_num_accepted_tokens)

if args_mtp_step > 0:
if is_mtp_verify_decode:
# 修正 mtp 在 fa3 下的输入。
mtp_size = args_mtp_step + 1
b_q_seq_len = torch.full(
Expand All @@ -143,8 +145,9 @@ def init_state(self):
self.cu_seqlens_q = self.infer_state.b1_cu_q_seq_len.int()
self.cu_seqlens_k = self.infer_state.b1_cu_kv_seq_len.int()

att_batch_size = self.infer_state.batch_size // (args_mtp_step + 1)
assert self.infer_state.batch_size % (args_mtp_step + 1) == 0
mtp_size = args_mtp_step + 1 if is_mtp_verify_decode else 1
att_batch_size = self.infer_state.batch_size // mtp_size
assert self.infer_state.batch_size % mtp_size == 0

model = self.backend.model
# 可以使用 cuda graph的时候从 buffer中申请
Expand All @@ -163,7 +166,7 @@ def init_state(self):
device=self.infer_state.input_ids.device,
)

if args_mtp_step > 0:
if is_mtp_verify_decode:
page_table_copy(
page_table=self.page_table[:, : self.infer_state.max_kv_seq_len],
req_to_token_indexs=model.req_manager.req_to_token_indexs,
Expand Down
12 changes: 3 additions & 9 deletions lightllm/common/basemodel/attention/fa3/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from ..base_att import AttControl
from typing import Optional, TYPE_CHECKING
from lightllm.utils.sgl_utils import flash_attn_with_kvcache
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.common.basemodel.triton_kernel.quantization.q_per_head_fp8_quant import q_per_head_fp8_quant
from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops
from typing import Union
Expand Down Expand Up @@ -116,12 +115,7 @@ def init_state(self):
super().init_state()
self.backend: Fp8Fa3AttBackend = self.backend

args_mtp_step = get_env_start_args().mtp_step
att_batch_size = self.infer_state.batch_size // (args_mtp_step + 1)
assert self.infer_state.batch_size % (args_mtp_step + 1) == 0

device = self.infer_state.input_ids.device
batch_size = att_batch_size
batch_size = self.b_att_seq_len.shape[0]
mem_manager = self.backend.model.mem_manager

offline_scales: torch.Tensor = mem_manager.scales
Expand Down Expand Up @@ -180,11 +174,11 @@ def _fp8_decode_att(
k_cache=cache_k,
v_cache=cache_v,
page_table=self.page_table,
cache_seqlens=self.infer_state.b_seq_len,
cache_seqlens=self.b_att_seq_len,
cu_seqlens_q=self.cu_seqlens_q,
cu_seqlens_k_new=self.cu_seqlens_k,
max_seqlen_q=self.decode_max_q_seq_len,
causal=False,
causal=True,
window_size=(-1, -1),
softcap=0.0,
q_descale=q_scale.view(self.infer_state.batch_size, k_head_num),
Expand Down
11 changes: 7 additions & 4 deletions lightllm/common/basemodel/attention/fa3/mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from lightllm.common.basemodel.triton_kernel.fa3_utils import page_table_copy
from lightllm.common.basemodel.triton_kernel.gen_prefill_params import gen_cumsum_pad0_tensor
from lightllm.utils.sgl_utils import flash_attn_varlen_func
from lightllm.common.basemodel.batch_objs import is_mtp_verify_decode as is_mtp_verify_decode_fn


class MlaFa3AttBackend(BaseAttBackend):
Expand Down Expand Up @@ -108,8 +109,9 @@ class MlaFa3DecodeAttState(BaseDecodeAttState):
def init_state(self):
self.backend: MlaFa3AttBackend = self.backend
args_mtp_step = get_env_start_args().mtp_step
is_mtp_verify_decode = is_mtp_verify_decode_fn(args_mtp_step, self.infer_state.b_num_accepted_tokens)

if args_mtp_step > 0:
if is_mtp_verify_decode:
# 修正 mtp 在 fa3 下的输入。
mtp_size = args_mtp_step + 1
b_q_seq_len = torch.full(
Expand All @@ -126,8 +128,9 @@ def init_state(self):
self.cu_seqlens_q = self.infer_state.b1_cu_q_seq_len.int()
self.cu_seqlens_k = self.infer_state.b1_cu_kv_seq_len.int()

att_batch_size = self.infer_state.batch_size // (args_mtp_step + 1)
assert self.infer_state.batch_size % (args_mtp_step + 1) == 0
mtp_size = args_mtp_step + 1 if is_mtp_verify_decode else 1
att_batch_size = self.infer_state.batch_size // mtp_size
assert self.infer_state.batch_size % mtp_size == 0

model = self.backend.model
# 可以使用 cuda graph的时候从 buffer中申请
Expand All @@ -146,7 +149,7 @@ def init_state(self):
device=self.infer_state.input_ids.device,
)

if args_mtp_step > 0:
if is_mtp_verify_decode:
page_table_copy(
page_table=self.page_table[:, : self.infer_state.max_kv_seq_len],
req_to_token_indexs=model.req_manager.req_to_token_indexs,
Expand Down
Loading
Loading