Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion fasterai/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,9 @@
'fasterai.quantize.quantizer.Quantizer._update_qconfig_for_per_tensor': ( 'quantize/quantizer.html#quantizer._update_qconfig_for_per_tensor',
'fasterai/quantize/quantizer.py'),
'fasterai.quantize.quantizer.Quantizer.quantize': ( 'quantize/quantizer.html#quantizer.quantize',
'fasterai/quantize/quantizer.py')},
'fasterai/quantize/quantizer.py'),
'fasterai.quantize.quantizer.quantize_mixed': ( 'quantize/quantizer.html#quantize_mixed',
'fasterai/quantize/quantizer.py')},
'fasterai.regularize.all': {},
'fasterai.regularize.regularize_callback': { 'fasterai.regularize.regularize_callback.RegularizeCallback': ( 'regularize/regularize_callback.html#regularizecallback',
'fasterai/regularize/regularize_callback.py'),
Expand Down
128 changes: 106 additions & 22 deletions fasterai/quantize/quantizer.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/quantize/quantizer.ipynb.

# %% auto #0
__all__ = ['Quantizer']
__all__ = ['Quantizer', 'quantize_mixed']

# %% ../../nbs/quantize/quantizer.ipynb #80613b7a-9ee9-4729-80e0-a33e6406a83e
import torch
import torch.nn as nn
from fastcore.basics import store_attr
from torch.ao.quantization import QConfig, get_default_qconfig_mapping, get_default_qat_qconfig_mapping
from torch.ao.quantization.quantize_fx import prepare_fx, prepare_qat_fx, convert_fx
from torch.ao.quantization.observer import MinMaxObserver, MovingAverageMinMaxObserver
from torch.ao.quantization.observer import MinMaxObserver, MovingAverageMinMaxObserver, HistogramObserver
from torch.ao.quantization.fake_quantize import FakeQuantize
from torch.quantization import quantize_dynamic
from torch.ao.quantization.qconfig import default_dynamic_qconfig
Expand All @@ -35,33 +35,67 @@
_HAS_TORCHAO = False
_HAS_INT4 = False

try:
from torchao.quantization import IntxWeightOnlyConfig
from torchao.quantization.granularity import PerAxis
from torchao.quantization.quant_primitives import MappingType
_HAS_INTX = True
except ImportError:
_HAS_INTX = False

_TORCHAO_CONFIGS = {}
if _HAS_TORCHAO:
_TORCHAO_CONFIGS['int8_weight_only'] = lambda: Int8WeightOnlyConfig()
_TORCHAO_CONFIGS['int8_dynamic'] = lambda: Int8DynamicActivationInt8WeightConfig()
if _HAS_INT4:
_TORCHAO_CONFIGS['int4_weight_only'] = lambda: Int4WeightOnlyConfig(group_size=128)
if _HAS_INTX:
_TORCHAO_CONFIGS['intx_int4'] = lambda: IntxWeightOnlyConfig(
weight_dtype=torch.int4,
granularity=PerAxis(0),
mapping_type=MappingType.ASYMMETRIC,
)
_TORCHAO_CONFIGS['intx_int8'] = lambda: IntxWeightOnlyConfig(
weight_dtype=torch.int8,
granularity=PerAxis(0),
mapping_type=MappingType.SYMMETRIC,
)

_OBSERVERS = {
'minmax': MinMaxObserver,
'histogram': HistogramObserver,
'moving_average': MovingAverageMinMaxObserver,
}

# %% ../../nbs/quantize/quantizer.ipynb #fb1fd84a-dcf6-4ec5-966e-6fdd01e1d19b
import contextlib

# IntxWeightOnlyConfig methods that support Conv2d (need explicit filter_fn)
_INTX_METHODS = frozenset({'intx_int4', 'intx_int8'})

class Quantizer:
def __init__(self,
backend: str = "x86", # Target backend: 'x86', 'qnnpack', 'fbgemm', or 'torchao'
method: str = "static", # Method: 'static', 'dynamic', 'qat', 'int8_weight_only', 'int8_dynamic'
qconfig_mapping: dict | None = None, # Optional custom quantization config (legacy backends only)
custom_configs: dict | None = None, # Custom module-specific configurations
use_per_tensor: bool = False, # Force per-tensor quantization (legacy backends only)
observer: str = 'minmax', # Activation observer: 'minmax', 'histogram', 'moving_average'
verbose: bool = False # Enable verbose output
):
"Initialize a quantizer with specified backend and options."
store_attr()

if observer not in _OBSERVERS:
raise ValueError(f"Unknown observer: {observer}. Choose from: {list(_OBSERVERS)}")

if backend == 'torchao':
if not _HAS_TORCHAO:
raise ImportError("torchao backend requires torchao. Install with: pip install torchao")
if method not in _TORCHAO_CONFIGS:
raise ValueError(f"Unknown torchao method '{method}'. Available: {list(_TORCHAO_CONFIGS.keys())}")
if observer != 'minmax':
warnings.warn("observer parameter is ignored for torchao backend")
return

# Legacy backend setup
Expand Down Expand Up @@ -89,31 +123,36 @@ def _update_qconfig_for_per_tensor(self):
"Replace per-channel with per-tensor quantization to avoid conversion issues"
if self.verbose:
print("Using per-tensor quantization instead of per-channel")

if self.method == "qat":

act_obs_cls = _OBSERVERS[self.observer]

if self.method == "qat":
weight_observer = MinMaxObserver.with_args(
dtype=torch.qint8,
qscheme=torch.per_tensor_symmetric,
quant_min=-128,
quant_max=127
)
activation_observer = MovingAverageMinMaxObserver.with_args(
activation_observer = act_obs_cls.with_args(
averaging_constant=0.01,
quant_min=0,
quant_max=255
) if self.observer == 'moving_average' else act_obs_cls.with_args(
quant_min=0,
quant_max=255
)
per_tensor_qconfig = QConfig(
activation=FakeQuantize.with_args(
observer=activation_observer, quant_min=0, quant_max=255),
weight=FakeQuantize.with_args(
observer=weight_observer, quant_min=-128, quant_max=127))
else:
activation_observer = MinMaxObserver.with_args(
activation_observer = act_obs_cls.with_args(
dtype=torch.quint8, qscheme=torch.per_tensor_affine, quant_min=0, quant_max=255)
weight_observer = MinMaxObserver.with_args(
dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, quant_min=-128, quant_max=127)
per_tensor_qconfig = QConfig(activation=activation_observer, weight=weight_observer)

self.qconfig_mapping.global_qconfig = per_tensor_qconfig

def _apply_custom_configs(self):
Expand All @@ -122,17 +161,17 @@ def _apply_custom_configs(self):
for module_name, config in self.custom_configs.items():
if self.verbose: print(f"Setting custom config for {module_name}")
self.qconfig_mapping.set_module_name(module_name, config)

def _prepare_model(self, model, example_inputs):
"Prepare model for quantization based on selected method"
model = model.cpu()
model = model.train() if self.method == "qat" else model.eval()

try:
with self._quantized_engine():
if self.method == "static":
return prepare_fx(model, self.qconfig_mapping, example_inputs)
elif self.method == "dynamic":
elif self.method == "dynamic":
self.qconfig_mapping.set_object_type(torch.nn.Linear, default_dynamic_qconfig)
self.qconfig_mapping.set_object_type(torch.nn.LSTM, default_dynamic_qconfig)
self.qconfig_mapping.set_object_type(torch.nn.GRU, default_dynamic_qconfig)
Expand All @@ -147,20 +186,20 @@ def _prepare_model(self, model, example_inputs):
raise ValueError(f"Unknown quantization method: {self.method}")
except Exception as e:
raise RuntimeError(f"Error preparing model for quantization: {e}")

def _calibrate_model(self, model, dataloader, max_samples=None, device='cpu'):
"Calibrate the model on CPU (PyTorch quantization is CPU-only)."
model.eval()
device = torch.device(device)
model = model.to(device)

num_samples = getattr(dataloader, 'n', None)
if max_samples is not None and num_samples is not None:
num_samples = min(num_samples, max_samples)

data_iter = dataloader if not self.verbose else tqdm(
dataloader, desc="Calibrating", total=num_samples//dataloader.bs if num_samples else None)

samples_seen = 0
with torch.no_grad():
for i, batch in enumerate(data_iter):
Expand All @@ -174,7 +213,7 @@ def _calibrate_model(self, model, dataloader, max_samples=None, device='cpu'):
batch_size = inputs.shape[0] if isinstance(inputs, torch.Tensor) else inputs[0].shape[0]
samples_seen += batch_size
if max_samples is not None and samples_seen >= max_samples: break

def _quantize_dynamic(self, model):
"Quantize a model with dynamic quantization"
try:
Expand All @@ -192,10 +231,17 @@ def _quantize_torchao(self, model):
config = _TORCHAO_CONFIGS[self.method]()
if self.verbose:
print(f"torchao: applying {self.method} ({type(config).__name__})")
# IntxWeightOnlyConfig supports Conv2d but needs explicit filter_fn
# Exclude depthwise convolutions (groups>1) — PerAxis(0) fails on (C,1,K,K) weights
if self.method in _INTX_METHODS:
filter_fn = lambda m, fqn: (isinstance(m, nn.Linear) or
(isinstance(m, nn.Conv2d) and m.groups == 1))
else:
filter_fn = None
with warnings.catch_warnings():
warnings.simplefilter('ignore')
try:
quantize_(model, config)
quantize_(model, config, filter_fn=filter_fn)
except ImportError as e:
raise ImportError(f"torchao method '{self.method}' requires additional dependencies: {e}")
if self.verbose:
Expand All @@ -219,18 +265,18 @@ def quantize(self,
if self.verbose: print(f"Performing dynamic quantization with {self.backend} backend")
self._apply_custom_configs()
return self._quantize_dynamic(model)

self._apply_custom_configs()
example_batch, _ = calibration_dl.one_batch()

try:
if self.verbose: print(f"Preparing model for {self.method} quantization with {self.backend} backend")
model_prepared = self._prepare_model(model, example_batch.cpu())

if self.method in ["static", "qat"]:
if self.verbose: print(f"Calibrating with up to {max_calibration_samples} samples")
self._calibrate_model(model_prepared, calibration_dl, max_samples=max_calibration_samples, device=device)

if self.verbose: print("Converting to quantized model")
try:
with self._quantized_engine():
Expand All @@ -243,13 +289,51 @@ def quantize(self,
return self.quantize(model, calibration_dl, max_calibration_samples, device)
else:
raise e

if self.verbose: print("Quantization complete")
return quantized_model

except Exception as e:
print(f"Error during quantization: {e}")
if self.verbose:
import traceback
traceback.print_exc()
return model

# %% ../../nbs/quantize/quantizer.ipynb #37r0vysj2l2
import warnings as _warnings
from collections import OrderedDict as _OrderedDict

def quantize_mixed(
model: nn.Module, # model to quantize (deepcopied internally)
layer_configs: dict[str, Any | None], # {fqn: torchao_config_or_None} from to_quant_config()
verbose: bool = False, # print per-layer summary
) -> nn.Module:
"Apply per-layer quantization using torchao FqnToConfig. Layers mapped to None are skipped."
if not _HAS_TORCHAO:
raise ImportError("quantize_mixed requires torchao. Install with: pip install torchao")

from torchao.quantization import quantize_, FqnToConfig
import copy

model = copy.deepcopy(model).eval()

# Filter out None entries and validate FQNs
active = {k: v for k, v in layer_configs.items() if v is not None}
if not active: return model

model_fqns = {n for n, _ in model.named_modules()}
unmatched = set(active) - model_fqns
if unmatched:
_warnings.warn(f"quantize_mixed: {len(unmatched)} FQN(s) not found in model: {list(unmatched)[:5]}")

if verbose:
for fqn, cfg in layer_configs.items():
status = type(cfg).__name__ if cfg is not None else "SKIP"
print(f" {fqn:30s} → {status}")

fqn_config = FqnToConfig(fqn_to_config=_OrderedDict(active))
with _warnings.catch_warnings():
_warnings.simplefilter('ignore')
quantize_(model, fqn_config, filter_fn=None)
return model
Loading
Loading