Skip to content

Commit 4f492a6

Browse files
Reza Sajadianyfacebook-github-bot
authored andcommitted
support for cadence platform QAT
Summary: Adds support for QAT variant of the cadence quantizer with fake quantizer as wrappers Differential Revision: D99712539
1 parent 38c5ca3 commit 4f492a6

2 files changed

Lines changed: 99 additions & 30 deletions

File tree

backends/cadence/aot/compiler.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,14 +55,15 @@ def trace(
5555
inputs: tuple[object, ...],
5656
dump_graphs: bool = False,
5757
ops_to_keep: Optional[list[torch._ops.OpOverload]] = None,
58+
is_qat: bool = False,
5859
) -> ExportedProgram:
5960
"""
6061
Trace the model with export and return an ExportedProgram.
6162
"""
6263
if ops_to_keep is None:
6364
ops_to_keep = []
6465
program = trace_fn(
65-
model, inputs, is_qat=False, strict=True, ops_to_keep=ops_to_keep
66+
model, inputs, is_qat=is_qat, strict=True, ops_to_keep=ops_to_keep
6667
)
6768

6869
if dump_graphs:
@@ -77,6 +78,7 @@ def prepare_pt2(
7778
inputs: tuple[object, ...],
7879
quantizer: CadenceQuantizer,
7980
dump_graphs: bool = False,
81+
is_qat: bool = False,
8082
) -> torch.fx.GraphModule:
8183
"""
8284
Trace and Prepare a model using the given quantizer.
@@ -89,10 +91,10 @@ def prepare_pt2(
8991

9092
ops_to_keep = quantizer.get_ops_to_preserve_from_decomposition()
9193
traced_program = trace(
92-
model, inputs, dump_graphs=dump_graphs, ops_to_keep=ops_to_keep
94+
model, inputs, dump_graphs=dump_graphs, ops_to_keep=ops_to_keep, is_qat=is_qat
9395
)
9496
prepared_program = prepare_traced_pt2(
95-
traced_program, quantizer, dump_graphs=dump_graphs
97+
traced_program, quantizer, dump_graphs=dump_graphs, is_qat=is_qat
9698
)
9799

98100
return prepared_program
@@ -102,6 +104,7 @@ def prepare_traced_pt2(
102104
program: ExportedProgram,
103105
quantizer: CadenceQuantizer,
104106
dump_graphs: bool = False,
107+
is_qat: bool = False,
105108
) -> torch.fx.GraphModule:
106109
"""
107110
Prepare a model using the given quantizer.
@@ -112,7 +115,7 @@ def prepare_traced_pt2(
112115
Returns a GraphModule with the prepared model.
113116
"""
114117

115-
prepared_model = prepare_fn(program, quantizer, is_qat=False)
118+
prepared_model = prepare_fn(program, quantizer, is_qat=is_qat)
116119

117120
if dump_graphs:
118121
logging.info("Graph after preparation:")

backends/cadence/aot/quantizer/quantizer.py

Lines changed: 92 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,12 @@
4141
no_outside_users,
4242
)
4343
from torch import fx
44-
from torchao.quantization.pt2e import HistogramObserver, MinMaxObserver
44+
from torchao.quantization.pt2e import (
45+
FakeQuantize,
46+
FusedMovingAvgObsFakeQuantize,
47+
HistogramObserver,
48+
MinMaxObserver,
49+
)
4550
from torchao.quantization.pt2e.quantizer import (
4651
ComposableQuantizer,
4752
DerivedQuantizationSpec,
@@ -106,6 +111,50 @@
106111
None,
107112
)
108113

114+
# QAT variants: use FakeQuantize wrappers so gradients flow during training
115+
act_qat_qspec_asym8s = QuantizationSpec(
116+
dtype=torch.int8,
117+
quant_min=-128,
118+
quant_max=127,
119+
qscheme=torch.per_tensor_affine,
120+
is_dynamic=False,
121+
observer_or_fake_quant_ctr=FusedMovingAvgObsFakeQuantize.with_args(
122+
observer=HistogramObserver, eps=2**-12
123+
),
124+
)
125+
126+
wgt_qat_qspec_asym8s = QuantizationSpec(
127+
dtype=torch.int8,
128+
quant_min=-128,
129+
quant_max=127,
130+
qscheme=torch.per_tensor_affine,
131+
is_dynamic=False,
132+
observer_or_fake_quant_ctr=FakeQuantize.with_args(observer=MinMaxObserver),
133+
)
134+
135+
wgt_qat_qspec_sym8s = QuantizationSpec(
136+
dtype=torch.int8,
137+
quant_min=-128,
138+
quant_max=127,
139+
qscheme=torch.per_tensor_symmetric,
140+
is_dynamic=False,
141+
observer_or_fake_quant_ctr=FakeQuantize.with_args(observer=MinMaxObserver),
142+
)
143+
144+
qconfig_A8W8_qat = QuantizationConfig(
145+
act_qat_qspec_asym8s,
146+
act_qat_qspec_asym8s,
147+
wgt_qat_qspec_asym8s,
148+
None,
149+
)
150+
151+
qconfig_A8W8sym_qat = QuantizationConfig(
152+
act_qat_qspec_asym8s,
153+
act_qat_qspec_asym8s,
154+
wgt_qat_qspec_sym8s,
155+
None,
156+
)
157+
109158
qconfig_A16 = QuantizationConfig(
110159
act_qspec_asym16s,
111160
act_qspec_asym16s,
@@ -221,18 +270,20 @@ def get_supported_operators(cls) -> List[OperatorConfig]:
221270
return []
222271

223272

224-
def get_cadence_default_quantizers() -> List[Quantizer]:
273+
def get_cadence_default_quantizers(is_qat: bool = False) -> List[Quantizer]:
274+
a8w8 = qconfig_A8W8_qat if is_qat else qconfig_A8W8
275+
a8w8sym = qconfig_A8W8sym_qat if is_qat else qconfig_A8W8sym
225276
return [
226-
CadenceAtenQuantizer(AddmmPattern(), qconfig_A8W8),
227-
CadenceAtenQuantizer(BmmPattern(), qconfig_A8W8),
228-
CadenceAtenQuantizer(Conv1dPattern(), qconfig_A8W8sym),
229-
CadenceAtenQuantizer(Conv2dPattern(), qconfig_A8W8sym),
230-
CadenceAtenQuantizer(LinearPattern(), qconfig_A8W8),
231-
CadenceAtenQuantizer(MatmulPattern(), qconfig_A8W8),
232-
CadenceAtenQuantizer(MaxPool2dPattern(), qconfig_A8W8),
233-
CadenceAtenQuantizer(MaxPool2dWithoutIndicesPattern(), qconfig_A8W8),
234-
CadenceAtenQuantizer(ReluPattern0(), qconfig_A8W8),
235-
CadenceAtenQuantizer(ReluPattern1(), qconfig_A8W8),
277+
CadenceAtenQuantizer(AddmmPattern(), a8w8),
278+
CadenceAtenQuantizer(BmmPattern(), a8w8),
279+
CadenceAtenQuantizer(Conv1dPattern(), a8w8sym),
280+
CadenceAtenQuantizer(Conv2dPattern(), a8w8sym),
281+
CadenceAtenQuantizer(LinearPattern(), a8w8),
282+
CadenceAtenQuantizer(MatmulPattern(), a8w8),
283+
CadenceAtenQuantizer(MaxPool2dPattern(), a8w8),
284+
CadenceAtenQuantizer(MaxPool2dWithoutIndicesPattern(), a8w8),
285+
CadenceAtenQuantizer(ReluPattern0(), a8w8),
286+
CadenceAtenQuantizer(ReluPattern1(), a8w8),
236287
]
237288

238289

@@ -270,9 +321,13 @@ class CadenceDefaultQuantizer(CadenceQuantizer):
270321
Default quantizer for Cadence backend.
271322
"""
272323

273-
def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None:
324+
def __init__(
325+
self,
326+
quantizers: Optional[list[Quantizer]] = None,
327+
is_qat: bool = False,
328+
) -> None:
274329
if quantizers is None:
275-
quantizers = get_cadence_default_quantizers()
330+
quantizers = get_cadence_default_quantizers(is_qat=is_qat)
276331
super().__init__(quantizers)
277332

278333

@@ -314,11 +369,16 @@ class CadenceWakeWordQuantizer(CadenceQuantizer):
314369
Quantizer for WakeWord, including add and cat
315370
"""
316371

317-
def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None:
372+
def __init__(
373+
self,
374+
quantizers: Optional[list[Quantizer]] = None,
375+
is_qat: bool = False,
376+
) -> None:
318377
if quantizers is None:
319-
quantizers = get_cadence_default_quantizers()
320-
quantizers.append(CadenceAtenQuantizer(AddPattern(), qconfig_A8W8))
321-
quantizers.append(CadenceAtenQuantizer(CatPattern(), qconfig_A8W8))
378+
quantizers = get_cadence_default_quantizers(is_qat=is_qat)
379+
a8w8 = qconfig_A8W8_qat if is_qat else qconfig_A8W8
380+
quantizers.append(CadenceAtenQuantizer(AddPattern(), a8w8))
381+
quantizers.append(CadenceAtenQuantizer(CatPattern(), a8w8))
322382
super().__init__(quantizers)
323383

324384

@@ -327,17 +387,23 @@ class CadenceFusedConvReluQuantizer(CadenceQuantizer):
327387
Quantizer using fused conv+relu patterns, and including add and cat
328388
"""
329389

330-
def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None:
390+
def __init__(
391+
self,
392+
quantizers: Optional[list[Quantizer]] = None,
393+
is_qat: bool = False,
394+
) -> None:
331395
if quantizers is None:
332396
quantizers = []
397+
a8w8 = qconfig_A8W8_qat if is_qat else qconfig_A8W8
398+
a8w8sym = qconfig_A8W8sym_qat if is_qat else qconfig_A8W8sym
333399
# Order matters here, perform the "fused" patterns first
334-
quantizers.append(CadenceAtenQuantizer(Conv1dReluPattern0(), qconfig_A8W8sym))
335-
quantizers.append(CadenceAtenQuantizer(Conv1dReluPattern1(), qconfig_A8W8sym))
336-
quantizers.append(CadenceAtenQuantizer(Conv2dReluPattern0(), qconfig_A8W8sym))
337-
quantizers.append(CadenceAtenQuantizer(Conv2dReluPattern1(), qconfig_A8W8sym))
338-
quantizers = quantizers + get_cadence_default_quantizers()
339-
quantizers.append(CadenceAtenQuantizer(AddPattern(), qconfig_A8W8))
340-
quantizers.append(CadenceAtenQuantizer(CatPattern(), qconfig_A8W8))
400+
quantizers.append(CadenceAtenQuantizer(Conv1dReluPattern0(), a8w8sym))
401+
quantizers.append(CadenceAtenQuantizer(Conv1dReluPattern1(), a8w8sym))
402+
quantizers.append(CadenceAtenQuantizer(Conv2dReluPattern0(), a8w8sym))
403+
quantizers.append(CadenceAtenQuantizer(Conv2dReluPattern1(), a8w8sym))
404+
quantizers = quantizers + get_cadence_default_quantizers(is_qat=is_qat)
405+
quantizers.append(CadenceAtenQuantizer(AddPattern(), a8w8))
406+
quantizers.append(CadenceAtenQuantizer(CatPattern(), a8w8))
341407
super().__init__(quantizers)
342408

343409

0 commit comments

Comments
 (0)