Skip to content

Commit 3476cd8

Browse files
Martin LindströmMartin Lindström
authored andcommitted
Arm backend: Rework reporting of qspecs
The quantization reporter prints quantization specs in human-readable format. Prior to this patch, this was implemented such that quantizer_reporter.py defined a dict `SUPPORTED_QSPECS` which was populated by the user. This dict would map qspec objects to string representations. This patch removes this dict and instead modifies the helper function `_qspec_repr` to return a compact string representation based on the attributes of the qspec. Signed-off-by: Martin Lindström <Martin.Lindstroem@arm.com> Change-Id: I9ccd9127b8c332e7c30662be6986ccad4a38881f
1 parent 4d5269f commit 3476cd8

4 files changed

Lines changed: 117 additions & 88 deletions

File tree

backends/arm/quantizer/arm_quantizer.py

Lines changed: 1 addition & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,7 @@
3939
)
4040
from executorch.backends.cortex_m.quantizer.pattern_matcher import PatternMatcher
4141

42-
from executorch.backends.cortex_m.quantizer_reporter import (
43-
QuantizerReporter,
44-
SUPPORTED_QSPECS,
45-
)
42+
from executorch.backends.cortex_m.quantizer_reporter import QuantizerReporter
4643

4744
from torch._ops import OpOverload
4845

@@ -391,32 +388,6 @@ def get_symmetric_a16w8_quantization_config(
391388
)
392389

393390

394-
# Register supported quantization configs and qspecs in the reporter for human-readable reporting
395-
# MLETORCH-1854: Temporary solution, refactor to automatically register these instead
396-
_symmetric_a8w4_config_per_channel = get_symmetric_a8w4_quantization_config()
397-
_symmetric_a8w8_config_per_channel = get_symmetric_quantization_config()
398-
_symmetric_a16w8_config_per_channel = get_symmetric_a16w8_quantization_config()
399-
_symmetric_a8w4_config_per_tensor = get_symmetric_a8w4_quantization_config(
400-
is_per_channel=False
401-
)
402-
_symmetric_a8w8_config_per_tensor = get_symmetric_quantization_config(
403-
is_per_channel=False
404-
)
405-
_symmetric_a16w8_config_per_tensor = get_symmetric_a16w8_quantization_config(
406-
is_per_channel=False
407-
)
408-
409-
SUPPORTED_QSPECS.update(
410-
{
411-
_symmetric_a8w4_config_per_channel.get_weight_qspec(): "INT4_PER_CHANNEL_QSPEC",
412-
_symmetric_a8w8_config_per_channel.get_weight_qspec(): "INT8_PER_CHANNEL_QSPEC",
413-
_symmetric_a8w8_config_per_tensor.get_weight_qspec(): "INT8_PER_TENSOR_QSPEC",
414-
_symmetric_a8w4_config_per_tensor.get_weight_qspec(): "INT4_PER_TENSOR_QSPEC",
415-
_symmetric_a8w8_config_per_tensor.get_input_act_qspec(): "INT8_PER_TENSOR_QSPEC",
416-
_symmetric_a16w8_config_per_tensor.get_input_act_qspec(): "INT16_PER_TENSOR_QSPEC",
417-
}
418-
)
419-
420391
NodeFilterType = Callable[[Node], bool]
421392
"""Type for a Node Filter used by annotators.
422393

backends/cortex_m/quantizer/quantization_configs.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
_get_int32_per_channel_bias_qspec,
1111
)
1212
from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
13-
from executorch.backends.cortex_m.quantizer_reporter import SUPPORTED_QSPECS
1413
from torch.fx import Node
1514
from torchao.quantization.pt2e import (
1615
HistogramObserver,
@@ -164,15 +163,3 @@ def get_bias_qspec(
164163
_get_int32_per_channel_bias_qspec,
165164
f"{__name__}.INT8_PER_CHANNEL_CONFIG",
166165
)
167-
168-
169-
SUPPORTED_QSPECS.update(
170-
{
171-
INT8_ACTIVATION_PER_TENSOR_QSPEC: "INT8_ACTIVATION_PER_TENSOR_QSPEC",
172-
INT8_ACTIVATION_PER_CHANNEL_QSPEC: "INT8_ACTIVATION_PER_CHANNEL_QSPEC",
173-
INT8_WEIGHT_PER_TENSOR_QSPEC: "INT8_WEIGHT_PER_TENSOR_QSPEC",
174-
INT8_WEIGHT_PER_CHANNEL_QSPEC: "INT8_WEIGHT_PER_CHANNEL_QSPEC",
175-
INT8_WEIGHT_PER_CHANNEL_TRANSPOSE_QSPEC: "INT8_WEIGHT_PER_CHANNEL_TRANSPOSE_QSPEC",
176-
SOFTMAX_OUTPUT_FIXED_QSPEC: "SOFTMAX_OUTPUT_FIXED_QSPEC",
177-
}
178-
)

backends/cortex_m/quantizer_reporter.py

Lines changed: 22 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55
"""Contains classes for reporting quantization decisions made by Quantizers.
66
7-
Basic useage:
7+
Basic usage:
88
1. Implement the QuantizerReporterUser API for all quantizers intending to use the reporter.
99
2. Instantiate the QuantizerReporter with a list of quantizers to be reported.
1010
3. After annotation, log the report using QuantizerReporter.log_quantizer_report(model).
@@ -17,7 +17,7 @@
1717

1818
import logging
1919
from importlib import import_module
20-
from typing import Callable, cast, Dict, List, NamedTuple, Optional
20+
from typing import Any, Callable, cast, Dict, List, NamedTuple, Optional
2121

2222
from torch.fx import GraphModule, Node
2323
from torchao.quantization.pt2e.quantizer import (
@@ -32,43 +32,29 @@
3232
logger = logging.getLogger(__name__)
3333
tabulate = cast(Callable[..., str], import_module("tabulate").tabulate)
3434

35-
# Look-up dicts used to get human readable names for supported quantization specs
36-
SUPPORTED_QSPECS: dict[QuantizationSpecBase | None, str] = {}
3735

36+
def qspec_repr(qspec: Optional[QuantizationSpecBase]) -> str:
37+
"""Get a human-readable representation of a QuantizationSpec."""
3838

39-
def _qspec_repr(qspec):
40-
"""Get a human readable representation of QuantizationSpecs.
41-
42-
Note that the observer_or_fake_quant_ctr field is created dynamically with
43-
the qspec so two qspecs created at different times will not evaluate as
44-
equal. Therefore a custom comparison is required.
45-
46-
#TODO: Clean up qconfig/ qspec string representation logic in cortex_m/arm
47-
backend.
48-
49-
"""
5039
if isinstance(qspec, SharedQuantizationSpec):
51-
return "SHARED_QSPEC"
40+
return f"SharedQuantizationSpec(edge_or_node={qspec.edge_or_node})"
5241
elif isinstance(qspec, DerivedQuantizationSpec):
53-
return "DERIVED_QSPEC"
54-
elif qspec is None:
55-
return "NO_QSPEC"
42+
return f"DerivedQuantizationSpec(derived_from={qspec.derived_from}, dtype={qspec.dtype})"
5643
elif isinstance(qspec, QuantizationSpec):
57-
for key, val in SUPPORTED_QSPECS.items():
58-
if type(qspec) is not type(key):
59-
continue
60-
if qspec.dtype != key.dtype:
61-
continue
62-
if qspec.quant_min != key.quant_min:
63-
continue
64-
if qspec.quant_max != key.quant_max:
65-
continue
66-
if qspec.qscheme != key.qscheme:
67-
continue
68-
if qspec.is_dynamic != key.is_dynamic:
69-
continue
70-
return val
71-
return "UNREGISTERED_QSPEC"
44+
45+
def _fmt(obj: Any) -> str:
46+
return str(obj).removeprefix("torch.").upper()
47+
48+
q_range_fmt = (
49+
f", range=({qspec.quant_min},{qspec.quant_max})"
50+
if (qspec.quant_min is not None or qspec.quant_max is not None)
51+
else ""
52+
)
53+
return f"QuantizationSpec(dtype={_fmt(qspec.dtype)}{q_range_fmt})"
54+
elif qspec is None:
55+
return "None"
56+
else:
57+
return qspec.__class__.__name__
7258

7359

7460
class QuantizerInfo(NamedTuple):
@@ -154,15 +140,15 @@ def report_accept(self, pattern: List[Node]) -> None:
154140
f"Node {node.name} was reported as annotated but annotation metadata is missing."
155141
)
156142
qspec_input_map_lines = [
157-
f"{node.name}: {_qspec_repr(qspec)}"
143+
f"{node.name}: {qspec_repr(qspec)}"
158144
for node, qspec in annotation.input_qspec_map.items()
159145
]
160146

161147
node_reports.append(
162148
NodeQSpecReport(
163149
node.name,
164150
qspec_input_map_lines,
165-
_qspec_repr(annotation.output_qspec),
151+
qspec_repr(annotation.output_qspec),
166152
)
167153
)
168154

backends/cortex_m/test/misc/test_quantizer_reporter.py

Lines changed: 94 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,36 @@
66
import logging
77

88
import torch
9-
from executorch.backends.cortex_m.quantizer.quantization_configs import (
10-
INT8_ACTIVATION_PER_CHANNEL_QSPEC,
11-
INT8_WEIGHT_PER_TENSOR_QSPEC,
12-
)
139
from executorch.backends.cortex_m.quantizer.quantizer import mark_node_as_annotated
1410
from executorch.backends.cortex_m.quantizer_reporter import (
1511
logger as quantizer_logger,
12+
qspec_repr,
1613
QuantizerInfo,
1714
QuantizerReport,
1815
QuantizerReporter,
1916
QuantizerReporterUser,
2017
)
2118
from torch.export import export
19+
from torchao.quantization.pt2e import MinMaxObserver, PerChannelMinMaxObserver
20+
from torchao.quantization.pt2e.quantizer import (
21+
DerivedQuantizationSpec,
22+
QuantizationSpec,
23+
SharedQuantizationSpec,
24+
)
25+
26+
INT8_WEIGHT_PER_TENSOR_QSPEC = QuantizationSpec(
27+
dtype=torch.int8,
28+
observer_or_fake_quant_ctr=MinMaxObserver,
29+
qscheme=torch.per_tensor_symmetric,
30+
quant_min=-127,
31+
quant_max=127,
32+
)
33+
INT8_ACTIVATION_PER_CHANNEL_QSPEC = QuantizationSpec(
34+
dtype=torch.int8,
35+
observer_or_fake_quant_ctr=PerChannelMinMaxObserver,
36+
qscheme=torch.per_channel_affine,
37+
ch_axis=0,
38+
)
2239

2340

2441
class _TwoOpModule(torch.nn.Module):
@@ -43,6 +60,74 @@ def get_quantizer_info(self) -> QuantizerInfo:
4360
)
4461

4562

63+
def test_qspec_repr_quantization_spec_with_range():
64+
qspec = QuantizationSpec(
65+
torch.int8,
66+
MinMaxObserver,
67+
quant_min=-42,
68+
quant_max=123,
69+
)
70+
assert qspec_repr(qspec) == "QuantizationSpec(dtype=INT8, range=(-42,123))"
71+
72+
73+
def test_qspec_repr_quantization_spec_without_range():
74+
qspec = QuantizationSpec(
75+
torch.int16,
76+
MinMaxObserver,
77+
)
78+
assert qspec_repr(qspec) == "QuantizationSpec(dtype=INT16)"
79+
80+
81+
def test_qspec_repr_quantization_spec_partial_range():
82+
qspec = QuantizationSpec(
83+
torch.int16,
84+
MinMaxObserver,
85+
quant_min=-100,
86+
)
87+
assert qspec_repr(qspec) == "QuantizationSpec(dtype=INT16, range=(-100,None))"
88+
89+
90+
def test_qspec_repr_shared_quantization_spec():
91+
graph_module = _export_two_op_graph_module()
92+
add_node = next(
93+
node
94+
for node in graph_module.graph.nodes
95+
if node.target == torch.ops.aten.add.Tensor
96+
)
97+
qspec = SharedQuantizationSpec(add_node)
98+
99+
assert qspec_repr(qspec) == f"SharedQuantizationSpec(edge_or_node={add_node})"
100+
101+
102+
def test_qspec_repr_derived_quantization_spec():
103+
graph_module = _export_two_op_graph_module()
104+
x_node = next(node for node in graph_module.graph.nodes if node.name == "x")
105+
y_node = next(node for node in graph_module.graph.nodes if node.name == "y")
106+
add_node = next(
107+
node
108+
for node in graph_module.graph.nodes
109+
if node.target == torch.ops.aten.add.Tensor
110+
)
111+
derived_from = [(x_node, add_node), (y_node, add_node)]
112+
qspec = DerivedQuantizationSpec(
113+
derived_from=derived_from,
114+
derive_qparams_fn=lambda _: (
115+
torch.tensor([1.0]),
116+
torch.tensor([0], dtype=torch.int32),
117+
),
118+
dtype=torch.int32,
119+
)
120+
121+
assert (
122+
qspec_repr(qspec)
123+
== f"DerivedQuantizationSpec(derived_from={derived_from}, dtype={qspec.dtype})"
124+
)
125+
126+
127+
def test_qspec_repr_none():
128+
assert qspec_repr(None) == "None"
129+
130+
46131
def test_warning_log_level(caplog):
47132
graph_module = _export_two_op_graph_module()
48133

@@ -128,11 +213,11 @@ def test_debug_log_level(caplog):
128213
Rejected due to previous annotation: 0
129214
Rejected nodes: 0
130215
131-
NODE NAME INPUT QSPEC MAP OUTPUT QSPEC MAP
132-
-- ----------- ------------------------------- ---------------------------------
133-
╒ add x: INT8_WEIGHT_PER_TENSOR_QSPEC NO_QSPEC
134-
| y: NO_QSPEC
135-
╘ relu INT8_ACTIVATION_PER_CHANNEL_QSPEC
216+
NODE NAME INPUT QSPEC MAP OUTPUT QSPEC MAP
217+
-- ----------- ------------------------------------------------- ----------------------------
218+
╒ add x: QuantizationSpec(dtype=INT8, range=(-127,127)) None
219+
| y: None
220+
╘ relu QuantizationSpec(dtype=INT8)
136221
----------------------------------------------------------------------------------------------------
137222
DummyQuantizer using dummy nodes
138223
Annotating with dummy.config

0 commit comments

Comments
 (0)