@@ -735,6 +735,85 @@ def test_te_grouped_vs_sequential_quantize(need_4_gpus):
735735 )
736736
737737
738+ @pytest .mark .parametrize ("ep_size" , [1 , 2 ])
739+ @pytest .mark .parametrize ("moe_grouped_gemm" , [True , False ])
740+ @pytest .mark .parametrize ("shared_moe_weight_scale" , [True , False ])
741+ def test_layer_sync_moe_local_experts_amax (ep_size , moe_grouped_gemm , shared_moe_weight_scale ):
742+ """Test expert model parallel synchronization."""
743+ size = torch .cuda .device_count ()
744+ if size < ep_size :
745+ pytest .skip (f"Requires at least { ep_size } GPUs for expert model parallel test" )
746+
747+ spawn_multiprocess_job (
748+ size = size ,
749+ job = partial (
750+ _test_layer_sync_moe_local_experts_amax ,
751+ ep_size ,
752+ moe_grouped_gemm ,
753+ shared_moe_weight_scale ,
754+ ),
755+ backend = "nccl" ,
756+ )
757+
758+
759+ def _test_layer_sync_moe_local_experts_amax (
760+ ep_size , moe_grouped_gemm , shared_moe_weight_scale , rank , size
761+ ):
762+ initialize_for_megatron (
763+ tensor_model_parallel_size = 1 ,
764+ pipeline_model_parallel_size = 1 ,
765+ expert_model_parallel_size = ep_size ,
766+ expert_tensor_parallel_size = 1 ,
767+ seed = SEED ,
768+ )
769+ model = _gpt_model_provider (
770+ tp_size = 1 ,
771+ ep_size = ep_size ,
772+ etp_size = 1 ,
773+ hidden_size = 256 ,
774+ moe_grouped_gemm = moe_grouped_gemm ,
775+ use_te = moe_grouped_gemm ,
776+ num_moe_experts = 8 ,
777+ transformer_impl = "modelopt" ,
778+ )
779+ model = mtq .quantize (model , mtq .FP8_DEFAULT_CFG , get_forward (model ))
780+
781+ # Sync amax across local experts in each layer
782+ for layer in model .decoder .layers :
783+ layer .mlp .experts .layer_sync_moe_local_experts_amax (shared_moe_weight_scale )
784+
785+ for layer in model .decoder .layers :
786+ fc1_amax = None
787+ fc2_amax = None
788+ for expert in layer .mlp .experts .local_experts :
789+ assert expert .linear_fc1 .input_quantizer .amax is not None
790+ assert expert .linear_fc2 .input_quantizer .amax is not None
791+ if fc1_amax is None :
792+ fc1_amax = expert .linear_fc1 .input_quantizer .amax
793+ else :
794+ assert torch .allclose (fc1_amax , expert .linear_fc1 .input_quantizer .amax )
795+ if fc2_amax is None :
796+ fc2_amax = expert .linear_fc2 .input_quantizer .amax
797+ else :
798+ assert torch .allclose (fc2_amax , expert .linear_fc2 .input_quantizer .amax )
799+
800+ if shared_moe_weight_scale :
801+ for layer in model .decoder .layers :
802+ fc1_amax = None
803+ fc2_amax = None
804+ for expert in layer .mlp .experts .local_experts :
805+ assert expert .linear_fc1 .weight_quantizer .amax is not None
806+ assert expert .linear_fc2 .weight_quantizer .amax is not None
807+ if fc1_amax is None :
808+ fc1_amax = expert .linear_fc1 .weight_quantizer .amax
809+ else :
810+ assert torch .allclose (fc1_amax , expert .linear_fc1 .weight_quantizer .amax )
811+ if fc2_amax is None :
812+ fc2_amax = expert .linear_fc2 .weight_quantizer .amax
813+ else :
814+ assert torch .allclose (fc2_amax , expert .linear_fc2 .weight_quantizer .amax )
815+
816+
738817def _test_expert_model_parallel_amax_sync (
739818 tp_size , ep_size , etp_size , moe_grouped_gemm , config , rank , size
740819):
@@ -815,9 +894,6 @@ def test_expert_parallel_sync(config, ep_size, etp_size, moe_grouped_gemm):
815894 if size < ep_size * etp_size :
816895 pytest .skip (f"Requires at least { ep_size * etp_size } GPUs for expert model parallel test" )
817896
818- if moe_grouped_gemm :
819- pytest .skip ("TEGroupedMLP is not enabled in Megatron-LM currently" )
820-
821897 spawn_multiprocess_job (
822898 size = size ,
823899 job = partial (
0 commit comments