@@ -1343,15 +1343,15 @@ def main():
13431343 key = (k , p )
13441344 quant_cb , deq_cb = unique_configs [key ]
13451345
1346- # Use existing BNF hook installation (inline)
1347- # Reuse the closure pattern from install_bnf_hooks/install_l2_hooks
1348- sign_key = f"hadamard_sign_ { args . seed } "
1349- had_sign = None
1350-
1351- def make_hook ( q_cb , d_cb , bs , rot_bs , sign_seed , norm_type , p_dim ):
1352- def hook ( module , input , output ):
1353- # Use module weight directly (quantized in-place )
1354- W = module .weight .data
1346+ # Use pre_hook (runs BEFORE forward) to quantize weights before use
1347+ # Then post_hook to restore original weights
1348+ def make_pre_hook ( q_cb , d_cb , bs , rot_bs , sign_seed , norm_type , p_dim ):
1349+ def pre_hook ( mod , args_ ):
1350+ if not _hooks_enabled :
1351+ return
1352+ # Save original weight
1353+ mod . _orig_weight = mod . weight . data . clone ( )
1354+ W = mod .weight .data
13551355 dtype = W .dtype
13561356 W_float = W .float ()
13571357
@@ -1366,14 +1366,11 @@ def hook(module, input, output):
13661366 else :
13671367 actual_rot_bs = rot_bs
13681368
1369- # Reshape for rotation
13701369 W_reshaped = W_float .reshape (out_dim * n_rot , actual_rot_bs )
13711370
13721371 # Get or create sign vector
1373- nonlocal had_sign
1374- if had_sign is None :
1375- torch .manual_seed (sign_seed )
1376- had_sign = (2 * (torch .rand (actual_rot_bs , device = W .device ) > 0.5 ).float () - 1 ).to (W .device )
1372+ torch .manual_seed (sign_seed )
1373+ had_sign = (2 * (torch .rand (actual_rot_bs , device = W .device ) > 0.5 ).float () - 1 ).to (W .device )
13771374
13781375 # Apply sign and Hadamard
13791376 W_signed = W_reshaped * had_sign .unsqueeze (0 )
@@ -1386,7 +1383,7 @@ def hook(module, input, output):
13861383
13871384 # Quantization
13881385 if norm_type == 'absmax' :
1389- # Flatten, pad, blockwise absmax quantization (same pattern as install_bnf_hooks)
1386+ # Flatten, pad, blockwise absmax quantization
13901387 flat = W_rot .flatten ()
13911388 n = flat .numel ()
13921389 pad_n = (bs - n % bs ) % bs
@@ -1419,12 +1416,11 @@ def hook(module, input, output):
14191416 else :
14201417 dequantized = dq_vq * absmax
14211418 else :
1422- # p_dim > bs, can't do VQ - keep normalized
14231419 dequantized = normalized * absmax
14241420
14251421 W_q = dequantized .flatten ()[:n ].reshape (W_rot .shape )
14261422 else :
1427- # L2 norm - simpler case
1423+ # L2 norm
14281424 W_flat = W_rot .reshape (- 1 , p_dim )
14291425 dists = torch .cdist (W_flat , q_cb .float ())
14301426 idx = dists .argmin (dim = 1 )
@@ -1439,15 +1435,21 @@ def hook(module, input, output):
14391435 else :
14401436 W_final = W_q .reshape (W .shape )
14411437
1442- module .weight .data = W_final .to (dtype )
1443- return output
1444- return hook
1438+ mod .weight .data = W_final .to (dtype )
1439+ return pre_hook
14451440
1446- handle = module .register_forward_hook (
1447- make_hook (quant_cb , deq_cb , args .blocksize , args .rot_blocksize ,
1448- args .seed , args .norm , p )
1441+ def post_hook (mod , args_ , output ):
1442+ # Restore original weight after forward
1443+ if hasattr (mod , '_orig_weight' ):
1444+ mod .weight .data = mod ._orig_weight
1445+ del mod ._orig_weight
1446+
1447+ h1 = module .register_forward_pre_hook (
1448+ make_pre_hook (quant_cb , deq_cb , args .blocksize , args .rot_blocksize ,
1449+ args .seed , args .norm , p )
14491450 )
1450- hooks .append (handle )
1451+ h2 = module .register_forward_hook (post_hook )
1452+ hooks .extend ([h1 , h2 ])
14511453
14521454 print (f"Installed { len (hooks )} per-layer quantization hooks" )
14531455 effective_bits = avg_bits
0 commit comments