Skip to content

Commit 6bc90d5

Browse files
authored
Develop ResNet operators (#73)
* Add `max_pool2d` operator * Add `avg_pool2d` operator * Add `pooling.arrangement` * Add `ntops.torch.pooling._calculate_output_size` * Add `conv2d` operator
1 parent 10c2cac commit 6bc90d5

File tree

13 files changed

+636
-0
lines changed

13 files changed

+636
-0
lines changed

src/ntops/kernels/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22
abs,
33
add,
44
addmm,
5+
avg_pool2d,
56
bitwise_and,
67
bitwise_not,
78
bitwise_or,
89
bmm,
910
clamp,
11+
conv2d,
1012
cos,
1113
div,
1214
dropout,
@@ -20,6 +22,7 @@
2022
layer_norm,
2123
le,
2224
lt,
25+
max_pool2d,
2326
mm,
2427
mul,
2528
ne,
@@ -42,11 +45,13 @@
4245
"abs",
4346
"add",
4447
"addmm",
48+
"avg_pool2d",
4549
"bitwise_and",
4650
"bitwise_not",
4751
"bitwise_or",
4852
"bmm",
4953
"clamp",
54+
"conv2d",
5055
"cos",
5156
"div",
5257
"dropout",
@@ -60,6 +65,7 @@
6065
"layer_norm",
6166
"le",
6267
"lt",
68+
"max_pool2d",
6369
"mm",
6470
"mul",
6571
"ne",

src/ntops/kernels/avg_pool2d.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import functools
2+
3+
import ninetoothed.language as ntl
4+
from ninetoothed import Tensor
5+
6+
from ntops.kernels.pooling import arrangement
7+
8+
9+
def application(input, output):
10+
output = ntl.sum(input, axis=-1) / input.shape[-1] # noqa: F841
11+
12+
13+
def premake(
14+
kernel_size_h=None,
15+
kernel_size_w=None,
16+
stride_h=None,
17+
stride_w=None,
18+
padding_h=None,
19+
padding_w=None,
20+
dilation_h=None,
21+
dilation_w=None,
22+
ceil_mode=None,
23+
dtype=None,
24+
block_size=None,
25+
):
26+
arrangement_ = functools.partial(
27+
arrangement,
28+
kernel_size_h=kernel_size_h,
29+
kernel_size_w=kernel_size_w,
30+
stride_h=stride_h,
31+
stride_w=stride_w,
32+
padding_h=padding_h,
33+
padding_w=padding_w,
34+
dilation_h=dilation_h,
35+
dilation_w=dilation_w,
36+
ceil_mode=ceil_mode,
37+
block_size=block_size,
38+
)
39+
40+
tensors = (Tensor(4, dtype=dtype), Tensor(4, dtype=dtype))
41+
42+
return arrangement_, application, tensors

src/ntops/kernels/conv2d.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
import copy
2+
import functools
3+
4+
import ninetoothed.language as ntl
5+
from ninetoothed import Symbol, Tensor
6+
7+
from ntops.kernels import mm
8+
9+
10+
def arrangement(
11+
input,
12+
weight,
13+
bias,
14+
output,
15+
input_precision,
16+
stride_h=None,
17+
stride_w=None,
18+
padding_h=None,
19+
padding_w=None,
20+
dilation_h=None,
21+
dilation_w=None,
22+
block_size_m=None,
23+
block_size_n=None,
24+
block_size_k=None,
25+
):
26+
if stride_h is None:
27+
stride_h = Symbol("stride_h", constexpr=True)
28+
29+
if stride_w is None:
30+
stride_w = Symbol("stride_w", constexpr=True)
31+
32+
if padding_h is None:
33+
padding_h = Symbol("padding_h", constexpr=True)
34+
35+
if padding_w is None:
36+
padding_w = Symbol("padding_w", constexpr=True)
37+
38+
if dilation_h is None:
39+
dilation_h = Symbol("dilation_h", constexpr=True)
40+
41+
if dilation_w is None:
42+
dilation_w = Symbol("dilation_w", constexpr=True)
43+
44+
if block_size_m is None:
45+
block_size_m = mm.BLOCK_SIZE_M
46+
47+
if block_size_n is None:
48+
block_size_n = mm.BLOCK_SIZE_N
49+
50+
if block_size_k is None:
51+
block_size_k = mm.BLOCK_SIZE_K
52+
53+
mm_arrangement = functools.partial(
54+
mm.arrangement,
55+
block_size_m=block_size_m,
56+
block_size_n=block_size_n,
57+
block_size_k=block_size_k,
58+
)
59+
60+
input_arranged = input.pad(
61+
((0, 0), (0, 0), (padding_h, padding_h), (padding_w, padding_w))
62+
)
63+
input_arranged = input_arranged.tile(
64+
(1, *weight.shape[1:]),
65+
strides=(-1, -1, stride_h, stride_w),
66+
dilation=(1, 1, dilation_h, dilation_w),
67+
floor_mode=True,
68+
)
69+
input_arranged = input_arranged.squeeze(1)
70+
input_arranged.dtype = input_arranged.dtype.squeeze(0)
71+
input_arranged = input_arranged.ravel()
72+
input_arranged = input_arranged.flatten(end_dim=3).flatten(start_dim=1)
73+
74+
weight_arranged = weight.flatten(start_dim=1)
75+
weight_arranged = weight_arranged.permute((1, 0))
76+
77+
bias_arranged = bias[None, :, None, None].expand(
78+
(output.shape[0], -1, output.shape[2], output.shape[3])
79+
)
80+
bias_arranged = bias_arranged.permute((0, 2, 3, 1)).flatten(end_dim=3)
81+
82+
output_arranged = output.permute((0, 2, 3, 1)).flatten(end_dim=3)
83+
84+
_, _, bias_arranged, _ = mm_arrangement(
85+
copy.deepcopy(input_arranged),
86+
copy.deepcopy(weight_arranged),
87+
bias_arranged,
88+
copy.deepcopy(input_precision),
89+
)
90+
91+
input_arranged, weight_arranged, output_arranged, input_precision_arranged = (
92+
mm_arrangement(
93+
input_arranged, weight_arranged, output_arranged, input_precision
94+
)
95+
)
96+
97+
return (
98+
input_arranged,
99+
weight_arranged,
100+
bias_arranged,
101+
output_arranged,
102+
input_precision_arranged,
103+
)
104+
105+
106+
def application(input, weight, bias, output, input_precision):
107+
mm_output = ntl.zeros(output.shape, dtype=ntl.float32)
108+
mm.application(input, weight, mm_output, input_precision)
109+
output = mm_output + bias
110+
111+
112+
def premake(
113+
input_precision=None,
114+
stride_h=None,
115+
stride_w=None,
116+
padding_h=None,
117+
padding_w=None,
118+
dilation_h=None,
119+
dilation_w=None,
120+
dtype=None,
121+
block_size_m=None,
122+
block_size_n=None,
123+
block_size_k=None,
124+
):
125+
arrangement_ = functools.partial(
126+
arrangement,
127+
stride_h=stride_h,
128+
stride_w=stride_w,
129+
padding_h=padding_h,
130+
padding_w=padding_w,
131+
dilation_h=dilation_h,
132+
dilation_w=dilation_w,
133+
block_size_m=block_size_m,
134+
block_size_n=block_size_n,
135+
block_size_k=block_size_k,
136+
)
137+
138+
input, weight, output = (Tensor(4, dtype=dtype) for _ in range(3))
139+
bias = Tensor(1, dtype=dtype)
140+
input_precision = Tensor(0, dtype=dtype, constexpr=True, value=input_precision)
141+
142+
tensors = (input, weight, bias, output, input_precision)
143+
144+
return arrangement_, application, tensors

src/ntops/kernels/max_pool2d.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import functools
2+
3+
import ninetoothed.language as ntl
4+
from ninetoothed import Tensor
5+
6+
from ntops.kernels.pooling import arrangement
7+
8+
9+
def application(input, output):
10+
output = ntl.max(input, axis=-1) # noqa: F841
11+
12+
13+
def premake(
14+
kernel_size_h=None,
15+
kernel_size_w=None,
16+
stride_h=None,
17+
stride_w=None,
18+
padding_h=None,
19+
padding_w=None,
20+
dilation_h=None,
21+
dilation_w=None,
22+
ceil_mode=None,
23+
dtype=None,
24+
block_size=None,
25+
):
26+
arrangement_ = functools.partial(
27+
arrangement,
28+
kernel_size_h=kernel_size_h,
29+
kernel_size_w=kernel_size_w,
30+
stride_h=stride_h,
31+
stride_w=stride_w,
32+
padding_h=padding_h,
33+
padding_w=padding_w,
34+
dilation_h=dilation_h,
35+
dilation_w=dilation_w,
36+
ceil_mode=ceil_mode,
37+
block_size=block_size,
38+
)
39+
40+
tensors = (Tensor(4, dtype=dtype, other=float("-inf")), Tensor(4, dtype=dtype))
41+
42+
return arrangement_, application, tensors

src/ntops/kernels/pooling.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import ninetoothed
2+
from ninetoothed import Symbol
3+
4+
5+
def arrangement(
6+
input,
7+
output,
8+
kernel_size_h=None,
9+
kernel_size_w=None,
10+
stride_h=None,
11+
stride_w=None,
12+
padding_h=None,
13+
padding_w=None,
14+
dilation_h=None,
15+
dilation_w=None,
16+
ceil_mode=None,
17+
block_size=None,
18+
):
19+
if kernel_size_h is None:
20+
kernel_size_h = Symbol("kernel_size_h", constexpr=True, upper_bound=16)
21+
22+
if kernel_size_w is None:
23+
kernel_size_w = Symbol("kernel_size_w", constexpr=True, upper_bound=16)
24+
25+
if stride_h is None:
26+
stride_h = Symbol("stride_h", constexpr=True)
27+
28+
if stride_w is None:
29+
stride_w = Symbol("stride_w", constexpr=True)
30+
31+
if padding_h is None:
32+
padding_h = Symbol("padding_h", constexpr=True)
33+
34+
if padding_w is None:
35+
padding_w = Symbol("padding_w", constexpr=True)
36+
37+
if dilation_h is None:
38+
dilation_h = Symbol("dilation_h", constexpr=True)
39+
40+
if dilation_w is None:
41+
dilation_w = Symbol("dilation_w", constexpr=True)
42+
43+
if ceil_mode is None:
44+
ceil_mode = False
45+
46+
if block_size is None:
47+
block_size = ninetoothed.block_size()
48+
49+
input_arranged = input.pad(
50+
((0, 0), (0, 0), (padding_h, padding_h), (padding_w, padding_w))
51+
)
52+
input_arranged = input_arranged.tile(
53+
(1, 1, kernel_size_h, kernel_size_w),
54+
strides=(-1, -1, stride_h, stride_w),
55+
dilation=(1, 1, dilation_h, dilation_w),
56+
floor_mode=not ceil_mode,
57+
)
58+
input_arranged = input_arranged.ravel()
59+
input_arranged = input_arranged.flatten(end_dim=4).flatten(start_dim=1)
60+
input_arranged = input_arranged.tile((block_size, -1))
61+
62+
output_arranged = output.tile((1, 1, 1, 1))
63+
output_arranged = output_arranged.ravel()
64+
output_arranged = output_arranged.flatten(end_dim=4).flatten(start_dim=1)
65+
output_arranged = output_arranged.tile((block_size, -1))
66+
output_arranged.dtype = output_arranged.dtype.squeeze(1)
67+
68+
return input_arranged, output_arranged

src/ntops/torch/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
from ntops.torch.abs import abs
22
from ntops.torch.add import add
33
from ntops.torch.addmm import addmm
4+
from ntops.torch.avg_pool2d import avg_pool2d
45
from ntops.torch.bitwise_and import bitwise_and
56
from ntops.torch.bitwise_not import bitwise_not
67
from ntops.torch.bitwise_or import bitwise_or
78
from ntops.torch.bmm import bmm
89
from ntops.torch.clamp import clamp
10+
from ntops.torch.conv2d import conv2d
911
from ntops.torch.cos import cos
1012
from ntops.torch.div import div
1113
from ntops.torch.dropout import dropout
@@ -20,6 +22,7 @@
2022
from ntops.torch.le import le
2123
from ntops.torch.lt import lt
2224
from ntops.torch.matmul import matmul
25+
from ntops.torch.max_pool2d import max_pool2d
2326
from ntops.torch.mm import mm
2427
from ntops.torch.mul import mul
2528
from ntops.torch.ne import ne
@@ -41,11 +44,13 @@
4144
"abs",
4245
"add",
4346
"addmm",
47+
"avg_pool2d",
4448
"bitwise_and",
4549
"bitwise_not",
4650
"bitwise_or",
4751
"bmm",
4852
"clamp",
53+
"conv2d",
4954
"cos",
5055
"div",
5156
"dropout",
@@ -60,6 +65,7 @@
6065
"le",
6166
"lt",
6267
"matmul",
68+
"max_pool2d",
6369
"mm",
6470
"mul",
6571
"ne",

0 commit comments

Comments
 (0)