Skip to content

Commit 02fe862

Browse files
NXP backend: Update partitioning to avoid delegating no-op partitions. (#16901)
### Summary If a partition delegated no Neutron only contained no-ops, it would cause a crash. This PR identifies such cases and prohibits their delegation. ### Test plan Unit-tests provided.
1 parent e4492a3 commit 02fe862

10 files changed

Lines changed: 509 additions & 88 deletions

File tree

backends/nxp/backend/custom_delegation_options.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2025 NXP
1+
# Copyright 2025-2026 NXP
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -17,3 +17,8 @@ class CustomDelegationOptions:
1717
# of `num_macs`. The `force_delegate_cat` allows the user to turn off the defensive check if from the model design
1818
# it is known this constraint will be satisfied.
1919
force_delegate_cat: bool = False
20+
21+
# Proposed partitions which only contain Neutron no-ops are normally not delegated, as the NeutronConverter would
22+
# not create any NeutronGraph that can be called. This is done by the partitioner itself, and is not handled by
23+
# the individual node converters.
24+
allow_no_op_partitions: bool = False

backends/nxp/backend/edge_helper.py

Lines changed: 135 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1-
# Copyright 2024-2025 NXP
1+
# Copyright 2024-2026 NXP
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
import logging
7+
68
import torch
79

810
from executorch.exir.dialects._ops import ops as exir_ops
@@ -19,6 +21,14 @@
1921
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
2022
]
2123

24+
# A set of operators which could possibly be no-ops in certain conditions. The operators in this set will be proclaimed
25+
# as no-ops (and potentially not delegated), if their input and output tensors are equal (when run on random data).
26+
no_op_candidates = {
27+
exir_ops.edge.aten.add.Tensor,
28+
exir_ops.edge.aten.mul.Tensor,
29+
exir_ops.edge.aten.sub.Tensor,
30+
}
31+
2232

2333
def input_tensor(node: Node, input_index: int) -> torch.Tensor:
2434
if len(node.all_input_nodes) <= input_index:
@@ -220,3 +230,127 @@ def get_non_qdq_parent(node: Node, input_index: int = 0) -> Node | None:
220230
return None
221231

222232
return quant_node.args[0]
233+
234+
235+
def try_get_dequantized_data(
236+
dequantize_node: Node, parameters_mapping: dict[str, Parameter]
237+
) -> Parameter | None:
238+
"""Get the dequantized data from the following pattern. The dequantization formula is `r = (q - Z) * S`, where `q`
239+
represents the static quantized data.
240+
241+
┌─────────────────────────┐
242+
│ <static_quantized_data> │
243+
└────────────┬────────────┘
244+
245+
┌─────▼──────┐
246+
│ Dequantize │
247+
└─────┬──────┘
248+
249+
250+
251+
:param dequantize_node: The Dequantize node from the pattern, which dequantizes the static quantized data.
252+
:param parameters_mapping: Dict mapping tensor names to their static data. Should be inferred from the
253+
`state_dict` attribute of an edge program.
254+
:return: The dequantized static parameter, or `None` if the data is not available.
255+
"""
256+
if not _is_dequantize(dequantize_node):
257+
return None
258+
259+
if not node_is_static_tensor(param := dequantize_node.args[0], parameters_mapping):
260+
return None
261+
262+
# The pattern is correct. Dequantize the static data and return it.
263+
scale, zp = get_quantization_parameters_for(dequantize_node)
264+
quantized_data = parameters_mapping[param.name]
265+
266+
dequantized_data = (quantized_data - zp) * scale
267+
return dequantized_data
268+
269+
270+
def is_no_op_on_neutron(node: Node, parameters_mapping: dict[str, Parameter]) -> bool:
271+
"""Check if a node is a no-op operation from the perspective of Neutron."""
272+
if node.op != "call_function":
273+
raise ValueError(
274+
f"is_no_op_on_neutron(): Expected call_function node, got {node.op}."
275+
)
276+
277+
if node.target in [
278+
exir_ops.edge.aten.view_copy.default,
279+
exir_ops.edge.dim_order_ops._clone_dim_order.default,
280+
exir_ops.edge.aten.clone.default,
281+
]:
282+
# Known operators which are always no-ops on Neutron.
283+
return True
284+
285+
if node.target == exir_ops.edge.aten.cat.default and len(node.args[0]) == 1:
286+
# Concatenation with 1 input is a no-op.
287+
return True
288+
289+
# For any other operators, run them with random data and see if the output is identical to the input.
290+
torch.manual_seed(42)
291+
# noinspection PyBroadException
292+
try:
293+
input_data = None
294+
args_with_random_data = []
295+
for arg in node.args:
296+
match arg:
297+
case Node():
298+
# `arg` is either another operator, a model input, or a static parameter.
299+
300+
if (
301+
data := try_get_dequantized_data(arg, parameters_mapping)
302+
) is not None:
303+
# The `arg` is a static parameter. Use it's actual static data during the no-op test.
304+
args_with_random_data.append(data)
305+
306+
else:
307+
# The `arg` is a compute node or a model input. Replace it with random data for the no-op test.
308+
if input_data is not None:
309+
# Some random input data for `node` has already been stored, which means that the node has
310+
# more than 1 dynamic input node. Therefore, it cannot be a no-op.
311+
return False
312+
313+
# Generate the random data. Use the range [-5, 5) to avoid proclaiming operations like Relu as
314+
# no-ops.
315+
val = arg.meta["val"]
316+
input_data = torch.rand(val.shape, dtype=val.dtype) * 10 - 5
317+
args_with_random_data.append(input_data)
318+
319+
case list():
320+
# Lists of input nodes are not supported to keep the code simple. It is not crucial to support this
321+
# case as the affected operators are either not supported on Neutron, or are extremely unlikely to
322+
# be no-ops (e.g. GRU). One exception is `aten.cat`, which is explicitly supported above.
323+
return False
324+
325+
case _:
326+
# Generic argument (value). Not an input from a previous node. Store it in the arguments for the
327+
# no-op test.
328+
args_with_random_data.append(arg)
329+
330+
# Run the operator with the random data. If the input equals the output, the node is considered a no-op.
331+
output_data = node.target(*args_with_random_data)
332+
333+
val = node.meta["val"]
334+
if (
335+
output_data.dtype == val.dtype
336+
and output_data.shape == val.shape
337+
and torch.all(input_data == output_data)
338+
):
339+
# The operator preserves the shape, data type, and data. Therefore, it is a no-op from the perspective of
340+
# Neutron.
341+
if node.target in no_op_candidates:
342+
return True
343+
else:
344+
logging.info(
345+
f"Found the operator `{node.target}`, which appears to be a no-op, but is not in the "
346+
"known no-op list. Please report this issue."
347+
)
348+
return False
349+
350+
else:
351+
# Type, shape, or data doesn't match.
352+
return False
353+
354+
except Exception:
355+
# If execution fails, assume it's not a no-op.
356+
return False

backends/nxp/backend/edge_program_converter.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from executorch.backends.nxp.backend.ir.converter.node_converter import (
1414
CustomDelegationOptions,
1515
)
16+
from torch._subclasses import FakeTensor
1617
from torch.export import ExportedProgram
1718
from torch.export.graph_signature import InputKind
1819
from torch.fx import Node
@@ -161,20 +162,49 @@ def _process_nodes(self, nodes: list[Node], conversion_context: ConversionContex
161162
)
162163

163164
@staticmethod
164-
def map_inputs_to_parameters(edge_program: ExportedProgram) -> dict[str, Parameter]:
165+
def map_inputs_to_parameters(
166+
edge_program: ExportedProgram,
167+
post_quantization_state_dict: dict[str, Parameter] | None = None,
168+
) -> dict[str, Parameter]:
165169
"""
166170
Create mapping between program parameters (input nodes & static data nodes) and their names.
167171
168172
:param edge_program: EdgeProgram instance.
173+
:param post_quantization_state_dict: State-dict of the model right after quantization. During partitioning, the
174+
`edge_program` only contains fake tensors without any data. In this case,
175+
this state dict is used instead (if provided). Notice: It may potentially
176+
contain outdated data,
169177
:return: Mapping from parameter name to parameter instance.
170178
"""
171179
result_map = {}
172180

173181
for input_spec in edge_program.graph_signature.input_specs:
174182
if input_spec.kind in [InputKind.PARAMETER, InputKind.BUFFER]:
175-
result_map[input_spec.arg.name] = edge_program.state_dict[
176-
input_spec.target
177-
]
183+
184+
# First, try to load the static data from the model.
185+
param = edge_program.state_dict[input_spec.target]
186+
187+
if not isinstance(param, FakeTensor):
188+
# Use the data from the model.
189+
result_map[input_spec.arg.name] = param
190+
191+
else:
192+
# It is the partitioning stage, which uses a FakeModel with FakeTensors (without the actual data).
193+
# Try to load the data from the post-quantization state dict.
194+
if (
195+
post_quantization_state_dict is not None
196+
and (
197+
param := post_quantization_state_dict.get(
198+
input_spec.target, None
199+
)
200+
)
201+
is not None
202+
):
203+
result_map[input_spec.arg.name] = param
204+
205+
else:
206+
# There is no data available.
207+
continue
178208

179209
return result_map
180210

backends/nxp/backend/ir/converter/node_converter.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,10 @@ def _is_supported_in_IR(
7575
Classes which implement conversion for individual operators must overwrite this method.
7676
7777
:param node: torch.Node to check.
78-
:param parameters_mapping: Dictionary mapping tensor names to their static data (if they have it).
78+
:param parameters_mapping: Dictionary mapping static parameter names to Parameter objects containing their data
79+
(if they have any). During partitioning, this data is extracted from the model right
80+
after quantization and before edge dialect passes. Therefore, it could potentially
81+
be outdated.
7982
:param custom_delegation_options: Custom options which affect delegation.
8083
"""
8184
pass
@@ -93,7 +96,10 @@ def _is_supported_on_target(
9396
9497
:param node: The node (edge operator) to check.
9598
:param neutron_target_spec: Object for querying the target platform to retrieve its properties.
96-
:param parameters_mapping: Dictionary mapping tensor names to their static data (if they have it).
99+
:param parameters_mapping: Dictionary mapping static parameter names to Parameter objects containing their data
100+
(if they have any). During partitioning, this data is extracted from the model right
101+
after quantization and before edge dialect passes. Therefore, it could potentially
102+
be outdated.
97103
:param custom_delegation_options: Custom options which affect delegation.
98104
"""
99105
return True
@@ -110,7 +116,10 @@ def is_supported(
110116
111117
:param node: torch.Node to check.
112118
:param neutron_target_spec: Object for querying the target platform to retrieve its properties.
113-
:param parameters_mapping: Dict mapping tensor names to their data.
119+
:param parameters_mapping: Dictionary mapping static parameter names to Parameter objects containing their data
120+
(if they have any). During partitioning, this data is extracted from the model right
121+
after quantization and before edge dialect passes. Therefore, it could potentially
122+
be outdated.
114123
:param custom_delegation_options: Custom user options which affect node delegation.
115124
"""
116125
return cls._is_supported_in_IR(
@@ -136,7 +145,10 @@ def supports_partitioning_result(
136145
:param partition_list: List of proposed partitions.
137146
:param custom_delegation_options: Custom user options which affect node delegation.
138147
:param neutron_target_spec: NeutronTargetSpec instance.
139-
:param parameters_mapping: Dictionary mapping tensor names to their static data.
148+
:param parameters_mapping: Dictionary mapping static parameter names to Parameter objects containing their data
149+
(if they have any). During partitioning, this data is extracted from the model right
150+
after quantization and before edge dialect passes. Therefore, it could potentially
151+
be outdated.
140152
:return: Boolean indicating whether the node supports the current partitioning.
141153
"""
142154
return True

backends/nxp/backend/ir/converter/node_converters/ops_converters/view_copy_converter.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
)
2222
from executorch.backends.nxp.backend.ir.converter.node_converter import (
2323
CustomDelegationOptions,
24-
is_not_qdq_node,
2524
NodeConverter,
2625
)
2726
from executorch.backends.nxp.backend.ir.converter.node_converters.shared.reshape_transposition import (
@@ -59,24 +58,6 @@ def _is_supported_in_IR(
5958

6059
return True
6160

62-
@classmethod
63-
def _partition_contains_compute_nodes(cls, view_copy_partition: Partition) -> bool:
64-
non_q_dq_partition_nodes = list(
65-
filter(is_not_qdq_node, view_copy_partition.nodes)
66-
)
67-
68-
if len(non_q_dq_partition_nodes) == 1:
69-
# The `view_copy` cannot be the only node in a partition.
70-
return False
71-
72-
# It is common for a `clone` node to come before the `view_copy`. Make sure these are not the only two nodes
73-
# in the partition.
74-
if any("clone" in n.name for n in non_q_dq_partition_nodes):
75-
if len(non_q_dq_partition_nodes) <= 2:
76-
return False
77-
78-
return True
79-
8061
@classmethod
8162
def supports_partitioning_result(
8263
cls,
@@ -91,9 +72,6 @@ def supports_partitioning_result(
9172
]
9273
assert len(view_copy_partitions) == 1
9374

94-
if not cls._partition_contains_compute_nodes(view_copy_partitions[0]):
95-
return False
96-
9775
input_format = node.args[0].meta[NXP_NODE_FORMAT]
9876
output_format = node.meta[NXP_NODE_FORMAT]
9977
input_shape = list(node.args[0].meta["val"].shape)

0 commit comments

Comments
 (0)