Skip to content

Commit ec4c462

Browse files
authored
Arm backend: Add BF16 support to operators pt. 4-5 (pytorch#17003)
ge, gt, alias_copy, index_tensor mapping to TOSA operators GATHER, GREATER, GREATER_EQUAL, IDENTITY log, maxpool_2d, log, matmul, maximum, minimum mapping to TOSA operators LOG, MATMUL, MAXIMUM, MINIMUM, MAX_POOL2D cc @freddan80 @per @zingo @oscarandersson8218 @digantdesai --------- Signed-off-by: Erik Lundell <erik.lundell@arm.com>
1 parent cfc71fe commit ec4c462

19 files changed

Lines changed: 195 additions & 39 deletions

backends/arm/_passes/rewrite_matmul.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2025 Arm Limited and/or its affiliates.
1+
# Copyright 2025-2026 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -91,6 +91,22 @@ def call(self, graph_module):
9191
tosa_matmul_node.meta[TosaSpecialDtype.meta_key()] = (
9292
TosaSpecialDtype.INT48
9393
)
94+
elif (
95+
x1_fake_tensor.dtype == torch.bfloat16
96+
and x2_fake_tensor.dtype == torch.bfloat16
97+
and output_fake_tensor.dtype != torch.bfloat16
98+
):
99+
# A TOSA BF16 MATMUL outputs FP32 wheras pytorch outputs BF16.
100+
# Cast back to BF16 to get matching semantics.
101+
with graph_module.graph.inserting_after(tosa_matmul_node):
102+
cast_node = create_node(
103+
graph_module.graph,
104+
op_target=exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
105+
kwargs={"dtype": torch.bfloat16},
106+
from_node=tosa_matmul_node,
107+
)
108+
tosa_matmul_node.replace_all_uses_with(cast_node)
109+
cast_node.args = (tosa_matmul_node,)
94110

95111
if modified:
96112
graph_module.recompile()

backends/arm/operators/op_ge.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def define_node(
4141
validate_valid_dtype(
4242
self.target,
4343
inputs,
44-
[ts.DType.INT32, ts.DType.FP32],
44+
[ts.DType.INT32, ts.DType.FP32, ts.DType.BF16],
4545
self.tosa_spec,
4646
)
4747
validate_valid_dtype(self.target, output, ts.DType.BOOL, self.tosa_spec)

backends/arm/operators/op_gt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def define_node(
4141
validate_valid_dtype(
4242
self.target,
4343
inputs,
44-
[ts.DType.INT32, ts.DType.FP32],
44+
[ts.DType.INT32, ts.DType.FP32, ts.DType.BF16],
4545
self.tosa_spec,
4646
)
4747
validate_valid_dtype(self.target, output, ts.DType.BOOL, self.tosa_spec)

backends/arm/operators/op_log.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,10 @@ def define_node(
4141
validate_num_inputs(self.target, inputs, 1)
4242
validate_same_dtype(self.target, [*inputs, output], ts)
4343
validate_valid_dtype(
44-
self.target, [*inputs, output], ts.DType.FP32, self.tosa_spec
44+
self.target,
45+
[*inputs, output],
46+
[ts.DType.FP32, ts.DType.BF16],
47+
self.tosa_spec,
4548
)
4649
attr = ts.TosaSerializerAttribute()
4750
attr.LogAttribute()

backends/arm/operators/op_max_pool2d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def define_node(
3838
) -> None:
3939
validate_num_inputs(self.target, inputs, [3, 4, 5, 6])
4040
validate_same_dtype(self.target, [inputs[0], output], ts)
41-
supported_dtypes = [ts.DType.INT8, ts.DType.FP32]
41+
supported_dtypes = [ts.DType.INT8, ts.DType.FP32, ts.DType.BF16]
4242
if self.tosa_spec.support_extension("int16"):
4343
supported_dtypes.append(ts.DType.INT16)
4444
validate_valid_dtype(

backends/arm/operators/op_mul.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,13 @@ def define_node(
3838
validate_valid_dtype(
3939
self.target,
4040
[*inputs, output],
41-
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32],
41+
[
42+
ts.DType.INT8,
43+
ts.DType.INT16,
44+
ts.DType.INT32,
45+
ts.DType.FP32,
46+
ts.DType.BF16,
47+
],
4248
self.tosa_spec,
4349
)
4450

backends/arm/operators/op_tosa_matmul.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ def define_node(
4646
validate_num_inputs(self.target, inputs, 2)
4747
validate_same_dtype(self.target, [*inputs], ts)
4848
supported_input_dtypes = [ts.DType.INT8, ts.DType.INT32, ts.DType.FP32]
49+
if self.tosa_spec.support_extension("bf16"):
50+
supported_input_dtypes.append(ts.DType.BF16)
4951
if self.tosa_spec.support_extension("int16"):
5052
supported_input_dtypes.append(ts.DType.INT16)
5153
validate_valid_dtype(

backends/arm/operators/op_view.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def define_node(
4646
ts.DType.INT16,
4747
ts.DType.INT32,
4848
ts.DType.FP32,
49+
ts.DType.BF16,
4950
ts.DType.BOOL,
5051
],
5152
self.tosa_spec,

backends/arm/operators/ops_identity.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ def define_node(
4949
]
5050
if self.tosa_spec.support_float():
5151
supported_dtypes += [ts.DType.FP32]
52+
if self.tosa_spec.support_extension("bf16"):
53+
supported_dtypes += [ts.DType.BF16]
5254
if self.tosa_spec.support_extension("int16"):
5355
supported_dtypes += [ts.DType.INT48]
5456
if self.tosa_spec.support_extension("int4"):

backends/arm/test/ops/test_alias_copy.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2025 Arm Limited and/or its affiliates.
1+
# Copyright 2025-2026 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -36,6 +36,9 @@ class AliasCopy(torch.nn.Module):
3636
"3d_rand": lambda: (torch.rand(3, 5, 5),),
3737
"4d_zeros": lambda: (torch.zeros(1, 10, 10, 10),),
3838
}
39+
test_data_bf16 = {
40+
"3d_rand_bf16": lambda: (torch.rand(3, 5, 2, dtype=torch.bfloat16),)
41+
}
3942

4043
def __init__(self):
4144
super().__init__()
@@ -46,13 +49,14 @@ def forward(self, x: torch.Tensor):
4649
) # Multiply by one to make sure it is partitioned.
4750

4851

49-
@common.parametrize("test_data", AliasCopy.test_data)
52+
@common.parametrize("test_data", AliasCopy.test_data | AliasCopy.test_data_bf16)
5053
def test_alias_tosa_FP(test_data: input_t1):
5154
TosaPipelineFP[input_t1](
5255
AliasCopy(),
5356
test_data(),
5457
AliasCopy.aten_op,
5558
AliasCopy.exir_op,
59+
tosa_extensions=["bf16"],
5660
).run()
5761

5862

0 commit comments

Comments
 (0)