Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
6cf58b4
fix
hiworldwzj Apr 9, 2026
f1251a3
add gitignore
flyinglandlord Mar 12, 2026
d9b1fdd
finish usable mtp kernel
flyinglandlord Mar 17, 2026
315366a
end-to-end finish
flyinglandlord Mar 19, 2026
73ea125
fix cudagraph support
flyinglandlord Mar 19, 2026
dc91e59
save runnable version of dynamic mtp
flyinglandlord Mar 26, 2026
aefe67e
save runnable version of dynamic mtp
flyinglandlord Mar 26, 2026
3c28fb0
fix
hiworldwzj Apr 9, 2026
2dc933e
save fixed dynamic mtp
flyinglandlord Mar 27, 2026
0b08de8
save
flyinglandlord Mar 30, 2026
952ec15
save
flyinglandlord Mar 30, 2026
3750118
add experiment script
flyinglandlord Apr 1, 2026
4bf4287
update mtp kernel support BLOCK_BATCH < max_verify_group_size
flyinglandlord Apr 1, 2026
05e0dfd
fix implementation issues
flyinglandlord Apr 3, 2026
c2b7569
save
flyinglandlord Apr 4, 2026
80219af
save
flyinglandlord Apr 8, 2026
41180d3
fix
hiworldwzj Apr 9, 2026
8afd7a8
fix
hiworldwzj Apr 9, 2026
2b277fa
fix
hiworldwzj Apr 9, 2026
93c2ada
fix
hiworldwzj Apr 9, 2026
6395447
fix
hiworldwzj Apr 9, 2026
fc20624
fix
hiworldwzj Apr 9, 2026
1b08d15
fix
hiworldwzj Apr 9, 2026
775adbd
fix
hiworldwzj Apr 9, 2026
d4830ff
fix
hiworldwzj Apr 9, 2026
da944f8
fix
hiworldwzj Apr 9, 2026
22c5996
fix
hiworldwzj Apr 9, 2026
6fbe8d8
fix
hiworldwzj Apr 9, 2026
c4a9f74
fix lightllm/server/router/model_infer/mode_backend/generic_pre_proce…
flyinglandlord Apr 9, 2026
5eb4889
update generic_pre_process.py
flyinglandlord Apr 9, 2026
379f256
fix
flyinglandlord Apr 9, 2026
e943a43
fix
flyinglandlord Apr 9, 2026
e121b9d
refactor qwen3_eagle3
flyinglandlord Apr 9, 2026
1723230
add stage1
hiworldwzj Apr 9, 2026
6890bc0
fix
hiworldwzj Apr 9, 2026
1e1fb98
fix
hiworldwzj Apr 9, 2026
29535b2
fix
hiworldwzj Apr 9, 2026
5b925a6
fix
hiworldwzj Apr 9, 2026
538200d
fix
hiworldwzj Apr 9, 2026
2831c70
fix
hiworldwzj Apr 9, 2026
158c7a3
fix
hiworldwzj Apr 9, 2026
4c08120
fix
hiworldwzj Apr 10, 2026
a837cbb
fix
hiworldwzj Apr 10, 2026
17ed333
fix
hiworldwzj Apr 10, 2026
7927e15
fix
hiworldwzj Apr 10, 2026
53a7077
fix base_backend.py
flyinglandlord Apr 10, 2026
322f713
fix
hiworldwzj Apr 10, 2026
c3e46c9
fix
hiworldwzj Apr 10, 2026
2f0c250
fix
hiworldwzj Apr 10, 2026
8ab2ab4
fix
hiworldwzj Apr 10, 2026
76ca4ce
fix
hiworldwzj Apr 10, 2026
1d0f18e
fix
hiworldwzj Apr 10, 2026
6e69701
fix
hiworldwzj Apr 10, 2026
1565698
fix
hiworldwzj Apr 10, 2026
5e857c8
fix
hiworldwzj Apr 10, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,9 @@ dist
.vscode
tmp/
requirements-musa.txt

hf_datasets_cache/
wandb/
datasets/
trace/
experiment_results/
12,002 changes: 12,002 additions & 0 deletions datasets/gsm8k.json

Large diffs are not rendered by default.

51 changes: 49 additions & 2 deletions lightllm/common/basemodel/attention/triton/fp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import dataclasses
import torch

from lightllm.utils.envs_utils import enable_dynamic_mtp_verify, get_env_start_args, enable_triton_mtp_kernel
from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl
from typing import Optional

Expand Down Expand Up @@ -80,8 +82,17 @@ def _nomarl_prefill_att(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,

@dataclasses.dataclass
class TritonDecodeAttState(BaseDecodeAttState):
# MTP related state variables
b_mark_shared_group: torch.Tensor = None

def init_state(self):
pass
args_mtp_step = get_env_start_args().mtp_step

if args_mtp_step > 0:
# MTP mode initialization
self.b_mark_shared_group = self.infer_state.b_mark_shared_group
else:
self.b_mark_shared_group = None

def copy_for_decode_cuda_graph(self, new_state: "TritonDecodeAttState"):
super().copy_for_decode_cuda_graph(new_state)
Expand All @@ -99,9 +110,17 @@ def decode_att(
assert att_control.tp_alibi is not None
return self._alibi_decode_att(q=q, k=k, v=v, att_control=att_control, alloc_func=alloc_func)
else:

args_mtp_step = get_env_start_args().mtp_step

q_head_num = q.shape[1]
k_head_num = k.shape[1]
if q_head_num == k_head_num:

if args_mtp_step > 0 and (enable_dynamic_mtp_verify() or enable_triton_mtp_kernel()):
# MTP mode: use mtp diverse attention
assert q_head_num >= k_head_num, "MTP diverse attention requires q_head_num >= k_head_num"
return self._dynamic_mtp_decode_gqa_att(q=q, k=k, v=v, alloc_func=alloc_func)
elif q_head_num == k_head_num:
return self._normal_decode_flash_decoding_att(q=q, k=k, v=v, alloc_func=alloc_func)
elif q_head_num > k_head_num:
return self._normal_decode_gqa_flash_decoding_att(q=q, k=k, v=v, alloc_func=alloc_func)
Expand Down Expand Up @@ -182,6 +201,34 @@ def _normal_decode_gqa_flash_decoding_att(

return out

def _dynamic_mtp_decode_gqa_att(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
alloc_func=torch.empty,
):
from ...triton_kernel.att.decode_att.gqa.mtp_diverse import (
token_decode_attention_mtp_diverse_single_token,
)

b_seq_len = self.infer_state.b_seq_len
# 在动态 MTP 验证模式下,使用 infer_state.b_mark_shared_group(从 model_input 传递)
# 在静态 MTP 模式下,使用 self.b_mark_shared_group(在 init_state 中初始化)
b_mark_shared_group = self.infer_state.b_mark_shared_group
out = token_decode_attention_mtp_diverse_single_token(
q=q,
k=k,
v=v,
Req_to_tokens=self.infer_state.req_manager.req_to_token_indexs,
B_req_idx=self.infer_state.b_req_idx,
b_seq_len=b_seq_len,
b_mark_shared_group=b_mark_shared_group,
alloc_tensor_func=alloc_func,
)

return out

def _normal_decode_gqa_flash_decoding_att_vsm(
self,
q: torch.Tensor,
Expand Down
100 changes: 80 additions & 20 deletions lightllm/common/basemodel/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,23 @@
from lightllm.common.basemodel.triton_kernel.gather_token_id import gather_token
from lightllm.utils.log_utils import init_logger
from lightllm.utils.dist_utils import get_dp_world_size
from lightllm.utils.envs_utils import get_env_start_args, get_llm_data_type, get_added_mtp_kv_layer_num
from lightllm.utils.envs_utils import (
enable_triton_mtp_kernel,
get_env_start_args,
get_llm_data_type,
get_added_mtp_kv_layer_num,
)
from lightllm.distributed.communication_op import dist_group_manager
from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput
from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput, OutHiddenState
from lightllm.common.triton_utils.autotuner import AutotuneLevel
from lightllm.utils.custom_kernel_utis import pad2dim_tensor_to_new_batch
from lightllm.utils.envs_utils import set_model_init_status, enable_diverse_mode_gqa_decode_fast_kernel
from lightllm.utils.envs_utils import (
set_model_init_status,
enable_diverse_mode_gqa_decode_fast_kernel,
enable_dynamic_mtp_verify,
)
from lightllm.common.triton_utils.autotuner import Autotuner
from lightllm.utils.infer_utils import post_empty_cache
from lightllm.utils.infer_utils import calculate_time, post_empty_cache
from .attention import get_prefill_att_backend_class, get_decode_att_backend_class
from .attention import BaseAttBackend

Expand Down Expand Up @@ -93,16 +102,11 @@ def __init__(self, kvargs):
self.mem_fraction = kvargs.get("mem_fraction", 0.9)
self.tp_world_size_ = get_dp_world_size()
self.enable_tpsp_mix_mode = get_env_start_args().enable_tpsp_mix_mode

self.is_mtp_mode = self.args.mtp_mode in [
"vanilla_with_att",
"eagle_with_att",
"vanilla_no_att",
"eagle_no_att",
]
self.prefill_graph: PrefillCudaGraph = None

self._init_config()
self._init_speculative_algo(kvargs)

self._verify_must()
self._verify_params()
self._init_quant()
Expand Down Expand Up @@ -137,6 +141,38 @@ def __init__(self, kvargs):
set_model_init_status(True)
return

def _init_speculative_algo(self, kvargs):
self.is_mtp_draft_model = kvargs.get("is_mtp_draft_model", False)
self.is_mtp_mode = self.args.mtp_mode in [
"vanilla_with_att",
"eagle_with_att",
"vanilla_no_att",
"eagle_no_att",
"eagle3",
]
if not self.is_mtp_mode:
self.output_hidden_layers = []
return

if not self.is_mtp_draft_model:
# 主 main model 需要输出 hidden state 用于 draft 模型进行 mtp 预测。
if self.args.mtp_mode == "eagle3":
assert not self.args.enable_prefill_cudagraph, "eagle3 mode does not support prefill cudagraph"
assert (
not self.args.enable_decode_microbatch_overlap
), "eagle3 mode does not support decode microbatch overlap"
assert (
not self.args.enable_prefill_microbatch_overlap
), "eagle3 mode does not support prefill microbatch overlap"
self.output_hidden_layers = [1, self.config["n_layer"] // 2 - 1, self.config["n_layer"] - 4]
else:
self.output_hidden_layers = [self.config["n_layer"] - 1]
else:
# draft model 需要输出 hidden state 用于 多步 mtp 预测
self.output_hidden_layers = [self.config["n_layer"] - 1]

return

def _wait_other_modules_ready(self):
for event in self.wait_events:
event.wait()
Expand All @@ -151,6 +187,12 @@ def _init_config(self):
repair_config(self.config, same_names=["num_hidden_layers", "n_layer"])
if self.finetune_config:
self.config["vocab_size"] = self.finetune_config.vocab_size

# eagle3 mode 下,需要修改 vocab_size 为 draft_vocab_size, 其他场景
# 这个代码并不会生效。
if "draft_vocab_size" in self.config.keys():
self.config["target_vocab_size"] = self.config["vocab_size"]
self.config["vocab_size"] = self.config["draft_vocab_size"]
return

@final
Expand Down Expand Up @@ -314,6 +356,8 @@ def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0)
if enable_diverse_mode_gqa_decode_fast_kernel():
infer_state.b_shared_seq_len = model_input.b_shared_seq_len
infer_state.b_mark_shared_group = model_input.b_mark_shared_group
elif enable_dynamic_mtp_verify() or enable_triton_mtp_kernel():
infer_state.b_mark_shared_group = model_input.b_mark_shared_group

infer_state.multimodal_params = model_input.multimodal_params

Expand Down Expand Up @@ -377,6 +421,11 @@ def _create_padded_decode_model_input(self, model_input: ModelInput, new_batch_s
new_model_input.b_mark_shared_group = F.pad(
new_model_input.b_mark_shared_group, (0, padded_batch_size), mode="constant", value=1
)
elif enable_dynamic_mtp_verify() or enable_triton_mtp_kernel():
assert new_model_input.b_mark_shared_group is not None
new_model_input.b_mark_shared_group = F.pad(
new_model_input.b_mark_shared_group, (0, padded_batch_size), mode="constant", value=0
)

# 特殊模型,特殊模式的特殊变量的特殊 padding
if new_model_input.mtp_draft_input_hiddens is not None:
Expand Down Expand Up @@ -561,12 +610,23 @@ def _context_forward(self, infer_state: InferStateInfo):
input_tensors = [input_embs]

def prefill_func(input_tensors, infer_state):
mtp_out_hidden_state = OutHiddenState(selected_layers=self.output_hidden_layers)
_input_embs = input_tensors[0]
for i in range(self.layers_num):
layer = self.layers_infer[i]
layer_method = (layer.context_forward, layer.tpsp_context_forward)[run_mode_index]
_input_embs = layer_method(_input_embs, infer_state, self.trans_layers_weight[i])
return [_input_embs]
mtp_out_hidden_state.add_hidden(
layer_index=i,
layer_num=self.layers_num,
hidden=_input_embs,
)

capture_hiddens = mtp_out_hidden_state.get_captured_hiddens()
if capture_hiddens is not None:
return [_input_embs, capture_hiddens]
else:
return [_input_embs]

handle_token_num = input_ids.shape[0]

Expand Down Expand Up @@ -596,8 +656,8 @@ def prefill_func(input_tensors, infer_state):
model_output = ModelOutput(logits=predict_logits)

# 特殊模型特殊模式的额外输出
if self.is_mtp_mode:
model_output.mtp_main_output_hiddens = input_embs
if self.is_mtp_mode and len(output_tensors) > 1:
model_output.mtp_main_output_hiddens = output_tensors[1]

# 在开启使用deepep的时候,需要调用clear_deepep_buffer做资源清理,没有启用的时候
# 该调用没有实际意义
Expand All @@ -611,22 +671,21 @@ def _token_forward(self, infer_state: InferStateInfo):
cuda_input_ids = input_ids
pre_method = (self.pre_infer.token_forward, self.pre_infer.tpsp_token_forward)[run_mode_index]
input_embs = pre_method(cuda_input_ids, infer_state, self.pre_post_weight)
mtp_out_hidden_state = OutHiddenState(selected_layers=self.output_hidden_layers)
for i in range(self.layers_num):
layer = self.layers_infer[i]
layer_method = (layer.token_forward, layer.tpsp_token_forward)[run_mode_index]
input_embs: torch.Tensor = layer_method(input_embs, infer_state, self.trans_layers_weight[i])
mtp_out_hidden_state.add_hidden(layer_index=i, layer_num=self.layers_num, hidden=input_embs)

capture_hiddens = mtp_out_hidden_state.get_captured_hiddens()
post_method = (self.post_infer.token_forward, self.post_infer.tpsp_token_forward)[run_mode_index]
predict_logits: torch.Tensor = post_method(input_embs, infer_state, self.pre_post_weight)

if self.is_mtp_mode:
graph_out_hiddens = input_embs.contiguous()

model_output = ModelOutput(logits=predict_logits.contiguous())

# 特殊模型特殊模式的额外输出
if self.is_mtp_mode:
model_output.mtp_main_output_hiddens = graph_out_hiddens
if self.is_mtp_mode and capture_hiddens is not None:
model_output.mtp_main_output_hiddens = capture_hiddens.contiguous()

# 在 cuda graph 模式下,输出需要转为 no ref tensor, 加强mem pool 的复用,降低显存的使用。
if infer_state.is_cuda_graph:
Expand Down Expand Up @@ -1027,6 +1086,7 @@ def _gen_special_model_input(self, token_num: int):
or "Qwen3MOEMTPModel" in str(self.__class__)
or "MistralMTPModel" in str(self.__class__)
or "Glm4MoeLiteMTPModel" in str(self.__class__)
or "Qwen3EagleModel" in str(self.__class__)
)
if is_mtp_draft_model:
special_model_input["mtp_draft_input_hiddens"] = torch.randn(
Expand Down
38 changes: 37 additions & 1 deletion lightllm/common/basemodel/batch_objs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@
from dataclasses import dataclass, field
from typing import Optional
from typing import List
from lightllm.utils.envs_utils import enable_diverse_mode_gqa_decode_fast_kernel
from lightllm.utils.envs_utils import (
enable_diverse_mode_gqa_decode_fast_kernel,
enable_dynamic_mtp_verify,
enable_triton_mtp_kernel,
)
from lightllm.utils.tensor_utils import tensor_to_no_ref_tensor


Expand Down Expand Up @@ -69,6 +73,12 @@ def to_cuda(self):
self.b_shared_seq_len = torch.zeros(size=(batch_size,), dtype=torch.int32, device="cuda")
else:
self.b_shared_seq_len = self.b_shared_seq_len.cuda(non_blocking=True)
elif not self.is_prefill and (enable_dynamic_mtp_verify() or enable_triton_mtp_kernel()):
batch_size = len(self.b_req_idx)
if self.b_mark_shared_group is None:
self.b_mark_shared_group = torch.ones(size=(batch_size,), dtype=torch.int32, device="cuda")
else:
self.b_mark_shared_group = self.b_mark_shared_group.cuda(non_blocking=True)

def __post_init__(self):
self.check_input()
Expand Down Expand Up @@ -96,3 +106,29 @@ def to_no_ref_tensor(self):
self.logits = tensor_to_no_ref_tensor(self.logits)
if self.mtp_main_output_hiddens is not None:
self.mtp_main_output_hiddens = tensor_to_no_ref_tensor(self.mtp_main_output_hiddens)


class OutHiddenState:
def __init__(self, selected_layers: List[int]):
self.selected_layers = selected_layers
self.capture_hiddens = []

def add_hidden(self, layer_index: int, layer_num: int, hidden: torch.Tensor):
if layer_index in self.selected_layers:
is_last_layer = layer_index == layer_num - 1
if not is_last_layer:
self.capture_hiddens.append(hidden.clone())
else:
# 最后一层可以不clone,直接使用提升性能
self.capture_hiddens.append(hidden)

def get_captured_hiddens(self) -> Optional[torch.Tensor]:
if self.capture_hiddens:
if len(self.capture_hiddens) > 1:
self.capture_hiddens = torch.cat(self.capture_hiddens, dim=-1)
else:
# 减少一次 clone 操作, 可以提升性能
self.capture_hiddens = self.capture_hiddens[0]
else:
self.capture_hiddens = None
return self.capture_hiddens
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""
MTP Diverse Attention Module

MTP (Multi-Token Prediction) Diverse Attention 的实现。
"""
from .mtp_diverse_attn import token_decode_attention_mtp_diverse_single_token
from .stage1_single_token import mtp_diverse_stage1_single_token
from .stage2_single_token import mtp_diverse_stage2_single_token

__all__ = [
"token_decode_attention_mtp_diverse_single_token",
"mtp_diverse_stage1_single_token",
"mtp_diverse_stage2_single_token",
]
Loading
Loading