Skip to content

Commit 64d7236

Browse files
authored
Arm backend: Add FP16 support to operators pt.2 (pytorch#17088)
Add FP16 support for operators: - conv2d - conv3d - cos - slice - pad Update op tests to cover the new datatype. Also correct the test name violations seen in test_conv2d.py. Signed-off-by: Martin Lindström <Martin.Lindstroem@arm.com>
1 parent 5dbe157 commit 64d7236

9 files changed

Lines changed: 100 additions & 21 deletions

File tree

backends/arm/operators/op_constant_pad_nd.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def define_node(
4646
ts.DType.INT8,
4747
ts.DType.INT16,
4848
ts.DType.INT32,
49+
ts.DType.FP16,
4950
ts.DType.FP32,
5051
ts.DType.BF16,
5152
ts.DType.BOOL,

backends/arm/operators/op_cos.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def define_node(
4242
validate_valid_dtype(
4343
self.target,
4444
[*inputs, output],
45-
[ts.DType.FP32, ts.DType.BF16],
45+
[ts.DType.FP16, ts.DType.FP32, ts.DType.BF16],
4646
self.tosa_spec,
4747
)
4848
attr = ts.TosaSerializerAttribute()

backends/arm/operators/op_slice.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def define_node(
7777
ts.DType.INT16,
7878
ts.DType.INT32,
7979
ts.DType.BF16,
80+
ts.DType.FP16,
8081
ts.DType.FP32,
8182
],
8283
self.tosa_spec,

backends/arm/operators/op_tosa_conv2d.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def define_node(
5151

5252
valid_input_dtypes = []
5353
if self.tosa_spec.support_float():
54-
valid_input_dtypes.append(ts.DType.FP32)
54+
valid_input_dtypes.extend([ts.DType.FP16, ts.DType.FP32])
5555
if self.tosa_spec.support_integer():
5656
valid_input_dtypes.append(ts.DType.INT8)
5757

@@ -82,8 +82,8 @@ def define_node(
8282

8383
conv2d_output_name = output.name
8484
acc_type = output.dtype
85-
if output.dtype == ts.DType.BF16:
86-
# Accumulate BF16 inputs in FP32 for better precision per TOSA BF16 extension.
85+
if output.dtype in [ts.DType.BF16, ts.DType.FP16]:
86+
# Accumulate BF16, FP16 inputs in FP32 for better precision.
8787
acc_type = ts.DType.FP32
8888

8989
input_zp_name, weight_zp_name = add_input_weight_zp_consts(

backends/arm/test/ops/test_constant_pad_nd.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,18 @@
4545
-0.5,
4646
),
4747
}
48+
test_data_suite_fp16 = {
49+
"4dim_last1dim_fp16": lambda: (
50+
torch.rand(1, 1, 8, 8, dtype=torch.float16),
51+
(1, 1, 0, 0, 0, 0, 0, 0),
52+
1.0,
53+
),
54+
"3dim_last1dim_fp16": lambda: (
55+
torch.rand(1, 1, 8, dtype=torch.float16),
56+
(1, 0, 1, 0, 0, 0),
57+
-0.5,
58+
),
59+
}
4860

4961

5062
class ConstantPadND(torch.nn.Module):
@@ -65,7 +77,7 @@ def forward(self, x: torch.Tensor):
6577

6678
@common.parametrize(
6779
"test_data",
68-
test_data_suite | test_data_suite_bf16,
80+
test_data_suite | test_data_suite_bf16 | test_data_suite_fp16,
6981
)
7082
def test_constant_pad_nd_tosa_FP(test_data: Tuple):
7183
test_data, padding, value = test_data()
@@ -105,7 +117,7 @@ def test_constant_pad_nd_tosa_INT_a16w8(test_data: Tuple):
105117
pipeline.run()
106118

107119

108-
@common.parametrize("test_data", test_data_suite)
120+
@common.parametrize("test_data", test_data_suite | test_data_suite_fp16)
109121
@common.SkipIfNoModelConverter
110122
def test_constant_pad_nd_vgf_no_quant(test_data: Tuple):
111123
inp, padding, value = test_data()

backends/arm/test/ops/test_conv2d.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,30 @@ def forward(self, x):
410410
dtype=torch.bfloat16,
411411
),
412412
}
413+
test_data_FP_fp16 = {
414+
"fp16_3x3": lambda: Conv2d(
415+
height=12,
416+
width=12,
417+
in_channels=3,
418+
out_channels=4,
419+
kernel_size=(3, 3),
420+
stride=(1, 1),
421+
padding=(1, 1),
422+
bias=True,
423+
dtype=torch.float16,
424+
),
425+
"fp16_1x1": lambda: Conv2d(
426+
height=8,
427+
width=8,
428+
in_channels=2,
429+
out_channels=2,
430+
kernel_size=(1, 1),
431+
stride=(2, 1),
432+
padding=(0, 3),
433+
bias=False,
434+
dtype=torch.float16,
435+
),
436+
}
413437

414438
# Generate a new test set paired with per_channel_quant=True/False.
415439
test_data_INT = {
@@ -431,7 +455,7 @@ def _get_dtype_count(model: torch.nn.Module):
431455
}
432456

433457

434-
@common.parametrize("test_data", test_data_FP | test_data_FP_bf16)
458+
@common.parametrize("test_data", test_data_FP | test_data_FP_bf16 | test_data_FP_fp16)
435459
def test_convolution_2d_tosa_FP(test_data):
436460
model = test_data()
437461
pipeline = TosaPipelineFP[input_t](
@@ -539,7 +563,7 @@ def test_convolution_2d_u85_INT_a8w4(test_data):
539563
pipeline.run()
540564

541565

542-
@common.parametrize("test_data", test_data_FP)
566+
@common.parametrize("test_data", test_data_FP | test_data_FP_fp16)
543567
@common.SkipIfNoModelConverter
544568
def test_convolution_2d_vgf_no_quant(test_data):
545569
model = test_data()
@@ -614,7 +638,7 @@ def test_convolution_2d_u55_INT_not_delegated(module: Conv2d):
614638

615639

616640
@common.parametrize("test_data", test_data_INT)
617-
def test_conv2d_tosa_INT_a16w8(test_data: input_t):
641+
def test_convolution_2d_tosa_INT_a16w8(test_data: input_t):
618642
"""Test conv2d with 16A8W quantization for TOSA INT."""
619643
model, per_channel_quantization = test_data()
620644
pipeline = TosaPipelineINT[input_t](
@@ -630,7 +654,7 @@ def test_conv2d_tosa_INT_a16w8(test_data: input_t):
630654

631655
@common.parametrize("test_data", test_data_INT)
632656
@common.XfailIfNoCorstone300
633-
def test_conv2d_u55_INT_a16w8(test_data: input_t):
657+
def test_convolution_2d_u55_INT_a16w8(test_data: input_t):
634658
"""Test conv2d with 16A8W quantization on U55 (16-bit activations, 8-bit weights)"""
635659
model, per_channel_quantization = test_data()
636660
pipeline = EthosU55PipelineINT[input_t](
@@ -647,7 +671,7 @@ def test_conv2d_u55_INT_a16w8(test_data: input_t):
647671

648672
@common.parametrize("test_data", test_data_INT)
649673
@common.XfailIfNoCorstone320
650-
def test_conv2d_u85_INT_a16w8(test_data: input_t):
674+
def test_convolution_2d_u85_INT_a16w8(test_data: input_t):
651675
"""Test conv2d with 16A8W quantization on U85 (16-bit activations, 8-bit weights)"""
652676
model, per_channel_quantization = test_data()
653677
pipeline = EthosU85PipelineINT[input_t](

backends/arm/test/ops/test_conv3d.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,32 @@ def forward(self, x):
439439
dtype=torch.bfloat16,
440440
),
441441
}
442+
test_data_FP_fp16 = {
443+
"fp16_3x3": lambda: Conv3d(
444+
height=10,
445+
width=10,
446+
depth=6,
447+
in_channels=3,
448+
out_channels=4,
449+
kernel_size=(3, 3, 3),
450+
stride=(1, 1, 1),
451+
padding=(1, 1, 1),
452+
bias=True,
453+
dtype=torch.float16,
454+
),
455+
"fp16_1x1": lambda: Conv3d(
456+
height=6,
457+
width=6,
458+
depth=4,
459+
in_channels=2,
460+
out_channels=2,
461+
kernel_size=(1, 1, 1),
462+
stride=(1, 1, 1),
463+
padding=(0, 0, 0),
464+
bias=False,
465+
dtype=torch.float16,
466+
),
467+
}
442468

443469
# Generate a new test set paired with per_channel_quant=True/False.
444470
test_data_INT = {
@@ -466,11 +492,12 @@ def _get_dtype_count(model: torch.nn.Module):
466492
input_t = Tuple[torch.Tensor]
467493

468494

469-
@common.parametrize("test_data", test_data_FP | test_data_FP_bf16)
495+
@common.parametrize("test_data", test_data_FP | test_data_FP_bf16 | test_data_FP_fp16)
470496
def test_convolution_3d_tosa_FP(test_data):
497+
model = test_data()
471498
pipeline = TosaPipelineFP[input_t](
472-
test_data(),
473-
test_data().get_inputs(),
499+
model,
500+
model.get_inputs(),
474501
aten_op,
475502
exir_op,
476503
tosa_extensions=["bf16"],
@@ -623,12 +650,13 @@ def test_convolution_3d_u85_INT_a8w4(test_data):
623650
pipeline.run()
624651

625652

626-
@common.parametrize("test_data", test_data_FP)
653+
@common.parametrize("test_data", test_data_FP | test_data_FP_fp16)
627654
@common.SkipIfNoModelConverter
628655
def test_convolution_3d_vgf_no_quant(test_data):
656+
model = test_data()
629657
pipeline = VgfPipeline[input_t](
630-
test_data(),
631-
test_data().get_inputs(),
658+
model,
659+
model.get_inputs(),
632660
aten_op,
633661
exir_op,
634662
quantize=False,

backends/arm/test/ops/test_cos.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@
3535
"rand_bf16": torch.rand(4, 4, dtype=torch.bfloat16) - 0.5,
3636
"ramp_bf16": torch.arange(-8, 8, 0.5, dtype=torch.bfloat16),
3737
}
38+
test_data_suite_fp16 = {
39+
"rand_fp16": torch.rand(4, 4, dtype=torch.float16) - 0.5,
40+
"ramp_fp16": torch.arange(-8, 8, 0.5, dtype=torch.float16),
41+
}
3842

3943

4044
class Cos(torch.nn.Module):
@@ -43,7 +47,9 @@ def forward(self, x: torch.Tensor):
4347
return torch.cos(x)
4448

4549

46-
@common.parametrize("test_data", test_data_suite | test_data_suite_bf16)
50+
@common.parametrize(
51+
"test_data", test_data_suite | test_data_suite_bf16 | test_data_suite_fp16
52+
)
4753
@pytest.mark.tosa_ref_model
4854
def test_cos_tosa_FP(test_data: Tuple):
4955
pipeline = TosaPipelineFP[input_t1](
@@ -92,7 +98,7 @@ def test_cos_u85_INT(test_data: Tuple):
9298
pipeline.run()
9399

94100

95-
@common.parametrize("test_data", test_data_suite)
101+
@common.parametrize("test_data", test_data_suite | test_data_suite_fp16)
96102
@common.SkipIfNoModelConverter
97103
def test_cos_vgf_no_quant(test_data: Tuple):
98104
pipeline = VgfPipeline[input_t1](

backends/arm/test/ops/test_slice.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,21 @@
3737
),
3838
}
3939

40+
test_data_suite_fp16 = {
41+
"ones_slice_4_fp16": lambda: (
42+
torch.ones((1, 12, 10, 10), dtype=torch.float16),
43+
[(0, 1), (0, 5), (3, 5), (4, 10)],
44+
),
45+
}
46+
4047

4148
class Slice(torch.nn.Module):
4249
def forward(self, x: torch.Tensor, s: list[tuple[int, int]]):
4350
slices = [slice(*i) for i in s]
4451
return x[slices]
4552

4653

47-
@common.parametrize("test_data", test_data_suite)
54+
@common.parametrize("test_data", test_data_suite | test_data_suite_fp16)
4855
def test_slice_tensor_tosa_FP(test_data: torch.Tensor):
4956
pipeline = TosaPipelineFP[input_t1](Slice(), test_data(), aten_op, exir_op)
5057
pipeline.run()
@@ -96,7 +103,7 @@ def test_slice_tensor_u85_INT(test_data: torch.Tensor):
96103
pipeline.run()
97104

98105

99-
@common.parametrize("test_data", test_data_suite)
106+
@common.parametrize("test_data", test_data_suite | test_data_suite_fp16)
100107
@common.SkipIfNoModelConverter
101108
def test_slice_tensor_vgf_no_quant(test_data: torch.Tensor):
102109
pipeline = VgfPipeline[input_t1](

0 commit comments

Comments
 (0)