forked from pytorch/executorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy patharm_quantizer.py
More file actions
1287 lines (1077 loc) · 47.4 KB
/
arm_quantizer.py
File metadata and controls
1287 lines (1077 loc) · 47.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2024-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#
# Quantizer for Arm backend
#
from __future__ import annotations
import functools
import logging
from typing import Any, Callable, Dict, List, Optional
import torch
from executorch.backends.arm._passes import ArmPassManager
from executorch.backends.arm.common.annotation_meta import ArmAnnotationInfo
from executorch.backends.arm.constants import DISALLOW_TFA_META_KEY
from executorch.backends.arm.ethosu import EthosUCompileSpec
from executorch.backends.arm.quantizer.quantization_config import (
QuantizationConfig,
TOSAQuantizationConfig,
)
from executorch.backends.arm.quantizer.quantizer_support import (
TOSA_QUANTIZER_SUPPORT_DICT,
)
from executorch.backends.arm.tosa import TosaSpecification
from executorch.backends.cortex_m.quantizer.node_finders import (
GlobalNodeFinder,
InputNodeFinder,
ModuleNameNodeFinder,
ModuleTypeNodeFinder,
NodeNameNodeFinder,
NodeTargetNodeFinder,
OutputNodeFinder,
)
from executorch.backends.cortex_m.quantizer.pattern_matcher import PatternMatcher
from executorch.backends.cortex_m.quantizer.quantizer_reporter import (
QuantizerReporter,
SUPPORTED_QCONFIGS,
SUPPORTED_QSPECS,
)
from torch._ops import OpOverload
from torchao.quantization.pt2e.quantizer import (
ComposableQuantizer,
QuantizationAnnotation,
Quantizer,
)
from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY
from executorch.backends.arm.common.arm_compile_spec import (
ArmCompileSpec,
) # isort: skip
from executorch.backends.arm._passes.arm_pass_utils import (
get_cond_while_submodules_nested,
is_submodule_node,
)
from executorch.backends.arm.quantizer.arm_quantizer_utils import (
_get_int32_bias_qspec,
_get_int32_per_channel_bias_qspec,
is_annotated,
mark_node_as_annotated,
NodeFinder,
PatternQuantizer,
SharedQspecQuantizer,
)
from executorch.backends.arm.vgf import VgfCompileSpec
from executorch.exir._warnings import experimental
from torch.fx import GraphModule, Node
from torchao.quantization.pt2e import (
FakeQuantize,
FusedMovingAvgObsFakeQuantize,
HistogramObserver,
MinMaxObserver,
MovingAverageMinMaxObserver,
ObserverOrFakeQuantizeConstructor,
PerChannelMinMaxObserver,
PlaceholderObserver,
)
from torchao.quantization.pt2e.quantize_pt2e import (
convert_pt2e,
prepare_pt2e,
prepare_qat_pt2e,
)
from torchao.quantization.pt2e.quantizer import (
annotate_input_qspec_map,
annotate_output_qspec,
get_module_name_filter,
QuantizationSpec,
)
from .quantization_annotator import annotate_graph
__all__ = [
"TOSAQuantizer",
"EthosUQuantizer",
"VgfQuantizer",
"get_symmetric_a16w8_quantization_config",
"get_symmetric_quantization_config",
"get_uint8_io_quantization_config",
]
logger = logging.getLogger(__name__)
@functools.lru_cache
def get_symmetric_quantization_config(
is_per_channel: bool = True,
is_qat: bool = False,
is_dynamic: bool = False,
act_qmin: int = -128,
act_qmax: int = 127,
weight_qmin: int = -127,
weight_qmax: int = 127,
eps: float = 2**-16,
) -> QuantizationConfig:
"""Create symmetric quantization config for activations and weights.
Activations use an affine qscheme; "symmetric" refers to the weight
quantization qscheme.
Args:
is_per_channel (bool): Whether to use per-channel quantization for
weights.
is_qat (bool): Whether the configuration targets quantization aware
training.
is_dynamic (bool): Whether to generate dynamic activation observers.
act_qmin (int): Minimum activation quantization value.
act_qmax (int): Maximum activation quantization value.
weight_qmin (int): Minimum weight quantization value.
weight_qmax (int): Maximum weight quantization value.
Returns:
QuantizationConfig: Quantization settings for activations, weights, and
bias.
"""
extra_args: Dict[str, Any] = {"eps": eps}
if is_qat:
if is_dynamic:
act_observer_or_fake_quant_ctr = FakeQuantize
dynamic_quant_observer = MovingAverageMinMaxObserver.with_args(
averaging_constant=1
)
extra_args["observer"] = dynamic_quant_observer
else:
act_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize # type: ignore[assignment]
else:
if is_dynamic:
act_observer_or_fake_quant_ctr = PlaceholderObserver # type: ignore[assignment]
else:
act_observer_or_fake_quant_ctr = HistogramObserver # type: ignore[assignment]
act_quantization_spec = QuantizationSpec(
dtype=torch.int8,
quant_min=act_qmin,
quant_max=act_qmax,
qscheme=torch.per_tensor_affine,
is_dynamic=is_dynamic,
observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(
**extra_args,
),
)
# Setup quantization config for weights
weight_qscheme = (
torch.per_channel_symmetric if is_per_channel else torch.per_tensor_symmetric
)
weight_observer_or_fake_quant_ctr: ObserverOrFakeQuantizeConstructor = (
MinMaxObserver
)
# Determine the right observer/fake-quant constructor
if is_qat:
if is_per_channel:
weight_observer_or_fake_quant_ctr = FakeQuantize.with_args(
observer=PerChannelMinMaxObserver,
quant_min=weight_qmin,
quant_max=weight_qmax,
dtype=torch.qint8,
qscheme=torch.per_channel_symmetric,
reduce_range=False,
ch_axis=0,
**extra_args,
)
else:
# Set plain fake-quant with true min/max
weight_observer_or_fake_quant_ctr = FakeQuantize.with_args(**extra_args)
else:
# PTQ: set min/max observer
weight_observer_or_fake_quant_ctr = (
PerChannelMinMaxObserver if is_per_channel else MinMaxObserver
)
weight_observer_or_fake_quant_ctr = weight_observer_or_fake_quant_ctr.with_args(
**extra_args,
)
weight_quantization_spec = QuantizationSpec(
dtype=torch.int8,
quant_min=weight_qmin,
quant_max=weight_qmax,
qscheme=weight_qscheme,
ch_axis=0,
is_dynamic=False,
observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr,
)
if is_per_channel:
bias_quantization_spec = _get_int32_per_channel_bias_qspec
else:
bias_quantization_spec = _get_int32_bias_qspec
if is_dynamic:
quantization_config = TOSAQuantizationConfig(
act_quantization_spec,
None,
weight_quantization_spec,
bias_quantization_spec,
)
else:
quantization_config = TOSAQuantizationConfig(
act_quantization_spec,
act_quantization_spec,
weight_quantization_spec,
bias_quantization_spec,
)
return quantization_config
@functools.lru_cache
def get_uint8_io_quantization_config(
is_qat: bool = False,
is_dynamic: bool = False,
eps: float = 2**-16,
) -> QuantizationConfig:
"""Create a uint8 IO quantization config for TOSA backends.
This config is intended for model inputs/outputs only. Internal tensors
should remain int8 for TOSA INT lowering.
"""
extra_args: Dict[str, Any] = {"eps": eps}
if is_qat:
if is_dynamic:
act_observer_or_fake_quant_ctr = FakeQuantize
dynamic_quant_observer = MovingAverageMinMaxObserver.with_args(
averaging_constant=1
)
extra_args["observer"] = dynamic_quant_observer
else:
act_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize # type: ignore[assignment]
else:
if is_dynamic:
act_observer_or_fake_quant_ctr = PlaceholderObserver # type: ignore[assignment]
else:
act_observer_or_fake_quant_ctr = HistogramObserver # type: ignore[assignment]
act_quantization_spec = QuantizationSpec(
dtype=torch.uint8,
quant_min=torch.iinfo(torch.uint8).min,
quant_max=torch.iinfo(torch.uint8).max,
qscheme=torch.per_tensor_affine,
is_dynamic=is_dynamic,
observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(
**extra_args,
),
)
return TOSAQuantizationConfig(
act_quantization_spec,
act_quantization_spec,
None,
None,
)
def get_symmetric_a8w4_quantization_config(
is_per_channel: bool = True, is_qat: bool = True, is_dynamic: bool = False
):
return get_symmetric_quantization_config(
is_per_channel, is_qat, is_dynamic, weight_qmin=-7, weight_qmax=7
)
@functools.lru_cache
def get_symmetric_a16w8_quantization_config(
is_per_channel: bool = True,
is_qat: bool = False,
is_dynamic: bool = False,
weight_qmin: int = -127,
weight_qmax: int = 127,
epsilon: float = 2**-12,
) -> QuantizationConfig:
"""16A8W quantization config: 16-bit activations, 8-bit weights.
This configuration provides better accuracy than 8A8W while maintaining
reasonable memory usage through 8-bit weights.
Args:
is_per_channel (bool): Whether to use per-channel quantization for
weights.
is_qat (bool): Whether this is for quantization aware training.
is_dynamic (bool): Whether to use dynamic quantization.
weight_qmin (int): Minimum quantization value for weights.
weight_qmax (int): Maximum quantization value for weights.
epsilon (float): Value used to pad observed [qmin, qmax] before initial
zero-point and scale calculation.
Returns:
QuantizationConfig: Configuration with 16-bit activations and 8-bit
weights.
"""
extra_args: Dict[str, Any] = {"eps": epsilon}
# Setup observer/fake-quant for 16-bit activations
if is_qat:
if is_dynamic:
act_observer_or_fake_quant_ctr = FakeQuantize
dynamic_quant_observer = MovingAverageMinMaxObserver.with_args(
averaging_constant=1
)
extra_args["observer"] = dynamic_quant_observer
else:
act_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize # type: ignore[assignment]
else:
if is_dynamic:
act_observer_or_fake_quant_ctr = PlaceholderObserver # type: ignore[assignment]
else:
# HistogramObserver works well for 16-bit range
act_observer_or_fake_quant_ctr = HistogramObserver # type: ignore[assignment]
# 16-bit activation quantization spec
act_quantization_spec = QuantizationSpec(
dtype=torch.int16,
quant_min=torch.iinfo(torch.int16).min + 1, # -32767
quant_max=torch.iinfo(torch.int16).max, # 32767
qscheme=torch.per_tensor_symmetric,
is_dynamic=is_dynamic,
observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(
**extra_args,
),
)
# Instead of reconstructing quantization_config, just clone and update as needed
# Clone the quantization_config from get_symmetric_quantization_config and update activation spec
base_config = get_symmetric_quantization_config(
is_per_channel=is_per_channel,
is_qat=is_qat,
is_dynamic=is_dynamic,
)
# Replace activation quantization spec with 16-bit version
if is_dynamic:
quantization_config = TOSAQuantizationConfig(
act_quantization_spec, # 16-bit input activations
None,
base_config.weight, # 8-bit weights from base config
base_config.bias, # bias from base config
)
else:
quantization_config = TOSAQuantizationConfig(
act_quantization_spec, # 16-bit input activations
act_quantization_spec, # 16-bit output activations
base_config.weight, # 8-bit weights from base config
base_config.bias, # bias from base config
)
return quantization_config
# Register supported quantization configs and qspecs in the reporter for human-readable reporting
# MLETORCH-1854: Temporary solution, refactor to automatically register these instead
_symmetric_a8w4_config_per_channel = get_symmetric_a8w4_quantization_config()
_symmetric_a8w8_config_per_channel = get_symmetric_quantization_config()
_symmetric_a16w8_config_per_channel = get_symmetric_a16w8_quantization_config()
_symmetric_a8w4_config_per_tensor = get_symmetric_a8w4_quantization_config(
is_per_channel=False
)
_symmetric_a8w8_config_per_tensor = get_symmetric_quantization_config(
is_per_channel=False
)
_symmetric_a16w8_config_per_tensor = get_symmetric_a16w8_quantization_config(
is_per_channel=False
)
SUPPORTED_QCONFIGS.update(
{
_symmetric_a8w8_config_per_channel: f"{__name__}.get_symmetric_quantization_config(is_per_channel=True)",
_symmetric_a16w8_config_per_channel: f"{__name__}.get_symmetric_a16w8_quantization_config(is_per_channel=True)",
_symmetric_a8w4_config_per_channel: f"{__name__}.get_symmetric_a8w4_quantization_config(is_per_channel=True)",
_symmetric_a8w8_config_per_tensor: f"{__name__}.get_symmetric_quantization_config(is_per_channel=False)",
_symmetric_a16w8_config_per_tensor: f"{__name__}.get_symmetric_a16w8_quantization_config(is_per_channel=False)",
_symmetric_a8w4_config_per_tensor: f"{__name__}.get_symmetric_a8w4_quantization_config(is_per_channel=False)",
}
)
SUPPORTED_QSPECS.update(
{
_symmetric_a8w4_config_per_channel.get_weight_qspec(): "INT4_PER_CHANNEL_QSPEC",
_symmetric_a8w8_config_per_channel.get_weight_qspec(): "INT8_PER_CHANNEL_QSPEC",
_symmetric_a8w8_config_per_tensor.get_weight_qspec(): "INT8_PER_TENSOR_QSPEC",
_symmetric_a8w4_config_per_tensor.get_weight_qspec(): "INT4_PER_TENSOR_QSPEC",
_symmetric_a8w8_config_per_tensor.get_input_act_qspec(): "INT8_PER_TENSOR_QSPEC",
_symmetric_a16w8_config_per_tensor.get_input_act_qspec(): "INT16_PER_TENSOR_QSPEC",
}
)
NodeFilterType = Callable[[Node], bool]
"""Type for a Node Filter used by annotators.
A Node filter is a function that takes a Node and returns whether the node
should be annotated or not.
"""
def _get_module_type_filter(tp: Callable) -> NodeFilterType:
"""Get the module_type_filter function for a given module type.
The filter accepts a node and checks if the node comes from a module that
has a certain module type.
Args:
tp (Callable): Module class to match against the graph node metadata.
Returns:
NodeFilterType: Predicate that returns True for nodes from the module
type.
For example:
node: linear_op = call_function[...](...) # type Block -> Sub -> Linear
>> module_type_filter = _get_module_type_filter(Sub)
>> print(module_type_filter(node))
True # the node is from the submodule `Sub` (same for `Block` and `Linear`)
"""
tp_str = tp.__module__ + "." + tp.__qualname__
def module_type_filter(n: Node) -> bool:
"""Return True if the node originates from the target module type."""
# node_stack example: {
# 'L__self___sub': ("L['self'].sub", <class '....Sub'>),
# 'L__self___sub_linear': ("L['self'].sub.linear", <class 'torch.nn.modules.linear.Linear'>)
# }
nn_module_stack = n.meta.get("nn_module_stack", {})
types = [t for _, t in nn_module_stack.values()]
return tp_str in types
return module_type_filter
def _get_not_module_type_or_name_filter(
tp_list: List[Callable], module_name_list: List[str]
) -> NodeFilterType:
"""Create a filter that excludes provided module types and names.
Args:
tp_list (List[Callable]): Module types to exclude from annotation.
module_name_list (List[str]): Module names to exclude from annotation.
Returns:
NodeFilterType: Filter that returns True when the node does not match
any provided module type or name.
"""
module_type_filters = [_get_module_type_filter(tp) for tp in tp_list]
module_name_list_filters = [get_module_name_filter(m) for m in module_name_list]
def not_module_type_or_name_filter(n: Node) -> bool:
"""Return True when the node matches none of the blocked filters."""
return not any(f(n) for f in module_type_filters + module_name_list_filters)
return not_module_type_or_name_filter
def _for_each_filtered_node(
model: GraphModule,
filter_fn: Callable[[Node], bool],
):
for node in model.graph.nodes:
if filter_fn(node):
yield node
class TOSAQuantizer(Quantizer):
"""Manage quantization annotations for TOSA-compatible backends.
.. warning::
Setting ``use_composable_quantizer=True`` enables an experimental API
surface that may change without notice.
"""
def __init__(
self,
compile_spec_or_tosa_spec,
use_composable_quantizer: bool = False,
) -> None:
"""Create a TOSA quantizer from a TOSA spec or Arm compile spec.
.. warning::
Setting ``use_composable_quantizer=True`` enables an experimental
API surface that may change without notice.
"""
self.use_composable_quantizer = use_composable_quantizer
self.quantizer: _TOSAQuantizerV1 | _TOSAQuantizerV2
if use_composable_quantizer:
logger.info(
"Using composable quantizer implementation in the arm backend. See https://github.com/pytorch/executorch/issues/17701"
)
self.quantizer = _TOSAQuantizerV2(compile_spec_or_tosa_spec)
else:
logger.info(
"Using default quantizer in the arm backend. This quantizer is planned to be replaced by the composable quantizer implementation in the future, see https://github.com/pytorch/executorch/issues/17701"
)
self.quantizer = _TOSAQuantizerV1(compile_spec_or_tosa_spec)
@property
def tosa_spec(self):
return self.quantizer.tosa_spec
@property
def compile_spec(self):
return self.quantizer.compile_spec
@property
def global_config(self):
return self.quantizer.global_config
@global_config.setter
def global_config(self, value: Optional[QuantizationConfig]) -> None:
if isinstance(self.quantizer, _TOSAQuantizerV1):
self.quantizer.global_config = value
else:
raise NotImplementedError(
"Composable quantizer does not allow setting global_config directly. Please use set_global() instead."
)
@property
def io_config(self):
if isinstance(self.quantizer, _TOSAQuantizerV1):
return self.quantizer.io_config
else:
raise NotImplementedError(
"Composable quantizer does not allow accessing io_config."
)
@io_config.setter
def io_config(self, value: Optional[QuantizationConfig]) -> None:
if isinstance(self.quantizer, _TOSAQuantizerV1):
self.quantizer.io_config = value
else:
raise NotImplementedError(
"Composable quantizer does not allow setting io_config directly. Please use set_io() instead."
)
@property
def module_type_config(self):
if isinstance(self.quantizer, _TOSAQuantizerV1):
return self.quantizer.module_type_config
else:
raise NotImplementedError(
"Composable quantizer does not allow accessing module_type_config."
)
@module_type_config.setter
def module_type_config(
self, value: Dict[Callable, Optional[QuantizationConfig]]
) -> None:
if isinstance(self.quantizer, _TOSAQuantizerV1):
self.quantizer.module_type_config = value
else:
raise NotImplementedError(
"Composable quantizer does not allow setting module_type_config directly. Please use set_module_type() instead."
)
@property
def module_name_config(self):
if isinstance(self.quantizer, _TOSAQuantizerV1):
return getattr(self.quantizer, "module_name_config", {})
else:
raise NotImplementedError(
"Composable quantizer does not allow accessing module_name_config."
)
@module_name_config.setter
def module_name_config(
self, value: Dict[str, Optional[QuantizationConfig]]
) -> None:
if isinstance(self.quantizer, _TOSAQuantizerV1):
self.quantizer.module_name_config = value
else:
raise NotImplementedError(
"Composable quantizer does not allow setting module_name_config directly. Please use set_module_name() instead."
)
def set_global(
self, quantization_config: Optional[QuantizationConfig]
) -> TOSAQuantizer:
"""Set quantization_config for submodules not matched by other filters.
Args:
quantization_config (Optional[QuantizationConfig]): Configuration to
apply to modules that are not captured by name or type filters.
``None`` indicates no quantization.
"""
self.quantizer.set_global(quantization_config)
return self
def set_module_type(
self, module_type: Callable, quantization_config: Optional[QuantizationConfig]
) -> TOSAQuantizer:
"""Set quantization_config for submodules with a given module type.
For example, calling set_module_type(Softmax) quantizes supported
patterns in each Softmax instance with the provided quantization_config.
Args:
module_type (Callable): Type whose submodules should use the
provided quantization configuration.
quantization_config (Optional[QuantizationConfig]): Configuration to
apply to submodules of the given type. ``None`` indicates no
quantization.
"""
self.quantizer.set_module_type(module_type, quantization_config)
return self
def set_module_name(
self, module_name: str, quantization_config: Optional[QuantizationConfig]
) -> TOSAQuantizer:
"""Set quantization_config for submodules with a given module name.
For example, calling set_module_name("blocks.sub") quantizes supported
patterns for that submodule with the provided quantization_config.
Args:
module_name (str): Fully qualified module name to configure.
quantization_config (Optional[QuantizationConfig]): Configuration
applied to the named submodule. ``None`` indicates no
quantization.
"""
self.quantizer.set_module_name(module_name, quantization_config)
return self
def set_io(
self, quantization_config: Optional[QuantizationConfig]
) -> TOSAQuantizer:
"""Set quantization_config for input and output nodes.
Args:
quantization_config (Optional[QuantizationConfig]): Configuration
describing activation quantization for model inputs and outputs.
``None`` indicates no quantization.
"""
self.quantizer.set_io(quantization_config)
return self
@experimental(
"This API is experimental and may change without notice. "
"It is only available when use_composable_quantizer=True."
)
def add_quantizer(self, quantizer: Quantizer) -> TOSAQuantizer:
"""Insert a quantizer with highest precedence."""
if self.use_composable_quantizer:
return self.quantizer.add_quantizer(quantizer) # type: ignore[union-attr,return-value]
raise NotImplementedError(
"add_quantizer is only supported in the composable quantizer implementation."
)
@experimental(
"This API is experimental and may change without notice. "
"It is only available when use_composable_quantizer=True."
)
def set_node_finder(
self, quantization_config: Optional[QuantizationConfig], node_finder: NodeFinder
) -> TOSAQuantizer:
"""Set quantization_config for nodes matched by a custom NodeFinder.
Args:
quantization_config (Optional[QuantizationConfig]): Configuration
describing quantization settings for nodes matched by the provided
NodeFinder. ``None`` indicates no quantization.
"""
if self.use_composable_quantizer:
return self.quantizer.set_node_finder(quantization_config, node_finder) # type: ignore[union-attr,return-value]
raise NotImplementedError(
"set_node_finder is only supported in the composable quantizer implementation."
)
@experimental(
"This API is experimental and may change without notice. "
"It is only available when use_composable_quantizer=True."
)
def set_node_target(
self, node_target: OpOverload, quantization_config: Optional[QuantizationConfig]
) -> TOSAQuantizer:
"""Set quantization config for a specific operator target."""
if self.use_composable_quantizer:
return self.quantizer.set_node_target(node_target, quantization_config) # type: ignore[union-attr,return-value]
raise NotImplementedError(
"set_node_target is only supported in the composable quantizer implementation."
)
@experimental(
"This API is experimental and may change without notice. "
"It is only available when use_composable_quantizer=True."
)
def set_node_name(
self, node_name: str, quantization_config: Optional[QuantizationConfig]
) -> TOSAQuantizer:
"""Set quantization config for a specific node name."""
if self.use_composable_quantizer:
return self.quantizer.set_node_name(node_name, quantization_config) # type: ignore[union-attr,return-value]
raise NotImplementedError(
"set_node_name is only supported in the composable quantizer implementation."
)
def transform_for_annotation(self, model: GraphModule) -> GraphModule:
"""Transform the graph to prepare it for quantization annotation.
Decomposes all operators where required to get correct quantization parameters.
Args:
model (GraphModule): Model whose graph will be transformed.
Returns:
GraphModule: Transformed model prepared for annotation.
"""
return self.quantizer.transform_for_annotation(model)
def annotate(self, model: GraphModule) -> GraphModule:
"""Annotate the graph with the configured quantization settings.
Currently only does static quantization annotation.
Args:
model (GraphModule): Model to annotate statically.
Returns:
GraphModule: Annotated model ready for export.
"""
return self.quantizer.annotate(model)
def validate(self, model: GraphModule) -> None:
"""Validate the quantization results. Currently, this includes:
- Ensure tensor inputs to each operator live on the same device.
Args:
model (GraphModule): GraphModule being validated.
Raises:
ValueError: If tensor inputs for any operator span more than one
device.
"""
for node in model.graph.nodes:
if node.op != "call_function":
continue
devices = set()
for arg_node in node.all_input_nodes:
meta_val = arg_node.meta.get("val", None)
if meta_val is None:
continue
if isinstance(meta_val, (tuple, list)):
for tensor in meta_val:
devices.add(
str(
getattr(
tensor,
"device",
f"Could not get device from {tensor}",
)
)
)
else:
devices.add(
str(
getattr(
meta_val,
"device",
f"Could not get device from {meta_val}",
)
)
)
if len(devices) > 1:
raise ValueError(
f"Quantizer detected operator {node.name} with different device inputs: {devices}."
)
def _quantize_with_submodules(
self,
model: GraphModule,
calibration_samples: list[tuple],
is_qat: bool = False,
fold_quantize: bool = True,
):
"""Quantizes a GraphModule in a way such that conditional submodules are
handled properly.
Note: torchao's prepare_pt2e and convert_pt2e natively handle
while_loop body_fn submodules, so we only manually process cond
branches and while_loop cond_fn here.
Args:
model (GraphModule): The model to quantize.
calibration_samples (list[tuple]): A list of inputs to used to
calibrate the model during quantization. To properly calibrate a
model with submodules, at least one sample per code path is
needed.
is_qat (bool): Whether to do quantization aware training or not.
fold_quantize (bool): Enables or disables constant folding when quantization
is completed.
Returns:
GraphModule: The quantized model.
"""
prepare_fn = prepare_qat_pt2e if is_qat else prepare_pt2e
prepared = prepare_fn(model, self)
# Prepare conditional submodules (e.g., if/while bodies)
# prepare only cond branches and while_loop cond_fn
for name, submodule, _ in get_cond_while_submodules_nested(
prepared, apply_quantization=True
):
prepared.set_submodule(name, prepare_fn(submodule, self), strict=True)
for submodule_node in submodule.graph.nodes:
if is_submodule_node(submodule_node):
for nested_name, nested_sub, _ in get_cond_while_submodules_nested(
submodule, apply_quantization=True
):
prepared.set_submodule(
nested_name, prepare_fn(nested_sub, self), strict=True
)
for inp in calibration_samples:
prepared(*inp)
# Prepare conditional submodules (e.g., if/while bodies)
# convert only cond branches and while_loop cond_fn
for _, submodule, _ in get_cond_while_submodules_nested(
prepared, apply_quantization=True
):
converted = convert_pt2e(submodule, fold_quantize=fold_quantize)
for submodule_node in submodule.graph.nodes:
if is_submodule_node(submodule_node):
for nested_name, nested_sub, _ in get_cond_while_submodules_nested(
submodule, apply_quantization=True
):
converted.set_submodule(
nested_name,
convert_pt2e(nested_sub, fold_quantize=fold_quantize),
strict=True,
)
return convert_pt2e(prepared, fold_quantize=fold_quantize)
class _TOSAQuantizerV1(Quantizer):
def __init__(
self, compile_spec_or_tosa_spec: TosaSpecification | ArmCompileSpec
) -> None:
super().__init__()
self.compile_spec: ArmCompileSpec
if isinstance(compile_spec_or_tosa_spec, TosaSpecification):
from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec
self.compile_spec = TosaCompileSpec(compile_spec_or_tosa_spec)
self.tosa_spec = self.compile_spec.tosa_spec
elif isinstance(compile_spec_or_tosa_spec, ArmCompileSpec):
self.compile_spec = compile_spec_or_tosa_spec
self.tosa_spec = self.compile_spec.tosa_spec
else:
raise TypeError(
f"TOSAQuantizer constructor expects "
f"a TosaSpecification or compile_spec list, "
f"got {type(compile_spec_or_tosa_spec)}"
)
self.global_config: Optional[QuantizationConfig] = None
self.io_config: Optional[QuantizationConfig] = None
self.module_type_config: Dict[Callable, Optional[QuantizationConfig]] = {}
self.module_name_config: Dict[str, Optional[QuantizationConfig]] = {}
def set_global(
self, quantization_config: Optional[QuantizationConfig]
) -> _TOSAQuantizerV1:
self.global_config = quantization_config
return self
def set_module_type(
self, module_type: Callable, quantization_config: Optional[QuantizationConfig]
) -> _TOSAQuantizerV1:
self.module_type_config[module_type] = quantization_config
return self
def set_module_name(
self, module_name: str, quantization_config: Optional[QuantizationConfig]
) -> _TOSAQuantizerV1:
# Validate that quantization_config is provided
self.module_name_config[module_name] = quantization_config
return self
def set_io(
self, quantization_config: Optional[QuantizationConfig]
) -> _TOSAQuantizerV1:
self.io_config = quantization_config
return self
def _set_disallow_tfa_for_nodes(self, model: GraphModule) -> None:
"""Populate `disallow_tfa` metadata for each FX node.
Transform-for-annotation passes inspect this flag to decide whether they
may transform a node. Typically, a node should not be transformed in
case it is not to be quantized, which is relevant for partially
quantized models.
"""
# First, set all nodes according to global config
for node in model.graph.nodes:
node.meta[DISALLOW_TFA_META_KEY] = self.global_config is None
# Next, override using module type config to take precedence over global config
for module_type, config in self.module_type_config.items():
mod_type_filter = _get_module_type_filter(module_type)
for node in _for_each_filtered_node(model, mod_type_filter):
node.meta[DISALLOW_TFA_META_KEY] = config is None
# Finally, override using module name config to take precedence over both global and type configs
for module_name, config in self.module_name_config.items():
mod_name_filter = get_module_name_filter(module_name)
for node in _for_each_filtered_node(model, mod_name_filter):
node.meta[DISALLOW_TFA_META_KEY] = config is None
def transform_for_annotation(self, model: GraphModule) -> GraphModule:
self._set_disallow_tfa_for_nodes(model)
pass_manager = ArmPassManager(self.compile_spec)
return pass_manager.transform_for_annotation_pipeline(graph_module=model)
def annotate(self, model: GraphModule) -> GraphModule:
model = self._annotate_for_static_quantization_config(model)
return model
def _annotate_all_static_patterns(
self,
model: GraphModule,
quantization_config: Optional[QuantizationConfig],
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> GraphModule:
"""Annotate all static patterns registered for the backend.
Args:
model (GraphModule): Model to annotate statically.
quantization_config (Optional[QuantizationConfig]): Quantization
specs for input activations, output activations, weights, and
biases.
filter_fn (Optional[Callable[[Node], bool]]): Optional node filter
specifying which nodes to annotate.
Returns:
GraphModule: Model populated with quantization annotations.
"""
# TODO: implement the support for None to be canceling out previous annotations
if quantization_config is None:
return model
annotate_graph(model, quantization_config, filter_fn)
return model
def _annotate_for_static_quantization_config(
self, model: GraphModule
) -> GraphModule:
"""Match QuantizationConfigs to modules before annotating patterns.
Args:
model (GraphModule): Model whose modules are being matched to
quantization configs.