Skip to content

Commit 3786bcb

Browse files
TimDettmersclaude
andcommitted
fix: Use flat padding for NVFP4 output quantization
Row-level padding corrupted block scales across row boundaries. Instead, flatten M*N and pad only the flat vector if needed. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 18f068f commit 3786bcb

File tree

1 file changed

+11
-14
lines changed

1 file changed

+11
-14
lines changed

bitsandbytes/functional.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1268,22 +1268,19 @@ def gemm_nvfp4_to_nvfp4(
12681268
D_fp32.mul_(alpha)
12691269

12701270
# Step 3: Quantize FP32 output → NVFP4
1271-
# Reshape to 2D (M, N) for quantization
1271+
# quantize_nvfp4 works on flattened data in blocks of 16.
1272+
# We need M*N to be divisible by 16. If not, pad the flat vector.
12721273
M = A_state.shape[0]
12731274
N = B_state.shape[0]
1274-
D_2d = D_fp32.reshape(M, N)
1275-
1276-
# Pad N to multiple of 16 if needed for quantization
1277-
N_padded = ((N + 15) // 16) * 16
1278-
if N_padded != N:
1279-
D_padded = torch.zeros(M, N_padded, dtype=D_fp32.dtype, device=D_fp32.device)
1280-
D_padded[:, :N] = D_2d
1281-
packed, out_state = quantize_nvfp4(D_padded.reshape(-1))
1282-
# Adjust state shape to reflect actual (unpadded) output
1283-
out_state.shape = (M, N)
1284-
else:
1285-
packed, out_state = quantize_nvfp4(D_2d.reshape(-1))
1286-
out_state.shape = (M, N)
1275+
numel = M * N
1276+
D_flat = D_fp32.reshape(-1)
1277+
1278+
numel_padded = ((numel + 15) // 16) * 16
1279+
if numel_padded != numel:
1280+
D_flat = torch.nn.functional.pad(D_flat, (0, numel_padded - numel))
1281+
1282+
packed, out_state = quantize_nvfp4(D_flat)
1283+
out_state.shape = (M, N)
12871284

12881285
return packed, out_state
12891286

0 commit comments

Comments
 (0)