@@ -543,6 +543,193 @@ def test_cleanup_workspace_function(self):
543543 mock_manager .cleanup .assert_called_once ()
544544
545545
546+ class TestRMSNormProxyAllreduceFused (unittest .TestCase ):
547+ @classmethod
548+ def setUpClass (cls ):
549+ # The outer test_run_distributed in test_trtllm_allreduce_rms_fusion.py
550+ # has already done paddle.set_device + init_parallel_env, so we don't
551+ # repeat that here. (unittest.main runs in the same process.)
552+ cls .tp_size = dist .get_world_size ()
553+ cls .tp_rank = dist .get_rank ()
554+
555+ def _make_fd_config (self , enable_fusion : bool ):
556+ """Mock fd_config with the minimal attributes RMSNorm.__init__ touches."""
557+ fd_config = Mock ()
558+ fd_config .parallel_config = Mock ()
559+ fd_config .parallel_config .tensor_parallel_size = self .tp_size
560+ fd_config .parallel_config .tensor_parallel_rank = self .tp_rank
561+ fd_config .parallel_config .tp_group = dist .get_group ()
562+ fd_config .parallel_config .expert_parallel_size = 1
563+ fd_config .parallel_config .enable_flashinfer_allreduce_fusion = enable_fusion
564+ fd_config .parallel_config .use_sequence_parallel_moe = False
565+ fd_config .model_config = Mock ()
566+ fd_config .model_config .moe_layer_start_index = - 1
567+ fd_config .quant_config = None
568+ return fd_config
569+
570+ def _build_rmsnorm (self , enable_fusion : bool , hidden_size : int , layer_id : int = 1 ):
571+ """Build a real RMSNorm whose enable_all_reduce_fusion resolves to
572+ `enable_fusion` (use post_attention_layernorm prefix to ensure the
573+ prefix-match in __init__ passes)."""
574+ from fastdeploy .model_executor .layers .normalization import RMSNorm
575+
576+ fd_config = self ._make_fd_config (enable_fusion = enable_fusion )
577+ norm = RMSNorm (
578+ fd_config = fd_config ,
579+ hidden_size = hidden_size ,
580+ eps = 1e-6 ,
581+ prefix = f"model.layers.{ layer_id } .post_attention_layernorm" ,
582+ layer_id = layer_id ,
583+ dtype = "bfloat16" ,
584+ )
585+ # Initialize weight to a known reproducible value (constant=1.0 by default).
586+ with paddle .no_grad ():
587+ paddle .seed (2024 )
588+ new_w = paddle .randn ([hidden_size ], dtype = paddle .bfloat16 )
589+ dist .broadcast (new_w , src = 0 )
590+ norm .weight .set_value (new_w )
591+ return norm
592+
593+ @staticmethod
594+ def _proxy_rmsnorm_fn (x , weight , eps ):
595+ """Stand-in for phi rmsnorm used as proxy_rmsnorm — standard formula
596+ in fp32 to keep reference numerics clean."""
597+ x_fp32 = x .astype ("float32" )
598+ var = x_fp32 .pow (2 ).mean (axis = - 1 , keepdim = True )
599+ out = x_fp32 * paddle .rsqrt (var + eps )
600+ out = out * weight .astype ("float32" )
601+ return out .astype (x .dtype )
602+
603+ def _reference (self , x_partial , residual , weight , eps ):
604+ """Manual: all_reduce(x_partial) + residual, then standard RMSNorm.
605+ Mirrors what proxy path WOULD produce after explicit allreduce+add."""
606+ x = x_partial .clone ()
607+ dist .all_reduce (x , op = dist .ReduceOp .SUM )
608+ residual_out = x + residual
609+ norm_out = self ._proxy_rmsnorm_fn (residual_out , weight , eps )
610+ return norm_out , residual_out
611+
612+ def _make_inputs (self , token_num , hidden_size , seed = 123 ):
613+ """Each rank gets a different x_partial (simulates RowParallelLinear's
614+ un-reduced output); residual is identical across ranks."""
615+ paddle .seed (seed + self .tp_rank * 7919 )
616+ x_partial = paddle .randn ([token_num , hidden_size ], dtype = paddle .bfloat16 ) * 0.1
617+ paddle .seed (seed + 99 )
618+ residual = paddle .randn ([token_num , hidden_size ], dtype = paddle .bfloat16 )
619+ dist .broadcast (residual , src = 0 )
620+ return x_partial , residual
621+
622+ def _assert_close_bf16 (self , a , b , rtol = 5e-2 , atol = 5e-2 , msg = "" ):
623+ a32 = a .astype ("float32" ).numpy ()
624+ b32 = b .astype ("float32" ).numpy ()
625+ np .testing .assert_allclose (a32 , b32 , rtol = rtol , atol = atol , err_msg = msg )
626+
627+ # ---------- Tests ----------
628+
629+ def test_proxy_path_takes_fused_branch (self ):
630+ """fusion=on, tp>1, shape<=2048, residual!=None
631+ -> proxy branch picks flashinfer_allreduce_residual_rmsnorm.
632+ Verify by patching the symbol and asserting it was called.
633+ """
634+ if self .tp_size < 2 :
635+ self .skipTest ("Requires tp_size >= 2" )
636+ hidden = 512
637+ norm = self ._build_rmsnorm (enable_fusion = True , hidden_size = hidden )
638+ self .assertTrue (norm .enable_all_reduce_fusion )
639+ x_partial , residual = self ._make_inputs (token_num = 64 , hidden_size = hidden )
640+
641+ # Patch within the normalization module's namespace.
642+ with patch (
643+ "fastdeploy.model_executor.layers.normalization.flashinfer_allreduce_residual_rmsnorm" ,
644+ wraps = __import__ (
645+ "fastdeploy.model_executor.layers.normalization" , fromlist = ["flashinfer_allreduce_residual_rmsnorm" ]
646+ ).flashinfer_allreduce_residual_rmsnorm ,
647+ ) as spy :
648+ out , res = norm .forward (
649+ x_partial .clone (),
650+ residual_input = residual .clone (),
651+ proxy_rmsnorm = self ._proxy_rmsnorm_fn ,
652+ )
653+ spy .assert_called_once ()
654+
655+ # Numerics: must match reference (allreduce + add + std rmsnorm).
656+ ref_norm , ref_res = self ._reference (x_partial , residual , norm .weight , norm .eps )
657+ self ._assert_close_bf16 (out , ref_norm , msg = "proxy fused-branch norm output mismatch" )
658+ self ._assert_close_bf16 (res , ref_res , msg = "proxy fused-branch residual mismatch" )
659+
660+ def test_proxy_path_falls_back_when_fusion_disabled (self ):
661+ """fusion=off -> proxy branch must call proxy_rmsnorm directly,
662+ no fused allreduce path used. Input is treated as already-reduced."""
663+ if self .tp_size < 2 :
664+ self .skipTest ("Requires tp_size >= 2" )
665+ hidden = 512
666+ norm = self ._build_rmsnorm (enable_fusion = False , hidden_size = hidden )
667+ self .assertFalse (norm .enable_all_reduce_fusion )
668+
669+ # Each rank uses the SAME x (already-reduced) — that's the contract
670+ # when fusion is off (RowParallelLinear has done its own allreduce).
671+ paddle .seed (777 )
672+ x = paddle .randn ([64 , hidden ], dtype = paddle .bfloat16 ) * 0.1
673+ dist .broadcast (x , src = 0 )
674+ residual = paddle .randn ([64 , hidden ], dtype = paddle .bfloat16 )
675+ dist .broadcast (residual , src = 0 )
676+
677+ proxy_called = {"n" : 0 }
678+
679+ def proxy_spy (_x , _w , _eps ):
680+ proxy_called ["n" ] += 1
681+ return self ._proxy_rmsnorm_fn (_x , _w , _eps )
682+
683+ with patch (
684+ "fastdeploy.model_executor.layers.normalization.flashinfer_allreduce_residual_rmsnorm"
685+ ) as fused_spy :
686+ out , res = norm .forward (
687+ x .clone (),
688+ residual_input = residual .clone (),
689+ proxy_rmsnorm = proxy_spy ,
690+ )
691+ fused_spy .assert_not_called ()
692+
693+ self .assertEqual (proxy_called ["n" ], 1 , "proxy_rmsnorm must be invoked exactly once" )
694+
695+ # Reference: x is already full -> just add + rmsnorm, no allreduce.
696+ residual_full = x + residual
697+ ref_norm = self ._proxy_rmsnorm_fn (residual_full , norm .weight , norm .eps )
698+ self ._assert_close_bf16 (out , ref_norm , msg = "fallback norm output mismatch" )
699+ self ._assert_close_bf16 (res , residual_full , msg = "fallback residual mismatch" )
700+
701+ def test_proxy_path_falls_back_when_token_too_large (self ):
702+ """fusion=on but shape[0] > 2048 -> proxy branch must NOT call fused;
703+ in this regime upstream RowParallelLinear didn't skip its own
704+ all-reduce, so x is already full and proxy_rmsnorm is invoked directly."""
705+ if self .tp_size < 2 :
706+ self .skipTest ("Requires tp_size >= 2" )
707+ hidden = 256
708+ norm = self ._build_rmsnorm (enable_fusion = True , hidden_size = hidden )
709+ # shape[0] > 2048 forces use_allreduce_fused=False
710+ token_num = 2049
711+ paddle .seed (555 )
712+ x = paddle .randn ([token_num , hidden ], dtype = paddle .bfloat16 ) * 0.1
713+ dist .broadcast (x , src = 0 )
714+ residual = paddle .randn ([token_num , hidden ], dtype = paddle .bfloat16 )
715+ dist .broadcast (residual , src = 0 )
716+
717+ with patch (
718+ "fastdeploy.model_executor.layers.normalization.flashinfer_allreduce_residual_rmsnorm"
719+ ) as fused_spy :
720+ out , res = norm .forward (
721+ x .clone (),
722+ residual_input = residual .clone (),
723+ proxy_rmsnorm = self ._proxy_rmsnorm_fn ,
724+ )
725+ fused_spy .assert_not_called ()
726+
727+ residual_full = x + residual
728+ ref_norm = self ._proxy_rmsnorm_fn (residual_full , norm .weight , norm .eps )
729+ self ._assert_close_bf16 (out , ref_norm , msg = "large-shape fallback norm mismatch" )
730+ self ._assert_close_bf16 (res , residual_full , msg = "large-shape fallback residual mismatch" )
731+
732+
546733if __name__ == "__main__" :
547734 """Run tests directly (called by subprocess after distributed launch)"""
548735 unittest .main (verbosity = 2 )
0 commit comments