Skip to content

Commit 2478798

Browse files
authored
Arm backend: Add FP16 (pt.8) and BF16 support to operators (#17370)
Add FP16 support for operators: - add - amax - amin - permute - sum Add BF16 support for operators: - amax - amin Signed-off-by: Martin Lindström <Martin.Lindstroem@arm.com>
1 parent c09db81 commit 2478798

11 files changed

Lines changed: 254 additions & 62 deletions

File tree

backends/arm/operators/op_add.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def define_node(
3737
validate_valid_dtype(
3838
self.target,
3939
[*inputs, output],
40-
[ts.DType.INT32, ts.DType.FP32, ts.DType.BF16],
40+
[ts.DType.INT32, ts.DType.FP16, ts.DType.FP32, ts.DType.BF16],
4141
self.tosa_spec,
4242
)
4343

backends/arm/operators/op_amax.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,14 @@ def define_node(
3939
validate_valid_dtype(
4040
self.target,
4141
[inputs[0], output],
42-
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32],
42+
[
43+
ts.DType.INT8,
44+
ts.DType.INT16,
45+
ts.DType.INT32,
46+
ts.DType.FP16,
47+
ts.DType.FP32,
48+
ts.DType.BF16,
49+
],
4350
self.tosa_spec,
4451
)
4552

backends/arm/operators/op_amin.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,14 @@ def define_node(
3939
validate_valid_dtype(
4040
self.target,
4141
[inputs[0], output],
42-
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32],
42+
[
43+
ts.DType.INT8,
44+
ts.DType.INT16,
45+
ts.DType.INT32,
46+
ts.DType.FP16,
47+
ts.DType.FP32,
48+
ts.DType.BF16,
49+
],
4350
self.tosa_spec,
4451
)
4552

backends/arm/operators/op_permute.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def define_node(
120120
ts.DType.INT8,
121121
ts.DType.INT16,
122122
ts.DType.INT32,
123+
ts.DType.FP16,
123124
ts.DType.FP32,
124125
ts.DType.BF16,
125126
],

backends/arm/operators/op_sum.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def define_node(
3737
validate_valid_dtype(
3838
self.target,
3939
[inputs[0], output],
40-
[ts.DType.INT32, ts.DType.FP32, ts.DType.BF16],
40+
[ts.DType.INT32, ts.DType.FP16, ts.DType.FP32, ts.DType.BF16],
4141
self.tosa_spec,
4242
)
4343

backends/arm/test/ops/test_add.py

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,16 @@ def forward(self, x: torch.Tensor):
4141
"4d_ones_2": lambda: (torch.ones(1, 3, 4, 2),),
4242
}
4343

44+
test_data_fp16 = {
45+
"1d_ones_fp16": lambda: (torch.ones(8, dtype=torch.float16),),
46+
"4d_ones_fp16": lambda: (torch.ones(1, 2, 3, 4, dtype=torch.float16),),
47+
}
48+
49+
test_data_bf16 = {
50+
"1d_ones_bf16": lambda: (torch.ones(8, dtype=torch.bfloat16),),
51+
"4d_ones_bf16": lambda: (torch.ones(1, 2, 3, 4, dtype=torch.bfloat16),),
52+
}
53+
4454

4555
input_t2 = Tuple[torch.Tensor, torch.Tensor] # Input x, y
4656

@@ -70,6 +80,14 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
7080
torch.randn(1, 10, 20, 30),
7181
),
7282
}
83+
84+
test_data_fp16 = {
85+
"4d_big_small_fp16": lambda: (
86+
(10e10) * torch.randn(1, 10, 20, 30, dtype=torch.float16),
87+
torch.randn(1, 10, 20, 30, dtype=torch.float16),
88+
),
89+
}
90+
7391
test_data_bf16 = {
7492
"4d_big_small_bf16": lambda: (
7593
(10e10) * torch.randn(1, 10, 20, 30, dtype=torch.bfloat16),
@@ -88,18 +106,27 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
88106
"4d_randn_diff_rank_2": lambda: (torch.randn(4, 1), torch.randn(1, 1, 4, 5)),
89107
}
90108

109+
test_data_fp16: list[input_t2] = {
110+
"4d_randn_diff_rank_fp16": lambda: (
111+
torch.randn(1, 1, 4, 4, dtype=torch.float16),
112+
torch.randn(4, 1, dtype=torch.float16),
113+
),
114+
}
91115

92-
@common.parametrize("test_data", Add.test_data)
93-
def test_add_tensor_tosa_FP(test_data: input_t1):
94-
pipeline = TosaPipelineFP[input_t1](Add(), test_data(), aten_op, exir_op)
95-
pipeline.run()
116+
test_data_bf16: list[input_t2] = {
117+
"4d_randn_diff_rank_bf16": lambda: (
118+
torch.randn(1, 1, 4, 4, dtype=torch.bfloat16),
119+
torch.randn(4, 1, dtype=torch.bfloat16),
120+
),
121+
}
96122

97123

98-
@common.parametrize("test_data", Add.test_data)
99-
def test_add_tensor_tosa_FP_bf16(test_data: input_t1):
100-
x = test_data()[0].to(torch.bfloat16)
124+
@common.parametrize(
125+
"test_data", Add.test_data | Add.test_data_fp16 | Add.test_data_bf16
126+
)
127+
def test_add_tensor_tosa_FP(test_data: input_t1):
101128
pipeline = TosaPipelineFP[input_t1](
102-
Add(), (x,), aten_op, exir_op, tosa_extensions=["bf16"]
129+
Add(), test_data(), aten_op, exir_op, tosa_extensions=["bf16"]
103130
)
104131
pipeline.run()
105132

@@ -167,17 +194,23 @@ def test_add_tensor_u85_INT(test_data: input_t1):
167194
pipeline.run()
168195

169196

170-
@common.parametrize("test_data", Add2.test_data | Add2.test_data_bf16)
197+
@common.parametrize(
198+
"test_data", Add2.test_data | Add2.test_data_fp16 | Add2.test_data_bf16
199+
)
171200
def test_add_tensor_tosa_FP_2(test_data: input_t2):
172201
pipeline = TosaPipelineFP[input_t2](
173202
Add2(), test_data(), aten_op, exir_op, tosa_extensions=["bf16"]
174203
)
175204
pipeline.run()
176205

177206

178-
@common.parametrize("test_data", Add3.test_data)
207+
@common.parametrize(
208+
"test_data", Add3.test_data | Add3.test_data_fp16 | Add3.test_data_bf16
209+
)
179210
def test_add_tensor_tosa_FP_3(test_data: input_t2):
180-
pipeline = TosaPipelineFP[input_t2](Add3(), test_data(), aten_op, exir_op)
211+
pipeline = TosaPipelineFP[input_t2](
212+
Add3(), test_data(), aten_op, exir_op, tosa_extensions=["bf16"]
213+
)
181214
pipeline.run()
182215

183216

@@ -217,7 +250,7 @@ def test_add_tensor_u85_INT_2(test_data: input_t2):
217250
pipeline.run()
218251

219252

220-
@common.parametrize("test_data", Add.test_data)
253+
@common.parametrize("test_data", Add.test_data | Add.test_data_fp16)
221254
@common.SkipIfNoModelConverter
222255
def test_add_tensor_vgf_no_quant(test_data: input_t1):
223256
pipeline = VgfPipeline[input_t1](

backends/arm/test/ops/test_amax.py

Lines changed: 78 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,15 @@
1717
VgfPipeline,
1818
)
1919

20+
amax_aten_op = "torch.ops.aten.amax"
21+
amax_exir_op = "executorch_exir_dialects_edge__ops_aten_amax_default"
22+
23+
max_aten_op = "torch.ops.aten.max"
24+
max_exir_op = "executorch_exir_dialects_edge__ops_aten_max_default"
25+
2026

2127
class Amax(torch.nn.Module):
2228
input_t = Tuple[Tuple[torch.Tensor], int | Tuple[int], bool]
23-
aten_op = ["torch.ops.aten.amax"]
2429

2530
def __init__(self, dim, keep_dims):
2631
self.dim = dim
@@ -38,10 +43,35 @@ def forward(self, x):
3843
"rank_4_mult_batches": lambda: ((torch.rand([2, 2, 2, 2]),), (0), True),
3944
}
4045

46+
test_data_fp16: Dict[str, input_t] = {
47+
"rank_1_dim_0_fp16": lambda: (
48+
(torch.rand([10], dtype=torch.float16),),
49+
0,
50+
False,
51+
),
52+
"rank_2_dim_1_keep_dims_fp16": lambda: (
53+
(torch.rand([2, 2], dtype=torch.float16),),
54+
(1,),
55+
True,
56+
),
57+
}
58+
59+
test_data_bf16: Dict[str, input_t] = {
60+
"rank_1_dim_0_bf16": lambda: (
61+
(torch.rand([10], dtype=torch.bfloat16),),
62+
0,
63+
False,
64+
),
65+
"rank_2_dim_1_keep_dims_bf16": lambda: (
66+
(torch.rand([2, 2], dtype=torch.bfloat16),),
67+
(1,),
68+
True,
69+
),
70+
}
71+
4172

4273
class Max(torch.nn.Module):
4374
input_t = Tuple[Tuple[torch.Tensor], int]
44-
aten_op = ["torch.ops.aten.amax"]
4575

4676
def __init__(self, dim):
4777
self.dim = dim
@@ -58,8 +88,20 @@ def forward(self, x):
5888
"rank_4_dim_3": lambda: ((torch.rand([2, 2, 2, 2]),), 3),
5989
}
6090

91+
test_data_fp16: Dict[str, input_t] = {
92+
"rank_1_dim_0_fp16": lambda: ((torch.rand([10], dtype=torch.float16),), 0),
93+
"rank_2_dim_1_fp16": lambda: ((torch.rand([2, 2], dtype=torch.float16),), 1),
94+
}
95+
96+
test_data_bf16: Dict[str, input_t] = {
97+
"rank_1_dim_0_bf16": lambda: ((torch.rand([10], dtype=torch.bfloat16),), 0),
98+
"rank_2_dim_1_bf16": lambda: ((torch.rand([2, 2], dtype=torch.bfloat16),), 1),
99+
}
100+
61101

62102
class MaxWithIndex(torch.nn.Module):
103+
input_t = Tuple[Tuple[torch.Tensor], int]
104+
63105
def __init__(self, dim):
64106
self.dim = dim
65107
super().__init__()
@@ -68,18 +110,29 @@ def forward(self, x):
68110
x, i = torch.max(x, self.dim)
69111
return x, i
70112

113+
test_data: Dict[str, input_t] = Max.test_data
114+
test_data_fp16: Dict[str, input_t] = Max.test_data_fp16
115+
test_data_bf16: Dict[str, input_t] = Max.test_data_bf16
71116

72-
@common.parametrize("test_data", Amax.test_data)
117+
118+
@common.parametrize(
119+
"test_data", Amax.test_data | Amax.test_data_fp16 | Amax.test_data_bf16
120+
)
73121
def test_amax_tosa_FP(test_data: Amax.input_t):
74122
data, dim, keep_dims = test_data()
75-
pipeline = TosaPipelineFP[Amax.input_t](Amax(dim, keep_dims), data, Amax.aten_op)
123+
pipeline = TosaPipelineFP[Amax.input_t](
124+
Amax(dim, keep_dims),
125+
data,
126+
amax_aten_op,
127+
tosa_extensions=["bf16"],
128+
)
76129
pipeline.run()
77130

78131

79132
@common.parametrize("test_data", Amax.test_data)
80133
def test_amax_tosa_INT(test_data: Amax.input_t):
81134
data, dim, keep_dims = test_data()
82-
pipeline = TosaPipelineINT[Amax.input_t](Amax(dim, keep_dims), data, Amax.aten_op)
135+
pipeline = TosaPipelineINT[Amax.input_t](Amax(dim, keep_dims), data, amax_aten_op)
83136
pipeline.run()
84137

85138

@@ -88,7 +141,7 @@ def test_amax_u55_INT_not_delegated():
88141
pipeline = OpNotSupportedPipeline[Amax.input_t](
89142
Amax(dim, keep_dims),
90143
data,
91-
{" executorch_exir_dialects_edge__ops_aten_amax_default": 1},
144+
{"executorch_exir_dialects_edge__ops_aten_amax_default": 1},
92145
quantize=True,
93146
u55_subset=True,
94147
)
@@ -102,23 +155,30 @@ def test_amax_u85_INT(test_data: Amax.input_t):
102155
pipeline = EthosU85PipelineINT[Amax.input_t](
103156
Amax(dim, keep_dims),
104157
data,
105-
Amax.aten_op,
158+
amax_aten_op,
106159
)
107160
pipeline.run()
108161

109162

110-
@common.parametrize("test_data", Max.test_data)
163+
@common.parametrize(
164+
"test_data", Max.test_data | Max.test_data_fp16 | Max.test_data_bf16
165+
)
111166
def test_max_dim_tosa_FP_to_amax(test_data: Max.input_t):
112167
data, dim = test_data()
113-
pipeline = TosaPipelineFP[Max.input_t](Max(dim), data, "torch.ops.aten.max")
168+
pipeline = TosaPipelineFP[Max.input_t](
169+
Max(dim),
170+
data,
171+
max_aten_op,
172+
tosa_extensions=["bf16"],
173+
)
114174
pipeline.run()
115175

116176

117177
@common.parametrize("test_data", Max.test_data)
118178
def test_max_dim_tosa_INT_to_amax(test_data: Max.input_t):
119179
data, dim = test_data()
120180
module = Max(dim)
121-
pipeline = TosaPipelineINT[Max.input_t](module, data, "torch.ops.aten.amax")
181+
pipeline = TosaPipelineINT[Max.input_t](module, data, amax_aten_op)
122182
pipeline.run()
123183

124184

@@ -137,15 +197,15 @@ def test_max_dim_tosa_FP_not_delegated():
137197
pipeline.run()
138198

139199

140-
@common.parametrize("test_data", Amax.test_data)
200+
@common.parametrize("test_data", Amax.test_data | Amax.test_data_fp16)
141201
@common.SkipIfNoModelConverter
142202
def test_amax_vgf_no_quant(test_data: Amax.input_t):
143203
data, dim, keep_dims = test_data()
144204
module = Amax(dim, keep_dims)
145205
pipeline = VgfPipeline[Amax.input_t](
146206
module,
147207
data,
148-
Amax.aten_op,
208+
amax_aten_op,
149209
quantize=False,
150210
)
151211
pipeline.run()
@@ -159,20 +219,20 @@ def test_amax_vgf_quant(test_data: Amax.input_t):
159219
pipeline = VgfPipeline[Amax.input_t](
160220
module,
161221
data,
162-
Amax.aten_op,
222+
amax_aten_op,
163223
quantize=True,
164224
)
165225
pipeline.run()
166226

167227

168-
@common.parametrize("test_data", Max.test_data)
228+
@common.parametrize("test_data", Max.test_data | Max.test_data_fp16)
169229
@common.SkipIfNoModelConverter
170230
def test_max_dim_vgf_no_quant_to_amax(test_data: Max.input_t):
171231
data, dim = test_data()
172232
pipeline = VgfPipeline[Max.input_t](
173233
Max(dim),
174234
data,
175-
"torch.ops.aten.max",
235+
max_aten_op,
176236
quantize=False,
177237
)
178238
pipeline.run()
@@ -185,7 +245,7 @@ def test_max_dim_vgf_quant_to_amax(test_data: Max.input_t):
185245
pipeline = VgfPipeline[Max.input_t](
186246
Max(dim),
187247
data,
188-
"torch.ops.aten.amax",
248+
amax_aten_op,
189249
quantize=True,
190250
)
191251
pipeline.run()
@@ -199,7 +259,7 @@ def test_amax_tosa_INT_a16w8(test_data: Amax.input_t):
199259
pipeline = TosaPipelineINT[Max.input_t](
200260
module,
201261
data,
202-
"torch.ops.aten.amax",
262+
amax_aten_op,
203263
tosa_extensions=["int16"],
204264
)
205265
pipeline.run()
@@ -214,7 +274,7 @@ def test_amax_u85_INT_a16w8(test_data: Amax.input_t):
214274
pipeline = EthosU85PipelineINT[Max.input_t](
215275
module,
216276
data,
217-
"torch.ops.aten.amax",
277+
amax_aten_op,
218278
a16w8_quantization=True,
219279
use_to_edge_transform_and_lower=True,
220280
)

0 commit comments

Comments
 (0)