Skip to content

Commit f4baf7e

Browse files
committed
Fix absmax broadcasting for block-wise denormalization
1 parent 5cff1c3 commit f4baf7e

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

baselines/opt_sym/eval_ppl.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1420,9 +1420,14 @@ def hook(module, input, output):
14201420
dq_groups = d_cb[idx]
14211421
dq_vq = dq_groups.reshape(out_dim * n_rot, elems_per_p)
14221422

1423-
# Denormalize - reshape absmax to match dq_vq shape
1424-
absmax_reshaped = absmax_vals.reshape(out_dim * n_rot, -1)
1425-
dq_vq_denorm = dq_vq * absmax_reshaped
1423+
# Denormalize - absmax has one value per block of bs elements
1424+
# absmax_vals: [out_dim * n_rot * elems_per_p / bs, 1]
1425+
# dq_vq: [out_dim * n_rot, elems_per_p]
1426+
# Need to reshape absmax to [out_dim * n_rot, elems_per_p / bs, 1] and broadcast
1427+
n_blocks_per_row = elems_per_p // bs
1428+
absmax_reshaped = absmax_vals.reshape(out_dim * n_rot, n_blocks_per_row, 1)
1429+
dq_vq_reshaped = dq_vq.reshape(out_dim * n_rot, n_blocks_per_row, bs)
1430+
dq_vq_denorm = (dq_vq_reshaped * absmax_reshaped).reshape(out_dim * n_rot, elems_per_p)
14261431

14271432
if rem > 0:
14281433
rem_part = W_rot_reshaped[:, elems_per_p:]

0 commit comments

Comments
 (0)