@@ -501,10 +501,10 @@ def _make_weight_mse_calibrator(
501501 )
502502 if backend is not None and backend_factory is not None :
503503 if error_func is not None :
504- # Registered backends can 't take a custom error_func; skip Hessian refinement .
504+ # Registered backend factories don 't accept a custom error_func.
505505 warnings .warn (
506- f"local_hessian: backend '{ backend } ' does not support a custom error "
507- "function; skipping Hessian -weighted calibration for this quantizer."
506+ f"backend '{ backend } ' does not support a custom error function; skipping "
507+ "error- function-weighted MSE calibration for this quantizer."
508508 )
509509 return None
510510 return backend_factory (initial_amax , axis , quant_func )
@@ -706,6 +706,80 @@ def _warn_local_hessian_fallback(name, weight, weight_quantizer, block_size, war
706706 _warn_if_block_size_mismatch (weight_quantizer , block_size , name )
707707
708708
709+ def _is_quant_fused_experts (module : nn .Module ) -> bool :
710+ """Whether ``module`` is a converted HF fused-MoE-experts wrapper with per-expert quantizers."""
711+ return hasattr (module , "_current_expert_idx" ) and hasattr (
712+ module , "gate_up_proj_weight_quantizers"
713+ )
714+
715+
716+ def _register_local_hessian_input_hooks (model , name_to_module , capture , block_size , warned ):
717+ """Register forward hooks feeding each weight's input activations to ``capture``.
718+
719+ Local-Hessian-specific (kept here rather than as a general ``QuantModule`` API): dense
720+ quantized linears hook the layer input; HF fused-MoE experts hook the shared input quantizers,
721+ keyed by the active expert (``_current_expert_idx``). Weights without a hook (conv,
722+ SequentialQuantizer, non-eager experts) fall back to plain MSE. Returns removable handles.
723+ """
724+ handles : list = []
725+
726+ def _make_expert_hook (expert_module , weight_name , quantizers , enabled ):
727+ def _expert_hook (_input_quantizer , args ):
728+ if not args :
729+ return
730+ idx = expert_module ._current_expert_idx
731+ if idx in enabled :
732+ # Read the weight fresh (valid under accelerate/FSDP re-materialization).
733+ capture (quantizers [idx ], getattr (expert_module , weight_name )[idx ], args [0 ])
734+
735+ return _expert_hook
736+
737+ for name , module in name_to_module .items ():
738+ if is_quantized_linear (module ) and isinstance (module .weight_quantizer , TensorQuantizer ):
739+ with enable_weight_access_and_writeback (module , model , name_to_module ):
740+ # ``weight`` may be absent (e.g. TE GroupedLinear exposes weight0..N, not weight);
741+ # such modules have no single 2-D weight to pair and fall back to plain MSE.
742+ weight = getattr (module , "weight" , None )
743+ if weight is None or weight .dim () != 2 or not module .weight_quantizer .is_enabled :
744+ continue
745+ _warn_local_hessian_fallback (
746+ name , weight , module .weight_quantizer , block_size , warned
747+ )
748+
749+ def _dense_hook (linear , args ):
750+ if args :
751+ capture (linear .weight_quantizer , linear .weight , args [0 ])
752+
753+ handles .append (module .register_forward_pre_hook (_dense_hook ))
754+ elif _is_quant_fused_experts (module ):
755+ with enable_weight_access_and_writeback (module , model , name_to_module ):
756+ for weight_name , quantizers_name , input_q_name in (
757+ (
758+ "gate_up_proj" ,
759+ "gate_up_proj_weight_quantizers" ,
760+ "gate_up_proj_input_quantizer" ,
761+ ),
762+ ("down_proj" , "down_proj_weight_quantizers" , "down_proj_input_quantizer" ),
763+ ):
764+ weight = getattr (module , weight_name , None )
765+ quantizers = getattr (module , quantizers_name , None )
766+ input_quantizer = getattr (module , input_q_name , None )
767+ if weight is None or quantizers is None or input_quantizer is None :
768+ continue
769+ _warn_local_hessian_fallback (
770+ f"{ name } .{ weight_name } " , weight [0 ], quantizers [0 ], block_size , warned
771+ )
772+ # Snapshot which experts are enabled now, before the caching forward silences
773+ # all weight quantizers — so we don't capture (and discard) disabled experts.
774+ enabled = {i for i , q in enumerate (quantizers ) if q .is_enabled }
775+ handles .append (
776+ input_quantizer .register_forward_pre_hook (
777+ _make_expert_hook (module , weight_name , quantizers , enabled )
778+ )
779+ )
780+ return handles
781+
782+
709783@torch .no_grad ()
710784def local_hessian_calibrate (
711785 model : nn .Module ,
@@ -767,53 +841,19 @@ def capture(weight_quantizer, weight, input_tensor):
767841 accumulators [id (weight_quantizer )] = acc
768842 acc .accumulate (input_local )
769843
770- # Phase 2: register capture hooks, disable weight fake-quant (input quantizers left as-is,
771- # matching prior behavior), run one forward to accumulate Hessians. Hooks live only for it.
772- handles : list = []
773- silenced_weight_quantizers : list [TensorQuantizer ] = []
844+ # Phase 2: capture each weight's input activations during a forward with weight fake-quant
845+ # disabled (so H = ΣXᵀX reflects full-precision weights); input quantizers are left as-is.
774846 warned : set = set ()
775- seen_modules : set [int ] = set ()
776- for name , module in name_to_module .items ():
777- if not isinstance (module , QuantModule ) or id (module ) in seen_modules :
778- continue
779- seen_modules .add (id (module ))
780- with enable_weight_access_and_writeback (module , model , name_to_module ):
781- captures = module .register_calibration_input_hooks (capture )
782- handles .extend (captures )
783- for weight , weight_quantizer in module .iter_weights_for_calibration ():
784- # Silence weight fake-quant (incl. SequentialQuantizer leaves) so the capture
785- # forward uses full-precision weights and downstream Hessians aren't corrupted.
786- leaves = (
787- list (weight_quantizer )
788- if isinstance (weight_quantizer , SequentialQuantizer )
789- else [weight_quantizer ]
790- )
791- silenced_weight_quantizers .extend (
792- q
793- for q in leaves
794- if isinstance (q , TensorQuantizer ) and q .is_enabled and q ._if_quant
795- )
796- # Only TensorQuantizer weights are refined (same as mse_calibrate); other types
797- # (e.g. SequentialQuantizer) are unsupported and left at their max-cal scale.
798- if not isinstance (weight_quantizer , TensorQuantizer ):
799- if weight_quantizer .is_enabled and "unsupported" not in warned :
800- warned .add ("unsupported" )
801- warn_rank_0 (
802- "local_hessian: only TensorQuantizer weights are calibrated; other "
803- "types (e.g. SequentialQuantizer) stay at their max-calibrated scale."
804- )
805- continue
806- if captures :
807- _warn_local_hessian_fallback (name , weight , weight_quantizer , block_size , warned )
808-
809- for weight_quantizer in silenced_weight_quantizers :
810- weight_quantizer .disable_quant ()
847+ handles = _register_local_hessian_input_hooks (
848+ model , name_to_module , capture , block_size , warned
849+ )
811850 print_rank_0 ("local_hessian: Caching activations and computing local Hessian..." )
812851 try :
813- forward_loop (model )
852+ with set_quantizer_by_cfg_context (
853+ model , [{"quantizer_name" : "*weight_quantizer" , "enable" : False }]
854+ ):
855+ forward_loop (model )
814856 finally :
815- for weight_quantizer in silenced_weight_quantizers :
816- weight_quantizer .enable_quant ()
817857 for handle in handles :
818858 handle .remove ()
819859
0 commit comments