Skip to content

Commit 6594cf7

Browse files
authored
fix accurate issue (#7923)
1 parent 8bb4479 commit 6594cf7

2 files changed

Lines changed: 208 additions & 4 deletions

File tree

fastdeploy/model_executor/layers/normalization.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -243,14 +243,21 @@ def forward(
243243

244244
if residual_input is None:
245245
residual_out = x
246+
use_allreduce_fused = (
247+
self.enable_all_reduce_fusion
248+
and self.tp_size > 1
249+
and x.shape[0] <= 2048
250+
and residual_input is not None
251+
and current_platform.is_cuda()
252+
)
246253
if proxy_rmsnorm is None:
247254
if current_platform.is_gcu():
248255
if residual_input is None:
249256
norm_out = rms_norm(x, self.weight, self.eps)
250257
return norm_out.astype(x_dtype), residual_out
251258
norm_out = self.norm_func(x, residual_input, self.weight, self.eps)
252259
# enable trtllm all reduce fusion
253-
elif self.enable_all_reduce_fusion and x.shape[0] <= 2048:
260+
elif use_allreduce_fused:
254261
norm_out = flashinfer_allreduce_residual_rmsnorm(
255262
fd_config=self.fd_config, input_tensor=x, residual=residual_input, weight=self.weight, eps=self.eps
256263
)
@@ -276,9 +283,19 @@ def forward(
276283
quant_min_bound=self.quant_min_bound,
277284
)
278285
else:
279-
if residual_input is not None:
280-
x = x + residual_input
281-
norm_out = proxy_rmsnorm(x, self.weight, self.eps), x
286+
if use_allreduce_fused:
287+
norm_out = flashinfer_allreduce_residual_rmsnorm(
288+
fd_config=self.fd_config,
289+
input_tensor=x,
290+
residual=residual_input,
291+
weight=self.weight,
292+
eps=self.eps,
293+
)
294+
assert norm_out[0] is not None, "Trtllm-all-reduce fusion failed!"
295+
else:
296+
if residual_input is not None:
297+
x = x + residual_input
298+
norm_out = proxy_rmsnorm(x, self.weight, self.eps), x
282299

283300
out = norm_out[0].astype(x_dtype)
284301
if residual_input is not None:

tests/layers/trtllm_allreduce_rms_fusion.py

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,193 @@ def test_cleanup_workspace_function(self):
543543
mock_manager.cleanup.assert_called_once()
544544

545545

546+
class TestRMSNormProxyAllreduceFused(unittest.TestCase):
547+
@classmethod
548+
def setUpClass(cls):
549+
# The outer test_run_distributed in test_trtllm_allreduce_rms_fusion.py
550+
# has already done paddle.set_device + init_parallel_env, so we don't
551+
# repeat that here. (unittest.main runs in the same process.)
552+
cls.tp_size = dist.get_world_size()
553+
cls.tp_rank = dist.get_rank()
554+
555+
def _make_fd_config(self, enable_fusion: bool):
556+
"""Mock fd_config with the minimal attributes RMSNorm.__init__ touches."""
557+
fd_config = Mock()
558+
fd_config.parallel_config = Mock()
559+
fd_config.parallel_config.tensor_parallel_size = self.tp_size
560+
fd_config.parallel_config.tensor_parallel_rank = self.tp_rank
561+
fd_config.parallel_config.tp_group = dist.get_group()
562+
fd_config.parallel_config.expert_parallel_size = 1
563+
fd_config.parallel_config.enable_flashinfer_allreduce_fusion = enable_fusion
564+
fd_config.parallel_config.use_sequence_parallel_moe = False
565+
fd_config.model_config = Mock()
566+
fd_config.model_config.moe_layer_start_index = -1
567+
fd_config.quant_config = None
568+
return fd_config
569+
570+
def _build_rmsnorm(self, enable_fusion: bool, hidden_size: int, layer_id: int = 1):
571+
"""Build a real RMSNorm whose enable_all_reduce_fusion resolves to
572+
`enable_fusion` (use post_attention_layernorm prefix to ensure the
573+
prefix-match in __init__ passes)."""
574+
from fastdeploy.model_executor.layers.normalization import RMSNorm
575+
576+
fd_config = self._make_fd_config(enable_fusion=enable_fusion)
577+
norm = RMSNorm(
578+
fd_config=fd_config,
579+
hidden_size=hidden_size,
580+
eps=1e-6,
581+
prefix=f"model.layers.{layer_id}.post_attention_layernorm",
582+
layer_id=layer_id,
583+
dtype="bfloat16",
584+
)
585+
# Initialize weight to a known reproducible value (constant=1.0 by default).
586+
with paddle.no_grad():
587+
paddle.seed(2024)
588+
new_w = paddle.randn([hidden_size], dtype=paddle.bfloat16)
589+
dist.broadcast(new_w, src=0)
590+
norm.weight.set_value(new_w)
591+
return norm
592+
593+
@staticmethod
594+
def _proxy_rmsnorm_fn(x, weight, eps):
595+
"""Stand-in for phi rmsnorm used as proxy_rmsnorm — standard formula
596+
in fp32 to keep reference numerics clean."""
597+
x_fp32 = x.astype("float32")
598+
var = x_fp32.pow(2).mean(axis=-1, keepdim=True)
599+
out = x_fp32 * paddle.rsqrt(var + eps)
600+
out = out * weight.astype("float32")
601+
return out.astype(x.dtype)
602+
603+
def _reference(self, x_partial, residual, weight, eps):
604+
"""Manual: all_reduce(x_partial) + residual, then standard RMSNorm.
605+
Mirrors what proxy path WOULD produce after explicit allreduce+add."""
606+
x = x_partial.clone()
607+
dist.all_reduce(x, op=dist.ReduceOp.SUM)
608+
residual_out = x + residual
609+
norm_out = self._proxy_rmsnorm_fn(residual_out, weight, eps)
610+
return norm_out, residual_out
611+
612+
def _make_inputs(self, token_num, hidden_size, seed=123):
613+
"""Each rank gets a different x_partial (simulates RowParallelLinear's
614+
un-reduced output); residual is identical across ranks."""
615+
paddle.seed(seed + self.tp_rank * 7919)
616+
x_partial = paddle.randn([token_num, hidden_size], dtype=paddle.bfloat16) * 0.1
617+
paddle.seed(seed + 99)
618+
residual = paddle.randn([token_num, hidden_size], dtype=paddle.bfloat16)
619+
dist.broadcast(residual, src=0)
620+
return x_partial, residual
621+
622+
def _assert_close_bf16(self, a, b, rtol=5e-2, atol=5e-2, msg=""):
623+
a32 = a.astype("float32").numpy()
624+
b32 = b.astype("float32").numpy()
625+
np.testing.assert_allclose(a32, b32, rtol=rtol, atol=atol, err_msg=msg)
626+
627+
# ---------- Tests ----------
628+
629+
def test_proxy_path_takes_fused_branch(self):
630+
"""fusion=on, tp>1, shape<=2048, residual!=None
631+
-> proxy branch picks flashinfer_allreduce_residual_rmsnorm.
632+
Verify by patching the symbol and asserting it was called.
633+
"""
634+
if self.tp_size < 2:
635+
self.skipTest("Requires tp_size >= 2")
636+
hidden = 512
637+
norm = self._build_rmsnorm(enable_fusion=True, hidden_size=hidden)
638+
self.assertTrue(norm.enable_all_reduce_fusion)
639+
x_partial, residual = self._make_inputs(token_num=64, hidden_size=hidden)
640+
641+
# Patch within the normalization module's namespace.
642+
with patch(
643+
"fastdeploy.model_executor.layers.normalization.flashinfer_allreduce_residual_rmsnorm",
644+
wraps=__import__(
645+
"fastdeploy.model_executor.layers.normalization", fromlist=["flashinfer_allreduce_residual_rmsnorm"]
646+
).flashinfer_allreduce_residual_rmsnorm,
647+
) as spy:
648+
out, res = norm.forward(
649+
x_partial.clone(),
650+
residual_input=residual.clone(),
651+
proxy_rmsnorm=self._proxy_rmsnorm_fn,
652+
)
653+
spy.assert_called_once()
654+
655+
# Numerics: must match reference (allreduce + add + std rmsnorm).
656+
ref_norm, ref_res = self._reference(x_partial, residual, norm.weight, norm.eps)
657+
self._assert_close_bf16(out, ref_norm, msg="proxy fused-branch norm output mismatch")
658+
self._assert_close_bf16(res, ref_res, msg="proxy fused-branch residual mismatch")
659+
660+
def test_proxy_path_falls_back_when_fusion_disabled(self):
661+
"""fusion=off -> proxy branch must call proxy_rmsnorm directly,
662+
no fused allreduce path used. Input is treated as already-reduced."""
663+
if self.tp_size < 2:
664+
self.skipTest("Requires tp_size >= 2")
665+
hidden = 512
666+
norm = self._build_rmsnorm(enable_fusion=False, hidden_size=hidden)
667+
self.assertFalse(norm.enable_all_reduce_fusion)
668+
669+
# Each rank uses the SAME x (already-reduced) — that's the contract
670+
# when fusion is off (RowParallelLinear has done its own allreduce).
671+
paddle.seed(777)
672+
x = paddle.randn([64, hidden], dtype=paddle.bfloat16) * 0.1
673+
dist.broadcast(x, src=0)
674+
residual = paddle.randn([64, hidden], dtype=paddle.bfloat16)
675+
dist.broadcast(residual, src=0)
676+
677+
proxy_called = {"n": 0}
678+
679+
def proxy_spy(_x, _w, _eps):
680+
proxy_called["n"] += 1
681+
return self._proxy_rmsnorm_fn(_x, _w, _eps)
682+
683+
with patch(
684+
"fastdeploy.model_executor.layers.normalization.flashinfer_allreduce_residual_rmsnorm"
685+
) as fused_spy:
686+
out, res = norm.forward(
687+
x.clone(),
688+
residual_input=residual.clone(),
689+
proxy_rmsnorm=proxy_spy,
690+
)
691+
fused_spy.assert_not_called()
692+
693+
self.assertEqual(proxy_called["n"], 1, "proxy_rmsnorm must be invoked exactly once")
694+
695+
# Reference: x is already full -> just add + rmsnorm, no allreduce.
696+
residual_full = x + residual
697+
ref_norm = self._proxy_rmsnorm_fn(residual_full, norm.weight, norm.eps)
698+
self._assert_close_bf16(out, ref_norm, msg="fallback norm output mismatch")
699+
self._assert_close_bf16(res, residual_full, msg="fallback residual mismatch")
700+
701+
def test_proxy_path_falls_back_when_token_too_large(self):
702+
"""fusion=on but shape[0] > 2048 -> proxy branch must NOT call fused;
703+
in this regime upstream RowParallelLinear didn't skip its own
704+
all-reduce, so x is already full and proxy_rmsnorm is invoked directly."""
705+
if self.tp_size < 2:
706+
self.skipTest("Requires tp_size >= 2")
707+
hidden = 256
708+
norm = self._build_rmsnorm(enable_fusion=True, hidden_size=hidden)
709+
# shape[0] > 2048 forces use_allreduce_fused=False
710+
token_num = 2049
711+
paddle.seed(555)
712+
x = paddle.randn([token_num, hidden], dtype=paddle.bfloat16) * 0.1
713+
dist.broadcast(x, src=0)
714+
residual = paddle.randn([token_num, hidden], dtype=paddle.bfloat16)
715+
dist.broadcast(residual, src=0)
716+
717+
with patch(
718+
"fastdeploy.model_executor.layers.normalization.flashinfer_allreduce_residual_rmsnorm"
719+
) as fused_spy:
720+
out, res = norm.forward(
721+
x.clone(),
722+
residual_input=residual.clone(),
723+
proxy_rmsnorm=self._proxy_rmsnorm_fn,
724+
)
725+
fused_spy.assert_not_called()
726+
727+
residual_full = x + residual
728+
ref_norm = self._proxy_rmsnorm_fn(residual_full, norm.weight, norm.eps)
729+
self._assert_close_bf16(out, ref_norm, msg="large-shape fallback norm mismatch")
730+
self._assert_close_bf16(res, residual_full, msg="large-shape fallback residual mismatch")
731+
732+
546733
if __name__ == "__main__":
547734
"""Run tests directly (called by subprocess after distributed launch)"""
548735
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)