Skip to content

Commit bbeae9c

Browse files
committed
Add conv2d operator
1 parent b2623d4 commit bbeae9c

3 files changed

Lines changed: 191 additions & 0 deletions

File tree

src/ntops/kernels/conv2d.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
import functools
2+
3+
import ninetoothed.language as ntl
4+
from ninetoothed import Symbol, Tensor
5+
6+
import ntops.kernels.mm as mm
7+
8+
STRIDE_H = Symbol("stride_h", constexpr=True)
9+
STRIDE_W = Symbol("stride_w", constexpr=True)
10+
DILATION_H = Symbol("dilation_h", constexpr=True)
11+
DILATION_W = Symbol("dilation_w", constexpr=True)
12+
13+
14+
def arrangement(
15+
input,
16+
weight,
17+
bias,
18+
output,
19+
stride_h=None,
20+
stride_w=None,
21+
dilation_h=None,
22+
dilation_w=None,
23+
block_size_m=None,
24+
block_size_n=None,
25+
block_size_k=None,
26+
):
27+
if stride_h is None:
28+
stride_h = STRIDE_H
29+
30+
if stride_w is None:
31+
stride_w = STRIDE_W
32+
33+
if dilation_h is None:
34+
dilation_h = DILATION_H
35+
36+
if dilation_w is None:
37+
dilation_w = DILATION_W
38+
39+
if block_size_m is None:
40+
block_size_m = mm.BLOCK_SIZE_M
41+
42+
if block_size_n is None:
43+
block_size_n = mm.BLOCK_SIZE_N
44+
45+
if block_size_k is None:
46+
block_size_k = mm.BLOCK_SIZE_K
47+
48+
mm_arrangement = functools.partial(
49+
mm.arrangement,
50+
block_size_m=block_size_m,
51+
block_size_n=block_size_n,
52+
block_size_k=block_size_k,
53+
)
54+
55+
input_arranged = input.tile(
56+
(1, *weight.shape[1:]),
57+
strides=(-1, -1, stride_h, stride_w),
58+
dilation=(1, 1, dilation_h, dilation_w),
59+
floor_mode=True,
60+
)
61+
input_arranged = input_arranged.squeeze(1)
62+
input_arranged.dtype = input_arranged.dtype.squeeze(0)
63+
input_arranged = input_arranged.ravel()
64+
input_arranged = input_arranged.flatten(end_dim=3).flatten(start_dim=1)
65+
66+
weight_arranged = weight.flatten(start_dim=1)
67+
weight_arranged = weight_arranged.permute((1, 0))
68+
69+
bias_arranged = bias.permute((0, 2, 3, 1)).flatten(end_dim=3)
70+
71+
_, _, bias_arranged = mm_arrangement(input_arranged, weight_arranged, bias_arranged)
72+
73+
output_arranged = output.permute((0, 2, 3, 1)).flatten(end_dim=3)
74+
75+
input_arranged, weight_arranged, output_arranged = mm_arrangement(
76+
input_arranged, weight_arranged, output_arranged
77+
)
78+
79+
return input_arranged, weight_arranged, bias_arranged, output_arranged
80+
81+
82+
def application(input, weight, bias, output):
83+
mm_output = ntl.zeros(output.shape, dtype=ntl.float32)
84+
mm.application(input, weight, mm_output)
85+
output = mm_output + bias
86+
87+
88+
def premake(
89+
stride_h=None,
90+
stride_w=None,
91+
dilation_h=None,
92+
dilation_w=None,
93+
dtype=None,
94+
block_size_m=None,
95+
block_size_n=None,
96+
block_size_k=None,
97+
):
98+
arrangement_ = functools.partial(
99+
arrangement,
100+
stride_h=stride_h,
101+
stride_w=stride_w,
102+
dilation_h=dilation_h,
103+
dilation_w=dilation_w,
104+
block_size_m=block_size_m,
105+
block_size_n=block_size_n,
106+
block_size_k=block_size_k,
107+
)
108+
109+
tensors = tuple(Tensor(4, dtype=dtype) for _ in range(4))
110+
111+
return arrangement_, application, tensors

src/ntops/torch.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import ntops.kernels.bitwise_or
1414
import ntops.kernels.bmm
1515
import ntops.kernels.clamp
16+
import ntops.kernels.conv2d
1617
import ntops.kernels.cos
1718
import ntops.kernels.div
1819
import ntops.kernels.dropout
@@ -140,6 +141,54 @@ def clamp(input, min=None, max=None, *, out=None):
140141
return out
141142

142143

144+
def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
145+
if isinstance(stride, int):
146+
stride = (stride, stride)
147+
148+
# TODO: Support `padding != 0`.
149+
assert padding == 0, "`padding != 0` is not supported yet."
150+
151+
if isinstance(padding, str):
152+
if padding == "valid":
153+
padding = 0
154+
155+
if isinstance(padding, int):
156+
padding = (padding, padding)
157+
158+
if isinstance(dilation, int):
159+
dilation = (dilation, dilation)
160+
161+
# TODO: Support `groups != 1`.
162+
assert groups == 1, "`groups != 1` is not supported yet."
163+
164+
n, _, h, w = input.shape
165+
k, _, r, s = weight.shape
166+
p = math.floor((h + 2 * padding[0] - dilation[0] * (r - 1) - 1) / stride[0] + 1)
167+
q = math.floor((w + 2 * padding[1] - dilation[1] * (s - 1) - 1) / stride[1] + 1)
168+
169+
output = torch.empty((n, k, p, q), dtype=input.dtype, device=input.device)
170+
171+
if bias is None:
172+
bias = torch.zeros((k,), dtype=output.dtype, device=output.device)
173+
174+
bias = bias[None, :, None, None].expand_as(output)
175+
176+
kernel = _cached_make(ntops.kernels.conv2d.premake)
177+
178+
kernel(
179+
input,
180+
weight,
181+
bias,
182+
output,
183+
stride_h=stride[0],
184+
stride_w=stride[1],
185+
dilation_h=dilation[0],
186+
dilation_w=dilation[1],
187+
)
188+
189+
return output
190+
191+
143192
def cos(input, *, out=None):
144193
if out is None:
145194
out = torch.empty_like(input)

tests/test_conv2d.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import pytest
2+
import torch
3+
import torch.nn.functional as F
4+
5+
import ntops.torch
6+
from tests.skippers import skip_if_cuda_not_available
7+
8+
9+
@skip_if_cuda_not_available
10+
@pytest.mark.parametrize(
11+
"dtype, atol, rtol", ((torch.float16, 0.025, 0.025), (torch.float32, 0.01, 0.01))
12+
)
13+
@pytest.mark.parametrize("dilation", (1, 2, (2, 3)))
14+
@pytest.mark.parametrize("stride", (1, 2, (2, 3)))
15+
@pytest.mark.parametrize("r, s", ((1, 1), (3, 3)))
16+
@pytest.mark.parametrize("n, c, h, w, k", ((2, 3, 112, 112, 4),))
17+
def test_cuda(n, c, h, w, k, r, s, stride, dilation, dtype, atol, rtol):
18+
device = "cuda"
19+
20+
input = torch.randn((n, c, h, w), dtype=dtype, device=device)
21+
weight = torch.randn((k, c, r, s), dtype=dtype, device=device)
22+
bias = torch.randn((k,), dtype=dtype, device=device)
23+
24+
ninetoothed_output = ntops.torch.conv2d(
25+
input, weight, bias=bias, stride=stride, dilation=dilation
26+
)
27+
reference_output = F.conv2d(
28+
input, weight, bias=bias, stride=stride, dilation=dilation
29+
)
30+
31+
assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol)

0 commit comments

Comments
 (0)