Skip to content

Commit 3b60dd4

Browse files
committed
slime deepgemm_impl
1 parent f911956 commit 3b60dd4

2 files changed

Lines changed: 29 additions & 22 deletions

File tree

lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py

Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@
99
get_deepep_num_max_dispatch_tokens_per_rank_decode,
1010
)
1111
from lightllm.common.basemodel.triton_kernel.fused_moe.grouped_fused_moe_ep import (
12-
do_fused_experts,
12+
fused_experts,
1313
get_ep_num_sms,
1414
masked_group_gemm,
1515
deepgemm_grouped_fp8_nt_contiguous,
16-
use_sm100_mega_moe,
16+
quantize_fused_experts_input,
1717
)
1818
from lightllm.common.basemodel.triton_kernel.quantization.fp8act_quant_kernel import (
1919
per_token_group_quant_fp8,
@@ -77,7 +77,7 @@ def _fused_experts(
7777
router_logits: Optional[torch.Tensor] = None,
7878
is_prefill: Optional[bool] = None,
7979
):
80-
output = do_fused_experts(
80+
output = fused_experts(
8181
hidden_states=input_tensor,
8282
w13=w13,
8383
w2=w2,
@@ -152,24 +152,8 @@ def select_experts_and_quant_input(
152152
num_expert_group=n_group,
153153
scoring_func=scoring_func,
154154
)
155-
w13_weight, w13_scale = w13.weight, w13.weight_scale
156-
if use_sm100_mega_moe(self.quant_method):
157-
from deep_gemm.utils import per_token_cast_to_fp8
158-
159-
qinput_tensor = per_token_cast_to_fp8(
160-
hidden_states,
161-
use_ue8m0=True,
162-
gran_k=self.quant_method.block_size,
163-
use_packed_ue8m0=True,
164-
)
165-
return topk_weights, topk_idx.to(torch.long), qinput_tensor
166-
167-
block_size_k = 0
168-
if w13_weight.ndim == 3:
169-
block_size_k = w13_weight.shape[2] // w13_scale.shape[2]
170-
assert block_size_k == 128, "block_size_k must be 128"
171-
qinput_tensor, input_scale = per_token_group_quant_fp8(hidden_states, block_size_k, dtype=w13_weight.dtype)
172-
return topk_weights, topk_idx.to(torch.long), (qinput_tensor, input_scale)
155+
qinput_tensor = quantize_fused_experts_input(hidden_states, w13, self.quant_method)
156+
return topk_weights, topk_idx.to(torch.long), qinput_tensor
173157

174158
def dispatch(
175159
self,

lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,30 @@ def mega_moe_impl(
162162
return output
163163

164164

165-
def do_fused_experts(
165+
def quantize_fused_experts_input(
166+
hidden_states: torch.Tensor,
167+
w13: Any,
168+
quant_method: Any,
169+
):
170+
check_ep_expert_dtype(quant_method)
171+
if use_sm100_mega_moe(quant_method):
172+
from deep_gemm.utils import per_token_cast_to_fp8
173+
174+
return per_token_cast_to_fp8(
175+
hidden_states,
176+
use_ue8m0=True,
177+
gran_k=quant_method.block_size,
178+
use_packed_ue8m0=True,
179+
)
180+
181+
block_size_k = 0
182+
if w13.weight.ndim == 3:
183+
block_size_k = w13.weight.shape[2] // w13.weight_scale.shape[2]
184+
assert block_size_k == 128, "block_size_k must be 128"
185+
return per_token_group_quant_fp8(hidden_states, block_size_k, dtype=w13.weight.dtype)
186+
187+
188+
def fused_experts(
166189
hidden_states: torch.Tensor,
167190
w13: Any,
168191
w2: Any,

0 commit comments

Comments
 (0)