@@ -968,33 +968,6 @@ def eval_ppl(model, testenc, seqlen, device, compare_fp16=False):
968968# HIGGS support
969969# ============================================================
970970
971- def compute_absmax_codebook (k , p , device = 'cuda' ):
972- """Compute uniform codebook for absmax quantization."""
973- n_entries = 2 ** k
974- values = torch .linspace (- 1 , 1 , n_entries , device = device )
975- if p == 1 :
976- return values .view (- 1 , 1 )
977- else :
978- grids = torch .meshgrid (* ([values ] * p ), indexing = 'ij' )
979- codebook = torch .stack ([g .flatten () for g in grids ], dim = 1 )
980- return codebook
981-
982-
983- def compute_l2_codebook (k , p , device = 'cuda' ):
984- """Compute codebook for L2 quantization."""
985- n_entries = 2 ** k
986- values = torch .linspace (- 1 , 1 , int (n_entries ** (1 / p )) + 1 , device = device )
987- if p == 1 :
988- return values .view (- 1 , 1 )
989- else :
990- grids = torch .meshgrid (* ([values ] * p ), indexing = 'ij' )
991- codebook = torch .stack ([g .flatten () for g in grids ], dim = 1 )
992- # Keep only n_entries closest to origin
993- norms = torch .norm (codebook , dim = 1 )
994- _ , indices = torch .sort (norms )
995- return codebook [indices [:n_entries ]]
996-
997-
998971def load_higgs_assignment (assignment_path ):
999972 """Load HIGGS bitwidth assignment from JSON file.
1000973
@@ -1338,54 +1311,139 @@ def main():
13381311 opt = options [opt_idx ]
13391312 print (f" { opt ['config_str' ]} : { count } layers" )
13401313
1314+ # Pre-compute codebooks for each unique (k, p) in the assignment
1315+ dev = next (model .parameters ()).device
1316+ unique_configs = {}
1317+ for opt_idx in set (assignment .values ()):
1318+ opt = options [opt_idx ]
1319+ k , p = opt ['k' ], opt ['p' ]
1320+ key = (k , p )
1321+ if key not in unique_configs :
1322+ # Use existing BNF codebook computation
1323+ quant_cb , deq_cb , _ , _ = compute_codebook (
1324+ k , p , blocksize = args .blocksize , device = dev
1325+ )
1326+ unique_configs [key ] = (quant_cb , deq_cb )
1327+ print (f" Computed codebook for k={ k } , p={ p } : { quant_cb .shape } " )
1328+
13411329 # Get all linear layers
13421330 linear_layers = []
13431331 for name , module in model .named_modules ():
13441332 if isinstance (module , nn .Linear ) and 'embed' not in name and 'lm_head' not in name :
13451333 linear_layers .append (module )
13461334
1347- # Apply per-layer quantization
1348- dev = next (model .parameters ()).device
1335+ # Apply per-layer quantization hooks
13491336 hooks = []
13501337 for layer_idx , module in enumerate (linear_layers ):
13511338 if layer_idx not in assignment :
13521339 continue
13531340 opt_idx = assignment [layer_idx ]
13541341 opt = options [opt_idx ]
13551342 k , p = opt ['k' ], opt ['p' ]
1343+ key = (k , p )
1344+ quant_cb , deq_cb = unique_configs [key ]
13561345
1357- # Compute codebook for this layer
1358- if args .norm == 'l2' :
1359- quant_cb = compute_l2_codebook (k , p , device = dev )
1360- else :
1361- quant_cb = compute_absmax_codebook (k , p , device = dev )
1346+ # Use existing BNF hook installation (inline)
1347+ # Reuse the closure pattern from install_bnf_hooks/install_l2_hooks
1348+ sign_key = f"hadamard_sign_{ args .seed } "
1349+ had_sign = None
13621350
1363- # Create hook
1364- def make_hook (k_val , p_val , cb , norm , bs , rot_bs ):
1351+ def make_hook (q_cb , d_cb , bs , rot_bs , sign_seed , norm_type , p_dim ):
13651352 def hook (module , input , output ):
1366- if not hasattr (module , '_higgs_quantized' ):
1367- w = module .weight .data
1368- weight_dtype = w .dtype
1369-
1370- # Apply Hadamard rotation if L2
1371- if norm == 'l2' :
1372- # Simple rotation (without full block structure for now)
1373- w_flat = w .reshape (- 1 , p_val ).float ()
1353+ # Use module weight directly (quantized in-place)
1354+ W = module .weight .data
1355+ dtype = W .dtype
1356+ W_float = W .float ()
1357+
1358+ out_dim , in_dim = W .shape
1359+
1360+ # Apply Hadamard rotation
1361+ if p_dim > 1 or norm_type == 'l2' :
1362+ n_rot = in_dim // rot_bs
1363+ if n_rot == 0 :
1364+ n_rot = 1
1365+ actual_rot_bs = in_dim
13741366 else :
1375- w_flat = w .reshape (- 1 , p_val ).float ()
1367+ actual_rot_bs = rot_bs
1368+
1369+ # Reshape for rotation
1370+ W_reshaped = W_float .reshape (out_dim * n_rot , actual_rot_bs )
1371+
1372+ # Get or create sign vector
1373+ nonlocal had_sign
1374+ if had_sign is None :
1375+ torch .manual_seed (sign_seed )
1376+ had_sign = (2 * (torch .rand (actual_rot_bs , device = W .device ) > 0.5 ).float () - 1 ).to (W .device )
1377+
1378+ # Apply sign and Hadamard
1379+ W_signed = W_reshaped * had_sign .unsqueeze (0 )
1380+ H = torch .tensor (hadamard (actual_rot_bs ), dtype = torch .float32 , device = W .device )
1381+ W_rot = (W_signed @ H ) / torch .sqrt (torch .tensor (actual_rot_bs , dtype = torch .float32 ))
1382+ else :
1383+ W_rot = W_float
1384+ n_rot = 1
1385+ actual_rot_bs = in_dim
1386+
1387+ # Quantization
1388+ if norm_type == 'absmax' :
1389+ # Block-wise absmax
1390+ W_blocks = W_rot .reshape (- 1 , bs )
1391+ absmax_vals = W_blocks .abs ().max (dim = 1 , keepdim = True )[0 ]
1392+ absmax_vals = absmax_vals .clamp_min (1e-8 )
1393+
1394+ W_unit = W_blocks / absmax_vals
1395+
1396+ # VQ quantization
1397+ elems_per_p = (actual_rot_bs // p_dim ) * p_dim
1398+ rem = actual_rot_bs - elems_per_p
1399+
1400+ if rem > 0 :
1401+ vq_part = W_unit .reshape (out_dim * n_rot , actual_rot_bs )[:, :elems_per_p ]
1402+ else :
1403+ vq_part = W_unit
1404+
1405+ groups = vq_part .reshape (- 1 , p_dim )
13761406
1377- # VQ quantize
1378- dists = torch .cdist (w_flat , cb .float ())
1379- indices = dists .argmin (dim = 1 )
1380- w_q = cb [ indices ]. reshape ( w . shape ). to ( weight_dtype )
1407+ # Find nearest codewords
1408+ dists = torch .cdist (groups , q_cb .float ())
1409+ idx = dists .argmin (dim = 1 )
1410+ q_groups = q_cb [ idx ]
13811411
1382- module .weight .data = w_q
1383- module ._higgs_quantized = True
1412+ # Dequantize
1413+ dq_groups = d_cb [idx ]
1414+ dq_vq = dq_groups .reshape (out_dim * n_rot , elems_per_p )
1415+
1416+ if rem > 0 :
1417+ rem_part = W_unit .reshape (out_dim * n_rot , actual_rot_bs )[:, elems_per_p :]
1418+ dq_blocks = torch .cat ([dq_vq , rem_part ], dim = 1 )
1419+ else :
1420+ dq_blocks = dq_vq
1421+
1422+ # Denormalize
1423+ W_q = (dq_blocks * absmax_vals ).reshape (W_rot .shape )
1424+ else :
1425+ # L2 norm - simpler case
1426+ W_flat = W_rot .reshape (- 1 , p_dim )
1427+ dists = torch .cdist (W_flat , q_cb .float ())
1428+ idx = dists .argmin (dim = 1 )
1429+ W_q = d_cb [idx ].reshape (W_rot .shape )
1430+
1431+ # Inverse Hadamard if needed
1432+ if p_dim > 1 or norm_type == 'l2' :
1433+ W_deshaped = W_q .reshape (out_dim * n_rot , actual_rot_bs )
1434+ W_unrot = (W_deshaped @ H .T ) * torch .sqrt (torch .tensor (actual_rot_bs , dtype = torch .float32 ))
1435+ W_unrot = W_unrot * had_sign .unsqueeze (0 )
1436+ W_final = W_unrot .reshape (W .shape )
1437+ else :
1438+ W_final = W_q .reshape (W .shape )
1439+
1440+ module .weight .data = W_final .to (dtype )
13841441 return output
13851442 return hook
13861443
13871444 handle = module .register_forward_hook (
1388- make_hook (k , p , quant_cb , args .norm , args .blocksize , args .rot_blocksize )
1445+ make_hook (quant_cb , deq_cb , args .blocksize , args .rot_blocksize ,
1446+ args .seed , args .norm , p )
13891447 )
13901448 hooks .append (handle )
13911449
0 commit comments