|
1 | | -# Copyright 2025 Arm Limited and/or its affiliates. |
| 1 | +# Copyright 2025-2026 Arm Limited and/or its affiliates. |
2 | 2 | # |
3 | 3 | # This source code is licensed under the BSD-style license found in the |
4 | 4 | # LICENSE file in the root directory of this source tree. |
5 | 5 |
|
6 | 6 |
|
7 | | -from typing import Any, Callable, cast, List, Optional |
| 7 | +from collections import defaultdict |
| 8 | +from typing import Any, Callable, cast, Iterator, List, Optional |
8 | 9 |
|
9 | 10 | import torch |
10 | 11 | from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor |
@@ -195,6 +196,8 @@ class OperatorConfigQuantizer(Quantizer): |
195 | 196 | skipped. Used to match for example particular targets or modules. |
196 | 197 | """ |
197 | 198 |
|
| 199 | + Q_PATTERN_MATCHED_KEY = "quantizer_matched" |
| 200 | + |
198 | 201 | def __init__( |
199 | 202 | self, |
200 | 203 | operator_config: QuantizationConfig, |
@@ -236,24 +239,33 @@ def check_pattern( |
236 | 239 | return match |
237 | 240 |
|
238 | 241 | def match_patterns( |
239 | | - self, model: GraphModule, patterns: List[List[str]] |
240 | | - ) -> List[List[Node]]: |
| 242 | + self, model: GraphModule, patterns: List[List[OpOverload]] |
| 243 | + ) -> Iterator[List[Node]]: |
241 | 244 | """ |
242 | 245 | Match all given patterns in the graph and return list of matches. |
243 | 246 | Each node can only be part of one match, larger patterns are prioritized. |
244 | 247 | Currently only linear patterns (single chain) are supported. |
| 248 | +
|
| 249 | + Q_PATTERN_MATCHED_KEY is set to True in node.meta to track which nodes have |
| 250 | + already been matched. |
245 | 251 | """ |
246 | | - patterns.sort(key=len, reverse=True) |
247 | | - matches: List[List[Node]] = [] |
248 | | - for pattern in patterns: |
249 | | - for node in model.graph.nodes: |
250 | | - potential_match = self.check_pattern(node, pattern) |
251 | | - if potential_match: |
252 | | - matches.append(potential_match) |
253 | | - for node in potential_match: |
254 | | - node.meta["quantizer_matched"] = True |
255 | | - |
256 | | - return matches |
| 252 | + |
| 253 | + # maps operator -> list of patterns starting with operator |
| 254 | + patterns_by_first = defaultdict(list) |
| 255 | + for p in sorted(patterns, key=len, reverse=True): |
| 256 | + patterns_by_first[p[0]].append(p) |
| 257 | + |
| 258 | + for node in model.graph.nodes: |
| 259 | + if node.meta.get(OperatorConfigQuantizer.Q_PATTERN_MATCHED_KEY, False): |
| 260 | + continue |
| 261 | + for pattern in patterns_by_first.get(node.target, []): |
| 262 | + match_or_none = self.check_pattern(node, pattern) |
| 263 | + if match_or_none is not None: |
| 264 | + for matched_node in match_or_none: |
| 265 | + matched_node.meta[ |
| 266 | + OperatorConfigQuantizer.Q_PATTERN_MATCHED_KEY |
| 267 | + ] = True |
| 268 | + yield match_or_none |
257 | 269 |
|
258 | 270 | def is_parameter(self, node: Node, model: GraphModule) -> bool: |
259 | 271 | """Returns True if the given node is a parameter of the model.""" |
|
0 commit comments