@@ -737,7 +737,8 @@ def test_te_grouped_vs_sequential_quantize(need_4_gpus):
737737
738738@pytest .mark .parametrize ("ep_size" , [1 , 2 ])
739739@pytest .mark .parametrize ("moe_grouped_gemm" , [True , False ])
740- def test_layer_sync_moe_local_experts_amax (ep_size , moe_grouped_gemm ):
740+ @pytest .mark .parametrize ("shared_moe_weight_scale" , [True , False ])
741+ def test_layer_sync_moe_local_experts_amax (ep_size , moe_grouped_gemm , shared_moe_weight_scale ):
741742 """Test expert model parallel synchronization."""
742743 size = torch .cuda .device_count ()
743744 if size < ep_size :
@@ -749,12 +750,15 @@ def test_layer_sync_moe_local_experts_amax(ep_size, moe_grouped_gemm):
749750 _test_layer_sync_moe_local_experts_amax ,
750751 ep_size ,
751752 moe_grouped_gemm ,
753+ shared_moe_weight_scale ,
752754 ),
753755 backend = "nccl" ,
754756 )
755757
756758
757- def _test_layer_sync_moe_local_experts_amax (ep_size , moe_grouped_gemm , rank , size ):
759+ def _test_layer_sync_moe_local_experts_amax (
760+ ep_size , moe_grouped_gemm , shared_moe_weight_scale , rank , size
761+ ):
758762 initialize_for_megatron (
759763 tensor_model_parallel_size = 1 ,
760764 pipeline_model_parallel_size = 1 ,
@@ -776,7 +780,7 @@ def _test_layer_sync_moe_local_experts_amax(ep_size, moe_grouped_gemm, rank, siz
776780
777781 # Sync amax across local experts in each layer
778782 for layer in model .decoder .layers :
779- layer .mlp .experts .layer_sync_moe_local_experts_amax ()
783+ layer .mlp .experts .layer_sync_moe_local_experts_amax (shared_moe_weight_scale )
780784
781785 for layer in model .decoder .layers :
782786 fc1_amax = None
@@ -793,6 +797,22 @@ def _test_layer_sync_moe_local_experts_amax(ep_size, moe_grouped_gemm, rank, siz
793797 else :
794798 assert torch .allclose (fc2_amax , expert .linear_fc2 .input_quantizer .amax )
795799
800+ if shared_moe_weight_scale :
801+ for layer in model .decoder .layers :
802+ fc1_amax = None
803+ fc2_amax = None
804+ for expert in layer .mlp .experts .local_experts :
805+ assert expert .linear_fc1 .weight_quantizer .amax is not None
806+ assert expert .linear_fc2 .weight_quantizer .amax is not None
807+ if fc1_amax is None :
808+ fc1_amax = expert .linear_fc1 .weight_quantizer .amax
809+ else :
810+ assert torch .allclose (fc1_amax , expert .linear_fc1 .weight_quantizer .amax )
811+ if fc2_amax is None :
812+ fc2_amax = expert .linear_fc2 .weight_quantizer .amax
813+ else :
814+ assert torch .allclose (fc2_amax , expert .linear_fc2 .weight_quantizer .amax )
815+
796816
797817def _test_expert_model_parallel_amax_sync (
798818 tp_size , ep_size , etp_size , moe_grouped_gemm , config , rank , size
0 commit comments