Skip to content

Commit fefbcff

Browse files
BingooYangroot
andauthored
[Cherry-Pick] [BugFix] fix all reduce fusion accurate issue (#7923) (#7922)
* fix accurate issue * fix acc issue in ep + tp mode --------- Co-authored-by: root <root@tjzj-inf-sci-k8s-bzz2-0271.tjzj.baidu.com>
1 parent 1e7ee22 commit fefbcff

3 files changed

Lines changed: 238 additions & 13 deletions

File tree

fastdeploy/model_executor/layers/normalization.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,8 @@ def __init__(
124124
self.tp_group = self.fd_config.parallel_config.tp_group
125125
is_input_norm = prefix.endswith(".input_layernorm")
126126
self.enable_all_reduce_fusion = fd_config.parallel_config.enable_flashinfer_allreduce_fusion and (
127-
("post_attention_layernorm" in prefix) or (("input_layernorm" in prefix and layer_id != 0))
127+
("post_attention_layernorm" in prefix)
128+
or (("input_layernorm" in prefix and layer_id != 0) and not fd_config.parallel_config.use_ep)
128129
)
129130

130131
self.is_last_norm = prefix.endswith(".norm")
@@ -239,14 +240,21 @@ def forward(
239240

240241
if residual_input is None:
241242
residual_out = x
243+
use_allreduce_fused = (
244+
self.enable_all_reduce_fusion
245+
and self.tp_size > 1
246+
and x.shape[0] <= 2048
247+
and residual_input is not None
248+
and current_platform.is_cuda()
249+
)
242250
if proxy_rmsnorm is None:
243251
if current_platform.is_gcu():
244252
if residual_input is None:
245253
norm_out = rms_norm(x, self.weight, self.eps)
246254
return norm_out.astype(x_dtype), residual_out
247255
norm_out = self.norm_func(x, residual_input, self.weight, self.eps)
248256
# enable trtllm all reduce fusion
249-
elif self.enable_all_reduce_fusion and x.shape[0] <= 2048:
257+
elif use_allreduce_fused:
250258
norm_out = flashinfer_allreduce_residual_rmsnorm(
251259
fd_config=self.fd_config, input_tensor=x, residual=residual_input, weight=self.weight, eps=self.eps
252260
)
@@ -272,9 +280,19 @@ def forward(
272280
quant_min_bound=self.quant_min_bound,
273281
)
274282
else:
275-
if residual_input is not None:
276-
x = x + residual_input
277-
norm_out = proxy_rmsnorm(x, self.weight, self.eps), x
283+
if use_allreduce_fused:
284+
norm_out = flashinfer_allreduce_residual_rmsnorm(
285+
fd_config=self.fd_config,
286+
input_tensor=x,
287+
residual=residual_input,
288+
weight=self.weight,
289+
eps=self.eps,
290+
)
291+
assert norm_out[0] is not None, "Trtllm-all-reduce fusion failed!"
292+
else:
293+
if residual_input is not None:
294+
x = x + residual_input
295+
norm_out = proxy_rmsnorm(x, self.weight, self.eps), x
278296

279297
out = norm_out[0].astype(x_dtype)
280298
if residual_input is not None:

fastdeploy/model_executor/models/glm4_moe.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,10 @@ def __init__(
6464
reduce_results: bool = True,
6565
) -> None:
6666
super().__init__()
67+
self.enable_all_reduce_fusion = (
68+
fd_config.parallel_config.enable_flashinfer_allreduce_fusion and not reduce_results
69+
)
70+
6771
# shared experts not split when use_sequence_parallel_moe in ep + tp
6872
if (
6973
fd_config.parallel_config.use_sequence_parallel_moe
@@ -101,7 +105,7 @@ def __init__(
101105
output_size=fd_config.model_config.hidden_size,
102106
with_bias=False,
103107
reduce_results=reduce_results,
104-
enable_all_reduce_fusion=fd_config.parallel_config.enable_flashinfer_allreduce_fusion,
108+
enable_all_reduce_fusion=self.enable_all_reduce_fusion,
105109
)
106110

107111
self.act_fn = SiluAndMul(

tests/layers/trtllm_allreduce_rms_fusion.py

Lines changed: 210 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,13 @@ def setUp(self):
6767
paddle.seed(42)
6868
np.random.seed(42)
6969

70-
self.dtype = paddle.float32
70+
# NOTE: switched fp32 -> bf16 to mirror real model dtype on B GPUs.
71+
# Combined with use_oneshot=None below, this exercises the bf16 +
72+
# oneshot Lamport path, which is the suspected garbled-output path
73+
# on Blackwell (sm100).
74+
self.dtype = paddle.bfloat16
7175
self.token_num = 128
72-
self.hidden_dim = 768
76+
self.hidden_dim = 4096
7377
self.eps = 1e-6
7478
self.epsilon = 1e-6
7579
self.max_token_num = 2048
@@ -144,7 +148,9 @@ def flashinfer_rms_fuse(self, input_tensor, residual, weight, eps):
144148
weight=weight,
145149
eps=eps,
146150
max_token_num=self.max_token_num,
147-
use_oneshot=False,
151+
# NOTE: do NOT pass use_oneshot=False here. We want the auto path
152+
# (use_oneshot=None) so the oneshot Lamport kernel is exercised,
153+
# matching how normalization.py calls it in the real model.
148154
)
149155
return norm_out, residual_out
150156

@@ -235,11 +241,21 @@ def test_accuracy_fused_vs_reference(self):
235241
flashinfer_output, flashinfer_res = self.flashinfer_rms_fuse(
236242
input_tensor.clone(), residual.clone(), weight.clone(), self.eps
237243
)
244+
245+
# bf16 needs much looser tolerance than fp32. Cast to fp32 for
246+
# comparison to avoid numpy bf16 issues.
247+
if self.dtype == paddle.bfloat16:
248+
rtol, atol = 5e-2, 5e-2
249+
to_np = lambda t: t.astype("float32").numpy() # noqa: E731
250+
else:
251+
rtol, atol = 1e-5, 1e-5
252+
to_np = lambda t: t.numpy() # noqa: E731
253+
238254
# Verify results
239-
np.testing.assert_allclose(fused_output.numpy(), reference_output.numpy(), rtol=1e-5, atol=1e-5)
240-
np.testing.assert_allclose(ref_res.numpy(), paddle_res.numpy(), rtol=1e-5, atol=1e-5)
241-
np.testing.assert_allclose(flashinfer_output.numpy(), reference_output.numpy(), rtol=1e-5, atol=1e-5)
242-
np.testing.assert_allclose(ref_res.numpy(), flashinfer_res.numpy(), rtol=1e-5, atol=1e-5)
255+
np.testing.assert_allclose(to_np(fused_output), to_np(reference_output), rtol=rtol, atol=atol)
256+
np.testing.assert_allclose(to_np(ref_res), to_np(paddle_res), rtol=rtol, atol=atol)
257+
np.testing.assert_allclose(to_np(flashinfer_output), to_np(reference_output), rtol=rtol, atol=atol)
258+
np.testing.assert_allclose(to_np(ref_res), to_np(flashinfer_res), rtol=rtol, atol=atol)
243259

244260

245261
class TestFlashInferWorkspaceManager(unittest.TestCase):
@@ -569,6 +585,193 @@ def test_cleanup_workspace_function(self):
569585
mock_manager.cleanup.assert_called_once()
570586

571587

588+
class TestRMSNormProxyAllreduceFused(unittest.TestCase):
589+
@classmethod
590+
def setUpClass(cls):
591+
# The outer test_run_distributed in test_trtllm_allreduce_rms_fusion.py
592+
# has already done paddle.set_device + init_parallel_env, so we don't
593+
# repeat that here. (unittest.main runs in the same process.)
594+
cls.tp_size = dist.get_world_size()
595+
cls.tp_rank = dist.get_rank()
596+
597+
def _make_fd_config(self, enable_fusion: bool):
598+
"""Mock fd_config with the minimal attributes RMSNorm.__init__ touches."""
599+
fd_config = Mock()
600+
fd_config.parallel_config = Mock()
601+
fd_config.parallel_config.tensor_parallel_size = self.tp_size
602+
fd_config.parallel_config.tensor_parallel_rank = self.tp_rank
603+
fd_config.parallel_config.tp_group = dist.get_group()
604+
fd_config.parallel_config.expert_parallel_size = 1
605+
fd_config.parallel_config.enable_flashinfer_allreduce_fusion = enable_fusion
606+
fd_config.parallel_config.use_sequence_parallel_moe = False
607+
fd_config.model_config = Mock()
608+
fd_config.model_config.moe_layer_start_index = -1
609+
fd_config.quant_config = None
610+
return fd_config
611+
612+
def _build_rmsnorm(self, enable_fusion: bool, hidden_size: int, layer_id: int = 1):
613+
"""Build a real RMSNorm whose enable_all_reduce_fusion resolves to
614+
`enable_fusion` (use post_attention_layernorm prefix to ensure the
615+
prefix-match in __init__ passes)."""
616+
from fastdeploy.model_executor.layers.normalization import RMSNorm
617+
618+
fd_config = self._make_fd_config(enable_fusion=enable_fusion)
619+
norm = RMSNorm(
620+
fd_config=fd_config,
621+
hidden_size=hidden_size,
622+
eps=1e-6,
623+
prefix=f"model.layers.{layer_id}.post_attention_layernorm",
624+
layer_id=layer_id,
625+
dtype="bfloat16",
626+
)
627+
# Initialize weight to a known reproducible value (constant=1.0 by default).
628+
with paddle.no_grad():
629+
paddle.seed(2024)
630+
new_w = paddle.randn([hidden_size], dtype=paddle.bfloat16)
631+
dist.broadcast(new_w, src=0)
632+
norm.weight.set_value(new_w)
633+
return norm
634+
635+
@staticmethod
636+
def _proxy_rmsnorm_fn(x, weight, eps):
637+
"""Stand-in for phi rmsnorm used as proxy_rmsnorm — standard formula
638+
in fp32 to keep reference numerics clean."""
639+
x_fp32 = x.astype("float32")
640+
var = x_fp32.pow(2).mean(axis=-1, keepdim=True)
641+
out = x_fp32 * paddle.rsqrt(var + eps)
642+
out = out * weight.astype("float32")
643+
return out.astype(x.dtype)
644+
645+
def _reference(self, x_partial, residual, weight, eps):
646+
"""Manual: all_reduce(x_partial) + residual, then standard RMSNorm.
647+
Mirrors what proxy path WOULD produce after explicit allreduce+add."""
648+
x = x_partial.clone()
649+
dist.all_reduce(x, op=dist.ReduceOp.SUM)
650+
residual_out = x + residual
651+
norm_out = self._proxy_rmsnorm_fn(residual_out, weight, eps)
652+
return norm_out, residual_out
653+
654+
def _make_inputs(self, token_num, hidden_size, seed=123):
655+
"""Each rank gets a different x_partial (simulates RowParallelLinear's
656+
un-reduced output); residual is identical across ranks."""
657+
paddle.seed(seed + self.tp_rank * 7919)
658+
x_partial = paddle.randn([token_num, hidden_size], dtype=paddle.bfloat16) * 0.1
659+
paddle.seed(seed + 99)
660+
residual = paddle.randn([token_num, hidden_size], dtype=paddle.bfloat16)
661+
dist.broadcast(residual, src=0)
662+
return x_partial, residual
663+
664+
def _assert_close_bf16(self, a, b, rtol=5e-2, atol=5e-2, msg=""):
665+
a32 = a.astype("float32").numpy()
666+
b32 = b.astype("float32").numpy()
667+
np.testing.assert_allclose(a32, b32, rtol=rtol, atol=atol, err_msg=msg)
668+
669+
# ---------- Tests ----------
670+
671+
def test_proxy_path_takes_fused_branch(self):
672+
"""fusion=on, tp>1, shape<=2048, residual!=None
673+
-> proxy branch picks flashinfer_allreduce_residual_rmsnorm.
674+
Verify by patching the symbol and asserting it was called.
675+
"""
676+
if self.tp_size < 2:
677+
self.skipTest("Requires tp_size >= 2")
678+
hidden = 512
679+
norm = self._build_rmsnorm(enable_fusion=True, hidden_size=hidden)
680+
self.assertTrue(norm.enable_all_reduce_fusion)
681+
x_partial, residual = self._make_inputs(token_num=64, hidden_size=hidden)
682+
683+
# Patch within the normalization module's namespace.
684+
with patch(
685+
"fastdeploy.model_executor.layers.normalization.flashinfer_allreduce_residual_rmsnorm",
686+
wraps=__import__(
687+
"fastdeploy.model_executor.layers.normalization", fromlist=["flashinfer_allreduce_residual_rmsnorm"]
688+
).flashinfer_allreduce_residual_rmsnorm,
689+
) as spy:
690+
out, res = norm.forward(
691+
x_partial.clone(),
692+
residual_input=residual.clone(),
693+
proxy_rmsnorm=self._proxy_rmsnorm_fn,
694+
)
695+
spy.assert_called_once()
696+
697+
# Numerics: must match reference (allreduce + add + std rmsnorm).
698+
ref_norm, ref_res = self._reference(x_partial, residual, norm.weight, norm.eps)
699+
self._assert_close_bf16(out, ref_norm, msg="proxy fused-branch norm output mismatch")
700+
self._assert_close_bf16(res, ref_res, msg="proxy fused-branch residual mismatch")
701+
702+
def test_proxy_path_falls_back_when_fusion_disabled(self):
703+
"""fusion=off -> proxy branch must call proxy_rmsnorm directly,
704+
no fused allreduce path used. Input is treated as already-reduced."""
705+
if self.tp_size < 2:
706+
self.skipTest("Requires tp_size >= 2")
707+
hidden = 512
708+
norm = self._build_rmsnorm(enable_fusion=False, hidden_size=hidden)
709+
self.assertFalse(norm.enable_all_reduce_fusion)
710+
711+
# Each rank uses the SAME x (already-reduced) — that's the contract
712+
# when fusion is off (RowParallelLinear has done its own allreduce).
713+
paddle.seed(777)
714+
x = paddle.randn([64, hidden], dtype=paddle.bfloat16) * 0.1
715+
dist.broadcast(x, src=0)
716+
residual = paddle.randn([64, hidden], dtype=paddle.bfloat16)
717+
dist.broadcast(residual, src=0)
718+
719+
proxy_called = {"n": 0}
720+
721+
def proxy_spy(_x, _w, _eps):
722+
proxy_called["n"] += 1
723+
return self._proxy_rmsnorm_fn(_x, _w, _eps)
724+
725+
with patch(
726+
"fastdeploy.model_executor.layers.normalization.flashinfer_allreduce_residual_rmsnorm"
727+
) as fused_spy:
728+
out, res = norm.forward(
729+
x.clone(),
730+
residual_input=residual.clone(),
731+
proxy_rmsnorm=proxy_spy,
732+
)
733+
fused_spy.assert_not_called()
734+
735+
self.assertEqual(proxy_called["n"], 1, "proxy_rmsnorm must be invoked exactly once")
736+
737+
# Reference: x is already full -> just add + rmsnorm, no allreduce.
738+
residual_full = x + residual
739+
ref_norm = self._proxy_rmsnorm_fn(residual_full, norm.weight, norm.eps)
740+
self._assert_close_bf16(out, ref_norm, msg="fallback norm output mismatch")
741+
self._assert_close_bf16(res, residual_full, msg="fallback residual mismatch")
742+
743+
def test_proxy_path_falls_back_when_token_too_large(self):
744+
"""fusion=on but shape[0] > 2048 -> proxy branch must NOT call fused;
745+
in this regime upstream RowParallelLinear didn't skip its own
746+
all-reduce, so x is already full and proxy_rmsnorm is invoked directly."""
747+
if self.tp_size < 2:
748+
self.skipTest("Requires tp_size >= 2")
749+
hidden = 256
750+
norm = self._build_rmsnorm(enable_fusion=True, hidden_size=hidden)
751+
# shape[0] > 2048 forces use_allreduce_fused=False
752+
token_num = 2049
753+
paddle.seed(555)
754+
x = paddle.randn([token_num, hidden], dtype=paddle.bfloat16) * 0.1
755+
dist.broadcast(x, src=0)
756+
residual = paddle.randn([token_num, hidden], dtype=paddle.bfloat16)
757+
dist.broadcast(residual, src=0)
758+
759+
with patch(
760+
"fastdeploy.model_executor.layers.normalization.flashinfer_allreduce_residual_rmsnorm"
761+
) as fused_spy:
762+
out, res = norm.forward(
763+
x.clone(),
764+
residual_input=residual.clone(),
765+
proxy_rmsnorm=self._proxy_rmsnorm_fn,
766+
)
767+
fused_spy.assert_not_called()
768+
769+
residual_full = x + residual
770+
ref_norm = self._proxy_rmsnorm_fn(residual_full, norm.weight, norm.eps)
771+
self._assert_close_bf16(out, ref_norm, msg="large-shape fallback norm mismatch")
772+
self._assert_close_bf16(res, residual_full, msg="large-shape fallback residual mismatch")
773+
774+
572775
if __name__ == "__main__":
573776
"""Run tests directly (called by subprocess after distributed launch)"""
574777
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)