Skip to content

Commit 9440ce9

Browse files
committed
Add conv2d operator
1 parent a55bb05 commit 9440ce9

5 files changed

Lines changed: 232 additions & 0 deletions

File tree

src/ntops/kernels/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
bitwise_or,
99
bmm,
1010
clamp,
11+
conv2d,
1112
cos,
1213
div,
1314
dropout,
@@ -50,6 +51,7 @@
5051
"bitwise_or",
5152
"bmm",
5253
"clamp",
54+
"conv2d",
5355
"cos",
5456
"div",
5557
"dropout",

src/ntops/kernels/conv2d.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
import copy
2+
import functools
3+
4+
import ninetoothed.language as ntl
5+
from ninetoothed import Symbol, Tensor
6+
7+
from ntops.kernels import mm
8+
9+
10+
def arrangement(
11+
input,
12+
weight,
13+
bias,
14+
output,
15+
input_precision,
16+
stride_h=None,
17+
stride_w=None,
18+
padding_h=None,
19+
padding_w=None,
20+
dilation_h=None,
21+
dilation_w=None,
22+
block_size_m=None,
23+
block_size_n=None,
24+
block_size_k=None,
25+
):
26+
if stride_h is None:
27+
stride_h = Symbol("stride_h", constexpr=True)
28+
29+
if stride_w is None:
30+
stride_w = Symbol("stride_w", constexpr=True)
31+
32+
if padding_h is None:
33+
padding_h = Symbol("padding_h", constexpr=True)
34+
35+
if padding_w is None:
36+
padding_w = Symbol("padding_w", constexpr=True)
37+
38+
if dilation_h is None:
39+
dilation_h = Symbol("dilation_h", constexpr=True)
40+
41+
if dilation_w is None:
42+
dilation_w = Symbol("dilation_w", constexpr=True)
43+
44+
if block_size_m is None:
45+
block_size_m = mm.BLOCK_SIZE_M
46+
47+
if block_size_n is None:
48+
block_size_n = mm.BLOCK_SIZE_N
49+
50+
if block_size_k is None:
51+
block_size_k = mm.BLOCK_SIZE_K
52+
53+
mm_arrangement = functools.partial(
54+
mm.arrangement,
55+
block_size_m=block_size_m,
56+
block_size_n=block_size_n,
57+
block_size_k=block_size_k,
58+
)
59+
60+
input_arranged = input.pad(
61+
((0, 0), (0, 0), (padding_h, padding_h), (padding_w, padding_w))
62+
)
63+
input_arranged = input_arranged.tile(
64+
(1, *weight.shape[1:]),
65+
strides=(-1, -1, stride_h, stride_w),
66+
dilation=(1, 1, dilation_h, dilation_w),
67+
floor_mode=True,
68+
)
69+
input_arranged = input_arranged.squeeze(1)
70+
input_arranged.dtype = input_arranged.dtype.squeeze(0)
71+
input_arranged = input_arranged.ravel()
72+
input_arranged = input_arranged.flatten(end_dim=3).flatten(start_dim=1)
73+
74+
weight_arranged = weight.flatten(start_dim=1)
75+
weight_arranged = weight_arranged.permute((1, 0))
76+
77+
bias_arranged = bias[None, :, None, None].expand(
78+
(output.shape[0], -1, output.shape[2], output.shape[3])
79+
)
80+
bias_arranged = bias_arranged.permute((0, 2, 3, 1)).flatten(end_dim=3)
81+
82+
output_arranged = output.permute((0, 2, 3, 1)).flatten(end_dim=3)
83+
84+
_, _, bias_arranged, _ = mm_arrangement(
85+
copy.deepcopy(input_arranged),
86+
copy.deepcopy(weight_arranged),
87+
bias_arranged,
88+
copy.deepcopy(input_precision),
89+
)
90+
91+
input_arranged, weight_arranged, output_arranged, input_precision_arranged = (
92+
mm_arrangement(
93+
input_arranged, weight_arranged, output_arranged, input_precision
94+
)
95+
)
96+
97+
return (
98+
input_arranged,
99+
weight_arranged,
100+
bias_arranged,
101+
output_arranged,
102+
input_precision_arranged,
103+
)
104+
105+
106+
def application(input, weight, bias, output, input_precision):
107+
mm_output = ntl.zeros(output.shape, dtype=ntl.float32)
108+
mm.application(input, weight, mm_output, input_precision)
109+
output = mm_output + bias
110+
111+
112+
def premake(
113+
input_precision=None,
114+
stride_h=None,
115+
stride_w=None,
116+
padding_h=None,
117+
padding_w=None,
118+
dilation_h=None,
119+
dilation_w=None,
120+
dtype=None,
121+
block_size_m=None,
122+
block_size_n=None,
123+
block_size_k=None,
124+
):
125+
arrangement_ = functools.partial(
126+
arrangement,
127+
stride_h=stride_h,
128+
stride_w=stride_w,
129+
padding_h=padding_h,
130+
padding_w=padding_w,
131+
dilation_h=dilation_h,
132+
dilation_w=dilation_w,
133+
block_size_m=block_size_m,
134+
block_size_n=block_size_n,
135+
block_size_k=block_size_k,
136+
)
137+
138+
input, weight, output = (Tensor(4, dtype=dtype) for _ in range(3))
139+
bias = Tensor(1, dtype=dtype)
140+
input_precision = Tensor(0, dtype=dtype, constexpr=True, value=input_precision)
141+
142+
tensors = (input, weight, bias, output, input_precision)
143+
144+
return arrangement_, application, tensors

src/ntops/torch/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from ntops.torch.bitwise_or import bitwise_or
88
from ntops.torch.bmm import bmm
99
from ntops.torch.clamp import clamp
10+
from ntops.torch.conv2d import conv2d
1011
from ntops.torch.cos import cos
1112
from ntops.torch.div import div
1213
from ntops.torch.dropout import dropout
@@ -49,6 +50,7 @@
4950
"bitwise_or",
5051
"bmm",
5152
"clamp",
53+
"conv2d",
5254
"cos",
5355
"div",
5456
"dropout",

src/ntops/torch/conv2d.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import torch
2+
3+
import ntops
4+
from ntops.torch.pooling import _calculate_output_size
5+
from ntops.torch.utils import _cached_make, _get_matmul_input_precision
6+
7+
8+
def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
9+
if isinstance(stride, int):
10+
stride = (stride, stride)
11+
12+
if isinstance(padding, str):
13+
if padding == "valid":
14+
padding = 0
15+
16+
if isinstance(padding, int):
17+
padding = (padding, padding)
18+
19+
if isinstance(dilation, int):
20+
dilation = (dilation, dilation)
21+
22+
assert groups == 1, "`groups` is not supported yet."
23+
24+
n, c, h, w = input.shape
25+
k, _, r, s = weight.shape
26+
27+
p = _calculate_output_size(
28+
h, r, stride=stride[0], padding=padding[0], dilation=dilation[0]
29+
)
30+
q = _calculate_output_size(
31+
w, s, stride=stride[1], padding=padding[1], dilation=dilation[1]
32+
)
33+
34+
output = torch.empty((n, k, p, q), dtype=input.dtype, device=input.device)
35+
36+
if bias is None:
37+
bias = torch.zeros((k,), dtype=output.dtype, device=output.device)
38+
39+
kernel = _cached_make(
40+
ntops.kernels.conv2d.premake,
41+
stride_h=stride[0],
42+
stride_w=stride[1],
43+
padding_h=padding[0],
44+
padding_w=padding[1],
45+
dilation_h=dilation[0],
46+
dilation_w=dilation[1],
47+
)
48+
49+
kernel(input, weight, bias, output, _get_matmul_input_precision())
50+
51+
return output

tests/test_conv2d.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import pytest
2+
import torch
3+
import torch.nn.functional as F
4+
5+
import ntops
6+
from tests.skippers import skip_if_cuda_not_available
7+
8+
9+
@skip_if_cuda_not_available
10+
@pytest.mark.parametrize("device", ("cuda",))
11+
@pytest.mark.parametrize(
12+
"dtype, rtol, atol", ((torch.float32, 1e-5, 1e-5), (torch.float16, 1e-3, 1e-3))
13+
)
14+
@pytest.mark.parametrize("dilation", (1, 2, (2, 3)))
15+
@pytest.mark.parametrize("padding", (0, 1, (2, 3)))
16+
@pytest.mark.parametrize("stride", (1, 2, (2, 3)))
17+
@pytest.mark.parametrize("r, s", ((1, 1), (3, 3)))
18+
@pytest.mark.parametrize("n, c, h, w, k", ((2, 3, 112, 112, 4),))
19+
def test_conv2d(
20+
n, c, h, w, k, r, s, stride, padding, dilation, dtype, device, rtol, atol
21+
):
22+
input = torch.randn((n, c, h, w), dtype=dtype, device=device)
23+
weight = torch.randn((k, c, r, s), dtype=dtype, device=device)
24+
bias = torch.randn((k,), dtype=dtype, device=device)
25+
26+
ninetoothed_output = ntops.torch.conv2d(
27+
input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation
28+
)
29+
reference_output = F.conv2d(
30+
input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation
31+
)
32+
33+
assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol)

0 commit comments

Comments
 (0)