Skip to content

Commit b736d14

Browse files
TimDettmersclaude
andcommitted
test: Validate quantized size formulas for streaming quantizer
Tests compute_quantized_sizes() against actual quantize_kbit() output for all k values (2-5), standard N values (128-12288), edge cases where N is not a multiple of 128 (100, 300, 1000), and K values (128-5120). 294 tests, all pass — formulas exactly match kernel output. Also includes GLM-4.7 specific tests for q_proj (12288×5120, NF4) and expert gate_proj (1536×5120, NF2). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 14ff086 commit b736d14

File tree

1 file changed

+127
-0
lines changed

1 file changed

+127
-0
lines changed

tests/test_quantized_sizes.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
"""Validate that quantized tensor size formulas exactly match actual quantize_kbit output.
2+
3+
These formulas are used by the streaming quantizer (two-pass) to compute the
4+
safetensors header before any GPU quantization happens. If the formulas are
5+
wrong, the safetensors file will be corrupted.
6+
7+
The formulas (from _ops.py):
8+
N_padded = ceil(N / 128) * 128
9+
n_elements = N_padded * K
10+
num_blocks = ceil(n_elements / 32)
11+
packed_numel = num_blocks * k + k # int32 elements
12+
absmax_numel = num_blocks + 1 # float32 elements
13+
codebook_numel = 2^k # float32 elements
14+
"""
15+
16+
import pytest
17+
import torch
18+
19+
import bitsandbytes.functional as F
20+
21+
22+
def compute_quantized_sizes(N: int, K: int, k: int) -> dict:
23+
"""Compute quantized tensor sizes for a weight matrix [N, K].
24+
25+
This is the formula that the streaming quantizer will use.
26+
"""
27+
N_padded = ((N + 127) // 128) * 128
28+
n_elements = N_padded * K
29+
num_blocks = -(n_elements // -32) # ceil_div
30+
31+
packed_numel = num_blocks * k + k
32+
absmax_numel = num_blocks + 1
33+
codebook_numel = 1 << k
34+
35+
return {
36+
"N_padded": N_padded,
37+
"n_elements": n_elements,
38+
"num_blocks": num_blocks,
39+
"packed_numel": packed_numel,
40+
"absmax_numel": absmax_numel,
41+
"codebook_numel": codebook_numel,
42+
}
43+
44+
45+
# Standard N values (multiples of 128)
46+
N_VALUES_STANDARD = [128, 256, 512, 768, 1024, 1536, 2048, 4096, 12288]
47+
# Edge case N values (NOT multiples of 128)
48+
N_VALUES_EDGE = [100, 300, 1000]
49+
# K values
50+
K_VALUES = [128, 512, 1024, 2048, 4096, 5120]
51+
# k values (bit widths)
52+
K_BIT_VALUES = [2, 3, 4, 5]
53+
54+
55+
@pytest.mark.parametrize("k", K_BIT_VALUES)
56+
@pytest.mark.parametrize("K", K_VALUES)
57+
@pytest.mark.parametrize("N", N_VALUES_STANDARD + N_VALUES_EDGE)
58+
def test_quantized_sizes_match(N, K, k):
59+
"""Verify formula-predicted sizes match actual quantize_kbit output."""
60+
predicted = compute_quantized_sizes(N, K, k)
61+
N_padded = predicted["N_padded"]
62+
63+
# Create a tensor with the padded size
64+
A = torch.randn(N_padded * K, device="cuda", dtype=torch.float32)
65+
66+
# Actually quantize
67+
packed, absmax, codebook = F.quantize_kbit(A, k=k, absmax_format="fp32")
68+
69+
# Compare sizes
70+
assert packed.numel() == predicted["packed_numel"], (
71+
f"packed size mismatch for N={N}, K={K}, k={k}: "
72+
f"got {packed.numel()}, expected {predicted['packed_numel']}"
73+
)
74+
assert absmax.numel() == predicted["absmax_numel"], (
75+
f"absmax size mismatch for N={N}, K={K}, k={k}: "
76+
f"got {absmax.numel()}, expected {predicted['absmax_numel']}"
77+
)
78+
assert codebook.numel() == predicted["codebook_numel"], (
79+
f"codebook size mismatch for N={N}, K={K}, k={k}: "
80+
f"got {codebook.numel()}, expected {predicted['codebook_numel']}"
81+
)
82+
83+
# Verify N_padded is correct
84+
assert N_padded >= N
85+
assert N_padded % 128 == 0
86+
assert N_padded - N < 128
87+
88+
89+
@pytest.mark.parametrize("k", K_BIT_VALUES)
90+
def test_codebook_size(k):
91+
"""Verify codebook has 2^k entries."""
92+
codebook = F.create_normal_float_codebook(k, device="cuda")
93+
assert codebook.numel() == (1 << k)
94+
95+
96+
def test_glm47_q_proj_sizes():
97+
"""Verify formula with GLM-4.7 q_proj dimensions (a real-world case)."""
98+
# GLM-4.7: num_heads=96, head_dim=128 → N=12288, K=5120
99+
N, K, k = 12288, 5120, 4
100+
predicted = compute_quantized_sizes(N, K, k)
101+
102+
assert predicted["N_padded"] == 12288 # already multiple of 128
103+
assert predicted["n_elements"] == 12288 * 5120
104+
assert predicted["num_blocks"] == -(12288 * 5120 // -32)
105+
assert predicted["packed_numel"] == predicted["num_blocks"] * 4 + 4
106+
assert predicted["codebook_numel"] == 16
107+
108+
# Verify against actual quantization
109+
A = torch.randn(predicted["n_elements"], device="cuda", dtype=torch.float32)
110+
packed, absmax, codebook = F.quantize_kbit(A, k=k, absmax_format="fp32")
111+
assert packed.numel() == predicted["packed_numel"]
112+
assert absmax.numel() == predicted["absmax_numel"]
113+
114+
115+
def test_glm47_expert_gate_sizes():
116+
"""Verify formula with GLM-4.7 expert gate_proj dimensions."""
117+
# GLM-4.7 expert: intermediate=1536, hidden=5120, NF2
118+
N, K, k = 1536, 5120, 2
119+
predicted = compute_quantized_sizes(N, K, k)
120+
121+
assert predicted["N_padded"] == 1536 # already multiple of 128
122+
assert predicted["codebook_numel"] == 4 # 2^2 = 4 for NF2
123+
124+
A = torch.randn(predicted["n_elements"], device="cuda", dtype=torch.float32)
125+
packed, absmax, codebook = F.quantize_kbit(A, k=k, absmax_format="fp32")
126+
assert packed.numel() == predicted["packed_numel"]
127+
assert absmax.numel() == predicted["absmax_numel"]

0 commit comments

Comments
 (0)