Skip to content

Commit ce4bff6

Browse files
authored
Merge pull request #41 from InfiniTensor/develop-premake-mechanism
Refactor kernels to use `premake` instead of `make`
2 parents 4ddc09c + feb4ca4 commit ce4bff6

36 files changed

Lines changed: 418 additions & 220 deletions

src/ntops/kernels/abs.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import functools
22

3-
import ninetoothed
43
import ninetoothed.language as ntl
54
from ninetoothed import Tensor
65

@@ -11,6 +10,9 @@ def application(input, output):
1110
output = ntl.abs(input) # noqa: F841
1211

1312

14-
@functools.cache
15-
def make(ndim):
16-
return ninetoothed.make(arrangement, application, (Tensor(ndim), Tensor(ndim)))
13+
def premake(ndim, dtype=None, block_size=None):
14+
arrangement_ = functools.partial(arrangement, block_size=block_size)
15+
16+
tensors = (Tensor(ndim, dtype=dtype), Tensor(ndim, dtype=dtype))
17+
18+
return arrangement_, application, tensors

src/ntops/kernels/add.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import functools
22

3-
import ninetoothed
43
from ninetoothed import Tensor
54

65
from ntops.kernels.element_wise import arrangement
@@ -10,8 +9,14 @@ def application(input, other, alpha, output):
109
output = input + alpha * other # noqa: F841
1110

1211

13-
@functools.cache
14-
def make(ndim):
15-
tensors = (Tensor(ndim), Tensor(ndim), Tensor(0), Tensor(ndim))
12+
def premake(ndim, dtype=None, block_size=None):
13+
arrangement_ = functools.partial(arrangement, block_size=block_size)
1614

17-
return ninetoothed.make(arrangement, application, tensors)
15+
tensors = (
16+
Tensor(ndim, dtype=dtype),
17+
Tensor(ndim, dtype=dtype),
18+
Tensor(0, dtype=dtype),
19+
Tensor(ndim, dtype=dtype),
20+
)
21+
22+
return arrangement_, application, tensors

src/ntops/kernels/addmm.py

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,39 @@
11
import functools
22

3-
import ninetoothed
43
import ninetoothed.language as ntl
54
from ninetoothed import Tensor
65

76
import ntops.kernels.mm as mm
87

98

10-
def arrangement(input, x, y, beta, alpha, output):
11-
_, _, input_arranged = mm.arrangement(x, y, input)
9+
def arrangement(
10+
input,
11+
x,
12+
y,
13+
beta,
14+
alpha,
15+
output,
16+
block_size_m=None,
17+
block_size_n=None,
18+
block_size_k=None,
19+
):
20+
if block_size_m is None:
21+
block_size_m = mm.BLOCK_SIZE_M
22+
23+
if block_size_n is None:
24+
block_size_n = mm.BLOCK_SIZE_N
25+
26+
if block_size_k is None:
27+
block_size_k = mm.BLOCK_SIZE_K
28+
29+
_, _, input_arranged = mm.arrangement(
30+
x,
31+
y,
32+
input,
33+
block_size_m=block_size_m,
34+
block_size_n=block_size_n,
35+
block_size_k=block_size_k,
36+
)
1237

1338
x_arranged, y_arranged, output_arranged = mm.arrangement(x, y, output)
1439

@@ -21,8 +46,21 @@ def application(input, x, y, beta, alpha, output):
2146
output = beta * input + alpha * mm_output
2247

2348

24-
@functools.cache
25-
def make():
26-
tensors = (Tensor(2), Tensor(2), Tensor(2), Tensor(0), Tensor(0), Tensor(2))
49+
def premake(dtype=None, block_size_m=None, block_size_n=None, block_size_k=None):
50+
arrangement_ = functools.partial(
51+
arrangement,
52+
block_size_m=block_size_m,
53+
block_size_n=block_size_n,
54+
block_size_k=block_size_k,
55+
)
56+
57+
tensors = (
58+
Tensor(2, dtype=dtype),
59+
Tensor(2, dtype=dtype),
60+
Tensor(2, dtype=dtype),
61+
Tensor(0, dtype=dtype),
62+
Tensor(0, dtype=dtype),
63+
Tensor(2, dtype=dtype),
64+
)
2765

28-
return ninetoothed.make(arrangement, application, tensors)
66+
return arrangement_, application, tensors

src/ntops/kernels/bitwise_and.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import functools
22

3-
import ninetoothed
43
from ninetoothed import Tensor
54

65
from ntops.kernels.element_wise import arrangement
@@ -10,8 +9,13 @@ def application(input, other, output):
109
output = input & other # noqa: F841
1110

1211

13-
@functools.cache
14-
def make(ndim):
15-
tensors = (Tensor(ndim), Tensor(ndim), Tensor(ndim))
12+
def premake(ndim, dtype=None, block_size=None):
13+
arrangement_ = functools.partial(arrangement, block_size=block_size)
1614

17-
return ninetoothed.make(arrangement, application, tensors)
15+
tensors = (
16+
Tensor(ndim, dtype=dtype),
17+
Tensor(ndim, dtype=dtype),
18+
Tensor(ndim, dtype=dtype),
19+
)
20+
21+
return arrangement_, application, tensors

src/ntops/kernels/bitwise_not.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import functools
22

3-
import ninetoothed
43
import ninetoothed.language as ntl
54
from ninetoothed import Tensor
65

@@ -15,10 +14,11 @@ def logical_application(input, output):
1514
output = ntl.where(input, False, True) # noqa: F841
1615

1716

18-
@functools.cache
19-
def make(ndim, logical=False):
20-
tensors = (Tensor(ndim), Tensor(ndim))
17+
def premake(ndim, logical=False, dtype=None, block_size=None):
18+
arrangement_ = functools.partial(arrangement, block_size=block_size)
2119

2220
application = logical_application if logical else bitwise_application
2321

24-
return ninetoothed.make(arrangement, application, tensors)
22+
tensors = (Tensor(ndim, dtype=dtype), Tensor(ndim, dtype=dtype))
23+
24+
return arrangement_, application, tensors

src/ntops/kernels/bitwise_or.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import functools
22

3-
import ninetoothed
43
from ninetoothed import Tensor
54

65
from ntops.kernels.element_wise import arrangement
@@ -10,8 +9,13 @@ def application(input, other, output):
109
output = input | other # noqa: F841
1110

1211

13-
@functools.cache
14-
def make(ndim):
15-
tensors = (Tensor(ndim), Tensor(ndim), Tensor(ndim))
12+
def premake(ndim, dtype=None, block_size=None):
13+
arrangement_ = functools.partial(arrangement, block_size=block_size)
1614

17-
return ninetoothed.make(arrangement, application, tensors)
15+
tensors = (
16+
Tensor(ndim, dtype=dtype),
17+
Tensor(ndim, dtype=dtype),
18+
Tensor(ndim, dtype=dtype),
19+
)
20+
21+
return arrangement_, application, tensors

src/ntops/kernels/bmm.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,32 @@
11
import functools
22

3-
import ninetoothed
43
from ninetoothed import Tensor
54

65
from ntops.kernels.mm import BLOCK_SIZE_K, BLOCK_SIZE_M, BLOCK_SIZE_N, application
76

87

9-
def arrangement(input, other, output):
10-
output_arranged = output.tile((1, BLOCK_SIZE_M, BLOCK_SIZE_N))
8+
def arrangement(
9+
input, other, output, block_size_m=None, block_size_n=None, block_size_k=None
10+
):
11+
if block_size_m is None:
12+
block_size_m = BLOCK_SIZE_M
13+
14+
if block_size_n is None:
15+
block_size_n = BLOCK_SIZE_N
16+
17+
if block_size_k is None:
18+
block_size_k = BLOCK_SIZE_K
19+
20+
output_arranged = output.tile((1, block_size_m, block_size_n))
1121
output_arranged.dtype = output_arranged.dtype.squeeze(0)
1222

13-
input_arranged = input.tile((1, BLOCK_SIZE_M, BLOCK_SIZE_K))
23+
input_arranged = input.tile((1, block_size_m, block_size_k))
1424
input_arranged = input_arranged.tile((1, 1, -1))
1525
input_arranged = input_arranged.expand((-1, -1, output_arranged.shape[-1]))
1626
input_arranged.dtype = input_arranged.dtype.squeeze((0, 1))
1727
input_arranged.dtype.dtype = input_arranged.dtype.dtype.squeeze(0)
1828

19-
other_arranged = other.tile((1, BLOCK_SIZE_K, BLOCK_SIZE_N))
29+
other_arranged = other.tile((1, block_size_k, block_size_n))
2030
other_arranged = other_arranged.tile((1, -1, 1))
2131
other_arranged = other_arranged.expand((-1, output_arranged.shape[-2], -1))
2232
other_arranged.dtype = other_arranged.dtype.squeeze((0, 2))
@@ -25,6 +35,14 @@ def arrangement(input, other, output):
2535
return input_arranged, other_arranged, output_arranged
2636

2737

28-
@functools.cache
29-
def make():
30-
return ninetoothed.make(arrangement, application, (Tensor(3), Tensor(3), Tensor(3)))
38+
def premake(dtype=None, block_size_m=None, block_size_n=None, block_size_k=None):
39+
arrangement_ = functools.partial(
40+
arrangement,
41+
block_size_m=block_size_m,
42+
block_size_n=block_size_n,
43+
block_size_k=block_size_k,
44+
)
45+
46+
tensors = (Tensor(3, dtype=dtype), Tensor(3, dtype=dtype), Tensor(3, dtype=dtype))
47+
48+
return arrangement_, application, tensors

src/ntops/kernels/clamp.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import functools
22

3-
import ninetoothed
43
import ninetoothed.language as ntl
54
from ninetoothed import Tensor
65

@@ -11,8 +10,14 @@ def application(input, min_val, max_val, output):
1110
output = ntl.clamp(input, min_val, max_val) # noqa: F841
1211

1312

14-
@functools.cache
15-
def make(ndim):
16-
tensors = (Tensor(ndim), Tensor(ndim), Tensor(ndim), Tensor(ndim))
13+
def premake(ndim, dtype=None, block_size=None):
14+
arrangement_ = functools.partial(arrangement, block_size=block_size)
1715

18-
return ninetoothed.make(arrangement, application, tensors)
16+
tensors = (
17+
Tensor(ndim, dtype=dtype),
18+
Tensor(ndim, dtype=dtype),
19+
Tensor(ndim, dtype=dtype),
20+
Tensor(ndim, dtype=dtype),
21+
)
22+
23+
return arrangement_, application, tensors

src/ntops/kernels/cos.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import functools
22

3-
import ninetoothed
43
import ninetoothed.language as ntl
54
from ninetoothed import Tensor
65

@@ -11,6 +10,9 @@ def application(input, output):
1110
output = ntl.cos(input) # noqa: F841
1211

1312

14-
@functools.cache
15-
def make(ndim):
16-
return ninetoothed.make(arrangement, application, (Tensor(ndim), Tensor(ndim)))
13+
def premake(ndim, dtype=None, block_size=None):
14+
arrangement_ = functools.partial(arrangement, block_size=block_size)
15+
16+
tensors = (Tensor(ndim, dtype=dtype), Tensor(ndim, dtype=dtype))
17+
18+
return arrangement_, application, tensors

src/ntops/kernels/div.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import functools
22

3-
import ninetoothed
43
import ninetoothed.language as ntl
54
from ninetoothed import Tensor
65

@@ -19,15 +18,20 @@ def floor_application(input, other, output):
1918
output = ntl.floor(input / other) # noqa: F841
2019

2120

22-
@functools.cache
23-
def make(ndim, rounding_mode):
21+
def premake(ndim, rounding_mode, dtype=None, block_size=None):
22+
arrangement_ = functools.partial(arrangement, block_size=block_size)
23+
2424
if rounding_mode == "trunc":
2525
application = trunc_application
2626
elif rounding_mode == "floor":
2727
application = floor_application
2828
else:
2929
application = default_application
3030

31-
tensors = (Tensor(ndim), Tensor(ndim), Tensor(ndim))
31+
tensors = (
32+
Tensor(ndim, dtype=dtype),
33+
Tensor(ndim, dtype=dtype),
34+
Tensor(ndim, dtype=dtype),
35+
)
3236

33-
return ninetoothed.make(arrangement, application, tensors)
37+
return arrangement_, application, tensors

0 commit comments

Comments
 (0)