Skip to content

Commit 2d0a1ad

Browse files
authored
Merge pull request #52 from InfiniTensor/add-input-precision-support-for-matmul-operations
Add input precision support for matmul operations
2 parents 75445b5 + 22c5be2 commit 2d0a1ad

5 files changed

Lines changed: 101 additions & 23 deletions

File tree

src/ntops/kernels/addmm.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ def arrangement(
1313
beta,
1414
alpha,
1515
output,
16+
input_precision,
1617
block_size_m=None,
1718
block_size_n=None,
1819
block_size_k=None,
@@ -26,27 +27,46 @@ def arrangement(
2627
if block_size_k is None:
2728
block_size_k = mm.BLOCK_SIZE_K
2829

29-
_, _, input_arranged = mm.arrangement(
30+
_, _, input_arranged, _ = mm.arrangement(
3031
x,
3132
y,
3233
input,
34+
input_precision,
3335
block_size_m=block_size_m,
3436
block_size_n=block_size_n,
3537
block_size_k=block_size_k,
3638
)
3739

38-
x_arranged, y_arranged, output_arranged = mm.arrangement(x, y, output)
40+
x_arranged, y_arranged, output_arranged, _ = mm.arrangement(
41+
x, y, output, input_precision
42+
)
43+
44+
input_precision_arranged = input_precision
3945

40-
return input_arranged, x_arranged, y_arranged, beta, alpha, output_arranged
46+
return (
47+
input_arranged,
48+
x_arranged,
49+
y_arranged,
50+
beta,
51+
alpha,
52+
output_arranged,
53+
input_precision_arranged,
54+
)
4155

4256

43-
def application(input, x, y, beta, alpha, output):
57+
def application(input, x, y, beta, alpha, output, input_precision):
4458
mm_output = ntl.zeros(output.shape, dtype=ntl.float32)
45-
mm.application(x, y, mm_output)
59+
mm.application(x, y, mm_output, input_precision)
4660
output = beta * input + alpha * mm_output
4761

4862

49-
def premake(dtype=None, block_size_m=None, block_size_n=None, block_size_k=None):
63+
def premake(
64+
input_precision=None,
65+
dtype=None,
66+
block_size_m=None,
67+
block_size_n=None,
68+
block_size_k=None,
69+
):
5070
arrangement_ = functools.partial(
5171
arrangement,
5272
block_size_m=block_size_m,
@@ -61,6 +81,7 @@ def premake(dtype=None, block_size_m=None, block_size_n=None, block_size_k=None)
6181
Tensor(0, dtype=dtype),
6282
Tensor(0, dtype=dtype),
6383
Tensor(2, dtype=dtype),
84+
Tensor(0, dtype=dtype, constexpr=True, value=input_precision),
6485
)
6586

6687
return arrangement_, application, tensors

src/ntops/kernels/bmm.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,13 @@
66

77

88
def arrangement(
9-
input, other, output, block_size_m=None, block_size_n=None, block_size_k=None
9+
input,
10+
other,
11+
output,
12+
input_precision,
13+
block_size_m=None,
14+
block_size_n=None,
15+
block_size_k=None,
1016
):
1117
if block_size_m is None:
1218
block_size_m = BLOCK_SIZE_M
@@ -32,17 +38,30 @@ def arrangement(
3238
other_arranged.dtype = other_arranged.dtype.squeeze((0, 2))
3339
other_arranged.dtype.dtype = other_arranged.dtype.dtype.squeeze(0)
3440

35-
return input_arranged, other_arranged, output_arranged
41+
input_precision_arranged = input_precision
3642

43+
return input_arranged, other_arranged, output_arranged, input_precision_arranged
3744

38-
def premake(dtype=None, block_size_m=None, block_size_n=None, block_size_k=None):
45+
46+
def premake(
47+
input_precision=None,
48+
dtype=None,
49+
block_size_m=None,
50+
block_size_n=None,
51+
block_size_k=None,
52+
):
3953
arrangement_ = functools.partial(
4054
arrangement,
4155
block_size_m=block_size_m,
4256
block_size_n=block_size_n,
4357
block_size_k=block_size_k,
4458
)
4559

46-
tensors = (Tensor(3, dtype=dtype), Tensor(3, dtype=dtype), Tensor(3, dtype=dtype))
60+
tensors = (
61+
Tensor(3, dtype=dtype),
62+
Tensor(3, dtype=dtype),
63+
Tensor(3, dtype=dtype),
64+
Tensor(0, dtype=dtype, constexpr=True, value=input_precision),
65+
)
4766

4867
return arrangement_, application, tensors

src/ntops/kernels/mm.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import enum
12
import functools
23

34
import ninetoothed
@@ -9,8 +10,20 @@
910
BLOCK_SIZE_K = ninetoothed.block_size()
1011

1112

13+
class InputPrecisionVariant(enum.IntEnum):
14+
TF32 = enum.auto()
15+
16+
IEEE = enum.auto()
17+
18+
1219
def arrangement(
13-
input, other, output, block_size_m=None, block_size_n=None, block_size_k=None
20+
input,
21+
other,
22+
output,
23+
input_precision,
24+
block_size_m=None,
25+
block_size_n=None,
26+
block_size_k=None,
1427
):
1528
if block_size_m is None:
1629
block_size_m = BLOCK_SIZE_M
@@ -33,26 +46,44 @@ def arrangement(
3346
other_arranged = other_arranged.expand((output_arranged.shape[0], -1))
3447
other_arranged.dtype = other_arranged.dtype.squeeze(1)
3548

36-
return input_arranged, other_arranged, output_arranged
49+
input_precision_arranged = input_precision
50+
51+
return input_arranged, other_arranged, output_arranged, input_precision_arranged
3752

3853

39-
def application(input, other, output):
54+
def application(input, other, output, input_precision):
4055
accumulator = ntl.zeros(output.shape, dtype=ntl.float32)
4156

57+
if input_precision == 2: # InputPrecisionVariant.IEEE:
58+
input_precision_: ntl.constexpr = "ieee"
59+
else:
60+
input_precision_: ntl.constexpr = "tf32"
61+
4262
for k in range(input.shape[0]):
43-
accumulator += ntl.dot(input[k], other[k])
63+
accumulator += ntl.dot(input[k], other[k], input_precision=input_precision_)
4464

4565
output = accumulator
4666

4767

48-
def premake(dtype=None, block_size_m=None, block_size_n=None, block_size_k=None):
68+
def premake(
69+
input_precision=None,
70+
dtype=None,
71+
block_size_m=None,
72+
block_size_n=None,
73+
block_size_k=None,
74+
):
4975
arrangement_ = functools.partial(
5076
arrangement,
5177
block_size_m=block_size_m,
5278
block_size_n=block_size_n,
5379
block_size_k=block_size_k,
5480
)
5581

56-
tensors = (Tensor(2, dtype=dtype), Tensor(2, dtype=dtype), Tensor(2, dtype=dtype))
82+
tensors = (
83+
Tensor(2, dtype=dtype),
84+
Tensor(2, dtype=dtype),
85+
Tensor(2, dtype=dtype),
86+
Tensor(0, dtype=dtype, constexpr=True, value=input_precision),
87+
)
5788

5889
return arrangement_, application, tensors

src/ntops/torch.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def addmm(input, mat1, mat2, *, beta=1, alpha=1, out=None):
7676

7777
kernel = _cached_make(ntops.kernels.addmm.premake)
7878

79-
kernel(input, mat1, mat2, beta, alpha, out)
79+
kernel(input, mat1, mat2, beta, alpha, out, _get_matmul_input_precision())
8080

8181
return out
8282

@@ -125,7 +125,7 @@ def bmm(input, mat2, *, out=None):
125125

126126
kernel = _cached_make(ntops.kernels.bmm.premake)
127127

128-
kernel(input, mat2, out)
128+
kernel(input, mat2, out, _get_matmul_input_precision())
129129

130130
return out
131131

@@ -294,7 +294,7 @@ def mm(input, mat2, *, out=None):
294294

295295
kernel = _cached_make(ntops.kernels.mm.premake)
296296

297-
kernel(input, mat2, out)
297+
kernel(input, mat2, out, _get_matmul_input_precision())
298298

299299
return out
300300

@@ -619,3 +619,10 @@ def _cached_make(
619619
num_stages=num_stages,
620620
max_num_configs=max_num_configs,
621621
)
622+
623+
624+
def _get_matmul_input_precision():
625+
if torch.get_float32_matmul_precision() == "highest":
626+
return ntops.kernels.mm.InputPrecisionVariant.IEEE
627+
628+
return ntops.kernels.mm.InputPrecisionVariant.TF32

tests/test_mm.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@ def generate_arguments():
1212

1313
for dtype in (torch.float32, torch.float16):
1414
if dtype is torch.float32:
15-
atol = 0.05
16-
rtol = 0.05
15+
atol = 0.001
16+
rtol = 0.001
1717
else:
18-
atol = 0.025
19-
rtol = 0.025
18+
atol = 0.01
19+
rtol = 0.01
2020

2121
def generate_random_size():
2222
return random.randint(1, 1024)

0 commit comments

Comments
 (0)