2222
2323logger = init_logger (__name__ )
2424_MEGA_MOE_STATES : Dict [Tuple [int , int , int , int ], Dict [str , Any ]] = {}
25+ SUPPORTED_EP_EXPERT_DTYPES = ("deepgemm-fp8w8a8-b128" , "deepgemm-fp4fp8-b32" )
2526
2627try :
2728 from deep_ep import Buffer , EventOverlap
@@ -37,10 +38,27 @@ def get_ep_num_sms() -> int:
3738 return getattr (dist_group_manager , "ep_num_sms" , None ) or 0
3839
3940
40- def use_sm100_fp4_moe (quant_method : Any ) -> bool :
41+ def use_sm100_mega_moe (quant_method : Any ) -> bool :
4142 return is_sm100_gpu () and quant_method .method_name == "deepgemm-fp4fp8-b32"
4243
4344
45+ def check_ep_expert_dtype (quant_method : Any ):
46+ expert_dtype = getattr (quant_method , "method_name" , None )
47+ if expert_dtype not in SUPPORTED_EP_EXPERT_DTYPES :
48+ raise ValueError (
49+ "EP MoE requires --expert_dtype to be one of "
50+ f"{ list (SUPPORTED_EP_EXPERT_DTYPES )} , but got `{ expert_dtype } `. "
51+ "Please start with --expert_dtype deepgemm-fp8w8a8-b128 or "
52+ "--expert_dtype deepgemm-fp4fp8-b32. Note that deepgemm-fp4fp8-b32 "
53+ "is only supported on SM100 GPUs."
54+ )
55+ if expert_dtype == "deepgemm-fp4fp8-b32" and not is_sm100_gpu ():
56+ raise RuntimeError (
57+ "--expert_dtype deepgemm-fp4fp8-b32 requires an SM100 GPU for EP MoE; "
58+ "please use --expert_dtype deepgemm-fp8w8a8-b128 on non-SM100 GPUs."
59+ )
60+
61+
4462def masked_group_gemm (
4563 recv_x : Tuple [torch .Tensor , torch .Tensor ],
4664 masked_m : torch .Tensor ,
@@ -155,10 +173,10 @@ def do_fused_experts(
155173 is_prefill : Optional [bool ],
156174 previous_event : Optional [Any ] = None ,
157175):
158- if use_sm100_fp4_moe (quant_method ):
176+ check_ep_expert_dtype (quant_method )
177+ if use_sm100_mega_moe (quant_method ):
159178 return mega_moe_impl (hidden_states , w13 , w2 , topk_weights , topk_idx , quant_method )
160179
161- use_fp8_w8a8 = quant_method .method_name != "none"
162180 buffer = dist_group_manager .ep_buffer if is_prefill else dist_group_manager .ep_low_latency_buffer
163181 return fused_experts_impl (
164182 hidden_states = hidden_states ,
@@ -169,8 +187,8 @@ def do_fused_experts(
169187 num_experts = num_experts ,
170188 buffer = buffer ,
171189 is_prefill = is_prefill ,
172- use_fp8_w8a8 = use_fp8_w8a8 ,
173- use_fp8_all2all = use_fp8_w8a8 ,
190+ use_fp8_w8a8 = True ,
191+ use_fp8_all2all = True ,
174192 use_int8_w8a16 = False ,
175193 w1_scale = w13 .weight_scale ,
176194 w2_scale = w2 .weight_scale ,
0 commit comments