Skip to content

Commit a43675c

Browse files
authored
support for cadence platform QAT (#18746)
Differential Revision: D99712539 Pull Request resolved: #18746
1 parent 9ca0ff1 commit a43675c

2 files changed

Lines changed: 96 additions & 30 deletions

File tree

backends/cadence/aot/compiler.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,14 +55,15 @@ def trace(
5555
inputs: tuple[object, ...],
5656
dump_graphs: bool = False,
5757
ops_to_keep: Optional[list[torch._ops.OpOverload]] = None,
58+
is_qat: bool = False,
5859
) -> ExportedProgram:
5960
"""
6061
Trace the model with export and return an ExportedProgram.
6162
"""
6263
if ops_to_keep is None:
6364
ops_to_keep = []
6465
program = trace_fn(
65-
model, inputs, is_qat=False, strict=True, ops_to_keep=ops_to_keep
66+
model, inputs, is_qat=is_qat, strict=True, ops_to_keep=ops_to_keep
6667
)
6768

6869
if dump_graphs:
@@ -77,6 +78,7 @@ def prepare_pt2(
7778
inputs: tuple[object, ...],
7879
quantizer: CadenceQuantizer,
7980
dump_graphs: bool = False,
81+
is_qat: bool = False,
8082
) -> torch.fx.GraphModule:
8183
"""
8284
Trace and Prepare a model using the given quantizer.
@@ -89,10 +91,10 @@ def prepare_pt2(
8991

9092
ops_to_keep = quantizer.get_ops_to_preserve_from_decomposition()
9193
traced_program = trace(
92-
model, inputs, dump_graphs=dump_graphs, ops_to_keep=ops_to_keep
94+
model, inputs, dump_graphs=dump_graphs, ops_to_keep=ops_to_keep, is_qat=is_qat
9395
)
9496
prepared_program = prepare_traced_pt2(
95-
traced_program, quantizer, dump_graphs=dump_graphs
97+
traced_program, quantizer, dump_graphs=dump_graphs, is_qat=is_qat
9698
)
9799

98100
return prepared_program
@@ -102,6 +104,7 @@ def prepare_traced_pt2(
102104
program: ExportedProgram,
103105
quantizer: CadenceQuantizer,
104106
dump_graphs: bool = False,
107+
is_qat: bool = False,
105108
) -> torch.fx.GraphModule:
106109
"""
107110
Prepare a model using the given quantizer.
@@ -112,7 +115,7 @@ def prepare_traced_pt2(
112115
Returns a GraphModule with the prepared model.
113116
"""
114117

115-
prepared_model = prepare_fn(program, quantizer, is_qat=False)
118+
prepared_model = prepare_fn(program, quantizer, is_qat=is_qat)
116119

117120
if dump_graphs:
118121
logging.info("Graph after preparation:")

backends/cadence/aot/quantizer/quantizer.py

Lines changed: 89 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,12 @@
4141
no_outside_users,
4242
)
4343
from torch import fx
44-
from torchao.quantization.pt2e import HistogramObserver, MinMaxObserver
44+
from torchao.quantization.pt2e import (
45+
FakeQuantize,
46+
FusedMovingAvgObsFakeQuantize,
47+
HistogramObserver,
48+
MinMaxObserver,
49+
)
4550
from torchao.quantization.pt2e.quantizer import (
4651
ComposableQuantizer,
4752
DerivedQuantizationSpec,
@@ -106,6 +111,47 @@
106111
None,
107112
)
108113

114+
act_qat_qspec_asym8s = QuantizationSpec(
115+
dtype=torch.int8,
116+
quant_min=-128,
117+
quant_max=127,
118+
qscheme=torch.per_tensor_affine,
119+
is_dynamic=False,
120+
observer_or_fake_quant_ctr=FusedMovingAvgObsFakeQuantize,
121+
)
122+
123+
wgt_qat_qspec_asym8s = QuantizationSpec(
124+
dtype=torch.int8,
125+
quant_min=-128,
126+
quant_max=127,
127+
qscheme=torch.per_tensor_affine,
128+
is_dynamic=False,
129+
observer_or_fake_quant_ctr=FakeQuantize.with_args(observer=MinMaxObserver),
130+
)
131+
132+
wgt_qat_qspec_sym8s = QuantizationSpec(
133+
dtype=torch.int8,
134+
quant_min=-128,
135+
quant_max=127,
136+
qscheme=torch.per_tensor_symmetric,
137+
is_dynamic=False,
138+
observer_or_fake_quant_ctr=FakeQuantize.with_args(observer=MinMaxObserver),
139+
)
140+
141+
qconfig_A8W8_qat = QuantizationConfig(
142+
act_qat_qspec_asym8s,
143+
act_qat_qspec_asym8s,
144+
wgt_qat_qspec_asym8s,
145+
None,
146+
)
147+
148+
qconfig_A8W8sym_qat = QuantizationConfig(
149+
act_qat_qspec_asym8s,
150+
act_qat_qspec_asym8s,
151+
wgt_qat_qspec_sym8s,
152+
None,
153+
)
154+
109155
qconfig_A16 = QuantizationConfig(
110156
act_qspec_asym16s,
111157
act_qspec_asym16s,
@@ -221,18 +267,20 @@ def get_supported_operators(cls) -> List[OperatorConfig]:
221267
return []
222268

223269

224-
def get_cadence_default_quantizers() -> List[Quantizer]:
270+
def get_cadence_default_quantizers(is_qat: bool = False) -> List[Quantizer]:
271+
A8W8 = qconfig_A8W8_qat if is_qat else qconfig_A8W8
272+
A8W8sym = qconfig_A8W8sym_qat if is_qat else qconfig_A8W8sym
225273
return [
226-
CadenceAtenQuantizer(AddmmPattern(), qconfig_A8W8),
227-
CadenceAtenQuantizer(BmmPattern(), qconfig_A8W8),
228-
CadenceAtenQuantizer(Conv1dPattern(), qconfig_A8W8sym),
229-
CadenceAtenQuantizer(Conv2dPattern(), qconfig_A8W8sym),
230-
CadenceAtenQuantizer(LinearPattern(), qconfig_A8W8),
231-
CadenceAtenQuantizer(MatmulPattern(), qconfig_A8W8),
232-
CadenceAtenQuantizer(MaxPool2dPattern(), qconfig_A8W8),
233-
CadenceAtenQuantizer(MaxPool2dWithoutIndicesPattern(), qconfig_A8W8),
234-
CadenceAtenQuantizer(ReluPattern0(), qconfig_A8W8),
235-
CadenceAtenQuantizer(ReluPattern1(), qconfig_A8W8),
274+
CadenceAtenQuantizer(AddmmPattern(), A8W8),
275+
CadenceAtenQuantizer(BmmPattern(), A8W8),
276+
CadenceAtenQuantizer(Conv1dPattern(), A8W8sym),
277+
CadenceAtenQuantizer(Conv2dPattern(), A8W8sym),
278+
CadenceAtenQuantizer(LinearPattern(), A8W8),
279+
CadenceAtenQuantizer(MatmulPattern(), A8W8),
280+
CadenceAtenQuantizer(MaxPool2dPattern(), A8W8),
281+
CadenceAtenQuantizer(MaxPool2dWithoutIndicesPattern(), A8W8),
282+
CadenceAtenQuantizer(ReluPattern0(), A8W8),
283+
CadenceAtenQuantizer(ReluPattern1(), A8W8),
236284
]
237285

238286

@@ -270,9 +318,13 @@ class CadenceDefaultQuantizer(CadenceQuantizer):
270318
Default quantizer for Cadence backend.
271319
"""
272320

273-
def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None:
321+
def __init__(
322+
self,
323+
quantizers: Optional[list[Quantizer]] = None,
324+
is_qat: bool = False,
325+
) -> None:
274326
if quantizers is None:
275-
quantizers = get_cadence_default_quantizers()
327+
quantizers = get_cadence_default_quantizers(is_qat=is_qat)
276328
super().__init__(quantizers)
277329

278330

@@ -314,11 +366,16 @@ class CadenceWakeWordQuantizer(CadenceQuantizer):
314366
Quantizer for WakeWord, including add and cat
315367
"""
316368

317-
def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None:
369+
def __init__(
370+
self,
371+
quantizers: Optional[list[Quantizer]] = None,
372+
is_qat: bool = False,
373+
) -> None:
318374
if quantizers is None:
319-
quantizers = get_cadence_default_quantizers()
320-
quantizers.append(CadenceAtenQuantizer(AddPattern(), qconfig_A8W8))
321-
quantizers.append(CadenceAtenQuantizer(CatPattern(), qconfig_A8W8))
375+
quantizers = get_cadence_default_quantizers(is_qat=is_qat)
376+
a8w8 = qconfig_A8W8_qat if is_qat else qconfig_A8W8
377+
quantizers.append(CadenceAtenQuantizer(AddPattern(), a8w8))
378+
quantizers.append(CadenceAtenQuantizer(CatPattern(), a8w8))
322379
super().__init__(quantizers)
323380

324381

@@ -327,17 +384,23 @@ class CadenceFusedConvReluQuantizer(CadenceQuantizer):
327384
Quantizer using fused conv+relu patterns, and including add and cat
328385
"""
329386

330-
def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None:
387+
def __init__(
388+
self,
389+
quantizers: Optional[list[Quantizer]] = None,
390+
is_qat: bool = False,
391+
) -> None:
331392
if quantizers is None:
332393
quantizers = []
394+
a8w8 = qconfig_A8W8_qat if is_qat else qconfig_A8W8
395+
a8w8sym = qconfig_A8W8sym_qat if is_qat else qconfig_A8W8sym
333396
# Order matters here, perform the "fused" patterns first
334-
quantizers.append(CadenceAtenQuantizer(Conv1dReluPattern0(), qconfig_A8W8sym))
335-
quantizers.append(CadenceAtenQuantizer(Conv1dReluPattern1(), qconfig_A8W8sym))
336-
quantizers.append(CadenceAtenQuantizer(Conv2dReluPattern0(), qconfig_A8W8sym))
337-
quantizers.append(CadenceAtenQuantizer(Conv2dReluPattern1(), qconfig_A8W8sym))
338-
quantizers = quantizers + get_cadence_default_quantizers()
339-
quantizers.append(CadenceAtenQuantizer(AddPattern(), qconfig_A8W8))
340-
quantizers.append(CadenceAtenQuantizer(CatPattern(), qconfig_A8W8))
397+
quantizers.append(CadenceAtenQuantizer(Conv1dReluPattern0(), a8w8sym))
398+
quantizers.append(CadenceAtenQuantizer(Conv1dReluPattern1(), a8w8sym))
399+
quantizers.append(CadenceAtenQuantizer(Conv2dReluPattern0(), a8w8sym))
400+
quantizers.append(CadenceAtenQuantizer(Conv2dReluPattern1(), a8w8sym))
401+
quantizers = quantizers + get_cadence_default_quantizers(is_qat=is_qat)
402+
quantizers.append(CadenceAtenQuantizer(AddPattern(), a8w8))
403+
quantizers.append(CadenceAtenQuantizer(CatPattern(), a8w8))
341404
super().__init__(quantizers)
342405

343406

0 commit comments

Comments
 (0)