Skip to content

Commit b159527

Browse files
committed
Add LAQ (Learnable Amax Quantization) algorithm
Also clean up llm_qat example configs and fix pad_token_id handling. Preserve weight dtype for LAQ amax and per-tensor scales: - StaticBlockScaleQuantizer.enable_laq no longer forces float32 on _amax_pre, _amax_post, and _per_tensor_scale buffers/parameters; they now inherit the dtype of the passed tensors. - laq() calibration casts amax and per_tensor_scale to the weight dtype before calling enable_laq so the quantizer matches module precision (bf16/fp16) instead of silently upcasting to fp32. Signed-off-by: realAsma <akuriparambi@nvidia.com>
1 parent f246115 commit b159527

10 files changed

Lines changed: 1074 additions & 21 deletions

File tree

modelopt/torch/quantization/config.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1551,6 +1551,70 @@ class GPTQCalibConfig(QuantizeAlgorithmConfig):
15511551
)
15521552

15531553

1554+
class LAQConfig(QuantizeAlgorithmConfig):
1555+
"""Config for LAQ (Learnt Amax Quantization) algorithm.
1556+
1557+
LAQ uses separate learnable pre-quantization and post-dequantization amax
1558+
values. Forward: ``w_q = Q_STE(w / s_pre) * s_post`` where ``s = amax / Q_max``.
1559+
1560+
``learnable_amax`` controls which amax parameters are learnable vs frozen:
1561+
- ``["pre", "post"]``: both learnable
1562+
- ``"post"`` or ``["post"]``: only post learnable, pre frozen
1563+
- ``"pre"`` or ``["pre"]``: only pre learnable, post frozen
1564+
- ``[]``: both frozen (static scales)
1565+
1566+
``tied_amax`` makes pre and post share a single tensor (requires both to
1567+
have the same learnable state, i.e. ``learnable_amax`` must be
1568+
``["pre", "post"]`` or ``[]``).
1569+
"""
1570+
1571+
method: Literal["laq"] = ModeloptField("laq")
1572+
1573+
learnable_amax: list[Literal["pre", "post"]] | Literal["pre", "post"] = ModeloptField(
1574+
default=["post"],
1575+
title="Which amax parameters are learnable.",
1576+
description=(
1577+
"Which amax params are learnable. "
1578+
"'pre', 'post', ['pre', 'post'], or []. "
1579+
"Defaults to ['post'] (post-only learnable)."
1580+
),
1581+
)
1582+
1583+
tied_amax: bool = ModeloptField(
1584+
default=False,
1585+
title="Tie pre and post amax into a single tensor.",
1586+
description=(
1587+
"If True, pre and post share one underlying tensor. "
1588+
"Requires both to have the same learnable state."
1589+
),
1590+
)
1591+
1592+
scale_algorithm: dict | None = ModeloptField(
1593+
default=None,
1594+
title="Scale calibration algorithm to run first.",
1595+
description=(
1596+
"Dict with 'method' key: 'mse', 'local_hessian', or 'max'. "
1597+
"Optional keys include 'fp8_scale_sweep' for FP4 formats. "
1598+
"Defaults to {'method': 'mse'} if None."
1599+
),
1600+
)
1601+
1602+
@model_validator(mode="after")
1603+
def _validate_tied_amax(self):
1604+
"""Validate tied_amax is compatible with learnable_amax."""
1605+
learn = self.learnable_amax
1606+
if isinstance(learn, str):
1607+
learn = [learn]
1608+
learn_set = set(learn)
1609+
if self.tied_amax:
1610+
if learn_set not in (set(), {"pre", "post"}):
1611+
raise ValueError(
1612+
f"tied_amax=True requires learnable_amax to be [] or ['pre', 'post'], "
1613+
f"got {self.learnable_amax}"
1614+
)
1615+
return self
1616+
1617+
15541618
QuantizeQuantCfgType = list[QuantizerCfgEntry]
15551619

15561620
_QuantizeAlgoCfgType = str | dict | QuantizeAlgorithmConfig | None

modelopt/torch/quantization/conversion.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,10 @@
3636
normalize_quant_cfg_list,
3737
)
3838
from .nn import (
39-
NVFP4StaticQuantizer,
4039
QuantModule,
4140
QuantModuleRegistry,
4241
SequentialQuantizer,
42+
StaticBlockScaleQuantizer,
4343
SVDQuantLinear,
4444
TensorQuantizer,
4545
)
@@ -131,10 +131,11 @@ def restore_quantizer_state(model: nn.Module, config: QuantizeConfig, metadata:
131131
name = get_unwrapped_name(name, model)
132132
state = quantizer_state_dict[name]
133133
# TODO: Add a registry for TensorQuantizers and avoid this manual conversion.
134-
if state.get("_is_nvfp4_static_quantizer") and not isinstance(
135-
module, NVFP4StaticQuantizer
136-
):
137-
NVFP4StaticQuantizer.from_tensor_quantizer(module)
134+
if (
135+
state.get("_is_static_block_scale_quantizer")
136+
or state.get("_is_nvfp4_static_quantizer") # legacy checkpoint compat
137+
) and not isinstance(module, StaticBlockScaleQuantizer):
138+
StaticBlockScaleQuantizer.from_tensor_quantizer(module)
138139
module.set_from_modelopt_state(quantizer_state_dict[name])
139140

140141
for name, module in model.named_modules():

modelopt/torch/quantization/mode.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
AWQLiteCalibConfig,
3939
CompressConfig,
4040
GPTQCalibConfig,
41+
LAQConfig,
4142
LocalHessianCalibConfig,
4243
MaxCalibConfig,
4344
MseCalibConfig,
@@ -60,6 +61,7 @@
6061
from .model_calib import (
6162
awq,
6263
gptq,
64+
laq,
6365
layerwise_calibrate,
6466
local_hessian_calibrate,
6567
max_calibrate,
@@ -524,3 +526,15 @@ def config_class(self) -> type[QuantizeAlgorithmConfig]:
524526
return GPTQCalibConfig
525527

526528
_calib_func = gptq
529+
530+
531+
@CalibrateModeRegistry.register_mode
532+
class LAQModeDescriptor(BaseCalibrateModeDescriptor):
533+
"""Mode for LAQ (Learnt Amax Quantization) algorithm."""
534+
535+
@property
536+
def config_class(self) -> type[QuantizeAlgorithmConfig]:
537+
"""Specifies the config class for the mode."""
538+
return LAQConfig
539+
540+
_calib_func = laq

modelopt/torch/quantization/model_calib.py

Lines changed: 159 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,19 @@
3232
LayerActivationCollector,
3333
_CheckpointState,
3434
)
35-
from modelopt.torch.utils import print_rank_0
35+
from modelopt.torch.utils import print_rank_0, same_device_as
3636
from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState
3737
from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method
3838

3939
from .calib import MseCalibrator, NVFP4MSECalibrator
4040
from .conversion import create_and_replace_svdquant_linear_on_the_fly, set_quantizer_by_cfg_context
41-
from .nn import NVFP4StaticQuantizer, QuantModule, SequentialQuantizer, TensorQuantizer
41+
from .nn import (
42+
NVFP4StaticQuantizer,
43+
QuantModule,
44+
SequentialQuantizer,
45+
StaticBlockScaleQuantizer,
46+
TensorQuantizer,
47+
)
4248
from .utils import (
4349
disable_calib,
4450
enable_fake_quant,
@@ -57,6 +63,7 @@
5763

5864
__all__ = [
5965
"awq",
66+
"laq",
6067
"layerwise_calibrate",
6168
"local_hessian_calibrate",
6269
"max_calibrate",
@@ -1732,3 +1739,153 @@ def _make_gptq_handle(name, m):
17321739
if torch.cuda.is_available():
17331740
torch.cuda.empty_cache()
17341741
print_rank_0(f"GPTQ time: {time.time() - total_start:.2f}s")
1742+
1743+
1744+
def _is_quantized_block_scale(quantizer: StaticBlockScaleQuantizer) -> bool:
1745+
if quantizer._block_sizes is None:
1746+
return False
1747+
scale_bits = quantizer._block_sizes.get("scale_bits", None)
1748+
if scale_bits is None:
1749+
return False
1750+
return scale_bits == (4, 3)
1751+
1752+
1753+
def _convert_to_static_block_quantizers(model: nn.Module):
1754+
"""Convert eligible TensorQuantizers to StaticBlockScaleQuantizer."""
1755+
for name, module in model.named_modules():
1756+
if isinstance(module, TensorQuantizer) and not module._disabled:
1757+
if not hasattr(module, "_amax") or module._amax is None:
1758+
continue
1759+
is_static_block_scale = (
1760+
module.is_static_block_quant
1761+
and module._block_sizes is not None
1762+
and (
1763+
(module._num_bits == (2, 1) and module._block_sizes.get("scale_bits") == (4, 3))
1764+
or isinstance(module._num_bits, int)
1765+
)
1766+
)
1767+
if is_static_block_scale:
1768+
if _is_quantized_block_scale(module):
1769+
global_amax = reduce_amax(module._amax.clone().detach(), axis=None)
1770+
else:
1771+
global_amax = None
1772+
StaticBlockScaleQuantizer.from_tensor_quantizer(module, global_amax=global_amax)
1773+
1774+
1775+
def _run_scale_calibration(model, forward_loop, scale_algorithm, caller_name):
1776+
"""Run calibration and convert to StaticBlockScaleQuantizer if needed."""
1777+
if scale_algorithm is None:
1778+
scale_algorithm = {"method": "mse"}
1779+
1780+
method = scale_algorithm.get("method")
1781+
supported = ("mse", "local_hessian", "max")
1782+
assert method in supported, f"{caller_name}: method must be one of {supported}, got '{method}'"
1783+
1784+
algo_kwargs = {k: v for k, v in scale_algorithm.items() if k != "method"}
1785+
calib_funcs = {
1786+
"mse": mse_calibrate,
1787+
"local_hessian": local_hessian_calibrate,
1788+
"max": max_calibrate,
1789+
}
1790+
calib_funcs[method](model, forward_loop=forward_loop, **algo_kwargs)
1791+
1792+
if method == "max":
1793+
_convert_to_static_block_quantizers(model)
1794+
1795+
1796+
def _compute_block_scales(quantizer):
1797+
"""Compute per-block and per-tensor scales from a StaticBlockScaleQuantizer.
1798+
1799+
Returns (per_block_scale, per_tensor_scale, quantize_scales).
1800+
"""
1801+
from .nn.modules.tensor_quantizer import _amax_to_scale
1802+
from .tensor_quant import scaled_e4m3
1803+
1804+
amax = quantizer._amax.float()
1805+
max_representable = quantizer._quant_max_bound
1806+
quantize_scales = _is_quantized_block_scale(quantizer)
1807+
per_tensor_scale = None
1808+
1809+
with same_device_as(amax):
1810+
if quantize_scales:
1811+
global_amax = quantizer._global_amax.float()
1812+
per_tensor_scale = _amax_to_scale(global_amax, max_representable)
1813+
per_block_scale = scaled_e4m3(
1814+
_amax_to_scale(
1815+
amax,
1816+
max_representable,
1817+
min_value=0.002
1818+
* per_tensor_scale.view(-1), # 0.002 ≈ smallest positive FP8 E4M3 value
1819+
),
1820+
per_tensor_scale,
1821+
None,
1822+
4,
1823+
3,
1824+
)
1825+
else:
1826+
per_block_scale = _amax_to_scale(amax, max_representable)
1827+
1828+
return per_block_scale, per_tensor_scale, quantize_scales
1829+
1830+
1831+
def _iter_weight_quantizers(model):
1832+
"""Yield (module, weight_name, quantizer) for each StaticBlockScaleQuantizer with amax."""
1833+
seen_modules = set()
1834+
for name, module in model.named_modules():
1835+
if module in seen_modules:
1836+
continue
1837+
for weight_name in weight_attr_names(module):
1838+
wq_name = quantizer_attr_names(weight_name).weight_quantizer
1839+
quantizer = getattr(module, wq_name, None)
1840+
if isinstance(quantizer, StaticBlockScaleQuantizer) and hasattr(quantizer, "_amax"):
1841+
seen_modules.add(module)
1842+
yield module, weight_name, quantizer
1843+
break
1844+
1845+
1846+
def _compute_laq_params(quantizer):
1847+
"""Compute amax and scale-quantization params for LAQ."""
1848+
per_block_scale, per_tensor_scale, quantize_scales = _compute_block_scales(quantizer)
1849+
amax = per_block_scale * quantizer._quant_max_bound
1850+
return amax, per_tensor_scale, quantize_scales
1851+
1852+
1853+
@torch.no_grad()
1854+
def laq(
1855+
model: nn.Module,
1856+
forward_loop: ForwardLoop | None = None,
1857+
scale_algorithm: dict | None = None,
1858+
learnable_amax: list | str = ("post",),
1859+
tied_amax: bool = False,
1860+
**kwargs,
1861+
):
1862+
"""Run scale calibration then convert to LAQ mode.
1863+
1864+
Uses separate pre (quant) and post (dequant) amax values.
1865+
Forward: ``w_q = Q_STE(w / s_pre) * s_post`` where ``s = amax / Q_max``.
1866+
1867+
Args:
1868+
model: Quantized model.
1869+
forward_loop: Calibration data forward loop.
1870+
scale_algorithm: Calibration algorithm config to run first.
1871+
Dict with 'method' key: 'mse', 'local_hessian', or 'max'.
1872+
Defaults to {'method': 'mse'} if None.
1873+
learnable_amax: Which amax params are learnable: 'pre', 'post',
1874+
['pre', 'post'], or [].
1875+
tied_amax: If True, pre and post share a single tensor.
1876+
"""
1877+
_run_scale_calibration(model, forward_loop, scale_algorithm, "laq")
1878+
1879+
for module, weight_name, quantizer in _iter_weight_quantizers(model):
1880+
amax, per_tensor_scale, quantize_scales = _compute_laq_params(quantizer)
1881+
weight_dtype = getattr(module, weight_name).dtype
1882+
amax = amax.to(weight_dtype)
1883+
if per_tensor_scale is not None:
1884+
per_tensor_scale = per_tensor_scale.to(weight_dtype)
1885+
quantizer.enable_laq(
1886+
amax,
1887+
per_tensor_scale,
1888+
quantize_scales,
1889+
learnable_amax=learnable_amax,
1890+
tied_amax=tied_amax,
1891+
)

0 commit comments

Comments
 (0)