Skip to content

Commit 2aff419

Browse files
committed
Use _chunked_nearest for VQ lookup (matches working BNF)
1 parent 0a12b73 commit 2aff419

File tree

1 file changed

+5
-7
lines changed

1 file changed

+5
-7
lines changed

baselines/opt_sym/eval_ppl.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1402,12 +1402,11 @@ def pre_hook(mod, args_):
14021402
vq_part = normalized[:, :elems_per_p]
14031403
groups = vq_part.reshape(-1, p_dim)
14041404

1405-
# Find nearest codewords
1406-
dists = torch.cdist(groups, q_cb.float())
1407-
idx = dists.argmin(dim=1)
1405+
# Find nearest codewords (use _chunked_nearest like working BNF)
1406+
idx = _chunked_nearest(groups, q_cb.to(W.device), chunk_size=100000)
14081407

14091408
# Dequantize
1410-
dq_groups = d_cb[idx]
1409+
dq_groups = d_cb.to(W.device)[idx]
14111410
dq_vq = dq_groups.reshape(normalized.shape[0], elems_per_p)
14121411

14131412
if rem > 0:
@@ -1422,9 +1421,8 @@ def pre_hook(mod, args_):
14221421
else:
14231422
# L2 norm
14241423
W_flat = W_rot.reshape(-1, p_dim)
1425-
dists = torch.cdist(W_flat, q_cb.float())
1426-
idx = dists.argmin(dim=1)
1427-
W_q = d_cb[idx].reshape(W_rot.shape)
1424+
idx = _chunked_nearest(W_flat, q_cb.to(W.device), chunk_size=100000)
1425+
W_q = d_cb.to(W.device)[idx].reshape(W_rot.shape)
14281426

14291427
# Inverse Hadamard if needed
14301428
if p_dim > 1 or norm_type == 'l2':

0 commit comments

Comments
 (0)