|
1 | 1 | import functools |
2 | 2 |
|
3 | | -import ninetoothed |
4 | | -from ninetoothed import Tensor, block_size |
| 3 | +from ninetoothed import Tensor |
5 | 4 |
|
6 | | -from ops.ninetoothed.kernels._common import DTYPES, build |
| 5 | +from ops.ninetoothed.kernels._common import build |
7 | 6 | from ops.ninetoothed.kernels.mm import application |
8 | 7 |
|
9 | 8 |
|
@@ -33,87 +32,51 @@ def arrangement( |
33 | 32 | return input_arranged, other_arranged, output_arranged |
34 | 33 |
|
35 | 34 |
|
36 | | -def premake(k, n, dtype, block_size_m, block_size_n, block_size_k): |
| 35 | +def premake(batch, m, k, n, dtype, block_size_m, block_size_n, block_size_k): |
37 | 36 | arrangement_ = functools.partial( |
38 | 37 | arrangement, |
39 | 38 | block_size_m=block_size_m, |
40 | 39 | block_size_n=block_size_n, |
41 | 40 | block_size_k=block_size_k, |
42 | 41 | ) |
43 | | - shape_options = ({"upper_bound": 4}, None, None) |
44 | 42 | tensors = ( |
45 | | - Tensor(shape=(None, None, k), shape_options=shape_options, dtype=dtype), |
46 | | - Tensor(shape=(None, k, n), shape_options=shape_options, dtype=dtype), |
47 | | - Tensor(shape=(None, None, n), shape_options=shape_options, dtype=dtype), |
| 43 | + Tensor(shape=(batch, m, k), dtype=dtype), |
| 44 | + Tensor(shape=(batch, k, n), dtype=dtype), |
| 45 | + Tensor(shape=(batch, m, n), dtype=dtype), |
48 | 46 | ) |
49 | 47 |
|
50 | 48 | return arrangement_, application, tensors |
51 | 49 |
|
52 | 50 |
|
53 | | -_SHAPES = ( |
54 | | - (4096, 4096), |
55 | | - (4096, 1024), |
56 | | - (4096, 14336), |
57 | | - (14336, 4096), |
58 | | - (4096, 128256), |
59 | | -) |
60 | | - |
61 | | -configs = tuple( |
62 | | - ( |
63 | | - (), |
64 | | - { |
65 | | - "k": k, |
66 | | - "n": n, |
67 | | - "dtype": dtype, |
68 | | - "block_size_m": bm, |
69 | | - "block_size_n": bn, |
70 | | - "block_size_k": bk, |
71 | | - }, |
72 | | - {"num_warps": nw, "num_stages": ns}, |
| 51 | +def _configs(batch, m, k, n, dtype): |
| 52 | + return ( |
| 53 | + ( |
| 54 | + (), |
| 55 | + { |
| 56 | + "batch": batch, |
| 57 | + "m": m, |
| 58 | + "k": k, |
| 59 | + "n": n, |
| 60 | + "dtype": dtype, |
| 61 | + "block_size_m": 16, |
| 62 | + "block_size_n": 64, |
| 63 | + "block_size_k": 32, |
| 64 | + }, |
| 65 | + {"num_warps": 4, "num_stages": 3}, |
| 66 | + ), |
73 | 67 | ) |
74 | | - for k, n in _SHAPES |
75 | | - for dtype in DTYPES |
76 | | - for bm in (16, 64) |
77 | | - for bn in (64, 128) |
78 | | - for bk in (32, 64) |
79 | | - for nw in (4, 8) |
80 | | - for ns in (3, 4) |
81 | | -) |
82 | 68 |
|
83 | | -_build_kernel = build( |
84 | | - premake, |
85 | | - configs, |
86 | | - meta_parameters=("block_size_m", "block_size_n", "block_size_k"), |
87 | | - kernel_name="bmm", |
88 | | -) |
89 | | - |
90 | | - |
91 | | -_BUILD_KN = frozenset(_SHAPES) |
92 | | - |
93 | | - |
94 | | -_BLOCK_SIZE_M = block_size() |
95 | | -_BLOCK_SIZE_N = block_size() |
96 | | -_BLOCK_SIZE_K = block_size() |
97 | | - |
98 | | - |
99 | | -def _fallback_arrangement( |
100 | | - input, |
101 | | - other, |
102 | | - output, |
103 | | - BLOCK_SIZE_M=_BLOCK_SIZE_M, |
104 | | - BLOCK_SIZE_N=_BLOCK_SIZE_N, |
105 | | - BLOCK_SIZE_K=_BLOCK_SIZE_K, |
106 | | -): |
107 | | - return arrangement(input, other, output, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K) |
108 | | - |
109 | | - |
110 | | -_fallback_kernel = ninetoothed.make( |
111 | | - _fallback_arrangement, application, (Tensor(3), Tensor(3), Tensor(3)) |
112 | | -) |
113 | 69 |
|
| 70 | +@functools.cache |
| 71 | +def _kernel(batch, m, k, n, dtype): |
| 72 | + return build( |
| 73 | + premake, |
| 74 | + _configs(batch, m, k, n, dtype), |
| 75 | + kernel_name=f"bmm_{batch}_{m}_{k}_{n}", |
| 76 | + ) |
114 | 77 |
|
115 | | -def kernel(lhs, rhs, output, k, n, dtype): |
116 | | - if (k, n) in _BUILD_KN: |
117 | | - return _build_kernel(lhs, rhs, output, k, n, dtype) |
118 | 78 |
|
119 | | - return _fallback_kernel(lhs, rhs, output) |
| 79 | +def kernel(lhs, rhs, output, batch, m, k, n, dtype): |
| 80 | + return _kernel(batch, m, k, n, dtype)( |
| 81 | + lhs, rhs, output, batch, m, k, n, dtype, 16, 64, 32, 4, 3 |
| 82 | + ) |
0 commit comments