Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 53 additions & 52 deletions fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,7 @@
from .fused_moe_backend_base import UnquantizedFusedMoEMethod

if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import (
count_tokens_per_expert_func,
moe_expert_dispatch,
moe_expert_reduce,
)
from fastdeploy.model_executor.ops.gpu import moe_expert_dispatch, moe_expert_reduce

try:
from fastdeploy.model_executor.ops.gpu import (
Expand All @@ -43,6 +39,7 @@
logger.warning("import w4afp8_gemm_scale_permute Failed!")

from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
from fastdeploy.model_executor.layers.quantization.fp8_utils import paddlefleet_ops
from fastdeploy.model_executor.utils import (
TensorTracker,
free_tensor,
Expand All @@ -52,6 +49,12 @@
)


def m_grouped_bf16_gemm_nn_contiguous(x, y, expert_idx_per_token):
out = paddle.empty([x.shape[0], y.shape[-1]], dtype=x.dtype)
paddlefleet_ops.deep_gemm.m_grouped_bf16_gemm_nn_contiguous(x, y, out, expert_idx_per_token)
return out


class CutlassMoEMethod(UnquantizedFusedMoEMethod):
"""
Use Cutlass Group Gemm to compute Fused MoE.
Expand Down Expand Up @@ -155,29 +158,26 @@ def apply_ep_prefill(
if fastdeploy.envs.FD_USE_PHI_MOE_PERMUTE and self.moe_quant_type == "w16a16":
# --- moe_permute / moe_unpermute path ---
recv_topk_idx_i32 = recv_topk_idx.astype(paddle.int32)
(permute_input, permute_indices_per_token, dst_weights, _scale_out) = paddle.nn.functional.moe_permute(
hidden_states=recv_x,
scale=None,
expert_routemap_topk=recv_topk_idx_i32,
expert_prob_topk=recv_topk_weights,
num_experts=layer.num_local_experts,
tokens_per_expert=[],
padding_alignment=128,
override_buffer_size=token_all_num,
(permute_input, permute_indices_per_token, dst_weights, _scale_out, expert_idx_per_token) = (
paddle.nn.functional.moe_permute(
hidden_states=recv_x,
scale=None,
expert_routemap_topk=recv_topk_idx_i32,
expert_prob_topk=recv_topk_weights,
num_experts=layer.num_local_experts,
tokens_per_expert=[],
padding_alignment=128,
override_buffer_size=token_all_num,
return_expert_indices=True,
)
)

token_nums_per_expert_cumsum = count_tokens_per_expert_func(
recv_topk_idx, layer.num_local_experts, True
)[2].cast(paddle.int64)
ffn_out = self.compute_ffn(
layer,
permute_input,
token_nums_per_expert_cumsum,
None,
False,
-1,
None,
None,
out = m_grouped_bf16_gemm_nn_contiguous(
permute_input, getattr(layer, self.added_weight_attrs[0]), expert_idx_per_token
)
out = paddlefleet_ops.fused_swiglu_scale(out, dst_weights)
ffn_out = m_grouped_bf16_gemm_nn_contiguous(
out, getattr(layer, self.added_weight_attrs[1]), expert_idx_per_token
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 apply_ep_prefill 中缺少对 layer.with_bias 的处理

apply_ep_prefill 的新代码路径(FD_USE_PHI_MOE_PERMUTE and self.moe_quant_type == "w16a16")中,没有处理 layer.with_bias 的情况,与 apply_tp 中的处理不一致。

原始 compute_ffn 方法在第 115-117 行有 bias 处理逻辑,apply_tp 在第 370-372 行也有相同的处理,建议在 apply_ep_prefill 中添加相同的 bias 处理代码以保证一致性。


tmp_ffn_out, _out_probs = paddle.nn.functional.moe_unpermute(
Expand All @@ -187,7 +187,7 @@ def apply_ep_prefill(
token_prob_unzipped=dst_weights,
total_zipped_tokens=recv_x.shape[0],
num_experts=layer.num_local_experts,
using_weighted_combine=True,
using_weighted_combine=False,
)
else:
# --- original ep_moe_expert_dispatch / combine path ---
Expand Down Expand Up @@ -339,36 +339,37 @@ def apply_tp(
)
topk_idx_i32 = topk_idx.astype(paddle.int32)
override_buffer_size = x.shape[0] * layer.top_k + layer.num_experts * (128 - 1)
(permute_input, permute_indices_per_token, dst_weights, _scale_out) = ( # zipped_expertwise_rowmap
paddle.nn.functional.moe_permute(
hidden_states=x,
scale=None,
expert_routemap_topk=topk_idx_i32,
expert_prob_topk=topk_weights,
num_experts=layer.num_experts,
tokens_per_expert=[],
padding_alignment=128,
override_buffer_size=override_buffer_size,
)
(
permute_input,
permute_indices_per_token,
dst_weights,
_scale_out,
expert_idx_per_token,
) = paddle.nn.functional.moe_permute( # zipped_expertwise_rowmap
hidden_states=x,
scale=None,
expert_routemap_topk=topk_idx_i32,
expert_prob_topk=topk_weights,
num_experts=layer.num_experts,
tokens_per_expert=[],
padding_alignment=128,
override_buffer_size=override_buffer_size,
return_expert_indices=True,
)

# Row 2 of count_tokens_per_expert_func is the prefix sum token_nums_per_expert.
token_nums_per_expert_cumsum = count_tokens_per_expert_func(topk_idx, layer.num_experts, True)[2].cast(
paddle.int64
)
if topk_ids_hookfunc is not None:
topk_ids_hookfunc(topk_ids=topk_idx)

ffn_out = self.compute_ffn(
layer,
permute_input,
token_nums_per_expert_cumsum,
None, # expert_idx_per_token not needed for w16a16 without bias
False,
-1,
None, # dequant_scale
None, # max_tokens_per_expert
out = m_grouped_bf16_gemm_nn_contiguous(
permute_input, getattr(layer, self.added_weight_attrs[0]), expert_idx_per_token
)
out = paddlefleet_ops.fused_swiglu_scale(out, dst_weights)
ffn_out = m_grouped_bf16_gemm_nn_contiguous(
out, getattr(layer, self.added_weight_attrs[1]), expert_idx_per_token
)
if layer.with_bias:
down_proj_bias_expand = paddle.index_select(layer.down_proj_bias, expert_idx_per_token, axis=0)
ffn_out = paddle.add(ffn_out, down_proj_bias_expand)

fused_moe_out, _out_probs = paddle.nn.functional.moe_unpermute(
hidden_states_unzipped=ffn_out,
Expand All @@ -377,7 +378,7 @@ def apply_tp(
token_prob_unzipped=dst_weights,
total_zipped_tokens=x.shape[0],
num_experts=layer.num_experts,
using_weighted_combine=True,
using_weighted_combine=False,
)
return fused_moe_out

Expand Down
Loading