Skip to content

Commit 2560680

Browse files
committed
feat(qwen3_5_mtp): Qwen3.5 / Qwen3.5-MoE MTP draft models
Add the MTP draft model packages and register them: - qwen3_5_mtp: a forced single full-attn-layer draft model, with the MTP pre-layer infer (embed/hidden norm + fc fusion) and pre/post + transformer-layer weight loaders reading the mtp.* namespace. - qwen3_5_moe_mtp: the MoE variant draft weight loaders + model. - register qwen3_5 / qwen3_5_moe MTP draft models with per-block draft_idx, plus the qwen3_5 verify infer_struct. Unit tests scaffold the MTP draft layer and the hybrid verify forward.
1 parent e4b8a08 commit 2560680

14 files changed

Lines changed: 490 additions & 0 deletions

File tree

lightllm/models/qwen3_5/infer_struct.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,39 @@ def init_some_extra_state(self, model):
1616
mtp_step = get_env_start_args().mtp_step
1717

1818
self.b_buffer_idx = self.b_req_idx * (mtp_step + 1) + self.b_mtp_index
19+
# conv buffer is now ONE widened slot per request (indexed by req_idx),
20+
# dropping the *(S+1) + mtp_index addressing used by the SSM block.
21+
self.b_conv_buffer_idx = self.b_req_idx
22+
# MTP verify batch: decode-mode, S+1 expanded, and gated on the
23+
# per-real-request accept tensor that decode_mtp threads in. Gating on
24+
# b_num_accepted_tokens (vs only b_mtp_index, which is set for any decode)
25+
# distinguishes the main-model verify forward from draft/plain decode.
26+
self.is_mtp_verify = (
27+
(mtp_step > 0)
28+
and (not self.is_prefill)
29+
and (self.b_mtp_index is not None)
30+
and (self.b_num_accepted_tokens is not None)
31+
)
32+
self.b_gdn_verify_cu_seqlens = None
33+
self.b_ssm_index_rows = None
34+
# b_num_accepted_tokens is threaded onto the infer_state from ModelInput by
35+
# _create_inferstate (mirrors b_mtp_index) BEFORE this runs; nothing to do here.
36+
if self.is_mtp_verify:
37+
step = mtp_step + 1
38+
n_real = self.b_req_idx.shape[0] // step
39+
self.b_gdn_verify_cu_seqlens = torch.arange(
40+
0, (n_real + 1) * step, step, dtype=torch.int32, device=self.b_req_idx.device
41+
)
42+
req_first = self.b_req_idx.view(n_real, step)[:, 0]
43+
base = (req_first * step).view(n_real, 1)
44+
self.b_ssm_index_rows = base + torch.arange(step, device=base.device, dtype=base.dtype).view(1, step)
45+
assert self.b_ssm_index_rows.shape == (n_real, step)
46+
# The spec conv kernel is per-SEQUENCE (one program per real request),
47+
# indexed by conv_state_indices[idx_seq] with idx_seq in [0, n_real),
48+
# aligned 1:1 with b_gdn_verify_cu_seqlens / b_num_accepted_tokens. The
49+
# default b_conv_buffer_idx = b_req_idx has the expanded length n_real*step,
50+
# which launches n_real*step conv programs and reads num_accepted/
51+
# query_start_loc out of bounds for idx_seq >= n_real, corrupting the
52+
# committed conv slot. Narrow it to one widened conv slot per request.
53+
self.b_conv_buffer_idx = req_first
1954
return
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from lightllm.models.qwen3_5_moe_mtp.model import Qwen3_5MoeMTPModel
2+
3+
__all__ = ["Qwen3_5MoeMTPModel"]
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from lightllm.models.qwen3_5_moe_mtp.layer_weights.transformer_layer_weight import (
2+
Qwen3_5MoeMTPTransformerLayerWeight,
3+
)
4+
5+
__all__ = ["Qwen3_5MoeMTPTransformerLayerWeight"]
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
from lightllm.common.basemodel.layer_weights.meta_weights import (
2+
COLMMWeight,
3+
FusedMoeWeight,
4+
QKVROWNMMWeight,
5+
ROWMMWeight,
6+
)
7+
from lightllm.models.qwen3_5_moe.layer_weights.transformer_layer_weight import (
8+
Qwen35MOETransformerLayerWeight,
9+
)
10+
from lightllm.utils.envs_utils import get_env_start_args
11+
12+
13+
class Qwen3_5MoeMTPTransformerLayerWeight(Qwen35MOETransformerLayerWeight):
14+
_MAIN_PREFIX = "model.layers."
15+
_MTP_PREFIX = "mtp.layers."
16+
17+
def _retarget(self, name):
18+
if name is None:
19+
return None
20+
return name.replace(self._MAIN_PREFIX, self._MTP_PREFIX, 1)
21+
22+
def _init_weight_names(self):
23+
super()._init_weight_names()
24+
self._q_weight_name = self._retarget(self._q_weight_name)
25+
self._q_norm_name = self._retarget(self._q_norm_name)
26+
self._q_bias_name = self._retarget(self._q_bias_name)
27+
self._k_weight_name = self._retarget(self._k_weight_name)
28+
self._k_norm_name = self._retarget(self._k_norm_name)
29+
self._k_bias_name = self._retarget(self._k_bias_name)
30+
self._v_weight_name = self._retarget(self._v_weight_name)
31+
self._v_bias_name = self._retarget(self._v_bias_name)
32+
self._kv_weight_name = self._retarget(self._kv_weight_name)
33+
self._kv_bias_name = self._retarget(self._kv_bias_name)
34+
self._o_weight_name = self._retarget(self._o_weight_name)
35+
self._o_bias_name = self._retarget(self._o_bias_name)
36+
self._att_norm_weight_name = self._retarget(self._att_norm_weight_name)
37+
self._att_norm_bias_name = self._retarget(self._att_norm_bias_name)
38+
self._ffn_norm_weight_name = self._retarget(self._ffn_norm_weight_name)
39+
self._ffn_norm_bias_name = self._retarget(self._ffn_norm_bias_name)
40+
41+
def _init_qkv(self):
42+
in_dim = self.n_embed
43+
q_out_dim = self.q_head_num_ * self.head_dim
44+
self.qkv_proj = QKVROWNMMWeight(
45+
in_dim=in_dim,
46+
q_head_num=self.q_head_num_,
47+
kv_head_num=self.k_head_num_,
48+
head_dim=self.head_dim,
49+
weight_names=[self._q_weight_name, self._k_weight_name, self._v_weight_name],
50+
data_type=self.data_type_,
51+
bias_names=[self._q_bias_name, self._k_bias_name, self._v_bias_name],
52+
quant_method=self.get_quant_method("qkv_proj"),
53+
)
54+
self._o_gate_weight_name = f"{self._MTP_PREFIX}{self.layer_num_}.self_attn.o_gate_proj.weight"
55+
self._o_gate_proj = ROWMMWeight(
56+
in_dim=in_dim,
57+
out_dims=[q_out_dim],
58+
weight_names=[self._o_gate_weight_name],
59+
data_type=self.data_type_,
60+
bias_names=None,
61+
quant_method=self.get_quant_method("o_gate_proj"),
62+
)
63+
64+
def _init_moe(self):
65+
moe_intermediate_size = self.network_config_["moe_intermediate_size"]
66+
self.moe_gate = ROWMMWeight(
67+
in_dim=self.network_config_["hidden_size"],
68+
out_dims=[self.n_routed_experts],
69+
weight_names=f"{self._MTP_PREFIX}{self.layer_num_}.mlp.gate.weight",
70+
data_type=self.data_type_,
71+
quant_method=None,
72+
tp_rank=0,
73+
tp_world_size=1,
74+
)
75+
self.experts = FusedMoeWeight(
76+
gate_proj_name="gate_proj",
77+
down_proj_name="down_proj",
78+
up_proj_name="up_proj",
79+
e_score_correction_bias_name="",
80+
weight_prefix=f"{self._MTP_PREFIX}{self.layer_num_}.mlp.experts",
81+
n_routed_experts=self.n_routed_experts,
82+
hidden_size=self.network_config_["hidden_size"],
83+
moe_intermediate_size=moe_intermediate_size,
84+
data_type=self.data_type_,
85+
quant_method=self.quant_cfg.get_quant_method(self.layer_num_, "fused_moe"),
86+
layer_num=self.layer_num_,
87+
network_config=self.network_config_,
88+
)
89+
self._init_gated_ffn()
90+
91+
def _init_gated_ffn(self):
92+
hidden_size = self.network_config_["hidden_size"]
93+
if "shared_expert_intermediate_size" not in self.network_config_:
94+
return
95+
96+
prefix = f"{self._MTP_PREFIX}{self.layer_num_}.mlp.shared_expert"
97+
inter_size = self.network_config_["shared_expert_intermediate_size"]
98+
if get_env_start_args().enable_ep_moe:
99+
self.gate_up_proj = ROWMMWeight(
100+
in_dim=hidden_size,
101+
out_dims=[inter_size, inter_size],
102+
weight_names=[f"{prefix}.gate_proj.weight", f"{prefix}.up_proj.weight"],
103+
data_type=self.data_type_,
104+
quant_method=self.get_quant_method("gate_up_proj"),
105+
tp_rank=0,
106+
tp_world_size=1,
107+
)
108+
self.down_proj = COLMMWeight(
109+
in_dim=inter_size,
110+
out_dims=[hidden_size],
111+
weight_names=f"{prefix}.down_proj.weight",
112+
data_type=self.data_type_,
113+
quant_method=self.get_quant_method("down_proj"),
114+
tp_rank=0,
115+
tp_world_size=1,
116+
)
117+
else:
118+
self.gate_up_proj = ROWMMWeight(
119+
in_dim=hidden_size,
120+
out_dims=[inter_size, inter_size],
121+
weight_names=[f"{prefix}.gate_proj.weight", f"{prefix}.up_proj.weight"],
122+
data_type=self.data_type_,
123+
quant_method=self.get_quant_method("gate_up_proj"),
124+
)
125+
self.down_proj = COLMMWeight(
126+
in_dim=inter_size,
127+
out_dims=[hidden_size],
128+
weight_names=f"{prefix}.down_proj.weight",
129+
data_type=self.data_type_,
130+
quant_method=self.get_quant_method("down_proj"),
131+
)
132+
133+
self.ffn_gate = ROWMMWeight(
134+
in_dim=hidden_size,
135+
out_dims=[1],
136+
weight_names=f"{self._MTP_PREFIX}{self.layer_num_}.mlp.shared_expert_gate.weight",
137+
data_type=self.data_type_,
138+
bias_names=None,
139+
quant_method=None,
140+
tp_rank=0,
141+
tp_world_size=1,
142+
)
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from lightllm.models.qwen3_5_mtp.model import Qwen3_5MTPModel
2+
from lightllm.models.qwen3_5_moe_mtp.layer_weights.transformer_layer_weight import (
3+
Qwen3_5MoeMTPTransformerLayerWeight,
4+
)
5+
6+
7+
class Qwen3_5MoeMTPModel(Qwen3_5MTPModel):
8+
transformer_weight_class = Qwen3_5MoeMTPTransformerLayerWeight

lightllm/models/qwen3_5_mtp/__init__.py

Whitespace-only changes.

lightllm/models/qwen3_5_mtp/layer_infer/__init__.py

Whitespace-only changes.
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import torch
2+
3+
from lightllm.models.qwen3_vl.layer_infer.pre_layer_infer import Qwen3VLMultimodalPreLayerInfer
4+
from lightllm.models.qwen3_5_mtp.layer_weights.pre_and_post_layer_weight import Qwen3_5MTPPreAndPostLayerWeight
5+
from lightllm.models.llama.infer_struct import LlamaInferStateInfo
6+
7+
8+
class Qwen3_5MTPPreLayerInfer(Qwen3VLMultimodalPreLayerInfer):
9+
10+
def __init__(self, network_config):
11+
super().__init__(network_config)
12+
self.eps_ = network_config["rms_norm_eps"]
13+
self.hidden_size = network_config["hidden_size"]
14+
return
15+
16+
def _mtp_fuse(
17+
self,
18+
input_embdings: torch.Tensor,
19+
infer_state: LlamaInferStateInfo,
20+
layer_weight: Qwen3_5MTPPreAndPostLayerWeight,
21+
) -> torch.Tensor:
22+
tgt_embdings = infer_state.mtp_draft_input_hiddens
23+
assert (
24+
input_embdings.shape[0] == tgt_embdings.shape[0]
25+
), f"shape {input_embdings.shape} != shape {tgt_embdings.shape}"
26+
27+
layer_weight.enorm_weight_(input=input_embdings, eps=self.eps_, out=input_embdings)
28+
layer_weight.hnorm_weight_(input=tgt_embdings, eps=self.eps_, out=tgt_embdings)
29+
cat_embdings = torch.cat((input_embdings, tgt_embdings), dim=-1)
30+
31+
return layer_weight.eh_proj_weight_.mm(cat_embdings)
32+
33+
def context_forward(
34+
self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: Qwen3_5MTPPreAndPostLayerWeight
35+
):
36+
input_embdings = super().context_forward(input_ids, infer_state, layer_weight)
37+
return self._mtp_fuse(input_embdings, infer_state, layer_weight)
38+
39+
def token_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: Qwen3_5MTPPreAndPostLayerWeight):
40+
input_embdings = super().token_forward(input_ids, infer_state, layer_weight)
41+
return self._mtp_fuse(input_embdings, infer_state, layer_weight)

lightllm/models/qwen3_5_mtp/layer_weights/__init__.py

Whitespace-only changes.
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from lightllm.common.basemodel import PreAndPostLayerWeight
2+
from lightllm.common.basemodel.layer_weights.meta_weights import (
3+
EmbeddingWeight,
4+
LMHeadWeight,
5+
NoTpGEMMANormWeight,
6+
ROWMMWeight,
7+
)
8+
from lightllm.common.quantization import Quantcfg
9+
10+
11+
class Qwen3_5MTPPreAndPostLayerWeight(PreAndPostLayerWeight):
12+
13+
def __init__(self, data_type, network_config, quant_cfg: Quantcfg):
14+
super().__init__(data_type, network_config)
15+
self.quant_cfg: Quantcfg = quant_cfg
16+
hidden_size = network_config["hidden_size"]
17+
18+
self.eh_proj_weight_ = ROWMMWeight(
19+
in_dim=hidden_size * 2,
20+
out_dims=[hidden_size],
21+
weight_names="mtp.fc.weight",
22+
data_type=self.data_type_,
23+
quant_method=self.quant_cfg.get_quant_method(0, "eh_proj"),
24+
tp_rank=0,
25+
tp_world_size=1,
26+
)
27+
self.enorm_weight_ = NoTpGEMMANormWeight(
28+
dim=hidden_size,
29+
weight_name="mtp.pre_fc_norm_embedding.weight",
30+
data_type=self.data_type_,
31+
)
32+
self.hnorm_weight_ = NoTpGEMMANormWeight(
33+
dim=hidden_size,
34+
weight_name="mtp.pre_fc_norm_hidden.weight",
35+
data_type=self.data_type_,
36+
)
37+
self.final_norm_weight_ = NoTpGEMMANormWeight(
38+
dim=hidden_size,
39+
weight_name="mtp.norm.weight",
40+
data_type=self.data_type_,
41+
)
42+
43+
# Shared with the main Qwen3.5 model, injected by the model class (not loaded here).
44+
self.wte_weight_: EmbeddingWeight = None
45+
self.lm_head_weight_: LMHeadWeight = None
46+
return

0 commit comments

Comments
 (0)