Skip to content

Commit 17cfd40

Browse files
committed
feat: quantize_mixed(), IntxWeightOnlyConfig (Conv2d INT4), observer parameter
1. quantize_mixed() — per-layer quantization via torchao FqnToConfig 2. IntxWeightOnlyConfig — Conv2d + Linear INT4/INT8 (excludes depthwise) 3. Observer parameter — histogram/moving_average for better PTQ calibration
1 parent 004d55c commit 17cfd40

3 files changed

Lines changed: 164 additions & 274 deletions

File tree

fasterai/_modidx.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,35 @@
208208
'fasterai/export/onnx_exporter.py'),
209209
'fasterai.export.onnx_exporter.verify_onnx': ( 'export/onnx_exporter.html#verify_onnx',
210210
'fasterai/export/onnx_exporter.py')},
211+
'fasterai.huggingface.all': {},
212+
'fasterai.huggingface.huggingface': { 'fasterai.huggingface.huggingface.HFSparsifyCallback': ( 'huggingface/huggingface.html#hfsparsifycallback',
213+
'fasterai/huggingface/huggingface.py'),
214+
'fasterai.huggingface.huggingface.HFSparsifyCallback.__init__': ( 'huggingface/huggingface.html#hfsparsifycallback.__init__',
215+
'fasterai/huggingface/huggingface.py'),
216+
'fasterai.huggingface.huggingface.HFSparsifyCallback._sparsity_value': ( 'huggingface/huggingface.html#hfsparsifycallback._sparsity_value',
217+
'fasterai/huggingface/huggingface.py'),
218+
'fasterai.huggingface.huggingface.HFSparsifyCallback.on_epoch_end': ( 'huggingface/huggingface.html#hfsparsifycallback.on_epoch_end',
219+
'fasterai/huggingface/huggingface.py'),
220+
'fasterai.huggingface.huggingface.HFSparsifyCallback.on_log': ( 'huggingface/huggingface.html#hfsparsifycallback.on_log',
221+
'fasterai/huggingface/huggingface.py'),
222+
'fasterai.huggingface.huggingface.HFSparsifyCallback.on_optimizer_step': ( 'huggingface/huggingface.html#hfsparsifycallback.on_optimizer_step',
223+
'fasterai/huggingface/huggingface.py'),
224+
'fasterai.huggingface.huggingface.HFSparsifyCallback.on_step_begin': ( 'huggingface/huggingface.html#hfsparsifycallback.on_step_begin',
225+
'fasterai/huggingface/huggingface.py'),
226+
'fasterai.huggingface.huggingface.HFSparsifyCallback.on_train_begin': ( 'huggingface/huggingface.html#hfsparsifycallback.on_train_begin',
227+
'fasterai/huggingface/huggingface.py'),
228+
'fasterai.huggingface.huggingface.HFSparsifyCallback.on_train_end': ( 'huggingface/huggingface.html#hfsparsifycallback.on_train_end',
229+
'fasterai/huggingface/huggingface.py'),
230+
'fasterai.huggingface.huggingface._has_transformers': ( 'huggingface/huggingface.html#_has_transformers',
231+
'fasterai/huggingface/huggingface.py'),
232+
'fasterai.huggingface.huggingface._load_model': ( 'huggingface/huggingface.html#_load_model',
233+
'fasterai/huggingface/huggingface.py'),
234+
'fasterai.huggingface.huggingface._require_transformers': ( 'huggingface/huggingface.html#_require_transformers',
235+
'fasterai/huggingface/huggingface.py'),
236+
'fasterai.huggingface.huggingface._save_compressed': ( 'huggingface/huggingface.html#_save_compressed',
237+
'fasterai/huggingface/huggingface.py'),
238+
'fasterai.huggingface.huggingface.sparsify_model': ( 'huggingface/huggingface.html#sparsify_model',
239+
'fasterai/huggingface/huggingface.py')},
211240
'fasterai.misc.all': {},
212241
'fasterai.misc.bn_folding': { 'fasterai.misc.bn_folding.BN_Folder': ( 'misc/bn_folding.html#bn_folder',
213242
'fasterai/misc/bn_folding.py'),
@@ -305,7 +334,9 @@
305334
'fasterai.quantize.quantizer.Quantizer._update_qconfig_for_per_tensor': ( 'quantize/quantizer.html#quantizer._update_qconfig_for_per_tensor',
306335
'fasterai/quantize/quantizer.py'),
307336
'fasterai.quantize.quantizer.Quantizer.quantize': ( 'quantize/quantizer.html#quantizer.quantize',
308-
'fasterai/quantize/quantizer.py')},
337+
'fasterai/quantize/quantizer.py'),
338+
'fasterai.quantize.quantizer.quantize_mixed': ( 'quantize/quantizer.html#quantize_mixed',
339+
'fasterai/quantize/quantizer.py')},
309340
'fasterai.regularize.all': {},
310341
'fasterai.regularize.regularize_callback': { 'fasterai.regularize.regularize_callback.RegularizeCallback': ( 'regularize/regularize_callback.html#regularizecallback',
311342
'fasterai/regularize/regularize_callback.py'),

fasterai/quantize/quantizer.py

Lines changed: 106 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/quantize/quantizer.ipynb.
22

33
# %% auto #0
4-
__all__ = ['Quantizer']
4+
__all__ = ['Quantizer', 'quantize_mixed']
55

66
# %% ../../nbs/quantize/quantizer.ipynb #80613b7a-9ee9-4729-80e0-a33e6406a83e
77
import torch
88
import torch.nn as nn
99
from fastcore.basics import store_attr
1010
from torch.ao.quantization import QConfig, get_default_qconfig_mapping, get_default_qat_qconfig_mapping
1111
from torch.ao.quantization.quantize_fx import prepare_fx, prepare_qat_fx, convert_fx
12-
from torch.ao.quantization.observer import MinMaxObserver, MovingAverageMinMaxObserver
12+
from torch.ao.quantization.observer import MinMaxObserver, MovingAverageMinMaxObserver, HistogramObserver
1313
from torch.ao.quantization.fake_quantize import FakeQuantize
1414
from torch.quantization import quantize_dynamic
1515
from torch.ao.quantization.qconfig import default_dynamic_qconfig
@@ -35,33 +35,67 @@
3535
_HAS_TORCHAO = False
3636
_HAS_INT4 = False
3737

38+
try:
39+
from torchao.quantization import IntxWeightOnlyConfig
40+
from torchao.quantization.granularity import PerAxis
41+
from torchao.quantization.quant_primitives import MappingType
42+
_HAS_INTX = True
43+
except ImportError:
44+
_HAS_INTX = False
45+
3846
_TORCHAO_CONFIGS = {}
3947
if _HAS_TORCHAO:
4048
_TORCHAO_CONFIGS['int8_weight_only'] = lambda: Int8WeightOnlyConfig()
4149
_TORCHAO_CONFIGS['int8_dynamic'] = lambda: Int8DynamicActivationInt8WeightConfig()
4250
if _HAS_INT4:
4351
_TORCHAO_CONFIGS['int4_weight_only'] = lambda: Int4WeightOnlyConfig(group_size=128)
52+
if _HAS_INTX:
53+
_TORCHAO_CONFIGS['intx_int4'] = lambda: IntxWeightOnlyConfig(
54+
weight_dtype=torch.int4,
55+
granularity=PerAxis(0),
56+
mapping_type=MappingType.ASYMMETRIC,
57+
)
58+
_TORCHAO_CONFIGS['intx_int8'] = lambda: IntxWeightOnlyConfig(
59+
weight_dtype=torch.int8,
60+
granularity=PerAxis(0),
61+
mapping_type=MappingType.SYMMETRIC,
62+
)
63+
64+
_OBSERVERS = {
65+
'minmax': MinMaxObserver,
66+
'histogram': HistogramObserver,
67+
'moving_average': MovingAverageMinMaxObserver,
68+
}
4469

4570
# %% ../../nbs/quantize/quantizer.ipynb #fb1fd84a-dcf6-4ec5-966e-6fdd01e1d19b
4671
import contextlib
4772

73+
# IntxWeightOnlyConfig methods that support Conv2d (need explicit filter_fn)
74+
_INTX_METHODS = frozenset({'intx_int4', 'intx_int8'})
75+
4876
class Quantizer:
4977
def __init__(self,
5078
backend: str = "x86", # Target backend: 'x86', 'qnnpack', 'fbgemm', or 'torchao'
5179
method: str = "static", # Method: 'static', 'dynamic', 'qat', 'int8_weight_only', 'int8_dynamic'
5280
qconfig_mapping: dict | None = None, # Optional custom quantization config (legacy backends only)
5381
custom_configs: dict | None = None, # Custom module-specific configurations
5482
use_per_tensor: bool = False, # Force per-tensor quantization (legacy backends only)
83+
observer: str = 'minmax', # Activation observer: 'minmax', 'histogram', 'moving_average'
5584
verbose: bool = False # Enable verbose output
5685
):
5786
"Initialize a quantizer with specified backend and options."
5887
store_attr()
5988

89+
if observer not in _OBSERVERS:
90+
raise ValueError(f"Unknown observer: {observer}. Choose from: {list(_OBSERVERS)}")
91+
6092
if backend == 'torchao':
6193
if not _HAS_TORCHAO:
6294
raise ImportError("torchao backend requires torchao. Install with: pip install torchao")
6395
if method not in _TORCHAO_CONFIGS:
6496
raise ValueError(f"Unknown torchao method '{method}'. Available: {list(_TORCHAO_CONFIGS.keys())}")
97+
if observer != 'minmax':
98+
warnings.warn("observer parameter is ignored for torchao backend")
6599
return
66100

67101
# Legacy backend setup
@@ -89,31 +123,36 @@ def _update_qconfig_for_per_tensor(self):
89123
"Replace per-channel with per-tensor quantization to avoid conversion issues"
90124
if self.verbose:
91125
print("Using per-tensor quantization instead of per-channel")
92-
93-
if self.method == "qat":
126+
127+
act_obs_cls = _OBSERVERS[self.observer]
128+
129+
if self.method == "qat":
94130
weight_observer = MinMaxObserver.with_args(
95131
dtype=torch.qint8,
96132
qscheme=torch.per_tensor_symmetric,
97133
quant_min=-128,
98134
quant_max=127
99135
)
100-
activation_observer = MovingAverageMinMaxObserver.with_args(
136+
activation_observer = act_obs_cls.with_args(
101137
averaging_constant=0.01,
102138
quant_min=0,
103139
quant_max=255
140+
) if self.observer == 'moving_average' else act_obs_cls.with_args(
141+
quant_min=0,
142+
quant_max=255
104143
)
105144
per_tensor_qconfig = QConfig(
106145
activation=FakeQuantize.with_args(
107146
observer=activation_observer, quant_min=0, quant_max=255),
108147
weight=FakeQuantize.with_args(
109148
observer=weight_observer, quant_min=-128, quant_max=127))
110149
else:
111-
activation_observer = MinMaxObserver.with_args(
150+
activation_observer = act_obs_cls.with_args(
112151
dtype=torch.quint8, qscheme=torch.per_tensor_affine, quant_min=0, quant_max=255)
113152
weight_observer = MinMaxObserver.with_args(
114153
dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, quant_min=-128, quant_max=127)
115154
per_tensor_qconfig = QConfig(activation=activation_observer, weight=weight_observer)
116-
155+
117156
self.qconfig_mapping.global_qconfig = per_tensor_qconfig
118157

119158
def _apply_custom_configs(self):
@@ -122,17 +161,17 @@ def _apply_custom_configs(self):
122161
for module_name, config in self.custom_configs.items():
123162
if self.verbose: print(f"Setting custom config for {module_name}")
124163
self.qconfig_mapping.set_module_name(module_name, config)
125-
164+
126165
def _prepare_model(self, model, example_inputs):
127166
"Prepare model for quantization based on selected method"
128167
model = model.cpu()
129168
model = model.train() if self.method == "qat" else model.eval()
130-
169+
131170
try:
132171
with self._quantized_engine():
133172
if self.method == "static":
134173
return prepare_fx(model, self.qconfig_mapping, example_inputs)
135-
elif self.method == "dynamic":
174+
elif self.method == "dynamic":
136175
self.qconfig_mapping.set_object_type(torch.nn.Linear, default_dynamic_qconfig)
137176
self.qconfig_mapping.set_object_type(torch.nn.LSTM, default_dynamic_qconfig)
138177
self.qconfig_mapping.set_object_type(torch.nn.GRU, default_dynamic_qconfig)
@@ -147,20 +186,20 @@ def _prepare_model(self, model, example_inputs):
147186
raise ValueError(f"Unknown quantization method: {self.method}")
148187
except Exception as e:
149188
raise RuntimeError(f"Error preparing model for quantization: {e}")
150-
189+
151190
def _calibrate_model(self, model, dataloader, max_samples=None, device='cpu'):
152191
"Calibrate the model on CPU (PyTorch quantization is CPU-only)."
153192
model.eval()
154193
device = torch.device(device)
155194
model = model.to(device)
156-
195+
157196
num_samples = getattr(dataloader, 'n', None)
158197
if max_samples is not None and num_samples is not None:
159198
num_samples = min(num_samples, max_samples)
160-
199+
161200
data_iter = dataloader if not self.verbose else tqdm(
162201
dataloader, desc="Calibrating", total=num_samples//dataloader.bs if num_samples else None)
163-
202+
164203
samples_seen = 0
165204
with torch.no_grad():
166205
for i, batch in enumerate(data_iter):
@@ -174,7 +213,7 @@ def _calibrate_model(self, model, dataloader, max_samples=None, device='cpu'):
174213
batch_size = inputs.shape[0] if isinstance(inputs, torch.Tensor) else inputs[0].shape[0]
175214
samples_seen += batch_size
176215
if max_samples is not None and samples_seen >= max_samples: break
177-
216+
178217
def _quantize_dynamic(self, model):
179218
"Quantize a model with dynamic quantization"
180219
try:
@@ -192,10 +231,17 @@ def _quantize_torchao(self, model):
192231
config = _TORCHAO_CONFIGS[self.method]()
193232
if self.verbose:
194233
print(f"torchao: applying {self.method} ({type(config).__name__})")
234+
# IntxWeightOnlyConfig supports Conv2d but needs explicit filter_fn
235+
# Exclude depthwise convolutions (groups>1) — PerAxis(0) fails on (C,1,K,K) weights
236+
if self.method in _INTX_METHODS:
237+
filter_fn = lambda m, fqn: (isinstance(m, nn.Linear) or
238+
(isinstance(m, nn.Conv2d) and m.groups == 1))
239+
else:
240+
filter_fn = None
195241
with warnings.catch_warnings():
196242
warnings.simplefilter('ignore')
197243
try:
198-
quantize_(model, config)
244+
quantize_(model, config, filter_fn=filter_fn)
199245
except ImportError as e:
200246
raise ImportError(f"torchao method '{self.method}' requires additional dependencies: {e}")
201247
if self.verbose:
@@ -219,18 +265,18 @@ def quantize(self,
219265
if self.verbose: print(f"Performing dynamic quantization with {self.backend} backend")
220266
self._apply_custom_configs()
221267
return self._quantize_dynamic(model)
222-
268+
223269
self._apply_custom_configs()
224270
example_batch, _ = calibration_dl.one_batch()
225-
271+
226272
try:
227273
if self.verbose: print(f"Preparing model for {self.method} quantization with {self.backend} backend")
228274
model_prepared = self._prepare_model(model, example_batch.cpu())
229-
275+
230276
if self.method in ["static", "qat"]:
231277
if self.verbose: print(f"Calibrating with up to {max_calibration_samples} samples")
232278
self._calibrate_model(model_prepared, calibration_dl, max_samples=max_calibration_samples, device=device)
233-
279+
234280
if self.verbose: print("Converting to quantized model")
235281
try:
236282
with self._quantized_engine():
@@ -243,13 +289,51 @@ def quantize(self,
243289
return self.quantize(model, calibration_dl, max_calibration_samples, device)
244290
else:
245291
raise e
246-
292+
247293
if self.verbose: print("Quantization complete")
248294
return quantized_model
249-
295+
250296
except Exception as e:
251297
print(f"Error during quantization: {e}")
252298
if self.verbose:
253299
import traceback
254300
traceback.print_exc()
255301
return model
302+
303+
# %% ../../nbs/quantize/quantizer.ipynb #37r0vysj2l2
304+
import warnings as _warnings
305+
from collections import OrderedDict as _OrderedDict
306+
307+
def quantize_mixed(
308+
model: nn.Module, # model to quantize (deepcopied internally)
309+
layer_configs: dict[str, Any | None], # {fqn: torchao_config_or_None} from to_quant_config()
310+
verbose: bool = False, # print per-layer summary
311+
) -> nn.Module:
312+
"Apply per-layer quantization using torchao FqnToConfig. Layers mapped to None are skipped."
313+
if not _HAS_TORCHAO:
314+
raise ImportError("quantize_mixed requires torchao. Install with: pip install torchao")
315+
316+
from torchao.quantization import quantize_, FqnToConfig
317+
import copy
318+
319+
model = copy.deepcopy(model).eval()
320+
321+
# Filter out None entries and validate FQNs
322+
active = {k: v for k, v in layer_configs.items() if v is not None}
323+
if not active: return model
324+
325+
model_fqns = {n for n, _ in model.named_modules()}
326+
unmatched = set(active) - model_fqns
327+
if unmatched:
328+
_warnings.warn(f"quantize_mixed: {len(unmatched)} FQN(s) not found in model: {list(unmatched)[:5]}")
329+
330+
if verbose:
331+
for fqn, cfg in layer_configs.items():
332+
status = type(cfg).__name__ if cfg is not None else "SKIP"
333+
print(f" {fqn:30s}{status}")
334+
335+
fqn_config = FqnToConfig(fqn_to_config=_OrderedDict(active))
336+
with _warnings.catch_warnings():
337+
_warnings.simplefilter('ignore')
338+
quantize_(model, fqn_config, filter_fn=None)
339+
return model

0 commit comments

Comments
 (0)