Skip to content

Commit f1062a7

Browse files
NXP backend: Test Mul Tensor with new Neutron flow (#19336)
### Summary Add tests verifying correct support for mul.tensor by the Neutron backend using the new Neutron MLIR flow. ### Test plan Unit tests provided. cc @robert-kalmar @JakeStevens @digantdesai
1 parent 4d9b0e9 commit f1062a7

9 files changed

Lines changed: 268 additions & 71 deletions

File tree

backends/nxp/backend/ir/converter/conversion/common.py

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@
2323
transpose_conv_options,
2424
)
2525

26-
from torch.fx import Node
27-
2826

2927
def try_get_input(t_op: tflite_model.Operator, idx: int) -> tflite_model.Tensor | None:
3028
"""Return the input tensors of 't_op' at index 'idx', or None if the operator doesn't have that input.
@@ -135,34 +133,6 @@ def uses_shape_broadcasting(t_op: tflite_model.Operator) -> bool:
135133
)
136134

137135

138-
def node_uses_shape_broadcasting(node: Node) -> bool:
139-
"""Determine if given PyTorch fx Node uses shape broadcasting for it's input nodes or not.
140-
141-
:param node: PyTorch fx Node with 'all_input_nodes' initialized.
142-
:return: True, if the node uses shape broadcasting for it's input nodes.
143-
False otherwise.
144-
"""
145-
146-
if node.all_input_nodes is None:
147-
logger.e(
148-
logger.Code.INTERNAL_ERROR,
149-
"common.node_uses_shape_broadcasting(): 'all_input_nodes' are None!",
150-
)
151-
152-
if len(node.all_input_nodes) == 0:
153-
logger.e(
154-
logger.Code.INTERNAL_ERROR,
155-
"common.node_uses_shape_broadcasting(): Operator has no inputs!",
156-
)
157-
158-
first_input_shape = node.all_input_nodes[0].meta["val"].shape
159-
160-
return any(
161-
input_tensor.meta["val"].shape != first_input_shape
162-
for input_tensor in node.all_input_nodes[1:]
163-
)
164-
165-
166136
class OpsList:
167137
"""
168138
Holder of TFLite operator (middle_op) that can be prefixed (pre_ops) of suffixed (post_ops)

backends/nxp/backend/ir/converter/node_converter.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
input_quantization_type,
1717
output_quantization_type,
1818
)
19+
from executorch.backends.nxp.backend.ir import logger as logger
1920
from executorch.backends.nxp.backend.ir.conversion_context import ConversionContext
2021
from executorch.backends.nxp.backend.ir.converter.builder.aten_model_builder_director import (
2122
AtenModelBuilderDirector,
@@ -380,3 +381,67 @@ def uses_quantization_type_for_io(
380381
) and NodeConverter.uses_quantization_type_for_outputs(
381382
node, supported_types, output_indices
382383
)
384+
385+
@staticmethod
386+
def uses_shape_broadcasting(node: Node) -> bool:
387+
"""Determine if given PyTorch fx Node uses shape broadcasting for it's input nodes or not.
388+
389+
:param node: PyTorch fx Node with 'all_input_nodes' initialized.
390+
:return: True, if the node uses shape broadcasting for it's input nodes.
391+
False otherwise.
392+
"""
393+
394+
if node.all_input_nodes is None:
395+
logger.e(
396+
logger.Code.INTERNAL_ERROR,
397+
"node_converter.uses_shape_broadcasting(): 'all_input_nodes' are None!",
398+
)
399+
400+
if len(node.all_input_nodes) == 0:
401+
logger.e(
402+
logger.Code.INTERNAL_ERROR,
403+
"node_converter.uses_shape_broadcasting(): Operator has no inputs!",
404+
)
405+
406+
first_input_shape = node.all_input_nodes[0].meta["val"].shape
407+
408+
return any(
409+
input_tensor.meta["val"].shape != first_input_shape
410+
for input_tensor in node.all_input_nodes[1:]
411+
)
412+
413+
@staticmethod
414+
def at_least_one_input_shape_matches_the_output_shape(node: Node) -> bool:
415+
"""Determine if given PyTorch fx Node uses at least one input shape broadcasting for it's input nodes or not.
416+
417+
:param node: PyTorch fx Node with 'all_input_nodes' initialized.
418+
:return: True, if at least one input has the same shape as the output node.
419+
False otherwise.
420+
"""
421+
422+
if node.all_input_nodes is None:
423+
logger.e(
424+
logger.Code.INTERNAL_ERROR,
425+
"node_converter.at_least_one_input_shape_matches_the_output_shape(): 'all_input_nodes' are None!",
426+
)
427+
428+
if len(node.all_input_nodes) == 0:
429+
logger.e(
430+
logger.Code.INTERNAL_ERROR,
431+
"node_converter.at_least_one_input_shape_matches_the_output_shape(): Operator has no inputs!",
432+
)
433+
434+
output_shape = node.meta["val"].shape
435+
436+
return any(
437+
input_tensor.meta["val"].shape == output_shape
438+
for input_tensor in node.all_input_nodes
439+
)
440+
441+
@staticmethod
442+
def _node_inputs_ranks_not_equal(node) -> bool:
443+
first_input_shape = node.all_input_nodes[0].meta["val"].shape
444+
return not all(
445+
len(input_node.meta["val"].shape) == len(first_input_shape)
446+
for input_node in node.all_input_nodes[1:]
447+
)

backends/nxp/backend/ir/converter/node_converters/ops_converters/add_tensor_converter.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
1-
# Copyright 2025 NXP
1+
# Copyright 2025-2026 NXP
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
from executorch.backends.nxp.backend.ir.converter.conversion.common import (
7-
node_uses_shape_broadcasting,
8-
)
96
from executorch.backends.nxp.backend.ir.converter.node_converter import (
107
CustomDelegationOptions,
118
NodeConverter,
@@ -26,7 +23,7 @@ def _is_supported_on_target(
2623
parameters_mapping: dict[str, Parameter],
2724
custom_delegation_options: CustomDelegationOptions,
2825
) -> bool:
29-
if node_uses_shape_broadcasting(node):
26+
if NodeConverter.uses_shape_broadcasting(node):
3027
# Shape broadcasting may require the addition of `Transpose` ops during conversion.
3128
return False
3229

backends/nxp/backend/ir/converter/node_converters/ops_converters/mul_tensor_converter.py

Lines changed: 41 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
# Copyright 2025 NXP
1+
# Copyright 2025-2026 NXP
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
from executorch.backends.nxp.backend.ir.converter.conversion.common import (
7-
node_uses_shape_broadcasting,
8-
)
6+
import torch
7+
8+
from executorch.backends.nxp.backend.data_format import NXP_NODE_FORMAT
99
from executorch.backends.nxp.backend.ir.converter.node_converter import (
1010
CustomDelegationOptions,
1111
NodeConverter,
@@ -26,19 +26,41 @@ def _is_supported_on_target(
2626
parameters_mapping: dict[str, Parameter],
2727
custom_delegation_options: CustomDelegationOptions,
2828
) -> bool:
29-
if node_uses_shape_broadcasting(node):
30-
# Shape broadcasting may require the addition of `Transpose` ops during conversion.
31-
return False
29+
if custom_delegation_options.use_new_flow_neutron_c:
30+
if not NodeConverter.at_least_one_input_shape_matches_the_output_shape(
31+
node
32+
):
33+
return False
34+
35+
# If one input is in channel first and ranks of input tensors are not equal, we need to add Transposes
36+
# Transpose is currently not supported for new flow
37+
if any(
38+
input_node.meta[NXP_NODE_FORMAT].is_channels_first()
39+
for input_node in node.all_input_nodes
40+
) and NodeConverter._node_inputs_ranks_not_equal(node):
41+
return False
42+
43+
supported_types = [torch.int8, torch.uint8]
44+
if not NodeConverter.uses_quantization_type_for_io(
45+
node, supported_types, [0, 1], [0]
46+
):
47+
return False
48+
49+
return True
50+
else:
51+
if NodeConverter.uses_shape_broadcasting(node):
52+
# Shape broadcasting may require the addition of `Transpose` ops during conversion.
53+
return False
3254

33-
node_shape = node.meta["val"].shape
55+
node_shape = node.meta["val"].shape
3456

35-
# Check that at least one dimension is divisible by number of MACS
36-
# or all dimensions are equal to one
37-
# Otherwise Neutron cannot convert it
38-
dim_divisible = any(s % 8 == 0 for s in node_shape) or all(
39-
s == 1 for s in node_shape
40-
)
41-
return dim_divisible
57+
# Check that at least one dimension is divisible by number of MACS
58+
# or all dimensions are equal to one
59+
# Otherwise Neutron cannot convert it
60+
dim_divisible = any(s % 8 == 0 for s in node_shape) or all(
61+
s == 1 for s in node_shape
62+
)
63+
return dim_divisible
4264

4365
@staticmethod
4466
def _is_supported_in_IR(
@@ -51,9 +73,11 @@ def _is_supported_in_IR(
5173

5274
return True
5375

54-
# mul.Tensor Node format: (Tensor self, Tensor other, *)
5576
def convert(self, node: Node):
56-
"""Convert 'mul_tensor' operator to NeutronIR 'Mul'."""
77+
"""Convert 'mul_tensor' operator to NeutronIR 'Mul'.
78+
The ExecuTorch schema is:
79+
mul.Tensor(Tensor self, Tensor other)
80+
"""
5781
self.assert_convertible(node)
5882
t_op = self._create_tflite_op_with_io_tensors(node)
5983
t_op.builtin_options = mul_options.Mul()

backends/nxp/backend/ir/converter/node_converters/ops_converters/sub_tensor_converter.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
1-
# Copyright 2025 NXP
1+
# Copyright 2025-2026 NXP
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
from executorch.backends.nxp.backend.ir.converter.conversion.common import (
7-
node_uses_shape_broadcasting,
8-
)
96
from executorch.backends.nxp.backend.ir.converter.node_converter import (
107
CustomDelegationOptions,
118
NodeConverter,
@@ -26,7 +23,7 @@ def _is_supported_on_target(
2623
parameters_mapping: dict[str, Parameter],
2724
custom_delegation_options: CustomDelegationOptions,
2825
) -> bool:
29-
if node_uses_shape_broadcasting(node):
26+
if NodeConverter.uses_shape_broadcasting(node):
3027
# Shape broadcasting may require the addition of `Transpose` ops during conversion.
3128
return False
3229

backends/nxp/quantizer/patterns.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -880,7 +880,7 @@ class MulTensorPattern(QuantizationPattern):
880880
Basic quantization for all inputs and output.
881881
"""
882882

883-
def partition_types(self) -> list[torch.nn.Module]:
883+
def partition_types(self) -> list[OpOverload]:
884884
return [torch.ops.aten.mul.Tensor]
885885

886886
def get_anchors(

0 commit comments

Comments
 (0)