1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515
16- """MSE-based calibrators for quantization ."""
16+ """Calibrator that returns the MSE amax of all collected tensors ."""
1717
1818import math
1919from collections .abc import Callable
2828
2929
3030class MseCalibrator (_Calibrator ):
31- """MSE amax search that minimizes error between x and quantized x."""
31+ """Per-tensor and per-channel MSE amax search that minimizes error between x and quantized x."""
3232
3333 def __init__ (
3434 self ,
@@ -40,7 +40,20 @@ def __init__(
4040 quant_func : Callable [[torch .Tensor , torch .Tensor ], torch .Tensor ] | None = None ,
4141 error_func : Callable [[torch .Tensor , torch .Tensor ], torch .Tensor ] | None = None ,
4242 ):
43- """Initialize MSE calibrator with initial amax and search parameters."""
43+ """Initialize MSE calibrator.
44+
45+ Args:
46+ amax: Initial amax value (required).
47+ axis: Quantization axis. None means per-tensor quantization.
48+ step_size: Step size for amax search. The number of steps is computed as
49+ ceil((stop_multiplier - start_multiplier) / step_size) + 1.
50+ start_multiplier: Starting multiplier for amax search.
51+ stop_multiplier: Ending multiplier for amax search.
52+ quant_func: Function that quantizes input tensor given an amax value.
53+ Should have signature: quant_func(x, amax) -> quantized_x.
54+ error_func: Function to compute error between x and xq.
55+ Default is F.mse_loss(x, xq, reduction='none').
56+ """
4457 super ().__init__ (num_bits = None , axis = axis , unsigned = None )
4558 self ._initial_amax = amax
4659 self ._step_size = step_size
@@ -68,7 +81,11 @@ def _compute_candidate_amax(self, candidates: torch.Tensor) -> torch.Tensor:
6881
6982 @torch .no_grad ()
7083 def collect (self , x : torch .Tensor ):
71- """Collect tensor statistics for MSE calibration."""
84+ """Collect input tensor statistics and compute losses for MSE calibration.
85+
86+ Args:
87+ x: Input tensor.
88+ """
7289 if self ._quant_func is None :
7390 raise RuntimeError ("Quantization function not set." )
7491
@@ -101,13 +118,16 @@ def collect(self, x: torch.Tensor):
101118 self ._losses_sum [step ] += loss
102119
103120 def reset (self ):
104- """Reset collected statistics ."""
121+ """Reset the stored losses and amax value ."""
105122 self ._losses_sum = None
106123 self ._candidates = None
107124 self ._amax = None
108125
109126 def clear (self ):
110- """Clear all state including initial amax."""
127+ """Clear all cached data to free GPU memory.
128+
129+ Call this after compute_amax() and load_calib_amax() are done.
130+ """
111131 self ._losses_sum = None
112132 self ._candidates = None
113133 if self ._initial_amax is not None :
@@ -116,7 +136,11 @@ def clear(self):
116136
117137 @torch .no_grad ()
118138 def compute_amax (self , verbose : bool = False ):
119- """Compute optimal amax from collected statistics."""
139+ """Return the amax value that minimizes quantization error.
140+
141+ Args:
142+ verbose: If True, print the ratio of best_amax to initial_amax.
143+ """
120144 if self ._losses_sum is None or not any (loss is not None for loss in self ._losses_sum ):
121145 return None
122146
0 commit comments