Skip to content

Commit f412e29

Browse files
committed
fix: address PR review feedback (realAsma)
- Fold pre_quant_scale on GPU before .cpu() move (perf fix) - Use torch.allclose instead of torch.equal in test (nit) Signed-off-by: Sungsoo Ha <sungsooh@nvidia.com>
1 parent c6b93b9 commit f412e29

2 files changed

Lines changed: 3 additions & 3 deletions

File tree

modelopt/torch/export/plugins/vllm_fakequant_hf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,11 +105,11 @@ def _process_weight(item: _WeightQuantWork) -> tuple[str, torch.Tensor, str | No
105105
Returns (sd_key, quantized_weight_on_cpu, inp_q_key_or_None).
106106
"""
107107
w = item.weight
108-
w_quant = item.quantizer(w.float()).to(w.dtype).cpu()
108+
w_quant = item.quantizer(w.float()).to(w.dtype)
109109
if item.inp_q is not None:
110110
scale = item.inp_q._pre_quant_scale.squeeze().to(device=w_quant.device)
111111
w_quant = (w_quant * scale[None, :]).to(w_quant.dtype)
112-
return item.sd_key, w_quant, item.inp_q_key
112+
return item.sd_key, w_quant.cpu(), item.inp_q_key
113113

114114

115115
def _process_device_batch(items: list[_WeightQuantWork], device: torch.device):

tests/gpu/torch/export/test_vllm_fakequant_hf_parallel_export.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def forward_loop(model):
8080

8181
assert seq_sd.keys() == par_sd.keys(), "Key mismatch between sequential and parallel export"
8282
for key in seq_sd:
83-
assert torch.equal(seq_sd[key], par_sd[key]), (
83+
assert torch.allclose(seq_sd[key], par_sd[key]), (
8484
f"Weight mismatch for {key}: max diff={torch.abs(seq_sd[key] - par_sd[key]).max()}"
8585
)
8686

0 commit comments

Comments
 (0)