Skip to content

Commit 7f55586

Browse files
authored
[OP]Unify MoE op with moe_permute path for bf16 GLM (#7164) (#7282)
1 parent 14598d6 commit 7f55586

5 files changed

Lines changed: 1154 additions & 69 deletions

File tree

custom_ops/gpu_ops/cpp_extensions.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -522,7 +522,9 @@ std::vector<paddle::Tensor> TextImageGatherScatter(
522522
const bool is_scatter);
523523

524524
std::vector<paddle::Tensor> count_tokens_per_expert_func(
525-
const paddle::Tensor& topk_ids, int64_t num_experts);
525+
const paddle::Tensor& topk_ids,
526+
int64_t num_experts,
527+
bool compute_padded_cumsum = false);
526528
void GetPositionIdsAndMaskEncoderBatch(
527529
const paddle::Tensor& seq_lens_encoder,
528530
const paddle::Tensor& seq_lens_decoder,

custom_ops/gpu_ops/moe/deepgemm_preprocess.cu

Lines changed: 52 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@
1515
#include "helper.h"
1616
#include "paddle/extension.h"
1717

18-
template <typename scalar_t>
18+
template <typename scalar_t, bool kComputeCumsum>
1919
__global__ void cuda_kernel(const scalar_t *__restrict__ topk_ids,
2020
int32_t *__restrict__ res,
2121
int32_t *__restrict__ res_padded,
22+
int32_t *__restrict__ res_padded_cumsum,
2223
size_t numel,
2324
int num_experts) {
2425
extern __shared__ int32_t tokens_per_ep[];
@@ -35,48 +36,81 @@ __global__ void cuda_kernel(const scalar_t *__restrict__ topk_ids,
3536

3637
__syncthreads();
3738

38-
for (size_t i = threadIdx.x; i < num_experts; i += blockDim.x) {
39-
res[i] = tokens_per_ep[i];
40-
res_padded[i] = (res[i] + 127) / 128 * 128;
39+
if constexpr (kComputeCumsum) {
40+
if (threadIdx.x == 0) {
41+
int32_t running_sum = 0;
42+
for (int i = 0; i < num_experts; i++) {
43+
int32_t count = tokens_per_ep[i];
44+
int32_t padded = (count + 127) / 128 * 128;
45+
res[i] = count;
46+
res_padded[i] = padded;
47+
running_sum += padded;
48+
res_padded_cumsum[i] = running_sum;
49+
}
50+
}
51+
} else {
52+
for (size_t i = threadIdx.x; i < num_experts; i += blockDim.x) {
53+
res[i] = tokens_per_ep[i];
54+
res_padded[i] = (tokens_per_ep[i] + 127) / 128 * 128;
55+
}
4156
}
4257
}
4358

4459
std::vector<paddle::Tensor> count_tokens_per_expert_func(
45-
const paddle::Tensor &topk_ids, int64_t num_experts) {
60+
const paddle::Tensor &topk_ids,
61+
int64_t num_experts,
62+
bool compute_padded_cumsum) {
4663
int topk_ids_numel = topk_ids.shape()[0] * topk_ids.shape()[1];
4764

65+
int64_t num_rows = compute_padded_cumsum ? 3 : 2;
4866
auto token_nums_per_expert = paddle::empty(
49-
{2, num_experts}, paddle::DataType::INT32, topk_ids.place());
67+
{num_rows, num_experts}, paddle::DataType::INT32, topk_ids.place());
5068

5169
auto stream = topk_ids.stream();
5270
using scalar_t = int64_t;
5371

54-
// CUDA_CHECK(cudaGetLastError());
55-
cuda_kernel<<<1, 1024, num_experts * sizeof(int32_t), stream>>>(
56-
topk_ids.data<scalar_t>(),
57-
token_nums_per_expert.data<int32_t>(),
58-
token_nums_per_expert.data<int32_t>() + num_experts,
59-
topk_ids_numel,
60-
num_experts);
72+
if (compute_padded_cumsum) {
73+
cuda_kernel<scalar_t, true>
74+
<<<1, 1024, num_experts * sizeof(int32_t), stream>>>(
75+
topk_ids.data<scalar_t>(),
76+
token_nums_per_expert.data<int32_t>(),
77+
token_nums_per_expert.data<int32_t>() + num_experts,
78+
token_nums_per_expert.data<int32_t>() + 2 * num_experts,
79+
topk_ids_numel,
80+
num_experts);
81+
} else {
82+
cuda_kernel<scalar_t, false>
83+
<<<1, 1024, num_experts * sizeof(int32_t), stream>>>(
84+
topk_ids.data<scalar_t>(),
85+
token_nums_per_expert.data<int32_t>(),
86+
token_nums_per_expert.data<int32_t>() + num_experts,
87+
nullptr,
88+
topk_ids_numel,
89+
num_experts);
90+
}
6191

62-
// CUDA_CHECK(cudaGetLastError());
6392
return {token_nums_per_expert};
6493
}
6594

6695
std::vector<paddle::DataType> count_tokens_per_expert_func_infer_dtype(
67-
const paddle::DataType &topk_ids_dtype, int64_t num_experts) {
96+
const paddle::DataType &topk_ids_dtype,
97+
int64_t num_experts,
98+
bool compute_padded_cumsum) {
6899
return {paddle::DataType::INT32};
69100
}
70101

71102
std::vector<std::vector<int64_t>> count_tokens_per_expert_func_infer_shape(
72-
const std::vector<int64_t> &topk_ids_shape, int64_t num_experts) {
73-
return {{2, num_experts}};
103+
const std::vector<int64_t> &topk_ids_shape,
104+
int64_t num_experts,
105+
bool compute_padded_cumsum) {
106+
int64_t num_rows = compute_padded_cumsum ? 3 : 2;
107+
return {{num_rows, num_experts}};
74108
}
75109

76110
PD_BUILD_STATIC_OP(count_tokens_per_expert_func)
77111
.Inputs({"topk_ids"})
78112
.Outputs({"token_nums_per_expert"})
79-
.Attrs({"num_experts:int64_t"})
113+
.Attrs({"num_experts:int64_t", "compute_padded_cumsum:bool"})
80114
.SetKernelFn(PD_KERNEL(count_tokens_per_expert_func))
81115
.SetInferShapeFn(PD_INFER_SHAPE(count_tokens_per_expert_func_infer_shape))
82116
.SetInferDtypeFn(PD_INFER_DTYPE(count_tokens_per_expert_func_infer_dtype));

fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py

Lines changed: 151 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,11 @@
2828
from .fused_moe_backend_base import UnquantizedFusedMoEMethod
2929

3030
if current_platform.is_cuda():
31-
from fastdeploy.model_executor.ops.gpu import moe_expert_dispatch, moe_expert_reduce
31+
from fastdeploy.model_executor.ops.gpu import (
32+
count_tokens_per_expert_func,
33+
moe_expert_dispatch,
34+
moe_expert_reduce,
35+
)
3236

3337
try:
3438
from fastdeploy.model_executor.ops.gpu import (
@@ -145,14 +149,15 @@ def apply_ep_prefill(
145149
# 1. Select topk experts and weights
146150
topk_idx, topk_weights = self.ep_prefill_runner.moe_select(layer, gate_out)
147151
# 2. EP Dispatch
152+
dispatch_kwargs = {"expert_alignment": 128} if fastdeploy.envs.FD_USE_PHI_MOE_PERMUTE else {}
148153
(
149154
recv_x,
150155
recv_topk_idx,
151156
recv_topk_weights,
152157
recv_num_tokens_per_expert_list,
153158
handle,
154159
event,
155-
) = self.ep_prefill_runner.dispatch(x, topk_idx, topk_weights)
160+
) = self.ep_prefill_runner.dispatch(x, topk_idx, topk_weights, **dispatch_kwargs)
156161

157162
if topk_ids_hookfunc is not None:
158163
topk_ids_hookfunc(topk_ids=topk_idx)
@@ -165,54 +170,91 @@ def apply_ep_prefill(
165170
# 3. Compute ffn
166171
if token_all_num > 0:
167172
logger.debug(f"token_all_num {token_all_num}")
168-
(
169-
permute_input,
170-
permute_indices_per_token,
171-
recv_num_tokens_per_expert_list_cumsum,
172-
dst_weights,
173-
dst_indices,
174-
cumsum_idx_gpu,
175-
expert_idx_per_token,
176-
dequant_scale,
177-
) = fastdeploy.model_executor.ops.gpu.ep_moe_expert_dispatch(
178-
recv_x,
179-
recv_topk_idx,
180-
recv_topk_weights,
181-
(layer.up_gate_proj_in_scale if hasattr(layer, "up_gate_proj_in_scale") else None),
182-
recv_num_tokens_per_expert_list,
183-
token_all_num,
184-
self.moe_quant_type,
185-
)
186-
if not layer.with_bias and self.moe_quant_type != "w4a8" and self.moe_quant_type != "w4afp8":
187-
# only w4a8 and w4afp8 need expert_idx_per_token
188-
# Other need not this tensor, so we make it None.
189-
expert_idx_per_token = None
173+
174+
if fastdeploy.envs.FD_USE_PHI_MOE_PERMUTE and self.moe_quant_type == "w16a16":
175+
# --- moe_permute / moe_unpermute path ---
176+
recv_topk_idx_i32 = recv_topk_idx.astype(paddle.int32)
177+
(permute_input, permute_indices_per_token, dst_weights, _scale_out) = paddle.nn.functional.moe_permute(
178+
hidden_states=recv_x,
179+
scale=None,
180+
expert_routemap_topk=recv_topk_idx_i32,
181+
expert_prob_topk=recv_topk_weights,
182+
num_experts=layer.num_local_experts,
183+
tokens_per_expert=[],
184+
padding_alignment=128,
185+
override_buffer_size=token_all_num,
186+
)
187+
188+
token_nums_per_expert_cumsum = count_tokens_per_expert_func(
189+
recv_topk_idx, layer.num_local_experts, True
190+
)[2].cast(paddle.int64)
191+
ffn_out = self.compute_ffn(
192+
layer,
193+
permute_input,
194+
token_nums_per_expert_cumsum,
195+
None,
196+
False,
197+
-1,
198+
None,
199+
None,
200+
)
201+
202+
tmp_ffn_out, _out_probs = paddle.nn.functional.moe_unpermute(
203+
hidden_states_unzipped=ffn_out,
204+
zipped_expertwise_rowmap=permute_indices_per_token,
205+
expert_routemap_topk=recv_topk_idx_i32,
206+
token_prob_unzipped=dst_weights,
207+
total_zipped_tokens=recv_x.shape[0],
208+
num_experts=layer.num_local_experts,
209+
using_weighted_combine=True,
210+
)
190211
else:
191-
expert_idx_per_token = expert_idx_per_token.cast("int64")
212+
# --- original ep_moe_expert_dispatch / combine path ---
213+
(
214+
permute_input,
215+
permute_indices_per_token,
216+
recv_num_tokens_per_expert_list_cumsum,
217+
dst_weights,
218+
dst_indices,
219+
cumsum_idx_gpu,
220+
expert_idx_per_token,
221+
dequant_scale,
222+
) = fastdeploy.model_executor.ops.gpu.ep_moe_expert_dispatch(
223+
recv_x,
224+
recv_topk_idx,
225+
recv_topk_weights,
226+
(layer.up_gate_proj_in_scale if hasattr(layer, "up_gate_proj_in_scale") else None),
227+
recv_num_tokens_per_expert_list,
228+
token_all_num,
229+
self.moe_quant_type,
230+
)
231+
if not layer.with_bias and self.moe_quant_type != "w4a8" and self.moe_quant_type != "w4afp8":
232+
expert_idx_per_token = None
233+
else:
234+
expert_idx_per_token = expert_idx_per_token.cast("int64")
192235

193-
if hasattr(layer, "up_gate_proj_in_scale"):
194-
dequant_scale = None
236+
if hasattr(layer, "up_gate_proj_in_scale"):
237+
dequant_scale = None
195238

196-
ffn_out = self.compute_ffn(
197-
layer,
198-
permute_input,
199-
recv_num_tokens_per_expert_list_cumsum,
200-
expert_idx_per_token,
201-
False,
202-
-1,
203-
dequant_scale,
204-
)
239+
ffn_out = self.compute_ffn(
240+
layer,
241+
permute_input,
242+
recv_num_tokens_per_expert_list_cumsum,
243+
expert_idx_per_token,
244+
False,
245+
-1,
246+
dequant_scale,
247+
)
205248

206-
# prmt back per rank
207-
tmp_ffn_out = fastdeploy.model_executor.ops.gpu.ep_moe_expert_combine(
208-
ffn_out,
209-
dst_weights,
210-
permute_indices_per_token,
211-
dst_indices,
212-
None, # down_proj_bias,
213-
False, # norm_topk_prob
214-
1.0,
215-
)
249+
tmp_ffn_out = fastdeploy.model_executor.ops.gpu.ep_moe_expert_combine(
250+
ffn_out,
251+
dst_weights,
252+
permute_indices_per_token,
253+
dst_indices,
254+
None, # down_proj_bias,
255+
False, # norm_topk_prob
256+
1.0,
257+
)
216258
else:
217259
tmp_ffn_out = recv_x
218260

@@ -292,6 +334,69 @@ def apply_tp(
292334
Paddle Cutlass compute Fused MoE.
293335
"""
294336
gate_out = gate(x.cast("float32"))
337+
if fastdeploy.envs.FD_USE_PHI_MOE_PERMUTE and self.moe_quant_type == "w16a16":
338+
if layer.topk_method == "noaux_tc":
339+
gate_out, topk_weights, topk_idx = get_moe_scores(
340+
gate_out,
341+
layer.n_group,
342+
layer.topk_group,
343+
layer.top_k,
344+
layer.routed_scaling_factor,
345+
layer.gate_correction_bias,
346+
getattr(layer, "renormalize", True),
347+
)
348+
else:
349+
topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
350+
gate_out,
351+
layer.gate_correction_bias,
352+
layer.top_k,
353+
True, # apply_norm_weight
354+
False,
355+
)
356+
topk_idx_i32 = topk_idx.astype(paddle.int32)
357+
override_buffer_size = x.shape[0] * layer.top_k + layer.num_experts * (128 - 1)
358+
(permute_input, permute_indices_per_token, dst_weights, _scale_out) = ( # zipped_expertwise_rowmap
359+
paddle.nn.functional.moe_permute(
360+
hidden_states=x,
361+
scale=None,
362+
expert_routemap_topk=topk_idx_i32,
363+
expert_prob_topk=topk_weights,
364+
num_experts=layer.num_experts,
365+
tokens_per_expert=[],
366+
padding_alignment=128,
367+
override_buffer_size=override_buffer_size,
368+
)
369+
)
370+
371+
# Row 2 of count_tokens_per_expert_func is the prefix sum token_nums_per_expert.
372+
token_nums_per_expert_cumsum = count_tokens_per_expert_func(topk_idx, layer.num_experts, True)[2].cast(
373+
paddle.int64
374+
)
375+
if topk_ids_hookfunc is not None:
376+
topk_ids_hookfunc(topk_ids=topk_idx)
377+
378+
ffn_out = self.compute_ffn(
379+
layer,
380+
permute_input,
381+
token_nums_per_expert_cumsum,
382+
None, # expert_idx_per_token not needed for w16a16 without bias
383+
False,
384+
-1,
385+
None, # dequant_scale
386+
None, # max_tokens_per_expert
387+
)
388+
389+
fused_moe_out, _out_probs = paddle.nn.functional.moe_unpermute(
390+
hidden_states_unzipped=ffn_out,
391+
zipped_expertwise_rowmap=permute_indices_per_token,
392+
expert_routemap_topk=topk_idx_i32,
393+
token_prob_unzipped=dst_weights,
394+
total_zipped_tokens=x.shape[0],
395+
num_experts=layer.num_experts,
396+
using_weighted_combine=True,
397+
)
398+
return fused_moe_out
399+
295400
if layer.topk_method == "noaux_tc":
296401
gate_out, topk_weights, topk_idx = get_moe_scores(
297402
gate_out,
@@ -401,7 +506,6 @@ def apply_tp(
401506
expert_idx_per_token = None
402507
else:
403508
expert_idx_per_token = expert_idx_per_token.cast("int64")
404-
405509
ffn_out = self.compute_ffn(
406510
layer,
407511
permute_input,
@@ -423,7 +527,6 @@ def apply_tp(
423527
norm_topk_prob=False if layer.topk_method == "noaux_tc" else True,
424528
routed_scaling_factor=1.0,
425529
)
426-
427530
return fused_moe_out
428531

429532

fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ def apply_ep_prefill(
341341
using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0,
342342
)
343343
else:
344-
token_nums_this_rank = count_tokens_per_expert_func(recv_topk_idx, layer.num_local_experts)
344+
token_nums_this_rank = count_tokens_per_expert_func(recv_topk_idx, layer.num_local_experts, False)
345345
(
346346
permute_input,
347347
permute_scale,
@@ -602,7 +602,7 @@ def apply_tp(
602602
using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0,
603603
)
604604
else:
605-
tmp = count_tokens_per_expert_func(topk_ids, layer.num_experts)
605+
tmp = count_tokens_per_expert_func(topk_ids, layer.num_experts, False)
606606
(
607607
permute_input,
608608
permute_scale,

0 commit comments

Comments
 (0)