2626from paddleformers .utils .log import logger
2727
2828from fastdeploy .config import FDConfig
29+ from fastdeploy .distributed .communication import tensor_model_parallel_all_reduce
2930from fastdeploy .model_executor .forward_meta import ForwardMeta
3031from fastdeploy .model_executor .graph_optimization .decorator import (
3132 support_graph_optimization ,
@@ -160,8 +161,16 @@ def __init__(
160161 default_initializer = paddle .nn .initializer .Constant (0 ),
161162 )
162163
164+ # In pure-TP mode (tp>1, ep=1) both branches return partial sums, so we
165+ # defer the all-reduce to after combining them — saving one collective.
166+ # In all other modes (EP, EP+attn-TP, no parallelism) each branch handles
167+ # its own reduction internally (reduce_results default=True), so we must
168+ # NOT add an extra all-reduce here.
169+ self .merge_ffn_tp = self .use_tp and not self .use_ep
170+
163171 self .experts = FusedMoE (
164172 fd_config ,
173+ reduce_results = not self .merge_ffn_tp ,
165174 renormalize = self .norm_topk_prob ,
166175 moe_intermediate_size = fd_config .model_config .moe_intermediate_size ,
167176 num_experts = fd_config .model_config .n_routed_experts ,
@@ -182,14 +191,16 @@ def __init__(
182191 intermediate_size = shared_experts_intermediate_size ,
183192 layer_id = layer_id ,
184193 prefix = f"{ prefix } .shared_experts" ,
194+ reduce_results = not self .merge_ffn_tp ,
185195 )
186196
187197 def forward (self , x , forward_meta : ForwardMeta = None ):
188198 out = self .experts (x , self .gate , forward_meta )
189199 if self .n_shared_experts > 0 :
190- shared_experts_out = self .shared_experts (x )
191- out = out + shared_experts_out
192-
200+ out = out + self .shared_experts (x )
201+ if self .merge_ffn_tp :
202+ # Both branches produced partial sums; combine first, then single all-reduce.
203+ out = tensor_model_parallel_all_reduce (out , self .tp_group )
193204 return out
194205
195206
0 commit comments