4141 no_outside_users ,
4242)
4343from torch import fx
44- from torchao .quantization .pt2e import HistogramObserver , MinMaxObserver
44+ from torchao .quantization .pt2e import (
45+ FakeQuantize ,
46+ FusedMovingAvgObsFakeQuantize ,
47+ HistogramObserver ,
48+ MinMaxObserver ,
49+ )
4550from torchao .quantization .pt2e .quantizer import (
4651 ComposableQuantizer ,
4752 DerivedQuantizationSpec ,
106111 None ,
107112)
108113
114+ act_qat_qspec_asym8s = QuantizationSpec (
115+ dtype = torch .int8 ,
116+ quant_min = - 128 ,
117+ quant_max = 127 ,
118+ qscheme = torch .per_tensor_affine ,
119+ is_dynamic = False ,
120+ observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize ,
121+ )
122+
123+ wgt_qat_qspec_asym8s = QuantizationSpec (
124+ dtype = torch .int8 ,
125+ quant_min = - 128 ,
126+ quant_max = 127 ,
127+ qscheme = torch .per_tensor_affine ,
128+ is_dynamic = False ,
129+ observer_or_fake_quant_ctr = FakeQuantize .with_args (observer = MinMaxObserver ),
130+ )
131+
132+ wgt_qat_qspec_sym8s = QuantizationSpec (
133+ dtype = torch .int8 ,
134+ quant_min = - 128 ,
135+ quant_max = 127 ,
136+ qscheme = torch .per_tensor_symmetric ,
137+ is_dynamic = False ,
138+ observer_or_fake_quant_ctr = FakeQuantize .with_args (observer = MinMaxObserver ),
139+ )
140+
141+ qconfig_A8W8_qat = QuantizationConfig (
142+ act_qat_qspec_asym8s ,
143+ act_qat_qspec_asym8s ,
144+ wgt_qat_qspec_asym8s ,
145+ None ,
146+ )
147+
148+ qconfig_A8W8sym_qat = QuantizationConfig (
149+ act_qat_qspec_asym8s ,
150+ act_qat_qspec_asym8s ,
151+ wgt_qat_qspec_sym8s ,
152+ None ,
153+ )
154+
109155qconfig_A16 = QuantizationConfig (
110156 act_qspec_asym16s ,
111157 act_qspec_asym16s ,
@@ -221,18 +267,20 @@ def get_supported_operators(cls) -> List[OperatorConfig]:
221267 return []
222268
223269
224- def get_cadence_default_quantizers () -> List [Quantizer ]:
270+ def get_cadence_default_quantizers (is_qat : bool = False ) -> List [Quantizer ]:
271+ A8W8 = qconfig_A8W8_qat if is_qat else qconfig_A8W8
272+ A8W8sym = qconfig_A8W8sym_qat if is_qat else qconfig_A8W8sym
225273 return [
226- CadenceAtenQuantizer (AddmmPattern (), qconfig_A8W8 ),
227- CadenceAtenQuantizer (BmmPattern (), qconfig_A8W8 ),
228- CadenceAtenQuantizer (Conv1dPattern (), qconfig_A8W8sym ),
229- CadenceAtenQuantizer (Conv2dPattern (), qconfig_A8W8sym ),
230- CadenceAtenQuantizer (LinearPattern (), qconfig_A8W8 ),
231- CadenceAtenQuantizer (MatmulPattern (), qconfig_A8W8 ),
232- CadenceAtenQuantizer (MaxPool2dPattern (), qconfig_A8W8 ),
233- CadenceAtenQuantizer (MaxPool2dWithoutIndicesPattern (), qconfig_A8W8 ),
234- CadenceAtenQuantizer (ReluPattern0 (), qconfig_A8W8 ),
235- CadenceAtenQuantizer (ReluPattern1 (), qconfig_A8W8 ),
274+ CadenceAtenQuantizer (AddmmPattern (), A8W8 ),
275+ CadenceAtenQuantizer (BmmPattern (), A8W8 ),
276+ CadenceAtenQuantizer (Conv1dPattern (), A8W8sym ),
277+ CadenceAtenQuantizer (Conv2dPattern (), A8W8sym ),
278+ CadenceAtenQuantizer (LinearPattern (), A8W8 ),
279+ CadenceAtenQuantizer (MatmulPattern (), A8W8 ),
280+ CadenceAtenQuantizer (MaxPool2dPattern (), A8W8 ),
281+ CadenceAtenQuantizer (MaxPool2dWithoutIndicesPattern (), A8W8 ),
282+ CadenceAtenQuantizer (ReluPattern0 (), A8W8 ),
283+ CadenceAtenQuantizer (ReluPattern1 (), A8W8 ),
236284 ]
237285
238286
@@ -270,9 +318,13 @@ class CadenceDefaultQuantizer(CadenceQuantizer):
270318 Default quantizer for Cadence backend.
271319 """
272320
273- def __init__ (self , quantizers : Optional [list [Quantizer ]] = None ) -> None :
321+ def __init__ (
322+ self ,
323+ quantizers : Optional [list [Quantizer ]] = None ,
324+ is_qat : bool = False ,
325+ ) -> None :
274326 if quantizers is None :
275- quantizers = get_cadence_default_quantizers ()
327+ quantizers = get_cadence_default_quantizers (is_qat = is_qat )
276328 super ().__init__ (quantizers )
277329
278330
@@ -314,11 +366,16 @@ class CadenceWakeWordQuantizer(CadenceQuantizer):
314366 Quantizer for WakeWord, including add and cat
315367 """
316368
317- def __init__ (self , quantizers : Optional [list [Quantizer ]] = None ) -> None :
369+ def __init__ (
370+ self ,
371+ quantizers : Optional [list [Quantizer ]] = None ,
372+ is_qat : bool = False ,
373+ ) -> None :
318374 if quantizers is None :
319- quantizers = get_cadence_default_quantizers ()
320- quantizers .append (CadenceAtenQuantizer (AddPattern (), qconfig_A8W8 ))
321- quantizers .append (CadenceAtenQuantizer (CatPattern (), qconfig_A8W8 ))
375+ quantizers = get_cadence_default_quantizers (is_qat = is_qat )
376+ a8w8 = qconfig_A8W8_qat if is_qat else qconfig_A8W8
377+ quantizers .append (CadenceAtenQuantizer (AddPattern (), a8w8 ))
378+ quantizers .append (CadenceAtenQuantizer (CatPattern (), a8w8 ))
322379 super ().__init__ (quantizers )
323380
324381
@@ -327,17 +384,23 @@ class CadenceFusedConvReluQuantizer(CadenceQuantizer):
327384 Quantizer using fused conv+relu patterns, and including add and cat
328385 """
329386
330- def __init__ (self , quantizers : Optional [list [Quantizer ]] = None ) -> None :
387+ def __init__ (
388+ self ,
389+ quantizers : Optional [list [Quantizer ]] = None ,
390+ is_qat : bool = False ,
391+ ) -> None :
331392 if quantizers is None :
332393 quantizers = []
394+ a8w8 = qconfig_A8W8_qat if is_qat else qconfig_A8W8
395+ a8w8sym = qconfig_A8W8sym_qat if is_qat else qconfig_A8W8sym
333396 # Order matters here, perform the "fused" patterns first
334- quantizers .append (CadenceAtenQuantizer (Conv1dReluPattern0 (), qconfig_A8W8sym ))
335- quantizers .append (CadenceAtenQuantizer (Conv1dReluPattern1 (), qconfig_A8W8sym ))
336- quantizers .append (CadenceAtenQuantizer (Conv2dReluPattern0 (), qconfig_A8W8sym ))
337- quantizers .append (CadenceAtenQuantizer (Conv2dReluPattern1 (), qconfig_A8W8sym ))
338- quantizers = quantizers + get_cadence_default_quantizers ()
339- quantizers .append (CadenceAtenQuantizer (AddPattern (), qconfig_A8W8 ))
340- quantizers .append (CadenceAtenQuantizer (CatPattern (), qconfig_A8W8 ))
397+ quantizers .append (CadenceAtenQuantizer (Conv1dReluPattern0 (), a8w8sym ))
398+ quantizers .append (CadenceAtenQuantizer (Conv1dReluPattern1 (), a8w8sym ))
399+ quantizers .append (CadenceAtenQuantizer (Conv2dReluPattern0 (), a8w8sym ))
400+ quantizers .append (CadenceAtenQuantizer (Conv2dReluPattern1 (), a8w8sym ))
401+ quantizers = quantizers + get_cadence_default_quantizers (is_qat = is_qat )
402+ quantizers .append (CadenceAtenQuantizer (AddPattern (), a8w8 ))
403+ quantizers .append (CadenceAtenQuantizer (CatPattern (), a8w8 ))
341404 super ().__init__ (quantizers )
342405
343406
0 commit comments