Skip to content

Commit 0a12b73

Browse files
committed
Fix HIGGS hook: use pre_hook (before forward) instead of post_hook
1 parent 2495cf3 commit 0a12b73

File tree

1 file changed

+26
-24
lines changed

1 file changed

+26
-24
lines changed

baselines/opt_sym/eval_ppl.py

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1343,15 +1343,15 @@ def main():
13431343
key = (k, p)
13441344
quant_cb, deq_cb = unique_configs[key]
13451345

1346-
# Use existing BNF hook installation (inline)
1347-
# Reuse the closure pattern from install_bnf_hooks/install_l2_hooks
1348-
sign_key = f"hadamard_sign_{args.seed}"
1349-
had_sign = None
1350-
1351-
def make_hook(q_cb, d_cb, bs, rot_bs, sign_seed, norm_type, p_dim):
1352-
def hook(module, input, output):
1353-
# Use module weight directly (quantized in-place)
1354-
W = module.weight.data
1346+
# Use pre_hook (runs BEFORE forward) to quantize weights before use
1347+
# Then post_hook to restore original weights
1348+
def make_pre_hook(q_cb, d_cb, bs, rot_bs, sign_seed, norm_type, p_dim):
1349+
def pre_hook(mod, args_):
1350+
if not _hooks_enabled:
1351+
return
1352+
# Save original weight
1353+
mod._orig_weight = mod.weight.data.clone()
1354+
W = mod.weight.data
13551355
dtype = W.dtype
13561356
W_float = W.float()
13571357

@@ -1366,14 +1366,11 @@ def hook(module, input, output):
13661366
else:
13671367
actual_rot_bs = rot_bs
13681368

1369-
# Reshape for rotation
13701369
W_reshaped = W_float.reshape(out_dim * n_rot, actual_rot_bs)
13711370

13721371
# Get or create sign vector
1373-
nonlocal had_sign
1374-
if had_sign is None:
1375-
torch.manual_seed(sign_seed)
1376-
had_sign = (2 * (torch.rand(actual_rot_bs, device=W.device) > 0.5).float() - 1).to(W.device)
1372+
torch.manual_seed(sign_seed)
1373+
had_sign = (2 * (torch.rand(actual_rot_bs, device=W.device) > 0.5).float() - 1).to(W.device)
13771374

13781375
# Apply sign and Hadamard
13791376
W_signed = W_reshaped * had_sign.unsqueeze(0)
@@ -1386,7 +1383,7 @@ def hook(module, input, output):
13861383

13871384
# Quantization
13881385
if norm_type == 'absmax':
1389-
# Flatten, pad, blockwise absmax quantization (same pattern as install_bnf_hooks)
1386+
# Flatten, pad, blockwise absmax quantization
13901387
flat = W_rot.flatten()
13911388
n = flat.numel()
13921389
pad_n = (bs - n % bs) % bs
@@ -1419,12 +1416,11 @@ def hook(module, input, output):
14191416
else:
14201417
dequantized = dq_vq * absmax
14211418
else:
1422-
# p_dim > bs, can't do VQ - keep normalized
14231419
dequantized = normalized * absmax
14241420

14251421
W_q = dequantized.flatten()[:n].reshape(W_rot.shape)
14261422
else:
1427-
# L2 norm - simpler case
1423+
# L2 norm
14281424
W_flat = W_rot.reshape(-1, p_dim)
14291425
dists = torch.cdist(W_flat, q_cb.float())
14301426
idx = dists.argmin(dim=1)
@@ -1439,15 +1435,21 @@ def hook(module, input, output):
14391435
else:
14401436
W_final = W_q.reshape(W.shape)
14411437

1442-
module.weight.data = W_final.to(dtype)
1443-
return output
1444-
return hook
1438+
mod.weight.data = W_final.to(dtype)
1439+
return pre_hook
14451440

1446-
handle = module.register_forward_hook(
1447-
make_hook(quant_cb, deq_cb, args.blocksize, args.rot_blocksize,
1448-
args.seed, args.norm, p)
1441+
def post_hook(mod, args_, output):
1442+
# Restore original weight after forward
1443+
if hasattr(mod, '_orig_weight'):
1444+
mod.weight.data = mod._orig_weight
1445+
del mod._orig_weight
1446+
1447+
h1 = module.register_forward_pre_hook(
1448+
make_pre_hook(quant_cb, deq_cb, args.blocksize, args.rot_blocksize,
1449+
args.seed, args.norm, p)
14491450
)
1450-
hooks.append(handle)
1451+
h2 = module.register_forward_hook(post_hook)
1452+
hooks.extend([h1, h2])
14511453

14521454
print(f"Installed {len(hooks)} per-layer quantization hooks")
14531455
effective_bits = avg_bits

0 commit comments

Comments
 (0)