@@ -735,37 +735,63 @@ def test_te_grouped_vs_sequential_quantize(need_4_gpus):
735735 )
736736
737737
738- def test_layer_sync_moe_local_experts_amax (moe_grouped_gemm ):
738+ @pytest .mark .parametrize ("ep_size" , [1 , 2 ])
739+ @pytest .mark .parametrize ("moe_grouped_gemm" , [True , False ])
740+ def test_layer_sync_moe_local_experts_amax (ep_size , moe_grouped_gemm ):
741+ """Test expert model parallel synchronization."""
742+ size = torch .cuda .device_count ()
743+ if size < ep_size :
744+ pytest .skip (f"Requires at least { ep_size } GPUs for expert model parallel test" )
745+
746+ spawn_multiprocess_job (
747+ size = size ,
748+ job = partial (
749+ _test_layer_sync_moe_local_experts_amax ,
750+ ep_size ,
751+ moe_grouped_gemm ,
752+ ),
753+ backend = "nccl" ,
754+ )
755+
756+
757+ def _test_layer_sync_moe_local_experts_amax (ep_size , moe_grouped_gemm , rank , size ):
739758 initialize_for_megatron (
740759 tensor_model_parallel_size = 1 ,
741760 pipeline_model_parallel_size = 1 ,
742- expert_model_parallel_size = 2 ,
761+ expert_model_parallel_size = ep_size ,
743762 expert_tensor_parallel_size = 1 ,
744763 seed = SEED ,
745764 )
746765 model = _gpt_model_provider (
747766 tp_size = 1 ,
748- ep_size = 2 ,
767+ ep_size = ep_size ,
749768 etp_size = 1 ,
750769 hidden_size = 256 ,
751770 moe_grouped_gemm = moe_grouped_gemm ,
752771 use_te = moe_grouped_gemm ,
753772 num_moe_experts = 8 ,
754773 transformer_impl = "modelopt" ,
755774 )
756- # model = mtq.quantize(model, mtq.FP8_DEFAULT_CFG, get_forward(model))
757- forward = get_forward (model )
758- forward ()
759- print (model )
760-
761- model .layer_sync_moe_local_experts_amax ()
762- prev_amax = None
763- for expert in model .local_experts :
764- assert expert .input_quantizer .amax is not None
765- if prev_amax is None :
766- prev_amax = expert .input_quantizer .amax
767- else :
768- assert torch .allclose (prev_amax , expert .input_quantizer .amax )
775+ model = mtq .quantize (model , mtq .FP8_DEFAULT_CFG , get_forward (model ))
776+
777+ # Sync amax across local experts in each layer
778+ for layer in model .decoder .layers :
779+ layer .mlp .experts .layer_sync_moe_local_experts_amax ()
780+
781+ for layer in model .decoder .layers :
782+ fc1_amax = None
783+ fc2_amax = None
784+ for expert in layer .mlp .experts .local_experts :
785+ assert expert .linear_fc1 .input_quantizer .amax is not None
786+ assert expert .linear_fc2 .input_quantizer .amax is not None
787+ if fc1_amax is None :
788+ fc1_amax = expert .linear_fc1 .input_quantizer .amax
789+ else :
790+ assert torch .allclose (fc1_amax , expert .linear_fc1 .input_quantizer .amax )
791+ if fc2_amax is None :
792+ fc2_amax = expert .linear_fc2 .input_quantizer .amax
793+ else :
794+ assert torch .allclose (fc2_amax , expert .linear_fc2 .input_quantizer .amax )
769795
770796
771797def _test_expert_model_parallel_amax_sync (
@@ -815,6 +841,7 @@ def _test_expert_model_parallel_amax_sync(
815841
816842 # quantize the model
817843 model = mtq .quantize (model , config , forward )
844+
818845 # Check initial sync status
819846 initial_sync , quantizer_type , rank_values = compare_amax_sync_across_expert_parallel (model )
820847 assert initial_sync , (
0 commit comments