Skip to content

Commit f5afa0b

Browse files
Cortex-M backend: Optimize quantizer pattern matching (pytorch#16700)
Current implementation naively loops over all nodes in the graph for each pattern, scaling badly as the number of patterns increase. New implementation instead uses a lookup table for the first node in each pattern and then loops over patterns starting with that node. Signed-off-by: Adrian Lundell <adrian.lundell@arm.com>
1 parent eaf1c65 commit f5afa0b

1 file changed

Lines changed: 27 additions & 15 deletions

File tree

backends/cortex_m/quantizer/quantizer.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
# Copyright 2025 Arm Limited and/or its affiliates.
1+
# Copyright 2025-2026 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

66

7-
from typing import Any, Callable, cast, List, Optional
7+
from collections import defaultdict
8+
from typing import Any, Callable, cast, Iterator, List, Optional
89

910
import torch
1011
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
@@ -195,6 +196,8 @@ class OperatorConfigQuantizer(Quantizer):
195196
skipped. Used to match for example particular targets or modules.
196197
"""
197198

199+
Q_PATTERN_MATCHED_KEY = "quantizer_matched"
200+
198201
def __init__(
199202
self,
200203
operator_config: QuantizationConfig,
@@ -236,24 +239,33 @@ def check_pattern(
236239
return match
237240

238241
def match_patterns(
239-
self, model: GraphModule, patterns: List[List[str]]
240-
) -> List[List[Node]]:
242+
self, model: GraphModule, patterns: List[List[OpOverload]]
243+
) -> Iterator[List[Node]]:
241244
"""
242245
Match all given patterns in the graph and return list of matches.
243246
Each node can only be part of one match, larger patterns are prioritized.
244247
Currently only linear patterns (single chain) are supported.
248+
249+
Q_PATTERN_MATCHED_KEY is set to True in node.meta to track which nodes have
250+
already been matched.
245251
"""
246-
patterns.sort(key=len, reverse=True)
247-
matches: List[List[Node]] = []
248-
for pattern in patterns:
249-
for node in model.graph.nodes:
250-
potential_match = self.check_pattern(node, pattern)
251-
if potential_match:
252-
matches.append(potential_match)
253-
for node in potential_match:
254-
node.meta["quantizer_matched"] = True
255-
256-
return matches
252+
253+
# maps operator -> list of patterns starting with operator
254+
patterns_by_first = defaultdict(list)
255+
for p in sorted(patterns, key=len, reverse=True):
256+
patterns_by_first[p[0]].append(p)
257+
258+
for node in model.graph.nodes:
259+
if node.meta.get(OperatorConfigQuantizer.Q_PATTERN_MATCHED_KEY, False):
260+
continue
261+
for pattern in patterns_by_first.get(node.target, []):
262+
match_or_none = self.check_pattern(node, pattern)
263+
if match_or_none is not None:
264+
for matched_node in match_or_none:
265+
matched_node.meta[
266+
OperatorConfigQuantizer.Q_PATTERN_MATCHED_KEY
267+
] = True
268+
yield match_or_none
257269

258270
def is_parameter(self, node: Node, model: GraphModule) -> bool:
259271
"""Returns True if the given node is a parameter of the model."""

0 commit comments

Comments
 (0)