11# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/quantize/quantizer.ipynb.
22
33# %% auto #0
4- __all__ = ['Quantizer' ]
4+ __all__ = ['Quantizer' , 'quantize_mixed' ]
55
66# %% ../../nbs/quantize/quantizer.ipynb #80613b7a-9ee9-4729-80e0-a33e6406a83e
77import torch
88import torch .nn as nn
99from fastcore .basics import store_attr
1010from torch .ao .quantization import QConfig , get_default_qconfig_mapping , get_default_qat_qconfig_mapping
1111from torch .ao .quantization .quantize_fx import prepare_fx , prepare_qat_fx , convert_fx
12- from torch .ao .quantization .observer import MinMaxObserver , MovingAverageMinMaxObserver
12+ from torch .ao .quantization .observer import MinMaxObserver , MovingAverageMinMaxObserver , HistogramObserver
1313from torch .ao .quantization .fake_quantize import FakeQuantize
1414from torch .quantization import quantize_dynamic
1515from torch .ao .quantization .qconfig import default_dynamic_qconfig
3535 _HAS_TORCHAO = False
3636 _HAS_INT4 = False
3737
38+ try :
39+ from torchao .quantization import IntxWeightOnlyConfig
40+ from torchao .quantization .granularity import PerAxis
41+ from torchao .quantization .quant_primitives import MappingType
42+ _HAS_INTX = True
43+ except ImportError :
44+ _HAS_INTX = False
45+
3846_TORCHAO_CONFIGS = {}
3947if _HAS_TORCHAO :
4048 _TORCHAO_CONFIGS ['int8_weight_only' ] = lambda : Int8WeightOnlyConfig ()
4149 _TORCHAO_CONFIGS ['int8_dynamic' ] = lambda : Int8DynamicActivationInt8WeightConfig ()
4250if _HAS_INT4 :
4351 _TORCHAO_CONFIGS ['int4_weight_only' ] = lambda : Int4WeightOnlyConfig (group_size = 128 )
52+ if _HAS_INTX :
53+ _TORCHAO_CONFIGS ['intx_int4' ] = lambda : IntxWeightOnlyConfig (
54+ weight_dtype = torch .int4 ,
55+ granularity = PerAxis (0 ),
56+ mapping_type = MappingType .ASYMMETRIC ,
57+ )
58+ _TORCHAO_CONFIGS ['intx_int8' ] = lambda : IntxWeightOnlyConfig (
59+ weight_dtype = torch .int8 ,
60+ granularity = PerAxis (0 ),
61+ mapping_type = MappingType .SYMMETRIC ,
62+ )
63+
64+ _OBSERVERS = {
65+ 'minmax' : MinMaxObserver ,
66+ 'histogram' : HistogramObserver ,
67+ 'moving_average' : MovingAverageMinMaxObserver ,
68+ }
4469
4570# %% ../../nbs/quantize/quantizer.ipynb #fb1fd84a-dcf6-4ec5-966e-6fdd01e1d19b
4671import contextlib
4772
73+ # IntxWeightOnlyConfig methods that support Conv2d (need explicit filter_fn)
74+ _INTX_METHODS = frozenset ({'intx_int4' , 'intx_int8' })
75+
4876class Quantizer :
4977 def __init__ (self ,
5078 backend : str = "x86" , # Target backend: 'x86', 'qnnpack', 'fbgemm', or 'torchao'
5179 method : str = "static" , # Method: 'static', 'dynamic', 'qat', 'int8_weight_only', 'int8_dynamic'
5280 qconfig_mapping : dict | None = None , # Optional custom quantization config (legacy backends only)
5381 custom_configs : dict | None = None , # Custom module-specific configurations
5482 use_per_tensor : bool = False , # Force per-tensor quantization (legacy backends only)
83+ observer : str = 'minmax' , # Activation observer: 'minmax', 'histogram', 'moving_average'
5584 verbose : bool = False # Enable verbose output
5685 ):
5786 "Initialize a quantizer with specified backend and options."
5887 store_attr ()
5988
89+ if observer not in _OBSERVERS :
90+ raise ValueError (f"Unknown observer: { observer } . Choose from: { list (_OBSERVERS )} " )
91+
6092 if backend == 'torchao' :
6193 if not _HAS_TORCHAO :
6294 raise ImportError ("torchao backend requires torchao. Install with: pip install torchao" )
6395 if method not in _TORCHAO_CONFIGS :
6496 raise ValueError (f"Unknown torchao method '{ method } '. Available: { list (_TORCHAO_CONFIGS .keys ())} " )
97+ if observer != 'minmax' :
98+ warnings .warn ("observer parameter is ignored for torchao backend" )
6599 return
66100
67101 # Legacy backend setup
@@ -89,31 +123,36 @@ def _update_qconfig_for_per_tensor(self):
89123 "Replace per-channel with per-tensor quantization to avoid conversion issues"
90124 if self .verbose :
91125 print ("Using per-tensor quantization instead of per-channel" )
92-
93- if self .method == "qat" :
126+
127+ act_obs_cls = _OBSERVERS [self .observer ]
128+
129+ if self .method == "qat" :
94130 weight_observer = MinMaxObserver .with_args (
95131 dtype = torch .qint8 ,
96132 qscheme = torch .per_tensor_symmetric ,
97133 quant_min = - 128 ,
98134 quant_max = 127
99135 )
100- activation_observer = MovingAverageMinMaxObserver .with_args (
136+ activation_observer = act_obs_cls .with_args (
101137 averaging_constant = 0.01 ,
102138 quant_min = 0 ,
103139 quant_max = 255
140+ ) if self .observer == 'moving_average' else act_obs_cls .with_args (
141+ quant_min = 0 ,
142+ quant_max = 255
104143 )
105144 per_tensor_qconfig = QConfig (
106145 activation = FakeQuantize .with_args (
107146 observer = activation_observer , quant_min = 0 , quant_max = 255 ),
108147 weight = FakeQuantize .with_args (
109148 observer = weight_observer , quant_min = - 128 , quant_max = 127 ))
110149 else :
111- activation_observer = MinMaxObserver .with_args (
150+ activation_observer = act_obs_cls .with_args (
112151 dtype = torch .quint8 , qscheme = torch .per_tensor_affine , quant_min = 0 , quant_max = 255 )
113152 weight_observer = MinMaxObserver .with_args (
114153 dtype = torch .qint8 , qscheme = torch .per_tensor_symmetric , quant_min = - 128 , quant_max = 127 )
115154 per_tensor_qconfig = QConfig (activation = activation_observer , weight = weight_observer )
116-
155+
117156 self .qconfig_mapping .global_qconfig = per_tensor_qconfig
118157
119158 def _apply_custom_configs (self ):
@@ -122,17 +161,17 @@ def _apply_custom_configs(self):
122161 for module_name , config in self .custom_configs .items ():
123162 if self .verbose : print (f"Setting custom config for { module_name } " )
124163 self .qconfig_mapping .set_module_name (module_name , config )
125-
164+
126165 def _prepare_model (self , model , example_inputs ):
127166 "Prepare model for quantization based on selected method"
128167 model = model .cpu ()
129168 model = model .train () if self .method == "qat" else model .eval ()
130-
169+
131170 try :
132171 with self ._quantized_engine ():
133172 if self .method == "static" :
134173 return prepare_fx (model , self .qconfig_mapping , example_inputs )
135- elif self .method == "dynamic" :
174+ elif self .method == "dynamic" :
136175 self .qconfig_mapping .set_object_type (torch .nn .Linear , default_dynamic_qconfig )
137176 self .qconfig_mapping .set_object_type (torch .nn .LSTM , default_dynamic_qconfig )
138177 self .qconfig_mapping .set_object_type (torch .nn .GRU , default_dynamic_qconfig )
@@ -147,20 +186,20 @@ def _prepare_model(self, model, example_inputs):
147186 raise ValueError (f"Unknown quantization method: { self .method } " )
148187 except Exception as e :
149188 raise RuntimeError (f"Error preparing model for quantization: { e } " )
150-
189+
151190 def _calibrate_model (self , model , dataloader , max_samples = None , device = 'cpu' ):
152191 "Calibrate the model on CPU (PyTorch quantization is CPU-only)."
153192 model .eval ()
154193 device = torch .device (device )
155194 model = model .to (device )
156-
195+
157196 num_samples = getattr (dataloader , 'n' , None )
158197 if max_samples is not None and num_samples is not None :
159198 num_samples = min (num_samples , max_samples )
160-
199+
161200 data_iter = dataloader if not self .verbose else tqdm (
162201 dataloader , desc = "Calibrating" , total = num_samples // dataloader .bs if num_samples else None )
163-
202+
164203 samples_seen = 0
165204 with torch .no_grad ():
166205 for i , batch in enumerate (data_iter ):
@@ -174,7 +213,7 @@ def _calibrate_model(self, model, dataloader, max_samples=None, device='cpu'):
174213 batch_size = inputs .shape [0 ] if isinstance (inputs , torch .Tensor ) else inputs [0 ].shape [0 ]
175214 samples_seen += batch_size
176215 if max_samples is not None and samples_seen >= max_samples : break
177-
216+
178217 def _quantize_dynamic (self , model ):
179218 "Quantize a model with dynamic quantization"
180219 try :
@@ -192,10 +231,17 @@ def _quantize_torchao(self, model):
192231 config = _TORCHAO_CONFIGS [self .method ]()
193232 if self .verbose :
194233 print (f"torchao: applying { self .method } ({ type (config ).__name__ } )" )
234+ # IntxWeightOnlyConfig supports Conv2d but needs explicit filter_fn
235+ # Exclude depthwise convolutions (groups>1) — PerAxis(0) fails on (C,1,K,K) weights
236+ if self .method in _INTX_METHODS :
237+ filter_fn = lambda m , fqn : (isinstance (m , nn .Linear ) or
238+ (isinstance (m , nn .Conv2d ) and m .groups == 1 ))
239+ else :
240+ filter_fn = None
195241 with warnings .catch_warnings ():
196242 warnings .simplefilter ('ignore' )
197243 try :
198- quantize_ (model , config )
244+ quantize_ (model , config , filter_fn = filter_fn )
199245 except ImportError as e :
200246 raise ImportError (f"torchao method '{ self .method } ' requires additional dependencies: { e } " )
201247 if self .verbose :
@@ -219,18 +265,18 @@ def quantize(self,
219265 if self .verbose : print (f"Performing dynamic quantization with { self .backend } backend" )
220266 self ._apply_custom_configs ()
221267 return self ._quantize_dynamic (model )
222-
268+
223269 self ._apply_custom_configs ()
224270 example_batch , _ = calibration_dl .one_batch ()
225-
271+
226272 try :
227273 if self .verbose : print (f"Preparing model for { self .method } quantization with { self .backend } backend" )
228274 model_prepared = self ._prepare_model (model , example_batch .cpu ())
229-
275+
230276 if self .method in ["static" , "qat" ]:
231277 if self .verbose : print (f"Calibrating with up to { max_calibration_samples } samples" )
232278 self ._calibrate_model (model_prepared , calibration_dl , max_samples = max_calibration_samples , device = device )
233-
279+
234280 if self .verbose : print ("Converting to quantized model" )
235281 try :
236282 with self ._quantized_engine ():
@@ -243,13 +289,51 @@ def quantize(self,
243289 return self .quantize (model , calibration_dl , max_calibration_samples , device )
244290 else :
245291 raise e
246-
292+
247293 if self .verbose : print ("Quantization complete" )
248294 return quantized_model
249-
295+
250296 except Exception as e :
251297 print (f"Error during quantization: { e } " )
252298 if self .verbose :
253299 import traceback
254300 traceback .print_exc ()
255301 return model
302+
303+ # %% ../../nbs/quantize/quantizer.ipynb #37r0vysj2l2
304+ import warnings as _warnings
305+ from collections import OrderedDict as _OrderedDict
306+
307+ def quantize_mixed (
308+ model : nn .Module , # model to quantize (deepcopied internally)
309+ layer_configs : dict [str , Any | None ], # {fqn: torchao_config_or_None} from to_quant_config()
310+ verbose : bool = False , # print per-layer summary
311+ ) -> nn .Module :
312+ "Apply per-layer quantization using torchao FqnToConfig. Layers mapped to None are skipped."
313+ if not _HAS_TORCHAO :
314+ raise ImportError ("quantize_mixed requires torchao. Install with: pip install torchao" )
315+
316+ from torchao .quantization import quantize_ , FqnToConfig
317+ import copy
318+
319+ model = copy .deepcopy (model ).eval ()
320+
321+ # Filter out None entries and validate FQNs
322+ active = {k : v for k , v in layer_configs .items () if v is not None }
323+ if not active : return model
324+
325+ model_fqns = {n for n , _ in model .named_modules ()}
326+ unmatched = set (active ) - model_fqns
327+ if unmatched :
328+ _warnings .warn (f"quantize_mixed: { len (unmatched )} FQN(s) not found in model: { list (unmatched )[:5 ]} " )
329+
330+ if verbose :
331+ for fqn , cfg in layer_configs .items ():
332+ status = type (cfg ).__name__ if cfg is not None else "SKIP"
333+ print (f" { fqn :30s} → { status } " )
334+
335+ fqn_config = FqnToConfig (fqn_to_config = _OrderedDict (active ))
336+ with _warnings .catch_warnings ():
337+ _warnings .simplefilter ('ignore' )
338+ quantize_ (model , fqn_config , filter_fn = None )
339+ return model
0 commit comments