Skip to content

Commit a2c8d53

Browse files
committed
Rewrite HIGGS hooks using shared Hadamard matrix like working BNF
1 parent 7a2d740 commit a2c8d53

File tree

1 file changed

+91
-100
lines changed

1 file changed

+91
-100
lines changed

baselines/opt_sym/eval_ppl.py

Lines changed: 91 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -1343,113 +1343,104 @@ def main():
13431343
key = (k, p)
13441344
quant_cb, deq_cb = unique_configs[key]
13451345

1346-
# Use pre_hook (runs BEFORE forward) to quantize weights before use
1347-
# Then post_hook to restore original weights
1348-
def make_pre_hook(q_cb, d_cb, bs, rot_bs, sign_seed, norm_type, p_dim):
1349-
def pre_hook(mod, args_):
1350-
if not _hooks_enabled:
1351-
return
1352-
# Save original weight
1353-
mod._orig_weight = mod.weight.data.clone()
1354-
W = mod.weight.data
1355-
dtype = W.dtype
1356-
W_float = W.float()
1357-
1358-
out_dim, in_dim = W.shape
1359-
1360-
# Apply Hadamard rotation
1361-
if p_dim > 1 or norm_type == 'l2':
1362-
n_rot = in_dim // rot_bs
1363-
if n_rot == 0:
1364-
n_rot = 1
1365-
actual_rot_bs = in_dim
1366-
else:
1367-
actual_rot_bs = rot_bs
1368-
1369-
W_reshaped = W_float.reshape(out_dim * n_rot, actual_rot_bs)
1370-
1371-
# Get or create sign vector
1372-
torch.manual_seed(sign_seed)
1373-
had_sign = (2 * (torch.rand(actual_rot_bs, device=W.device) > 0.5).float() - 1).to(W.device)
1374-
1375-
# Apply sign and Hadamard (match working BNF pattern)
1376-
W_signed = W_reshaped * had_sign.unsqueeze(0)
1377-
H = torch.tensor(hadamard(actual_rot_bs), dtype=torch.float32, device=W.device)
1378-
W_rot = W_signed @ H.T # No normalization - BNF doesn't divide by sqrt(n)
1379-
else:
1380-
W_rot = W_float
1346+
# Store per-layer config for custom BNF hook
1347+
module._higgs_quant_cb = quant_cb.to(device)
1348+
module._higgs_deq_cb = deq_cb.to(device)
1349+
module._higgs_p = p
1350+
1351+
# Install custom BNF hooks that check for per-layer codebooks
1352+
H_block = make_hadamard_block(args.rot_blocksize, device)
1353+
max_in = max(m.weight.shape[1] for m in model.modules()
1354+
if isinstance(m, nn.Linear))
1355+
signs = make_random_signs(max_in, args.seed, device)
1356+
1357+
def make_higgs_pre_hook(bs, rot_bs, sign_seed, norm_type):
1358+
def hook(mod, args_):
1359+
if not _hooks_enabled:
1360+
return
1361+
# Check if this module has HIGGS config
1362+
if not hasattr(mod, '_higgs_quant_cb'):
1363+
return
1364+
1365+
q_cb = mod._higgs_quant_cb
1366+
d_cb = mod._higgs_deq_cb
1367+
p_dim = mod._higgs_p
1368+
1369+
mod._orig_weight = mod.weight.data.clone()
1370+
W = mod.weight.data.float()
1371+
out_dim, in_dim = W.shape
1372+
1373+
# Apply Hadamard rotation
1374+
if p_dim > 1 or norm_type == 'l2':
1375+
n_rot = in_dim // rot_bs
1376+
if n_rot == 0:
13811377
n_rot = 1
13821378
actual_rot_bs = in_dim
1383-
1384-
# Quantization
1385-
if norm_type == 'absmax':
1386-
# Flatten, pad, blockwise absmax quantization
1387-
flat = W_rot.flatten()
1388-
n = flat.numel()
1389-
pad_n = (bs - n % bs) % bs
1390-
if pad_n > 0:
1391-
flat = torch.nn.functional.pad(flat, (0, pad_n))
1392-
1393-
blocks = flat.reshape(-1, bs)
1394-
absmax = blocks.abs().amax(dim=1, keepdim=True).clamp_(min=1e-12)
1395-
normalized = blocks / absmax
1396-
1397-
# VQ quantization on normalized blocks
1398-
elems_per_p = (bs // p_dim) * p_dim
1399-
rem = bs - elems_per_p
1400-
1401-
if elems_per_p > 0:
1402-
vq_part = normalized[:, :elems_per_p]
1403-
groups = vq_part.reshape(-1, p_dim)
1404-
1405-
# Find nearest codewords (use _chunked_nearest like working BNF)
1406-
idx = _chunked_nearest(groups, q_cb.to(W.device), chunk_size=100000)
1407-
1408-
# Dequantize (returns values on unit sphere)
1409-
dq_groups = d_cb.to(W.device)[idx]
1410-
dq_vq = dq_groups.reshape(normalized.shape[0], elems_per_p)
1411-
1412-
if rem > 0:
1413-
# rem_part is normalized (unit scale), dq_vq is normalized
1414-
# Both need to be multiplied by absmax to denormalize
1415-
rem_part = normalized[:, elems_per_p:]
1416-
dequantized = torch.cat([dq_vq, rem_part], dim=1) * absmax
1417-
else:
1418-
dequantized = dq_vq * absmax
1419-
else:
1420-
dequantized = normalized * absmax
1421-
1422-
W_q = dequantized.flatten()[:n].reshape(W_rot.shape)
14231379
else:
1424-
# L2 norm
1425-
W_flat = W_rot.reshape(-1, p_dim)
1426-
idx = _chunked_nearest(W_flat, q_cb.to(W.device), chunk_size=100000)
1427-
W_q = d_cb.to(W.device)[idx].reshape(W_rot.shape)
1428-
1429-
# Inverse Hadamard if needed (match working BNF pattern)
1430-
if p_dim > 1 or norm_type == 'l2':
1431-
W_deshaped = W_q.reshape(out_dim * n_rot, actual_rot_bs)
1432-
W_unrot = W_deshaped @ H # No normalization - BNF doesn't multiply by sqrt(n)
1433-
W_unrot = W_unrot * had_sign.unsqueeze(0)
1434-
W_final = W_unrot.reshape(W.shape)
1380+
actual_rot_bs = rot_bs
1381+
1382+
W_reshaped = W.reshape(out_dim * n_rot, actual_rot_bs)
1383+
W_signed = W_reshaped * signs[:actual_rot_bs].unsqueeze(0)
1384+
W_rot = W_signed @ H_block.T
1385+
else:
1386+
W_rot = W
1387+
n_rot = 1
1388+
actual_rot_bs = in_dim
1389+
1390+
# Quantize
1391+
flat = W_rot.flatten()
1392+
n = flat.numel()
1393+
pad_n = (bs - n % bs) % bs
1394+
if pad_n > 0:
1395+
flat = torch.nn.functional.pad(flat, (0, pad_n))
1396+
1397+
blocks = flat.reshape(-1, bs)
1398+
absmax = blocks.abs().amax(dim=1, keepdim=True).clamp_(min=1e-12)
1399+
normalized = blocks / absmax
1400+
1401+
elems_per_p = (bs // p_dim) * p_dim
1402+
rem = bs - elems_per_p
1403+
1404+
if elems_per_p > 0:
1405+
vq_part = normalized[:, :elems_per_p]
1406+
groups = vq_part.reshape(-1, p_dim)
1407+
idx = _chunked_nearest(groups, q_cb, chunk_size=100000)
1408+
dq_groups = d_cb[idx]
1409+
dq_vq = dq_groups.reshape(normalized.shape[0], elems_per_p)
1410+
1411+
if rem > 0:
1412+
rem_part = normalized[:, elems_per_p:]
1413+
dequantized = torch.cat([dq_vq, rem_part], dim=1) * absmax
14351414
else:
1436-
W_final = W_q.reshape(W.shape)
1415+
dequantized = dq_vq * absmax
1416+
else:
1417+
dequantized = normalized * absmax
1418+
1419+
dequantized = dequantized.flatten()[:n].reshape(W_rot.shape)
14371420

1438-
mod.weight.data = W_final.to(dtype)
1439-
return pre_hook
1421+
# Inverse Hadamard
1422+
if p_dim > 1 or norm_type == 'l2':
1423+
W_deshaped = dequantized.reshape(out_dim * n_rot, actual_rot_bs)
1424+
W_unrot = W_deshaped @ H_block
1425+
W_unrot = W_unrot * signs[:actual_rot_bs].unsqueeze(0)
1426+
W_final = W_unrot.reshape(W.shape)
1427+
else:
1428+
W_final = dequantized.reshape(W.shape)
14401429

1441-
def post_hook(mod, args_, output):
1442-
# Restore original weight after forward
1443-
if hasattr(mod, '_orig_weight'):
1444-
mod.weight.data = mod._orig_weight
1445-
del mod._orig_weight
1430+
mod.weight.data = W_final.to(mod.weight.dtype)
1431+
return hook
14461432

1447-
h1 = module.register_forward_pre_hook(
1448-
make_pre_hook(quant_cb, deq_cb, args.blocksize, args.rot_blocksize,
1449-
args.seed, args.norm, p)
1450-
)
1451-
h2 = module.register_forward_hook(post_hook)
1452-
hooks.extend([h1, h2])
1433+
def post_hook(mod, args_, output):
1434+
if hasattr(mod, '_orig_weight'):
1435+
mod.weight.data = mod._orig_weight
1436+
del mod._orig_weight
1437+
1438+
for name, module in model.named_modules():
1439+
if isinstance(module, nn.Linear) and hasattr(module, '_higgs_quant_cb'):
1440+
h1 = module.register_forward_pre_hook(
1441+
make_higgs_pre_hook(args.blocksize, args.rot_blocksize, args.seed, args.norm))
1442+
h2 = module.register_forward_hook(post_hook)
1443+
hooks.extend([h1, h2])
14531444

14541445
print(f"Installed {len(hooks)} per-layer quantization hooks")
14551446
effective_bits = avg_bits

0 commit comments

Comments
 (0)