Skip to content

Commit d0dfae0

Browse files
committed
minor
Signed-off-by: realAsma <akuriparambi@nvidia.com>
1 parent 6d39a4e commit d0dfae0

File tree

2 files changed

+54
-8
lines changed

2 files changed

+54
-8
lines changed

modelopt/torch/quantization/calib/mse.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
"""MSE-based calibrators for quantization."""
16+
"""Calibrator that returns the MSE amax of all collected tensors."""
1717

1818
import math
1919
from collections.abc import Callable
@@ -28,7 +28,7 @@
2828

2929

3030
class MseCalibrator(_Calibrator):
31-
"""MSE amax search that minimizes error between x and quantized x."""
31+
"""Per-tensor and per-channel MSE amax search that minimizes error between x and quantized x."""
3232

3333
def __init__(
3434
self,
@@ -40,7 +40,20 @@ def __init__(
4040
quant_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
4141
error_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
4242
):
43-
"""Initialize MSE calibrator with initial amax and search parameters."""
43+
"""Initialize MSE calibrator.
44+
45+
Args:
46+
amax: Initial amax value (required).
47+
axis: Quantization axis. None means per-tensor quantization.
48+
step_size: Step size for amax search. The number of steps is computed as
49+
ceil((stop_multiplier - start_multiplier) / step_size) + 1.
50+
start_multiplier: Starting multiplier for amax search.
51+
stop_multiplier: Ending multiplier for amax search.
52+
quant_func: Function that quantizes input tensor given an amax value.
53+
Should have signature: quant_func(x, amax) -> quantized_x.
54+
error_func: Function to compute error between x and xq.
55+
Default is F.mse_loss(x, xq, reduction='none').
56+
"""
4457
super().__init__(num_bits=None, axis=axis, unsigned=None)
4558
self._initial_amax = amax
4659
self._step_size = step_size
@@ -68,7 +81,11 @@ def _compute_candidate_amax(self, candidates: torch.Tensor) -> torch.Tensor:
6881

6982
@torch.no_grad()
7083
def collect(self, x: torch.Tensor):
71-
"""Collect tensor statistics for MSE calibration."""
84+
"""Collect input tensor statistics and compute losses for MSE calibration.
85+
86+
Args:
87+
x: Input tensor.
88+
"""
7289
if self._quant_func is None:
7390
raise RuntimeError("Quantization function not set.")
7491

@@ -101,13 +118,16 @@ def collect(self, x: torch.Tensor):
101118
self._losses_sum[step] += loss
102119

103120
def reset(self):
104-
"""Reset collected statistics."""
121+
"""Reset the stored losses and amax value."""
105122
self._losses_sum = None
106123
self._candidates = None
107124
self._amax = None
108125

109126
def clear(self):
110-
"""Clear all state including initial amax."""
127+
"""Clear all cached data to free GPU memory.
128+
129+
Call this after compute_amax() and load_calib_amax() are done.
130+
"""
111131
self._losses_sum = None
112132
self._candidates = None
113133
if self._initial_amax is not None:
@@ -116,7 +136,11 @@ def clear(self):
116136

117137
@torch.no_grad()
118138
def compute_amax(self, verbose: bool = False):
119-
"""Compute optimal amax from collected statistics."""
139+
"""Return the amax value that minimizes quantization error.
140+
141+
Args:
142+
verbose: If True, print the ratio of best_amax to initial_amax.
143+
"""
120144
if self._losses_sum is None or not any(loss is not None for loss in self._losses_sum):
121145
return None
122146

tests/gpu/torch/quantization/test_nvfp4_static_quantizer_cuda.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,29 @@ def quant_func(x, amax):
212212
assert amax[2] > amax[1]
213213
assert amax[3] > amax[2]
214214

215-
# Test that fp8 sweep generates quantized scales
215+
def test_fp8_sweep_generates_quantized_scales(self, device):
216+
"""Test that the fp8 sweep produces scales that are already FP8-quantized."""
217+
num_blocks = 8
218+
block_size = 16
219+
220+
x = torch.randn(num_blocks, block_size, device=device)
221+
per_block_amax = x.abs().amax(dim=-1)
222+
global_amax = per_block_amax.max()
223+
224+
def quant_func(x, amax):
225+
return static_blockwise_fp4_fake_quant(x, amax, global_amax)
226+
227+
cal = NVFP4MSECalibrator(
228+
amax=per_block_amax,
229+
global_amax=global_amax,
230+
quant_func=quant_func,
231+
)
232+
233+
cal.collect(x)
234+
amax = cal.compute_amax()
235+
236+
# The calibrator sweeps over FP8 candidates, so the resulting scales
237+
# should already be representable in FP8 (i.e., quantize-dequantize is a no-op).
216238
scale = amax.float() / 6.0
217239
scale_fp8_quant_amax = global_amax.float() / 6.0
218240
scale_qdq = scaled_e4m3_impl(scale, scale_fp8_quant_amax)

0 commit comments

Comments
 (0)