Skip to content

Commit 2495cf3

Browse files
committed
Simplify HIGGS hook to use flatten/pad pattern like working BNF hooks
1 parent f4baf7e commit 2495cf3

File tree

1 file changed

+35
-48
lines changed

1 file changed

+35
-48
lines changed

baselines/opt_sym/eval_ppl.py

Lines changed: 35 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1386,56 +1386,43 @@ def hook(module, input, output):
13861386

13871387
# Quantization
13881388
if norm_type == 'absmax':
1389-
# VQ quantization setup
1390-
elems_per_p = (actual_rot_bs // p_dim) * p_dim
1391-
rem = actual_rot_bs - elems_per_p
1392-
1393-
# Reshape for VQ: [out_dim * n_rot, actual_rot_bs]
1394-
W_rot_reshaped = W_rot.reshape(out_dim * n_rot, actual_rot_bs)
1395-
1396-
# Compute absmax on VQ-compatible portion (excluding remainder)
1397-
if rem > 0:
1398-
W_for_vq = W_rot_reshaped[:, :elems_per_p]
1399-
else:
1400-
W_for_vq = W_rot_reshaped
1401-
1402-
# Reshape to blocks for absmax: [out_dim * n_rot * elems_per_p / bs, bs]
1403-
W_blocks_vq = W_for_vq.reshape(-1, bs)
1404-
absmax_vals = W_blocks_vq.abs().max(dim=1, keepdim=True)[0]
1405-
absmax_vals = absmax_vals.clamp_min(1e-8)
1406-
1407-
# Normalize
1408-
W_unit_blocks = W_blocks_vq / absmax_vals
1409-
W_unit = W_unit_blocks.reshape(out_dim * n_rot, elems_per_p)
1410-
1411-
# VQ quantization
1412-
groups = W_unit.reshape(-1, p_dim)
1413-
1414-
# Find nearest codewords
1415-
dists = torch.cdist(groups, q_cb.float())
1416-
idx = dists.argmin(dim=1)
1417-
q_groups = q_cb[idx]
1418-
1419-
# Dequantize
1420-
dq_groups = d_cb[idx]
1421-
dq_vq = dq_groups.reshape(out_dim * n_rot, elems_per_p)
1422-
1423-
# Denormalize - absmax has one value per block of bs elements
1424-
# absmax_vals: [out_dim * n_rot * elems_per_p / bs, 1]
1425-
# dq_vq: [out_dim * n_rot, elems_per_p]
1426-
# Need to reshape absmax to [out_dim * n_rot, elems_per_p / bs, 1] and broadcast
1427-
n_blocks_per_row = elems_per_p // bs
1428-
absmax_reshaped = absmax_vals.reshape(out_dim * n_rot, n_blocks_per_row, 1)
1429-
dq_vq_reshaped = dq_vq.reshape(out_dim * n_rot, n_blocks_per_row, bs)
1430-
dq_vq_denorm = (dq_vq_reshaped * absmax_reshaped).reshape(out_dim * n_rot, elems_per_p)
1431-
1432-
if rem > 0:
1433-
rem_part = W_rot_reshaped[:, elems_per_p:]
1434-
dq_blocks = torch.cat([dq_vq_denorm, rem_part], dim=1)
1389+
# Flatten, pad, blockwise absmax quantization (same pattern as install_bnf_hooks)
1390+
flat = W_rot.flatten()
1391+
n = flat.numel()
1392+
pad_n = (bs - n % bs) % bs
1393+
if pad_n > 0:
1394+
flat = torch.nn.functional.pad(flat, (0, pad_n))
1395+
1396+
blocks = flat.reshape(-1, bs)
1397+
absmax = blocks.abs().amax(dim=1, keepdim=True).clamp_(min=1e-12)
1398+
normalized = blocks / absmax
1399+
1400+
# VQ quantization on normalized blocks
1401+
elems_per_p = (bs // p_dim) * p_dim
1402+
rem = bs - elems_per_p
1403+
1404+
if elems_per_p > 0:
1405+
vq_part = normalized[:, :elems_per_p]
1406+
groups = vq_part.reshape(-1, p_dim)
1407+
1408+
# Find nearest codewords
1409+
dists = torch.cdist(groups, q_cb.float())
1410+
idx = dists.argmin(dim=1)
1411+
1412+
# Dequantize
1413+
dq_groups = d_cb[idx]
1414+
dq_vq = dq_groups.reshape(normalized.shape[0], elems_per_p)
1415+
1416+
if rem > 0:
1417+
rem_part = normalized[:, elems_per_p:]
1418+
dequantized = torch.cat([dq_vq, rem_part], dim=1) * absmax
1419+
else:
1420+
dequantized = dq_vq * absmax
14351421
else:
1436-
dq_blocks = dq_vq_denorm
1422+
# p_dim > bs, can't do VQ - keep normalized
1423+
dequantized = normalized * absmax
14371424

1438-
W_q = dq_blocks.reshape(W_rot.shape)
1425+
W_q = dequantized.flatten()[:n].reshape(W_rot.shape)
14391426
else:
14401427
# L2 norm - simpler case
14411428
W_flat = W_rot.reshape(-1, p_dim)

0 commit comments

Comments
 (0)