File tree Expand file tree Collapse file tree 1 file changed +17
-0
lines changed
Expand file tree Collapse file tree 1 file changed +17
-0
lines changed Original file line number Diff line number Diff line change @@ -65,6 +65,23 @@ def forward_loop(model):
6565 folded_model = deepcopy (model )
6666 fold_weight (folded_model )
6767 expected_weights = {k : v for k , v in folded_model .state_dict ().items () if "quantizer" not in k }
68+ # fold_weight only applies the weight quantizer's fake-quant; it does NOT fold
69+ # input_quantizer.pre_quant_scale into the weight. The export path does:
70+ # w_exported = fake_quant(W) * pqs[None, :]
71+ # for modules where input_quantizer is disabled but has pqs (AWQ weight-only).
72+ # Apply the same pqs fold here so expected_weights matches the export output.
73+ for module_name , module in folded_model .named_modules ():
74+ inp_q = getattr (module , "input_quantizer" , None )
75+ if (
76+ inp_q is not None
77+ and not inp_q .is_enabled
78+ and getattr (inp_q , "_pre_quant_scale" , None ) is not None
79+ ):
80+ w_key = f"{ module_name } .weight" if module_name else "weight"
81+ if w_key in expected_weights :
82+ w = expected_weights [w_key ]
83+ scale = inp_q ._pre_quant_scale .squeeze ().to (device = w .device )
84+ expected_weights [w_key ] = (w * scale [None , :]).to (w .dtype )
6885 del folded_model
6986
7087 # Snapshot model state before export to verify it is not mutated
You can’t perform that action at this time.
0 commit comments