4141 no_outside_users ,
4242)
4343from torch import fx
44- from torchao .quantization .pt2e import HistogramObserver , MinMaxObserver
44+ from torchao .quantization .pt2e import (
45+ FakeQuantize ,
46+ FusedMovingAvgObsFakeQuantize ,
47+ HistogramObserver ,
48+ MinMaxObserver ,
49+ )
4550from torchao .quantization .pt2e .quantizer import (
4651 ComposableQuantizer ,
4752 DerivedQuantizationSpec ,
106111 None ,
107112)
108113
114+ # QAT variants: use FakeQuantize wrappers so gradients flow during training
115+ act_qat_qspec_asym8s = QuantizationSpec (
116+ dtype = torch .int8 ,
117+ quant_min = - 128 ,
118+ quant_max = 127 ,
119+ qscheme = torch .per_tensor_affine ,
120+ is_dynamic = False ,
121+ observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize .with_args (
122+ observer = HistogramObserver , eps = 2 ** - 12
123+ ),
124+ )
125+
126+ wgt_qat_qspec_asym8s = QuantizationSpec (
127+ dtype = torch .int8 ,
128+ quant_min = - 128 ,
129+ quant_max = 127 ,
130+ qscheme = torch .per_tensor_affine ,
131+ is_dynamic = False ,
132+ observer_or_fake_quant_ctr = FakeQuantize .with_args (observer = MinMaxObserver ),
133+ )
134+
135+ wgt_qat_qspec_sym8s = QuantizationSpec (
136+ dtype = torch .int8 ,
137+ quant_min = - 128 ,
138+ quant_max = 127 ,
139+ qscheme = torch .per_tensor_symmetric ,
140+ is_dynamic = False ,
141+ observer_or_fake_quant_ctr = FakeQuantize .with_args (observer = MinMaxObserver ),
142+ )
143+
144+ qconfig_A8W8_qat = QuantizationConfig (
145+ act_qat_qspec_asym8s ,
146+ act_qat_qspec_asym8s ,
147+ wgt_qat_qspec_asym8s ,
148+ None ,
149+ )
150+
151+ qconfig_A8W8sym_qat = QuantizationConfig (
152+ act_qat_qspec_asym8s ,
153+ act_qat_qspec_asym8s ,
154+ wgt_qat_qspec_sym8s ,
155+ None ,
156+ )
157+
109158qconfig_A16 = QuantizationConfig (
110159 act_qspec_asym16s ,
111160 act_qspec_asym16s ,
@@ -221,18 +270,20 @@ def get_supported_operators(cls) -> List[OperatorConfig]:
221270 return []
222271
223272
224- def get_cadence_default_quantizers () -> List [Quantizer ]:
273+ def get_cadence_default_quantizers (is_qat : bool = False ) -> List [Quantizer ]:
274+ a8w8 = qconfig_A8W8_qat if is_qat else qconfig_A8W8
275+ a8w8sym = qconfig_A8W8sym_qat if is_qat else qconfig_A8W8sym
225276 return [
226- CadenceAtenQuantizer (AddmmPattern (), qconfig_A8W8 ),
227- CadenceAtenQuantizer (BmmPattern (), qconfig_A8W8 ),
228- CadenceAtenQuantizer (Conv1dPattern (), qconfig_A8W8sym ),
229- CadenceAtenQuantizer (Conv2dPattern (), qconfig_A8W8sym ),
230- CadenceAtenQuantizer (LinearPattern (), qconfig_A8W8 ),
231- CadenceAtenQuantizer (MatmulPattern (), qconfig_A8W8 ),
232- CadenceAtenQuantizer (MaxPool2dPattern (), qconfig_A8W8 ),
233- CadenceAtenQuantizer (MaxPool2dWithoutIndicesPattern (), qconfig_A8W8 ),
234- CadenceAtenQuantizer (ReluPattern0 (), qconfig_A8W8 ),
235- CadenceAtenQuantizer (ReluPattern1 (), qconfig_A8W8 ),
277+ CadenceAtenQuantizer (AddmmPattern (), a8w8 ),
278+ CadenceAtenQuantizer (BmmPattern (), a8w8 ),
279+ CadenceAtenQuantizer (Conv1dPattern (), a8w8sym ),
280+ CadenceAtenQuantizer (Conv2dPattern (), a8w8sym ),
281+ CadenceAtenQuantizer (LinearPattern (), a8w8 ),
282+ CadenceAtenQuantizer (MatmulPattern (), a8w8 ),
283+ CadenceAtenQuantizer (MaxPool2dPattern (), a8w8 ),
284+ CadenceAtenQuantizer (MaxPool2dWithoutIndicesPattern (), a8w8 ),
285+ CadenceAtenQuantizer (ReluPattern0 (), a8w8 ),
286+ CadenceAtenQuantizer (ReluPattern1 (), a8w8 ),
236287 ]
237288
238289
@@ -270,9 +321,13 @@ class CadenceDefaultQuantizer(CadenceQuantizer):
270321 Default quantizer for Cadence backend.
271322 """
272323
273- def __init__ (self , quantizers : Optional [list [Quantizer ]] = None ) -> None :
324+ def __init__ (
325+ self ,
326+ quantizers : Optional [list [Quantizer ]] = None ,
327+ is_qat : bool = False ,
328+ ) -> None :
274329 if quantizers is None :
275- quantizers = get_cadence_default_quantizers ()
330+ quantizers = get_cadence_default_quantizers (is_qat = is_qat )
276331 super ().__init__ (quantizers )
277332
278333
@@ -314,11 +369,16 @@ class CadenceWakeWordQuantizer(CadenceQuantizer):
314369 Quantizer for WakeWord, including add and cat
315370 """
316371
317- def __init__ (self , quantizers : Optional [list [Quantizer ]] = None ) -> None :
372+ def __init__ (
373+ self ,
374+ quantizers : Optional [list [Quantizer ]] = None ,
375+ is_qat : bool = False ,
376+ ) -> None :
318377 if quantizers is None :
319- quantizers = get_cadence_default_quantizers ()
320- quantizers .append (CadenceAtenQuantizer (AddPattern (), qconfig_A8W8 ))
321- quantizers .append (CadenceAtenQuantizer (CatPattern (), qconfig_A8W8 ))
378+ quantizers = get_cadence_default_quantizers (is_qat = is_qat )
379+ a8w8 = qconfig_A8W8_qat if is_qat else qconfig_A8W8
380+ quantizers .append (CadenceAtenQuantizer (AddPattern (), a8w8 ))
381+ quantizers .append (CadenceAtenQuantizer (CatPattern (), a8w8 ))
322382 super ().__init__ (quantizers )
323383
324384
@@ -327,17 +387,23 @@ class CadenceFusedConvReluQuantizer(CadenceQuantizer):
327387 Quantizer using fused conv+relu patterns, and including add and cat
328388 """
329389
330- def __init__ (self , quantizers : Optional [list [Quantizer ]] = None ) -> None :
390+ def __init__ (
391+ self ,
392+ quantizers : Optional [list [Quantizer ]] = None ,
393+ is_qat : bool = False ,
394+ ) -> None :
331395 if quantizers is None :
332396 quantizers = []
397+ a8w8 = qconfig_A8W8_qat if is_qat else qconfig_A8W8
398+ a8w8sym = qconfig_A8W8sym_qat if is_qat else qconfig_A8W8sym
333399 # Order matters here, perform the "fused" patterns first
334- quantizers .append (CadenceAtenQuantizer (Conv1dReluPattern0 (), qconfig_A8W8sym ))
335- quantizers .append (CadenceAtenQuantizer (Conv1dReluPattern1 (), qconfig_A8W8sym ))
336- quantizers .append (CadenceAtenQuantizer (Conv2dReluPattern0 (), qconfig_A8W8sym ))
337- quantizers .append (CadenceAtenQuantizer (Conv2dReluPattern1 (), qconfig_A8W8sym ))
338- quantizers = quantizers + get_cadence_default_quantizers ()
339- quantizers .append (CadenceAtenQuantizer (AddPattern (), qconfig_A8W8 ))
340- quantizers .append (CadenceAtenQuantizer (CatPattern (), qconfig_A8W8 ))
400+ quantizers .append (CadenceAtenQuantizer (Conv1dReluPattern0 (), a8w8sym ))
401+ quantizers .append (CadenceAtenQuantizer (Conv1dReluPattern1 (), a8w8sym ))
402+ quantizers .append (CadenceAtenQuantizer (Conv2dReluPattern0 (), a8w8sym ))
403+ quantizers .append (CadenceAtenQuantizer (Conv2dReluPattern1 (), a8w8sym ))
404+ quantizers = quantizers + get_cadence_default_quantizers (is_qat = is_qat )
405+ quantizers .append (CadenceAtenQuantizer (AddPattern (), a8w8 ))
406+ quantizers .append (CadenceAtenQuantizer (CatPattern (), a8w8 ))
341407 super ().__init__ (quantizers )
342408
343409
0 commit comments