Skip to content

Commit c42f9a5

Browse files
authored
fix kernelcofings. (#1240)
1 parent b355fe7 commit c42f9a5

14 files changed

Lines changed: 106 additions & 1237 deletions

File tree

lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from lightllm.utils.log_utils import init_logger
2626
from lightllm.utils.vllm_utils import vllm_ops
2727
from lightllm.utils.device_utils import triton_support_tensor_descriptor
28-
from .moe_kernel_configs import MoeGroupedGemmKernelConfig
2928
from .moe_silu_and_mul import silu_and_mul_fwd
3029
from .moe_sum_reduce import moe_sum_reduce
3130
from lightllm.common.basemodel.triton_kernel.quantization.fp8act_quant_kernel import per_token_group_quant_fp8
@@ -726,16 +725,26 @@ def grouped_matmul(
726725
block_size_k = expert_weights.shape[2] // expert_to_weights_scale.shape[2]
727726

728727
if run_config is None:
729-
run_config = MoeGroupedGemmKernelConfig.try_to_get_best_config(
730-
M=token_inputs.shape[0],
731-
N=n,
732-
K=k,
733-
topk_num=topk_num,
734-
expert_num=expert_num,
735-
mul_routed_weight=mul_routed_weight,
736-
use_fp8_w8a8=use_fp8_w8a8,
737-
out_dtype=str(out.dtype),
738-
)
728+
if token_inputs.shape[0] <= expert_num:
729+
run_config = {
730+
"BLOCK_SIZE_M": 16,
731+
"BLOCK_SIZE_N": 32,
732+
"BLOCK_SIZE_K": 64,
733+
"GROUP_SIZE_M": 1,
734+
"NEED_TRANS": False,
735+
"num_warps": 4,
736+
"num_stages": 1,
737+
}
738+
else:
739+
run_config = {
740+
"BLOCK_SIZE_M": 64,
741+
"BLOCK_SIZE_N": 64,
742+
"BLOCK_SIZE_K": 32,
743+
"GROUP_SIZE_M": 8,
744+
"NEED_TRANS": False,
745+
"num_warps": 4,
746+
"num_stages": 1,
747+
}
739748

740749
BLOCK_SIZE_M = run_config["BLOCK_SIZE_M"]
741750
BLOCK_SIZE_N = run_config["BLOCK_SIZE_N"]

lightllm/common/basemodel/triton_kernel/fused_moe/moe_kernel_configs.py

Lines changed: 0 additions & 88 deletions
This file was deleted.

lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import triton
44
import triton.language as tl
5-
from .moe_silu_and_mul_config import MoeSiluAndMulKernelConfig
65
from lightllm.common.triton_utils.autotuner import autotune
76

87

@@ -121,7 +120,10 @@ def silu_and_mul_fwd(
121120
size_n = input.shape[-1] // 2
122121

123122
if not run_config:
124-
run_config = MoeSiluAndMulKernelConfig.try_to_get_best_config(M=size_m, N=size_n, out_dtype=str(output.dtype))
123+
if size_m < 256:
124+
run_config = {"BLOCK_M": 1, "BLOCK_N": 128, "num_warps": 1, "NUM_STAGES": 1}
125+
else:
126+
run_config = {"BLOCK_M": 16, "BLOCK_N": 128, "num_warps": 4, "NUM_STAGES": 5}
125127

126128
BLOCK_M = run_config["BLOCK_M"]
127129
BLOCK_N = run_config["BLOCK_N"]

lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul_config.py

Lines changed: 0 additions & 53 deletions
This file was deleted.

lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul_mix_quant_ep.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import triton
44
import triton.language as tl
5-
from .moe_silu_and_mul_config import MoeSiluAndMulKernelConfig
65

76

87
@triton.jit

lightllm/common/basemodel/triton_kernel/fused_moe/moe_sum_recude_config.py

Lines changed: 0 additions & 59 deletions
This file was deleted.

lightllm/common/basemodel/triton_kernel/fused_moe/moe_sum_reduce.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
import torch
2-
32
import triton
43
import triton.language as tl
5-
from .moe_sum_recude_config import MoeSumReduceKernelConfig
6-
from typing import Any, Callable, Dict, Optional, Tuple
4+
from typing import Dict
75
from lightllm.common.triton_utils.autotuner import autotune
86

97

@@ -77,9 +75,12 @@ def moe_sum_reduce(input: torch.Tensor, output: torch.Tensor, run_config: Dict =
7775
assert output.shape[0] == token_num and output.shape[1] == hidden_dim
7876

7977
if not run_config:
80-
run_config = MoeSumReduceKernelConfig.try_to_get_best_config(
81-
M=token_num, topk_num=topk_num, hidden_dim=hidden_dim, out_dtype=str(output.dtype)
82-
)
78+
run_config = {
79+
"BLOCK_M": 1,
80+
"BLOCK_DIM": 128,
81+
"NUM_STAGE": 1,
82+
"num_warps": 2,
83+
}
8384

8485
BLOCK_M = run_config["BLOCK_M"]
8586
BLOCK_DIM = run_config["BLOCK_DIM"]

0 commit comments

Comments
 (0)