Skip to content

Commit 39ff38a

Browse files
authored
[OP]Unify MoE op with moe_permute path for bf16 GLM (#7164)
1 parent 33682c6 commit 39ff38a

File tree

5 files changed

+444
-69
lines changed

5 files changed

+444
-69
lines changed

custom_ops/gpu_ops/cpp_extensions.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -537,7 +537,9 @@ std::vector<paddle::Tensor> TextImageGatherScatter(
537537
const bool is_scatter);
538538

539539
std::vector<paddle::Tensor> count_tokens_per_expert_func(
540-
const paddle::Tensor& topk_ids, int64_t num_experts);
540+
const paddle::Tensor& topk_ids,
541+
int64_t num_experts,
542+
bool compute_padded_cumsum = false);
541543
void GetPositionIdsAndMaskEncoderBatch(
542544
const paddle::Tensor& seq_lens_encoder,
543545
const paddle::Tensor& seq_lens_decoder,

custom_ops/gpu_ops/moe/deepgemm_preprocess.cu

Lines changed: 52 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@
1515
#include "helper.h"
1616
#include "paddle/extension.h"
1717

18-
template <typename scalar_t>
18+
template <typename scalar_t, bool kComputeCumsum>
1919
__global__ void cuda_kernel(const scalar_t *__restrict__ topk_ids,
2020
int32_t *__restrict__ res,
2121
int32_t *__restrict__ res_padded,
22+
int32_t *__restrict__ res_padded_cumsum,
2223
size_t numel,
2324
int num_experts) {
2425
extern __shared__ int32_t tokens_per_ep[];
@@ -35,48 +36,81 @@ __global__ void cuda_kernel(const scalar_t *__restrict__ topk_ids,
3536

3637
__syncthreads();
3738

38-
for (size_t i = threadIdx.x; i < num_experts; i += blockDim.x) {
39-
res[i] = tokens_per_ep[i];
40-
res_padded[i] = (res[i] + 127) / 128 * 128;
39+
if constexpr (kComputeCumsum) {
40+
if (threadIdx.x == 0) {
41+
int32_t running_sum = 0;
42+
for (int i = 0; i < num_experts; i++) {
43+
int32_t count = tokens_per_ep[i];
44+
int32_t padded = (count + 127) / 128 * 128;
45+
res[i] = count;
46+
res_padded[i] = padded;
47+
running_sum += padded;
48+
res_padded_cumsum[i] = running_sum;
49+
}
50+
}
51+
} else {
52+
for (size_t i = threadIdx.x; i < num_experts; i += blockDim.x) {
53+
res[i] = tokens_per_ep[i];
54+
res_padded[i] = (tokens_per_ep[i] + 127) / 128 * 128;
55+
}
4156
}
4257
}
4358

4459
std::vector<paddle::Tensor> count_tokens_per_expert_func(
45-
const paddle::Tensor &topk_ids, int64_t num_experts) {
60+
const paddle::Tensor &topk_ids,
61+
int64_t num_experts,
62+
bool compute_padded_cumsum) {
4663
int topk_ids_numel = topk_ids.shape()[0] * topk_ids.shape()[1];
4764

65+
int64_t num_rows = compute_padded_cumsum ? 3 : 2;
4866
auto token_nums_per_expert = paddle::empty(
49-
{2, num_experts}, paddle::DataType::INT32, topk_ids.place());
67+
{num_rows, num_experts}, paddle::DataType::INT32, topk_ids.place());
5068

5169
auto stream = topk_ids.stream();
5270
using scalar_t = int64_t;
5371

54-
// CUDA_CHECK(cudaGetLastError());
55-
cuda_kernel<<<1, 1024, num_experts * sizeof(int32_t), stream>>>(
56-
topk_ids.data<scalar_t>(),
57-
token_nums_per_expert.data<int32_t>(),
58-
token_nums_per_expert.data<int32_t>() + num_experts,
59-
topk_ids_numel,
60-
num_experts);
72+
if (compute_padded_cumsum) {
73+
cuda_kernel<scalar_t, true>
74+
<<<1, 1024, num_experts * sizeof(int32_t), stream>>>(
75+
topk_ids.data<scalar_t>(),
76+
token_nums_per_expert.data<int32_t>(),
77+
token_nums_per_expert.data<int32_t>() + num_experts,
78+
token_nums_per_expert.data<int32_t>() + 2 * num_experts,
79+
topk_ids_numel,
80+
num_experts);
81+
} else {
82+
cuda_kernel<scalar_t, false>
83+
<<<1, 1024, num_experts * sizeof(int32_t), stream>>>(
84+
topk_ids.data<scalar_t>(),
85+
token_nums_per_expert.data<int32_t>(),
86+
token_nums_per_expert.data<int32_t>() + num_experts,
87+
nullptr,
88+
topk_ids_numel,
89+
num_experts);
90+
}
6191

62-
// CUDA_CHECK(cudaGetLastError());
6392
return {token_nums_per_expert};
6493
}
6594

6695
std::vector<paddle::DataType> count_tokens_per_expert_func_infer_dtype(
67-
const paddle::DataType &topk_ids_dtype, int64_t num_experts) {
96+
const paddle::DataType &topk_ids_dtype,
97+
int64_t num_experts,
98+
bool compute_padded_cumsum) {
6899
return {paddle::DataType::INT32};
69100
}
70101

71102
std::vector<std::vector<int64_t>> count_tokens_per_expert_func_infer_shape(
72-
const std::vector<int64_t> &topk_ids_shape, int64_t num_experts) {
73-
return {{2, num_experts}};
103+
const std::vector<int64_t> &topk_ids_shape,
104+
int64_t num_experts,
105+
bool compute_padded_cumsum) {
106+
int64_t num_rows = compute_padded_cumsum ? 3 : 2;
107+
return {{num_rows, num_experts}};
74108
}
75109

76110
PD_BUILD_STATIC_OP(count_tokens_per_expert_func)
77111
.Inputs({"topk_ids"})
78112
.Outputs({"token_nums_per_expert"})
79-
.Attrs({"num_experts:int64_t"})
113+
.Attrs({"num_experts:int64_t", "compute_padded_cumsum:bool"})
80114
.SetKernelFn(PD_KERNEL(count_tokens_per_expert_func))
81115
.SetInferShapeFn(PD_INFER_SHAPE(count_tokens_per_expert_func_infer_shape))
82116
.SetInferDtypeFn(PD_INFER_DTYPE(count_tokens_per_expert_func_infer_dtype));

fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py

Lines changed: 152 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,11 @@
2828
from .fused_moe_backend_base import UnquantizedFusedMoEMethod
2929

3030
if current_platform.is_cuda():
31-
from fastdeploy.model_executor.ops.gpu import moe_expert_dispatch, moe_expert_reduce
31+
from fastdeploy.model_executor.ops.gpu import (
32+
count_tokens_per_expert_func,
33+
moe_expert_dispatch,
34+
moe_expert_reduce,
35+
)
3236

3337
try:
3438
from fastdeploy.model_executor.ops.gpu import (
@@ -126,14 +130,15 @@ def apply_ep_prefill(
126130
# 1. Select topk experts and weights
127131
topk_idx, topk_weights = self.ep_prefill_runner.moe_select(layer, gate_out)
128132
# 2. EP Dispatch
133+
dispatch_kwargs = {"expert_alignment": 128} if fastdeploy.envs.FD_USE_PHI_MOE_PERMUTE else {}
129134
(
130135
recv_x,
131136
recv_topk_idx,
132137
recv_topk_weights,
133138
recv_num_tokens_per_expert_list,
134139
handle,
135140
event,
136-
) = self.ep_prefill_runner.dispatch(x, topk_idx, topk_weights)
141+
) = self.ep_prefill_runner.dispatch(x, topk_idx, topk_weights, **dispatch_kwargs)
137142

138143
if topk_ids_hookfunc is not None:
139144
topk_ids_hookfunc(topk_ids=topk_idx)
@@ -146,54 +151,91 @@ def apply_ep_prefill(
146151
# 3. Compute ffn
147152
if token_all_num > 0:
148153
logger.debug(f"token_all_num {token_all_num}")
149-
(
150-
permute_input,
151-
permute_indices_per_token,
152-
recv_num_tokens_per_expert_list_cumsum,
153-
dst_weights,
154-
dst_indices,
155-
cumsum_idx_gpu,
156-
expert_idx_per_token,
157-
dequant_scale,
158-
) = fastdeploy.model_executor.ops.gpu.ep_moe_expert_dispatch(
159-
recv_x,
160-
recv_topk_idx,
161-
recv_topk_weights,
162-
(layer.up_gate_proj_in_scale if hasattr(layer, "up_gate_proj_in_scale") else None),
163-
recv_num_tokens_per_expert_list,
164-
token_all_num,
165-
self.moe_quant_type,
166-
)
167-
if not layer.with_bias and self.moe_quant_type != "w4a8" and self.moe_quant_type != "w4afp8":
168-
# only w4a8 and w4afp8 need expert_idx_per_token
169-
# Other need not this tensor, so we make it None.
170-
expert_idx_per_token = None
154+
155+
if fastdeploy.envs.FD_USE_PHI_MOE_PERMUTE and self.moe_quant_type == "w16a16":
156+
# --- moe_permute / moe_unpermute path ---
157+
recv_topk_idx_i32 = recv_topk_idx.astype(paddle.int32)
158+
(permute_input, permute_indices_per_token, dst_weights, _scale_out) = paddle.nn.functional.moe_permute(
159+
hidden_states=recv_x,
160+
scale=None,
161+
expert_routemap_topk=recv_topk_idx_i32,
162+
expert_prob_topk=recv_topk_weights,
163+
num_experts=layer.num_local_experts,
164+
tokens_per_expert=[],
165+
padding_alignment=128,
166+
override_buffer_size=token_all_num,
167+
)
168+
169+
token_nums_per_expert_cumsum = count_tokens_per_expert_func(
170+
recv_topk_idx, layer.num_local_experts, True
171+
)[2].cast(paddle.int64)
172+
ffn_out = self.compute_ffn(
173+
layer,
174+
permute_input,
175+
token_nums_per_expert_cumsum,
176+
None,
177+
False,
178+
-1,
179+
None,
180+
None,
181+
)
182+
183+
tmp_ffn_out, _out_probs = paddle.nn.functional.moe_unpermute(
184+
hidden_states_unzipped=ffn_out,
185+
zipped_expertwise_rowmap=permute_indices_per_token,
186+
expert_routemap_topk=recv_topk_idx_i32,
187+
token_prob_unzipped=dst_weights,
188+
total_zipped_tokens=recv_x.shape[0],
189+
num_experts=layer.num_local_experts,
190+
using_weighted_combine=True,
191+
)
171192
else:
172-
expert_idx_per_token = expert_idx_per_token.cast("int64")
193+
# --- original ep_moe_expert_dispatch / combine path ---
194+
(
195+
permute_input,
196+
permute_indices_per_token,
197+
recv_num_tokens_per_expert_list_cumsum,
198+
dst_weights,
199+
dst_indices,
200+
cumsum_idx_gpu,
201+
expert_idx_per_token,
202+
dequant_scale,
203+
) = fastdeploy.model_executor.ops.gpu.ep_moe_expert_dispatch(
204+
recv_x,
205+
recv_topk_idx,
206+
recv_topk_weights,
207+
(layer.up_gate_proj_in_scale if hasattr(layer, "up_gate_proj_in_scale") else None),
208+
recv_num_tokens_per_expert_list,
209+
token_all_num,
210+
self.moe_quant_type,
211+
)
212+
if not layer.with_bias and self.moe_quant_type != "w4a8" and self.moe_quant_type != "w4afp8":
213+
expert_idx_per_token = None
214+
else:
215+
expert_idx_per_token = expert_idx_per_token.cast("int64")
173216

174-
if hasattr(layer, "up_gate_proj_in_scale"):
175-
dequant_scale = None
217+
if hasattr(layer, "up_gate_proj_in_scale"):
218+
dequant_scale = None
176219

177-
ffn_out = self.compute_ffn(
178-
layer,
179-
permute_input,
180-
recv_num_tokens_per_expert_list_cumsum,
181-
expert_idx_per_token,
182-
False,
183-
-1,
184-
dequant_scale,
185-
)
220+
ffn_out = self.compute_ffn(
221+
layer,
222+
permute_input,
223+
recv_num_tokens_per_expert_list_cumsum,
224+
expert_idx_per_token,
225+
False,
226+
-1,
227+
dequant_scale,
228+
)
186229

187-
# prmt back per rank
188-
tmp_ffn_out = fastdeploy.model_executor.ops.gpu.ep_moe_expert_combine(
189-
ffn_out,
190-
dst_weights,
191-
permute_indices_per_token,
192-
dst_indices,
193-
None, # down_proj_bias,
194-
False, # norm_topk_prob
195-
1.0,
196-
)
230+
tmp_ffn_out = fastdeploy.model_executor.ops.gpu.ep_moe_expert_combine(
231+
ffn_out,
232+
dst_weights,
233+
permute_indices_per_token,
234+
dst_indices,
235+
None, # down_proj_bias,
236+
False, # norm_topk_prob
237+
1.0,
238+
)
197239
else:
198240
tmp_ffn_out = recv_x
199241

@@ -276,6 +318,69 @@ def apply_tp(
276318
"""
277319
gate_out = gate(x)
278320
gate_out = gate_out.cast("float32")
321+
if fastdeploy.envs.FD_USE_PHI_MOE_PERMUTE and self.moe_quant_type == "w16a16":
322+
if layer.topk_method == "noaux_tc":
323+
gate_out, topk_weights, topk_idx = get_moe_scores(
324+
gate_out,
325+
layer.n_group,
326+
layer.topk_group,
327+
layer.top_k,
328+
layer.routed_scaling_factor,
329+
layer.gate_correction_bias,
330+
getattr(layer, "renormalize", True),
331+
)
332+
else:
333+
topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
334+
gate_out,
335+
layer.gate_correction_bias,
336+
layer.top_k,
337+
True, # apply_norm_weight
338+
False,
339+
)
340+
topk_idx_i32 = topk_idx.astype(paddle.int32)
341+
override_buffer_size = x.shape[0] * layer.top_k + layer.num_experts * (128 - 1)
342+
(permute_input, permute_indices_per_token, dst_weights, _scale_out) = ( # zipped_expertwise_rowmap
343+
paddle.nn.functional.moe_permute(
344+
hidden_states=x,
345+
scale=None,
346+
expert_routemap_topk=topk_idx_i32,
347+
expert_prob_topk=topk_weights,
348+
num_experts=layer.num_experts,
349+
tokens_per_expert=[],
350+
padding_alignment=128,
351+
override_buffer_size=override_buffer_size,
352+
)
353+
)
354+
355+
# Row 2 of count_tokens_per_expert_func is the prefix sum token_nums_per_expert.
356+
token_nums_per_expert_cumsum = count_tokens_per_expert_func(topk_idx, layer.num_experts, True)[2].cast(
357+
paddle.int64
358+
)
359+
if topk_ids_hookfunc is not None:
360+
topk_ids_hookfunc(topk_ids=topk_idx)
361+
362+
ffn_out = self.compute_ffn(
363+
layer,
364+
permute_input,
365+
token_nums_per_expert_cumsum,
366+
None, # expert_idx_per_token not needed for w16a16 without bias
367+
False,
368+
-1,
369+
None, # dequant_scale
370+
None, # max_tokens_per_expert
371+
)
372+
373+
fused_moe_out, _out_probs = paddle.nn.functional.moe_unpermute(
374+
hidden_states_unzipped=ffn_out,
375+
zipped_expertwise_rowmap=permute_indices_per_token,
376+
expert_routemap_topk=topk_idx_i32,
377+
token_prob_unzipped=dst_weights,
378+
total_zipped_tokens=x.shape[0],
379+
num_experts=layer.num_experts,
380+
using_weighted_combine=True,
381+
)
382+
return fused_moe_out
383+
279384
if layer.topk_method == "noaux_tc":
280385
gate_out, topk_weights, topk_idx = get_moe_scores(
281386
gate_out,
@@ -287,6 +392,7 @@ def apply_tp(
287392
getattr(layer, "renormalize", True),
288393
topk_reduce_func=getattr(layer, "topk_reduce_func", None),
289394
)
395+
290396
(
291397
permute_input,
292398
token_nums_per_expert,
@@ -341,7 +447,6 @@ def apply_tp(
341447
expert_idx_per_token = None
342448
else:
343449
expert_idx_per_token = expert_idx_per_token.cast("int64")
344-
345450
ffn_out = self.compute_ffn(
346451
layer,
347452
permute_input,
@@ -363,7 +468,6 @@ def apply_tp(
363468
norm_topk_prob=False if layer.topk_method == "noaux_tc" else True,
364469
routed_scaling_factor=1.0,
365470
)
366-
367471
return fused_moe_out
368472

369473

0 commit comments

Comments
 (0)