Skip to content

Commit 5cff1c3

Browse files
committed
Fix HIGGS absmax shape handling for proper denormalization
1 parent 8ff002f commit 5cff1c3

File tree

1 file changed

+26
-16
lines changed

1 file changed

+26
-16
lines changed

baselines/opt_sym/eval_ppl.py

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1386,23 +1386,30 @@ def hook(module, input, output):
13861386

13871387
# Quantization
13881388
if norm_type == 'absmax':
1389-
# Block-wise absmax
1390-
W_blocks = W_rot.reshape(-1, bs)
1391-
absmax_vals = W_blocks.abs().max(dim=1, keepdim=True)[0]
1392-
absmax_vals = absmax_vals.clamp_min(1e-8)
1393-
1394-
W_unit = W_blocks / absmax_vals
1395-
1396-
# VQ quantization
1389+
# VQ quantization setup
13971390
elems_per_p = (actual_rot_bs // p_dim) * p_dim
13981391
rem = actual_rot_bs - elems_per_p
13991392

1393+
# Reshape for VQ: [out_dim * n_rot, actual_rot_bs]
1394+
W_rot_reshaped = W_rot.reshape(out_dim * n_rot, actual_rot_bs)
1395+
1396+
# Compute absmax on VQ-compatible portion (excluding remainder)
14001397
if rem > 0:
1401-
vq_part = W_unit.reshape(out_dim * n_rot, actual_rot_bs)[:, :elems_per_p]
1398+
W_for_vq = W_rot_reshaped[:, :elems_per_p]
14021399
else:
1403-
vq_part = W_unit
1400+
W_for_vq = W_rot_reshaped
14041401

1405-
groups = vq_part.reshape(-1, p_dim)
1402+
# Reshape to blocks for absmax: [out_dim * n_rot * elems_per_p / bs, bs]
1403+
W_blocks_vq = W_for_vq.reshape(-1, bs)
1404+
absmax_vals = W_blocks_vq.abs().max(dim=1, keepdim=True)[0]
1405+
absmax_vals = absmax_vals.clamp_min(1e-8)
1406+
1407+
# Normalize
1408+
W_unit_blocks = W_blocks_vq / absmax_vals
1409+
W_unit = W_unit_blocks.reshape(out_dim * n_rot, elems_per_p)
1410+
1411+
# VQ quantization
1412+
groups = W_unit.reshape(-1, p_dim)
14061413

14071414
# Find nearest codewords
14081415
dists = torch.cdist(groups, q_cb.float())
@@ -1413,14 +1420,17 @@ def hook(module, input, output):
14131420
dq_groups = d_cb[idx]
14141421
dq_vq = dq_groups.reshape(out_dim * n_rot, elems_per_p)
14151422

1423+
# Denormalize - reshape absmax to match dq_vq shape
1424+
absmax_reshaped = absmax_vals.reshape(out_dim * n_rot, -1)
1425+
dq_vq_denorm = dq_vq * absmax_reshaped
1426+
14161427
if rem > 0:
1417-
rem_part = W_unit.reshape(out_dim * n_rot, actual_rot_bs)[:, elems_per_p:]
1418-
dq_blocks = torch.cat([dq_vq, rem_part], dim=1)
1428+
rem_part = W_rot_reshaped[:, elems_per_p:]
1429+
dq_blocks = torch.cat([dq_vq_denorm, rem_part], dim=1)
14191430
else:
1420-
dq_blocks = dq_vq
1431+
dq_blocks = dq_vq_denorm
14211432

1422-
# Denormalize
1423-
W_q = (dq_blocks * absmax_vals).reshape(W_rot.shape)
1433+
W_q = dq_blocks.reshape(W_rot.shape)
14241434
else:
14251435
# L2 norm - simpler case
14261436
W_flat = W_rot.reshape(-1, p_dim)

0 commit comments

Comments
 (0)