|
9 | 9 | get_deepep_num_max_dispatch_tokens_per_rank_decode, |
10 | 10 | ) |
11 | 11 | from lightllm.common.basemodel.triton_kernel.fused_moe.grouped_fused_moe_ep import ( |
12 | | - do_fused_experts, |
| 12 | + fused_experts, |
13 | 13 | get_ep_num_sms, |
14 | 14 | masked_group_gemm, |
15 | 15 | deepgemm_grouped_fp8_nt_contiguous, |
16 | | - use_sm100_mega_moe, |
| 16 | + quantize_fused_experts_input, |
17 | 17 | ) |
18 | 18 | from lightllm.common.basemodel.triton_kernel.quantization.fp8act_quant_kernel import ( |
19 | 19 | per_token_group_quant_fp8, |
@@ -77,7 +77,7 @@ def _fused_experts( |
77 | 77 | router_logits: Optional[torch.Tensor] = None, |
78 | 78 | is_prefill: Optional[bool] = None, |
79 | 79 | ): |
80 | | - output = do_fused_experts( |
| 80 | + output = fused_experts( |
81 | 81 | hidden_states=input_tensor, |
82 | 82 | w13=w13, |
83 | 83 | w2=w2, |
@@ -152,24 +152,8 @@ def select_experts_and_quant_input( |
152 | 152 | num_expert_group=n_group, |
153 | 153 | scoring_func=scoring_func, |
154 | 154 | ) |
155 | | - w13_weight, w13_scale = w13.weight, w13.weight_scale |
156 | | - if use_sm100_mega_moe(self.quant_method): |
157 | | - from deep_gemm.utils import per_token_cast_to_fp8 |
158 | | - |
159 | | - qinput_tensor = per_token_cast_to_fp8( |
160 | | - hidden_states, |
161 | | - use_ue8m0=True, |
162 | | - gran_k=self.quant_method.block_size, |
163 | | - use_packed_ue8m0=True, |
164 | | - ) |
165 | | - return topk_weights, topk_idx.to(torch.long), qinput_tensor |
166 | | - |
167 | | - block_size_k = 0 |
168 | | - if w13_weight.ndim == 3: |
169 | | - block_size_k = w13_weight.shape[2] // w13_scale.shape[2] |
170 | | - assert block_size_k == 128, "block_size_k must be 128" |
171 | | - qinput_tensor, input_scale = per_token_group_quant_fp8(hidden_states, block_size_k, dtype=w13_weight.dtype) |
172 | | - return topk_weights, topk_idx.to(torch.long), (qinput_tensor, input_scale) |
| 155 | + qinput_tensor = quantize_fused_experts_input(hidden_states, w13, self.quant_method) |
| 156 | + return topk_weights, topk_idx.to(torch.long), qinput_tensor |
173 | 157 |
|
174 | 158 | def dispatch( |
175 | 159 | self, |
|
0 commit comments