Skip to content

Commit 560ddf7

Browse files
committed
Fix Hadamard transform: remove sqrt(n) normalization to match BNF
1 parent 2aff419 commit 560ddf7

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

baselines/opt_sym/eval_ppl.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1372,10 +1372,10 @@ def pre_hook(mod, args_):
13721372
torch.manual_seed(sign_seed)
13731373
had_sign = (2 * (torch.rand(actual_rot_bs, device=W.device) > 0.5).float() - 1).to(W.device)
13741374

1375-
# Apply sign and Hadamard
1375+
# Apply sign and Hadamard (match working BNF pattern)
13761376
W_signed = W_reshaped * had_sign.unsqueeze(0)
13771377
H = torch.tensor(hadamard(actual_rot_bs), dtype=torch.float32, device=W.device)
1378-
W_rot = (W_signed @ H) / torch.sqrt(torch.tensor(actual_rot_bs, dtype=torch.float32))
1378+
W_rot = W_signed @ H.T # No normalization - BNF doesn't divide by sqrt(n)
13791379
else:
13801380
W_rot = W_float
13811381
n_rot = 1
@@ -1424,10 +1424,10 @@ def pre_hook(mod, args_):
14241424
idx = _chunked_nearest(W_flat, q_cb.to(W.device), chunk_size=100000)
14251425
W_q = d_cb.to(W.device)[idx].reshape(W_rot.shape)
14261426

1427-
# Inverse Hadamard if needed
1427+
# Inverse Hadamard if needed (match working BNF pattern)
14281428
if p_dim > 1 or norm_type == 'l2':
14291429
W_deshaped = W_q.reshape(out_dim * n_rot, actual_rot_bs)
1430-
W_unrot = (W_deshaped @ H.T) * torch.sqrt(torch.tensor(actual_rot_bs, dtype=torch.float32))
1430+
W_unrot = W_deshaped @ H # No normalization - BNF doesn't multiply by sqrt(n)
14311431
W_unrot = W_unrot * had_sign.unsqueeze(0)
14321432
W_final = W_unrot.reshape(W.shape)
14331433
else:

0 commit comments

Comments
 (0)