Skip to content

Commit 82f0d7f

Browse files
committed
RL support bf16 deep_gemm moe
1 parent 5c9fa43 commit 82f0d7f

1 file changed

Lines changed: 46 additions & 54 deletions

File tree

fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py

Lines changed: 46 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,7 @@
2828
from .fused_moe_backend_base import UnquantizedFusedMoEMethod
2929

3030
if current_platform.is_cuda():
31-
from fastdeploy.model_executor.ops.gpu import (
32-
count_tokens_per_expert_func,
33-
moe_expert_dispatch,
34-
moe_expert_reduce,
35-
)
31+
from fastdeploy.model_executor.ops.gpu import moe_expert_dispatch, moe_expert_reduce
3632

3733
try:
3834
from fastdeploy.model_executor.ops.gpu import (
@@ -43,6 +39,7 @@
4339
logger.warning("import w4afp8_gemm_scale_permute Failed!")
4440

4541
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
42+
from fastdeploy.model_executor.layers.quantization.fp8_utils import paddlefleet_ops
4643
from fastdeploy.model_executor.utils import (
4744
TensorTracker,
4845
free_tensor,
@@ -52,6 +49,12 @@
5249
)
5350

5451

52+
def deep_batch_gemm(x, y, expert_idx_per_token):
53+
out = paddle.empty([x.shape[0], y.shape[-1]], dtype=x.dtype)
54+
paddlefleet_ops.deep_gemm.m_grouped_bf16_gemm_nn_contiguous(x, y, out, expert_idx_per_token)
55+
return out
56+
57+
5558
class CutlassMoEMethod(UnquantizedFusedMoEMethod):
5659
"""
5760
Use Cutlass Group Gemm to compute Fused MoE.
@@ -155,30 +158,22 @@ def apply_ep_prefill(
155158
if fastdeploy.envs.FD_USE_PHI_MOE_PERMUTE and self.moe_quant_type == "w16a16":
156159
# --- moe_permute / moe_unpermute path ---
157160
recv_topk_idx_i32 = recv_topk_idx.astype(paddle.int32)
158-
(permute_input, permute_indices_per_token, dst_weights, _scale_out) = paddle.nn.functional.moe_permute(
159-
hidden_states=recv_x,
160-
scale=None,
161-
expert_routemap_topk=recv_topk_idx_i32,
162-
expert_prob_topk=recv_topk_weights,
163-
num_experts=layer.num_local_experts,
164-
tokens_per_expert=[],
165-
padding_alignment=128,
166-
override_buffer_size=token_all_num,
161+
(permute_input, permute_indices_per_token, dst_weights, _scale_out, expert_idx_per_token) = (
162+
paddle.nn.functional.moe_permute(
163+
hidden_states=recv_x,
164+
scale=None,
165+
expert_routemap_topk=recv_topk_idx_i32,
166+
expert_prob_topk=recv_topk_weights,
167+
num_experts=layer.num_local_experts,
168+
tokens_per_expert=[],
169+
padding_alignment=128,
170+
override_buffer_size=token_all_num,
171+
)
167172
)
168173

169-
token_nums_per_expert_cumsum = count_tokens_per_expert_func(
170-
recv_topk_idx, layer.num_local_experts, True
171-
)[2].cast(paddle.int64)
172-
ffn_out = self.compute_ffn(
173-
layer,
174-
permute_input,
175-
token_nums_per_expert_cumsum,
176-
None,
177-
False,
178-
-1,
179-
None,
180-
None,
181-
)
174+
out = deep_batch_gemm(permute_input, getattr(layer, self.added_weight_attrs[0]), expert_idx_per_token)
175+
out = paddlefleet_ops.fused_swiglu_scale(out, dst_weights)
176+
ffn_out = deep_batch_gemm(out, getattr(layer, self.added_weight_attrs[1]), expert_idx_per_token)
182177

183178
tmp_ffn_out, _out_probs = paddle.nn.functional.moe_unpermute(
184179
hidden_states_unzipped=ffn_out,
@@ -187,7 +182,7 @@ def apply_ep_prefill(
187182
token_prob_unzipped=dst_weights,
188183
total_zipped_tokens=recv_x.shape[0],
189184
num_experts=layer.num_local_experts,
190-
using_weighted_combine=True,
185+
using_weighted_combine=False,
191186
)
192187
else:
193188
# --- original ep_moe_expert_dispatch / combine path ---
@@ -339,36 +334,33 @@ def apply_tp(
339334
)
340335
topk_idx_i32 = topk_idx.astype(paddle.int32)
341336
override_buffer_size = x.shape[0] * layer.top_k + layer.num_experts * (128 - 1)
342-
(permute_input, permute_indices_per_token, dst_weights, _scale_out) = ( # zipped_expertwise_rowmap
343-
paddle.nn.functional.moe_permute(
344-
hidden_states=x,
345-
scale=None,
346-
expert_routemap_topk=topk_idx_i32,
347-
expert_prob_topk=topk_weights,
348-
num_experts=layer.num_experts,
349-
tokens_per_expert=[],
350-
padding_alignment=128,
351-
override_buffer_size=override_buffer_size,
352-
)
337+
(
338+
permute_input,
339+
permute_indices_per_token,
340+
dst_weights,
341+
_scale_out,
342+
expert_idx_per_token,
343+
) = paddle.nn.functional.moe_permute( # zipped_expertwise_rowmap
344+
hidden_states=x,
345+
scale=None,
346+
expert_routemap_topk=topk_idx_i32,
347+
expert_prob_topk=topk_weights,
348+
num_experts=layer.num_experts,
349+
tokens_per_expert=[],
350+
padding_alignment=128,
351+
override_buffer_size=override_buffer_size,
352+
return_expert_indices=True,
353353
)
354354

355-
# Row 2 of count_tokens_per_expert_func is the prefix sum token_nums_per_expert.
356-
token_nums_per_expert_cumsum = count_tokens_per_expert_func(topk_idx, layer.num_experts, True)[2].cast(
357-
paddle.int64
358-
)
359355
if topk_ids_hookfunc is not None:
360356
topk_ids_hookfunc(topk_ids=topk_idx)
361357

362-
ffn_out = self.compute_ffn(
363-
layer,
364-
permute_input,
365-
token_nums_per_expert_cumsum,
366-
None, # expert_idx_per_token not needed for w16a16 without bias
367-
False,
368-
-1,
369-
None, # dequant_scale
370-
None, # max_tokens_per_expert
371-
)
358+
out = deep_batch_gemm(permute_input, getattr(layer, self.added_weight_attrs[0]), expert_idx_per_token)
359+
out = paddlefleet_ops.fused_swiglu_scale(out, dst_weights)
360+
ffn_out = deep_batch_gemm(out, getattr(layer, self.added_weight_attrs[1]), expert_idx_per_token)
361+
if layer.with_bias:
362+
down_proj_bias_expand = paddle.index_select(layer.down_proj_bias, expert_idx_per_token, axis=0)
363+
ffn_out = paddle.add(ffn_out, down_proj_bias_expand)
372364

373365
fused_moe_out, _out_probs = paddle.nn.functional.moe_unpermute(
374366
hidden_states_unzipped=ffn_out,
@@ -377,7 +369,7 @@ def apply_tp(
377369
token_prob_unzipped=dst_weights,
378370
total_zipped_tokens=x.shape[0],
379371
num_experts=layer.num_experts,
380-
using_weighted_combine=True,
372+
using_weighted_combine=False,
381373
)
382374
return fused_moe_out
383375

0 commit comments

Comments
 (0)