Skip to content

Commit a206afb

Browse files
committed
NXP backend: Add DetailedGraphVerifier for strict checking of delegated nodes.
1 parent 2e58165 commit a206afb

5 files changed

Lines changed: 249 additions & 78 deletions

File tree

backends/nxp/tests/graph_verifier.py

Lines changed: 183 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,42 +5,85 @@
55

66
import abc
77
import re
8+
from collections import defaultdict
9+
from copy import deepcopy
810
from dataclasses import dataclass
9-
from typing import Union
11+
from typing import Callable, Union
1012

13+
from executorch.backends.nxp.neutron_partitioner import (
14+
NeutronPartitioner,
15+
NXP_DELEGATION_TAG,
16+
)
17+
from executorch.backends.nxp.tests.ops_aliases import (
18+
DequantizePerChannel,
19+
DequantizePerTensor,
20+
QuantizePerChannel,
21+
QuantizePerTensor,
22+
)
23+
24+
from executorch.exir.dialects.edge._ops import EdgeOpOverload
25+
26+
from pytest_mock import MockerFixture
27+
28+
from torch.fx import Node
1129
from torch.fx.graph import Graph
1230

1331

1432
@dataclass
1533
class NonDelegatedNode:
34+
"""Represents an expected non-delegated node in the graph.
35+
36+
:param node_name: The name of the node to check for
37+
:param num_occurrences: Expected number of occurrences. If None, just verifies that at least one exists
38+
"""
39+
1640
node_name: str
1741
num_occurrences: Union[int, None] = None
1842

1943

2044
class GraphVerifier(abc.ABC):
45+
"""Abstract base class for graph verification strategies."""
46+
2147
@abc.abstractmethod
2248
def verify_graph(self, graph: Graph):
23-
pass
49+
"""Verifies the graph meets expected criteria.
2450
25-
@abc.abstractmethod
26-
def check_num_delegated_nodes(self, num_dlg_nodes: int):
51+
:param graph: The FX graph to verify
52+
:raises AssertionError: If the graph does not meet expectations
53+
"""
2754
pass
2855

2956

3057
class BaseGraphVerifier(GraphVerifier):
31-
"""Graph verifier base class. Checks for number of delegated nodes and number of selected expected nodes."""
58+
"""Graph verifier base class. Checks for number of delegated nodes and number of selected expected nodes.
59+
60+
This verifier performs the following checks:
61+
- The total number of delegated call nodes matches expectations
62+
- Specific non-delegated nodes appear with the expected frequency
63+
- No unexpected aten nodes are present in the graph
64+
"""
3265

3366
def __init__(
3467
self,
3568
exp_num_delegate_call_nodes: int,
3669
exp_non_delegated_nodes: list[NonDelegatedNode] = None,
3770
):
71+
"""Initializes the BaseGraphVerifier.
72+
73+
:param exp_num_delegate_call_nodes: Expected number of delegated nodes
74+
:param exp_non_delegated_nodes: List of expected non-delegated nodes to verify
75+
"""
3876
self.exp_non_delegated_nodes = (
3977
exp_non_delegated_nodes if exp_non_delegated_nodes is not None else []
4078
)
4179
self.exp_num_delegate_call_nodes = exp_num_delegate_call_nodes
4280

4381
def check_num_delegated_nodes(self, num_dlg_nodes):
82+
"""Checks that the number of delegated nodes matches expectations.
83+
84+
:param num_dlg_nodes: Actual number of delegated nodes
85+
:raises AssertionError: If the count doesn't match expectations
86+
"""
4487
assert not (
4588
num_dlg_nodes < self.exp_num_delegate_call_nodes
4689
), f"Number of delegated nodes decreased from {self.exp_num_delegate_call_nodes} to {num_dlg_nodes}."
@@ -49,6 +92,11 @@ def check_num_delegated_nodes(self, num_dlg_nodes):
4992
), f"Number of delegated nodes increased from {self.exp_num_delegate_call_nodes} to {num_dlg_nodes}."
5093

5194
def verify_graph(self, graph):
95+
"""Verifies the graph meets delegation and node presence expectations.
96+
97+
:param graph: The FX graph to verify
98+
:raises AssertionError: If verification fails
99+
"""
52100
nodes = list(graph.nodes)
53101

54102
# Check for specific non delegated nodes
@@ -84,3 +132,133 @@ def verify_graph(self, graph):
84132
assert (
85133
not unexpected_aten_fn_nodes
86134
), f"Graphs contains unexpected aten nodes:\n{unexpected_aten_fn_nodes}."
135+
136+
137+
# Type alias for operators - can be either EdgeOpOverload or any callable (e.g., operator.getitem).
138+
Operator = EdgeOpOverload | Callable
139+
140+
141+
class DetailedGraphVerifier(GraphVerifier):
142+
"""Graph verifier that checks for exact delegated and non-delegated operators.
143+
144+
This verifier captures a snapshot of the graph immediately after partitioning and verifies
145+
that specific operators were delegated/non-delegated the expected number of times. It uses
146+
mocker to intercept the partition() call and create a deep copy of the nodes before they
147+
can be modified. Quantization/dequantization operators are ignored by default as they are
148+
typically not the focus of delegation verification.
149+
"""
150+
151+
default_ops_to_ignore = {
152+
QuantizePerTensor,
153+
QuantizePerChannel,
154+
DequantizePerTensor,
155+
DequantizePerChannel,
156+
}
157+
158+
def __init__(
159+
self,
160+
mocker: MockerFixture,
161+
*,
162+
expected_delegated_ops: dict[Operator, int],
163+
expected_non_delegated_ops: dict[Operator, int],
164+
ops_to_ignore: set[Operator] | None = None,
165+
):
166+
"""Initializes the DetailedGraphVerifier and patches NeutronPartitioner.partition() to capture node state.
167+
168+
:param expected_delegated_ops: Dictionary mapping operators to their expected delegation count
169+
:param expected_non_delegated_ops: Dictionary mapping operators to their expected non-delegation count
170+
:param mocker: Pytest mocker fixture for intercepting the partition method
171+
:param ops_to_ignore: Set of operators to ignore during verification. Defaults to quantization ops
172+
"""
173+
self.expected_delegated_ops = expected_delegated_ops
174+
self.expected_non_delegated_ops = expected_non_delegated_ops
175+
176+
self.ops_to_ignore = ops_to_ignore or self.default_ops_to_ignore
177+
178+
# We need to use mocker to capture a copy of the nodes returned by NeutronPartitioner.partition() to access
179+
# their partition tag. The nodes in the returned graph may be modified after partition() returns, so we
180+
# capture a deep copy immediately when the method completes.
181+
self.captured_partitioned_nodes: list[Node] | None = None
182+
183+
# Store original partition method for the wrapper.
184+
# Note: pytest-mock automatically restores the original method after the test completes,
185+
# so manual cleanup is not required.
186+
original_partition_method = NeutronPartitioner.partition
187+
188+
def partition_wrapper(self_, exported_program):
189+
"""Wraps NeutronPartitioner.partition() to capture a snapshot of nodes after partitioning.
190+
191+
:param self_: The NeutronPartitioner instance
192+
:param exported_program: The ExportedProgram being partitioned
193+
:return: The PartitionResult from the original partition method
194+
"""
195+
result = original_partition_method(self_, exported_program)
196+
# Capture a deep copy of the nodes with their metadata.
197+
# This ensures we have the exact state immediately after partitioning,
198+
# before any subsequent transformations modify the graph.
199+
self.captured_partitioned_nodes = list(
200+
deepcopy(exported_program.graph.nodes)
201+
)
202+
return result
203+
204+
# Patch the partition method to intercept and capture results.
205+
mocker.patch.object(NeutronPartitioner, "partition", partition_wrapper)
206+
207+
def verify_graph(self, graph):
208+
"""Verifies that operators were delegated/non-delegated as expected by comparing actual counts against expectations.
209+
210+
:param graph: The FX graph to verify (not directly used; we use captured nodes instead)
211+
:raises AssertionError: If the NeutronPartitioner wasn't used or if delegation doesn't match expectations
212+
"""
213+
assert (
214+
self.captured_partitioned_nodes is not None
215+
), "The NeutronPartitioner was not used. Cannot access delegated nodes."
216+
217+
delegated_ops = defaultdict(int)
218+
non_delegated_ops = defaultdict(int)
219+
220+
for node in self.captured_partitioned_nodes:
221+
# Only process call_function nodes with a target
222+
if not hasattr(node, "target") or node.op != "call_function":
223+
continue
224+
225+
# Skip operators we're configured to ignore (e.g., quantization ops)
226+
if node.target in self.ops_to_ignore:
227+
continue
228+
229+
# Check if the node was tagged for delegation during partitioning
230+
if NXP_DELEGATION_TAG in node.meta:
231+
delegated_ops[node.target] += 1
232+
else:
233+
non_delegated_ops[node.target] += 1
234+
235+
# All ops which were either expected to be delegated, or were actually delegated.
236+
all_delegated_ops = list(set(self.expected_delegated_ops).union(delegated_ops))
237+
238+
# All ops which were either expected to be non-delegated, or were actually non-delegated.
239+
all_non_delegated_ops = list(
240+
set(self.expected_non_delegated_ops).union(non_delegated_ops)
241+
)
242+
243+
message = ""
244+
245+
# Check delegated operators
246+
for op in all_delegated_ops:
247+
expected_count = self.expected_delegated_ops.get(op, 0)
248+
real_count = delegated_ops.get(op, 0)
249+
op_name = op.name() if hasattr(op, "name") else str(op)
250+
if expected_count != real_count:
251+
message += f"\t`{op_name}` was delegated {real_count} times instead of the expected {expected_count} times.\n"
252+
253+
# Check non-delegated operators
254+
for op in all_non_delegated_ops:
255+
expected_count = self.expected_non_delegated_ops.get(op, 0)
256+
real_count = non_delegated_ops.get(op, 0)
257+
op_name = op.name() if hasattr(op, "name") else str(op)
258+
if expected_count != real_count:
259+
message += f"\t`{op_name}` was NON-delegated {real_count} times instead of the expected {expected_count} times.\n"
260+
261+
if message:
262+
raise AssertionError(
263+
"Some operators were not delegated as expected:\n" + message
264+
)

backends/nxp/tests/ir/converter/node_converter/test_avg_pool2d_converter.py

Lines changed: 15 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,7 @@
2828
ToNCHWPreprocess,
2929
ToNHWCPreprocess,
3030
)
31-
from executorch.backends.nxp.tests.graph_verifier import (
32-
BaseGraphVerifier,
33-
NonDelegatedNode,
34-
)
31+
from executorch.backends.nxp.tests.graph_verifier import DetailedGraphVerifier
3532
from executorch.backends.nxp.tests.models import AvgPool2dConvModule, AvgPool2dModule
3633

3734
from executorch.backends.nxp.tests.nsys_testing import lower_run_compare
@@ -306,25 +303,23 @@ def test_from_avg_pool_1d(mocker):
306303

307304

308305
class TestAvgPool2DNewNeutronFlow:
309-
def test__basic_nsys_inference(self):
306+
def test__basic_nsys_inference(self, mocker):
310307
input_shape = (2, 4, 6, 7)
311308
model = AvgPool2dModule(False, 0)
312-
graph_verifier = BaseGraphVerifier(
313-
exp_num_delegate_call_nodes=1, # Delegated AvgPool.
314-
exp_non_delegated_nodes=[],
309+
graph_verifier = DetailedGraphVerifier(
310+
mocker, expected_delegated_ops={AvgPool2D: 1}, expected_non_delegated_ops={}
315311
)
316312

317313
lower_run_compare(
318314
model, input_shape, graph_verifier, use_new_flow_neutron_c=True
319315
)
320316

321-
def test__kernel_size_limit(self):
317+
def test__kernel_size_limit(self, mocker):
322318
kernel_size = (1, 4096)
323319
input_shape = (1, 4) + kernel_size
324320
model = AvgPool2dModule(False, 0, kernel_size)
325-
graph_verifier = BaseGraphVerifier(
326-
exp_num_delegate_call_nodes=1, # Delegated AvgPool.
327-
exp_non_delegated_nodes=[],
321+
graph_verifier = DetailedGraphVerifier(
322+
mocker, expected_delegated_ops={AvgPool2D: 1}, expected_non_delegated_ops={}
328323
)
329324

330325
lower_run_compare(
@@ -346,13 +341,12 @@ def test__kernel_size_limit_exceeded(self):
346341
)
347342
assert graph_contains_any_of_ops(delegated_ep.graph, [AvgPool2D])
348343

349-
def test__stride_limit(self):
344+
def test__stride_limit(self, mocker):
350345
stride = 4096
351346
input_shape = (1, 4, 1, 4096)
352347
model = AvgPool2dModule(False, 0, 1, stride)
353-
graph_verifier = BaseGraphVerifier(
354-
exp_num_delegate_call_nodes=1, # Delegated AvgPool.
355-
exp_non_delegated_nodes=[],
348+
graph_verifier = DetailedGraphVerifier(
349+
mocker, expected_delegated_ops={AvgPool2D: 1}, expected_non_delegated_ops={}
356350
)
357351

358352
lower_run_compare(
@@ -378,16 +372,13 @@ def test__stride_limit_exceeded(self):
378372
class TestAvgPool1DNewNeutronFlow:
379373

380374
# Just a basic test to verify that the operator gets extended to the 2D variant correctly.
381-
def test__basic_nsys_inference__view_not_delegated(self):
375+
def test__basic_nsys_inference__view_not_delegated(self, mocker):
382376
input_shape = (2, 4, 6) # The old flow limited the batch size to 1.
383377
model = AvgPool1DModule()
384-
graph_verifier = BaseGraphVerifier(
385-
exp_num_delegate_call_nodes=1, # Delegated AvgPool.
386-
exp_non_delegated_nodes=[
387-
NonDelegatedNode(
388-
"aten_view_copy_default", 2
389-
) # Non delegated due to shape requirements.
390-
],
378+
graph_verifier = DetailedGraphVerifier(
379+
mocker,
380+
expected_delegated_ops={AvgPool2D: 1},
381+
expected_non_delegated_ops={ViewCopy: 2},
391382
)
392383

393384
lower_run_compare(

0 commit comments

Comments
 (0)