Skip to content

Commit 4f2640d

Browse files
committed
Add max_pool2d operator
1 parent 10c2cac commit 4f2640d

File tree

5 files changed

+247
-0
lines changed

5 files changed

+247
-0
lines changed

src/ntops/kernels/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
layer_norm,
2121
le,
2222
lt,
23+
max_pool2d,
2324
mm,
2425
mul,
2526
ne,
@@ -60,6 +61,7 @@
6061
"layer_norm",
6162
"le",
6263
"lt",
64+
"max_pool2d",
6365
"mm",
6466
"mul",
6567
"ne",

src/ntops/kernels/max_pool2d.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
import functools
2+
3+
import ninetoothed
4+
import ninetoothed.language as ntl
5+
from ninetoothed import Symbol, Tensor
6+
7+
BLOCK_SIZE = ninetoothed.block_size()
8+
9+
KERNEL_SIZE_H = Symbol("kernel_size_h", constexpr=True, upper_bound=16)
10+
KERNEL_SIZE_W = Symbol("kernel_size_w", constexpr=True, upper_bound=16)
11+
STRIDE_H = Symbol("stride_h", constexpr=True)
12+
STRIDE_W = Symbol("stride_w", constexpr=True)
13+
PADDING_H = Symbol("padding_h", constexpr=True)
14+
PADDING_W = Symbol("padding_w", constexpr=True)
15+
DILATION_H = Symbol("dilation_h", constexpr=True)
16+
DILATION_W = Symbol("dilation_w", constexpr=True)
17+
18+
19+
def arrangement(
20+
input,
21+
output,
22+
kernel_size_h=None,
23+
kernel_size_w=None,
24+
stride_h=None,
25+
stride_w=None,
26+
padding_h=None,
27+
padding_w=None,
28+
dilation_h=None,
29+
dilation_w=None,
30+
ceil_mode=None,
31+
block_size=None,
32+
):
33+
if kernel_size_h is None:
34+
kernel_size_h = KERNEL_SIZE_H
35+
36+
if kernel_size_w is None:
37+
kernel_size_w = KERNEL_SIZE_W
38+
39+
if stride_h is None:
40+
stride_h = STRIDE_H
41+
42+
if stride_w is None:
43+
stride_w = STRIDE_W
44+
45+
if padding_h is None:
46+
padding_h = PADDING_H
47+
48+
if padding_w is None:
49+
padding_w = PADDING_W
50+
51+
if dilation_h is None:
52+
dilation_h = DILATION_H
53+
54+
if dilation_w is None:
55+
dilation_w = DILATION_W
56+
57+
if ceil_mode is None:
58+
ceil_mode = False
59+
60+
if block_size is None:
61+
block_size = BLOCK_SIZE
62+
63+
input_arranged = input.pad(
64+
((0, 0), (0, 0), (padding_h, padding_h), (padding_w, padding_w))
65+
)
66+
input_arranged = input_arranged.tile(
67+
(1, 1, kernel_size_h, kernel_size_w),
68+
strides=(-1, -1, stride_h, stride_w),
69+
dilation=(1, 1, dilation_h, dilation_w),
70+
floor_mode=not ceil_mode,
71+
)
72+
input_arranged = input_arranged.ravel()
73+
input_arranged = input_arranged.flatten(end_dim=4).flatten(start_dim=1)
74+
input_arranged = input_arranged.tile((block_size, -1))
75+
76+
output_arranged = output.tile((1, 1, 1, 1))
77+
output_arranged = output_arranged.ravel()
78+
output_arranged = output_arranged.flatten(end_dim=4).flatten(start_dim=1)
79+
output_arranged = output_arranged.tile((block_size, -1))
80+
output_arranged.dtype = output_arranged.dtype.squeeze(1)
81+
82+
return input_arranged, output_arranged
83+
84+
85+
def application(input, output):
86+
output = ntl.max(input, axis=1) # noqa: F841
87+
88+
89+
def premake(
90+
kernel_size_h=None,
91+
kernel_size_w=None,
92+
stride_h=None,
93+
stride_w=None,
94+
padding_h=None,
95+
padding_w=None,
96+
dilation_h=None,
97+
dilation_w=None,
98+
ceil_mode=None,
99+
dtype=None,
100+
block_size=None,
101+
):
102+
arrangement_ = functools.partial(
103+
arrangement,
104+
kernel_size_h=kernel_size_h,
105+
kernel_size_w=kernel_size_w,
106+
stride_h=stride_h,
107+
stride_w=stride_w,
108+
padding_h=padding_h,
109+
padding_w=padding_w,
110+
dilation_h=dilation_h,
111+
dilation_w=dilation_w,
112+
ceil_mode=ceil_mode,
113+
block_size=block_size,
114+
)
115+
116+
tensors = (Tensor(4, dtype=dtype, other=float("-inf")), Tensor(4, dtype=dtype))
117+
118+
return arrangement_, application, tensors

src/ntops/torch/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from ntops.torch.le import le
2121
from ntops.torch.lt import lt
2222
from ntops.torch.matmul import matmul
23+
from ntops.torch.max_pool2d import max_pool2d
2324
from ntops.torch.mm import mm
2425
from ntops.torch.mul import mul
2526
from ntops.torch.ne import ne
@@ -60,6 +61,7 @@
6061
"le",
6162
"lt",
6263
"matmul",
64+
"max_pool2d",
6365
"mm",
6466
"mul",
6567
"ne",

src/ntops/torch/max_pool2d.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import math
2+
3+
import torch
4+
5+
import ntops
6+
from ntops.torch.utils import _cached_make
7+
8+
9+
def max_pool2d(
10+
input,
11+
kernel_size,
12+
stride=None,
13+
padding=0,
14+
dilation=1,
15+
ceil_mode=False,
16+
return_indices=False,
17+
):
18+
if stride is None:
19+
stride = kernel_size
20+
21+
if isinstance(stride, int):
22+
stride = (stride, stride)
23+
24+
if isinstance(padding, int):
25+
padding = (padding, padding)
26+
27+
if isinstance(dilation, int):
28+
dilation = (dilation, dilation)
29+
30+
assert not return_indices, "`return_indices == True` is not supported yet."
31+
32+
n, c, h, w = input.shape
33+
34+
def _calculate_output_size(
35+
input_size, kernel_size, stride, padding, dilation, ceil_mode
36+
):
37+
int_ = math.ceil if ceil_mode else math.floor
38+
39+
result = int_(
40+
(input_size + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1
41+
)
42+
43+
if ceil_mode and (result - 1) * stride >= input_size + padding:
44+
result -= 1
45+
46+
return result
47+
48+
h_ = _calculate_output_size(
49+
h, kernel_size[0], stride[0], padding[0], dilation[0], ceil_mode
50+
)
51+
w_ = _calculate_output_size(
52+
w, kernel_size[1], stride[1], padding[1], dilation[1], ceil_mode
53+
)
54+
55+
output = torch.empty((n, c, h_, w_), dtype=input.dtype, device=input.device)
56+
57+
kernel = _cached_make(ntops.kernels.max_pool2d.premake, ceil_mode=ceil_mode)
58+
59+
kernel(
60+
input,
61+
output,
62+
kernel_size_h=kernel_size[0],
63+
kernel_size_w=kernel_size[1],
64+
stride_h=stride[0],
65+
stride_w=stride[1],
66+
padding_h=padding[0],
67+
padding_w=padding[1],
68+
dilation_h=dilation[0],
69+
dilation_w=dilation[1],
70+
)
71+
72+
return output

tests/test_max_pool2d.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import pytest
2+
import torch
3+
import torch.nn.functional as F
4+
5+
import ntops
6+
from tests.skippers import skip_if_cuda_not_available
7+
8+
9+
@skip_if_cuda_not_available
10+
@pytest.mark.parametrize("device", ("cuda",))
11+
@pytest.mark.parametrize("dtype", (torch.float32, torch.float16))
12+
@pytest.mark.parametrize("ceil_mode", (False, True))
13+
@pytest.mark.parametrize("dilation", (1, 2, (2, 3)))
14+
@pytest.mark.parametrize("padding", (0, 1, (2, 3)))
15+
@pytest.mark.parametrize("stride", (None, 1, (2, 3)))
16+
@pytest.mark.parametrize("kernel_size", ((1, 1), (3, 3)))
17+
@pytest.mark.parametrize("n, c, h, w", ((2, 3, 112, 112),))
18+
def test_max_pool2d(
19+
n, c, h, w, kernel_size, stride, padding, dilation, ceil_mode, dtype, device
20+
):
21+
padding_ = padding
22+
23+
if isinstance(padding_, int):
24+
padding_ = (padding_, padding_)
25+
26+
dilation_ = dilation
27+
28+
if isinstance(dilation_, int):
29+
dilation_ = (dilation_, dilation_)
30+
31+
if padding_[0] > kernel_size[0] / 2 or padding_[1] > kernel_size[1] / 2:
32+
pytest.skip(reason="Invalid padding.")
33+
34+
input = torch.randn((n, c, h, w), dtype=dtype, device=device)
35+
36+
ninetoothed_output = ntops.torch.max_pool2d(
37+
input,
38+
kernel_size=kernel_size,
39+
stride=stride,
40+
padding=padding,
41+
dilation=dilation,
42+
ceil_mode=ceil_mode,
43+
)
44+
reference_output = F.max_pool2d(
45+
input,
46+
kernel_size=kernel_size,
47+
stride=stride,
48+
padding=padding,
49+
dilation=dilation,
50+
ceil_mode=ceil_mode,
51+
)
52+
53+
assert torch.allclose(ninetoothed_output, reference_output)

0 commit comments

Comments
 (0)