forked from pytorch/executorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathannotator.py
More file actions
137 lines (117 loc) · 5.23 KB
/
annotator.py
File metadata and controls
137 lines (117 loc) · 5.23 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
from dataclasses import dataclass
from typing import Callable, Dict, Optional, Union
import torch
from executorch.backends.qualcomm.quantizer.rules import _is_float_tensor
from torchao.quantization.pt2e.quantizer import (
QuantizationAnnotation,
QuantizationSpec,
SharedQuantizationSpec,
)
from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY
logger = logging.getLogger(__name__)
@dataclass
class IOQuantConfig:
"""
Quantization config for custom op inputs and outputs.
Attributes:
input_quant_specs: Maps input index to its QuantizationSpec.
Only indices present in the dict are annotated. If None, no inputs
are annotated.
output_quant_specs: Maps output index to its QuantizationSpec.
For single-output ops annotation is done on the op node. For multi-output ops,
each index corresponds to a downstream getitem user. If None, no
outputs are annotated.
"""
input_quant_specs: Optional[
Dict[int, Union[QuantizationSpec, SharedQuantizationSpec]]
] = None
output_quant_specs: Optional[
Dict[int, Union[QuantizationSpec, SharedQuantizationSpec]]
] = None
class CustomOpsQuantAnnotator:
"""
Holds op IOQuantConfigs and builds a single annotation function
compatible with make_quantizer(custom_annotations=...).
"""
def __init__(self):
self._registry: Dict = {} # {op_target: IOQuantConfig}
def register_annotation(
self,
op_target,
io_quant_config: IOQuantConfig,
) -> "CustomOpsQuantAnnotator":
"""
Register quantization config for custom op.
Args:
op_target: The torch op target (e.g. torch.ops.my_ops.custom_op.default).
io_quant_config: IOQuantConfig specifying how to quantize inputs and outputs.
Returns self for method chaining.
"""
self._registry[op_target] = io_quant_config
return self
def build_annotation_fn(self) -> Callable[[torch.fx.GraphModule], None]:
"""
Build and return an annotation function for all registered ops.
The returned function has signature (gm: GraphModule) -> None and
can be passed directly to make_quantizer(custom_annotations=(fn,)).
"""
registry = dict(self._registry)
def annotate_custom_ops(gm: torch.fx.GraphModule) -> None:
for node in gm.graph.nodes:
if node.target not in registry:
continue
cfg = registry[node.target]
input_qspec_map = {}
if cfg.input_quant_specs is not None:
for arg_idx, spec in cfg.input_quant_specs.items():
if arg_idx >= len(node.args):
raise ValueError(
f"IOQuantConfig error for '{node.name}' ({node.target}): "
f"input_quant_specs index {arg_idx} is out of range "
f"(op has {len(node.args)} args)"
)
if not _is_float_tensor(node.args[arg_idx]):
logger.debug(
f"Skipping quantization of input {arg_idx} for "
f"'{node.name}' ({node.target}): expected a float tensor."
)
continue
logger.debug(
f"Annotating input {arg_idx} of '{node.name}' ({node.target}) "
f"with {spec}"
)
input_qspec_map[node.args[arg_idx]] = spec
if not cfg.output_quant_specs or len(cfg.output_quant_specs) <= 1:
# Single output — annotate on the op node
output_spec = (
cfg.output_quant_specs.get(0)
if cfg.output_quant_specs
else None
)
node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=output_spec,
_annotated=True,
)
else:
# Tuple output — push quantization down to getitem users
node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=None,
_annotated=True,
)
for user in node.users:
output_idx = user.args[1]
spec = cfg.output_quant_specs.get(output_idx)
if spec is not None:
user.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
output_qspec=spec,
_annotated=True,
)
return annotate_custom_ops