Skip to content

Commit a7d5e95

Browse files
WANDY666shihaobaiwangzaijun
authored
Optimize qwen3 moe (#1207)
Co-authored-by: shihaobai <1798930569@qq.com> Co-authored-by: wangzaijun <wangzaijun@sensetime.com>
1 parent ef57a56 commit a7d5e95

File tree

22 files changed

+565
-64
lines changed

22 files changed

+565
-64
lines changed

lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ def _select_experts(
4646
num_expert_group=num_expert_group,
4747
scoring_func=scoring_func,
4848
)
49-
topk_weights.mul_(self.routed_scaling_factor)
49+
if self.routed_scaling_factor != 1.0:
50+
topk_weights.mul_(self.routed_scaling_factor)
5051
if self.redundancy_expert_num > 0:
5152
redundancy_topk_ids_repair(
5253
topk_ids=topk_ids,

lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ def _select_experts(
5757
num_expert_group=num_expert_group,
5858
scoring_func=scoring_func,
5959
)
60-
topk_weights.mul_(self.routed_scaling_factor)
60+
if self.routed_scaling_factor != 1.0:
61+
topk_weights.mul_(self.routed_scaling_factor)
6162
if self.num_fused_shared_experts > 0:
6263
pad_topk_ids = (
6364
torch.arange(

lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py

Lines changed: 62 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from lightllm.utils.dist_utils import get_current_device_id, get_current_rank_in_dp, get_dp_world_size
55
from lightllm.common.basemodel.triton_kernel.norm.rmsnorm import rmsnorm_forward
66
from lightllm.common.basemodel.triton_kernel.norm.layernorm import layernorm_forward
7-
from lightllm.common.basemodel.triton_kernel.norm.qk_norm import qk_rmsnorm_forward
7+
from lightllm.common.basemodel.triton_kernel.norm.qk_norm import qk_rmsnorm_fused_forward
88
from .platform_op import PlatformAwareOp
99

1010

@@ -195,47 +195,84 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]):
195195
self.weight += 1
196196

197197

198-
class QKRMSNORMWeight(RMSNormWeight):
199-
def __init__(self, dim: int, weight_name: str, data_type: torch.dtype):
200-
super().__init__(dim=dim, weight_name=weight_name, data_type=data_type)
198+
class QKRMSNORMWeight(BaseWeightTpl, PlatformAwareOp):
199+
def __init__(self, dim: int, q_weight_name: str, k_weight_name: str, data_type: torch.dtype):
200+
super().__init__(tp_rank=0, tp_world_size=1)
201+
self.dim = dim
202+
self.q_weight_name = q_weight_name
203+
self.k_weight_name = k_weight_name
204+
self.data_type_ = data_type
205+
self._create_weight()
206+
207+
def _create_weight(self):
208+
self.q_weight: torch.Tensor = torch.empty(self.dim, dtype=self.data_type_, device=self.device_id_)
209+
self.k_weight: torch.Tensor = torch.empty(self.dim, dtype=self.data_type_, device=self.device_id_)
210+
self.q_weight.load_ok = False
211+
self.k_weight.load_ok = False
212+
213+
def load_hf_weights(self, weights: Dict[str, torch.Tensor]):
214+
if self.q_weight_name in weights:
215+
self.q_weight.copy_(weights[self.q_weight_name])
216+
self.q_weight.load_ok = True
217+
if self.k_weight_name in weights:
218+
self.k_weight.copy_(weights[self.k_weight_name])
219+
self.k_weight.load_ok = True
220+
221+
def verify_load(self):
222+
return self.q_weight.load_ok and self.k_weight.load_ok
201223

202224
def _native_forward(
203225
self,
204-
input: torch.Tensor,
226+
q: torch.Tensor,
227+
k: torch.Tensor,
205228
eps: float,
206229
) -> None:
207-
assert input.ndim == 2 and self.weight.ndim == 1
208-
assert input.shape[-1] == self.dim, f"Expected hidden_size to be {self.dim}, but found: {input.shape[-1]}"
209-
head_dim = self.weight.shape[0]
210-
x = input.to(torch.float32)
211-
x = x.view(-1, head_dim)
212-
x_var = x
213-
variance = x_var.pow(2).mean(dim=-1, keepdim=True)
214-
x = x * torch.rsqrt(variance + eps)
215-
x = (x * self.weight).to(self.data_type_)
216-
x = x.view(-1, input.shape[-1])
217-
input.copy_(x)
230+
assert q.ndim == 2 and self.q_weight.ndim == 1
231+
assert k.ndim == 2 and self.k_weight.ndim == 1
232+
assert (
233+
q.shape[-1] % self.dim == 0
234+
), f"Expected hidden_size to be multiple of {self.dim}, but found: {q.shape[-1]}"
235+
assert (
236+
k.shape[-1] % self.dim == 0
237+
), f"Expected hidden_size to be multiple of {self.dim}, but found: {k.shape[-1]}"
238+
239+
head_dim = self.q_weight.shape[0]
240+
241+
def _norm_inplace(t: torch.Tensor, weight: torch.Tensor):
242+
t_fp32 = t.to(torch.float32)
243+
t_fp32 = t_fp32.view(-1, head_dim)
244+
variance = t_fp32.pow(2).mean(dim=-1, keepdim=True)
245+
t_fp32 = t_fp32 * torch.rsqrt(variance + eps)
246+
t_fp32 = (t_fp32 * weight).to(self.data_type_)
247+
t_fp32 = t_fp32.view(-1, t.shape[-1])
248+
t.copy_(t_fp32)
249+
250+
_norm_inplace(q, self.q_weight)
251+
_norm_inplace(k, self.k_weight)
218252
return
219253

220-
def _triton_forward(self, input: torch.Tensor, eps: float) -> torch.Tensor:
221-
assert input.ndim == 2 and self.weight.ndim == 1
222-
return qk_rmsnorm_forward(x=input, weight=self.weight, eps=eps)
254+
def _triton_forward(self, q: torch.Tensor, k: torch.Tensor, eps: float) -> tuple:
255+
assert q.ndim == 2 and self.q_weight.ndim == 1
256+
assert k.ndim == 2 and self.k_weight.ndim == 1
257+
return qk_rmsnorm_fused_forward(q=q, k=k, w_q=self.q_weight, w_k=self.k_weight, eps=eps)
223258

224259
def _cuda_forward(
225260
self,
226-
input: torch.Tensor,
261+
q: torch.Tensor,
262+
k: torch.Tensor,
227263
eps: float,
228264
) -> None:
229-
self._triton_forward(input=input, eps=eps)
265+
self._triton_forward(q=q, k=k, eps=eps)
230266
return
231267

232-
def _musa_forward(self, input: torch.Tensor, eps: float) -> torch.Tensor:
268+
def _musa_forward(self, q: torch.Tensor, k: torch.Tensor, eps: float) -> tuple:
233269
# musa implementation is supported by musa triton on musa platform
234-
return self._triton_forward(input=input, eps=eps)
270+
return self._triton_forward(q=q, k=k, eps=eps)
235271

236272
def __call__(
237273
self,
238-
input: torch.Tensor,
274+
q: torch.Tensor,
275+
k: torch.Tensor,
239276
eps: float,
240277
) -> None:
241-
return self._forward(input=input, eps=eps)
278+
return self._forward(q=q, k=k, eps=eps)

lightllm/common/basemodel/triton_kernel/fused_moe/topk_select.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def fused_topk(
4646
sgl_ops.topk_softmax(
4747
topk_weights,
4848
topk_ids,
49-
gating_output.float(), # TODO(woosuk): Optimize this.
49+
gating_output,
5050
renormalize=renormalize,
5151
)
5252
return topk_weights, topk_ids

lightllm/common/basemodel/triton_kernel/norm/qk_norm.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,141 @@ def qk_rmsnorm_forward(x: torch.Tensor, weight: torch.Tensor, eps):
6464
num_warps=4,
6565
)
6666
return x
67+
68+
69+
@triton.jit
70+
def _qk_rms_norm_fused_kernel(
71+
# Q Pointers & Strides
72+
Q_ptr,
73+
WQ_ptr,
74+
stride_q_row,
75+
stride_q_col,
76+
# K Pointers & Strides
77+
K_ptr,
78+
WK_ptr,
79+
stride_k_row,
80+
stride_k_col,
81+
# Dimensions
82+
num_heads_q: tl.constexpr, # Q 的头数 (用于判断边界)
83+
head_dim: tl.constexpr,
84+
eps,
85+
BLOCK_SIZE: tl.constexpr,
86+
):
87+
# PID 0: 处理第几个 Token (Row)
88+
row_idx = tl.program_id(0)
89+
# PID 1: 处理第几个 Head (Combo Index)
90+
# 范围是 [0, num_heads_q + num_heads_k)
91+
combo_head_idx = tl.program_id(1)
92+
93+
# 公共的 offset (0 ~ head_dim)
94+
offs = tl.arange(0, BLOCK_SIZE)
95+
96+
# === 分支逻辑:判断是处理 Q 还是 K ===
97+
if combo_head_idx < num_heads_q:
98+
# ------------------ 处理 Q ------------------
99+
# 指针计算
100+
# Q 的实际 head index 就是 combo_head_idx
101+
Q_ptr += row_idx * stride_q_row
102+
103+
# 定位 Q 数据: Base + Row偏移 + Head偏移 + 列偏移
104+
q_ptr_offset = (combo_head_idx * head_dim + offs) * stride_q_col
105+
106+
# 加载 Q 数据
107+
x = tl.load(Q_ptr + q_ptr_offset).to(tl.float32)
108+
# RMSNorm 计算
109+
var = tl.sum(x * x, axis=0) / head_dim
110+
rstd = 1 / tl.sqrt(var + eps)
111+
112+
# 加载 Q 的权重 (假设所有 Head 共享同一组 dim=head_dim 的权重)
113+
w = tl.load(WQ_ptr + offs)
114+
115+
x *= rstd
116+
y = x.to(w.dtype) * w
117+
118+
# 写回 Q
119+
tl.store(Q_ptr + q_ptr_offset, y)
120+
121+
else:
122+
# ------------------ 处理 K ------------------
123+
# 重新映射 K 的 head index (从 0 开始)
124+
k_head_idx = combo_head_idx - num_heads_q
125+
126+
# 指针计算
127+
K_ptr += row_idx * stride_k_row
128+
k_ptr_offset = (k_head_idx * head_dim + offs) * stride_k_col
129+
130+
# 加载 K 数据
131+
x = tl.load(K_ptr + k_ptr_offset).to(tl.float32)
132+
# RMSNorm 计算
133+
var = tl.sum(x * x, axis=0) / head_dim
134+
rstd = 1 / tl.sqrt(var + eps)
135+
136+
# 加载 K 的权重
137+
w = tl.load(WK_ptr + offs)
138+
x *= rstd
139+
140+
y = x.to(w.dtype) * w
141+
142+
# 写回 K
143+
tl.store(K_ptr + k_ptr_offset, y)
144+
145+
146+
def qk_rmsnorm_fused_forward(q: torch.Tensor, k: torch.Tensor, w_q: torch.Tensor, w_k: torch.Tensor, eps: float = 1e-6):
147+
"""
148+
In-place RMSNorm for both Q and K in a single kernel launch.
149+
Supports GQA (different number of heads for Q and K).
150+
151+
Args:
152+
q: (Total_Tokens, Hidden_Q) or (B, S, H_q, D) -> flattend to 2D inside
153+
k: (Total_Tokens, Hidden_K)
154+
w_q: (head_dim,) Scale parameter for Q
155+
w_k: (head_dim,) Scale parameter for K
156+
"""
157+
# 1. 维度与连续性检查
158+
# 将输入统一视为 (Total_Tokens, Hidden_Size) 的 2D 视图
159+
q_view = q.view(-1, q.shape[-1])
160+
k_view = k.view(-1, k.shape[-1])
161+
162+
assert w_q.is_contiguous() and w_k.is_contiguous()
163+
164+
M = q_view.shape[0] # Total Tokens
165+
assert k_view.shape[0] == M, "Q and K must have the same number of tokens"
166+
167+
head_dim = w_q.shape[0]
168+
assert w_k.shape[0] == head_dim, "Head dim of Q and K must match"
169+
170+
# 计算 Head 数量
171+
N_q = q_view.shape[1]
172+
N_k = k_view.shape[1]
173+
174+
assert N_q % head_dim == 0
175+
assert N_k % head_dim == 0
176+
177+
num_heads_q = N_q // head_dim
178+
num_heads_k = N_k // head_dim
179+
180+
# 2. Block Size 设置
181+
BLOCK_SIZE = triton.next_power_of_2(head_dim)
182+
assert BLOCK_SIZE == head_dim, "Currently only supports head_dim power of 2 (e.g., 64, 128)"
183+
184+
# 3. 启动 Kernel
185+
# Grid: (Token数量, Q头数 + K头数)
186+
grid = (M, num_heads_q + num_heads_k)
187+
188+
_qk_rms_norm_fused_kernel[grid](
189+
q_view,
190+
w_q,
191+
q_view.stride(0),
192+
q_view.stride(1),
193+
k_view,
194+
w_k,
195+
k_view.stride(0),
196+
k_view.stride(1),
197+
num_heads_q=num_heads_q,
198+
head_dim=head_dim,
199+
eps=eps,
200+
BLOCK_SIZE=BLOCK_SIZE,
201+
num_warps=4,
202+
)
203+
204+
return q, k

lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,15 @@
1717
"num_stages": 3,
1818
"num_warps": 4
1919
},
20+
"192": {
21+
"BLOCK_SIZE_K": 64,
22+
"BLOCK_SIZE_M": 16,
23+
"BLOCK_SIZE_N": 128,
24+
"GROUP_SIZE_M": 64,
25+
"NEED_TRANS": false,
26+
"num_stages": 2,
27+
"num_warps": 4
28+
},
2029
"2048": {
2130
"BLOCK_SIZE_K": 32,
2231
"BLOCK_SIZE_M": 32,
@@ -35,6 +44,15 @@
3544
"num_stages": 2,
3645
"num_warps": 4
3746
},
47+
"384": {
48+
"BLOCK_SIZE_K": 64,
49+
"BLOCK_SIZE_M": 16,
50+
"BLOCK_SIZE_N": 128,
51+
"GROUP_SIZE_M": 16,
52+
"NEED_TRANS": false,
53+
"num_stages": 2,
54+
"num_warps": 4
55+
},
3856
"512": {
3957
"BLOCK_SIZE_K": 64,
4058
"BLOCK_SIZE_M": 16,
@@ -53,6 +71,24 @@
5371
"num_stages": 2,
5472
"num_warps": 4
5573
},
74+
"640": {
75+
"BLOCK_SIZE_K": 64,
76+
"BLOCK_SIZE_M": 16,
77+
"BLOCK_SIZE_N": 128,
78+
"GROUP_SIZE_M": 1,
79+
"NEED_TRANS": false,
80+
"num_stages": 2,
81+
"num_warps": 4
82+
},
83+
"768": {
84+
"BLOCK_SIZE_K": 64,
85+
"BLOCK_SIZE_M": 16,
86+
"BLOCK_SIZE_N": 128,
87+
"GROUP_SIZE_M": 64,
88+
"NEED_TRANS": false,
89+
"num_stages": 2,
90+
"num_warps": 4
91+
},
5692
"8": {
5793
"BLOCK_SIZE_K": 32,
5894
"BLOCK_SIZE_M": 16,
@@ -79,5 +115,23 @@
79115
"NEED_TRANS": false,
80116
"num_stages": 2,
81117
"num_warps": 4
118+
},
119+
"896": {
120+
"BLOCK_SIZE_K": 32,
121+
"BLOCK_SIZE_M": 16,
122+
"BLOCK_SIZE_N": 128,
123+
"GROUP_SIZE_M": 64,
124+
"NEED_TRANS": false,
125+
"num_stages": 3,
126+
"num_warps": 4
127+
},
128+
"96": {
129+
"BLOCK_SIZE_K": 64,
130+
"BLOCK_SIZE_M": 16,
131+
"BLOCK_SIZE_N": 64,
132+
"GROUP_SIZE_M": 32,
133+
"NEED_TRANS": false,
134+
"num_stages": 4,
135+
"num_warps": 4
82136
}
83137
}

0 commit comments

Comments
 (0)