Skip to content

Commit d8235d6

Browse files
committed
Add pooling.arrangement
1 parent dea2c85 commit d8235d6

3 files changed

Lines changed: 72 additions & 156 deletions

File tree

src/ntops/kernels/avg_pool2d.py

Lines changed: 2 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -1,85 +1,9 @@
11
import functools
22

3-
import ninetoothed
43
import ninetoothed.language as ntl
5-
from ninetoothed import Symbol, Tensor
4+
from ninetoothed import Tensor
65

7-
BLOCK_SIZE = ninetoothed.block_size()
8-
9-
KERNEL_SIZE_H = Symbol("kernel_size_h", constexpr=True, upper_bound=16)
10-
KERNEL_SIZE_W = Symbol("kernel_size_w", constexpr=True, upper_bound=16)
11-
STRIDE_H = Symbol("stride_h", constexpr=True)
12-
STRIDE_W = Symbol("stride_w", constexpr=True)
13-
PADDING_H = Symbol("padding_h", constexpr=True)
14-
PADDING_W = Symbol("padding_w", constexpr=True)
15-
DILATION_H = Symbol("dilation_h", constexpr=True)
16-
DILATION_W = Symbol("dilation_w", constexpr=True)
17-
18-
19-
def arrangement(
20-
input,
21-
output,
22-
kernel_size_h=None,
23-
kernel_size_w=None,
24-
stride_h=None,
25-
stride_w=None,
26-
padding_h=None,
27-
padding_w=None,
28-
dilation_h=None,
29-
dilation_w=None,
30-
ceil_mode=None,
31-
block_size=None,
32-
):
33-
if kernel_size_h is None:
34-
kernel_size_h = KERNEL_SIZE_H
35-
36-
if kernel_size_w is None:
37-
kernel_size_w = KERNEL_SIZE_W
38-
39-
if stride_h is None:
40-
stride_h = STRIDE_H
41-
42-
if stride_w is None:
43-
stride_w = STRIDE_W
44-
45-
if padding_h is None:
46-
padding_h = PADDING_H
47-
48-
if padding_w is None:
49-
padding_w = PADDING_W
50-
51-
if dilation_h is None:
52-
dilation_h = DILATION_H
53-
54-
if dilation_w is None:
55-
dilation_w = DILATION_W
56-
57-
if ceil_mode is None:
58-
ceil_mode = False
59-
60-
if block_size is None:
61-
block_size = BLOCK_SIZE
62-
63-
input_arranged = input.pad(
64-
((0, 0), (0, 0), (padding_h, padding_h), (padding_w, padding_w))
65-
)
66-
input_arranged = input_arranged.tile(
67-
(1, 1, kernel_size_h, kernel_size_w),
68-
strides=(-1, -1, stride_h, stride_w),
69-
dilation=(1, 1, dilation_h, dilation_w),
70-
floor_mode=not ceil_mode,
71-
)
72-
input_arranged = input_arranged.ravel()
73-
input_arranged = input_arranged.flatten(end_dim=4).flatten(start_dim=1)
74-
input_arranged = input_arranged.tile((block_size, -1))
75-
76-
output_arranged = output.tile((1, 1, 1, 1))
77-
output_arranged = output_arranged.ravel()
78-
output_arranged = output_arranged.flatten(end_dim=4).flatten(start_dim=1)
79-
output_arranged = output_arranged.tile((block_size, -1))
80-
output_arranged.dtype = output_arranged.dtype.squeeze(1)
81-
82-
return input_arranged, output_arranged
6+
from ntops.kernels.pooling import arrangement
837

848

859
def application(input, output):

src/ntops/kernels/max_pool2d.py

Lines changed: 2 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -1,85 +1,9 @@
11
import functools
22

3-
import ninetoothed
43
import ninetoothed.language as ntl
5-
from ninetoothed import Symbol, Tensor
4+
from ninetoothed import Tensor
65

7-
BLOCK_SIZE = ninetoothed.block_size()
8-
9-
KERNEL_SIZE_H = Symbol("kernel_size_h", constexpr=True, upper_bound=16)
10-
KERNEL_SIZE_W = Symbol("kernel_size_w", constexpr=True, upper_bound=16)
11-
STRIDE_H = Symbol("stride_h", constexpr=True)
12-
STRIDE_W = Symbol("stride_w", constexpr=True)
13-
PADDING_H = Symbol("padding_h", constexpr=True)
14-
PADDING_W = Symbol("padding_w", constexpr=True)
15-
DILATION_H = Symbol("dilation_h", constexpr=True)
16-
DILATION_W = Symbol("dilation_w", constexpr=True)
17-
18-
19-
def arrangement(
20-
input,
21-
output,
22-
kernel_size_h=None,
23-
kernel_size_w=None,
24-
stride_h=None,
25-
stride_w=None,
26-
padding_h=None,
27-
padding_w=None,
28-
dilation_h=None,
29-
dilation_w=None,
30-
ceil_mode=None,
31-
block_size=None,
32-
):
33-
if kernel_size_h is None:
34-
kernel_size_h = KERNEL_SIZE_H
35-
36-
if kernel_size_w is None:
37-
kernel_size_w = KERNEL_SIZE_W
38-
39-
if stride_h is None:
40-
stride_h = STRIDE_H
41-
42-
if stride_w is None:
43-
stride_w = STRIDE_W
44-
45-
if padding_h is None:
46-
padding_h = PADDING_H
47-
48-
if padding_w is None:
49-
padding_w = PADDING_W
50-
51-
if dilation_h is None:
52-
dilation_h = DILATION_H
53-
54-
if dilation_w is None:
55-
dilation_w = DILATION_W
56-
57-
if ceil_mode is None:
58-
ceil_mode = False
59-
60-
if block_size is None:
61-
block_size = BLOCK_SIZE
62-
63-
input_arranged = input.pad(
64-
((0, 0), (0, 0), (padding_h, padding_h), (padding_w, padding_w))
65-
)
66-
input_arranged = input_arranged.tile(
67-
(1, 1, kernel_size_h, kernel_size_w),
68-
strides=(-1, -1, stride_h, stride_w),
69-
dilation=(1, 1, dilation_h, dilation_w),
70-
floor_mode=not ceil_mode,
71-
)
72-
input_arranged = input_arranged.ravel()
73-
input_arranged = input_arranged.flatten(end_dim=4).flatten(start_dim=1)
74-
input_arranged = input_arranged.tile((block_size, -1))
75-
76-
output_arranged = output.tile((1, 1, 1, 1))
77-
output_arranged = output_arranged.ravel()
78-
output_arranged = output_arranged.flatten(end_dim=4).flatten(start_dim=1)
79-
output_arranged = output_arranged.tile((block_size, -1))
80-
output_arranged.dtype = output_arranged.dtype.squeeze(1)
81-
82-
return input_arranged, output_arranged
6+
from ntops.kernels.pooling import arrangement
837

848

859
def application(input, output):

src/ntops/kernels/pooling.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import ninetoothed
2+
from ninetoothed import Symbol
3+
4+
5+
def arrangement(
6+
input,
7+
output,
8+
kernel_size_h=None,
9+
kernel_size_w=None,
10+
stride_h=None,
11+
stride_w=None,
12+
padding_h=None,
13+
padding_w=None,
14+
dilation_h=None,
15+
dilation_w=None,
16+
ceil_mode=None,
17+
block_size=None,
18+
):
19+
if kernel_size_h is None:
20+
kernel_size_h = Symbol("kernel_size_h", constexpr=True, upper_bound=16)
21+
22+
if kernel_size_w is None:
23+
kernel_size_w = Symbol("kernel_size_w", constexpr=True, upper_bound=16)
24+
25+
if stride_h is None:
26+
stride_h = Symbol("stride_h", constexpr=True)
27+
28+
if stride_w is None:
29+
stride_w = Symbol("stride_w", constexpr=True)
30+
31+
if padding_h is None:
32+
padding_h = Symbol("padding_h", constexpr=True)
33+
34+
if padding_w is None:
35+
padding_w = Symbol("padding_w", constexpr=True)
36+
37+
if dilation_h is None:
38+
dilation_h = Symbol("dilation_h", constexpr=True)
39+
40+
if dilation_w is None:
41+
dilation_w = Symbol("dilation_w", constexpr=True)
42+
43+
if ceil_mode is None:
44+
ceil_mode = False
45+
46+
if block_size is None:
47+
block_size = ninetoothed.block_size()
48+
49+
input_arranged = input.pad(
50+
((0, 0), (0, 0), (padding_h, padding_h), (padding_w, padding_w))
51+
)
52+
input_arranged = input_arranged.tile(
53+
(1, 1, kernel_size_h, kernel_size_w),
54+
strides=(-1, -1, stride_h, stride_w),
55+
dilation=(1, 1, dilation_h, dilation_w),
56+
floor_mode=not ceil_mode,
57+
)
58+
input_arranged = input_arranged.ravel()
59+
input_arranged = input_arranged.flatten(end_dim=4).flatten(start_dim=1)
60+
input_arranged = input_arranged.tile((block_size, -1))
61+
62+
output_arranged = output.tile((1, 1, 1, 1))
63+
output_arranged = output_arranged.ravel()
64+
output_arranged = output_arranged.flatten(end_dim=4).flatten(start_dim=1)
65+
output_arranged = output_arranged.tile((block_size, -1))
66+
output_arranged.dtype = output_arranged.dtype.squeeze(1)
67+
68+
return input_arranged, output_arranged

0 commit comments

Comments
 (0)