Skip to content

Commit acc6aa6

Browse files
committed
Inline local-Hessian activation capture; drop the QuantModule hook API
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
1 parent 4a7e675 commit acc6aa6

5 files changed

Lines changed: 121 additions & 154 deletions

File tree

modelopt/torch/quantization/model_calib.py

Lines changed: 86 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -501,10 +501,10 @@ def _make_weight_mse_calibrator(
501501
)
502502
if backend is not None and backend_factory is not None:
503503
if error_func is not None:
504-
# Registered backends can't take a custom error_func; skip Hessian refinement.
504+
# Registered backend factories don't accept a custom error_func.
505505
warnings.warn(
506-
f"local_hessian: backend '{backend}' does not support a custom error "
507-
"function; skipping Hessian-weighted calibration for this quantizer."
506+
f"backend '{backend}' does not support a custom error function; skipping "
507+
"error-function-weighted MSE calibration for this quantizer."
508508
)
509509
return None
510510
return backend_factory(initial_amax, axis, quant_func)
@@ -706,6 +706,80 @@ def _warn_local_hessian_fallback(name, weight, weight_quantizer, block_size, war
706706
_warn_if_block_size_mismatch(weight_quantizer, block_size, name)
707707

708708

709+
def _is_quant_fused_experts(module: nn.Module) -> bool:
710+
"""Whether ``module`` is a converted HF fused-MoE-experts wrapper with per-expert quantizers."""
711+
return hasattr(module, "_current_expert_idx") and hasattr(
712+
module, "gate_up_proj_weight_quantizers"
713+
)
714+
715+
716+
def _register_local_hessian_input_hooks(model, name_to_module, capture, block_size, warned):
717+
"""Register forward hooks feeding each weight's input activations to ``capture``.
718+
719+
Local-Hessian-specific (kept here rather than as a general ``QuantModule`` API): dense
720+
quantized linears hook the layer input; HF fused-MoE experts hook the shared input quantizers,
721+
keyed by the active expert (``_current_expert_idx``). Weights without a hook (conv,
722+
SequentialQuantizer, non-eager experts) fall back to plain MSE. Returns removable handles.
723+
"""
724+
handles: list = []
725+
726+
def _make_expert_hook(expert_module, weight_name, quantizers, enabled):
727+
def _expert_hook(_input_quantizer, args):
728+
if not args:
729+
return
730+
idx = expert_module._current_expert_idx
731+
if idx in enabled:
732+
# Read the weight fresh (valid under accelerate/FSDP re-materialization).
733+
capture(quantizers[idx], getattr(expert_module, weight_name)[idx], args[0])
734+
735+
return _expert_hook
736+
737+
for name, module in name_to_module.items():
738+
if is_quantized_linear(module) and isinstance(module.weight_quantizer, TensorQuantizer):
739+
with enable_weight_access_and_writeback(module, model, name_to_module):
740+
# ``weight`` may be absent (e.g. TE GroupedLinear exposes weight0..N, not weight);
741+
# such modules have no single 2-D weight to pair and fall back to plain MSE.
742+
weight = getattr(module, "weight", None)
743+
if weight is None or weight.dim() != 2 or not module.weight_quantizer.is_enabled:
744+
continue
745+
_warn_local_hessian_fallback(
746+
name, weight, module.weight_quantizer, block_size, warned
747+
)
748+
749+
def _dense_hook(linear, args):
750+
if args:
751+
capture(linear.weight_quantizer, linear.weight, args[0])
752+
753+
handles.append(module.register_forward_pre_hook(_dense_hook))
754+
elif _is_quant_fused_experts(module):
755+
with enable_weight_access_and_writeback(module, model, name_to_module):
756+
for weight_name, quantizers_name, input_q_name in (
757+
(
758+
"gate_up_proj",
759+
"gate_up_proj_weight_quantizers",
760+
"gate_up_proj_input_quantizer",
761+
),
762+
("down_proj", "down_proj_weight_quantizers", "down_proj_input_quantizer"),
763+
):
764+
weight = getattr(module, weight_name, None)
765+
quantizers = getattr(module, quantizers_name, None)
766+
input_quantizer = getattr(module, input_q_name, None)
767+
if weight is None or quantizers is None or input_quantizer is None:
768+
continue
769+
_warn_local_hessian_fallback(
770+
f"{name}.{weight_name}", weight[0], quantizers[0], block_size, warned
771+
)
772+
# Snapshot which experts are enabled now, before the caching forward silences
773+
# all weight quantizers — so we don't capture (and discard) disabled experts.
774+
enabled = {i for i, q in enumerate(quantizers) if q.is_enabled}
775+
handles.append(
776+
input_quantizer.register_forward_pre_hook(
777+
_make_expert_hook(module, weight_name, quantizers, enabled)
778+
)
779+
)
780+
return handles
781+
782+
709783
@torch.no_grad()
710784
def local_hessian_calibrate(
711785
model: nn.Module,
@@ -767,53 +841,19 @@ def capture(weight_quantizer, weight, input_tensor):
767841
accumulators[id(weight_quantizer)] = acc
768842
acc.accumulate(input_local)
769843

770-
# Phase 2: register capture hooks, disable weight fake-quant (input quantizers left as-is,
771-
# matching prior behavior), run one forward to accumulate Hessians. Hooks live only for it.
772-
handles: list = []
773-
silenced_weight_quantizers: list[TensorQuantizer] = []
844+
# Phase 2: capture each weight's input activations during a forward with weight fake-quant
845+
# disabled (so H = ΣXᵀX reflects full-precision weights); input quantizers are left as-is.
774846
warned: set = set()
775-
seen_modules: set[int] = set()
776-
for name, module in name_to_module.items():
777-
if not isinstance(module, QuantModule) or id(module) in seen_modules:
778-
continue
779-
seen_modules.add(id(module))
780-
with enable_weight_access_and_writeback(module, model, name_to_module):
781-
captures = module.register_calibration_input_hooks(capture)
782-
handles.extend(captures)
783-
for weight, weight_quantizer in module.iter_weights_for_calibration():
784-
# Silence weight fake-quant (incl. SequentialQuantizer leaves) so the capture
785-
# forward uses full-precision weights and downstream Hessians aren't corrupted.
786-
leaves = (
787-
list(weight_quantizer)
788-
if isinstance(weight_quantizer, SequentialQuantizer)
789-
else [weight_quantizer]
790-
)
791-
silenced_weight_quantizers.extend(
792-
q
793-
for q in leaves
794-
if isinstance(q, TensorQuantizer) and q.is_enabled and q._if_quant
795-
)
796-
# Only TensorQuantizer weights are refined (same as mse_calibrate); other types
797-
# (e.g. SequentialQuantizer) are unsupported and left at their max-cal scale.
798-
if not isinstance(weight_quantizer, TensorQuantizer):
799-
if weight_quantizer.is_enabled and "unsupported" not in warned:
800-
warned.add("unsupported")
801-
warn_rank_0(
802-
"local_hessian: only TensorQuantizer weights are calibrated; other "
803-
"types (e.g. SequentialQuantizer) stay at their max-calibrated scale."
804-
)
805-
continue
806-
if captures:
807-
_warn_local_hessian_fallback(name, weight, weight_quantizer, block_size, warned)
808-
809-
for weight_quantizer in silenced_weight_quantizers:
810-
weight_quantizer.disable_quant()
847+
handles = _register_local_hessian_input_hooks(
848+
model, name_to_module, capture, block_size, warned
849+
)
811850
print_rank_0("local_hessian: Caching activations and computing local Hessian...")
812851
try:
813-
forward_loop(model)
852+
with set_quantizer_by_cfg_context(
853+
model, [{"quantizer_name": "*weight_quantizer", "enable": False}]
854+
):
855+
forward_loop(model)
814856
finally:
815-
for weight_quantizer in silenced_weight_quantizers:
816-
weight_quantizer.enable_quant()
817857
for handle in handles:
818858
handle.remove()
819859

modelopt/torch/quantization/nn/modules/quant_module.py

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
import contextlib
1919
import warnings
20-
from collections.abc import Callable
2120
from typing import Any
2221

2322
import torch
@@ -128,17 +127,6 @@ def iter_weights_for_calibration(self):
128127
weight_quantizer = getattr(self, quantizer_attr_names(weight_name).weight_quantizer)
129128
yield getattr(self, weight_name), weight_quantizer
130129

131-
def register_calibration_input_hooks(
132-
self, callback: Callable[[TensorQuantizer, torch.Tensor, torch.Tensor], None]
133-
) -> list:
134-
"""Register forward hooks calling ``callback(weight_quantizer, weight, input)`` per weight.
135-
136-
Activation-side counterpart to :meth:`iter_weights_for_calibration`, used by
137-
activation-aware calibration (e.g. local-Hessian). Returns removable handles; the base
138-
default is ``[]`` (no pairing available -> plain weight calibration). Override per module.
139-
"""
140-
return []
141-
142130
def fold_weight(self, keep_attrs: bool = False):
143131
"""Fold the weight for faster eval."""
144132
# Handle all attributes that end with _weight_quantizer
@@ -259,27 +247,6 @@ def _setup(self):
259247
self._register_temp_attribute("_enable_weight_quantization", False)
260248
self._register_dynamic_attribute("weight", self._get_quantized_weight)
261249

262-
def register_calibration_input_hooks(self, callback):
263-
"""Pair the weight quantizer with the forward input.
264-
265-
Only a 2-D weight with an enabled ``TensorQuantizer`` is hooked; conv (4-D) and
266-
``SequentialQuantizer`` weights are unsupported and fall back to plain calibration.
267-
"""
268-
weight = getattr(self, "weight", None)
269-
if (
270-
weight is None
271-
or weight.dim() != 2
272-
or not isinstance(self.weight_quantizer, TensorQuantizer)
273-
or not self.weight_quantizer.is_enabled
274-
):
275-
return []
276-
277-
def _pre_hook(module, args):
278-
if args:
279-
callback(module.weight_quantizer, module.weight, args[0])
280-
281-
return [self.register_forward_pre_hook(_pre_hook)]
282-
283250

284251
class _LegacyQuantInputBaseMixin:
285252
"""A mixin to support legacy quantized modules which needs to have an __init__ method."""

modelopt/torch/quantization/plugins/huggingface.py

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -918,36 +918,6 @@ def iter_weights_for_calibration(self):
918918
for idx, q in enumerate(quantizers):
919919
yield weight[idx], q
920920

921-
def register_calibration_input_hooks(self, callback):
922-
"""Pair each per-expert weight quantizer with its routed input activation.
923-
924-
Hooks the shared input quantizers, which the eager ``F.linear`` path calls per expert
925-
while ``_current_expert_idx`` is set. Batched/grouped kernels never call them, so those
926-
experts get no capture (fall back to plain weight calibration).
927-
"""
928-
handles = []
929-
for weight_name, quantizers_name, input_quantizer_name in (
930-
("gate_up_proj", "gate_up_proj_weight_quantizers", "gate_up_proj_input_quantizer"),
931-
("down_proj", "down_proj_weight_quantizers", "down_proj_input_quantizer"),
932-
):
933-
weight = getattr(self, weight_name, None)
934-
quantizers = getattr(self, quantizers_name, None)
935-
input_quantizer = getattr(self, input_quantizer_name, None)
936-
if weight is None or quantizers is None or input_quantizer is None:
937-
continue
938-
939-
def _pre_hook(_iq, args, _weight_name=weight_name, _quantizers=quantizers):
940-
if not args:
941-
return
942-
idx = self._current_expert_idx
943-
weight_quantizer = _quantizers[idx]
944-
if weight_quantizer.is_enabled:
945-
# Read the weight fresh (valid under accelerate/FSDP re-materialization).
946-
callback(weight_quantizer, getattr(self, _weight_name)[idx], args[0])
947-
948-
handles.append(input_quantizer.register_forward_pre_hook(_pre_hook))
949-
return handles
950-
951921
def fold_weight(self, keep_attrs: bool = False):
952922
"""Fold per-expert weight quantizers into the fused 3-D weights.
953923

tests/unit/torch/quantization/plugins/test_fused_experts.py

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -658,9 +658,8 @@ def forward_loop(m):
658658

659659
self._cleanup_registry(expert_type)
660660

661-
def test_local_hessian_per_expert_capture_and_refinement(self):
662-
"""The plugin's extension point pairs each per-expert weight quantizer with its routed
663-
input, and local_hessian uses that to refine every expert's weight amax."""
661+
def test_local_hessian_refines_per_expert_weights(self):
662+
"""local_hessian captures each expert's routed activations and refines its weight amax."""
664663
model = _TinyMoEModel()
665664
expert_type = type(model.moe.experts)
666665
self._cleanup_registry(expert_type)
@@ -685,28 +684,25 @@ def forward_loop(m):
685684
expert_quantizers = list(experts.gate_up_proj_weight_quantizers) + list(
686685
experts.down_proj_weight_quantizers
687686
)
688-
689-
# Extension point captures per-expert (weight_quantizer, weight_slice, cin).
690-
captured = []
691-
handles = experts.register_calibration_input_hooks(
692-
lambda wq, w, x: captured.append((id(wq), tuple(w.shape), x.shape[-1]))
693-
)
694-
assert len(handles) == 2 # one pre-hook per shared input quantizer (gate_up, down)
695-
with torch.no_grad():
696-
model(torch.randn(1, 8, HIDDEN_DIM))
697-
for h in handles:
698-
h.remove()
699-
valid_ids = {id(q) for q in expert_quantizers}
700-
shapes = {(2 * INTERMEDIATE_DIM, HIDDEN_DIM), (HIDDEN_DIM, INTERMEDIATE_DIM)}
701-
assert captured and all(
702-
wq_id in valid_ids and shape in shapes and cin == shape[1]
703-
for wq_id, shape, cin in captured
704-
)
705-
706-
# End-to-end: local_hessian refines per-expert weight amax via that capture.
707687
max_amax = {id(q): q.amax.clone() for q in expert_quantizers if q.amax is not None}
688+
# Expected (cout, cin) keyed by quantizer id, to verify each Hessian pairs with its
689+
# own expert's weight slice (catches gate_up/down swaps and stale-index mis-pairing).
690+
expected_shape = {}
691+
for quantizers, weight in (
692+
(experts.gate_up_proj_weight_quantizers, experts.gate_up_proj),
693+
(experts.down_proj_weight_quantizers, experts.down_proj),
694+
):
695+
for i, q in enumerate(quantizers):
696+
expected_shape[id(q)] = (weight[i].shape[0], weight[i].shape[1])
697+
708698
local_hessian_calibrate(model, forward_loop, fp8_scale_sweep=False, debug=True)
709-
assert any(a.num_samples > 0 for a in model._local_hessian_accumulators.values())
699+
700+
# Each captured Hessian is keyed to a real per-expert quantizer with the matching weight
701+
# shape, spans multiple distinct experts, and the refinement moved at least one amax.
702+
routed = {qid: a for qid, a in model._local_hessian_accumulators.items() if a.num_samples}
703+
assert len(routed) >= 2, "expected multiple distinct experts to capture Hessians"
704+
for qid, acc in routed.items():
705+
assert (acc.cout, acc.cin) == expected_shape[qid]
710706
assert all(q.amax is not None and torch.isfinite(q.amax).all() for q in expert_quantizers)
711707
assert any(
712708
id(q) in max_amax and not torch.allclose(q.amax, max_amax[id(q)])

tests/unit/torch/quantization/test_local_hessian.py

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -158,36 +158,30 @@ def test_no_forward_loop_is_skipped(self):
158158
assert all(torch.equal(before[n], a) for n, a in _weight_amaxes(model).items())
159159

160160

161-
class TestActivationCaptureExtensionPoint:
162-
"""The extension point that decouples local-Hessian capture from module type."""
161+
class TestLocalHessianFallbacks:
162+
"""Weights local-Hessian can't pair with an input fall back to plain MSE (no Hessian)."""
163163

164-
def test_dense_captures_and_conv_falls_back(self):
164+
def test_conv_weight_falls_back_without_crash(self):
165165
torch.manual_seed(0)
166-
model = SimpleLinear()
167-
mtq.quantize(model, INT8_WEIGHT_CFG, forward_loop=_make_forward_loop())
168-
captured = []
169-
handles = model.net[0].register_calibration_input_hooks(
170-
lambda wq, w, x: captured.append((tuple(w.shape), x.shape[-1]))
171-
)
172-
assert len(handles) == 1
173-
with torch.no_grad():
174-
model(torch.randn(2, 16))
175-
for h in handles:
176-
h.remove()
177-
assert captured and captured[0] == ((32, 16), 16) # cin from activation matches weight
178-
179-
conv = SimpleConv()
180-
mtq.quantize(conv, INT8_WEIGHT_CFG, forward_loop=lambda m: m(SimpleConv.get_input()))
181-
assert conv.net[0].register_calibration_input_hooks(lambda *a: None) == [] # 4-D weight
182-
183-
def test_sequential_quantizer_weight_not_hooked(self):
166+
model = SimpleConv() # 4-D conv weights — no single 2-D weight to pair
167+
forward_loop = lambda m: m(SimpleConv.get_input()) # noqa: E731
168+
mtq.quantize(model, INT8_WEIGHT_CFG, forward_loop=forward_loop)
169+
local_hessian_calibrate(model, forward_loop, fp8_scale_sweep=False, debug=True)
170+
conv = model.net[0]
171+
assert id(conv.weight_quantizer) not in model._local_hessian_accumulators
172+
assert conv.weight_quantizer.amax is not None # still calibrated via plain MSE
173+
174+
def test_sequential_quantizer_weight_falls_back_without_crash(self):
184175
torch.manual_seed(0)
185176
model = SimpleLinear()
186177
mtq.quantize(model, INT8_WEIGHT_CFG, forward_loop=_make_forward_loop())
187178
linear = model.net[0]
188179
linear.weight_quantizer = SequentialQuantizer(TensorQuantizer(), TensorQuantizer())
189-
assert linear.register_calibration_input_hooks(lambda *a: None) == [] # unsupported
180+
local_hessian_calibrate(model, _make_forward_loop(), fp8_scale_sweep=False, debug=True)
181+
assert id(linear.weight_quantizer) not in model._local_hessian_accumulators
182+
190183

184+
class TestBlockSizeMismatchWarning:
191185
def test_block_size_mismatch_warns_only_on_mismatch(self):
192186
def q(block):
193187
return TensorQuantizer(

0 commit comments

Comments
 (0)