From 82f0d7f393d7633a062a22bee26a6f7f560922ae Mon Sep 17 00:00:00 2001 From: ckl117 Date: Fri, 10 Apr 2026 16:30:34 +0800 Subject: [PATCH 1/3] RL support bf16 deep_gemm moe --- .../layers/moe/fused_moe_cutlass_backend.py | 100 ++++++++---------- 1 file changed, 46 insertions(+), 54 deletions(-) diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py index f927cd8c5ee..e22631052fd 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py @@ -28,11 +28,7 @@ from .fused_moe_backend_base import UnquantizedFusedMoEMethod if current_platform.is_cuda(): - from fastdeploy.model_executor.ops.gpu import ( - count_tokens_per_expert_func, - moe_expert_dispatch, - moe_expert_reduce, - ) + from fastdeploy.model_executor.ops.gpu import moe_expert_dispatch, moe_expert_reduce try: from fastdeploy.model_executor.ops.gpu import ( @@ -43,6 +39,7 @@ logger.warning("import w4afp8_gemm_scale_permute Failed!") from fastdeploy.model_executor.layers.moe.moe import get_moe_scores +from fastdeploy.model_executor.layers.quantization.fp8_utils import paddlefleet_ops from fastdeploy.model_executor.utils import ( TensorTracker, free_tensor, @@ -52,6 +49,12 @@ ) +def deep_batch_gemm(x, y, expert_idx_per_token): + out = paddle.empty([x.shape[0], y.shape[-1]], dtype=x.dtype) + paddlefleet_ops.deep_gemm.m_grouped_bf16_gemm_nn_contiguous(x, y, out, expert_idx_per_token) + return out + + class CutlassMoEMethod(UnquantizedFusedMoEMethod): """ Use Cutlass Group Gemm to compute Fused MoE. @@ -155,30 +158,22 @@ def apply_ep_prefill( if fastdeploy.envs.FD_USE_PHI_MOE_PERMUTE and self.moe_quant_type == "w16a16": # --- moe_permute / moe_unpermute path --- recv_topk_idx_i32 = recv_topk_idx.astype(paddle.int32) - (permute_input, permute_indices_per_token, dst_weights, _scale_out) = paddle.nn.functional.moe_permute( - hidden_states=recv_x, - scale=None, - expert_routemap_topk=recv_topk_idx_i32, - expert_prob_topk=recv_topk_weights, - num_experts=layer.num_local_experts, - tokens_per_expert=[], - padding_alignment=128, - override_buffer_size=token_all_num, + (permute_input, permute_indices_per_token, dst_weights, _scale_out, expert_idx_per_token) = ( + paddle.nn.functional.moe_permute( + hidden_states=recv_x, + scale=None, + expert_routemap_topk=recv_topk_idx_i32, + expert_prob_topk=recv_topk_weights, + num_experts=layer.num_local_experts, + tokens_per_expert=[], + padding_alignment=128, + override_buffer_size=token_all_num, + ) ) - token_nums_per_expert_cumsum = count_tokens_per_expert_func( - recv_topk_idx, layer.num_local_experts, True - )[2].cast(paddle.int64) - ffn_out = self.compute_ffn( - layer, - permute_input, - token_nums_per_expert_cumsum, - None, - False, - -1, - None, - None, - ) + out = deep_batch_gemm(permute_input, getattr(layer, self.added_weight_attrs[0]), expert_idx_per_token) + out = paddlefleet_ops.fused_swiglu_scale(out, dst_weights) + ffn_out = deep_batch_gemm(out, getattr(layer, self.added_weight_attrs[1]), expert_idx_per_token) tmp_ffn_out, _out_probs = paddle.nn.functional.moe_unpermute( hidden_states_unzipped=ffn_out, @@ -187,7 +182,7 @@ def apply_ep_prefill( token_prob_unzipped=dst_weights, total_zipped_tokens=recv_x.shape[0], num_experts=layer.num_local_experts, - using_weighted_combine=True, + using_weighted_combine=False, ) else: # --- original ep_moe_expert_dispatch / combine path --- @@ -339,36 +334,33 @@ def apply_tp( ) topk_idx_i32 = topk_idx.astype(paddle.int32) override_buffer_size = x.shape[0] * layer.top_k + layer.num_experts * (128 - 1) - (permute_input, permute_indices_per_token, dst_weights, _scale_out) = ( # zipped_expertwise_rowmap - paddle.nn.functional.moe_permute( - hidden_states=x, - scale=None, - expert_routemap_topk=topk_idx_i32, - expert_prob_topk=topk_weights, - num_experts=layer.num_experts, - tokens_per_expert=[], - padding_alignment=128, - override_buffer_size=override_buffer_size, - ) + ( + permute_input, + permute_indices_per_token, + dst_weights, + _scale_out, + expert_idx_per_token, + ) = paddle.nn.functional.moe_permute( # zipped_expertwise_rowmap + hidden_states=x, + scale=None, + expert_routemap_topk=topk_idx_i32, + expert_prob_topk=topk_weights, + num_experts=layer.num_experts, + tokens_per_expert=[], + padding_alignment=128, + override_buffer_size=override_buffer_size, + return_expert_indices=True, ) - # Row 2 of count_tokens_per_expert_func is the prefix sum token_nums_per_expert. - token_nums_per_expert_cumsum = count_tokens_per_expert_func(topk_idx, layer.num_experts, True)[2].cast( - paddle.int64 - ) if topk_ids_hookfunc is not None: topk_ids_hookfunc(topk_ids=topk_idx) - ffn_out = self.compute_ffn( - layer, - permute_input, - token_nums_per_expert_cumsum, - None, # expert_idx_per_token not needed for w16a16 without bias - False, - -1, - None, # dequant_scale - None, # max_tokens_per_expert - ) + out = deep_batch_gemm(permute_input, getattr(layer, self.added_weight_attrs[0]), expert_idx_per_token) + out = paddlefleet_ops.fused_swiglu_scale(out, dst_weights) + ffn_out = deep_batch_gemm(out, getattr(layer, self.added_weight_attrs[1]), expert_idx_per_token) + if layer.with_bias: + down_proj_bias_expand = paddle.index_select(layer.down_proj_bias, expert_idx_per_token, axis=0) + ffn_out = paddle.add(ffn_out, down_proj_bias_expand) fused_moe_out, _out_probs = paddle.nn.functional.moe_unpermute( hidden_states_unzipped=ffn_out, @@ -377,7 +369,7 @@ def apply_tp( token_prob_unzipped=dst_weights, total_zipped_tokens=x.shape[0], num_experts=layer.num_experts, - using_weighted_combine=True, + using_weighted_combine=False, ) return fused_moe_out From 8eb2f1e741348c5bad398f510df39d05eaab497f Mon Sep 17 00:00:00 2001 From: ckl117 Date: Fri, 10 Apr 2026 17:43:25 +0800 Subject: [PATCH 2/3] add return_expert_indices=True in ep_prefill --- .../model_executor/layers/moe/fused_moe_cutlass_backend.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py index e22631052fd..0ea28ef6dd1 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py @@ -168,6 +168,7 @@ def apply_ep_prefill( tokens_per_expert=[], padding_alignment=128, override_buffer_size=token_all_num, + return_expert_indices=True, ) ) From 7a6ed79ad14efdc53a953aed6d43f769d5dece5f Mon Sep 17 00:00:00 2001 From: ckl117 Date: Fri, 10 Apr 2026 17:44:35 +0800 Subject: [PATCH 3/3] m_grouped_bf16_gemm_nn_contiguous --- .../layers/moe/fused_moe_cutlass_backend.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py index 0ea28ef6dd1..67a5d29bf53 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py @@ -49,7 +49,7 @@ ) -def deep_batch_gemm(x, y, expert_idx_per_token): +def m_grouped_bf16_gemm_nn_contiguous(x, y, expert_idx_per_token): out = paddle.empty([x.shape[0], y.shape[-1]], dtype=x.dtype) paddlefleet_ops.deep_gemm.m_grouped_bf16_gemm_nn_contiguous(x, y, out, expert_idx_per_token) return out @@ -172,9 +172,13 @@ def apply_ep_prefill( ) ) - out = deep_batch_gemm(permute_input, getattr(layer, self.added_weight_attrs[0]), expert_idx_per_token) + out = m_grouped_bf16_gemm_nn_contiguous( + permute_input, getattr(layer, self.added_weight_attrs[0]), expert_idx_per_token + ) out = paddlefleet_ops.fused_swiglu_scale(out, dst_weights) - ffn_out = deep_batch_gemm(out, getattr(layer, self.added_weight_attrs[1]), expert_idx_per_token) + ffn_out = m_grouped_bf16_gemm_nn_contiguous( + out, getattr(layer, self.added_weight_attrs[1]), expert_idx_per_token + ) tmp_ffn_out, _out_probs = paddle.nn.functional.moe_unpermute( hidden_states_unzipped=ffn_out, @@ -356,9 +360,13 @@ def apply_tp( if topk_ids_hookfunc is not None: topk_ids_hookfunc(topk_ids=topk_idx) - out = deep_batch_gemm(permute_input, getattr(layer, self.added_weight_attrs[0]), expert_idx_per_token) + out = m_grouped_bf16_gemm_nn_contiguous( + permute_input, getattr(layer, self.added_weight_attrs[0]), expert_idx_per_token + ) out = paddlefleet_ops.fused_swiglu_scale(out, dst_weights) - ffn_out = deep_batch_gemm(out, getattr(layer, self.added_weight_attrs[1]), expert_idx_per_token) + ffn_out = m_grouped_bf16_gemm_nn_contiguous( + out, getattr(layer, self.added_weight_attrs[1]), expert_idx_per_token + ) if layer.with_bias: down_proj_bias_expand = paddle.index_select(layer.down_proj_bias, expert_idx_per_token, axis=0) ffn_out = paddle.add(ffn_out, down_proj_bias_expand)