Skip to content

Commit 763cdd1

Browse files
authored
NXP backend: added support for aten.bmm (#17818)
### Summary adds support for `aten.bmm` operator. The original PR is [here](#17670), however I pushed to the branch without committing the work first and the PR closed itself auto-magically. ### Test plan tests can be manually run using `pytest -c /dev/null backends/nxp/tests/`
1 parent 502d2de commit 763cdd1

9 files changed

Lines changed: 376 additions & 0 deletions

File tree

backends/nxp/backend/edge_program_converter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
exir_ops.edge.aten.addmm.default: AddMMConverter, # noqa F405
3333
exir_ops.edge.aten.add.Tensor: AddTensorConverter, # noqa F405
3434
exir_ops.edge.aten.avg_pool2d.default: AvgPool2dConverter, # noqa F405
35+
exir_ops.edge.aten.bmm.default: BMMConverter, # noqa F405
3536
exir_ops.edge.aten.cat.default: CatConverter, # noqa F405
3637
exir_ops.edge.aten.clamp.default: ClampConverter, # noqa F405
3738
exir_ops.edge.aten.clone.default: CloneConverter, # noqa F405

backends/nxp/backend/ir/converter/node_converters/ops_converters/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.avg_pool_2d_converter import (
1414
AvgPool2dConverter,
1515
)
16+
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.bmm_converter import (
17+
BMMConverter,
18+
)
1619
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.cat_converter import (
1720
CatConverter,
1821
)
@@ -99,6 +102,7 @@
99102
"AddMMConverter",
100103
"AddTensorConverter",
101104
"AvgPool2dConverter",
105+
"BMMConverter",
102106
"CatConverter",
103107
"ClampConverter",
104108
"CloneConverter",
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# Copyright 2026 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from executorch.backends.nxp.backend.data_format import NXP_NODE_FORMAT
7+
from executorch.backends.nxp.backend.edge_helper import input_rank
8+
from executorch.backends.nxp.backend.ir.converter.conversion import translator
9+
from executorch.backends.nxp.backend.ir.converter.conversion.common import OpsList
10+
from executorch.backends.nxp.backend.ir.converter.node_converter import (
11+
CustomDelegationOptions,
12+
NodeConverter,
13+
)
14+
from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options import (
15+
batch_mat_mul_options,
16+
)
17+
from executorch.backends.nxp.backend.neutron_operator_support import (
18+
transposition_is_supported_on_neutron,
19+
)
20+
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
21+
from torch.fx import Node
22+
from torch.nn import Parameter
23+
24+
25+
class BMMConverter(NodeConverter):
26+
@staticmethod
27+
def _is_supported_in_IR(
28+
node: Node,
29+
parameters_mapping: dict[str, Parameter],
30+
custom_delegation_options: CustomDelegationOptions,
31+
) -> bool:
32+
if len(node.all_input_nodes) != 2:
33+
return False
34+
35+
if input_rank(node, 0) != 3 or input_rank(node, 1) != 3:
36+
return False
37+
38+
return True
39+
40+
@staticmethod
41+
def _is_supported_on_target(
42+
node: Node,
43+
neutron_target_spec: NeutronTargetSpec,
44+
parameters_mapping: dict[str, Parameter],
45+
custom_delegation_options: CustomDelegationOptions,
46+
) -> bool:
47+
is_ch_first_1 = node.args[0].meta[NXP_NODE_FORMAT].is_channels_first()
48+
is_ch_first_2 = node.args[1].meta[NXP_NODE_FORMAT].is_channels_first()
49+
# This combination of node formats is not supported on Neutron (`adj_x = True`, `adj_y = False`),
50+
# but it should never happen because both input tensors are expected to share the same format.
51+
if is_ch_first_1 and not is_ch_first_2:
52+
return False
53+
54+
# In case we need to insert transpose after `BatchMatMul`, we also need to check if
55+
# such transposition is supported.
56+
if node.meta[NXP_NODE_FORMAT].is_channels_first():
57+
tensor_shape = node.meta["val"].shape
58+
tensor_rank = len(tensor_shape)
59+
perm = translator.create_channels_first_to_channels_last_permutation(
60+
tensor_rank, return_list=True
61+
)
62+
63+
tensor_shape_channels_last = [tensor_shape[i] for i in perm]
64+
if not transposition_is_supported_on_neutron(
65+
tensor_shape_channels_last, perm, neutron_target_spec
66+
):
67+
return False
68+
69+
_, d1, d2 = node.args[0].meta["val"].shape
70+
_, d3, d4 = node.args[1].meta["val"].shape
71+
72+
# The Neutron converter requires that every dimension participating in the
73+
# multiplication is divisible by NUM_MACS.
74+
num_macs = neutron_target_spec.get_num_macs()
75+
if not all(m % num_macs == 0 for m in [d1, d2, d3, d4]):
76+
return False
77+
78+
return True
79+
80+
def convert(self, node: Node):
81+
"""Convert the `aten.bmm` operator to TFLite `BatchMatMul`."""
82+
self.assert_convertible(node)
83+
84+
t_op = self._create_tflite_op_with_io_tensors(node)
85+
86+
# We set `adj_x = adj_y = True` when the inputs are in channels‑last format so
87+
# that TFLite internally transposes them to channels‑first. In that case, the
88+
# output also becomes channels‑first, so we need to transpose it back to
89+
# channels‑last afterward.
90+
#
91+
# We set `asymmetric_quantize_inputs = False`. Neutron ignores this parameter
92+
# entirely, so its value does not affect delegation and can be set arbitrarily.
93+
is_ch_first_1 = node.args[0].meta[NXP_NODE_FORMAT].is_channels_first()
94+
is_ch_first_2 = node.args[1].meta[NXP_NODE_FORMAT].is_channels_first()
95+
t_op.builtin_options = batch_mat_mul_options.BatchMatMul(
96+
is_ch_first_1, is_ch_first_2, False
97+
)
98+
99+
x1 = t_op.tmp_inputs[0]
100+
x2 = t_op.tmp_inputs[1]
101+
y = t_op.tmp_outputs[0]
102+
103+
# Assign the operator its TFLite inputs and outputs
104+
t_op.tmp_inputs = [x1, x2]
105+
t_op.tmp_outputs = [y]
106+
107+
ops = OpsList(middle_op=t_op)
108+
109+
# Transpose back to channels-last if needed.
110+
if node.meta[NXP_NODE_FORMAT].is_channels_first():
111+
tensor_rank = len(node.meta["val"].shape)
112+
perm = translator.create_channels_first_to_channels_last_permutation(
113+
tensor_rank, return_list=True
114+
)
115+
ops.add_post(self.builder.create_transpose_operator_after(t_op, 0, perm))
116+
117+
self.builder.append_operators(ops.flatten())

backends/nxp/neutron_partitioner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ def tag_qdq_clusters(self, nodes: list[torch.fx.Node]):
205205
exir_ops.edge.aten.addmm.default: AddMMConverter, # noqa F405
206206
exir_ops.edge.aten.add.Tensor: AddTensorConverter, # noqa F405
207207
exir_ops.edge.aten.avg_pool2d.default: AvgPool2dConverter, # noqa F405
208+
exir_ops.edge.aten.bmm.default: BMMConverter, # noqa F405
208209
exir_ops.edge.aten.cat.default: CatConverter, # noqa F405
209210
exir_ops.edge.aten.clamp.default: ClampConverter, # noqa F405
210211
exir_ops.edge.aten.clone.default: CloneConverter, # noqa F405

backends/nxp/quantizer/neutron_quantizer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
AvgPool1DPattern,
2121
AvgPool2DPattern,
2222
BatchNormPattern,
23+
BMMPattern,
2324
CatPattern,
2425
ClampPattern,
2526
Conv1dPattern,
@@ -262,6 +263,7 @@ def __init__(self, neutron_target_spec: NeutronTargetSpec, is_qat: bool = False)
262263
OpQuantizer(AvgPool1DPattern(is_qat=is_qat), static_qconfig),
263264
OpQuantizer(AvgPool2DPattern(is_qat=is_qat), static_qconfig),
264265
OpQuantizer(BatchNormPattern(is_qat=is_qat), static_qconfig),
266+
OpQuantizer(BMMPattern(is_qat=is_qat), static_qconfig),
265267
OpQuantizer(CatPattern(is_qat=is_qat), static_qconfig),
266268
OpQuantizer(ClampPattern(is_qat=is_qat), static_qconfig),
267269
OpQuantizer(Conv1dPattern(is_qat=is_qat), static_qconfig),

backends/nxp/quantizer/patterns.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,29 @@ def get_anchors(
298298
)
299299

300300

301+
class BMMPattern(QuantizationPattern):
302+
"""
303+
Quantizer for BatchMatMul operator.
304+
"""
305+
306+
def partition_types(self) -> list[torch.nn.Module]:
307+
return [torch.ops.aten.bmm.default]
308+
309+
def get_anchors(
310+
self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule]
311+
) -> PartitionAnchors | None:
312+
bmm_node = fused_partition[0].nodes[-1]
313+
314+
return PartitionAnchors(
315+
inputs=[
316+
(bmm_node, NodeArgsIdx(0)),
317+
(bmm_node, NodeArgsIdx(1)),
318+
],
319+
biases=[],
320+
output=[(bmm_node,)],
321+
)
322+
323+
301324
class SubTensorPattern(QuantizationPattern):
302325
"""
303326
Quantization pattern for Sub Tensor quantization. Accepts 1 or 2 input nodes.

0 commit comments

Comments
 (0)