diff --git a/fasterai/_modidx.py b/fasterai/_modidx.py index b20f500..bfc1870 100644 --- a/fasterai/_modidx.py +++ b/fasterai/_modidx.py @@ -305,7 +305,9 @@ 'fasterai.quantize.quantizer.Quantizer._update_qconfig_for_per_tensor': ( 'quantize/quantizer.html#quantizer._update_qconfig_for_per_tensor', 'fasterai/quantize/quantizer.py'), 'fasterai.quantize.quantizer.Quantizer.quantize': ( 'quantize/quantizer.html#quantizer.quantize', - 'fasterai/quantize/quantizer.py')}, + 'fasterai/quantize/quantizer.py'), + 'fasterai.quantize.quantizer.quantize_mixed': ( 'quantize/quantizer.html#quantize_mixed', + 'fasterai/quantize/quantizer.py')}, 'fasterai.regularize.all': {}, 'fasterai.regularize.regularize_callback': { 'fasterai.regularize.regularize_callback.RegularizeCallback': ( 'regularize/regularize_callback.html#regularizecallback', 'fasterai/regularize/regularize_callback.py'), diff --git a/fasterai/quantize/quantizer.py b/fasterai/quantize/quantizer.py index 2a4e7f0..4cd78eb 100644 --- a/fasterai/quantize/quantizer.py +++ b/fasterai/quantize/quantizer.py @@ -1,7 +1,7 @@ # AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/quantize/quantizer.ipynb. # %% auto #0 -__all__ = ['Quantizer'] +__all__ = ['Quantizer', 'quantize_mixed'] # %% ../../nbs/quantize/quantizer.ipynb #80613b7a-9ee9-4729-80e0-a33e6406a83e import torch @@ -9,7 +9,7 @@ from fastcore.basics import store_attr from torch.ao.quantization import QConfig, get_default_qconfig_mapping, get_default_qat_qconfig_mapping from torch.ao.quantization.quantize_fx import prepare_fx, prepare_qat_fx, convert_fx -from torch.ao.quantization.observer import MinMaxObserver, MovingAverageMinMaxObserver +from torch.ao.quantization.observer import MinMaxObserver, MovingAverageMinMaxObserver, HistogramObserver from torch.ao.quantization.fake_quantize import FakeQuantize from torch.quantization import quantize_dynamic from torch.ao.quantization.qconfig import default_dynamic_qconfig @@ -35,16 +35,44 @@ _HAS_TORCHAO = False _HAS_INT4 = False +try: + from torchao.quantization import IntxWeightOnlyConfig + from torchao.quantization.granularity import PerAxis + from torchao.quantization.quant_primitives import MappingType + _HAS_INTX = True +except ImportError: + _HAS_INTX = False + _TORCHAO_CONFIGS = {} if _HAS_TORCHAO: _TORCHAO_CONFIGS['int8_weight_only'] = lambda: Int8WeightOnlyConfig() _TORCHAO_CONFIGS['int8_dynamic'] = lambda: Int8DynamicActivationInt8WeightConfig() if _HAS_INT4: _TORCHAO_CONFIGS['int4_weight_only'] = lambda: Int4WeightOnlyConfig(group_size=128) +if _HAS_INTX: + _TORCHAO_CONFIGS['intx_int4'] = lambda: IntxWeightOnlyConfig( + weight_dtype=torch.int4, + granularity=PerAxis(0), + mapping_type=MappingType.ASYMMETRIC, + ) + _TORCHAO_CONFIGS['intx_int8'] = lambda: IntxWeightOnlyConfig( + weight_dtype=torch.int8, + granularity=PerAxis(0), + mapping_type=MappingType.SYMMETRIC, + ) + +_OBSERVERS = { + 'minmax': MinMaxObserver, + 'histogram': HistogramObserver, + 'moving_average': MovingAverageMinMaxObserver, +} # %% ../../nbs/quantize/quantizer.ipynb #fb1fd84a-dcf6-4ec5-966e-6fdd01e1d19b import contextlib +# IntxWeightOnlyConfig methods that support Conv2d (need explicit filter_fn) +_INTX_METHODS = frozenset({'intx_int4', 'intx_int8'}) + class Quantizer: def __init__(self, backend: str = "x86", # Target backend: 'x86', 'qnnpack', 'fbgemm', or 'torchao' @@ -52,16 +80,22 @@ def __init__(self, qconfig_mapping: dict | None = None, # Optional custom quantization config (legacy backends only) custom_configs: dict | None = None, # Custom module-specific configurations use_per_tensor: bool = False, # Force per-tensor quantization (legacy backends only) + observer: str = 'minmax', # Activation observer: 'minmax', 'histogram', 'moving_average' verbose: bool = False # Enable verbose output ): "Initialize a quantizer with specified backend and options." store_attr() + if observer not in _OBSERVERS: + raise ValueError(f"Unknown observer: {observer}. Choose from: {list(_OBSERVERS)}") + if backend == 'torchao': if not _HAS_TORCHAO: raise ImportError("torchao backend requires torchao. Install with: pip install torchao") if method not in _TORCHAO_CONFIGS: raise ValueError(f"Unknown torchao method '{method}'. Available: {list(_TORCHAO_CONFIGS.keys())}") + if observer != 'minmax': + warnings.warn("observer parameter is ignored for torchao backend") return # Legacy backend setup @@ -89,18 +123,23 @@ def _update_qconfig_for_per_tensor(self): "Replace per-channel with per-tensor quantization to avoid conversion issues" if self.verbose: print("Using per-tensor quantization instead of per-channel") - - if self.method == "qat": + + act_obs_cls = _OBSERVERS[self.observer] + + if self.method == "qat": weight_observer = MinMaxObserver.with_args( dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, quant_min=-128, quant_max=127 ) - activation_observer = MovingAverageMinMaxObserver.with_args( + activation_observer = act_obs_cls.with_args( averaging_constant=0.01, quant_min=0, quant_max=255 + ) if self.observer == 'moving_average' else act_obs_cls.with_args( + quant_min=0, + quant_max=255 ) per_tensor_qconfig = QConfig( activation=FakeQuantize.with_args( @@ -108,12 +147,12 @@ def _update_qconfig_for_per_tensor(self): weight=FakeQuantize.with_args( observer=weight_observer, quant_min=-128, quant_max=127)) else: - activation_observer = MinMaxObserver.with_args( + activation_observer = act_obs_cls.with_args( dtype=torch.quint8, qscheme=torch.per_tensor_affine, quant_min=0, quant_max=255) weight_observer = MinMaxObserver.with_args( dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, quant_min=-128, quant_max=127) per_tensor_qconfig = QConfig(activation=activation_observer, weight=weight_observer) - + self.qconfig_mapping.global_qconfig = per_tensor_qconfig def _apply_custom_configs(self): @@ -122,17 +161,17 @@ def _apply_custom_configs(self): for module_name, config in self.custom_configs.items(): if self.verbose: print(f"Setting custom config for {module_name}") self.qconfig_mapping.set_module_name(module_name, config) - + def _prepare_model(self, model, example_inputs): "Prepare model for quantization based on selected method" model = model.cpu() model = model.train() if self.method == "qat" else model.eval() - + try: with self._quantized_engine(): if self.method == "static": return prepare_fx(model, self.qconfig_mapping, example_inputs) - elif self.method == "dynamic": + elif self.method == "dynamic": self.qconfig_mapping.set_object_type(torch.nn.Linear, default_dynamic_qconfig) self.qconfig_mapping.set_object_type(torch.nn.LSTM, default_dynamic_qconfig) self.qconfig_mapping.set_object_type(torch.nn.GRU, default_dynamic_qconfig) @@ -147,20 +186,20 @@ def _prepare_model(self, model, example_inputs): raise ValueError(f"Unknown quantization method: {self.method}") except Exception as e: raise RuntimeError(f"Error preparing model for quantization: {e}") - + def _calibrate_model(self, model, dataloader, max_samples=None, device='cpu'): "Calibrate the model on CPU (PyTorch quantization is CPU-only)." model.eval() device = torch.device(device) model = model.to(device) - + num_samples = getattr(dataloader, 'n', None) if max_samples is not None and num_samples is not None: num_samples = min(num_samples, max_samples) - + data_iter = dataloader if not self.verbose else tqdm( dataloader, desc="Calibrating", total=num_samples//dataloader.bs if num_samples else None) - + samples_seen = 0 with torch.no_grad(): for i, batch in enumerate(data_iter): @@ -174,7 +213,7 @@ def _calibrate_model(self, model, dataloader, max_samples=None, device='cpu'): batch_size = inputs.shape[0] if isinstance(inputs, torch.Tensor) else inputs[0].shape[0] samples_seen += batch_size if max_samples is not None and samples_seen >= max_samples: break - + def _quantize_dynamic(self, model): "Quantize a model with dynamic quantization" try: @@ -192,10 +231,17 @@ def _quantize_torchao(self, model): config = _TORCHAO_CONFIGS[self.method]() if self.verbose: print(f"torchao: applying {self.method} ({type(config).__name__})") + # IntxWeightOnlyConfig supports Conv2d but needs explicit filter_fn + # Exclude depthwise convolutions (groups>1) — PerAxis(0) fails on (C,1,K,K) weights + if self.method in _INTX_METHODS: + filter_fn = lambda m, fqn: (isinstance(m, nn.Linear) or + (isinstance(m, nn.Conv2d) and m.groups == 1)) + else: + filter_fn = None with warnings.catch_warnings(): warnings.simplefilter('ignore') try: - quantize_(model, config) + quantize_(model, config, filter_fn=filter_fn) except ImportError as e: raise ImportError(f"torchao method '{self.method}' requires additional dependencies: {e}") if self.verbose: @@ -219,18 +265,18 @@ def quantize(self, if self.verbose: print(f"Performing dynamic quantization with {self.backend} backend") self._apply_custom_configs() return self._quantize_dynamic(model) - + self._apply_custom_configs() example_batch, _ = calibration_dl.one_batch() - + try: if self.verbose: print(f"Preparing model for {self.method} quantization with {self.backend} backend") model_prepared = self._prepare_model(model, example_batch.cpu()) - + if self.method in ["static", "qat"]: if self.verbose: print(f"Calibrating with up to {max_calibration_samples} samples") self._calibrate_model(model_prepared, calibration_dl, max_samples=max_calibration_samples, device=device) - + if self.verbose: print("Converting to quantized model") try: with self._quantized_engine(): @@ -243,13 +289,51 @@ def quantize(self, return self.quantize(model, calibration_dl, max_calibration_samples, device) else: raise e - + if self.verbose: print("Quantization complete") return quantized_model - + except Exception as e: print(f"Error during quantization: {e}") if self.verbose: import traceback traceback.print_exc() return model + +# %% ../../nbs/quantize/quantizer.ipynb #37r0vysj2l2 +import warnings as _warnings +from collections import OrderedDict as _OrderedDict + +def quantize_mixed( + model: nn.Module, # model to quantize (deepcopied internally) + layer_configs: dict[str, Any | None], # {fqn: torchao_config_or_None} from to_quant_config() + verbose: bool = False, # print per-layer summary +) -> nn.Module: + "Apply per-layer quantization using torchao FqnToConfig. Layers mapped to None are skipped." + if not _HAS_TORCHAO: + raise ImportError("quantize_mixed requires torchao. Install with: pip install torchao") + + from torchao.quantization import quantize_, FqnToConfig + import copy + + model = copy.deepcopy(model).eval() + + # Filter out None entries and validate FQNs + active = {k: v for k, v in layer_configs.items() if v is not None} + if not active: return model + + model_fqns = {n for n, _ in model.named_modules()} + unmatched = set(active) - model_fqns + if unmatched: + _warnings.warn(f"quantize_mixed: {len(unmatched)} FQN(s) not found in model: {list(unmatched)[:5]}") + + if verbose: + for fqn, cfg in layer_configs.items(): + status = type(cfg).__name__ if cfg is not None else "SKIP" + print(f" {fqn:30s} → {status}") + + fqn_config = FqnToConfig(fqn_to_config=_OrderedDict(active)) + with _warnings.catch_warnings(): + _warnings.simplefilter('ignore') + quantize_(model, fqn_config, filter_fn=None) + return model diff --git a/nbs/quantize/quantizer.ipynb b/nbs/quantize/quantizer.ipynb index bb138ce..0e77af6 100644 --- a/nbs/quantize/quantizer.ipynb +++ b/nbs/quantize/quantizer.ipynb @@ -40,46 +40,7 @@ "id": "80613b7a-9ee9-4729-80e0-a33e6406a83e", "metadata": {}, "outputs": [], - "source": [ - "#| export\n", - "import torch\n", - "import torch.nn as nn\n", - "from fastcore.basics import store_attr\n", - "from torch.ao.quantization import QConfig, get_default_qconfig_mapping, get_default_qat_qconfig_mapping\n", - "from torch.ao.quantization.quantize_fx import prepare_fx, prepare_qat_fx, convert_fx\n", - "from torch.ao.quantization.observer import MinMaxObserver, MovingAverageMinMaxObserver\n", - "from torch.ao.quantization.fake_quantize import FakeQuantize\n", - "from torch.quantization import quantize_dynamic\n", - "from torch.ao.quantization.qconfig import default_dynamic_qconfig\n", - "from typing import Any\n", - "import warnings\n", - "import copy\n", - "from tqdm import tqdm\n", - "\n", - "try:\n", - " from torchao.quantization import quantize_, Int8WeightOnlyConfig\n", - " from torchao.quantization import Int8DynamicActivationInt8WeightConfig\n", - " _HAS_TORCHAO = True\n", - " # INT4 requires additional kernel libraries — check availability\n", - " try:\n", - " from torchao.quantization import Int4WeightOnlyConfig\n", - " _m = nn.Linear(128, 128)\n", - " quantize_(_m, Int4WeightOnlyConfig(group_size=128))\n", - " _HAS_INT4 = True\n", - " del _m\n", - " except (ImportError, Exception):\n", - " _HAS_INT4 = False\n", - "except ImportError:\n", - " _HAS_TORCHAO = False\n", - " _HAS_INT4 = False\n", - "\n", - "_TORCHAO_CONFIGS = {}\n", - "if _HAS_TORCHAO:\n", - " _TORCHAO_CONFIGS['int8_weight_only'] = lambda: Int8WeightOnlyConfig()\n", - " _TORCHAO_CONFIGS['int8_dynamic'] = lambda: Int8DynamicActivationInt8WeightConfig()\n", - "if _HAS_INT4:\n", - " _TORCHAO_CONFIGS['int4_weight_only'] = lambda: Int4WeightOnlyConfig(group_size=128)" - ] + "source": "#| export\nimport torch\nimport torch.nn as nn\nfrom fastcore.basics import store_attr\nfrom torch.ao.quantization import QConfig, get_default_qconfig_mapping, get_default_qat_qconfig_mapping\nfrom torch.ao.quantization.quantize_fx import prepare_fx, prepare_qat_fx, convert_fx\nfrom torch.ao.quantization.observer import MinMaxObserver, MovingAverageMinMaxObserver, HistogramObserver\nfrom torch.ao.quantization.fake_quantize import FakeQuantize\nfrom torch.quantization import quantize_dynamic\nfrom torch.ao.quantization.qconfig import default_dynamic_qconfig\nfrom typing import Any\nimport warnings\nimport copy\nfrom tqdm import tqdm\n\ntry:\n from torchao.quantization import quantize_, Int8WeightOnlyConfig\n from torchao.quantization import Int8DynamicActivationInt8WeightConfig\n _HAS_TORCHAO = True\n # INT4 requires additional kernel libraries — check availability\n try:\n from torchao.quantization import Int4WeightOnlyConfig\n _m = nn.Linear(128, 128)\n quantize_(_m, Int4WeightOnlyConfig(group_size=128))\n _HAS_INT4 = True\n del _m\n except (ImportError, Exception):\n _HAS_INT4 = False\nexcept ImportError:\n _HAS_TORCHAO = False\n _HAS_INT4 = False\n\ntry:\n from torchao.quantization import IntxWeightOnlyConfig\n from torchao.quantization.granularity import PerAxis\n from torchao.quantization.quant_primitives import MappingType\n _HAS_INTX = True\nexcept ImportError:\n _HAS_INTX = False\n\n_TORCHAO_CONFIGS = {}\nif _HAS_TORCHAO:\n _TORCHAO_CONFIGS['int8_weight_only'] = lambda: Int8WeightOnlyConfig()\n _TORCHAO_CONFIGS['int8_dynamic'] = lambda: Int8DynamicActivationInt8WeightConfig()\nif _HAS_INT4:\n _TORCHAO_CONFIGS['int4_weight_only'] = lambda: Int4WeightOnlyConfig(group_size=128)\nif _HAS_INTX:\n _TORCHAO_CONFIGS['intx_int4'] = lambda: IntxWeightOnlyConfig(\n weight_dtype=torch.int4,\n granularity=PerAxis(0),\n mapping_type=MappingType.ASYMMETRIC,\n )\n _TORCHAO_CONFIGS['intx_int8'] = lambda: IntxWeightOnlyConfig(\n weight_dtype=torch.int8,\n granularity=PerAxis(0),\n mapping_type=MappingType.SYMMETRIC,\n )\n\n_OBSERVERS = {\n 'minmax': MinMaxObserver,\n 'histogram': HistogramObserver,\n 'moving_average': MovingAverageMinMaxObserver,\n}" }, { "cell_type": "markdown", @@ -123,219 +84,33 @@ "metadata": {}, "outputs": [], "source": [ - "#| export\n", - "import contextlib\n", - "\n", - "class Quantizer:\n", - " def __init__(self, \n", - " backend: str = \"x86\", # Target backend: 'x86', 'qnnpack', 'fbgemm', or 'torchao'\n", - " method: str = \"static\", # Method: 'static', 'dynamic', 'qat', 'int8_weight_only', 'int8_dynamic'\n", - " qconfig_mapping: dict | None = None, # Optional custom quantization config (legacy backends only)\n", - " custom_configs: dict | None = None, # Custom module-specific configurations\n", - " use_per_tensor: bool = False, # Force per-tensor quantization (legacy backends only)\n", - " verbose: bool = False # Enable verbose output\n", - " ):\n", - " \"Initialize a quantizer with specified backend and options.\"\n", - " store_attr()\n", - "\n", - " if backend == 'torchao':\n", - " if not _HAS_TORCHAO:\n", - " raise ImportError(\"torchao backend requires torchao. Install with: pip install torchao\")\n", - " if method not in _TORCHAO_CONFIGS:\n", - " raise ValueError(f\"Unknown torchao method '{method}'. Available: {list(_TORCHAO_CONFIGS.keys())}\")\n", - " return\n", - "\n", - " # Legacy backend setup\n", - " if qconfig_mapping is None:\n", - " if method == \"qat\":\n", - " self.qconfig_mapping = get_default_qat_qconfig_mapping(backend)\n", - " else:\n", - " self.qconfig_mapping = get_default_qconfig_mapping(backend)\n", - " if use_per_tensor:\n", - " self._update_qconfig_for_per_tensor()\n", - " else:\n", - " self.qconfig_mapping = qconfig_mapping\n", - "\n", - " @contextlib.contextmanager\n", - " def _quantized_engine(self):\n", - " \"Context manager to temporarily set the quantization backend engine.\"\n", - " old_engine = torch.backends.quantized.engine\n", - " torch.backends.quantized.engine = self.backend\n", - " try:\n", - " yield\n", - " finally:\n", - " torch.backends.quantized.engine = old_engine\n", - "\n", - " def _update_qconfig_for_per_tensor(self):\n", - " \"Replace per-channel with per-tensor quantization to avoid conversion issues\"\n", - " if self.verbose:\n", - " print(\"Using per-tensor quantization instead of per-channel\")\n", - " \n", - " if self.method == \"qat\": \n", - " weight_observer = MinMaxObserver.with_args(\n", - " dtype=torch.qint8,\n", - " qscheme=torch.per_tensor_symmetric,\n", - " quant_min=-128,\n", - " quant_max=127\n", - " )\n", - " activation_observer = MovingAverageMinMaxObserver.with_args(\n", - " averaging_constant=0.01,\n", - " quant_min=0,\n", - " quant_max=255\n", - " )\n", - " per_tensor_qconfig = QConfig(\n", - " activation=FakeQuantize.with_args(\n", - " observer=activation_observer, quant_min=0, quant_max=255),\n", - " weight=FakeQuantize.with_args(\n", - " observer=weight_observer, quant_min=-128, quant_max=127))\n", - " else:\n", - " activation_observer = MinMaxObserver.with_args(\n", - " dtype=torch.quint8, qscheme=torch.per_tensor_affine, quant_min=0, quant_max=255)\n", - " weight_observer = MinMaxObserver.with_args(\n", - " dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, quant_min=-128, quant_max=127)\n", - " per_tensor_qconfig = QConfig(activation=activation_observer, weight=weight_observer)\n", - " \n", - " self.qconfig_mapping.global_qconfig = per_tensor_qconfig\n", - "\n", - " def _apply_custom_configs(self):\n", - " \"Apply custom quantization configurations to specific modules\"\n", - " if not self.custom_configs: return\n", - " for module_name, config in self.custom_configs.items():\n", - " if self.verbose: print(f\"Setting custom config for {module_name}\")\n", - " self.qconfig_mapping.set_module_name(module_name, config)\n", - " \n", - " def _prepare_model(self, model, example_inputs):\n", - " \"Prepare model for quantization based on selected method\"\n", - " model = model.cpu()\n", - " model = model.train() if self.method == \"qat\" else model.eval()\n", - " \n", - " try:\n", - " with self._quantized_engine():\n", - " if self.method == \"static\":\n", - " return prepare_fx(model, self.qconfig_mapping, example_inputs)\n", - " elif self.method == \"dynamic\": \n", - " self.qconfig_mapping.set_object_type(torch.nn.Linear, default_dynamic_qconfig)\n", - " self.qconfig_mapping.set_object_type(torch.nn.LSTM, default_dynamic_qconfig)\n", - " self.qconfig_mapping.set_object_type(torch.nn.GRU, default_dynamic_qconfig)\n", - " self.qconfig_mapping.set_object_type(torch.nn.RNN, default_dynamic_qconfig)\n", - " if self.custom_configs:\n", - " for module_name, config in self.custom_configs.items():\n", - " self.qconfig_mapping.set_module_name(module_name, config)\n", - " return prepare_fx(model, self.qconfig_mapping, example_inputs)\n", - " elif self.method == \"qat\":\n", - " return prepare_qat_fx(model, self.qconfig_mapping, example_inputs)\n", - " else:\n", - " raise ValueError(f\"Unknown quantization method: {self.method}\")\n", - " except Exception as e:\n", - " raise RuntimeError(f\"Error preparing model for quantization: {e}\")\n", - " \n", - " def _calibrate_model(self, model, dataloader, max_samples=None, device='cpu'):\n", - " \"Calibrate the model on CPU (PyTorch quantization is CPU-only).\"\n", - " model.eval()\n", - " device = torch.device(device)\n", - " model = model.to(device)\n", - " \n", - " num_samples = getattr(dataloader, 'n', None)\n", - " if max_samples is not None and num_samples is not None:\n", - " num_samples = min(num_samples, max_samples)\n", - " \n", - " data_iter = dataloader if not self.verbose else tqdm(\n", - " dataloader, desc=\"Calibrating\", total=num_samples//dataloader.bs if num_samples else None)\n", - " \n", - " samples_seen = 0\n", - " with torch.no_grad():\n", - " for i, batch in enumerate(data_iter):\n", - " inputs = batch[0] if isinstance(batch, (list, tuple)) and len(batch) >= 1 else batch\n", - " if hasattr(inputs, 'data'): inputs = inputs.data\n", - " if isinstance(inputs, (list, tuple)):\n", - " inputs = [x.to(device) if isinstance(x, torch.Tensor) else x for x in inputs]\n", - " else:\n", - " inputs = inputs.to(device)\n", - " model(inputs)\n", - " batch_size = inputs.shape[0] if isinstance(inputs, torch.Tensor) else inputs[0].shape[0]\n", - " samples_seen += batch_size\n", - " if max_samples is not None and samples_seen >= max_samples: break\n", - " \n", - " def _quantize_dynamic(self, model):\n", - " \"Quantize a model with dynamic quantization\"\n", - " try:\n", - " model_copy = copy.deepcopy(model).cpu().eval()\n", - " qconfig_spec = {nn.Linear, nn.LSTM, nn.GRU, nn.RNN}\n", - " with self._quantized_engine():\n", - " return quantize_dynamic(model_copy, qconfig_spec=qconfig_spec, dtype=torch.qint8, inplace=False)\n", - " except Exception as e:\n", - " print(f\"Dynamic quantization failed with error: {e}\")\n", - " return model\n", - "\n", - " def _quantize_torchao(self, model):\n", - " \"Quantize a model using torchao backend\"\n", - " model = copy.deepcopy(model).eval()\n", - " config = _TORCHAO_CONFIGS[self.method]()\n", - " if self.verbose:\n", - " print(f\"torchao: applying {self.method} ({type(config).__name__})\")\n", - " with warnings.catch_warnings():\n", - " warnings.simplefilter('ignore')\n", - " try:\n", - " quantize_(model, config)\n", - " except ImportError as e:\n", - " raise ImportError(f\"torchao method '{self.method}' requires additional dependencies: {e}\")\n", - " if self.verbose:\n", - " n = sum(1 for m in model.modules() if hasattr(getattr(m, 'weight', None), 'layout_type'))\n", - " print(f\"torchao: quantized {n} layers\")\n", - " return model\n", - "\n", - " def quantize(self, \n", - " model: nn.Module, # Model to quantize\n", - " calibration_dl: Any = None, # Dataloader for calibration (not needed for torchao weight-only)\n", - " max_calibration_samples: int = 100, # Maximum number of samples to use for calibration\n", - " device: str | torch.device = 'cpu' # Device to use for calibration\n", - " ) -> nn.Module:\n", - " \"Quantize a model using the specified backend and method.\"\n", - " # torchao backend\n", - " if self.backend == 'torchao':\n", - " return self._quantize_torchao(model)\n", - "\n", - " # Legacy backends below\n", - " if self.method == \"dynamic\":\n", - " if self.verbose: print(f\"Performing dynamic quantization with {self.backend} backend\")\n", - " self._apply_custom_configs()\n", - " return self._quantize_dynamic(model)\n", - " \n", - " self._apply_custom_configs()\n", - " example_batch, _ = calibration_dl.one_batch()\n", - " \n", - " try:\n", - " if self.verbose: print(f\"Preparing model for {self.method} quantization with {self.backend} backend\")\n", - " model_prepared = self._prepare_model(model, example_batch.cpu())\n", - " \n", - " if self.method in [\"static\", \"qat\"]:\n", - " if self.verbose: print(f\"Calibrating with up to {max_calibration_samples} samples\")\n", - " self._calibrate_model(model_prepared, calibration_dl, max_samples=max_calibration_samples, device=device)\n", - " \n", - " if self.verbose: print(\"Converting to quantized model\")\n", - " try:\n", - " with self._quantized_engine():\n", - " quantized_model = convert_fx(model_prepared)\n", - " except RuntimeError as e:\n", - " if \"Unsupported qscheme: per_channel_affine\" in str(e) and not self.use_per_tensor:\n", - " if self.verbose: print(\"Encountered per_channel_affine error, retrying with per-tensor\")\n", - " self.use_per_tensor = True\n", - " self._update_qconfig_for_per_tensor()\n", - " return self.quantize(model, calibration_dl, max_calibration_samples, device)\n", - " else:\n", - " raise e\n", - " \n", - " if self.verbose: print(\"Quantization complete\")\n", - " return quantized_model\n", - " \n", - " except Exception as e:\n", - " print(f\"Error during quantization: {e}\")\n", - " if self.verbose:\n", - " import traceback\n", - " traceback.print_exc()\n", - " return model" + "#| export\nimport contextlib\n\n# IntxWeightOnlyConfig methods that support Conv2d (need explicit filter_fn)\n_INTX_METHODS = frozenset({'intx_int4', 'intx_int8'})\n\nclass Quantizer:\n def __init__(self, \n backend: str = \"x86\", # Target backend: 'x86', 'qnnpack', 'fbgemm', or 'torchao'\n method: str = \"static\", # Method: 'static', 'dynamic', 'qat', 'int8_weight_only', 'int8_dynamic'\n qconfig_mapping: dict | None = None, # Optional custom quantization config (legacy backends only)\n custom_configs: dict | None = None, # Custom module-specific configurations\n use_per_tensor: bool = False, # Force per-tensor quantization (legacy backends only)\n observer: str = 'minmax', # Activation observer: 'minmax', 'histogram', 'moving_average'\n verbose: bool = False # Enable verbose output\n ):\n \"Initialize a quantizer with specified backend and options.\"\n store_attr()\n\n if observer not in _OBSERVERS:\n raise ValueError(f\"Unknown observer: {observer}. Choose from: {list(_OBSERVERS)}\")\n\n if backend == 'torchao':\n if not _HAS_TORCHAO:\n raise ImportError(\"torchao backend requires torchao. Install with: pip install torchao\")\n if method not in _TORCHAO_CONFIGS:\n raise ValueError(f\"Unknown torchao method '{method}'. Available: {list(_TORCHAO_CONFIGS.keys())}\")\n if observer != 'minmax':\n warnings.warn(\"observer parameter is ignored for torchao backend\")\n return\n\n # Legacy backend setup\n if qconfig_mapping is None:\n if method == \"qat\":\n self.qconfig_mapping = get_default_qat_qconfig_mapping(backend)\n else:\n self.qconfig_mapping = get_default_qconfig_mapping(backend)\n if use_per_tensor:\n self._update_qconfig_for_per_tensor()\n else:\n self.qconfig_mapping = qconfig_mapping\n\n @contextlib.contextmanager\n def _quantized_engine(self):\n \"Context manager to temporarily set the quantization backend engine.\"\n old_engine = torch.backends.quantized.engine\n torch.backends.quantized.engine = self.backend\n try:\n yield\n finally:\n torch.backends.quantized.engine = old_engine\n\n def _update_qconfig_for_per_tensor(self):\n \"Replace per-channel with per-tensor quantization to avoid conversion issues\"\n if self.verbose:\n print(\"Using per-tensor quantization instead of per-channel\")\n\n act_obs_cls = _OBSERVERS[self.observer]\n\n if self.method == \"qat\":\n weight_observer = MinMaxObserver.with_args(\n dtype=torch.qint8,\n qscheme=torch.per_tensor_symmetric,\n quant_min=-128,\n quant_max=127\n )\n activation_observer = act_obs_cls.with_args(\n averaging_constant=0.01,\n quant_min=0,\n quant_max=255\n ) if self.observer == 'moving_average' else act_obs_cls.with_args(\n quant_min=0,\n quant_max=255\n )\n per_tensor_qconfig = QConfig(\n activation=FakeQuantize.with_args(\n observer=activation_observer, quant_min=0, quant_max=255),\n weight=FakeQuantize.with_args(\n observer=weight_observer, quant_min=-128, quant_max=127))\n else:\n activation_observer = act_obs_cls.with_args(\n dtype=torch.quint8, qscheme=torch.per_tensor_affine, quant_min=0, quant_max=255)\n weight_observer = MinMaxObserver.with_args(\n dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, quant_min=-128, quant_max=127)\n per_tensor_qconfig = QConfig(activation=activation_observer, weight=weight_observer)\n\n self.qconfig_mapping.global_qconfig = per_tensor_qconfig\n\n def _apply_custom_configs(self):\n \"Apply custom quantization configurations to specific modules\"\n if not self.custom_configs: return\n for module_name, config in self.custom_configs.items():\n if self.verbose: print(f\"Setting custom config for {module_name}\")\n self.qconfig_mapping.set_module_name(module_name, config)\n\n def _prepare_model(self, model, example_inputs):\n \"Prepare model for quantization based on selected method\"\n model = model.cpu()\n model = model.train() if self.method == \"qat\" else model.eval()\n\n try:\n with self._quantized_engine():\n if self.method == \"static\":\n return prepare_fx(model, self.qconfig_mapping, example_inputs)\n elif self.method == \"dynamic\":\n self.qconfig_mapping.set_object_type(torch.nn.Linear, default_dynamic_qconfig)\n self.qconfig_mapping.set_object_type(torch.nn.LSTM, default_dynamic_qconfig)\n self.qconfig_mapping.set_object_type(torch.nn.GRU, default_dynamic_qconfig)\n self.qconfig_mapping.set_object_type(torch.nn.RNN, default_dynamic_qconfig)\n if self.custom_configs:\n for module_name, config in self.custom_configs.items():\n self.qconfig_mapping.set_module_name(module_name, config)\n return prepare_fx(model, self.qconfig_mapping, example_inputs)\n elif self.method == \"qat\":\n return prepare_qat_fx(model, self.qconfig_mapping, example_inputs)\n else:\n raise ValueError(f\"Unknown quantization method: {self.method}\")\n except Exception as e:\n raise RuntimeError(f\"Error preparing model for quantization: {e}\")\n\n def _calibrate_model(self, model, dataloader, max_samples=None, device='cpu'):\n \"Calibrate the model on CPU (PyTorch quantization is CPU-only).\"\n model.eval()\n device = torch.device(device)\n model = model.to(device)\n\n num_samples = getattr(dataloader, 'n', None)\n if max_samples is not None and num_samples is not None:\n num_samples = min(num_samples, max_samples)\n\n data_iter = dataloader if not self.verbose else tqdm(\n dataloader, desc=\"Calibrating\", total=num_samples//dataloader.bs if num_samples else None)\n\n samples_seen = 0\n with torch.no_grad():\n for i, batch in enumerate(data_iter):\n inputs = batch[0] if isinstance(batch, (list, tuple)) and len(batch) >= 1 else batch\n if hasattr(inputs, 'data'): inputs = inputs.data\n if isinstance(inputs, (list, tuple)):\n inputs = [x.to(device) if isinstance(x, torch.Tensor) else x for x in inputs]\n else:\n inputs = inputs.to(device)\n model(inputs)\n batch_size = inputs.shape[0] if isinstance(inputs, torch.Tensor) else inputs[0].shape[0]\n samples_seen += batch_size\n if max_samples is not None and samples_seen >= max_samples: break\n\n def _quantize_dynamic(self, model):\n \"Quantize a model with dynamic quantization\"\n try:\n model_copy = copy.deepcopy(model).cpu().eval()\n qconfig_spec = {nn.Linear, nn.LSTM, nn.GRU, nn.RNN}\n with self._quantized_engine():\n return quantize_dynamic(model_copy, qconfig_spec=qconfig_spec, dtype=torch.qint8, inplace=False)\n except Exception as e:\n print(f\"Dynamic quantization failed with error: {e}\")\n return model\n\n def _quantize_torchao(self, model):\n \"Quantize a model using torchao backend\"\n model = copy.deepcopy(model).eval()\n config = _TORCHAO_CONFIGS[self.method]()\n if self.verbose:\n print(f\"torchao: applying {self.method} ({type(config).__name__})\")\n # IntxWeightOnlyConfig supports Conv2d but needs explicit filter_fn\n # Exclude depthwise convolutions (groups>1) — PerAxis(0) fails on (C,1,K,K) weights\n if self.method in _INTX_METHODS:\n filter_fn = lambda m, fqn: (isinstance(m, nn.Linear) or\n (isinstance(m, nn.Conv2d) and m.groups == 1))\n else:\n filter_fn = None\n with warnings.catch_warnings():\n warnings.simplefilter('ignore')\n try:\n quantize_(model, config, filter_fn=filter_fn)\n except ImportError as e:\n raise ImportError(f\"torchao method '{self.method}' requires additional dependencies: {e}\")\n if self.verbose:\n n = sum(1 for m in model.modules() if hasattr(getattr(m, 'weight', None), 'layout_type'))\n print(f\"torchao: quantized {n} layers\")\n return model\n\n def quantize(self, \n model: nn.Module, # Model to quantize\n calibration_dl: Any = None, # Dataloader for calibration (not needed for torchao weight-only)\n max_calibration_samples: int = 100, # Maximum number of samples to use for calibration\n device: str | torch.device = 'cpu' # Device to use for calibration\n ) -> nn.Module:\n \"Quantize a model using the specified backend and method.\"\n # torchao backend\n if self.backend == 'torchao':\n return self._quantize_torchao(model)\n\n # Legacy backends below\n if self.method == \"dynamic\":\n if self.verbose: print(f\"Performing dynamic quantization with {self.backend} backend\")\n self._apply_custom_configs()\n return self._quantize_dynamic(model)\n\n self._apply_custom_configs()\n example_batch, _ = calibration_dl.one_batch()\n\n try:\n if self.verbose: print(f\"Preparing model for {self.method} quantization with {self.backend} backend\")\n model_prepared = self._prepare_model(model, example_batch.cpu())\n\n if self.method in [\"static\", \"qat\"]:\n if self.verbose: print(f\"Calibrating with up to {max_calibration_samples} samples\")\n self._calibrate_model(model_prepared, calibration_dl, max_samples=max_calibration_samples, device=device)\n\n if self.verbose: print(\"Converting to quantized model\")\n try:\n with self._quantized_engine():\n quantized_model = convert_fx(model_prepared)\n except RuntimeError as e:\n if \"Unsupported qscheme: per_channel_affine\" in str(e) and not self.use_per_tensor:\n if self.verbose: print(\"Encountered per_channel_affine error, retrying with per-tensor\")\n self.use_per_tensor = True\n self._update_qconfig_for_per_tensor()\n return self.quantize(model, calibration_dl, max_calibration_samples, device)\n else:\n raise e\n\n if self.verbose: print(\"Quantization complete\")\n return quantized_model\n\n except Exception as e:\n print(f\"Error during quantization: {e}\")\n if self.verbose:\n import traceback\n traceback.print_exc()\n return model" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "37r0vysj2l2", + "metadata": {}, + "outputs": [], + "source": "#| export\nimport warnings as _warnings\nfrom collections import OrderedDict as _OrderedDict\n\ndef quantize_mixed(\n model: nn.Module, # model to quantize (deepcopied internally)\n layer_configs: dict[str, Any | None], # {fqn: torchao_config_or_None} from to_quant_config()\n verbose: bool = False, # print per-layer summary\n) -> nn.Module:\n \"Apply per-layer quantization using torchao FqnToConfig. Layers mapped to None are skipped.\"\n if not _HAS_TORCHAO:\n raise ImportError(\"quantize_mixed requires torchao. Install with: pip install torchao\")\n\n from torchao.quantization import quantize_, FqnToConfig\n import copy\n\n model = copy.deepcopy(model).eval()\n\n # Filter out None entries and validate FQNs\n active = {k: v for k, v in layer_configs.items() if v is not None}\n if not active: return model\n\n model_fqns = {n for n, _ in model.named_modules()}\n unmatched = set(active) - model_fqns\n if unmatched:\n _warnings.warn(f\"quantize_mixed: {len(unmatched)} FQN(s) not found in model: {list(unmatched)[:5]}\")\n\n if verbose:\n for fqn, cfg in layer_configs.items():\n status = type(cfg).__name__ if cfg is not None else \"SKIP\"\n print(f\" {fqn:30s} → {status}\")\n\n fqn_config = FqnToConfig(fqn_to_config=_OrderedDict(active))\n with _warnings.catch_warnings():\n _warnings.simplefilter('ignore')\n quantize_(model, fqn_config, filter_fn=None)\n return model" + }, + { + "cell_type": "code", + "execution_count": null, + "id": "mwzav5fnara", + "metadata": {}, + "outputs": [], + "source": "#| hide\n# Test quantize_mixed\nif _HAS_TORCHAO:\n from torchao.quantization import Int8WeightOnlyConfig\n\n # Mixed: quantize first linear, skip second\n _model = nn.Sequential(nn.Linear(64, 128), nn.ReLU(), nn.Linear(128, 10)).eval()\n _config = {'0': Int8WeightOnlyConfig(), '2': None}\n _mq = quantize_mixed(_model, _config)\n assert torch.isfinite(_mq(torch.randn(1, 64))).all()\n assert 'AffineQuantized' in type(_mq[0].weight).__name__\n assert type(_mq[2].weight).__name__ == 'Parameter'\n\n # Empty config — model unchanged\n _m2 = nn.Sequential(nn.Linear(32, 16)).eval()\n _m2q = quantize_mixed(_m2, {})\n assert type(_m2q[0].weight).__name__ == 'Parameter'\n\n # All None — model unchanged\n _m3 = nn.Sequential(nn.Linear(32, 16)).eval()\n _m3q = quantize_mixed(_m3, {'0': None})\n assert type(_m3q[0].weight).__name__ == 'Parameter'\n\n # Unmatched FQN — should warn but not crash\n import warnings\n with warnings.catch_warnings(record=True) as w:\n warnings.simplefilter('always')\n _m4 = nn.Sequential(nn.Linear(32, 16)).eval()\n _m4q = quantize_mixed(_m4, {'nonexistent': Int8WeightOnlyConfig()})\n assert len(w) == 1\n assert 'not found in model' in str(w[0].message)\n\n print(\"quantize_mixed tests passed\")" + }, + { + "cell_type": "code", + "execution_count": null, + "id": "vucbeyqhjq", + "metadata": {}, + "outputs": [], + "source": "#| hide\nfrom copy import deepcopy\nfrom fastcore.test import *\n\n# === IntxWeightOnlyConfig tests ===\nif _HAS_INTX:\n _m_conv = nn.Sequential(\n nn.Conv2d(3, 16, 3, padding=1), nn.ReLU(),\n nn.AdaptiveAvgPool2d(1), nn.Flatten(),\n nn.Linear(16, 10),\n ).eval()\n _mq = Quantizer(backend='torchao', method='intx_int4').quantize(_m_conv)\n assert torch.isfinite(_mq(torch.randn(1, 3, 8, 8))).all()\n # Both Conv2d AND Linear should be quantized (intx + filter_fn)\n assert type(_mq[0].weight).__name__ != 'Parameter', f\"Conv2d not quantized: {type(_mq[0].weight).__name__}\"\n assert type(_mq[4].weight).__name__ != 'Parameter', f\"Linear not quantized: {type(_mq[4].weight).__name__}\"\n\n _mq8 = Quantizer(backend='torchao', method='intx_int8').quantize(deepcopy(_m_conv))\n assert torch.isfinite(_mq8(torch.randn(1, 3, 8, 8))).all()\n assert type(_mq8[0].weight).__name__ != 'Parameter' # Conv2d quantized\n assert type(_mq8[4].weight).__name__ != 'Parameter' # Linear quantized\n print(\"IntxWeightOnlyConfig tests passed (Conv2d + Linear)\")\n\n# === Observer parameter tests ===\n_q_hist = Quantizer(backend='x86', method='static', observer='histogram')\ntest_eq(_q_hist.observer, 'histogram')\n\n_q_ma = Quantizer(backend='x86', method='static', observer='moving_average')\ntest_eq(_q_ma.observer, 'moving_average')\n\nwith ExceptionExpected(ValueError):\n Quantizer(observer='invalid')\n\ntest_eq(Quantizer().observer, 'minmax')\n\nif _HAS_TORCHAO:\n import warnings as _w\n with _w.catch_warnings(record=True) as w:\n _w.simplefilter('always')\n Quantizer(backend='torchao', method='int8_weight_only', observer='histogram')\n assert any('ignored' in str(x.message) for x in w)\n\nprint(\"Observer tests passed\")" + }, { "cell_type": "code", "execution_count": null,