-
Notifications
You must be signed in to change notification settings - Fork 930
Expand file tree
/
Copy pathquantization_annotator.py
More file actions
971 lines (848 loc) · 34.8 KB
/
quantization_annotator.py
File metadata and controls
971 lines (848 loc) · 34.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
# Copyright 2024-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""Provide quantization annotation logic for Arm backends.
This module computes per-node quantization properties and applies input/output
annotations to FX graphs using TorchAO qspecs.
"""
import functools
import logging
import operator
from dataclasses import dataclass, replace
from typing import Any, Callable, cast, Iterable, List, NamedTuple, Optional, Sequence
import torch
import torch.fx
from executorch.backends.arm.common.debug import get_node_debug_info
from executorch.backends.arm.common.type import ensure_type
from executorch.backends.arm.quantizer import QuantizationConfig
from torch._subclasses import FakeTensor
from torch.fx import Node
from torchao.quantization.pt2e import (
FakeQuantize,
FusedMovingAvgObsFakeQuantize,
MovingAveragePerChannelMinMaxObserver,
PartialWrapper,
)
from torchao.quantization.pt2e.quantizer import (
annotate_input_qspec_map,
annotate_output_qspec,
FixedQParamsQuantizationSpec,
QuantizationSpec,
QuantizationSpecBase,
SharedQuantizationSpec,
)
from .arm_quantizer_utils import (
is_annotated,
is_output_annotated,
mark_node_as_annotated,
)
logger = logging.getLogger(__name__)
def _is_fused_moving_avg_obs_fake_quant_ctor(func: object) -> bool:
"""Return True when ``func`` is the fused fake-quant class or a subclass."""
return isinstance(func, type) and issubclass(func, FusedMovingAvgObsFakeQuantize)
@dataclass(frozen=True)
class _QuantProperty:
"""Specify how the input/output at 'index' must be quantized."""
index: int
qspec: QuantizationSpecBase | List[QuantizationSpecBase]
optional: bool = False
mark_annotated: bool = False
class _OpQuantProperties:
"""Collect input/output quantization properties for a node.
Attributes:
quant_inputs (List[_QuantProperty]): Quantization specs for inputs
indexed by argument positions.
quant_output (Optional[_QuantProperty]): Quantization spec for the
node's output when applicable.
"""
def __init__(self):
self.quant_inputs: List[_QuantProperty] = []
self.quant_output: Optional[_QuantProperty] = None
class _QParams(NamedTuple):
scale: float
zero_point: int
def _as_list(x):
"""Return ``x`` wrapped as a list if needed.
Args:
x: Value or list of values.
Returns:
list: ``x`` if already a list; otherwise ``[x]``.
"""
if isinstance(x, (list, tuple)):
return x
else:
return [
x,
]
def _adjust_weight_qspec_for_conv_transpose(
node: Node, weight_qspec: QuantizationSpec | None
) -> QuantizationSpec | None:
"""Adjust weight qspec axis/ctor for conv_transpose2d per-channel
quantization.
Use axis 1 for ungrouped ConvTranspose2d weights because the weight layout is
(in_channels, out_channels / groups, kH, kW). Grouped transpose conv keeps axis 0.
If the weight qspec contains a TorchAO QAT fake-quant/observer constructor
(e.g. PartialWrapper(partial(...)) or a with_args-based constructor), the
constructor is rebuilt with the corrected axis. For fused per-channel
FakeQuantize, which only supports axis 0, the constructor is replaced with
a non-fused FakeQuantize + MovingAveragePerChannelMinMaxObserver when the
required axis is not 0.
Return the qspec unchanged when weights are unset.
"""
if (
node.target != torch.ops.aten.conv_transpose2d.input
or weight_qspec is None
or weight_qspec.qscheme != torch.per_channel_symmetric
):
return weight_qspec
# For now skip axis adjustment for a8w4 per-channel configs (int4 weights).
if weight_qspec.quant_min == -7 and weight_qspec.quant_max == 7:
return weight_qspec
groups = 1
if len(node.args) > 6 and isinstance(node.args[6], int):
groups = node.args[6]
expected_axis = 0 if groups != 1 else 1
observer_or_fake_quant_ctr = weight_qspec.observer_or_fake_quant_ctr
observer_or_fake_quant_ctr_changed = False
# QAT FakeQuantize uses PartialWrapper; rebuild its partial to update ch_axis
# without breaking TorchAO introspection.
if isinstance(observer_or_fake_quant_ctr, PartialWrapper):
original_callable_args = dict(observer_or_fake_quant_ctr.callable_args)
base_partial = observer_or_fake_quant_ctr.p
if isinstance(base_partial, functools.partial):
base_keywords = dict(base_partial.keywords or {})
base_keywords["ch_axis"] = expected_axis
if (
_is_fused_moving_avg_obs_fake_quant_ctor(base_partial.func)
and expected_axis != 0
):
# Fused per-channel FakeQuant only supports axis 0; for other axes,
# fall back to FakeQuantize with a per-channel observer.
base_keywords["observer"] = MovingAveragePerChannelMinMaxObserver
observer_or_fake_quant_ctr = PartialWrapper(
functools.partial(FakeQuantize, **base_keywords)
)
else:
observer_or_fake_quant_ctr = PartialWrapper(
functools.partial(base_partial.func, **base_keywords)
)
observer_or_fake_quant_ctr.callable_args = original_callable_args
observer_or_fake_quant_ctr_changed = True
# Non-QAT observer/fake-quant ctrs can be updated via with_args.
elif hasattr(observer_or_fake_quant_ctr, "with_args"):
observer_or_fake_quant_ctr = observer_or_fake_quant_ctr.with_args(
ch_axis=expected_axis
)
observer_or_fake_quant_ctr_changed = True
if weight_qspec.ch_axis == expected_axis and not observer_or_fake_quant_ctr_changed:
return weight_qspec
return QuantizationSpec(
dtype=weight_qspec.dtype,
observer_or_fake_quant_ctr=observer_or_fake_quant_ctr,
quant_min=weight_qspec.quant_min,
quant_max=weight_qspec.quant_max,
qscheme=weight_qspec.qscheme,
ch_axis=expected_axis,
is_dynamic=weight_qspec.is_dynamic,
)
def _is_ok_for_quantization(
node: Node, quant_properties: _OpQuantProperties, gm: torch.fx.GraphModule
) -> bool:
"""Check if a node can be quantized.
A node can be quantized if:
- All inputs that are required for quantization are of type `float32`
and are not large scalar values.
- The output of the node itself is of type `float32` and is not a large
scalar.
Args:
node (Node): The node being analyzed.
quant_properties (_OpQuantProperties): Contains quantization properties
for the node, including input and output quantization specifications.
gm (torch.fx.GraphModule): The graph module containing the computational
graph.
Returns:
bool: `True` if the node can be quantized, otherwise `False`.
"""
# Check output
if quant_properties.quant_output is not None:
if _is_non_float_tensor(node):
logger.debug(
"Could not quantize non float tensor for the following output node: "
f"{get_node_debug_info(node, gm)}"
)
return False
elif _is_large_scalar(node, gm):
logger.debug(
"Could not quantize large scalar node for the following output node: "
f"{get_node_debug_info(node, gm)}"
)
return False
# Check inputs
for quant_property in quant_properties.quant_inputs:
if quant_property.optional and (
quant_property.index >= len(node.args)
or node.args[quant_property.index] is None
):
continue
for n_arg in _as_list(node.args[quant_property.index]):
if not isinstance(n_arg, Node):
raise TypeError(
f"n_arg must be a Node instance, got {type(n_arg).__name__!r}"
)
if _is_non_float_tensor(n_arg):
logger.debug(
"Could not quantize non float tensor for the following input "
f"node: {get_node_debug_info(node, gm)}"
)
return False
elif _is_large_scalar(n_arg, gm):
logger.debug(
"Could not quantize large scalar node for the following input "
f"node: {get_node_debug_info(node, gm)}"
)
return False
return True
def _get_node_target(module: torch.nn.Module | torch.fx.GraphModule, target_str: str):
"""Get an attribute from a module by dotted path.
Args:
module (torch.nn.Module | torch.fx.GraphModule): Root module.
target_str (str): Dotted attribute path, e.g., ``"sub.weight"``.
Returns:
Any: Resolved attribute on the module.
"""
targets = target_str.split(".")
for target in targets[:-1]:
module = module.get_submodule(target)
return getattr(module, targets[-1])
def _is_large_scalar(node: Node, gm: torch.fx.GraphModule):
"""Return True if input is a large scalar value.
Large scalars are skipped because ``torch.histc`` supports values only up
to a certain upper bound.
"""
HISTC_UPPER_BOUND = 3.4028235e15
if node.op == "get_attr" and isinstance(node.target, str):
tensor = _get_node_target(gm, node.target)
# torch.histc works until this upper bound
return tensor.numel() == 1 and abs(tensor.item()) > HISTC_UPPER_BOUND
if node.op == "call_function" and node.target in (
torch.ops.aten.full.default,
torch.ops.aten.full,
torch.ops.aten.fill_.Scalar,
):
fill_value = cast(float, node.args[1])
return abs(fill_value) > HISTC_UPPER_BOUND
return False
def _is_non_float_tensor(node: Node) -> bool:
"""Check if the output of a node has a data type other than `torch.float32`.
If the output is not `torch.float32`, quantization cannot be performed, as
observers only work with floating-point tensors.
Args:
node (Node): The node to check the output(s) for.
Returns:
bool: `True` if the data type is not float32, otherwise `False`.
Note:
- If `node.meta["val"]` is a `list`, the function returns `True` if
any element is not an instance of `FakeTensor` or does not have
`torch.float32` as its data type.
- If node.meta["val"] is missing or is not an instance of `FakeTensor`,
the function returns True.
"""
if "val" in node.meta and isinstance(node.meta["val"], Sequence):
return any(
not isinstance(fake_tensor, FakeTensor)
or fake_tensor.dtype != torch.float32
for fake_tensor in node.meta["val"]
)
if "val" not in node.meta or not isinstance(node.meta["val"], FakeTensor):
return True
return node.meta["val"].dtype != torch.float32
def _annotate_input(node: Node, quant_property: _QuantProperty):
"""Annotate a node's input with the given qspec.
Maps the specified input argument(s) to the provided quantization spec and
optionally marks the input node(s) as annotated.
Args:
node (Node): Node whose input should be annotated.
quant_property (_QuantProperty): Input index and qspec(s).
Raises:
RuntimeError: If the node is already annotated.
TypeError: If an input argument is not a ``Node`` instance.
"""
if is_annotated(node):
raise RuntimeError(
f"Cannot annotate input: node '{node.name}' is already annotated"
)
if quant_property.optional and (
quant_property.index >= len(node.args)
or node.args[quant_property.index] is None
):
return
for n_arg, qspec in zip(
_as_list(node.args[quant_property.index]),
_as_list(quant_property.qspec),
strict=True,
):
if not isinstance(n_arg, Node):
raise TypeError(
f"n_arg must be a Node instance, got {type(n_arg).__name__!r}"
)
annotate_input_qspec_map(node, n_arg, qspec)
if quant_property.mark_annotated:
mark_node_as_annotated(n_arg) # type: ignore[attr-defined]
def _annotate_output(node: Node, quant_property: _QuantProperty):
"""Annotate a node's output with the given qspec.
Args:
node (Node): Node whose output should be annotated.
quant_property (_QuantProperty): Output index and qspec.
Raises:
RuntimeError: If the node is already annotated.
ValueError: If ``mark_annotated`` is True, ``optional`` is True, or
``index`` is not zero.
"""
if is_annotated(node):
raise RuntimeError(
f"Cannot annotate output: node '{node.name}' is already annotated"
)
if quant_property.mark_annotated:
raise ValueError(
"quant_property.mark_annotated must be False for output annotation"
)
if quant_property.optional:
raise ValueError("quant_property.optional must be False for output annotation")
if quant_property.index != 0:
raise ValueError("Only one output annotation supported currently")
annotate_output_qspec(node, quant_property.qspec)
def _match_pattern(
node: Node,
pattern: Sequence[Iterable[object]],
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> bool:
"""Check whether a node chain matches a pattern.
Verify a chain of ancestors -> node -> descendants matches the provided
``pattern``. If ``filter_fn`` is provided, require all nodes in the chain
to pass the filter. Each pattern element is an iterable of disjunctive
node targets.
"""
if len(pattern) < 1:
raise ValueError("No pattern provided")
if filter_fn is not None:
if not filter_fn(node):
return False
if len(pattern) == 1:
# Base case where it has passed the filter_fn. Simply look if node.target is in pattern.
return node.target in pattern[0]
if node.target not in [op for sub_pattern in pattern for op in sub_pattern]:
# node.target not in pattern. No need to look at the rest of the pattern.
return False
# Find the index of this node's target in pattern
idx = [node.target in sub_pattern for sub_pattern in pattern].index(True)
left_pattern = pattern[:idx]
# Exclude idx as this contains node.target which we have already matched
right_pattern = pattern[idx + 1 :]
left_condition = True
right_condition = True
# Recursively look at the rest of the pattern by calling this function for
# node's input and user node with updated patterns.
if len(left_pattern) > 0:
parent = node.all_input_nodes[0]
if len(parent.users) != 1:
return False
left_condition = _match_pattern(parent, left_pattern, filter_fn)
if len(right_pattern) > 0:
right_condition = _match_pattern(list(node.users)[0], right_pattern, filter_fn)
return left_condition and right_condition
_conv_ops = {
torch.ops.aten.conv1d.default,
torch.ops.aten.conv2d.default,
torch.ops.aten.conv2d.padding,
torch.ops.aten.conv_transpose2d.input,
torch.ops.aten.conv3d.default,
torch.ops.aten.conv3d.padding,
}
# For these ops, we use fixed qspecs, meaning that quantization params for
# these are statically defined. This is to prevent issues with out-of-range
# values when using dynamic quantization.
#
# Dict of operator to a dict of num_bits to qparams for that operator.
_fixed_input_qspec_ops: dict[Any, dict[int, _QParams]] = {
# acos has a valid range of [-1, 1]
torch.ops.aten.acos.default: {
8: _QParams((1.0 - (-1.0)) / (1 << 8), 0),
16: _QParams((1.0 - (-1.0)) / (1 << 16), 0),
},
# asin has a valid range of [-1, 1]
torch.ops.aten.asin.default: {
8: _QParams((1.0 - (-1.0)) / (1 << 8), 0),
16: _QParams((1.0 - (-1.0)) / (1 << 16), 0),
},
# atanh has a valid range of (-1, 1) (excluding -1 and 1).
torch.ops.aten.atanh.default: {
8: _QParams((0.999 - (-0.999)) / (1 << 8), 0),
16: _QParams((0.99999 - (-0.99999)) / (1 << 16), 0),
},
}
_one_to_one = {
torch.ops.aten.abs.default,
torch.ops.aten.ceil.default,
torch.ops.aten.erf.default,
torch.ops.aten.erfinv.default,
torch.ops.aten.exp.default,
torch.ops.aten.expm1.default,
torch.ops.aten.elu.default,
torch.ops.aten.floor.default,
torch.ops.aten.log.default,
torch.ops.aten.reciprocal.default,
torch.ops.aten.rsqrt.default,
torch.ops.aten.sigmoid.default,
torch.ops.aten.cos.default,
torch.ops.aten.sin.default,
torch.ops.aten.tanh.default,
torch.ops.aten.sum.dim_IntList,
torch.ops.aten.sum.default,
torch.ops.aten.hardsigmoid.default,
torch.ops.aten.hardswish.default,
torch.ops.aten.hardswish_.default,
torch.ops.aten.full_like.default,
torch.ops.aten.zeros_like.default,
torch.ops.aten.pow.Tensor_Scalar,
torch.ops.aten.gelu.default,
torch.ops.aten.silu.default,
torch.ops.aten.sinh.default,
torch.ops.aten.atan.default,
torch.ops.aten.log1p.default,
torch.ops.aten.log10.default,
torch.ops.aten.acosh.default,
torch.ops.aten.sign.default,
torch.ops.aten.asinh.default,
torch.ops.aten.cosh.default,
torch.ops.aten.cumsum.default,
torch.ops.aten.remainder.Scalar,
torch.ops.aten.tan.default,
}
_one_to_one_shared_input_qspec = {
torch.ops.aten.squeeze.default,
torch.ops.aten.squeeze_copy.default,
torch.ops.aten.squeeze_copy.dim,
torch.ops.aten.squeeze_.dim,
torch.ops.aten.squeeze.dim,
torch.ops.aten.squeeze.dims,
torch.ops.aten.unbind.int,
torch.ops.aten.unsqueeze.default,
torch.ops.aten.unsqueeze_copy.default,
torch.ops.aten.reshape.default,
torch.ops.aten.repeat.default,
torch.ops.aten.repeat_interleave.self_int,
torch.ops.aten.expand_copy.default,
torch.ops.aten.expand.default,
# Disabling these as there seems to be an issue with support for complex
# datatypes in torch:
# torch.ops.aten.view_as_complex.default,
# torch.ops.aten.view_as_complex_copy.default,
# torch.ops.aten.view_as_real.default,
# torch.ops.aten.view_as_real_copy.default,
torch.ops.aten.view.default,
torch.ops.aten.view_as.default,
torch.ops.aten.view_copy.default,
torch.ops.aten._unsafe_view.default,
torch.ops.aten.select.int,
torch.ops.aten.select_copy.int,
torch.ops.aten.slice.Tensor,
torch.ops.aten.slice_copy.Tensor,
torch.ops.aten.split.Tensor,
torch.ops.aten.split_with_sizes.default,
torch.ops.aten.split_copy.Tensor,
torch.ops.aten.transpose.Dimname,
torch.ops.aten.transpose.int,
torch.ops.aten.transpose_copy.int,
torch.ops.aten.t_copy.default,
torch.ops.aten.tile.default,
torch.ops.aten.flip.default,
torch.ops.aten.chunk.default,
torch.ops.aten.contiguous.default,
torch.ops.aten.upsample_bilinear2d.vec,
torch.ops.aten.upsample_nearest2d.vec,
torch.ops.aten.pad.default,
torch.ops.aten.amax.default,
torch.ops.aten.amin.default,
torch.ops.aten.clamp.default,
torch.ops.aten.clamp.Tensor,
torch.ops.aten.unflatten.int,
torch.ops.aten.gather.default,
torch.ops.aten.unfold_copy.default,
torch.ops.aten.index_select.default,
torch.ops.aten.index.Tensor,
torch.ops.aten.as_strided_copy.default,
# Neg operator flips the range, but keps the magnitude the same.
# That is why we force it to use the same qparams and avoid
# dequant -> neg -> requant chain.
torch.ops.aten.neg.default,
torch.ops.aten.detach_copy.default,
}
_one_to_one_shared_input_or_input_act_qspec = {
torch.ops.aten.alias.default,
torch.ops.aten.clone.default,
torch.ops.aten.hardtanh.default,
torch.ops.aten.hardtanh_.default,
torch.ops.aten.relu.default,
torch.ops.aten.relu_.default,
torch.ops.aten.silu_.default,
torch.ops.aten.mean.default,
torch.ops.aten.mean.dim,
torch.ops.aten.permute.default,
torch.ops.aten.permute_copy.default,
torch.ops.aten.avg_pool2d.default,
torch.ops.aten.max_pool2d.default,
torch.ops.aten.flatten.using_ints,
torch.ops.aten.dropout.default,
torch.ops.aten.dropout_.default,
torch.ops.aten.adaptive_avg_pool2d.default,
torch.ops.aten.alias_copy.default,
torch.ops.aten.pixel_shuffle.default,
torch.ops.aten.pixel_unshuffle.default,
}
def get_quant_properties( # noqa: C901
node: Node, gm: torch.fx.GraphModule, quantization_config
) -> _OpQuantProperties | None:
"""Compute quantization properties for a node.
Determine which inputs and/or outputs should be annotated for quantization
based on the node's operator and surrounding pattern.
Args:
node (Node): Node to analyze.
gm (torch.fx.GraphModule): Owning graph module.
quantization_config: Source for activation/weight/bias qspecs.
Returns:
_OpQuantProperties | None: Properties to apply, or ``None`` if the
node is unsupported or not suitable for quantization.
"""
if node.target == torch.ops.aten.conv_transpose2d.input:
weight_qspec = _adjust_weight_qspec_for_conv_transpose(
node, quantization_config.get_weight_qspec()
)
quantization_config = replace(quantization_config, weight=weight_qspec)
input_act_qspec = quantization_config.get_input_act_qspec()
weight_qspec = quantization_config.get_weight_qspec()
output_act_qspec = quantization_config.get_output_act_qspec()
bias_qspec = quantization_config.get_bias_qspec(node)
if output_act_qspec is not None:
# Check if output activation qspec is symmetric. In that case
# we avoid conv + relu fusion for quantization annotation.
is_symmetric = output_act_qspec.qscheme == torch.per_tensor_symmetric
else:
is_symmetric = False
quant_properties = _OpQuantProperties()
def any_or_hardtanh_min_zero(n: Node):
"""Return True for any op or hardtanh with ``min_val == 0``."""
# Check that if the node is a hardtanh, its min_val is zero
return (
n.target
not in (torch.ops.aten.hardtanh.default, torch.ops.aten.hardtanh_.default)
or n.args[1] == 0
)
if not is_symmetric and _match_pattern(
node,
[
_conv_ops,
{torch.ops.aten.batch_norm.default},
{
torch.ops.aten.relu.default,
torch.ops.aten.relu_.default,
torch.ops.aten.hardtanh.default,
torch.ops.aten.hardtanh_.default,
},
],
filter_fn=any_or_hardtanh_min_zero,
):
if node.target in _conv_ops:
conv_weight_qspec = ensure_type(QuantizationSpec, weight_qspec) # For MyPy
quant_properties.quant_inputs = [
_QuantProperty(0, input_act_qspec),
_QuantProperty(1, conv_weight_qspec, mark_annotated=True),
_QuantProperty(2, bias_qspec, optional=True, mark_annotated=True),
]
elif node.target in (
torch.ops.aten.relu.default,
torch.ops.aten.relu_.default,
torch.ops.aten.hardtanh.default,
torch.ops.aten.hardtanh_.default,
):
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
elif _match_pattern(
node,
[
_conv_ops,
{torch.ops.aten.batch_norm.default},
],
):
if node.target in _conv_ops:
conv_weight_qspec = ensure_type(QuantizationSpec, weight_qspec) # For MyPy
quant_properties.quant_inputs = [
_QuantProperty(0, input_act_qspec),
_QuantProperty(1, conv_weight_qspec, mark_annotated=True),
_QuantProperty(2, bias_qspec, optional=True, mark_annotated=True),
]
elif node.target in {torch.ops.aten.batch_norm.default}:
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
elif not is_symmetric and _match_pattern(
node,
[
{
*_conv_ops,
torch.ops.aten.linear.default,
},
{
torch.ops.aten.relu.default,
torch.ops.aten.relu_.default,
torch.ops.aten.hardtanh.default,
torch.ops.aten.hardtanh_.default,
},
],
any_or_hardtanh_min_zero,
):
if node.target in (
*_conv_ops,
torch.ops.aten.linear.default,
):
conv_or_linear_weight_qspec = ensure_type(
QuantizationSpec, weight_qspec
) # For MyPy
quant_properties.quant_inputs = [
_QuantProperty(0, input_act_qspec),
_QuantProperty(1, conv_or_linear_weight_qspec, mark_annotated=True),
_QuantProperty(2, bias_qspec, optional=True, mark_annotated=True),
]
else:
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
elif node.target in (
*_conv_ops,
torch.ops.aten.linear.default,
):
conv_or_linear_weight_qspec = ensure_type(
QuantizationSpec, weight_qspec
) # For MyPy
quant_properties.quant_inputs = [
_QuantProperty(0, input_act_qspec),
_QuantProperty(1, conv_or_linear_weight_qspec, mark_annotated=True),
_QuantProperty(2, bias_qspec, optional=True, mark_annotated=True),
]
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
elif node.target in (
torch.ops.aten.add.Tensor,
torch.ops.aten.sub.Tensor,
torch.ops.aten.mm.default,
torch.ops.aten.bmm.default,
torch.ops.aten.mul.Tensor,
):
quant_properties.quant_inputs = [
_QuantProperty(0, input_act_qspec),
_QuantProperty(1, input_act_qspec),
]
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
elif node.target in (
torch.ops.aten.minimum.default,
torch.ops.aten.maximum.default,
):
quant_properties.quant_inputs = [
_QuantProperty(0, input_act_qspec),
_QuantProperty(1, input_act_qspec),
]
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
elif node.target in (torch.ops.aten.where.self,):
true_node = ensure_type(Node, node.args[1])
input_qspec = (
SharedQuantizationSpec(true_node)
if is_output_annotated(true_node)
else input_act_qspec
)
quant_properties.quant_inputs = [
_QuantProperty(1, input_qspec),
_QuantProperty(2, SharedQuantizationSpec((true_node, node))),
]
quant_properties.quant_output = _QuantProperty(
0,
SharedQuantizationSpec((true_node, node)),
)
elif node.target in _one_to_one_shared_input_or_input_act_qspec:
input_node = ensure_type(Node, node.args[0])
input_qspec = (
SharedQuantizationSpec(input_node)
if is_output_annotated(input_node)
else input_act_qspec
)
quant_properties.quant_inputs = [_QuantProperty(0, input_qspec)]
quant_properties.quant_output = _QuantProperty(
0,
SharedQuantizationSpec((input_node, node)),
)
elif node.target in (
torch.ops.aten.cat.default,
torch.ops.aten.concatenate.default,
torch.ops.aten.stack.default,
):
# first argument should be a non-empty list of nodes
if not isinstance(node.args[0], list):
raise TypeError(
"Expected node.args[0] to be a list, got "
f"{type(node.args[0]).__name__!r}"
)
if len(node.args[0]) == 0:
raise ValueError("Expected non-empty list for node.args[0]")
inputs = [ensure_type(Node, element) for element in node.args[0]]
shared_qspec = SharedQuantizationSpec((inputs[0], node))
quant_properties.quant_inputs = [
_QuantProperty(
0,
[input_act_qspec if n == inputs[0] else shared_qspec for n in inputs],
)
]
quant_properties.quant_output = _QuantProperty(0, shared_qspec) # type: ignore[arg-type]
elif node.target in (
torch.ops.aten.index_put.default,
torch.ops.aten.index_put_.default,
):
shared_qspec = SharedQuantizationSpec((node.args[0], node)) # type: ignore[arg-type]
quant_properties.quant_inputs = [
_QuantProperty(0, input_act_qspec),
_QuantProperty(2, shared_qspec),
]
quant_properties.quant_output = _QuantProperty(0, shared_qspec) # type: ignore[arg-type]
elif node.target in _one_to_one:
quant_properties.quant_inputs = [_QuantProperty(0, input_act_qspec)]
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
elif node.target in _fixed_input_qspec_ops:
num_bits = torch.iinfo(input_act_qspec.dtype).bits
qparams = _fixed_input_qspec_ops[node.target][num_bits]
quant_properties.quant_inputs = [
_QuantProperty(
0,
FixedQParamsQuantizationSpec(
dtype=input_act_qspec.dtype,
scale=qparams.scale,
zero_point=qparams.zero_point,
quant_min=input_act_qspec.quant_min,
quant_max=input_act_qspec.quant_max,
qscheme=input_act_qspec.qscheme,
is_dynamic=input_act_qspec.is_dynamic,
),
)
]
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
elif node.target in _one_to_one_shared_input_qspec:
input_node = ensure_type(Node, node.args[0])
quant_properties.quant_inputs = [_QuantProperty(0, input_act_qspec)]
quant_properties.quant_output = _QuantProperty(
0,
SharedQuantizationSpec((input_node, node)),
)
elif node.target in [torch.ops.aten.copy_.default]:
input_node = ensure_type(Node, node.args[1])
quant_properties.quant_inputs = [
_QuantProperty(0, input_act_qspec),
_QuantProperty(1, input_act_qspec),
]
quant_properties.quant_output = _QuantProperty(
0,
SharedQuantizationSpec((input_node, node)),
)
elif node.target in [
torch.ops.aten.eq.Tensor,
torch.ops.aten.ge.Tensor,
torch.ops.aten.gt.Tensor,
torch.ops.aten.le.Tensor,
torch.ops.aten.lt.Tensor,
]:
input_node = ensure_type(Node, node.args[0])
shared_qspec = SharedQuantizationSpec((input_node, node))
quant_properties.quant_inputs = [
_QuantProperty(0, input_act_qspec),
_QuantProperty(
1,
input_act_qspec if node.args[0] == node.args[1] else shared_qspec,
),
]
quant_properties.quant_output = None
elif node.target in [
torch.ops.aten.full.default,
torch.ops.aten.full,
torch.ops.aten.zeros.default,
torch.ops.aten.ones.default,
torch.ops.aten.fill_.Scalar,
torch.ops.aten.scalar_tensor.default,
]:
quant_properties.quant_inputs = []
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
elif node.target in [operator.getitem]:
input_node = ensure_type(Node, node.args[0])
if not is_output_annotated(input_node):
return None
shared_qspec = SharedQuantizationSpec(input_node)
quant_properties.quant_inputs = [_QuantProperty(0, shared_qspec)]
quant_properties.quant_output = _QuantProperty(0, shared_qspec)
elif node.target in (
torch.ops.higher_order.cond,
torch.ops.higher_order.while_loop,
):
submodule_args_pos = -1 if node.target == torch.ops.higher_order.cond else -2
submodule_args = node.args[submodule_args_pos]
output_qspec = output_act_qspec
if len(submodule_args) > 0: # type: ignore[arg-type]
# The way the TOSA backend handles quantized inputs, arrays of input tensors (such as the input to a
# conditional graph) need shared quantization.
shared_qspec = SharedQuantizationSpec(
(cast(list[Node], submodule_args)[0], node)
)
quant_properties.quant_inputs = [
_QuantProperty(
submodule_args_pos,
[
input_act_qspec,
*([shared_qspec] * (len(submodule_args) - 1)), # type: ignore[arg-type]
],
)
]
if node.target == torch.ops.higher_order.while_loop:
# The output of the while loop body can either re-enter the body, or exit the while loop.
# Therefore, A and B in the diagram below need to share the same quantization parameters.
# A -> while ( RESCALE -> ... RESCALE -> ) -> B
output_qspec = shared_qspec
quant_properties.quant_output = _QuantProperty(0, output_qspec)
else:
return None
# Don't check if operator.getitem is ok for quantization, it's always ok
if node.target == operator.getitem:
return quant_properties
# Check that each inputs/outputs can be quantized properly with the
# provided quantization properties.
if not _is_ok_for_quantization(node, quant_properties, gm):
return None
return quant_properties
def annotate_graph( # type: ignore[return]
gm: torch.fx.GraphModule,
quantization_config: QuantizationConfig,
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> Optional[List[List[Node]]]:
"""Annotate supported nodes in a graph with quantization specs.
Iterate through call_function nodes, computes quantization properties, and
apply input/output annotations. A filter can restrict which nodes are
considered.
Args:
gm (torch.fx.GraphModule): Graph to annotate.
quantization_config (QuantizationConfig): Default qspecs for nodes.
filter_fn (Optional[Callable[[Node], bool]]): Optional node predicate.
Returns:
Optional[List[List[Node]]]: Reserved for future use; currently None.
"""
for node in gm.graph.nodes:
if node.op != "call_function":
continue
if is_annotated(node):
continue
if filter_fn is not None and not filter_fn(node):
continue
quant_properties = get_quant_properties(node, gm, quantization_config)
if quant_properties is None:
continue
for quant_property in quant_properties.quant_inputs:
_annotate_input(node, quant_property)
if quant_properties.quant_output is not None:
_annotate_output(node, quant_properties.quant_output)
mark_node_as_annotated(node) # type: ignore[attr-defined]