@@ -734,8 +734,7 @@ def test_te_grouped_vs_sequential_quantize(need_4_gpus):
734734
735735@pytest .mark .parametrize ("ep_size" , [1 , 2 ])
736736@pytest .mark .parametrize ("moe_grouped_gemm" , [True , False ])
737- @pytest .mark .parametrize ("shared_moe_weight_scale" , [True , False ])
738- def test_layer_sync_moe_local_experts_amax (ep_size , moe_grouped_gemm , shared_moe_weight_scale ):
737+ def test_layer_sync_moe_local_experts_amax (ep_size , moe_grouped_gemm ):
739738 """Test expert model parallel synchronization."""
740739 size = torch .cuda .device_count ()
741740 if size < ep_size :
@@ -747,15 +746,12 @@ def test_layer_sync_moe_local_experts_amax(ep_size, moe_grouped_gemm, shared_moe
747746 _test_layer_sync_moe_local_experts_amax ,
748747 ep_size ,
749748 moe_grouped_gemm ,
750- shared_moe_weight_scale ,
751749 ),
752750 backend = "nccl" ,
753751 )
754752
755753
756- def _test_layer_sync_moe_local_experts_amax (
757- ep_size , moe_grouped_gemm , shared_moe_weight_scale , rank , size
758- ):
754+ def _test_layer_sync_moe_local_experts_amax (ep_size , moe_grouped_gemm , rank , size ):
759755 initialize_for_megatron (
760756 tensor_model_parallel_size = 1 ,
761757 pipeline_model_parallel_size = 1 ,
@@ -774,16 +770,14 @@ def _test_layer_sync_moe_local_experts_amax(
774770 transformer_impl = "modelopt" ,
775771 )
776772 quant_cfg = mtq .FP8_DEFAULT_CFG
777- if not shared_moe_weight_scale :
778- quant_cfg = copy .deepcopy (quant_cfg )
779- quant_cfg ["algorithm" ] = {"method" : "max" , "shared_moe_weight_scale" : False }
780773 model = mtq .quantize (model , quant_cfg , get_forward (model ))
781774
782775 # does layer_sync_moe_local_experts_amax happens in mtq.quantize if EP=1?
783776 for layer in model .decoder .layers :
784- layer .mlp .experts .layer_sync_moe_local_experts_amax (shared_moe_weight_scale )
777+ layer .mlp .experts .layer_sync_moe_local_experts_amax ()
785778
786779 for layer in model .decoder .layers :
780+ # Check input quantizer amax is synced across local experts
787781 fc1_amax = None
788782 fc2_amax = None
789783 for expert in layer .mlp .experts .local_experts :
@@ -798,25 +792,20 @@ def _test_layer_sync_moe_local_experts_amax(
798792 else :
799793 assert torch .allclose (fc2_amax , expert .linear_fc2 .input_quantizer .amax )
800794
801- for layer in model . decoder . layers :
795+ # Check weight quantizer amax is different across local experts
802796 fc1_amax = None
803797 fc2_amax = None
804798 for expert in layer .mlp .experts .local_experts :
805799 assert expert .linear_fc1 .weight_quantizer .amax is not None
806800 assert expert .linear_fc2 .weight_quantizer .amax is not None
807801 if fc1_amax is None :
808802 fc1_amax = expert .linear_fc1 .weight_quantizer .amax
809- elif shared_moe_weight_scale :
810- assert torch .allclose (fc1_amax , expert .linear_fc1 .weight_quantizer .amax )
811803 else :
812804 assert not torch .allclose (fc1_amax , expert .linear_fc1 .weight_quantizer .amax )
813- fc1_amax = expert .linear_fc1 .weight_quantizer .amax # update most recent amax
814-
815805 if fc2_amax is None :
816806 fc2_amax = expert .linear_fc2 .weight_quantizer .amax
817- elif shared_moe_weight_scale :
818- assert torch .allclose (fc2_amax , expert .linear_fc2 .weight_quantizer .amax )
819- # FC2 amaxes are the same since the input to the layer is all the same
807+ else :
808+ assert not torch .allclose (fc2_amax , expert .linear_fc2 .weight_quantizer .amax )
820809
821810
822811def _test_expert_model_parallel_amax_sync (
0 commit comments