@@ -1372,10 +1372,10 @@ def pre_hook(mod, args_):
13721372 torch .manual_seed (sign_seed )
13731373 had_sign = (2 * (torch .rand (actual_rot_bs , device = W .device ) > 0.5 ).float () - 1 ).to (W .device )
13741374
1375- # Apply sign and Hadamard
1375+ # Apply sign and Hadamard (match working BNF pattern)
13761376 W_signed = W_reshaped * had_sign .unsqueeze (0 )
13771377 H = torch .tensor (hadamard (actual_rot_bs ), dtype = torch .float32 , device = W .device )
1378- W_rot = ( W_signed @ H ) / torch . sqrt (torch . tensor ( actual_rot_bs , dtype = torch . float32 ) )
1378+ W_rot = W_signed @ H . T # No normalization - BNF doesn't divide by sqrt(n )
13791379 else :
13801380 W_rot = W_float
13811381 n_rot = 1
@@ -1424,10 +1424,10 @@ def pre_hook(mod, args_):
14241424 idx = _chunked_nearest (W_flat , q_cb .to (W .device ), chunk_size = 100000 )
14251425 W_q = d_cb .to (W .device )[idx ].reshape (W_rot .shape )
14261426
1427- # Inverse Hadamard if needed
1427+ # Inverse Hadamard if needed (match working BNF pattern)
14281428 if p_dim > 1 or norm_type == 'l2' :
14291429 W_deshaped = W_q .reshape (out_dim * n_rot , actual_rot_bs )
1430- W_unrot = ( W_deshaped @ H . T ) * torch . sqrt (torch . tensor ( actual_rot_bs , dtype = torch . float32 ) )
1430+ W_unrot = W_deshaped @ H # No normalization - BNF doesn't multiply by sqrt(n )
14311431 W_unrot = W_unrot * had_sign .unsqueeze (0 )
14321432 W_final = W_unrot .reshape (W .shape )
14331433 else :
0 commit comments