@@ -1343,113 +1343,104 @@ def main():
13431343 key = (k , p )
13441344 quant_cb , deq_cb = unique_configs [key ]
13451345
1346- # Use pre_hook (runs BEFORE forward) to quantize weights before use
1347- # Then post_hook to restore original weights
1348- def make_pre_hook (q_cb , d_cb , bs , rot_bs , sign_seed , norm_type , p_dim ):
1349- def pre_hook (mod , args_ ):
1350- if not _hooks_enabled :
1351- return
1352- # Save original weight
1353- mod ._orig_weight = mod .weight .data .clone ()
1354- W = mod .weight .data
1355- dtype = W .dtype
1356- W_float = W .float ()
1357-
1358- out_dim , in_dim = W .shape
1359-
1360- # Apply Hadamard rotation
1361- if p_dim > 1 or norm_type == 'l2' :
1362- n_rot = in_dim // rot_bs
1363- if n_rot == 0 :
1364- n_rot = 1
1365- actual_rot_bs = in_dim
1366- else :
1367- actual_rot_bs = rot_bs
1368-
1369- W_reshaped = W_float .reshape (out_dim * n_rot , actual_rot_bs )
1370-
1371- # Get or create sign vector
1372- torch .manual_seed (sign_seed )
1373- had_sign = (2 * (torch .rand (actual_rot_bs , device = W .device ) > 0.5 ).float () - 1 ).to (W .device )
1374-
1375- # Apply sign and Hadamard (match working BNF pattern)
1376- W_signed = W_reshaped * had_sign .unsqueeze (0 )
1377- H = torch .tensor (hadamard (actual_rot_bs ), dtype = torch .float32 , device = W .device )
1378- W_rot = W_signed @ H .T # No normalization - BNF doesn't divide by sqrt(n)
1379- else :
1380- W_rot = W_float
1346+ # Store per-layer config for custom BNF hook
1347+ module ._higgs_quant_cb = quant_cb .to (device )
1348+ module ._higgs_deq_cb = deq_cb .to (device )
1349+ module ._higgs_p = p
1350+
1351+ # Install custom BNF hooks that check for per-layer codebooks
1352+ H_block = make_hadamard_block (args .rot_blocksize , device )
1353+ max_in = max (m .weight .shape [1 ] for m in model .modules ()
1354+ if isinstance (m , nn .Linear ))
1355+ signs = make_random_signs (max_in , args .seed , device )
1356+
1357+ def make_higgs_pre_hook (bs , rot_bs , sign_seed , norm_type ):
1358+ def hook (mod , args_ ):
1359+ if not _hooks_enabled :
1360+ return
1361+ # Check if this module has HIGGS config
1362+ if not hasattr (mod , '_higgs_quant_cb' ):
1363+ return
1364+
1365+ q_cb = mod ._higgs_quant_cb
1366+ d_cb = mod ._higgs_deq_cb
1367+ p_dim = mod ._higgs_p
1368+
1369+ mod ._orig_weight = mod .weight .data .clone ()
1370+ W = mod .weight .data .float ()
1371+ out_dim , in_dim = W .shape
1372+
1373+ # Apply Hadamard rotation
1374+ if p_dim > 1 or norm_type == 'l2' :
1375+ n_rot = in_dim // rot_bs
1376+ if n_rot == 0 :
13811377 n_rot = 1
13821378 actual_rot_bs = in_dim
1383-
1384- # Quantization
1385- if norm_type == 'absmax' :
1386- # Flatten, pad, blockwise absmax quantization
1387- flat = W_rot .flatten ()
1388- n = flat .numel ()
1389- pad_n = (bs - n % bs ) % bs
1390- if pad_n > 0 :
1391- flat = torch .nn .functional .pad (flat , (0 , pad_n ))
1392-
1393- blocks = flat .reshape (- 1 , bs )
1394- absmax = blocks .abs ().amax (dim = 1 , keepdim = True ).clamp_ (min = 1e-12 )
1395- normalized = blocks / absmax
1396-
1397- # VQ quantization on normalized blocks
1398- elems_per_p = (bs // p_dim ) * p_dim
1399- rem = bs - elems_per_p
1400-
1401- if elems_per_p > 0 :
1402- vq_part = normalized [:, :elems_per_p ]
1403- groups = vq_part .reshape (- 1 , p_dim )
1404-
1405- # Find nearest codewords (use _chunked_nearest like working BNF)
1406- idx = _chunked_nearest (groups , q_cb .to (W .device ), chunk_size = 100000 )
1407-
1408- # Dequantize (returns values on unit sphere)
1409- dq_groups = d_cb .to (W .device )[idx ]
1410- dq_vq = dq_groups .reshape (normalized .shape [0 ], elems_per_p )
1411-
1412- if rem > 0 :
1413- # rem_part is normalized (unit scale), dq_vq is normalized
1414- # Both need to be multiplied by absmax to denormalize
1415- rem_part = normalized [:, elems_per_p :]
1416- dequantized = torch .cat ([dq_vq , rem_part ], dim = 1 ) * absmax
1417- else :
1418- dequantized = dq_vq * absmax
1419- else :
1420- dequantized = normalized * absmax
1421-
1422- W_q = dequantized .flatten ()[:n ].reshape (W_rot .shape )
14231379 else :
1424- # L2 norm
1425- W_flat = W_rot .reshape (- 1 , p_dim )
1426- idx = _chunked_nearest (W_flat , q_cb .to (W .device ), chunk_size = 100000 )
1427- W_q = d_cb .to (W .device )[idx ].reshape (W_rot .shape )
1428-
1429- # Inverse Hadamard if needed (match working BNF pattern)
1430- if p_dim > 1 or norm_type == 'l2' :
1431- W_deshaped = W_q .reshape (out_dim * n_rot , actual_rot_bs )
1432- W_unrot = W_deshaped @ H # No normalization - BNF doesn't multiply by sqrt(n)
1433- W_unrot = W_unrot * had_sign .unsqueeze (0 )
1434- W_final = W_unrot .reshape (W .shape )
1380+ actual_rot_bs = rot_bs
1381+
1382+ W_reshaped = W .reshape (out_dim * n_rot , actual_rot_bs )
1383+ W_signed = W_reshaped * signs [:actual_rot_bs ].unsqueeze (0 )
1384+ W_rot = W_signed @ H_block .T
1385+ else :
1386+ W_rot = W
1387+ n_rot = 1
1388+ actual_rot_bs = in_dim
1389+
1390+ # Quantize
1391+ flat = W_rot .flatten ()
1392+ n = flat .numel ()
1393+ pad_n = (bs - n % bs ) % bs
1394+ if pad_n > 0 :
1395+ flat = torch .nn .functional .pad (flat , (0 , pad_n ))
1396+
1397+ blocks = flat .reshape (- 1 , bs )
1398+ absmax = blocks .abs ().amax (dim = 1 , keepdim = True ).clamp_ (min = 1e-12 )
1399+ normalized = blocks / absmax
1400+
1401+ elems_per_p = (bs // p_dim ) * p_dim
1402+ rem = bs - elems_per_p
1403+
1404+ if elems_per_p > 0 :
1405+ vq_part = normalized [:, :elems_per_p ]
1406+ groups = vq_part .reshape (- 1 , p_dim )
1407+ idx = _chunked_nearest (groups , q_cb , chunk_size = 100000 )
1408+ dq_groups = d_cb [idx ]
1409+ dq_vq = dq_groups .reshape (normalized .shape [0 ], elems_per_p )
1410+
1411+ if rem > 0 :
1412+ rem_part = normalized [:, elems_per_p :]
1413+ dequantized = torch .cat ([dq_vq , rem_part ], dim = 1 ) * absmax
14351414 else :
1436- W_final = W_q .reshape (W .shape )
1415+ dequantized = dq_vq * absmax
1416+ else :
1417+ dequantized = normalized * absmax
1418+
1419+ dequantized = dequantized .flatten ()[:n ].reshape (W_rot .shape )
14371420
1438- mod .weight .data = W_final .to (dtype )
1439- return pre_hook
1421+ # Inverse Hadamard
1422+ if p_dim > 1 or norm_type == 'l2' :
1423+ W_deshaped = dequantized .reshape (out_dim * n_rot , actual_rot_bs )
1424+ W_unrot = W_deshaped @ H_block
1425+ W_unrot = W_unrot * signs [:actual_rot_bs ].unsqueeze (0 )
1426+ W_final = W_unrot .reshape (W .shape )
1427+ else :
1428+ W_final = dequantized .reshape (W .shape )
14401429
1441- def post_hook (mod , args_ , output ):
1442- # Restore original weight after forward
1443- if hasattr (mod , '_orig_weight' ):
1444- mod .weight .data = mod ._orig_weight
1445- del mod ._orig_weight
1430+ mod .weight .data = W_final .to (mod .weight .dtype )
1431+ return hook
14461432
1447- h1 = module .register_forward_pre_hook (
1448- make_pre_hook (quant_cb , deq_cb , args .blocksize , args .rot_blocksize ,
1449- args .seed , args .norm , p )
1450- )
1451- h2 = module .register_forward_hook (post_hook )
1452- hooks .extend ([h1 , h2 ])
1433+ def post_hook (mod , args_ , output ):
1434+ if hasattr (mod , '_orig_weight' ):
1435+ mod .weight .data = mod ._orig_weight
1436+ del mod ._orig_weight
1437+
1438+ for name , module in model .named_modules ():
1439+ if isinstance (module , nn .Linear ) and hasattr (module , '_higgs_quant_cb' ):
1440+ h1 = module .register_forward_pre_hook (
1441+ make_higgs_pre_hook (args .blocksize , args .rot_blocksize , args .seed , args .norm ))
1442+ h2 = module .register_forward_hook (post_hook )
1443+ hooks .extend ([h1 , h2 ])
14531444
14541445 print (f"Installed { len (hooks )} per-layer quantization hooks" )
14551446 effective_bits = avg_bits
0 commit comments