Skip to content

Commit dea2c85

Browse files
committed
Add avg_pool2d operator
1 parent 7c9029e commit dea2c85

5 files changed

Lines changed: 237 additions & 0 deletions

File tree

src/ntops/kernels/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
abs,
33
add,
44
addmm,
5+
avg_pool2d,
56
bitwise_and,
67
bitwise_not,
78
bitwise_or,
@@ -43,6 +44,7 @@
4344
"abs",
4445
"add",
4546
"addmm",
47+
"avg_pool2d",
4648
"bitwise_and",
4749
"bitwise_not",
4850
"bitwise_or",

src/ntops/kernels/avg_pool2d.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
import functools
2+
3+
import ninetoothed
4+
import ninetoothed.language as ntl
5+
from ninetoothed import Symbol, Tensor
6+
7+
BLOCK_SIZE = ninetoothed.block_size()
8+
9+
KERNEL_SIZE_H = Symbol("kernel_size_h", constexpr=True, upper_bound=16)
10+
KERNEL_SIZE_W = Symbol("kernel_size_w", constexpr=True, upper_bound=16)
11+
STRIDE_H = Symbol("stride_h", constexpr=True)
12+
STRIDE_W = Symbol("stride_w", constexpr=True)
13+
PADDING_H = Symbol("padding_h", constexpr=True)
14+
PADDING_W = Symbol("padding_w", constexpr=True)
15+
DILATION_H = Symbol("dilation_h", constexpr=True)
16+
DILATION_W = Symbol("dilation_w", constexpr=True)
17+
18+
19+
def arrangement(
20+
input,
21+
output,
22+
kernel_size_h=None,
23+
kernel_size_w=None,
24+
stride_h=None,
25+
stride_w=None,
26+
padding_h=None,
27+
padding_w=None,
28+
dilation_h=None,
29+
dilation_w=None,
30+
ceil_mode=None,
31+
block_size=None,
32+
):
33+
if kernel_size_h is None:
34+
kernel_size_h = KERNEL_SIZE_H
35+
36+
if kernel_size_w is None:
37+
kernel_size_w = KERNEL_SIZE_W
38+
39+
if stride_h is None:
40+
stride_h = STRIDE_H
41+
42+
if stride_w is None:
43+
stride_w = STRIDE_W
44+
45+
if padding_h is None:
46+
padding_h = PADDING_H
47+
48+
if padding_w is None:
49+
padding_w = PADDING_W
50+
51+
if dilation_h is None:
52+
dilation_h = DILATION_H
53+
54+
if dilation_w is None:
55+
dilation_w = DILATION_W
56+
57+
if ceil_mode is None:
58+
ceil_mode = False
59+
60+
if block_size is None:
61+
block_size = BLOCK_SIZE
62+
63+
input_arranged = input.pad(
64+
((0, 0), (0, 0), (padding_h, padding_h), (padding_w, padding_w))
65+
)
66+
input_arranged = input_arranged.tile(
67+
(1, 1, kernel_size_h, kernel_size_w),
68+
strides=(-1, -1, stride_h, stride_w),
69+
dilation=(1, 1, dilation_h, dilation_w),
70+
floor_mode=not ceil_mode,
71+
)
72+
input_arranged = input_arranged.ravel()
73+
input_arranged = input_arranged.flatten(end_dim=4).flatten(start_dim=1)
74+
input_arranged = input_arranged.tile((block_size, -1))
75+
76+
output_arranged = output.tile((1, 1, 1, 1))
77+
output_arranged = output_arranged.ravel()
78+
output_arranged = output_arranged.flatten(end_dim=4).flatten(start_dim=1)
79+
output_arranged = output_arranged.tile((block_size, -1))
80+
output_arranged.dtype = output_arranged.dtype.squeeze(1)
81+
82+
return input_arranged, output_arranged
83+
84+
85+
def application(input, output):
86+
output = ntl.sum(input, axis=-1) / input.shape[-1] # noqa: F841
87+
88+
89+
def premake(
90+
kernel_size_h=None,
91+
kernel_size_w=None,
92+
stride_h=None,
93+
stride_w=None,
94+
padding_h=None,
95+
padding_w=None,
96+
dilation_h=None,
97+
dilation_w=None,
98+
ceil_mode=None,
99+
dtype=None,
100+
block_size=None,
101+
):
102+
arrangement_ = functools.partial(
103+
arrangement,
104+
kernel_size_h=kernel_size_h,
105+
kernel_size_w=kernel_size_w,
106+
stride_h=stride_h,
107+
stride_w=stride_w,
108+
padding_h=padding_h,
109+
padding_w=padding_w,
110+
dilation_h=dilation_h,
111+
dilation_w=dilation_w,
112+
ceil_mode=ceil_mode,
113+
block_size=block_size,
114+
)
115+
116+
tensors = (Tensor(4, dtype=dtype), Tensor(4, dtype=dtype))
117+
118+
return arrangement_, application, tensors

src/ntops/torch/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from ntops.torch.abs import abs
22
from ntops.torch.add import add
33
from ntops.torch.addmm import addmm
4+
from ntops.torch.avg_pool2d import avg_pool2d
45
from ntops.torch.bitwise_and import bitwise_and
56
from ntops.torch.bitwise_not import bitwise_not
67
from ntops.torch.bitwise_or import bitwise_or
@@ -42,6 +43,7 @@
4243
"abs",
4344
"add",
4445
"addmm",
46+
"avg_pool2d",
4547
"bitwise_and",
4648
"bitwise_not",
4749
"bitwise_or",

src/ntops/torch/avg_pool2d.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import math
2+
3+
import torch
4+
5+
import ntops
6+
from ntops.torch.utils import _cached_make
7+
8+
9+
def avg_pool2d(
10+
input,
11+
kernel_size,
12+
stride=None,
13+
padding=0,
14+
ceil_mode=False,
15+
count_include_pad=True,
16+
divisor_override=None,
17+
):
18+
if stride is None:
19+
stride = kernel_size
20+
21+
if isinstance(stride, int):
22+
stride = (stride, stride)
23+
24+
if isinstance(padding, int):
25+
padding = (padding, padding)
26+
27+
assert not ceil_mode, "`ceil_mode` is not supported yet."
28+
29+
assert count_include_pad, "`count_include_pad` is not supported yet."
30+
31+
assert divisor_override is None, "`divisor_override` is not supported yet."
32+
33+
n, c, h, w = input.shape
34+
35+
def _calculate_output_size(input_size, kernel_size, stride, padding, ceil_mode):
36+
int_ = math.ceil if ceil_mode else math.floor
37+
38+
result = int_((input_size + 2 * padding - kernel_size) / stride + 1)
39+
40+
if ceil_mode and (result - 1) * stride >= input_size + padding:
41+
result -= 1
42+
43+
return result
44+
45+
h_ = _calculate_output_size(h, kernel_size[0], stride[0], padding[0], ceil_mode)
46+
w_ = _calculate_output_size(w, kernel_size[1], stride[1], padding[1], ceil_mode)
47+
48+
output = torch.empty((n, c, h_, w_), dtype=input.dtype, device=input.device)
49+
50+
kernel = _cached_make(
51+
ntops.kernels.avg_pool2d.premake,
52+
dilation_h=1,
53+
dilation_w=1,
54+
ceil_mode=ceil_mode,
55+
)
56+
57+
kernel(
58+
input,
59+
output,
60+
kernel_size_h=kernel_size[0],
61+
kernel_size_w=kernel_size[1],
62+
stride_h=stride[0],
63+
stride_w=stride[1],
64+
padding_h=padding[0],
65+
padding_w=padding[1],
66+
)
67+
68+
return output

tests/test_avg_pool2d.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import pytest
2+
import torch
3+
import torch.nn.functional as F
4+
5+
import ntops
6+
from tests.skippers import skip_if_cuda_not_available
7+
8+
9+
@skip_if_cuda_not_available
10+
@pytest.mark.parametrize("device", ("cuda",))
11+
@pytest.mark.parametrize(
12+
"dtype, rtol, atol", ((torch.float32, 1e-5, 1e-5), (torch.float16, 1e-3, 1e-3))
13+
)
14+
@pytest.mark.parametrize("ceil_mode", (False,))
15+
@pytest.mark.parametrize("padding", (0, 1, (2, 3)))
16+
@pytest.mark.parametrize("stride", (None, 1, (2, 3)))
17+
@pytest.mark.parametrize("kernel_size", ((1, 1), (3, 3)))
18+
@pytest.mark.parametrize("n, c, h, w", ((2, 3, 112, 112),))
19+
def test_avg_pool2d(
20+
n, c, h, w, kernel_size, stride, padding, ceil_mode, dtype, device, rtol, atol
21+
):
22+
padding_ = padding
23+
24+
if isinstance(padding_, int):
25+
padding_ = (padding_, padding_)
26+
27+
if padding_[0] > kernel_size[0] / 2 or padding_[1] > kernel_size[1] / 2:
28+
pytest.skip(reason="Invalid padding.")
29+
30+
input = torch.randn((n, c, h, w), dtype=dtype, device=device)
31+
32+
ninetoothed_output = ntops.torch.avg_pool2d(
33+
input,
34+
kernel_size=kernel_size,
35+
stride=stride,
36+
padding=padding,
37+
ceil_mode=ceil_mode,
38+
)
39+
reference_output = F.avg_pool2d(
40+
input,
41+
kernel_size=kernel_size,
42+
stride=stride,
43+
padding=padding,
44+
ceil_mode=ceil_mode,
45+
)
46+
47+
assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol)

0 commit comments

Comments
 (0)