@@ -148,41 +148,50 @@ def _resmooth_experts_for_export(
148148 for experts in expert_groups :
149149 if not experts :
150150 continue
151- pqs_list = _collect_expert_pre_quant_scales (experts )
152- if pqs_list is None :
151+ pre_quant_scales_list = _collect_expert_pre_quant_scales (experts )
152+ if pre_quant_scales_list is None :
153153 continue
154154
155- avg_pqs = torch .stack (pqs_list ).mean (0 )
155+ avg_pre_quant_scale = torch .stack (pre_quant_scales_list ).mean (0 )
156156 # Guard against degenerate calibration where a channel's scale is zero:
157157 # zero avg_pqs would produce inf ratio and corrupt the exported weight.
158- avg_pqs = avg_pqs .clamp (min = torch .finfo (torch .float32 ).tiny )
158+ avg_pre_quant_scale = avg_pre_quant_scale .clamp (min = torch .finfo (torch .float32 ).tiny )
159159
160160 for ex in experts :
161161 nm = id_to_name .get (id (ex ))
162162 if nm is None or f"{ nm } .weight" not in state_dict :
163163 continue
164- old_pqs = ex .input_quantizer ._pre_quant_scale
165- avg_on_dev = avg_pqs .to (device = old_pqs .device , dtype = old_pqs .dtype )
166- if torch .equal (old_pqs , avg_on_dev ):
164+ old_pre_quant_scale = ex .input_quantizer ._pre_quant_scale
165+ avg_pre_quant_scale = avg_pre_quant_scale .to (
166+ device = old_pre_quant_scale .device , dtype = old_pre_quant_scale .dtype
167+ )
168+ if torch .equal (old_pre_quant_scale , avg_pre_quant_scale ):
167169 continue
168- w = state_dict [f"{ nm } .weight" ]
169- ratio = (old_pqs / avg_pqs ).to (dtype = torch .float32 , device = w .device )
170- state_dict [f"{ nm } .weight" ] = (w .float () * ratio [None , :]).to (w .dtype )
170+ weight = state_dict [f"{ nm } .weight" ]
171+ updated_weight = (
172+ weight .to (torch .float32 )
173+ * old_pre_quant_scale .to (dtype = torch .float32 , device = weight .device )
174+ / avg_pre_quant_scale .to (dtype = torch .float32 , device = weight .device )
175+ ).to (weight .dtype )
176+ state_dict [f"{ nm } .weight" ] = updated_weight
171177 requant_weights .add (f"{ nm } .weight" )
172178
173179 iq0 = experts [0 ].input_quantizer
174- max_in_amax : torch .Tensor | None = None
180+ synced_amax : torch .Tensor | None = None
175181 if iq0 .is_enabled :
176182 amaxes = [e .input_quantizer .amax for e in experts ]
177183 if all (a is not None for a in amaxes ):
178- max_in_amax = merge_amax_tensors_for_group (amaxes )
184+ synced_amax = merge_amax_tensors_for_group (amaxes )
179185
180- avg_out = avg_pqs .detach ().clone ()
186+ avg_pre_quant_scale_output = avg_pre_quant_scale .detach ().clone ()
181187 for ex in experts :
182188 nm = id_to_name .get (id (ex ))
183189 if nm is None :
184190 continue
185- out [get_unwrapped_name (f"{ nm } .input_quantizer" , model )] = (avg_out , max_in_amax )
191+ out [get_unwrapped_name (f"{ nm } .input_quantizer" , model )] = (
192+ avg_pre_quant_scale_output ,
193+ synced_amax ,
194+ )
186195
187196 return out , requant_weights
188197
0 commit comments