55
66import abc
77import re
8+ from collections import defaultdict
9+ from copy import deepcopy
810from dataclasses import dataclass
9- from typing import Union
11+ from typing import Callable , Union
1012
13+ from executorch .backends .nxp .neutron_partitioner import (
14+ NeutronPartitioner ,
15+ NXP_DELEGATION_TAG ,
16+ )
17+ from executorch .backends .nxp .tests .ops_aliases import (
18+ DequantizePerChannel ,
19+ DequantizePerTensor ,
20+ QuantizePerChannel ,
21+ QuantizePerTensor ,
22+ )
23+
24+ from executorch .exir .dialects .edge ._ops import EdgeOpOverload
25+
26+ from pytest_mock import MockerFixture
27+
28+ from torch .fx import Node
1129from torch .fx .graph import Graph
1230
1331
1432@dataclass
1533class NonDelegatedNode :
34+ """Represents an expected non-delegated node in the graph.
35+
36+ :param node_name: The name of the node to check for
37+ :param num_occurrences: Expected number of occurrences. If None, just verifies that at least one exists
38+ """
39+
1640 node_name : str
1741 num_occurrences : Union [int , None ] = None
1842
1943
2044class GraphVerifier (abc .ABC ):
45+ """Abstract base class for graph verification strategies."""
46+
2147 @abc .abstractmethod
2248 def verify_graph (self , graph : Graph ):
23- pass
49+ """Verifies the graph meets expected criteria.
2450
25- @abc .abstractmethod
26- def check_num_delegated_nodes (self , num_dlg_nodes : int ):
51+ :param graph: The FX graph to verify
52+ :raises AssertionError: If the graph does not meet expectations
53+ """
2754 pass
2855
2956
3057class BaseGraphVerifier (GraphVerifier ):
31- """Graph verifier base class. Checks for number of delegated nodes and number of selected expected nodes."""
58+ """Graph verifier base class. Checks for number of delegated nodes and number of selected expected nodes.
59+
60+ This verifier performs the following checks:
61+ - The total number of delegated call nodes matches expectations
62+ - Specific non-delegated nodes appear with the expected frequency
63+ - No unexpected aten nodes are present in the graph
64+ """
3265
3366 def __init__ (
3467 self ,
3568 exp_num_delegate_call_nodes : int ,
3669 exp_non_delegated_nodes : list [NonDelegatedNode ] = None ,
3770 ):
71+ """Initializes the BaseGraphVerifier.
72+
73+ :param exp_num_delegate_call_nodes: Expected number of delegated nodes
74+ :param exp_non_delegated_nodes: List of expected non-delegated nodes to verify
75+ """
3876 self .exp_non_delegated_nodes = (
3977 exp_non_delegated_nodes if exp_non_delegated_nodes is not None else []
4078 )
4179 self .exp_num_delegate_call_nodes = exp_num_delegate_call_nodes
4280
4381 def check_num_delegated_nodes (self , num_dlg_nodes ):
82+ """Checks that the number of delegated nodes matches expectations.
83+
84+ :param num_dlg_nodes: Actual number of delegated nodes
85+ :raises AssertionError: If the count doesn't match expectations
86+ """
4487 assert not (
4588 num_dlg_nodes < self .exp_num_delegate_call_nodes
4689 ), f"Number of delegated nodes decreased from { self .exp_num_delegate_call_nodes } to { num_dlg_nodes } ."
@@ -49,6 +92,11 @@ def check_num_delegated_nodes(self, num_dlg_nodes):
4992 ), f"Number of delegated nodes increased from { self .exp_num_delegate_call_nodes } to { num_dlg_nodes } ."
5093
5194 def verify_graph (self , graph ):
95+ """Verifies the graph meets delegation and node presence expectations.
96+
97+ :param graph: The FX graph to verify
98+ :raises AssertionError: If verification fails
99+ """
52100 nodes = list (graph .nodes )
53101
54102 # Check for specific non delegated nodes
@@ -84,3 +132,133 @@ def verify_graph(self, graph):
84132 assert (
85133 not unexpected_aten_fn_nodes
86134 ), f"Graphs contains unexpected aten nodes:\n { unexpected_aten_fn_nodes } ."
135+
136+
137+ # Type alias for operators - can be either EdgeOpOverload or any callable (e.g., operator.getitem).
138+ Operator = EdgeOpOverload | Callable
139+
140+
141+ class DetailedGraphVerifier (GraphVerifier ):
142+ """Graph verifier that checks for exact delegated and non-delegated operators.
143+
144+ This verifier captures a snapshot of the graph immediately after partitioning and verifies
145+ that specific operators were delegated/non-delegated the expected number of times. It uses
146+ mocker to intercept the partition() call and create a deep copy of the nodes before they
147+ can be modified. Quantization/dequantization operators are ignored by default as they are
148+ typically not the focus of delegation verification.
149+ """
150+
151+ default_ops_to_ignore = {
152+ QuantizePerTensor ,
153+ QuantizePerChannel ,
154+ DequantizePerTensor ,
155+ DequantizePerChannel ,
156+ }
157+
158+ def __init__ (
159+ self ,
160+ mocker : MockerFixture ,
161+ * ,
162+ expected_delegated_ops : dict [Operator , int ],
163+ expected_non_delegated_ops : dict [Operator , int ],
164+ ops_to_ignore : set [Operator ] | None = None ,
165+ ):
166+ """Initializes the DetailedGraphVerifier and patches NeutronPartitioner.partition() to capture node state.
167+
168+ :param expected_delegated_ops: Dictionary mapping operators to their expected delegation count
169+ :param expected_non_delegated_ops: Dictionary mapping operators to their expected non-delegation count
170+ :param mocker: Pytest mocker fixture for intercepting the partition method
171+ :param ops_to_ignore: Set of operators to ignore during verification. Defaults to quantization ops
172+ """
173+ self .expected_delegated_ops = expected_delegated_ops
174+ self .expected_non_delegated_ops = expected_non_delegated_ops
175+
176+ self .ops_to_ignore = ops_to_ignore or self .default_ops_to_ignore
177+
178+ # We need to use mocker to capture a copy of the nodes returned by NeutronPartitioner.partition() to access
179+ # their partition tag. The nodes in the returned graph may be modified after partition() returns, so we
180+ # capture a deep copy immediately when the method completes.
181+ self .captured_partitioned_nodes : list [Node ] | None = None
182+
183+ # Store original partition method for the wrapper.
184+ # Note: pytest-mock automatically restores the original method after the test completes,
185+ # so manual cleanup is not required.
186+ original_partition_method = NeutronPartitioner .partition
187+
188+ def partition_wrapper (self_ , exported_program ):
189+ """Wraps NeutronPartitioner.partition() to capture a snapshot of nodes after partitioning.
190+
191+ :param self_: The NeutronPartitioner instance
192+ :param exported_program: The ExportedProgram being partitioned
193+ :return: The PartitionResult from the original partition method
194+ """
195+ result = original_partition_method (self_ , exported_program )
196+ # Capture a deep copy of the nodes with their metadata.
197+ # This ensures we have the exact state immediately after partitioning,
198+ # before any subsequent transformations modify the graph.
199+ self .captured_partitioned_nodes = list (
200+ deepcopy (exported_program .graph .nodes )
201+ )
202+ return result
203+
204+ # Patch the partition method to intercept and capture results.
205+ mocker .patch .object (NeutronPartitioner , "partition" , partition_wrapper )
206+
207+ def verify_graph (self , graph ):
208+ """Verifies that operators were delegated/non-delegated as expected by comparing actual counts against expectations.
209+
210+ :param graph: The FX graph to verify (not directly used; we use captured nodes instead)
211+ :raises AssertionError: If the NeutronPartitioner wasn't used or if delegation doesn't match expectations
212+ """
213+ assert (
214+ self .captured_partitioned_nodes is not None
215+ ), "The NeutronPartitioner was not used. Cannot access delegated nodes."
216+
217+ delegated_ops = defaultdict (int )
218+ non_delegated_ops = defaultdict (int )
219+
220+ for node in self .captured_partitioned_nodes :
221+ # Only process call_function nodes with a target
222+ if not hasattr (node , "target" ) or node .op != "call_function" :
223+ continue
224+
225+ # Skip operators we're configured to ignore (e.g., quantization ops)
226+ if node .target in self .ops_to_ignore :
227+ continue
228+
229+ # Check if the node was tagged for delegation during partitioning
230+ if NXP_DELEGATION_TAG in node .meta :
231+ delegated_ops [node .target ] += 1
232+ else :
233+ non_delegated_ops [node .target ] += 1
234+
235+ # All ops which were either expected to be delegated, or were actually delegated.
236+ all_delegated_ops = list (set (self .expected_delegated_ops ).union (delegated_ops ))
237+
238+ # All ops which were either expected to be non-delegated, or were actually non-delegated.
239+ all_non_delegated_ops = list (
240+ set (self .expected_non_delegated_ops ).union (non_delegated_ops )
241+ )
242+
243+ message = ""
244+
245+ # Check delegated operators
246+ for op in all_delegated_ops :
247+ expected_count = self .expected_delegated_ops .get (op , 0 )
248+ real_count = delegated_ops .get (op , 0 )
249+ op_name = op .name () if hasattr (op , "name" ) else str (op )
250+ if expected_count != real_count :
251+ message += f"\t `{ op_name } ` was delegated { real_count } times instead of the expected { expected_count } times.\n "
252+
253+ # Check non-delegated operators
254+ for op in all_non_delegated_ops :
255+ expected_count = self .expected_non_delegated_ops .get (op , 0 )
256+ real_count = non_delegated_ops .get (op , 0 )
257+ op_name = op .name () if hasattr (op , "name" ) else str (op )
258+ if expected_count != real_count :
259+ message += f"\t `{ op_name } ` was NON-delegated { real_count } times instead of the expected { expected_count } times.\n "
260+
261+ if message :
262+ raise AssertionError (
263+ "Some operators were not delegated as expected:\n " + message
264+ )
0 commit comments