Skip to content

Commit 215f643

Browse files
realAsmaclaude
andcommitted
Preserve weight dtype for LAQ amax and per-tensor scales
- StaticBlockScaleQuantizer.enable_laq no longer forces float32 on _amax_pre, _amax_post, and _per_tensor_scale buffers/parameters; they now inherit the dtype of the passed tensors. - laq() calibration casts amax and per_tensor_scale to the weight dtype before calling enable_laq so the quantizer matches module precision (bf16/fp16) instead of silently upcasting to fp32. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: realAsma <akuriparambi@nvidia.com>
1 parent 8866b80 commit 215f643

2 files changed

Lines changed: 9 additions & 5 deletions

File tree

modelopt/torch/quantization/model_calib.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1818,6 +1818,10 @@ def laq(
18181818

18191819
for module, weight_name, quantizer in _iter_weight_quantizers(model):
18201820
amax, per_tensor_scale, quantize_scales = _compute_laq_params(quantizer)
1821+
weight_dtype = getattr(module, weight_name).dtype
1822+
amax = amax.to(weight_dtype)
1823+
if per_tensor_scale is not None:
1824+
per_tensor_scale = per_tensor_scale.to(weight_dtype)
18211825
quantizer.enable_laq(
18221826
amax,
18231827
per_tensor_scale,

modelopt/torch/quantization/nn/modules/tensor_quantizer.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1455,18 +1455,18 @@ def enable_laq(
14551455
learn = {learnable_amax} if isinstance(learnable_amax, str) else set(learnable_amax)
14561456

14571457
if "post" in learn:
1458-
self._amax_post = nn.Parameter(amax.clone().detach().float(), requires_grad=True)
1458+
self._amax_post = nn.Parameter(amax.clone().detach(), requires_grad=True)
14591459
else:
1460-
self.register_buffer("_amax_post", amax.clone().detach().float())
1460+
self.register_buffer("_amax_post", amax.clone().detach())
14611461

14621462
if not tied_amax:
14631463
if "pre" in learn:
1464-
self._amax_pre = nn.Parameter(amax.clone().detach().float(), requires_grad=True)
1464+
self._amax_pre = nn.Parameter(amax.clone().detach(), requires_grad=True)
14651465
else:
1466-
self.register_buffer("_amax_pre", amax.clone().detach().float())
1466+
self.register_buffer("_amax_pre", amax.clone().detach())
14671467

14681468
if per_tensor_scale is not None:
1469-
self.register_buffer("_per_tensor_scale", per_tensor_scale.clone().detach().float())
1469+
self.register_buffer("_per_tensor_scale", per_tensor_scale.clone().detach())
14701470
self._quantize_scales = quantize_scales
14711471
self._laq = True
14721472
self._learnable_amax = sorted(learn)

0 commit comments

Comments
 (0)