|
1 | 1 | import torch |
2 | 2 | import triton |
3 | 3 | import triton.language as tl |
4 | | - |
5 | | -from lightllm.common.kernel_config import KernelConfigs |
6 | | -from frozendict import frozendict |
7 | | -from functools import lru_cache |
8 | | -from typing import Any, Dict, List, Optional, Tuple |
9 | | -from triton import Config |
10 | 4 | from lightllm.common.triton_utils.autotuner import autotune |
11 | | - |
12 | | - |
13 | | -class Fp8BlockMMKernelConfig(KernelConfigs): |
14 | | - kernel_name: str = "fp8_block_mm" |
15 | | - |
16 | | - @classmethod |
17 | | - @lru_cache(maxsize=200) |
18 | | - def try_to_get_best_config( |
19 | | - cls, |
20 | | - M: int, |
21 | | - N: int, |
22 | | - K: int, |
23 | | - block_size: Tuple[int, int], |
24 | | - out_dtype: str, |
25 | | - ) -> dict: |
26 | | - key_params = { |
27 | | - "N": N, |
28 | | - "K": K, |
29 | | - "block_size": block_size, |
30 | | - "out_dtype": str(out_dtype), |
31 | | - } |
32 | | - key_params = frozendict(key_params) |
33 | | - |
34 | | - finded_config = cls.get_the_config(key_params) |
35 | | - |
36 | | - if finded_config: |
37 | | - # find by M |
38 | | - config: dict = finded_config[min(finded_config.keys(), key=lambda x: abs(int(x) - M))] |
39 | | - return config |
40 | | - else: |
41 | | - config = { |
42 | | - "BLOCK_M": 64, |
43 | | - "BLOCK_N": block_size[0], |
44 | | - "BLOCK_K": block_size[1], |
45 | | - "GROUP_M": 32, |
46 | | - "num_warps": 4, |
47 | | - "num_stages": 3, |
48 | | - } |
49 | | - return config |
50 | | - |
51 | | - @classmethod |
52 | | - def save_config( |
53 | | - cls, N: int, K: int, block_size: Tuple[int, int], out_dtype: str, config_json: Dict[int, Dict[int, Dict]] |
54 | | - ): |
55 | | - |
56 | | - key_params = { |
57 | | - "N": N, |
58 | | - "K": K, |
59 | | - "block_size": block_size, |
60 | | - "out_dtype": str(out_dtype), |
61 | | - } |
62 | | - key_params = frozendict(key_params) |
63 | | - |
64 | | - return cls.store_config(key_params, config_json) |
| 5 | +from typing import List |
65 | 6 |
|
66 | 7 |
|
67 | 8 | @triton.jit |
@@ -215,9 +156,14 @@ def w8a8_block_fp8_matmul( |
215 | 156 | assert triton.cdiv(K, block_k) == Ascale.shape[-1] and Ascale.shape[-1] == Bscale.shape[0] |
216 | 157 | assert triton.cdiv(N, block_n) == Bscale.shape[1] |
217 | 158 | if not run_config: |
218 | | - run_config = Fp8BlockMMKernelConfig.try_to_get_best_config( |
219 | | - M=M, N=N, K=K, block_size=block_size, out_dtype=dtype |
220 | | - ) |
| 159 | + run_config = { |
| 160 | + "BLOCK_M": 64, |
| 161 | + "BLOCK_N": block_size[0], |
| 162 | + "BLOCK_K": block_size[1], |
| 163 | + "GROUP_M": 32, |
| 164 | + "num_warps": 4, |
| 165 | + "num_stages": 3, |
| 166 | + } |
221 | 167 | grid = (triton.cdiv(M, run_config["BLOCK_M"]) * triton.cdiv(N, run_config["BLOCK_N"]),) |
222 | 168 | _block_scaled_block_gemm[grid]( |
223 | 169 | A, |
|
0 commit comments