diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py index f927cd8c5ee..67a5d29bf53 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py @@ -28,11 +28,7 @@ from .fused_moe_backend_base import UnquantizedFusedMoEMethod if current_platform.is_cuda(): - from fastdeploy.model_executor.ops.gpu import ( - count_tokens_per_expert_func, - moe_expert_dispatch, - moe_expert_reduce, - ) + from fastdeploy.model_executor.ops.gpu import moe_expert_dispatch, moe_expert_reduce try: from fastdeploy.model_executor.ops.gpu import ( @@ -43,6 +39,7 @@ logger.warning("import w4afp8_gemm_scale_permute Failed!") from fastdeploy.model_executor.layers.moe.moe import get_moe_scores +from fastdeploy.model_executor.layers.quantization.fp8_utils import paddlefleet_ops from fastdeploy.model_executor.utils import ( TensorTracker, free_tensor, @@ -52,6 +49,12 @@ ) +def m_grouped_bf16_gemm_nn_contiguous(x, y, expert_idx_per_token): + out = paddle.empty([x.shape[0], y.shape[-1]], dtype=x.dtype) + paddlefleet_ops.deep_gemm.m_grouped_bf16_gemm_nn_contiguous(x, y, out, expert_idx_per_token) + return out + + class CutlassMoEMethod(UnquantizedFusedMoEMethod): """ Use Cutlass Group Gemm to compute Fused MoE. @@ -155,29 +158,26 @@ def apply_ep_prefill( if fastdeploy.envs.FD_USE_PHI_MOE_PERMUTE and self.moe_quant_type == "w16a16": # --- moe_permute / moe_unpermute path --- recv_topk_idx_i32 = recv_topk_idx.astype(paddle.int32) - (permute_input, permute_indices_per_token, dst_weights, _scale_out) = paddle.nn.functional.moe_permute( - hidden_states=recv_x, - scale=None, - expert_routemap_topk=recv_topk_idx_i32, - expert_prob_topk=recv_topk_weights, - num_experts=layer.num_local_experts, - tokens_per_expert=[], - padding_alignment=128, - override_buffer_size=token_all_num, + (permute_input, permute_indices_per_token, dst_weights, _scale_out, expert_idx_per_token) = ( + paddle.nn.functional.moe_permute( + hidden_states=recv_x, + scale=None, + expert_routemap_topk=recv_topk_idx_i32, + expert_prob_topk=recv_topk_weights, + num_experts=layer.num_local_experts, + tokens_per_expert=[], + padding_alignment=128, + override_buffer_size=token_all_num, + return_expert_indices=True, + ) ) - token_nums_per_expert_cumsum = count_tokens_per_expert_func( - recv_topk_idx, layer.num_local_experts, True - )[2].cast(paddle.int64) - ffn_out = self.compute_ffn( - layer, - permute_input, - token_nums_per_expert_cumsum, - None, - False, - -1, - None, - None, + out = m_grouped_bf16_gemm_nn_contiguous( + permute_input, getattr(layer, self.added_weight_attrs[0]), expert_idx_per_token + ) + out = paddlefleet_ops.fused_swiglu_scale(out, dst_weights) + ffn_out = m_grouped_bf16_gemm_nn_contiguous( + out, getattr(layer, self.added_weight_attrs[1]), expert_idx_per_token ) tmp_ffn_out, _out_probs = paddle.nn.functional.moe_unpermute( @@ -187,7 +187,7 @@ def apply_ep_prefill( token_prob_unzipped=dst_weights, total_zipped_tokens=recv_x.shape[0], num_experts=layer.num_local_experts, - using_weighted_combine=True, + using_weighted_combine=False, ) else: # --- original ep_moe_expert_dispatch / combine path --- @@ -339,36 +339,37 @@ def apply_tp( ) topk_idx_i32 = topk_idx.astype(paddle.int32) override_buffer_size = x.shape[0] * layer.top_k + layer.num_experts * (128 - 1) - (permute_input, permute_indices_per_token, dst_weights, _scale_out) = ( # zipped_expertwise_rowmap - paddle.nn.functional.moe_permute( - hidden_states=x, - scale=None, - expert_routemap_topk=topk_idx_i32, - expert_prob_topk=topk_weights, - num_experts=layer.num_experts, - tokens_per_expert=[], - padding_alignment=128, - override_buffer_size=override_buffer_size, - ) + ( + permute_input, + permute_indices_per_token, + dst_weights, + _scale_out, + expert_idx_per_token, + ) = paddle.nn.functional.moe_permute( # zipped_expertwise_rowmap + hidden_states=x, + scale=None, + expert_routemap_topk=topk_idx_i32, + expert_prob_topk=topk_weights, + num_experts=layer.num_experts, + tokens_per_expert=[], + padding_alignment=128, + override_buffer_size=override_buffer_size, + return_expert_indices=True, ) - # Row 2 of count_tokens_per_expert_func is the prefix sum token_nums_per_expert. - token_nums_per_expert_cumsum = count_tokens_per_expert_func(topk_idx, layer.num_experts, True)[2].cast( - paddle.int64 - ) if topk_ids_hookfunc is not None: topk_ids_hookfunc(topk_ids=topk_idx) - ffn_out = self.compute_ffn( - layer, - permute_input, - token_nums_per_expert_cumsum, - None, # expert_idx_per_token not needed for w16a16 without bias - False, - -1, - None, # dequant_scale - None, # max_tokens_per_expert + out = m_grouped_bf16_gemm_nn_contiguous( + permute_input, getattr(layer, self.added_weight_attrs[0]), expert_idx_per_token + ) + out = paddlefleet_ops.fused_swiglu_scale(out, dst_weights) + ffn_out = m_grouped_bf16_gemm_nn_contiguous( + out, getattr(layer, self.added_weight_attrs[1]), expert_idx_per_token ) + if layer.with_bias: + down_proj_bias_expand = paddle.index_select(layer.down_proj_bias, expert_idx_per_token, axis=0) + ffn_out = paddle.add(ffn_out, down_proj_bias_expand) fused_moe_out, _out_probs = paddle.nn.functional.moe_unpermute( hidden_states_unzipped=ffn_out, @@ -377,7 +378,7 @@ def apply_tp( token_prob_unzipped=dst_weights, total_zipped_tokens=x.shape[0], num_experts=layer.num_experts, - using_weighted_combine=True, + using_weighted_combine=False, ) return fused_moe_out