@@ -1402,12 +1402,11 @@ def pre_hook(mod, args_):
14021402 vq_part = normalized [:, :elems_per_p ]
14031403 groups = vq_part .reshape (- 1 , p_dim )
14041404
1405- # Find nearest codewords
1406- dists = torch .cdist (groups , q_cb .float ())
1407- idx = dists .argmin (dim = 1 )
1405+ # Find nearest codewords (use _chunked_nearest like working BNF)
1406+ idx = _chunked_nearest (groups , q_cb .to (W .device ), chunk_size = 100000 )
14081407
14091408 # Dequantize
1410- dq_groups = d_cb [idx ]
1409+ dq_groups = d_cb . to ( W . device ) [idx ]
14111410 dq_vq = dq_groups .reshape (normalized .shape [0 ], elems_per_p )
14121411
14131412 if rem > 0 :
@@ -1422,9 +1421,8 @@ def pre_hook(mod, args_):
14221421 else :
14231422 # L2 norm
14241423 W_flat = W_rot .reshape (- 1 , p_dim )
1425- dists = torch .cdist (W_flat , q_cb .float ())
1426- idx = dists .argmin (dim = 1 )
1427- W_q = d_cb [idx ].reshape (W_rot .shape )
1424+ idx = _chunked_nearest (W_flat , q_cb .to (W .device ), chunk_size = 100000 )
1425+ W_q = d_cb .to (W .device )[idx ].reshape (W_rot .shape )
14281426
14291427 # Inverse Hadamard if needed
14301428 if p_dim > 1 or norm_type == 'l2' :
0 commit comments