Skip to content

Commit b758cc5

Browse files
committed
minor
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
1 parent 900cc32 commit b758cc5

1 file changed

Lines changed: 23 additions & 14 deletions

File tree

modelopt/torch/export/plugins/vllm_fakequant_hf.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -148,41 +148,50 @@ def _resmooth_experts_for_export(
148148
for experts in expert_groups:
149149
if not experts:
150150
continue
151-
pqs_list = _collect_expert_pre_quant_scales(experts)
152-
if pqs_list is None:
151+
pre_quant_scales_list = _collect_expert_pre_quant_scales(experts)
152+
if pre_quant_scales_list is None:
153153
continue
154154

155-
avg_pqs = torch.stack(pqs_list).mean(0)
155+
avg_pre_quant_scale = torch.stack(pre_quant_scales_list).mean(0)
156156
# Guard against degenerate calibration where a channel's scale is zero:
157157
# zero avg_pqs would produce inf ratio and corrupt the exported weight.
158-
avg_pqs = avg_pqs.clamp(min=torch.finfo(torch.float32).tiny)
158+
avg_pre_quant_scale = avg_pre_quant_scale.clamp(min=torch.finfo(torch.float32).tiny)
159159

160160
for ex in experts:
161161
nm = id_to_name.get(id(ex))
162162
if nm is None or f"{nm}.weight" not in state_dict:
163163
continue
164-
old_pqs = ex.input_quantizer._pre_quant_scale
165-
avg_on_dev = avg_pqs.to(device=old_pqs.device, dtype=old_pqs.dtype)
166-
if torch.equal(old_pqs, avg_on_dev):
164+
old_pre_quant_scale = ex.input_quantizer._pre_quant_scale
165+
avg_pre_quant_scale = avg_pre_quant_scale.to(
166+
device=old_pre_quant_scale.device, dtype=old_pre_quant_scale.dtype
167+
)
168+
if torch.equal(old_pre_quant_scale, avg_pre_quant_scale):
167169
continue
168-
w = state_dict[f"{nm}.weight"]
169-
ratio = (old_pqs / avg_pqs).to(dtype=torch.float32, device=w.device)
170-
state_dict[f"{nm}.weight"] = (w.float() * ratio[None, :]).to(w.dtype)
170+
weight = state_dict[f"{nm}.weight"]
171+
updated_weight = (
172+
weight.to(torch.float32)
173+
* old_pre_quant_scale.to(dtype=torch.float32, device=weight.device)
174+
/ avg_pre_quant_scale.to(dtype=torch.float32, device=weight.device)
175+
).to(weight.dtype)
176+
state_dict[f"{nm}.weight"] = updated_weight
171177
requant_weights.add(f"{nm}.weight")
172178

173179
iq0 = experts[0].input_quantizer
174-
max_in_amax: torch.Tensor | None = None
180+
synced_amax: torch.Tensor | None = None
175181
if iq0.is_enabled:
176182
amaxes = [e.input_quantizer.amax for e in experts]
177183
if all(a is not None for a in amaxes):
178-
max_in_amax = merge_amax_tensors_for_group(amaxes)
184+
synced_amax = merge_amax_tensors_for_group(amaxes)
179185

180-
avg_out = avg_pqs.detach().clone()
186+
avg_pre_quant_scale_output = avg_pre_quant_scale.detach().clone()
181187
for ex in experts:
182188
nm = id_to_name.get(id(ex))
183189
if nm is None:
184190
continue
185-
out[get_unwrapped_name(f"{nm}.input_quantizer", model)] = (avg_out, max_in_amax)
191+
out[get_unwrapped_name(f"{nm}.input_quantizer", model)] = (
192+
avg_pre_quant_scale_output,
193+
synced_amax,
194+
)
186195

187196
return out, requant_weights
188197

0 commit comments

Comments
 (0)