Skip to content

Commit dde0206

Browse files
committed
Decouple the kernels
1 parent 2098fe0 commit dde0206

29 files changed

Lines changed: 248 additions & 268 deletions

src/ntops/__init__.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +0,0 @@
1-
from ntops.abs import abs
2-
from ntops.add import add
3-
from ntops.addmm import addmm
4-
from ntops.bmm import bmm
5-
from ntops.div import div
6-
from ntops.exp import exp
7-
from ntops.gelu import gelu
8-
from ntops.mm import mm
9-
from ntops.mul import mul
10-
from ntops.rsqrt import rsqrt
11-
12-
__all__ = ["abs", "add", "addmm", "bmm", "div", "exp", "gelu", "mm", "mul", "rsqrt"]

src/ntops/abs.py

Lines changed: 0 additions & 30 deletions
This file was deleted.

src/ntops/add.py

Lines changed: 0 additions & 31 deletions
This file was deleted.

src/ntops/exp.py

Lines changed: 0 additions & 30 deletions
This file was deleted.

src/ntops/kernels/__init__.py

Whitespace-only changes.

src/ntops/kernels/abs.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import functools
2+
3+
import ninetoothed
4+
import ninetoothed.language as ntl
5+
from ninetoothed import Tensor
6+
7+
from ntops.kernels.element_wise import arrangement
8+
9+
10+
def application(input, output):
11+
output = ntl.abs(input) # noqa: F841
12+
13+
14+
@functools.cache
15+
def make(ndim):
16+
return ninetoothed.make(arrangement, application, (Tensor(ndim), Tensor(ndim)))

src/ntops/kernels/add.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import functools
2+
3+
import ninetoothed
4+
from ninetoothed import Tensor
5+
6+
from ntops.kernels.element_wise import arrangement
7+
8+
9+
def application(input, other, alpha, output):
10+
output = input + alpha * other # noqa: F841
11+
12+
13+
@functools.cache
14+
def make(ndim):
15+
tensors = (Tensor(ndim), Tensor(ndim), Tensor(0), Tensor(ndim))
16+
17+
return ninetoothed.make(arrangement, application, tensors)
Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,9 @@
22

33
import ninetoothed
44
import ninetoothed.language as ntl
5-
import torch
65
from ninetoothed import Tensor
76

8-
import ntops.mm as mm
7+
import ntops.kernels.mm as mm
98

109

1110
def arrangement(input, x, y, beta, alpha, output):
@@ -22,22 +21,8 @@ def application(input, x, y, beta, alpha, output):
2221
output = beta * input + alpha * mm_output
2322

2423

25-
def addmm(input, x, y, beta, alpha, output=None):
26-
m, _ = x.shape
27-
_, n = y.shape
28-
29-
if output is None:
30-
output = torch.empty((m, n), dtype=input.dtype, device=input.device)
31-
32-
kernel = _make()
33-
34-
kernel(input, x, y, beta, alpha, output)
35-
36-
return output
37-
38-
3924
@functools.cache
40-
def _make():
25+
def make():
4126
tensors = (Tensor(2), Tensor(2), Tensor(2), Tensor(0), Tensor(0), Tensor(2))
4227

4328
return ninetoothed.make(arrangement, application, tensors)
Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
import functools
22

33
import ninetoothed
4-
import torch
54
from ninetoothed import Tensor
65

7-
from ntops.mm import BLOCK_SIZE_K, BLOCK_SIZE_M, BLOCK_SIZE_N, application
6+
from ntops.kernels.mm import BLOCK_SIZE_K, BLOCK_SIZE_M, BLOCK_SIZE_N, application
87

98

109
def arrangement(input, other, output):
@@ -26,20 +25,6 @@ def arrangement(input, other, output):
2625
return input_arranged, other_arranged, output_arranged
2726

2827

29-
def bmm(input, other, output=None):
30-
b, m, _ = input.shape
31-
_, _, n = other.shape
32-
33-
if output is None:
34-
output = torch.empty((b, m, n), dtype=input.dtype, device=input.device)
35-
36-
kernel = _make()
37-
38-
kernel(input, other, output)
39-
40-
return output
41-
42-
4328
@functools.cache
44-
def _make():
29+
def make():
4530
return ninetoothed.make(arrangement, application, (Tensor(3), Tensor(3), Tensor(3)))
Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,9 @@
22

33
import ninetoothed
44
import ninetoothed.language as ntl
5-
import torch
65
from ninetoothed import Tensor
76

8-
from ntops import element_wise
7+
from ntops.kernels.element_wise import arrangement
98

109

1110
def default_application(input, other, output):
@@ -20,26 +19,15 @@ def floor_application(input, other, output):
2019
output = ntl.floor(input / other) # noqa: F841
2120

2221

23-
def div(input, other, rounding_mode=None, output=None):
24-
if output is None:
25-
output = torch.empty_like(input)
26-
27-
kernel = _make(input.ndim, rounding_mode)
28-
29-
kernel(input, other, output)
30-
31-
return output
32-
33-
3422
@functools.cache
35-
def _make(ndim, rounding_mode):
36-
tensors = (Tensor(ndim), Tensor(ndim), Tensor(ndim))
37-
23+
def make(ndim, rounding_mode):
3824
if rounding_mode == "trunc":
3925
application = trunc_application
4026
elif rounding_mode == "floor":
4127
application = floor_application
4228
else:
4329
application = default_application
4430

45-
return ninetoothed.make(element_wise.arrangement, application, tensors)
31+
tensors = (Tensor(ndim), Tensor(ndim), Tensor(ndim))
32+
33+
return ninetoothed.make(arrangement, application, tensors)

0 commit comments

Comments
 (0)