@@ -1386,56 +1386,43 @@ def hook(module, input, output):
13861386
13871387 # Quantization
13881388 if norm_type == 'absmax' :
1389- # VQ quantization setup
1390- elems_per_p = (actual_rot_bs // p_dim ) * p_dim
1391- rem = actual_rot_bs - elems_per_p
1392-
1393- # Reshape for VQ: [out_dim * n_rot, actual_rot_bs]
1394- W_rot_reshaped = W_rot .reshape (out_dim * n_rot , actual_rot_bs )
1395-
1396- # Compute absmax on VQ-compatible portion (excluding remainder)
1397- if rem > 0 :
1398- W_for_vq = W_rot_reshaped [:, :elems_per_p ]
1399- else :
1400- W_for_vq = W_rot_reshaped
1401-
1402- # Reshape to blocks for absmax: [out_dim * n_rot * elems_per_p / bs, bs]
1403- W_blocks_vq = W_for_vq .reshape (- 1 , bs )
1404- absmax_vals = W_blocks_vq .abs ().max (dim = 1 , keepdim = True )[0 ]
1405- absmax_vals = absmax_vals .clamp_min (1e-8 )
1406-
1407- # Normalize
1408- W_unit_blocks = W_blocks_vq / absmax_vals
1409- W_unit = W_unit_blocks .reshape (out_dim * n_rot , elems_per_p )
1410-
1411- # VQ quantization
1412- groups = W_unit .reshape (- 1 , p_dim )
1413-
1414- # Find nearest codewords
1415- dists = torch .cdist (groups , q_cb .float ())
1416- idx = dists .argmin (dim = 1 )
1417- q_groups = q_cb [idx ]
1418-
1419- # Dequantize
1420- dq_groups = d_cb [idx ]
1421- dq_vq = dq_groups .reshape (out_dim * n_rot , elems_per_p )
1422-
1423- # Denormalize - absmax has one value per block of bs elements
1424- # absmax_vals: [out_dim * n_rot * elems_per_p / bs, 1]
1425- # dq_vq: [out_dim * n_rot, elems_per_p]
1426- # Need to reshape absmax to [out_dim * n_rot, elems_per_p / bs, 1] and broadcast
1427- n_blocks_per_row = elems_per_p // bs
1428- absmax_reshaped = absmax_vals .reshape (out_dim * n_rot , n_blocks_per_row , 1 )
1429- dq_vq_reshaped = dq_vq .reshape (out_dim * n_rot , n_blocks_per_row , bs )
1430- dq_vq_denorm = (dq_vq_reshaped * absmax_reshaped ).reshape (out_dim * n_rot , elems_per_p )
1431-
1432- if rem > 0 :
1433- rem_part = W_rot_reshaped [:, elems_per_p :]
1434- dq_blocks = torch .cat ([dq_vq_denorm , rem_part ], dim = 1 )
1389+ # Flatten, pad, blockwise absmax quantization (same pattern as install_bnf_hooks)
1390+ flat = W_rot .flatten ()
1391+ n = flat .numel ()
1392+ pad_n = (bs - n % bs ) % bs
1393+ if pad_n > 0 :
1394+ flat = torch .nn .functional .pad (flat , (0 , pad_n ))
1395+
1396+ blocks = flat .reshape (- 1 , bs )
1397+ absmax = blocks .abs ().amax (dim = 1 , keepdim = True ).clamp_ (min = 1e-12 )
1398+ normalized = blocks / absmax
1399+
1400+ # VQ quantization on normalized blocks
1401+ elems_per_p = (bs // p_dim ) * p_dim
1402+ rem = bs - elems_per_p
1403+
1404+ if elems_per_p > 0 :
1405+ vq_part = normalized [:, :elems_per_p ]
1406+ groups = vq_part .reshape (- 1 , p_dim )
1407+
1408+ # Find nearest codewords
1409+ dists = torch .cdist (groups , q_cb .float ())
1410+ idx = dists .argmin (dim = 1 )
1411+
1412+ # Dequantize
1413+ dq_groups = d_cb [idx ]
1414+ dq_vq = dq_groups .reshape (normalized .shape [0 ], elems_per_p )
1415+
1416+ if rem > 0 :
1417+ rem_part = normalized [:, elems_per_p :]
1418+ dequantized = torch .cat ([dq_vq , rem_part ], dim = 1 ) * absmax
1419+ else :
1420+ dequantized = dq_vq * absmax
14351421 else :
1436- dq_blocks = dq_vq_denorm
1422+ # p_dim > bs, can't do VQ - keep normalized
1423+ dequantized = normalized * absmax
14371424
1438- W_q = dq_blocks .reshape (W_rot .shape )
1425+ W_q = dequantized . flatten ()[: n ] .reshape (W_rot .shape )
14391426 else :
14401427 # L2 norm - simpler case
14411428 W_flat = W_rot .reshape (- 1 , p_dim )
0 commit comments