Skip to content

Commit c2d7ec7

Browse files
committed
Fixing quantization int8 packing bug for NF4 and FP4
1 parent e54dc12 commit c2d7ec7

File tree

2 files changed

+18
-22
lines changed

2 files changed

+18
-22
lines changed

csrc/kernels.cu

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,6 @@ __global__ void kQuantizeBlockwise(
431431
LoadFloat(loadf).Load(&rand[local_rand_idx], rand_vals, BLOCK_SIZE, 0);
432432
}
433433

434-
unsigned char packed_4bit = 0;
435434
switch (DATA_TYPE) {
436435
case General8bit:
437436
#pragma unroll NUM_PER_TH
@@ -445,17 +444,15 @@ __global__ void kQuantizeBlockwise(
445444
case FP4:
446445
#pragma unroll NUM_PER_TH
447446
for (int j = 0; j < NUM_PER_TH / 2; j++) {
448-
packed_4bit |= dQuantizeFP4(((float)vals[2 * j]) * local_abs_max) << 4;
449-
packed_4bit |= dQuantizeFP4(((float)vals[2 * j + 1]) * local_abs_max);
450-
qvals[j] = packed_4bit;
447+
qvals[j] = dQuantizeFP4(((float)vals[2 * j]) * local_abs_max) << 4;
448+
qvals[j] |= dQuantizeFP4(((float)vals[2 * j + 1]) * local_abs_max);
451449
}
452450
break;
453451
case NF4:
454452
#pragma unroll NUM_PER_TH
455453
for (int j = 0; j < NUM_PER_TH / 2; j++) {
456-
packed_4bit |= dQuantizeNF4(((float)vals[2 * j]) * local_abs_max) << 4;
457-
packed_4bit |= dQuantizeNF4(((float)vals[2 * j + 1]) * local_abs_max);
458-
qvals[j] = packed_4bit;
454+
qvals[j] = dQuantizeNF4(((float)vals[2 * j]) * local_abs_max) << 4;
455+
qvals[j] |= dQuantizeNF4(((float)vals[2 * j + 1]) * local_abs_max);
459456
}
460457
break;
461458
}

tests/test_functional.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1125,21 +1125,20 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize):
11251125

11261126
# With larger block sizes, we can expect this to blow up.
11271127
# At blocksize>=1024, don't even bother looking at relerr.
1128-
if blocksize <= 64:
1129-
assert err.item() < 0.1
1130-
assert relerr.item() < 0.28
1131-
elif blocksize <= 256:
1132-
assert err.item() < 0.11
1133-
assert relerr.item() < 0.30
1134-
elif blocksize <= 512:
1135-
assert err.item() < 0.12
1136-
assert relerr.item() < 0.31
1137-
elif quant_type == "fp4":
1138-
# 1024 => 0.48, 2048 => 0.52, 4096 => 0.56
1139-
assert err.item() < 0.08 + math.log2(blocksize) * 4e-2
1140-
else:
1141-
# 1024 => 0.8, 2048 => 0.88, 4096 => 0.96
1142-
assert err.item() < math.log2(blocksize) * 8e-2
1128+
#
1129+
# Actually, the above is not true anymore after fixing the integer packing bug.
1130+
# The following values were taken from averaging 1k samples per test configuration after fixing the bug.
1131+
error_dict = dict()
1132+
error_dict["fp4"] = dict()
1133+
error_dict["nf4"] = dict()
1134+
error_dict["fp4"]["err"] = {64: 0.096545, 128: 0.102947, 256: 0.108685, 512: 0.114087, 1024: 0.119312, 2048: 0.124460, 4096: 0.129573}
1135+
error_dict["fp4"]["rel_err"] = {64: 0.260130, 128: 0.275734, 256: 0.289842, 512: 0.302852, 1024: 0.314982, 2048: 0.326402, 4096: 0.337228}
1136+
1137+
error_dict["nf4"]["err"] = {64: 0.072792, 128: 0.076835, 256: 0.080326, 512: 0.083535, 1024: 0.086603, 2048: 0.089592, 4096: 0.092537}
1138+
error_dict["nf4"]["rel_err"] = {64: 0.203299, 128: 0.215252, 256: 0.226044, 512: 0.236021, 1024: 0.245365, 2048: 0.254146, 4096: 0.262457}
1139+
1140+
assert err < error_dict[quant_type]["err"][blocksize] + 1e-3
1141+
assert relerr < error_dict[quant_type]["rel_err"][blocksize] + 1e-3
11431142

11441143
@pytest.mark.parametrize("device", get_available_devices())
11451144
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])

0 commit comments

Comments
 (0)