Skip to content

Commit 92fdcf7

Browse files
authored
[Others] fix allreduce fusion accurate issue in ep + tp mode (#7947)
* fix accurate issue * fix allreduce accurate issue in ep + tp mode * add test * fix conflict
1 parent ed93530 commit 92fdcf7

3 files changed

Lines changed: 91 additions & 2 deletions

File tree

fastdeploy/model_executor/layers/normalization.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,8 @@ def __init__(
127127
self.tp_group = self.fd_config.parallel_config.tp_group
128128
is_input_norm = prefix.endswith(".input_layernorm")
129129
self.enable_all_reduce_fusion = fd_config.parallel_config.enable_flashinfer_allreduce_fusion and (
130-
("post_attention_layernorm" in prefix) or (("input_layernorm" in prefix and layer_id != 0))
130+
("post_attention_layernorm" in prefix)
131+
or (("input_layernorm" in prefix and layer_id != 0) and not fd_config.parallel_config.use_ep)
131132
)
132133

133134
self.is_last_norm = prefix.endswith(".norm")

fastdeploy/model_executor/models/glm4_moe.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@ def __init__(
6464
reduce_results: bool = True,
6565
) -> None:
6666
super().__init__()
67+
self.enable_all_reduce_fusion = (
68+
fd_config.parallel_config.enable_flashinfer_allreduce_fusion and not reduce_results
69+
)
6770
# shared experts not split when use_sequence_parallel_moe in ep + tp
6871
if (
6972
fd_config.parallel_config.use_sequence_parallel_moe
@@ -101,7 +104,7 @@ def __init__(
101104
output_size=fd_config.model_config.hidden_size,
102105
with_bias=False,
103106
reduce_results=reduce_results,
104-
enable_all_reduce_fusion=fd_config.parallel_config.enable_flashinfer_allreduce_fusion,
107+
enable_all_reduce_fusion=self.enable_all_reduce_fusion,
105108
)
106109

107110
self.act_fn = SiluAndMul(

tests/layers/trtllm_allreduce_rms_fusion.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,91 @@
2323
import paddle.distributed as dist
2424

2525

26+
class TestGlm4MoeMLPEnableAllReduceFusion(unittest.TestCase):
27+
"""Cover Glm4MoeMLP.__init__ line 67:
28+
29+
self.enable_all_reduce_fusion = (
30+
fd_config.parallel_config.enable_flashinfer_allreduce_fusion and not reduce_results
31+
)
32+
33+
The flag must also be propagated into the down_proj (RowParallelLinear) so
34+
fused-allreduce kicks in at that layer.
35+
"""
36+
37+
def _make_fd_config(self, enable_fusion: bool):
38+
from types import SimpleNamespace
39+
40+
mc = SimpleNamespace(
41+
hidden_size=16,
42+
hidden_act="silu",
43+
moe_layer_start_index=0,
44+
)
45+
pc = SimpleNamespace(
46+
tensor_parallel_size=1,
47+
expert_parallel_size=1,
48+
tensor_parallel_rank=0,
49+
tp_group=None,
50+
enable_flashinfer_allreduce_fusion=enable_fusion,
51+
use_sequence_parallel_moe=False,
52+
)
53+
return SimpleNamespace(model_config=mc, parallel_config=pc)
54+
55+
def _build_mlp(self, enable_fusion: bool, reduce_results: bool):
56+
"""Construct Glm4MoeMLP with all heavy linears stubbed and capture the
57+
kwargs passed to RowParallelLinear (the down_proj branch we care about)."""
58+
from fastdeploy.model_executor.models import glm4_moe
59+
60+
captured = {}
61+
62+
class _StubLinear(paddle.nn.Layer):
63+
def __init__(self, *args, **kwargs):
64+
super().__init__()
65+
66+
def forward(self, x):
67+
return x
68+
69+
class _RowRecorder(_StubLinear):
70+
def __init__(self, *args, **kwargs):
71+
captured["down_proj"] = kwargs
72+
super().__init__(*args, **kwargs)
73+
74+
with (
75+
patch.object(glm4_moe, "MergedColumnParallelLinear", _StubLinear),
76+
patch.object(glm4_moe, "RowParallelLinear", _RowRecorder),
77+
patch.object(glm4_moe, "MergedReplicatedLinear", _StubLinear),
78+
patch.object(glm4_moe, "ReplicatedLinear", _StubLinear),
79+
patch.object(glm4_moe, "SiluAndMul", _StubLinear),
80+
):
81+
mlp = glm4_moe.Glm4MoeMLP(
82+
fd_config=self._make_fd_config(enable_fusion=enable_fusion),
83+
intermediate_size=8,
84+
layer_id=0,
85+
reduce_results=reduce_results,
86+
)
87+
return mlp, captured
88+
89+
def test_fusion_true_when_flag_on_and_reduce_results_false(self):
90+
"""True iff flashinfer fusion is enabled AND reduce_results=False."""
91+
mlp, captured = self._build_mlp(enable_fusion=True, reduce_results=False)
92+
self.assertTrue(mlp.enable_all_reduce_fusion)
93+
# Flag must be forwarded into down_proj.
94+
self.assertTrue(captured["down_proj"]["enable_all_reduce_fusion"])
95+
self.assertFalse(captured["down_proj"]["reduce_results"])
96+
97+
def test_fusion_false_when_reduce_results_true(self):
98+
"""reduce_results=True forces fusion off even if flag is set."""
99+
mlp, captured = self._build_mlp(enable_fusion=True, reduce_results=True)
100+
self.assertFalse(mlp.enable_all_reduce_fusion)
101+
self.assertFalse(captured["down_proj"]["enable_all_reduce_fusion"])
102+
self.assertTrue(captured["down_proj"]["reduce_results"])
103+
104+
def test_fusion_false_when_flag_disabled(self):
105+
"""flashinfer fusion flag off -> fusion off regardless of reduce_results."""
106+
mlp, captured = self._build_mlp(enable_fusion=False, reduce_results=False)
107+
self.assertFalse(mlp.enable_all_reduce_fusion)
108+
self.assertFalse(captured["down_proj"]["enable_all_reduce_fusion"])
109+
110+
26111
class TestFlashInferAllReduceResidualRMSNorm(unittest.TestCase):
27112
"""Test FlashInfer AllReduce + Residual + RMSNorm fused operator"""
28113

0 commit comments

Comments
 (0)