Skip to content

Commit d592fd8

Browse files
committed
Remove ninetoothed.make fallbacks
1 parent 8878595 commit d592fd8

9 files changed

Lines changed: 235 additions & 369 deletions

File tree

ops/ninetoothed/kernels/addmm.py

Lines changed: 32 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -40,57 +40,39 @@ def premake(m, n, k, dtype, block_size_m, block_size_n, block_size_k):
4040
return arrangement_, application, tensors
4141

4242

43-
# Compile the square shapes used by the benchmark; other shapes use the generic
44-
# make-based fallback below.
45-
configs = tuple(
46-
(
47-
(),
48-
{
49-
"m": s,
50-
"n": s,
51-
"k": s,
52-
"dtype": ninetoothed.float16,
53-
"block_size_m": bm,
54-
"block_size_n": bn,
55-
"block_size_k": bk,
56-
},
57-
{"num_warps": nw, "num_stages": 3},
43+
def _configs(m, n, k, dtype):
44+
return tuple(
45+
(
46+
(),
47+
{
48+
"m": m,
49+
"n": n,
50+
"k": k,
51+
"dtype": dtype,
52+
"block_size_m": bm,
53+
"block_size_n": bn,
54+
"block_size_k": bk,
55+
},
56+
{"num_warps": nw, "num_stages": 3},
57+
)
58+
for bm in (64, 128)
59+
for bn in (64, 128)
60+
for bk in (32, 64)
61+
for nw in (4, 8)
5862
)
59-
for s in (128 * i for i in range(2, 33))
60-
for bm in (64, 128)
61-
for bn in (64, 128)
62-
for bk in (32, 64)
63-
for nw in (4, 8)
64-
)
65-
66-
_build_kernel = build(
67-
premake,
68-
configs,
69-
meta_parameters=("block_size_m", "block_size_n", "block_size_k"),
70-
kernel_name="addmm",
71-
)
72-
73-
_BUILD_CONFIGS = frozenset(
74-
(kwargs["m"], kwargs["n"], kwargs["k"], kwargs["dtype"])
75-
for _, kwargs, _ in configs
76-
)
77-
78-
_fallback_kernel = ninetoothed.make(
79-
arrangement,
80-
application,
81-
(
82-
Tensor(2),
83-
Tensor(2),
84-
Tensor(2),
85-
Tensor(0),
86-
Tensor(0),
87-
Tensor(2),
88-
),
89-
)
9063

9164

92-
def kernel(input, mat1, mat2, beta, alpha, output, m, n, k, dtype):
93-
if (m, n, k, dtype) in _BUILD_CONFIGS:
94-
return _build_kernel(input, mat1, mat2, beta, alpha, output, m, n, k, dtype)
65+
@functools.cache
66+
def _kernel(m, n, k, dtype):
67+
return build(
68+
premake,
69+
_configs(m, n, k, dtype),
70+
meta_parameters=("block_size_m", "block_size_n", "block_size_k"),
71+
kernel_name=f"addmm_{m}_{n}_{k}",
72+
)
73+
9574

96-
return _fallback_kernel(input, mat1, mat2, beta, alpha, output)
75+
def kernel(input, mat1, mat2, beta, alpha, output, m, n, k, dtype):
76+
return _kernel(m, n, k, dtype)(
77+
input, mat1, mat2, beta, alpha, output, m, n, k, dtype
78+
)

ops/ninetoothed/kernels/bmm.py

Lines changed: 33 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
import functools
22

3-
import ninetoothed
4-
from ninetoothed import Tensor, block_size
3+
from ninetoothed import Tensor
54

6-
from ops.ninetoothed.kernels._common import DTYPES, build
5+
from ops.ninetoothed.kernels._common import build
76
from ops.ninetoothed.kernels.mm import application
87

98

@@ -33,87 +32,51 @@ def arrangement(
3332
return input_arranged, other_arranged, output_arranged
3433

3534

36-
def premake(k, n, dtype, block_size_m, block_size_n, block_size_k):
35+
def premake(batch, m, k, n, dtype, block_size_m, block_size_n, block_size_k):
3736
arrangement_ = functools.partial(
3837
arrangement,
3938
block_size_m=block_size_m,
4039
block_size_n=block_size_n,
4140
block_size_k=block_size_k,
4241
)
43-
shape_options = ({"upper_bound": 4}, None, None)
4442
tensors = (
45-
Tensor(shape=(None, None, k), shape_options=shape_options, dtype=dtype),
46-
Tensor(shape=(None, k, n), shape_options=shape_options, dtype=dtype),
47-
Tensor(shape=(None, None, n), shape_options=shape_options, dtype=dtype),
43+
Tensor(shape=(batch, m, k), dtype=dtype),
44+
Tensor(shape=(batch, k, n), dtype=dtype),
45+
Tensor(shape=(batch, m, n), dtype=dtype),
4846
)
4947

5048
return arrangement_, application, tensors
5149

5250

53-
_SHAPES = (
54-
(4096, 4096),
55-
(4096, 1024),
56-
(4096, 14336),
57-
(14336, 4096),
58-
(4096, 128256),
59-
)
60-
61-
configs = tuple(
62-
(
63-
(),
64-
{
65-
"k": k,
66-
"n": n,
67-
"dtype": dtype,
68-
"block_size_m": bm,
69-
"block_size_n": bn,
70-
"block_size_k": bk,
71-
},
72-
{"num_warps": nw, "num_stages": ns},
51+
def _configs(batch, m, k, n, dtype):
52+
return (
53+
(
54+
(),
55+
{
56+
"batch": batch,
57+
"m": m,
58+
"k": k,
59+
"n": n,
60+
"dtype": dtype,
61+
"block_size_m": 16,
62+
"block_size_n": 64,
63+
"block_size_k": 32,
64+
},
65+
{"num_warps": 4, "num_stages": 3},
66+
),
7367
)
74-
for k, n in _SHAPES
75-
for dtype in DTYPES
76-
for bm in (16, 64)
77-
for bn in (64, 128)
78-
for bk in (32, 64)
79-
for nw in (4, 8)
80-
for ns in (3, 4)
81-
)
8268

83-
_build_kernel = build(
84-
premake,
85-
configs,
86-
meta_parameters=("block_size_m", "block_size_n", "block_size_k"),
87-
kernel_name="bmm",
88-
)
89-
90-
91-
_BUILD_KN = frozenset(_SHAPES)
92-
93-
94-
_BLOCK_SIZE_M = block_size()
95-
_BLOCK_SIZE_N = block_size()
96-
_BLOCK_SIZE_K = block_size()
97-
98-
99-
def _fallback_arrangement(
100-
input,
101-
other,
102-
output,
103-
BLOCK_SIZE_M=_BLOCK_SIZE_M,
104-
BLOCK_SIZE_N=_BLOCK_SIZE_N,
105-
BLOCK_SIZE_K=_BLOCK_SIZE_K,
106-
):
107-
return arrangement(input, other, output, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K)
108-
109-
110-
_fallback_kernel = ninetoothed.make(
111-
_fallback_arrangement, application, (Tensor(3), Tensor(3), Tensor(3))
112-
)
11369

70+
@functools.cache
71+
def _kernel(batch, m, k, n, dtype):
72+
return build(
73+
premake,
74+
_configs(batch, m, k, n, dtype),
75+
kernel_name=f"bmm_{batch}_{m}_{k}_{n}",
76+
)
11477

115-
def kernel(lhs, rhs, output, k, n, dtype):
116-
if (k, n) in _BUILD_KN:
117-
return _build_kernel(lhs, rhs, output, k, n, dtype)
11878

119-
return _fallback_kernel(lhs, rhs, output)
79+
def kernel(lhs, rhs, output, batch, m, k, n, dtype):
80+
return _kernel(batch, m, k, n, dtype)(
81+
lhs, rhs, output, batch, m, k, n, dtype, 16, 64, 32, 4, 3
82+
)

ops/ninetoothed/kernels/conv2d.py

Lines changed: 36 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import functools
22

3-
import ninetoothed
43
from ninetoothed import Tensor
54

65
import ops.ninetoothed.kernels.mm as mm
@@ -31,10 +30,8 @@ def premake(n, c, h, w, k, r, s, dtype, block_size_m, block_size_n, block_size_k
3130
block_size_n=block_size_n,
3231
block_size_k=block_size_k,
3332
)
34-
3533
p = h - r + 1
3634
q = w - s + 1
37-
3835
tensors = (
3936
Tensor(shape=(n, c, h, w), dtype=dtype),
4037
Tensor(shape=(k, c, r, s), dtype=dtype),
@@ -44,65 +41,44 @@ def premake(n, c, h, w, k, r, s, dtype, block_size_m, block_size_n, block_size_k
4441
return arrangement_, mm.application, tensors
4542

4643

47-
# Block sweep approximating the JIT auto-tuner default range. Conv2d's im2col
48-
# arrangement produces a tall-skinny matmul, so wider bn (up to 256) and
49-
# longer-pipelined num_stages help match the JIT-tuned kernel.
50-
configs = tuple(
51-
(
52-
(),
53-
{
54-
"n": n,
55-
"c": 512,
56-
"h": 14,
57-
"w": 14,
58-
"k": 512,
59-
"r": 3,
60-
"s": 3,
61-
"dtype": ninetoothed.float16,
62-
"block_size_m": bm,
63-
"block_size_n": bn,
64-
"block_size_k": bk,
65-
},
66-
{"num_warps": 8, "num_stages": ns},
44+
def _configs(n, c, h, w, k, r, s, dtype):
45+
return tuple(
46+
(
47+
(),
48+
{
49+
"n": n,
50+
"c": c,
51+
"h": h,
52+
"w": w,
53+
"k": k,
54+
"r": r,
55+
"s": s,
56+
"dtype": dtype,
57+
"block_size_m": bm,
58+
"block_size_n": bn,
59+
"block_size_k": bk,
60+
},
61+
{"num_warps": 8, "num_stages": ns},
62+
)
63+
for bm in (64, 128)
64+
for bn in (128, 256)
65+
for bk in (32, 64)
66+
for ns in (3, 5)
67+
if bm * bn <= 32768 and bm * bk <= 32768 and bn * bk <= 32768
6768
)
68-
for n in (2, 4, 8, 16, 32, 64, 128, 256, 512, 1024)
69-
for bm in (64, 128)
70-
for bn in (128, 256)
71-
for bk in (32, 64)
72-
for ns in (3, 5)
73-
if bm * bn <= 32768 and bm * bk <= 32768 and bn * bk <= 32768
74-
)
75-
76-
_build_kernel = build(
77-
premake,
78-
configs,
79-
meta_parameters=("block_size_m", "block_size_n", "block_size_k"),
80-
kernel_name="conv2d",
81-
)
82-
83-
_BUILD_CONFIGS = frozenset(
84-
(
85-
kwargs["n"],
86-
kwargs["c"],
87-
kwargs["h"],
88-
kwargs["w"],
89-
kwargs["k"],
90-
kwargs["r"],
91-
kwargs["s"],
92-
kwargs["dtype"],
93-
)
94-
for _, kwargs, _ in configs
95-
)
9669

97-
_fallback_kernel = ninetoothed.make(
98-
arrangement,
99-
mm.application,
100-
tuple(Tensor(4, shape_options={"constexpr": True}) for _ in range(3)),
101-
)
10270

71+
@functools.cache
72+
def _kernel(n, c, h, w, k, r, s, dtype):
73+
return build(
74+
premake,
75+
_configs(n, c, h, w, k, r, s, dtype),
76+
meta_parameters=("block_size_m", "block_size_n", "block_size_k"),
77+
kernel_name=f"conv2d_{n}_{c}_{h}_{w}_{k}_{r}_{s}",
78+
)
10379

104-
def kernel(input, filter, output, n, c, h, w, k, r, s, dtype):
105-
if (n, c, h, w, k, r, s, dtype) in _BUILD_CONFIGS:
106-
return _build_kernel(input, filter, output, n, c, h, w, k, r, s, dtype)
10780

108-
return _fallback_kernel(input, filter, output)
81+
def kernel(input, filter, output, n, c, h, w, k, r, s, dtype):
82+
return _kernel(n, c, h, w, k, r, s, dtype)(
83+
input, filter, output, n, c, h, w, k, r, s, dtype
84+
)

0 commit comments

Comments
 (0)