|
32 | 32 | LayerActivationCollector, |
33 | 33 | _CheckpointState, |
34 | 34 | ) |
35 | | -from modelopt.torch.utils import print_rank_0 |
| 35 | +from modelopt.torch.utils import print_rank_0, same_device_as |
36 | 36 | from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState |
37 | 37 | from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method |
38 | 38 |
|
39 | 39 | from .calib import MseCalibrator, NVFP4MSECalibrator |
40 | 40 | from .conversion import create_and_replace_svdquant_linear_on_the_fly, set_quantizer_by_cfg_context |
41 | | -from .nn import NVFP4StaticQuantizer, QuantModule, SequentialQuantizer, TensorQuantizer |
| 41 | +from .nn import ( |
| 42 | + NVFP4StaticQuantizer, |
| 43 | + QuantModule, |
| 44 | + SequentialQuantizer, |
| 45 | + StaticBlockScaleQuantizer, |
| 46 | + TensorQuantizer, |
| 47 | +) |
42 | 48 | from .utils import ( |
43 | 49 | disable_calib, |
44 | 50 | enable_fake_quant, |
|
57 | 63 |
|
58 | 64 | __all__ = [ |
59 | 65 | "awq", |
| 66 | + "laq", |
60 | 67 | "layerwise_calibrate", |
61 | 68 | "local_hessian_calibrate", |
62 | 69 | "max_calibrate", |
@@ -1732,3 +1739,153 @@ def _make_gptq_handle(name, m): |
1732 | 1739 | if torch.cuda.is_available(): |
1733 | 1740 | torch.cuda.empty_cache() |
1734 | 1741 | print_rank_0(f"GPTQ time: {time.time() - total_start:.2f}s") |
| 1742 | + |
| 1743 | + |
| 1744 | +def _is_quantized_block_scale(quantizer: StaticBlockScaleQuantizer) -> bool: |
| 1745 | + if quantizer._block_sizes is None: |
| 1746 | + return False |
| 1747 | + scale_bits = quantizer._block_sizes.get("scale_bits", None) |
| 1748 | + if scale_bits is None: |
| 1749 | + return False |
| 1750 | + return scale_bits == (4, 3) |
| 1751 | + |
| 1752 | + |
| 1753 | +def _convert_to_static_block_quantizers(model: nn.Module): |
| 1754 | + """Convert eligible TensorQuantizers to StaticBlockScaleQuantizer.""" |
| 1755 | + for name, module in model.named_modules(): |
| 1756 | + if isinstance(module, TensorQuantizer) and not module._disabled: |
| 1757 | + if not hasattr(module, "_amax") or module._amax is None: |
| 1758 | + continue |
| 1759 | + is_static_block_scale = ( |
| 1760 | + module.is_static_block_quant |
| 1761 | + and module._block_sizes is not None |
| 1762 | + and ( |
| 1763 | + (module._num_bits == (2, 1) and module._block_sizes.get("scale_bits") == (4, 3)) |
| 1764 | + or isinstance(module._num_bits, int) |
| 1765 | + ) |
| 1766 | + ) |
| 1767 | + if is_static_block_scale: |
| 1768 | + if _is_quantized_block_scale(module): |
| 1769 | + global_amax = reduce_amax(module._amax.clone().detach(), axis=None) |
| 1770 | + else: |
| 1771 | + global_amax = None |
| 1772 | + StaticBlockScaleQuantizer.from_tensor_quantizer(module, global_amax=global_amax) |
| 1773 | + |
| 1774 | + |
| 1775 | +def _run_scale_calibration(model, forward_loop, scale_algorithm, caller_name): |
| 1776 | + """Run calibration and convert to StaticBlockScaleQuantizer if needed.""" |
| 1777 | + if scale_algorithm is None: |
| 1778 | + scale_algorithm = {"method": "mse"} |
| 1779 | + |
| 1780 | + method = scale_algorithm.get("method") |
| 1781 | + supported = ("mse", "local_hessian", "max") |
| 1782 | + assert method in supported, f"{caller_name}: method must be one of {supported}, got '{method}'" |
| 1783 | + |
| 1784 | + algo_kwargs = {k: v for k, v in scale_algorithm.items() if k != "method"} |
| 1785 | + calib_funcs = { |
| 1786 | + "mse": mse_calibrate, |
| 1787 | + "local_hessian": local_hessian_calibrate, |
| 1788 | + "max": max_calibrate, |
| 1789 | + } |
| 1790 | + calib_funcs[method](model, forward_loop=forward_loop, **algo_kwargs) |
| 1791 | + |
| 1792 | + if method == "max": |
| 1793 | + _convert_to_static_block_quantizers(model) |
| 1794 | + |
| 1795 | + |
| 1796 | +def _compute_block_scales(quantizer): |
| 1797 | + """Compute per-block and per-tensor scales from a StaticBlockScaleQuantizer. |
| 1798 | +
|
| 1799 | + Returns (per_block_scale, per_tensor_scale, quantize_scales). |
| 1800 | + """ |
| 1801 | + from .nn.modules.tensor_quantizer import _amax_to_scale |
| 1802 | + from .tensor_quant import scaled_e4m3 |
| 1803 | + |
| 1804 | + amax = quantizer._amax.float() |
| 1805 | + max_representable = quantizer._quant_max_bound |
| 1806 | + quantize_scales = _is_quantized_block_scale(quantizer) |
| 1807 | + per_tensor_scale = None |
| 1808 | + |
| 1809 | + with same_device_as(amax): |
| 1810 | + if quantize_scales: |
| 1811 | + global_amax = quantizer._global_amax.float() |
| 1812 | + per_tensor_scale = _amax_to_scale(global_amax, max_representable) |
| 1813 | + per_block_scale = scaled_e4m3( |
| 1814 | + _amax_to_scale( |
| 1815 | + amax, |
| 1816 | + max_representable, |
| 1817 | + min_value=0.002 |
| 1818 | + * per_tensor_scale.view(-1), # 0.002 ≈ smallest positive FP8 E4M3 value |
| 1819 | + ), |
| 1820 | + per_tensor_scale, |
| 1821 | + None, |
| 1822 | + 4, |
| 1823 | + 3, |
| 1824 | + ) |
| 1825 | + else: |
| 1826 | + per_block_scale = _amax_to_scale(amax, max_representable) |
| 1827 | + |
| 1828 | + return per_block_scale, per_tensor_scale, quantize_scales |
| 1829 | + |
| 1830 | + |
| 1831 | +def _iter_weight_quantizers(model): |
| 1832 | + """Yield (module, weight_name, quantizer) for each StaticBlockScaleQuantizer with amax.""" |
| 1833 | + seen_modules = set() |
| 1834 | + for name, module in model.named_modules(): |
| 1835 | + if module in seen_modules: |
| 1836 | + continue |
| 1837 | + for weight_name in weight_attr_names(module): |
| 1838 | + wq_name = quantizer_attr_names(weight_name).weight_quantizer |
| 1839 | + quantizer = getattr(module, wq_name, None) |
| 1840 | + if isinstance(quantizer, StaticBlockScaleQuantizer) and hasattr(quantizer, "_amax"): |
| 1841 | + seen_modules.add(module) |
| 1842 | + yield module, weight_name, quantizer |
| 1843 | + break |
| 1844 | + |
| 1845 | + |
| 1846 | +def _compute_laq_params(quantizer): |
| 1847 | + """Compute amax and scale-quantization params for LAQ.""" |
| 1848 | + per_block_scale, per_tensor_scale, quantize_scales = _compute_block_scales(quantizer) |
| 1849 | + amax = per_block_scale * quantizer._quant_max_bound |
| 1850 | + return amax, per_tensor_scale, quantize_scales |
| 1851 | + |
| 1852 | + |
| 1853 | +@torch.no_grad() |
| 1854 | +def laq( |
| 1855 | + model: nn.Module, |
| 1856 | + forward_loop: ForwardLoop | None = None, |
| 1857 | + scale_algorithm: dict | None = None, |
| 1858 | + learnable_amax: list | str = ("post",), |
| 1859 | + tied_amax: bool = False, |
| 1860 | + **kwargs, |
| 1861 | +): |
| 1862 | + """Run scale calibration then convert to LAQ mode. |
| 1863 | +
|
| 1864 | + Uses separate pre (quant) and post (dequant) amax values. |
| 1865 | + Forward: ``w_q = Q_STE(w / s_pre) * s_post`` where ``s = amax / Q_max``. |
| 1866 | +
|
| 1867 | + Args: |
| 1868 | + model: Quantized model. |
| 1869 | + forward_loop: Calibration data forward loop. |
| 1870 | + scale_algorithm: Calibration algorithm config to run first. |
| 1871 | + Dict with 'method' key: 'mse', 'local_hessian', or 'max'. |
| 1872 | + Defaults to {'method': 'mse'} if None. |
| 1873 | + learnable_amax: Which amax params are learnable: 'pre', 'post', |
| 1874 | + ['pre', 'post'], or []. |
| 1875 | + tied_amax: If True, pre and post share a single tensor. |
| 1876 | + """ |
| 1877 | + _run_scale_calibration(model, forward_loop, scale_algorithm, "laq") |
| 1878 | + |
| 1879 | + for module, weight_name, quantizer in _iter_weight_quantizers(model): |
| 1880 | + amax, per_tensor_scale, quantize_scales = _compute_laq_params(quantizer) |
| 1881 | + weight_dtype = getattr(module, weight_name).dtype |
| 1882 | + amax = amax.to(weight_dtype) |
| 1883 | + if per_tensor_scale is not None: |
| 1884 | + per_tensor_scale = per_tensor_scale.to(weight_dtype) |
| 1885 | + quantizer.enable_laq( |
| 1886 | + amax, |
| 1887 | + per_tensor_scale, |
| 1888 | + quantize_scales, |
| 1889 | + learnable_amax=learnable_amax, |
| 1890 | + tied_amax=tied_amax, |
| 1891 | + ) |
0 commit comments