@@ -67,9 +67,13 @@ def setUp(self):
6767 paddle .seed (42 )
6868 np .random .seed (42 )
6969
70- self .dtype = paddle .float32
70+ # NOTE: switched fp32 -> bf16 to mirror real model dtype on B GPUs.
71+ # Combined with use_oneshot=None below, this exercises the bf16 +
72+ # oneshot Lamport path, which is the suspected garbled-output path
73+ # on Blackwell (sm100).
74+ self .dtype = paddle .bfloat16
7175 self .token_num = 128
72- self .hidden_dim = 768
76+ self .hidden_dim = 4096
7377 self .eps = 1e-6
7478 self .epsilon = 1e-6
7579 self .max_token_num = 2048
@@ -144,7 +148,9 @@ def flashinfer_rms_fuse(self, input_tensor, residual, weight, eps):
144148 weight = weight ,
145149 eps = eps ,
146150 max_token_num = self .max_token_num ,
147- use_oneshot = False ,
151+ # NOTE: do NOT pass use_oneshot=False here. We want the auto path
152+ # (use_oneshot=None) so the oneshot Lamport kernel is exercised,
153+ # matching how normalization.py calls it in the real model.
148154 )
149155 return norm_out , residual_out
150156
@@ -235,11 +241,21 @@ def test_accuracy_fused_vs_reference(self):
235241 flashinfer_output , flashinfer_res = self .flashinfer_rms_fuse (
236242 input_tensor .clone (), residual .clone (), weight .clone (), self .eps
237243 )
244+
245+ # bf16 needs much looser tolerance than fp32. Cast to fp32 for
246+ # comparison to avoid numpy bf16 issues.
247+ if self .dtype == paddle .bfloat16 :
248+ rtol , atol = 5e-2 , 5e-2
249+ to_np = lambda t : t .astype ("float32" ).numpy () # noqa: E731
250+ else :
251+ rtol , atol = 1e-5 , 1e-5
252+ to_np = lambda t : t .numpy () # noqa: E731
253+
238254 # Verify results
239- np .testing .assert_allclose (fused_output . numpy ( ), reference_output . numpy ( ), rtol = 1e-5 , atol = 1e-5 )
240- np .testing .assert_allclose (ref_res . numpy ( ), paddle_res . numpy ( ), rtol = 1e-5 , atol = 1e-5 )
241- np .testing .assert_allclose (flashinfer_output . numpy ( ), reference_output . numpy ( ), rtol = 1e-5 , atol = 1e-5 )
242- np .testing .assert_allclose (ref_res . numpy ( ), flashinfer_res . numpy ( ), rtol = 1e-5 , atol = 1e-5 )
255+ np .testing .assert_allclose (to_np ( fused_output ), to_np ( reference_output ), rtol = rtol , atol = atol )
256+ np .testing .assert_allclose (to_np ( ref_res ), to_np ( paddle_res ), rtol = rtol , atol = atol )
257+ np .testing .assert_allclose (to_np ( flashinfer_output ), to_np ( reference_output ), rtol = rtol , atol = atol )
258+ np .testing .assert_allclose (to_np ( ref_res ), to_np ( flashinfer_res ), rtol = rtol , atol = atol )
243259
244260
245261class TestFlashInferWorkspaceManager (unittest .TestCase ):
@@ -569,6 +585,193 @@ def test_cleanup_workspace_function(self):
569585 mock_manager .cleanup .assert_called_once ()
570586
571587
588+ class TestRMSNormProxyAllreduceFused (unittest .TestCase ):
589+ @classmethod
590+ def setUpClass (cls ):
591+ # The outer test_run_distributed in test_trtllm_allreduce_rms_fusion.py
592+ # has already done paddle.set_device + init_parallel_env, so we don't
593+ # repeat that here. (unittest.main runs in the same process.)
594+ cls .tp_size = dist .get_world_size ()
595+ cls .tp_rank = dist .get_rank ()
596+
597+ def _make_fd_config (self , enable_fusion : bool ):
598+ """Mock fd_config with the minimal attributes RMSNorm.__init__ touches."""
599+ fd_config = Mock ()
600+ fd_config .parallel_config = Mock ()
601+ fd_config .parallel_config .tensor_parallel_size = self .tp_size
602+ fd_config .parallel_config .tensor_parallel_rank = self .tp_rank
603+ fd_config .parallel_config .tp_group = dist .get_group ()
604+ fd_config .parallel_config .expert_parallel_size = 1
605+ fd_config .parallel_config .enable_flashinfer_allreduce_fusion = enable_fusion
606+ fd_config .parallel_config .use_sequence_parallel_moe = False
607+ fd_config .model_config = Mock ()
608+ fd_config .model_config .moe_layer_start_index = - 1
609+ fd_config .quant_config = None
610+ return fd_config
611+
612+ def _build_rmsnorm (self , enable_fusion : bool , hidden_size : int , layer_id : int = 1 ):
613+ """Build a real RMSNorm whose enable_all_reduce_fusion resolves to
614+ `enable_fusion` (use post_attention_layernorm prefix to ensure the
615+ prefix-match in __init__ passes)."""
616+ from fastdeploy .model_executor .layers .normalization import RMSNorm
617+
618+ fd_config = self ._make_fd_config (enable_fusion = enable_fusion )
619+ norm = RMSNorm (
620+ fd_config = fd_config ,
621+ hidden_size = hidden_size ,
622+ eps = 1e-6 ,
623+ prefix = f"model.layers.{ layer_id } .post_attention_layernorm" ,
624+ layer_id = layer_id ,
625+ dtype = "bfloat16" ,
626+ )
627+ # Initialize weight to a known reproducible value (constant=1.0 by default).
628+ with paddle .no_grad ():
629+ paddle .seed (2024 )
630+ new_w = paddle .randn ([hidden_size ], dtype = paddle .bfloat16 )
631+ dist .broadcast (new_w , src = 0 )
632+ norm .weight .set_value (new_w )
633+ return norm
634+
635+ @staticmethod
636+ def _proxy_rmsnorm_fn (x , weight , eps ):
637+ """Stand-in for phi rmsnorm used as proxy_rmsnorm — standard formula
638+ in fp32 to keep reference numerics clean."""
639+ x_fp32 = x .astype ("float32" )
640+ var = x_fp32 .pow (2 ).mean (axis = - 1 , keepdim = True )
641+ out = x_fp32 * paddle .rsqrt (var + eps )
642+ out = out * weight .astype ("float32" )
643+ return out .astype (x .dtype )
644+
645+ def _reference (self , x_partial , residual , weight , eps ):
646+ """Manual: all_reduce(x_partial) + residual, then standard RMSNorm.
647+ Mirrors what proxy path WOULD produce after explicit allreduce+add."""
648+ x = x_partial .clone ()
649+ dist .all_reduce (x , op = dist .ReduceOp .SUM )
650+ residual_out = x + residual
651+ norm_out = self ._proxy_rmsnorm_fn (residual_out , weight , eps )
652+ return norm_out , residual_out
653+
654+ def _make_inputs (self , token_num , hidden_size , seed = 123 ):
655+ """Each rank gets a different x_partial (simulates RowParallelLinear's
656+ un-reduced output); residual is identical across ranks."""
657+ paddle .seed (seed + self .tp_rank * 7919 )
658+ x_partial = paddle .randn ([token_num , hidden_size ], dtype = paddle .bfloat16 ) * 0.1
659+ paddle .seed (seed + 99 )
660+ residual = paddle .randn ([token_num , hidden_size ], dtype = paddle .bfloat16 )
661+ dist .broadcast (residual , src = 0 )
662+ return x_partial , residual
663+
664+ def _assert_close_bf16 (self , a , b , rtol = 5e-2 , atol = 5e-2 , msg = "" ):
665+ a32 = a .astype ("float32" ).numpy ()
666+ b32 = b .astype ("float32" ).numpy ()
667+ np .testing .assert_allclose (a32 , b32 , rtol = rtol , atol = atol , err_msg = msg )
668+
669+ # ---------- Tests ----------
670+
671+ def test_proxy_path_takes_fused_branch (self ):
672+ """fusion=on, tp>1, shape<=2048, residual!=None
673+ -> proxy branch picks flashinfer_allreduce_residual_rmsnorm.
674+ Verify by patching the symbol and asserting it was called.
675+ """
676+ if self .tp_size < 2 :
677+ self .skipTest ("Requires tp_size >= 2" )
678+ hidden = 512
679+ norm = self ._build_rmsnorm (enable_fusion = True , hidden_size = hidden )
680+ self .assertTrue (norm .enable_all_reduce_fusion )
681+ x_partial , residual = self ._make_inputs (token_num = 64 , hidden_size = hidden )
682+
683+ # Patch within the normalization module's namespace.
684+ with patch (
685+ "fastdeploy.model_executor.layers.normalization.flashinfer_allreduce_residual_rmsnorm" ,
686+ wraps = __import__ (
687+ "fastdeploy.model_executor.layers.normalization" , fromlist = ["flashinfer_allreduce_residual_rmsnorm" ]
688+ ).flashinfer_allreduce_residual_rmsnorm ,
689+ ) as spy :
690+ out , res = norm .forward (
691+ x_partial .clone (),
692+ residual_input = residual .clone (),
693+ proxy_rmsnorm = self ._proxy_rmsnorm_fn ,
694+ )
695+ spy .assert_called_once ()
696+
697+ # Numerics: must match reference (allreduce + add + std rmsnorm).
698+ ref_norm , ref_res = self ._reference (x_partial , residual , norm .weight , norm .eps )
699+ self ._assert_close_bf16 (out , ref_norm , msg = "proxy fused-branch norm output mismatch" )
700+ self ._assert_close_bf16 (res , ref_res , msg = "proxy fused-branch residual mismatch" )
701+
702+ def test_proxy_path_falls_back_when_fusion_disabled (self ):
703+ """fusion=off -> proxy branch must call proxy_rmsnorm directly,
704+ no fused allreduce path used. Input is treated as already-reduced."""
705+ if self .tp_size < 2 :
706+ self .skipTest ("Requires tp_size >= 2" )
707+ hidden = 512
708+ norm = self ._build_rmsnorm (enable_fusion = False , hidden_size = hidden )
709+ self .assertFalse (norm .enable_all_reduce_fusion )
710+
711+ # Each rank uses the SAME x (already-reduced) — that's the contract
712+ # when fusion is off (RowParallelLinear has done its own allreduce).
713+ paddle .seed (777 )
714+ x = paddle .randn ([64 , hidden ], dtype = paddle .bfloat16 ) * 0.1
715+ dist .broadcast (x , src = 0 )
716+ residual = paddle .randn ([64 , hidden ], dtype = paddle .bfloat16 )
717+ dist .broadcast (residual , src = 0 )
718+
719+ proxy_called = {"n" : 0 }
720+
721+ def proxy_spy (_x , _w , _eps ):
722+ proxy_called ["n" ] += 1
723+ return self ._proxy_rmsnorm_fn (_x , _w , _eps )
724+
725+ with patch (
726+ "fastdeploy.model_executor.layers.normalization.flashinfer_allreduce_residual_rmsnorm"
727+ ) as fused_spy :
728+ out , res = norm .forward (
729+ x .clone (),
730+ residual_input = residual .clone (),
731+ proxy_rmsnorm = proxy_spy ,
732+ )
733+ fused_spy .assert_not_called ()
734+
735+ self .assertEqual (proxy_called ["n" ], 1 , "proxy_rmsnorm must be invoked exactly once" )
736+
737+ # Reference: x is already full -> just add + rmsnorm, no allreduce.
738+ residual_full = x + residual
739+ ref_norm = self ._proxy_rmsnorm_fn (residual_full , norm .weight , norm .eps )
740+ self ._assert_close_bf16 (out , ref_norm , msg = "fallback norm output mismatch" )
741+ self ._assert_close_bf16 (res , residual_full , msg = "fallback residual mismatch" )
742+
743+ def test_proxy_path_falls_back_when_token_too_large (self ):
744+ """fusion=on but shape[0] > 2048 -> proxy branch must NOT call fused;
745+ in this regime upstream RowParallelLinear didn't skip its own
746+ all-reduce, so x is already full and proxy_rmsnorm is invoked directly."""
747+ if self .tp_size < 2 :
748+ self .skipTest ("Requires tp_size >= 2" )
749+ hidden = 256
750+ norm = self ._build_rmsnorm (enable_fusion = True , hidden_size = hidden )
751+ # shape[0] > 2048 forces use_allreduce_fused=False
752+ token_num = 2049
753+ paddle .seed (555 )
754+ x = paddle .randn ([token_num , hidden ], dtype = paddle .bfloat16 ) * 0.1
755+ dist .broadcast (x , src = 0 )
756+ residual = paddle .randn ([token_num , hidden ], dtype = paddle .bfloat16 )
757+ dist .broadcast (residual , src = 0 )
758+
759+ with patch (
760+ "fastdeploy.model_executor.layers.normalization.flashinfer_allreduce_residual_rmsnorm"
761+ ) as fused_spy :
762+ out , res = norm .forward (
763+ x .clone (),
764+ residual_input = residual .clone (),
765+ proxy_rmsnorm = self ._proxy_rmsnorm_fn ,
766+ )
767+ fused_spy .assert_not_called ()
768+
769+ residual_full = x + residual
770+ ref_norm = self ._proxy_rmsnorm_fn (residual_full , norm .weight , norm .eps )
771+ self ._assert_close_bf16 (out , ref_norm , msg = "large-shape fallback norm mismatch" )
772+ self ._assert_close_bf16 (res , residual_full , msg = "large-shape fallback residual mismatch" )
773+
774+
572775if __name__ == "__main__" :
573776 """Run tests directly (called by subprocess after distributed launch)"""
574777 unittest .main (verbosity = 2 )
0 commit comments