Skip to content

Commit ec35724

Browse files
Test Add Tensor with new Neutron flow
1 parent 9ccbc4a commit ec35724

4 files changed

Lines changed: 229 additions & 18 deletions

File tree

backends/nxp/backend/ir/converter/node_converters/ops_converters/add_tensor_converter.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
import torch
7+
8+
from executorch.backends.nxp.backend.data_format import NXP_NODE_FORMAT
69
from executorch.backends.nxp.backend.ir.converter.node_converter import (
710
CustomDelegationOptions,
811
NodeConverter,
@@ -23,11 +26,33 @@ def _is_supported_on_target(
2326
parameters_mapping: dict[str, Parameter],
2427
custom_delegation_options: CustomDelegationOptions,
2528
) -> bool:
26-
if NodeConverter.uses_shape_broadcasting(node):
27-
# Shape broadcasting may require the addition of `Transpose` ops during conversion.
28-
return False
29+
if custom_delegation_options.use_new_flow_neutron_c:
30+
if not NodeConverter.at_least_one_input_shape_matches_the_output_shape(
31+
node
32+
):
33+
return False
2934

30-
return True
35+
# If one input is in channel first and ranks of input tensors are not equal, we need to add Transposes
36+
# Transpose is currently not supported for new flow
37+
if any(
38+
input_node.meta[NXP_NODE_FORMAT].is_channels_first()
39+
for input_node in node.all_input_nodes
40+
) and NodeConverter._node_inputs_ranks_not_equal(node):
41+
return False
42+
43+
supported_types = [torch.int8, torch.uint8]
44+
if not NodeConverter.uses_quantization_type_for_io(
45+
node, supported_types, [0, 1], [0]
46+
):
47+
return False
48+
49+
return True
50+
else:
51+
if NodeConverter.uses_shape_broadcasting(node):
52+
# Shape broadcasting may require the addition of `Transpose` ops during conversion.
53+
return False
54+
55+
return True
3156

3257
@staticmethod
3358
def _is_supported_in_IR(
@@ -43,12 +68,13 @@ def _is_supported_in_IR(
4368

4469
return True
4570

46-
# add.Tensor Node format: (Tensor self, Tensor other, *, Scalar alpha=1)
4771
def convert(self, node: Node):
48-
"""Convert 'add_tensor' operator to TFLite 'add'."""
72+
"""Convert 'add_tensor' operator to NeutronIR 'Add'.
73+
The ExecuTorch schema is:
74+
add.Tensor(Tensor self, Tensor other, Scalar alpha=1)
75+
"""
4976
self.assert_convertible(node)
50-
5177
t_op = self._create_tflite_op_with_io_tensors(node)
52-
5378
t_op.builtin_options = add_options.Add()
79+
5480
self.builder.append_operators([t_op])

backends/nxp/tests/ir/converter/node_converter/test_add_tensor_converter.py

Lines changed: 192 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,36 @@
1-
# Copyright 2025 NXP
1+
# Copyright 2025-2026 NXP
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
5+
56
import numpy as np
67
import pytest
78
import torch
89

910
from executorch.backends.nxp.backend.edge_program_converter import (
1011
EdgeProgramToIRConverter,
1112
)
12-
from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program
13+
from executorch.backends.nxp.tests.executorch_pipeline import (
14+
ModelInputSpec,
15+
to_quantized_edge_program,
16+
)
1317
from executorch.backends.nxp.tests.executors import (
1418
convert_run_compare,
19+
graph_contains_any_of_ops,
1520
ToChannelFirstPreprocess,
1621
ToChannelLastPreprocess,
1722
)
23+
from executorch.backends.nxp.tests.graph_verifier import BaseGraphVerifier
24+
from executorch.backends.nxp.tests.model_output_comparator import (
25+
NumericalStatsOutputComparator,
26+
)
1827
from executorch.backends.nxp.tests.models import (
1928
AddTensorConvModule,
2029
AddTensorModule,
2130
AddTensorOneInputModule,
2231
)
32+
from executorch.backends.nxp.tests.nsys_testing import lower_run_compare
33+
from executorch.backends.nxp.tests.ops_aliases import AddTensor, ExecutorchDelegateCall
2334
from torch.export import ExportedProgram
2435
from executorch.backends.nxp.tests.use_qat import * # noqa F403
2536

@@ -64,7 +75,6 @@ def test_add_tensor_quant_conversion(mocker, input_shape, use_qat):
6475
@pytest.mark.parametrize(
6576
"input_shape",
6677
[
67-
pytest.param((4,), id="1D."),
6878
pytest.param((6, 6), id="2D."),
6979
pytest.param((1, 4, 8), id="3D."),
7080
pytest.param((1, 4, 8, 8), id="4D."),
@@ -92,20 +102,26 @@ def test_add_tensor_one_input_quant_conversion(mocker, input_shape, use_qat):
92102

93103

94104
@pytest.mark.parametrize(
95-
"input_shape",
105+
"x_input_shape",
96106
[
97107
pytest.param((1, 4, 8, 8), id="4D."),
98108
pytest.param((1, 4, 5, 5), id="4D, product of dims is not a multiple of 8."),
99109
],
100110
)
101-
def test_add_tensor_w_conv_quant_conversion(mocker, input_shape, use_qat):
111+
def test_add_tensor_w_conv_quant_conversion(mocker, x_input_shape, use_qat):
102112
model = AddTensorConvModule()
103113

104114
converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program")
105115

116+
n, c, h, w = x_input_shape
117+
y_input_shape = (n, 8, h, w)
118+
106119
# Run conversion
107120
_ = to_quantized_edge_program(
108-
model, input_shape, use_qat=use_qat, use_neutron_for_format_conversion=False
121+
model,
122+
[x_input_shape, y_input_shape],
123+
use_qat=use_qat,
124+
use_neutron_for_format_conversion=False,
109125
)
110126

111127
# Capture generated model
@@ -114,7 +130,13 @@ def test_add_tensor_w_conv_quant_conversion(mocker, input_shape, use_qat):
114130
# Capture converted program
115131
exported_program: ExportedProgram = converter_spy.call_args.args[1]
116132

117-
input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype(np.int8)
133+
input_data_1 = (np.random.random(x_input_shape).astype(np.float32) * 50).astype(
134+
np.int8
135+
)
136+
input_data_2 = (np.random.random(y_input_shape).astype(np.float32) * 50).astype(
137+
np.int8
138+
)
139+
input_data = {0: input_data_1, 1: input_data_2}
118140

119141
convert_run_compare(
120142
exported_program,
@@ -149,7 +171,7 @@ def test_add_tensor_broadcasting_unsupported_quant_conversion(
149171
nodes = list(edge_program.graph.nodes)
150172

151173
# Broadcast is not supported, node is not converted
152-
assert nodes[6].target.__name__ == "aten.add.Tensor" # Add Tensor is not delegated.
174+
assert nodes[6].target == AddTensor # Add Tensor is not delegated.
153175

154176
# Capture converted program
155177
# exported_program: ExportedProgram = converter_spy.call_args.args[1]
@@ -159,3 +181,165 @@ def test_add_tensor_broadcasting_unsupported_quant_conversion(
159181
# input_data = {0: x_input_data, 1: y_input_data}
160182
#
161183
# convert_run_compare(exported_program, tfl_model=tflite_flatbuffers_model, input_data=input_data)
184+
185+
186+
class TestAddTensorNewNeutronFlow:
187+
@pytest.mark.skip("AIR-14602: incorrect results")
188+
@pytest.mark.parametrize(
189+
"x_input_shape",
190+
[
191+
pytest.param((1,), id="1D."),
192+
pytest.param((6, 8), id="2D."),
193+
pytest.param((1, 4, 8), id="3D."),
194+
pytest.param((1, 4, 8, 8), id="4D."),
195+
],
196+
)
197+
def test__basic_nsys_inference(self, x_input_shape):
198+
x_input_spec = ModelInputSpec(x_input_shape)
199+
model = AddTensorModule()
200+
graph_verifier = BaseGraphVerifier(
201+
exp_num_delegate_call_nodes=1,
202+
exp_non_delegated_nodes=[],
203+
)
204+
205+
lower_run_compare(
206+
model,
207+
[x_input_spec, x_input_spec],
208+
graph_verifier,
209+
use_new_flow_neutron_c=True,
210+
)
211+
212+
@pytest.mark.skip("AIR-14602: incorrect results")
213+
@pytest.mark.parametrize(
214+
"x_input_shape",
215+
[
216+
pytest.param((6, 8), id="2D."),
217+
pytest.param((1, 4, 8), id="3D."),
218+
pytest.param((1, 4, 8, 8), id="4D."),
219+
],
220+
)
221+
def test__basic_nsys_inference_qat(self, x_input_shape):
222+
x_input_spec = ModelInputSpec(x_input_shape)
223+
model = AddTensorModule()
224+
comparator = NumericalStatsOutputComparator()
225+
graph_verifier = BaseGraphVerifier(
226+
exp_num_delegate_call_nodes=1,
227+
exp_non_delegated_nodes=[],
228+
)
229+
230+
lower_run_compare(
231+
model,
232+
[x_input_spec, x_input_spec],
233+
graph_verifier,
234+
output_comparator=comparator,
235+
use_new_flow_neutron_c=True,
236+
use_qat=True,
237+
)
238+
239+
@pytest.mark.parametrize(
240+
"input_spec",
241+
[
242+
pytest.param(
243+
[ModelInputSpec((4, 6)), ModelInputSpec((1, 6))], id="2 inputs 2D."
244+
),
245+
pytest.param(
246+
[ModelInputSpec((5, 3, 4)), ModelInputSpec((1, 3, 1))],
247+
id="2 inputs 3D.",
248+
),
249+
pytest.param(
250+
[ModelInputSpec((4,)), ModelInputSpec((4, 4))], id="2 inputs 2D+3D."
251+
),
252+
],
253+
)
254+
def test__correct_broadcast(self, input_spec):
255+
model = AddTensorModule()
256+
graph_verifier = BaseGraphVerifier(
257+
exp_num_delegate_call_nodes=1,
258+
exp_non_delegated_nodes=[],
259+
)
260+
261+
lower_run_compare(
262+
model, input_spec, graph_verifier, use_new_flow_neutron_c=True
263+
)
264+
265+
@pytest.mark.parametrize(
266+
"input_spec",
267+
[
268+
pytest.param(
269+
[ModelInputSpec((4, 1)), ModelInputSpec((1, 6))], id="2 inputs 2D."
270+
),
271+
pytest.param(
272+
[ModelInputSpec((1, 3, 4)), ModelInputSpec((5, 3, 1))],
273+
id="2 inputs 3D.",
274+
),
275+
pytest.param(
276+
[ModelInputSpec((6, 4)), ModelInputSpec((6, 6, 1))],
277+
id="2 inputs 2D+3D.",
278+
),
279+
],
280+
)
281+
def test__incorrect_broadcast(self, input_spec):
282+
# Broadcast where at least one of the inputs is not equal to output is not supported
283+
model = AddTensorModule()
284+
285+
delegated_ep = to_quantized_edge_program(
286+
model, input_spec, use_new_flow_neutron_c=True
287+
).exported_program()
288+
289+
# Make sure the `add.Tensor` was NOT delegated.
290+
assert not graph_contains_any_of_ops(
291+
delegated_ep.graph, [ExecutorchDelegateCall]
292+
)
293+
assert graph_contains_any_of_ops(delegated_ep.graph, [AddTensor])
294+
295+
@pytest.mark.skip("AIR-14602: incorrect results")
296+
@pytest.mark.parametrize(
297+
"x_input_shape",
298+
[
299+
pytest.param(
300+
(1, 4, 5, 5), id="4D, product of dims is not a multiple of 8."
301+
),
302+
],
303+
)
304+
def test__w_conv(self, x_input_shape):
305+
model = AddTensorConvModule()
306+
307+
n, c, h, w = x_input_shape
308+
y_input_spec = ModelInputSpec((n, 8, h, w))
309+
x_input_spec = ModelInputSpec(x_input_shape)
310+
311+
graph_verifier = BaseGraphVerifier(
312+
exp_num_delegate_call_nodes=1,
313+
exp_non_delegated_nodes=[],
314+
)
315+
316+
lower_run_compare(
317+
model,
318+
[x_input_spec, y_input_spec],
319+
graph_verifier,
320+
use_new_flow_neutron_c=True,
321+
)
322+
323+
@pytest.mark.parametrize(
324+
"input_spec",
325+
[
326+
pytest.param(
327+
[ModelInputSpec((1, 4, 5, 5)), ModelInputSpec((1, 5))],
328+
id="2 inputs 4D + 2D.",
329+
),
330+
pytest.param(
331+
[ModelInputSpec((1, 4, 4, 10)), ModelInputSpec((1, 4, 1))],
332+
id="2 inputs last + 3D.",
333+
),
334+
],
335+
)
336+
def test__w_conv_unsupported(self, input_spec):
337+
model = AddTensorConvModule()
338+
339+
delegated_ep = to_quantized_edge_program(
340+
model, input_spec, use_new_flow_neutron_c=True
341+
).exported_program()
342+
343+
# Make sure the `add.Tensor` was NOT delegated.
344+
assert graph_contains_any_of_ops(delegated_ep.graph, [ExecutorchDelegateCall])
345+
assert graph_contains_any_of_ops(delegated_ep.graph, [AddTensor])

backends/nxp/tests/models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -656,9 +656,9 @@ def __init__(self):
656656
super().__init__()
657657
self.conv = Conv2dModule(padding=1, stride=1)
658658

659-
def forward(self, x):
659+
def forward(self, x, y):
660660
x = self.conv(x)
661-
return x + x
661+
return x + y
662662

663663

664664
class AddTensorOneInputModule(torch.nn.Module):

backends/nxp/tests/ops_aliases.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from executorch.exir.dialects._ops import ops as exir_ops
1313

1414
Abs = exir_ops.edge.aten.abs.default
15+
AddTensor = exir_ops.edge.aten.add.Tensor
1516
AvgPool2D = exir_ops.edge.aten.avg_pool2d.default
1617
Bmm = exir_ops.edge.aten.bmm.default
1718
Convolution = exir_ops.edge.aten.convolution.default

0 commit comments

Comments
 (0)