Skip to content

Commit b213b42

Browse files
Arm backend: Allow PatternQuantizer to annotate all nodes with None (pytorch#19639)
Previously it required nodes to be part of the quantizer support_dict, ignoring nodes quantized by the SharedQspecQuantizer. Signed-off-by: Adrian Lundell <adrian.lundell@arm.com>
1 parent cc3afbe commit b213b42

2 files changed

Lines changed: 109 additions & 6 deletions

File tree

backends/arm/test/quantizer/test_selective_quantization.py

Lines changed: 95 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Dict
88

99
import torch
10+
1011
from executorch.backends.arm.quantizer import (
1112
get_symmetric_a16w8_quantization_config,
1213
get_symmetric_quantization_config,
@@ -16,13 +17,17 @@
1617
from executorch.backends.arm.test import common
1718
from executorch.backends.arm.test.tester.test_pipeline import QuantizationPipeline
1819
from executorch.backends.arm.tosa import TosaSpecification
20+
from executorch.backends.test.harness.stages import StageType
21+
from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY
1922
from torchvision import models, transforms # type: ignore[import-untyped]
2023
from torchvision.ops.misc import Conv2dNormActivation # type: ignore[import-untyped]
2124

2225

23-
def get_quantizer():
26+
def get_quantizer(use_composable_quantizer: bool = False):
2427
tosa_spec = TosaSpecification.create_from_string("TOSA-1.0+INT")
25-
quantizer = TOSAQuantizer(tosa_spec)
28+
quantizer = TOSAQuantizer(
29+
tosa_spec, use_composable_quantizer=use_composable_quantizer
30+
)
2631
quantizer.set_global(get_symmetric_quantization_config())
2732
return quantizer
2833

@@ -53,6 +58,25 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
5358
return x + y
5459

5560

61+
class Cat(torch.nn.Module):
62+
63+
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
64+
return torch.cat((x, y), dim=1)
65+
66+
67+
class LinearGraphTail(torch.nn.Module):
68+
69+
def __init__(self):
70+
super().__init__()
71+
self.linear = torch.nn.Linear(10, 10)
72+
73+
def forward(self, x: torch.Tensor) -> torch.Tensor:
74+
x = self.linear(x)
75+
x = torch.relu(x)
76+
x = torch.sigmoid(x)
77+
return torch.neg(x)
78+
79+
5680
class AddSoftmaxAdd(torch.nn.Module):
5781
module_names = {"add_0": None, "add_1": None}
5882
module_types = {
@@ -131,6 +155,75 @@ def test_selective_quant_module_type_tosa_INT(model):
131155
pipeline.run()
132156

133157

158+
def test_selective_quant_cat_node_target_none_tosa_INT():
159+
model = Cat()
160+
inputs = (torch.randn(1, 2, 4), torch.randn(1, 3, 4))
161+
162+
quantizer = get_quantizer(use_composable_quantizer=True)
163+
quantizer.set_node_target(torch.ops.aten.cat.default, None)
164+
165+
pipeline = QuantizationPipeline[tuple[torch.Tensor, torch.Tensor]](
166+
model,
167+
inputs,
168+
quantizer=quantizer,
169+
qspecs={
170+
"aten.cat.default": {
171+
None: 1,
172+
},
173+
},
174+
)
175+
176+
pipeline.run()
177+
178+
179+
def test_composable_io_none_skips_global_tosa_INT():
180+
model = Add()
181+
inputs = (torch.randn(1, 10), torch.randn(1, 10))
182+
183+
quantizer = get_quantizer(use_composable_quantizer=True)
184+
quantizer.set_io(None)
185+
186+
pipeline = QuantizationPipeline[tuple[torch.Tensor, torch.Tensor]](
187+
model,
188+
inputs,
189+
quantizer=quantizer,
190+
input_qspecs={None: 2},
191+
output_qspecs={None: 1},
192+
)
193+
194+
pipeline.run()
195+
196+
197+
def test_composable_global_none_linear_graph_tail_tosa_INT():
198+
model = LinearGraphTail()
199+
inputs = (torch.randn(1, 10),)
200+
201+
quantizer = get_quantizer(use_composable_quantizer=True)
202+
quantizer.set_global(None)
203+
204+
pipeline = QuantizationPipeline[tuple[torch.Tensor]](
205+
model,
206+
inputs,
207+
quantizer=quantizer,
208+
qspecs={
209+
"aten.linear.default": {None: 1},
210+
"aten.relu.default": {None: 1},
211+
"aten.sigmoid.default": {None: 1},
212+
"aten.neg.default": {None: 1},
213+
},
214+
)
215+
216+
pipeline.run()
217+
218+
graph = pipeline.tester.get_graph(StageType.QUANTIZE)
219+
unannotated_nodes = [
220+
node.name
221+
for node in graph.nodes
222+
if node.op == "call_function" and Q_ANNOTATION_KEY not in node.meta
223+
]
224+
assert not unannotated_nodes
225+
226+
134227
mv3 = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights)
135228
mv3.eval()
136229
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

backends/cortex_m/quantizer/pattern_matcher.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,15 +113,25 @@ def _get_match(self, node_queue: List[Node]) -> List[Node]:
113113
return []
114114

115115
def _get_matches(
116-
self, node_queue: List[Node], quantization_config: QuantizationConfig
116+
self, node_queue: List[Node], quantization_config: Optional[QuantizationConfig]
117117
) -> List[PatternMatchResult]:
118118
"""Returns the longest accepted match starting at the first node of the
119119
queue as well as longer rejected matches.
120120
"""
121+
# Annotating with None means rejecting quantization - this is always supported.
122+
if quantization_config is None:
123+
node = node_queue[0]
124+
if node.meta.get(self.Q_PATTERN_MATCHED_KEY, False):
125+
return [
126+
PatternMatchResult([node], False, self.REJECT_PREVIOUSLY_ANNOTATED)
127+
]
128+
129+
node.meta[self.Q_PATTERN_MATCHED_KEY] = True
130+
return [PatternMatchResult([node], True)]
131+
121132
matches: list[PatternMatchResult] = []
122133
accepted = False
123134
max_match_length = len(node_queue)
124-
125135
while max_match_length > 0 and not accepted:
126136
match = self._get_match(node_queue[:max_match_length])
127137
max_match_length = (
@@ -136,7 +146,7 @@ def _get_matches(
136146
return matches
137147

138148
def _dequeue_and_get_matches(
139-
self, node_queue: List[Node], quantization_config: QuantizationConfig
149+
self, node_queue: List[Node], quantization_config: Optional[QuantizationConfig]
140150
) -> List[PatternMatchResult]:
141151
"""Dequeues the longest accepted match starting at the first node of the
142152
queue, and returns all potential matches that were checked (rejected
@@ -160,7 +170,7 @@ def _dequeue_and_get_matches(
160170
return potential_matches
161171

162172
def find_pattern_matches(
163-
self, nodes: Iterator[Node], quantization_config: QuantizationConfig
173+
self, nodes: Iterator[Node], quantization_config: Optional[QuantizationConfig]
164174
) -> Iterator[PatternMatchResult]:
165175
"""Match all given patterns in the graph and return match results with
166176
acceptance/rejection status. Each node can only be part of one match,

0 commit comments

Comments
 (0)