Skip to content

Commit 4155842

Browse files
committed
NXP backend: Test avg_pool2d with new Neutron flow.
1 parent 32309af commit 4155842

5 files changed

Lines changed: 301 additions & 17 deletions

File tree

backends/nxp/backend/edge_helper.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import logging
7+
import operator
78

89
import torch
910

@@ -367,3 +368,104 @@ def node_has_well_defined_shape(node: Node) -> bool:
367368

368369
def try_get_arg(node: Node, idx: int) -> Argument | None:
369370
return node.args[idx] if idx < len(node.args) else None
371+
372+
373+
def input_quantization_type(
374+
node: Node, input_index: int | tuple[int, int]
375+
) -> torch.dtype | None:
376+
"""Return the quantization input datatype of the QDQ quantized `node`.
377+
378+
:param node: The compute node.
379+
:param input_index: The index into the `node.args`. If a tuple of 2 ints is provided,
380+
`args[input_index[0]][input_index[1]]` is used instead.
381+
:return: The input quantization datatype of the QDQ quantized `node`, or `None` if the graph does not follow the
382+
QDQ pattern or some metadata is incomplete or an invalid input index is given.
383+
384+
│ <returned type>
385+
┌─────▼──────┐
386+
│ Dequantize │
387+
└─────┬──────┘
388+
│ float
389+
┌───▼────┐
390+
│ `node` │
391+
└───┬────┘
392+
"""
393+
try:
394+
if isinstance(input_index, int):
395+
dequantize_node = node.args[input_index]
396+
elif (
397+
isinstance(input_index, tuple)
398+
and len(input_index) == 2
399+
and all(isinstance(i, int) for i in input_index)
400+
):
401+
dequantize_node = node.args[input_index[0]][input_index[1]]
402+
else:
403+
raise RuntimeError(
404+
"NXP backend: edge_helper.input_quantization_type(): Invalid input index."
405+
)
406+
except IndexError:
407+
return None # Invalid input args index.
408+
409+
if not _is_dequantize(dequantize_node):
410+
return None # Broken QDQ schema.
411+
412+
if (dequantize_input_val := dequantize_node.args[0].meta.get("val")) is None:
413+
return None # Invalid metadata.
414+
415+
return dequantize_input_val.dtype
416+
417+
418+
def output_quantization_type(
419+
node: Node, output_index: int | None = None
420+
) -> torch.dtype | None:
421+
"""Return the quantization output datatype of the QDQ quantized `node`.
422+
423+
:param node: The compute node.
424+
:param output_index: If the `node` has multiple outputs and therefore multiple `getitem` nodes follow it, the
425+
index selects the output.
426+
:return: The output quantization datatype of the QDQ quantized `node`, or `None` if the graph does not follow the
427+
QDQ pattern or some metadata is incomplete or an invalid input index is given.
428+
429+
┌───▼────┐
430+
│ `node` │
431+
┌───▼────┐ └───┬────┘
432+
│ `node` │ │
433+
└───┬────┘ ┌──┴───────────────...──
434+
│ float ┌─────────▼─────────────┐
435+
┌────▼─────┐ or │ getitem(output_index) │ ...
436+
│ Quantize │ └─────────┬─────────────┘
437+
└────┬─────┘ │ float
438+
│ <returned type> ┌────▼─────┐
439+
│ Quantize │
440+
└────┬─────┘
441+
│ <returned type>
442+
"""
443+
users = list(node.users)
444+
if len(users) == 1:
445+
if not _is_quantize(quantize_node := users[0]):
446+
return None
447+
448+
else: # Multiple users
449+
if not isinstance(output_index, int):
450+
return None # Invalid index.
451+
if not all(user.target == operator.getitem for user in users):
452+
# Broken QDQ schema (unexpected nodes). These nodes should be moved out by
453+
# `move_auxiliary_operator_into_separate_qdq_cluster_pass.py`.
454+
return None
455+
456+
selected_getitems = list(
457+
filter(lambda getitem: getitem.args[1] == output_index, users)
458+
)
459+
if len(selected_getitems) != 1:
460+
return None # Multiple getitems access the selected output -> broken QDQ schema.
461+
selected_getitem_users = list(selected_getitems[0].users)
462+
if not (
463+
len(selected_getitem_users) == 1
464+
and _is_quantize(quantize_node := selected_getitem_users[0])
465+
):
466+
return None # Broken QDQ schema.
467+
468+
if (quantize_val := quantize_node.meta.get("val")) is None:
469+
return None # Invalid metadata.
470+
471+
return quantize_val.dtype

backends/nxp/backend/ir/converter/node_converter.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
from executorch.backends.nxp.backend.custom_delegation_options import (
1313
CustomDelegationOptions,
1414
)
15+
from executorch.backends.nxp.backend.edge_helper import (
16+
input_quantization_type,
17+
output_quantization_type,
18+
)
1519
from executorch.backends.nxp.backend.ir.conversion_context import ConversionContext
1620
from executorch.backends.nxp.backend.ir.converter.builder.aten_model_builder_director import (
1721
AtenModelBuilderDirector,
@@ -308,3 +312,68 @@ def _create_tflite_op_with_io_tensors(self, node: Node) -> tflite_model.Operator
308312
t_operator.tmp_outputs.append(self.builder.tensor_for_name(tensor_name))
309313

310314
return t_operator
315+
316+
@staticmethod
317+
def uses_quantization_type_for_inputs(
318+
node: Node,
319+
supported_types: list[torch.dtype],
320+
input_indices: list[int | tuple[int, int]],
321+
) -> bool:
322+
"""Check if `node` uses the QDQ quantization schema and inputs on the provided indices use a quantization type
323+
that is in `supported_types`.
324+
325+
:param node: The compute node.
326+
:param supported_types: List of supported quantization types.
327+
:param input_indices: List of indices into the `node.args`, or tuples of 2 indices into `node.args[idx1][idx2]`.
328+
:return: True, if the `node` is QDQ quantized and has quantization input types in `supported_types`.
329+
"""
330+
return all(
331+
input_quantization_type(node, input_index) in supported_types
332+
for input_index in input_indices
333+
)
334+
335+
@staticmethod
336+
def uses_quantization_type_for_outputs(
337+
node: Node,
338+
supported_types: list[torch.dtype],
339+
output_indices: list[int] | None = None,
340+
):
341+
"""Check if `node` uses the QDQ quantization schema and outputs on the provided indices use a quantization type
342+
that is in `supported_types`.
343+
344+
:param node: The compute node.
345+
:param supported_types: List of supported quantization types.
346+
:param output_indices: If the `node` has multiple outputs and therefore multiple `getitem` nodes follow it, the
347+
indices select the outputs to be checked.
348+
:return: True, if the `node` is QDQ quantized and has quantization output types in `supported_types`.
349+
"""
350+
if output_indices is None:
351+
return output_quantization_type(node) in supported_types
352+
else:
353+
return all(
354+
output_quantization_type(node, output_index) in supported_types
355+
for output_index in output_indices
356+
)
357+
358+
@staticmethod
359+
def uses_quantization_type_for_io(
360+
node: Node,
361+
supported_types: list[torch.dtype],
362+
input_indices: list[int | tuple[int, int]],
363+
output_indices: list[int] | None = None,
364+
):
365+
"""Check if `node` uses the QDQ quantization schema and inputs and outputs on the provided indices use a
366+
quantization type that is in `supported_types`.
367+
368+
:param node: The compute node.
369+
:param supported_types: List of supported quantization types.
370+
:param input_indices: List of indices into the `node.args`, or tuples of 2 indices into `node.args[idx1][idx2]`.
371+
:param output_indices: If the `node` has multiple outputs and therefore multiple `getitem` nodes follow it, the
372+
indices select the outputs to be checked.
373+
:return: True, if the `node` is QDQ quantized and has quantization input types in `supported_types`.
374+
"""
375+
return NodeConverter.uses_quantization_type_for_inputs(
376+
node, supported_types, input_indices
377+
) and NodeConverter.uses_quantization_type_for_outputs(
378+
node, supported_types, output_indices
379+
)

backends/nxp/backend/ir/converter/node_converters/ops_converters/avg_pool_2d_converter.py

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
# Copyright 2025 NXP
1+
# Copyright 2025-2026 NXP
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

66
import numpy as np
7+
import torch
78

89
from executorch.backends.nxp.backend.ir.converter.conversion import (
910
aten_translator,
@@ -21,6 +22,8 @@
2122
from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options import (
2223
average_pool_2d_options,
2324
)
25+
26+
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
2427
from torch.fx import Node
2528
from torch.nn import Parameter
2629

@@ -53,6 +56,33 @@ def _is_supported_in_IR(
5356

5457
return True
5558

59+
@staticmethod
60+
def _is_supported_on_target(
61+
node: Node,
62+
neutron_target_spec: NeutronTargetSpec,
63+
parameters_mapping: dict[str, Parameter],
64+
custom_delegation_options: CustomDelegationOptions,
65+
) -> bool:
66+
kernel = node.args[1]
67+
stride = node.args[2]
68+
69+
if custom_delegation_options.use_new_flow_neutron_c:
70+
# Requirements specified by the new Neutron flow documentation.
71+
72+
supported_types = [torch.int8, torch.uint8]
73+
if not NodeConverter.uses_quantization_type_for_io(
74+
node, supported_types, [0]
75+
):
76+
return False
77+
78+
if any(k > 4096 for k in kernel):
79+
return False
80+
81+
if any(s > 4096 for s in stride):
82+
return False
83+
84+
return True
85+
5686
# noinspection PyMethodMayBeStatic
5787
def _convert_2d_avg_pool(
5888
self, kernel_size, stride, padding, t_op: tflite_model.Operator
@@ -85,10 +115,19 @@ def _convert_2d_avg_pool(
85115

86116
return ops.flatten()
87117

88-
# AvgPool2d Node format: (Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False
89-
# bool count_include_pad=True, int? divisor_override=None)
90118
def convert(self, node: Node):
91-
"""Convert 'avg_pool2d' operator to TFLite 'AveragePool2D'."""
119+
"""Convert 'avg_pool2d' operator to TFLite 'AveragePool2D'.
120+
The ExecuTorch schema is:
121+
aten.avg_pool2d(
122+
Tensor self,
123+
int[2] kernel_size,
124+
int[2] stride=[],
125+
int[2] padding=0,
126+
bool ceil_mode=False
127+
bool count_include_pad=True,
128+
int? divisor_override=None
129+
)
130+
"""
92131
self.assert_convertible(node)
93132

94133
kernel_size = node.args[1]

backends/nxp/tests/ir/converter/node_converter/test_avg_pool2d_converter.py

Lines changed: 84 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024 NXP
1+
# Copyright 2024,2026 NXP
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -28,18 +28,22 @@
2828
ToNCHWPreprocess,
2929
ToNHWCPreprocess,
3030
)
31+
from executorch.backends.nxp.tests.graph_verifier import BaseGraphVerifier
3132
from executorch.backends.nxp.tests.models import AvgPool2dConvModule, AvgPool2dModule
33+
34+
from executorch.backends.nxp.tests.nsys_testing import lower_run_compare
35+
36+
from executorch.backends.nxp.tests.ops_aliases import (
37+
AvgPool2D,
38+
ExecutorchDelegateCall,
39+
Squeeze,
40+
SqueezeDim,
41+
SqueezeDims,
42+
Unsqueeze,
43+
ViewCopy,
44+
)
3245
from torch.export import ExportedProgram
3346
from executorch.backends.nxp.tests.use_qat import * # noqa F403
34-
from executorch.exir.dialects._ops import ops as exir_ops
35-
36-
AvgPool2D = exir_ops.edge.aten.avg_pool2d.default
37-
ExecutorchDelegateCall = torch.ops.higher_order.executorch_call_delegate
38-
Squeeze = exir_ops.edge.aten.squeeze.default
39-
SqueezeDim = exir_ops.edge.aten.squeeze.dim
40-
SqueezeDims = exir_ops.edge.aten.squeeze.dims
41-
Unsqueeze = exir_ops.edge.aten.unsqueeze.default
42-
ViewCopy = exir_ops.edge.aten.view_copy.default
4347

4448

4549
@pytest.fixture(autouse=True)
@@ -296,3 +300,73 @@ def test_from_avg_pool_1d(mocker):
296300
tflite_input_preprocess=ToChannelLastPreprocess(),
297301
tflite_output_preprocess=ToChannelFirstPreprocess(),
298302
)
303+
304+
305+
class TestAvgPool2DNewNeutronFlow:
306+
def test__basic_nsys_inference(self):
307+
input_shape = (2, 4, 6, 7)
308+
model = AvgPool2dModule(False, 0)
309+
graph_verifier = BaseGraphVerifier(
310+
exp_num_delegate_call_nodes=1, # Delegated AvgPool.
311+
exp_non_delegated_nodes=[],
312+
)
313+
314+
lower_run_compare(
315+
model, input_shape, graph_verifier, use_new_flow_neutron_c=True
316+
)
317+
318+
def test__kernel_size_limit(self):
319+
kernel_size = (1, 4096)
320+
input_shape = (1, 4) + kernel_size
321+
model = AvgPool2dModule(False, 0, kernel_size)
322+
graph_verifier = BaseGraphVerifier(
323+
exp_num_delegate_call_nodes=1, # Delegated AvgPool.
324+
exp_non_delegated_nodes=[],
325+
)
326+
327+
lower_run_compare(
328+
model, input_shape, graph_verifier, use_new_flow_neutron_c=True
329+
)
330+
331+
def test__kernel_size_limit_exceeded(self):
332+
kernel_size = (1, 4097) # Exceeds the kernel size limit.
333+
input_shape = (1, 4) + kernel_size
334+
model = AvgPool2dModule(False, 0, kernel_size)
335+
336+
delegated_ep = to_quantized_edge_program(
337+
model, input_shape, use_new_flow_neutron_c=True
338+
).exported_program()
339+
340+
# Make sure the `avg_pool2d` was NOT delegated.
341+
assert not graph_contains_any_of_ops(
342+
delegated_ep.graph, [ExecutorchDelegateCall]
343+
)
344+
assert graph_contains_any_of_ops(delegated_ep.graph, [AvgPool2D])
345+
346+
def test__stride_limit(self):
347+
stride = 4096
348+
input_shape = (1, 4, 1, 4096)
349+
model = AvgPool2dModule(False, 0, 1, stride)
350+
graph_verifier = BaseGraphVerifier(
351+
exp_num_delegate_call_nodes=1, # Delegated AvgPool.
352+
exp_non_delegated_nodes=[],
353+
)
354+
355+
lower_run_compare(
356+
model, input_shape, graph_verifier, use_new_flow_neutron_c=True
357+
)
358+
359+
def test__stride_limit_exceeded(self):
360+
stride = 4097 # Exceeds the stride limit.
361+
input_shape = (1, 4, 1, 4096)
362+
model = AvgPool2dModule(False, 0, 1, stride)
363+
364+
delegated_ep = to_quantized_edge_program(
365+
model, input_shape, use_new_flow_neutron_c=True
366+
).exported_program()
367+
368+
# Make sure the `avg_pool2d` was NOT delegated.
369+
assert not graph_contains_any_of_ops(
370+
delegated_ep.graph, [ExecutorchDelegateCall]
371+
)
372+
assert graph_contains_any_of_ops(delegated_ep.graph, [AvgPool2D])

backends/nxp/tests/models.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -348,12 +348,12 @@ def forward(self, x):
348348

349349

350350
class AvgPool2dModule(torch.nn.Module):
351-
def __init__(self, count_include_pad, padding=0):
351+
def __init__(self, count_include_pad, padding=0, kernel_size=3, stride=2):
352352
super().__init__()
353353

354354
self.avg_pool = torch.nn.AvgPool2d(
355-
kernel_size=3,
356-
stride=2,
355+
kernel_size=kernel_size,
356+
stride=stride,
357357
padding=padding,
358358
count_include_pad=count_include_pad,
359359
)

0 commit comments

Comments
 (0)