Skip to content

Commit 4d5269f

Browse files
Martin LindströmMartin Lindström
authored andcommitted
Arm backend: Add label attribute to QuantizationConfig
The quantizer reporter logs the quantization config in a human-readable format. Prior to this patch, this was done with the help of a dict called `SUPPORTED_QCONFIGS`, which was defined in quantizer_reporter.py and populated by the user. This patch reworks this concept by instead adding a label attribute to `QuantizationConfig` that the reporter can use to print the config in a human-readable format. Signed-off-by: Martin Lindström <Martin.Lindstroem@arm.com> Change-Id: I38e80c9c3d57fb9d858119fe4281b713bf472475
1 parent 316e435 commit 4d5269f

6 files changed

Lines changed: 69 additions & 75 deletions

File tree

backends/arm/quantizer/arm_quantizer.py

Lines changed: 44 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141

4242
from executorch.backends.cortex_m.quantizer_reporter import (
4343
QuantizerReporter,
44-
SUPPORTED_QCONFIGS,
4544
SUPPORTED_QSPECS,
4645
)
4746

@@ -219,20 +218,28 @@ def get_symmetric_quantization_config(
219218
bias_quantization_spec = _get_int32_bias_qspec
220219

221220
if is_dynamic:
222-
quantization_config = TOSAQuantizationConfig(
223-
act_quantization_spec,
224-
None,
225-
weight_quantization_spec,
226-
bias_quantization_spec,
227-
)
221+
output_activation = None
228222
else:
229-
quantization_config = TOSAQuantizationConfig(
230-
act_quantization_spec,
231-
act_quantization_spec,
232-
weight_quantization_spec,
233-
bias_quantization_spec,
234-
)
235-
return quantization_config
223+
output_activation = act_quantization_spec
224+
225+
module_name = __name__.rsplit(".", maxsplit=1)[-1]
226+
label = (
227+
f"{module_name}.get_symmetric_quantization_config("
228+
f"per_channel={int(is_per_channel)}, "
229+
f"qat={int(is_qat)}, "
230+
f"dynamic={int(is_dynamic)}, "
231+
f"act_range=[{act_qmin}, {act_qmax}], "
232+
f"weight_range=[{weight_qmin}, {weight_qmax}]"
233+
")"
234+
)
235+
236+
return TOSAQuantizationConfig(
237+
act_quantization_spec,
238+
output_activation,
239+
weight_quantization_spec,
240+
bias_quantization_spec,
241+
label,
242+
)
236243

237244

238245
@functools.lru_cache
@@ -357,22 +364,31 @@ def get_symmetric_a16w8_quantization_config(
357364
is_qat=is_qat,
358365
is_dynamic=is_dynamic,
359366
)
360-
# Replace activation quantization spec with 16-bit version
367+
361368
if is_dynamic:
362-
quantization_config = TOSAQuantizationConfig(
363-
act_quantization_spec, # 16-bit input activations
364-
None,
365-
base_config.weight, # 8-bit weights from base config
366-
base_config.bias, # bias from base config
367-
)
369+
output_activation = None
368370
else:
369-
quantization_config = TOSAQuantizationConfig(
370-
act_quantization_spec, # 16-bit input activations
371-
act_quantization_spec, # 16-bit output activations
372-
base_config.weight, # 8-bit weights from base config
373-
base_config.bias, # bias from base config
374-
)
375-
return quantization_config
371+
output_activation = act_quantization_spec
372+
373+
module_name = __name__.rsplit(".", maxsplit=1)[-1]
374+
label = (
375+
f"{module_name}.get_symmetric_a16w8_quantization_config("
376+
f"per_channel={int(is_per_channel)}, "
377+
f"qat={int(is_qat)}, "
378+
f"dynamic={int(is_dynamic)}, "
379+
f"act_range=[{act_quantization_spec.quant_min}, {act_quantization_spec.quant_max}], "
380+
f"weight_range=[{weight_qmin}, {weight_qmax}]"
381+
")"
382+
)
383+
384+
# Replace activation quantization spec with 16-bit version
385+
return TOSAQuantizationConfig(
386+
act_quantization_spec, # 16-bit input activations
387+
output_activation,
388+
base_config.weight, # 8-bit weights from base config
389+
base_config.bias, # bias from base config
390+
label,
391+
)
376392

377393

378394
# Register supported quantization configs and qspecs in the reporter for human-readable reporting
@@ -389,16 +405,6 @@ def get_symmetric_a16w8_quantization_config(
389405
_symmetric_a16w8_config_per_tensor = get_symmetric_a16w8_quantization_config(
390406
is_per_channel=False
391407
)
392-
SUPPORTED_QCONFIGS.update(
393-
{
394-
_symmetric_a8w8_config_per_channel: f"{__name__}.get_symmetric_quantization_config(is_per_channel=True)",
395-
_symmetric_a16w8_config_per_channel: f"{__name__}.get_symmetric_a16w8_quantization_config(is_per_channel=True)",
396-
_symmetric_a8w4_config_per_channel: f"{__name__}.get_symmetric_a8w4_quantization_config(is_per_channel=True)",
397-
_symmetric_a8w8_config_per_tensor: f"{__name__}.get_symmetric_quantization_config(is_per_channel=False)",
398-
_symmetric_a16w8_config_per_tensor: f"{__name__}.get_symmetric_a16w8_quantization_config(is_per_channel=False)",
399-
_symmetric_a8w4_config_per_tensor: f"{__name__}.get_symmetric_a8w4_quantization_config(is_per_channel=False)",
400-
}
401-
)
402408

403409
SUPPORTED_QSPECS.update(
404410
{

backends/arm/quantizer/arm_quantizer_utils.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,10 @@
2121
from executorch.backends.arm.common.annotation_meta import ArmAnnotationInfo
2222
from executorch.backends.arm.constants import DISALLOW_TFA_META_KEY
2323
from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
24-
from executorch.backends.cortex_m.quantizer_reporter import QuantizerReporterUser
24+
from executorch.backends.cortex_m.quantizer_reporter import (
25+
QuantizerInfo,
26+
QuantizerReporterUser,
27+
)
2528
from torch.fx import Node
2629

2730
from torchao.quantization.pt2e.quantizer import (
@@ -253,22 +256,19 @@ def __init__(
253256
self.pattern_matcher: "PatternMatcher" = pattern_matcher
254257

255258
def get_quantizer_info(self):
256-
from executorch.backends.cortex_m.quantizer_reporter import (
257-
QuantizerInfo,
258-
SUPPORTED_QCONFIGS,
259-
)
260-
261259
name = self.__class__.__name__
262260
targeted_nodes_description = str(self.node_finder)
263-
quantization_config_path = SUPPORTED_QCONFIGS.get(
264-
self.quantization_config, "UNREGISTERED_QCONFIG"
261+
qconfig_label = (
262+
self.quantization_config.label
263+
if self.quantization_config.label is not None
264+
else self.quantization_config.__class__.__name__ # no label, fallback to class name
265265
)
266266
support_config_path = self.pattern_matcher.support_dict_name
267267

268268
return QuantizerInfo(
269269
name,
270270
targeted_nodes_description,
271-
quantization_config_path,
271+
qconfig_label,
272272
support_config_path,
273273
)
274274

@@ -490,16 +490,14 @@ def __init__(self, targets: Optional[list[Callable[..., object]]] = None) -> Non
490490
)
491491

492492
def get_quantizer_info(self):
493-
from executorch.backends.cortex_m.quantizer_reporter import QuantizerInfo
494-
495493
name = self.__class__.__name__
496494
targeted_nodes_description = ""
497-
quantization_config_path = "SHARED_QCONFIG"
495+
qconfig_label = "shared qparams for connected targeted nodes"
498496
support_config_path = self.support_config_path
499497
return QuantizerInfo(
500498
name,
501499
targeted_nodes_description,
502-
quantization_config_path,
500+
qconfig_label,
503501
support_config_path,
504502
)
505503

backends/arm/quantizer/quantization_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class QuantizationConfig:
4646
output_activation: Optional[QuantizationSpecBase]
4747
weight: Optional[QuantizationSpecBase]
4848
bias: Optional[QuantizationSpecBase] | Callable[[Any], Any]
49+
label: Optional[str] = None # Optional label for debugging/visualization purposes
4950

5051
def get_input_act_qspec(
5152
self, node: Optional[Node] = None, input_node: Optional[Node] = None

backends/cortex_m/quantizer/quantization_configs.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,7 @@
1010
_get_int32_per_channel_bias_qspec,
1111
)
1212
from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
13-
from executorch.backends.cortex_m.quantizer_reporter import (
14-
SUPPORTED_QCONFIGS,
15-
SUPPORTED_QSPECS,
16-
)
13+
from executorch.backends.cortex_m.quantizer_reporter import SUPPORTED_QSPECS
1714
from torch.fx import Node
1815
from torchao.quantization.pt2e import (
1916
HistogramObserver,
@@ -156,6 +153,7 @@ def get_bias_qspec(
156153
INT8_ACTIVATION_PER_TENSOR_QSPEC,
157154
INT8_WEIGHT_PER_TENSOR_QSPEC,
158155
_get_int32_bias_qspec,
156+
f"{__name__}.INT8_PER_TENSOR_CONFIG",
159157
)
160158

161159

@@ -164,18 +162,10 @@ def get_bias_qspec(
164162
INT8_ACTIVATION_PER_TENSOR_QSPEC,
165163
INT8_WEIGHT_PER_CHANNEL_QSPEC,
166164
_get_int32_per_channel_bias_qspec,
165+
f"{__name__}.INT8_PER_CHANNEL_CONFIG",
167166
)
168167

169168

170-
# Register supported quantization configs and qspecs in the reporter for human-readable reporting
171-
# MLETORCH-1854: Temporary solution, refactor to automatically register these instead
172-
SUPPORTED_QCONFIGS.update(
173-
{
174-
INT8_PER_CHANNEL_CONFIG: f"{__name__}.INT8_PER_CHANNEL_QCONFIG",
175-
INT8_PER_TENSOR_CONFIG: f"{__name__}.INT8_PER_TENSOR_QCONFIG",
176-
}
177-
)
178-
179169
SUPPORTED_QSPECS.update(
180170
{
181171
INT8_ACTIVATION_PER_TENSOR_QSPEC: "INT8_ACTIVATION_PER_TENSOR_QSPEC",

backends/cortex_m/quantizer_reporter.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import logging
1919
from importlib import import_module
20-
from typing import Any, Callable, cast, Dict, List, NamedTuple, Optional
20+
from typing import Callable, cast, Dict, List, NamedTuple, Optional
2121

2222
from torch.fx import GraphModule, Node
2323
from torchao.quantization.pt2e.quantizer import (
@@ -32,8 +32,7 @@
3232
logger = logging.getLogger(__name__)
3333
tabulate = cast(Callable[..., str], import_module("tabulate").tabulate)
3434

35-
# Look-up dicts used to get human readable names for supported quantization configs and specs
36-
SUPPORTED_QCONFIGS: dict[Any, str] = {}
35+
# Look-up dicts used to get human readable names for supported quantization specs
3736
SUPPORTED_QSPECS: dict[QuantizationSpecBase | None, str] = {}
3837

3938

@@ -77,7 +76,7 @@ class QuantizerInfo(NamedTuple):
7776

7877
name: str
7978
targeted_nodes_description: str
80-
quantization_config_path: str
79+
qconfig_label: str
8180
support_config_path: str
8281

8382

@@ -112,8 +111,8 @@ class QuantizerReport:
112111

113112
_PREVIOUS_ANNOTATION_REJECT_REASON = "Tried annotating already quantized node."
114113

115-
def __init__(self, quantizer):
116-
self.quantizer = quantizer.get_quantizer_info()
114+
def __init__(self, quantizer_info: QuantizerInfo):
115+
self.quantizer_info = quantizer_info
117116
self.accepted_patterns: List[AnnotatedPatternReport] = []
118117
self.rejected_patterns: List[RejectedPatternReport] = []
119118

@@ -180,11 +179,11 @@ def report_reject(self, pattern, reason):
180179
def get_quantizer_info_rows(self) -> List[str]:
181180
rows = []
182181
rows.append(
183-
f"{self.quantizer.name} using {self.quantizer.targeted_nodes_description}"
182+
f"{self.quantizer_info.name} using {self.quantizer_info.targeted_nodes_description}"
184183
)
185-
rows.append(f"Annotating with {self.quantizer.quantization_config_path}")
184+
rows.append(f"Annotating with {self.quantizer_info.qconfig_label}")
186185
rows.append(
187-
f"Supported operators and patterns defined by {self.quantizer.support_config_path}"
186+
f"Supported operators and patterns defined by {self.quantizer_info.support_config_path}"
188187
)
189188

190189
if (
@@ -317,7 +316,7 @@ def set_quantizers(self, quantizers: List[QuantizerReporterUser]) -> None:
317316
f"Quantizer {quantizer.__class__.__name__} does not implement QuantizerReporterUser interface and will not report quantization decisions."
318317
)
319318

320-
self.quantizers[quantizer] = QuantizerReport(quantizer)
319+
self.quantizers[quantizer] = QuantizerReport(quantizer.get_quantizer_info())
321320

322321
def report_reject(
323322
self, quantizer: QuantizerReporterUser, pattern: List[Node], reason: str

backends/cortex_m/test/misc/test_quantizer_reporter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def get_quantizer_info(self) -> QuantizerInfo:
3838
return QuantizerInfo(
3939
name="DummyQuantizer",
4040
targeted_nodes_description="dummy nodes",
41-
quantization_config_path="dummy.config",
41+
qconfig_label="dummy.config",
4242
support_config_path="dummy.support",
4343
)
4444

0 commit comments

Comments
 (0)