Skip to content

Commit 8ff002f

Browse files
TimDettmersclaude
andcommitted
Fix HIGGS evaluation to use proper BNF codebook computation
- Remove simplified codebook helpers - Use existing compute_codebook() with proper kappa correction - Add full Hadamard rotation and absmax normalization in hooks Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 4e329bd commit 8ff002f

File tree

1 file changed

+110
-52
lines changed

1 file changed

+110
-52
lines changed

baselines/opt_sym/eval_ppl.py

Lines changed: 110 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -968,33 +968,6 @@ def eval_ppl(model, testenc, seqlen, device, compare_fp16=False):
968968
# HIGGS support
969969
# ============================================================
970970

971-
def compute_absmax_codebook(k, p, device='cuda'):
972-
"""Compute uniform codebook for absmax quantization."""
973-
n_entries = 2**k
974-
values = torch.linspace(-1, 1, n_entries, device=device)
975-
if p == 1:
976-
return values.view(-1, 1)
977-
else:
978-
grids = torch.meshgrid(*([values] * p), indexing='ij')
979-
codebook = torch.stack([g.flatten() for g in grids], dim=1)
980-
return codebook
981-
982-
983-
def compute_l2_codebook(k, p, device='cuda'):
984-
"""Compute codebook for L2 quantization."""
985-
n_entries = 2**k
986-
values = torch.linspace(-1, 1, int(n_entries ** (1/p)) + 1, device=device)
987-
if p == 1:
988-
return values.view(-1, 1)
989-
else:
990-
grids = torch.meshgrid(*([values] * p), indexing='ij')
991-
codebook = torch.stack([g.flatten() for g in grids], dim=1)
992-
# Keep only n_entries closest to origin
993-
norms = torch.norm(codebook, dim=1)
994-
_, indices = torch.sort(norms)
995-
return codebook[indices[:n_entries]]
996-
997-
998971
def load_higgs_assignment(assignment_path):
999972
"""Load HIGGS bitwidth assignment from JSON file.
1000973
@@ -1338,54 +1311,139 @@ def main():
13381311
opt = options[opt_idx]
13391312
print(f" {opt['config_str']}: {count} layers")
13401313

1314+
# Pre-compute codebooks for each unique (k, p) in the assignment
1315+
dev = next(model.parameters()).device
1316+
unique_configs = {}
1317+
for opt_idx in set(assignment.values()):
1318+
opt = options[opt_idx]
1319+
k, p = opt['k'], opt['p']
1320+
key = (k, p)
1321+
if key not in unique_configs:
1322+
# Use existing BNF codebook computation
1323+
quant_cb, deq_cb, _, _ = compute_codebook(
1324+
k, p, blocksize=args.blocksize, device=dev
1325+
)
1326+
unique_configs[key] = (quant_cb, deq_cb)
1327+
print(f" Computed codebook for k={k}, p={p}: {quant_cb.shape}")
1328+
13411329
# Get all linear layers
13421330
linear_layers = []
13431331
for name, module in model.named_modules():
13441332
if isinstance(module, nn.Linear) and 'embed' not in name and 'lm_head' not in name:
13451333
linear_layers.append(module)
13461334

1347-
# Apply per-layer quantization
1348-
dev = next(model.parameters()).device
1335+
# Apply per-layer quantization hooks
13491336
hooks = []
13501337
for layer_idx, module in enumerate(linear_layers):
13511338
if layer_idx not in assignment:
13521339
continue
13531340
opt_idx = assignment[layer_idx]
13541341
opt = options[opt_idx]
13551342
k, p = opt['k'], opt['p']
1343+
key = (k, p)
1344+
quant_cb, deq_cb = unique_configs[key]
13561345

1357-
# Compute codebook for this layer
1358-
if args.norm == 'l2':
1359-
quant_cb = compute_l2_codebook(k, p, device=dev)
1360-
else:
1361-
quant_cb = compute_absmax_codebook(k, p, device=dev)
1346+
# Use existing BNF hook installation (inline)
1347+
# Reuse the closure pattern from install_bnf_hooks/install_l2_hooks
1348+
sign_key = f"hadamard_sign_{args.seed}"
1349+
had_sign = None
13621350

1363-
# Create hook
1364-
def make_hook(k_val, p_val, cb, norm, bs, rot_bs):
1351+
def make_hook(q_cb, d_cb, bs, rot_bs, sign_seed, norm_type, p_dim):
13651352
def hook(module, input, output):
1366-
if not hasattr(module, '_higgs_quantized'):
1367-
w = module.weight.data
1368-
weight_dtype = w.dtype
1369-
1370-
# Apply Hadamard rotation if L2
1371-
if norm == 'l2':
1372-
# Simple rotation (without full block structure for now)
1373-
w_flat = w.reshape(-1, p_val).float()
1353+
# Use module weight directly (quantized in-place)
1354+
W = module.weight.data
1355+
dtype = W.dtype
1356+
W_float = W.float()
1357+
1358+
out_dim, in_dim = W.shape
1359+
1360+
# Apply Hadamard rotation
1361+
if p_dim > 1 or norm_type == 'l2':
1362+
n_rot = in_dim // rot_bs
1363+
if n_rot == 0:
1364+
n_rot = 1
1365+
actual_rot_bs = in_dim
13741366
else:
1375-
w_flat = w.reshape(-1, p_val).float()
1367+
actual_rot_bs = rot_bs
1368+
1369+
# Reshape for rotation
1370+
W_reshaped = W_float.reshape(out_dim * n_rot, actual_rot_bs)
1371+
1372+
# Get or create sign vector
1373+
nonlocal had_sign
1374+
if had_sign is None:
1375+
torch.manual_seed(sign_seed)
1376+
had_sign = (2 * (torch.rand(actual_rot_bs, device=W.device) > 0.5).float() - 1).to(W.device)
1377+
1378+
# Apply sign and Hadamard
1379+
W_signed = W_reshaped * had_sign.unsqueeze(0)
1380+
H = torch.tensor(hadamard(actual_rot_bs), dtype=torch.float32, device=W.device)
1381+
W_rot = (W_signed @ H) / torch.sqrt(torch.tensor(actual_rot_bs, dtype=torch.float32))
1382+
else:
1383+
W_rot = W_float
1384+
n_rot = 1
1385+
actual_rot_bs = in_dim
1386+
1387+
# Quantization
1388+
if norm_type == 'absmax':
1389+
# Block-wise absmax
1390+
W_blocks = W_rot.reshape(-1, bs)
1391+
absmax_vals = W_blocks.abs().max(dim=1, keepdim=True)[0]
1392+
absmax_vals = absmax_vals.clamp_min(1e-8)
1393+
1394+
W_unit = W_blocks / absmax_vals
1395+
1396+
# VQ quantization
1397+
elems_per_p = (actual_rot_bs // p_dim) * p_dim
1398+
rem = actual_rot_bs - elems_per_p
1399+
1400+
if rem > 0:
1401+
vq_part = W_unit.reshape(out_dim * n_rot, actual_rot_bs)[:, :elems_per_p]
1402+
else:
1403+
vq_part = W_unit
1404+
1405+
groups = vq_part.reshape(-1, p_dim)
13761406

1377-
# VQ quantize
1378-
dists = torch.cdist(w_flat, cb.float())
1379-
indices = dists.argmin(dim=1)
1380-
w_q = cb[indices].reshape(w.shape).to(weight_dtype)
1407+
# Find nearest codewords
1408+
dists = torch.cdist(groups, q_cb.float())
1409+
idx = dists.argmin(dim=1)
1410+
q_groups = q_cb[idx]
13811411

1382-
module.weight.data = w_q
1383-
module._higgs_quantized = True
1412+
# Dequantize
1413+
dq_groups = d_cb[idx]
1414+
dq_vq = dq_groups.reshape(out_dim * n_rot, elems_per_p)
1415+
1416+
if rem > 0:
1417+
rem_part = W_unit.reshape(out_dim * n_rot, actual_rot_bs)[:, elems_per_p:]
1418+
dq_blocks = torch.cat([dq_vq, rem_part], dim=1)
1419+
else:
1420+
dq_blocks = dq_vq
1421+
1422+
# Denormalize
1423+
W_q = (dq_blocks * absmax_vals).reshape(W_rot.shape)
1424+
else:
1425+
# L2 norm - simpler case
1426+
W_flat = W_rot.reshape(-1, p_dim)
1427+
dists = torch.cdist(W_flat, q_cb.float())
1428+
idx = dists.argmin(dim=1)
1429+
W_q = d_cb[idx].reshape(W_rot.shape)
1430+
1431+
# Inverse Hadamard if needed
1432+
if p_dim > 1 or norm_type == 'l2':
1433+
W_deshaped = W_q.reshape(out_dim * n_rot, actual_rot_bs)
1434+
W_unrot = (W_deshaped @ H.T) * torch.sqrt(torch.tensor(actual_rot_bs, dtype=torch.float32))
1435+
W_unrot = W_unrot * had_sign.unsqueeze(0)
1436+
W_final = W_unrot.reshape(W.shape)
1437+
else:
1438+
W_final = W_q.reshape(W.shape)
1439+
1440+
module.weight.data = W_final.to(dtype)
13841441
return output
13851442
return hook
13861443

13871444
handle = module.register_forward_hook(
1388-
make_hook(k, p, quant_cb, args.norm, args.blocksize, args.rot_blocksize)
1445+
make_hook(quant_cb, deq_cb, args.blocksize, args.rot_blocksize,
1446+
args.seed, args.norm, p)
13891447
)
13901448
hooks.append(handle)
13911449

0 commit comments

Comments
 (0)