Skip to content

Commit b355fe7

Browse files
authored
fix _block_scaled_block_gemm kernel, remove config class (#1239)
1 parent a7c925c commit b355fe7

1 file changed

Lines changed: 9 additions & 63 deletions

File tree

lightllm/common/basemodel/triton_kernel/quantization/fp8w8a8_block_gemm_kernel.py

Lines changed: 9 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,67 +1,8 @@
11
import torch
22
import triton
33
import triton.language as tl
4-
5-
from lightllm.common.kernel_config import KernelConfigs
6-
from frozendict import frozendict
7-
from functools import lru_cache
8-
from typing import Any, Dict, List, Optional, Tuple
9-
from triton import Config
104
from lightllm.common.triton_utils.autotuner import autotune
11-
12-
13-
class Fp8BlockMMKernelConfig(KernelConfigs):
14-
kernel_name: str = "fp8_block_mm"
15-
16-
@classmethod
17-
@lru_cache(maxsize=200)
18-
def try_to_get_best_config(
19-
cls,
20-
M: int,
21-
N: int,
22-
K: int,
23-
block_size: Tuple[int, int],
24-
out_dtype: str,
25-
) -> dict:
26-
key_params = {
27-
"N": N,
28-
"K": K,
29-
"block_size": block_size,
30-
"out_dtype": str(out_dtype),
31-
}
32-
key_params = frozendict(key_params)
33-
34-
finded_config = cls.get_the_config(key_params)
35-
36-
if finded_config:
37-
# find by M
38-
config: dict = finded_config[min(finded_config.keys(), key=lambda x: abs(int(x) - M))]
39-
return config
40-
else:
41-
config = {
42-
"BLOCK_M": 64,
43-
"BLOCK_N": block_size[0],
44-
"BLOCK_K": block_size[1],
45-
"GROUP_M": 32,
46-
"num_warps": 4,
47-
"num_stages": 3,
48-
}
49-
return config
50-
51-
@classmethod
52-
def save_config(
53-
cls, N: int, K: int, block_size: Tuple[int, int], out_dtype: str, config_json: Dict[int, Dict[int, Dict]]
54-
):
55-
56-
key_params = {
57-
"N": N,
58-
"K": K,
59-
"block_size": block_size,
60-
"out_dtype": str(out_dtype),
61-
}
62-
key_params = frozendict(key_params)
63-
64-
return cls.store_config(key_params, config_json)
5+
from typing import List
656

667

678
@triton.jit
@@ -215,9 +156,14 @@ def w8a8_block_fp8_matmul(
215156
assert triton.cdiv(K, block_k) == Ascale.shape[-1] and Ascale.shape[-1] == Bscale.shape[0]
216157
assert triton.cdiv(N, block_n) == Bscale.shape[1]
217158
if not run_config:
218-
run_config = Fp8BlockMMKernelConfig.try_to_get_best_config(
219-
M=M, N=N, K=K, block_size=block_size, out_dtype=dtype
220-
)
159+
run_config = {
160+
"BLOCK_M": 64,
161+
"BLOCK_N": block_size[0],
162+
"BLOCK_K": block_size[1],
163+
"GROUP_M": 32,
164+
"num_warps": 4,
165+
"num_stages": 3,
166+
}
221167
grid = (triton.cdiv(M, run_config["BLOCK_M"]) * triton.cdiv(N, run_config["BLOCK_N"]),)
222168
_block_scaled_block_gemm[grid](
223169
A,

0 commit comments

Comments
 (0)