We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 560ddf7 commit 7a2d740Copy full SHA for 7a2d740
baselines/opt_sym/eval_ppl.py
@@ -1405,11 +1405,13 @@ def pre_hook(mod, args_):
1405
# Find nearest codewords (use _chunked_nearest like working BNF)
1406
idx = _chunked_nearest(groups, q_cb.to(W.device), chunk_size=100000)
1407
1408
- # Dequantize
+ # Dequantize (returns values on unit sphere)
1409
dq_groups = d_cb.to(W.device)[idx]
1410
dq_vq = dq_groups.reshape(normalized.shape[0], elems_per_p)
1411
1412
if rem > 0:
1413
+ # rem_part is normalized (unit scale), dq_vq is normalized
1414
+ # Both need to be multiplied by absmax to denormalize
1415
rem_part = normalized[:, elems_per_p:]
1416
dequantized = torch.cat([dq_vq, rem_part], dim=1) * absmax
1417
else:
0 commit comments