@@ -1386,23 +1386,30 @@ def hook(module, input, output):
13861386
13871387 # Quantization
13881388 if norm_type == 'absmax' :
1389- # Block-wise absmax
1390- W_blocks = W_rot .reshape (- 1 , bs )
1391- absmax_vals = W_blocks .abs ().max (dim = 1 , keepdim = True )[0 ]
1392- absmax_vals = absmax_vals .clamp_min (1e-8 )
1393-
1394- W_unit = W_blocks / absmax_vals
1395-
1396- # VQ quantization
1389+ # VQ quantization setup
13971390 elems_per_p = (actual_rot_bs // p_dim ) * p_dim
13981391 rem = actual_rot_bs - elems_per_p
13991392
1393+ # Reshape for VQ: [out_dim * n_rot, actual_rot_bs]
1394+ W_rot_reshaped = W_rot .reshape (out_dim * n_rot , actual_rot_bs )
1395+
1396+ # Compute absmax on VQ-compatible portion (excluding remainder)
14001397 if rem > 0 :
1401- vq_part = W_unit . reshape ( out_dim * n_rot , actual_rot_bs ) [:, :elems_per_p ]
1398+ W_for_vq = W_rot_reshaped [:, :elems_per_p ]
14021399 else :
1403- vq_part = W_unit
1400+ W_for_vq = W_rot_reshaped
14041401
1405- groups = vq_part .reshape (- 1 , p_dim )
1402+ # Reshape to blocks for absmax: [out_dim * n_rot * elems_per_p / bs, bs]
1403+ W_blocks_vq = W_for_vq .reshape (- 1 , bs )
1404+ absmax_vals = W_blocks_vq .abs ().max (dim = 1 , keepdim = True )[0 ]
1405+ absmax_vals = absmax_vals .clamp_min (1e-8 )
1406+
1407+ # Normalize
1408+ W_unit_blocks = W_blocks_vq / absmax_vals
1409+ W_unit = W_unit_blocks .reshape (out_dim * n_rot , elems_per_p )
1410+
1411+ # VQ quantization
1412+ groups = W_unit .reshape (- 1 , p_dim )
14061413
14071414 # Find nearest codewords
14081415 dists = torch .cdist (groups , q_cb .float ())
@@ -1413,14 +1420,17 @@ def hook(module, input, output):
14131420 dq_groups = d_cb [idx ]
14141421 dq_vq = dq_groups .reshape (out_dim * n_rot , elems_per_p )
14151422
1423+ # Denormalize - reshape absmax to match dq_vq shape
1424+ absmax_reshaped = absmax_vals .reshape (out_dim * n_rot , - 1 )
1425+ dq_vq_denorm = dq_vq * absmax_reshaped
1426+
14161427 if rem > 0 :
1417- rem_part = W_unit . reshape ( out_dim * n_rot , actual_rot_bs ) [:, elems_per_p :]
1418- dq_blocks = torch .cat ([dq_vq , rem_part ], dim = 1 )
1428+ rem_part = W_rot_reshaped [:, elems_per_p :]
1429+ dq_blocks = torch .cat ([dq_vq_denorm , rem_part ], dim = 1 )
14191430 else :
1420- dq_blocks = dq_vq
1431+ dq_blocks = dq_vq_denorm
14211432
1422- # Denormalize
1423- W_q = (dq_blocks * absmax_vals ).reshape (W_rot .shape )
1433+ W_q = dq_blocks .reshape (W_rot .shape )
14241434 else :
14251435 # L2 norm - simpler case
14261436 W_flat = W_rot .reshape (- 1 , p_dim )
0 commit comments