Skip to content

Commit 18f068f

Browse files
TimDettmersclaude
andcommitted
feat: Add GEMM NVFP4 output epilogue and QuantState serialization
- gemm_nvfp4_to_nvfp4(): chains GEMM + output quantization for layer chaining without dequantizing between layers. Supports alpha scaling and handles non-aligned N dimensions via padding. - NVFP4QuantState.state_dict()/from_state_dict(): serialization support for saving and loading quantized model weights. - Tests: NVFP4 output correctness, alpha scaling, non-aligned shapes, and serialization round-trip. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 7e85a3f commit 18f068f

File tree

2 files changed

+253
-0
lines changed

2 files changed

+253
-0
lines changed

bitsandbytes/functional.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1115,6 +1115,34 @@ def to(self, device):
11151115
rotated=self.rotated,
11161116
)
11171117

1118+
def state_dict(self) -> dict:
1119+
"""Serialize to a dictionary for saving."""
1120+
return {
1121+
"packed_data": self.packed_data,
1122+
"block_scales": self.block_scales,
1123+
"tensor_scale": self.tensor_scale,
1124+
"shape": list(self.shape),
1125+
"dtype": str(self.dtype),
1126+
"rotated": self.rotated,
1127+
}
1128+
1129+
@classmethod
1130+
def from_state_dict(cls, d: dict, device="cpu") -> "NVFP4QuantState":
1131+
"""Deserialize from a dictionary."""
1132+
dtype_map = {
1133+
"torch.float16": torch.float16,
1134+
"torch.bfloat16": torch.bfloat16,
1135+
"torch.float32": torch.float32,
1136+
}
1137+
return cls(
1138+
packed_data=d["packed_data"].to(device),
1139+
block_scales=d["block_scales"].to(device),
1140+
tensor_scale=float(d["tensor_scale"]),
1141+
shape=tuple(d["shape"]),
1142+
dtype=dtype_map.get(d["dtype"], torch.float16),
1143+
rotated=bool(d["rotated"]),
1144+
)
1145+
11181146

11191147
def quantize_nvfp4(
11201148
A: torch.Tensor,
@@ -1209,6 +1237,57 @@ def gemm_nvfp4(
12091237
)
12101238

12111239

1240+
def gemm_nvfp4_to_nvfp4(
1241+
A_data: torch.Tensor,
1242+
A_state: NVFP4QuantState,
1243+
B_data: torch.Tensor,
1244+
B_state: NVFP4QuantState,
1245+
alpha: float = 1.0,
1246+
) -> tuple[torch.Tensor, NVFP4QuantState]:
1247+
"""NVFP4 GEMM with NVFP4 output: compute A @ B^T and quantize the result.
1248+
1249+
This enables layer chaining without dequantizing between layers.
1250+
The GEMM is computed in FP32 internally, then the output is quantized
1251+
back to NVFP4 format (packed E2M1 + E4M3 block scales + FP32 tensor scale).
1252+
1253+
Args:
1254+
A_data: Packed FP4 data for A (M*K/2 bytes).
1255+
A_state: Quantization state for A (M x K).
1256+
B_data: Packed FP4 data for B (N*K/2 bytes, stored as N rows of K).
1257+
B_state: Quantization state for B (N x K).
1258+
alpha: Scalar multiplier applied to the GEMM result before quantization.
1259+
1260+
Returns:
1261+
Tuple of (packed_output, NVFP4QuantState) for the M x N output.
1262+
"""
1263+
# Step 1: Compute GEMM → FP32
1264+
D_fp32 = gemm_nvfp4(A_data, A_state, B_data, B_state)
1265+
1266+
# Step 2: Apply alpha scaling
1267+
if alpha != 1.0:
1268+
D_fp32.mul_(alpha)
1269+
1270+
# Step 3: Quantize FP32 output → NVFP4
1271+
# Reshape to 2D (M, N) for quantization
1272+
M = A_state.shape[0]
1273+
N = B_state.shape[0]
1274+
D_2d = D_fp32.reshape(M, N)
1275+
1276+
# Pad N to multiple of 16 if needed for quantization
1277+
N_padded = ((N + 15) // 16) * 16
1278+
if N_padded != N:
1279+
D_padded = torch.zeros(M, N_padded, dtype=D_fp32.dtype, device=D_fp32.device)
1280+
D_padded[:, :N] = D_2d
1281+
packed, out_state = quantize_nvfp4(D_padded.reshape(-1))
1282+
# Adjust state shape to reflect actual (unpadded) output
1283+
out_state.shape = (M, N)
1284+
else:
1285+
packed, out_state = quantize_nvfp4(D_2d.reshape(-1))
1286+
out_state.shape = (M, N)
1287+
1288+
return packed, out_state
1289+
1290+
12121291
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
12131292
def quantize(
12141293
A: Tensor,

tests/test_gemm_nvfp4.py

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,5 +338,179 @@ def test_gemm_tall_skinny(self, M, N, K):
338338
assert rel_err < 0.01, f"Relative error {rel_err:.6f} too large for {M}x{N}x{K}"
339339

340340

341+
class TestGemmNVFP4Output:
342+
"""Test GEMM with NVFP4 output (layer chaining) via Python API."""
343+
344+
def test_gemm_nvfp4_output_basic(self):
345+
"""GEMM with NVFP4 output: quantize → GEMM → quantize output → dequantize → compare."""
346+
from bitsandbytes.functional import (
347+
dequantize_nvfp4,
348+
gemm_nvfp4_to_nvfp4,
349+
quantize_nvfp4,
350+
)
351+
352+
torch.manual_seed(42)
353+
M, N, K = 32, 32, 64
354+
355+
A_float = torch.randn(M, K, dtype=torch.float32, device="cuda")
356+
B_float = torch.randn(N, K, dtype=torch.float32, device="cuda")
357+
358+
# Quantize inputs
359+
A_packed, A_state = quantize_nvfp4(A_float)
360+
B_packed, B_state = quantize_nvfp4(B_float)
361+
362+
# GEMM with NVFP4 output
363+
out_packed, out_state = gemm_nvfp4_to_nvfp4(A_packed, A_state, B_packed, B_state)
364+
365+
# Dequantize output
366+
D_deq = dequantize_nvfp4(out_packed, out_state, out_dtype=torch.float32)
367+
368+
# Reference: dequantize inputs → matmul
369+
A_deq = dequantize_nvfp4(A_packed, A_state, out_dtype=torch.float32)
370+
B_deq = dequantize_nvfp4(B_packed, B_state, out_dtype=torch.float32)
371+
D_ref = A_deq @ B_deq.T
372+
373+
# NVFP4 output adds a second layer of quantization error
374+
ref_mag = D_ref.abs().mean().item()
375+
mean_err = (D_deq - D_ref).abs().mean().item()
376+
rel_err = mean_err / ref_mag if ref_mag > 0 else mean_err
377+
378+
print(f"GEMM NVFP4 output (M={M}, N={N}, K={K}):")
379+
print(f" Reference magnitude: {ref_mag:.4f}")
380+
print(f" Mean abs error: {mean_err:.4f}")
381+
print(f" Relative error: {rel_err:.4f}")
382+
print(f" Output shape: {D_deq.shape}")
383+
384+
assert D_deq.shape == (M, N), f"Wrong shape: {D_deq.shape}"
385+
# Double quantization error: once for inputs, once for output
386+
assert rel_err < 0.5, f"Relative error {rel_err:.4f} too large"
387+
388+
def test_gemm_nvfp4_output_alpha(self):
389+
"""GEMM with alpha scaling and NVFP4 output."""
390+
from bitsandbytes.functional import (
391+
dequantize_nvfp4,
392+
gemm_nvfp4,
393+
gemm_nvfp4_to_nvfp4,
394+
quantize_nvfp4,
395+
)
396+
397+
torch.manual_seed(123)
398+
M, N, K = 16, 16, 64
399+
alpha = 2.5
400+
401+
A_float = torch.randn(M, K, dtype=torch.float32, device="cuda")
402+
B_float = torch.randn(N, K, dtype=torch.float32, device="cuda")
403+
404+
A_packed, A_state = quantize_nvfp4(A_float)
405+
B_packed, B_state = quantize_nvfp4(B_float)
406+
407+
# GEMM without alpha (FP32 output)
408+
D_fp32 = gemm_nvfp4(A_packed, A_state, B_packed, B_state)
409+
410+
# GEMM with alpha and NVFP4 output
411+
out_packed, out_state = gemm_nvfp4_to_nvfp4(
412+
A_packed, A_state, B_packed, B_state, alpha=alpha
413+
)
414+
D_nvfp4 = dequantize_nvfp4(out_packed, out_state, out_dtype=torch.float32)
415+
416+
# Reference: alpha * FP32 output
417+
D_ref = D_fp32 * alpha
418+
419+
# Verify alpha is reflected in the output (within NVFP4 quantization error)
420+
ref_mag = D_ref.abs().mean().item()
421+
mean_err = (D_nvfp4 - D_ref).abs().mean().item()
422+
rel_err = mean_err / ref_mag if ref_mag > 0 else mean_err
423+
424+
print(f"Alpha test (alpha={alpha}): rel_err={rel_err:.4f}")
425+
assert rel_err < 0.5, f"Relative error {rel_err:.4f} too large"
426+
427+
def test_gemm_nvfp4_output_non_aligned_N(self):
428+
"""GEMM with NVFP4 output where N is not a multiple of 16."""
429+
from bitsandbytes.functional import (
430+
dequantize_nvfp4,
431+
gemm_nvfp4_to_nvfp4,
432+
quantize_nvfp4,
433+
)
434+
435+
torch.manual_seed(77)
436+
M, N, K = 16, 24, 64 # N=24, not multiple of 16
437+
438+
A_float = torch.randn(M, K, dtype=torch.float32, device="cuda")
439+
B_float = torch.randn(N, K, dtype=torch.float32, device="cuda")
440+
441+
A_packed, A_state = quantize_nvfp4(A_float)
442+
B_packed, B_state = quantize_nvfp4(B_float)
443+
444+
out_packed, out_state = gemm_nvfp4_to_nvfp4(A_packed, A_state, B_packed, B_state)
445+
D_deq = dequantize_nvfp4(out_packed, out_state, out_dtype=torch.float32)
446+
447+
# Reference
448+
A_deq = dequantize_nvfp4(A_packed, A_state, out_dtype=torch.float32)
449+
B_deq = dequantize_nvfp4(B_packed, B_state, out_dtype=torch.float32)
450+
D_ref = A_deq @ B_deq.T
451+
452+
assert D_deq.shape == (M, N), f"Wrong shape: {D_deq.shape}"
453+
ref_mag = D_ref.abs().mean().item()
454+
mean_err = (D_deq - D_ref).abs().mean().item()
455+
rel_err = mean_err / ref_mag if ref_mag > 0 else mean_err
456+
print(f"Non-aligned N test ({M}x{N}x{K}): rel_err={rel_err:.4f}")
457+
assert rel_err < 0.5, f"Relative error {rel_err:.4f} too large"
458+
459+
460+
class TestNVFP4QuantStateSerialization:
461+
"""Test NVFP4QuantState save/load."""
462+
463+
def test_state_dict_round_trip(self):
464+
"""Serialize and deserialize NVFP4QuantState."""
465+
from bitsandbytes.functional import NVFP4QuantState, dequantize_nvfp4, quantize_nvfp4
466+
467+
torch.manual_seed(42)
468+
x = torch.randn(256, dtype=torch.float32, device="cuda")
469+
packed, state = quantize_nvfp4(x)
470+
471+
# Serialize
472+
sd = state.state_dict()
473+
assert "packed_data" in sd
474+
assert "block_scales" in sd
475+
assert "tensor_scale" in sd
476+
assert "shape" in sd
477+
assert "dtype" in sd
478+
479+
# Deserialize
480+
state2 = NVFP4QuantState.from_state_dict(sd, device="cuda")
481+
482+
# Verify fields match
483+
assert torch.equal(state.packed_data, state2.packed_data)
484+
assert torch.equal(state.block_scales, state2.block_scales)
485+
assert state.tensor_scale == state2.tensor_scale
486+
assert state.shape == state2.shape
487+
assert state.dtype == state2.dtype
488+
assert state.rotated == state2.rotated
489+
490+
# Verify dequantization produces same result
491+
out1 = dequantize_nvfp4(packed, state, out_dtype=torch.float32)
492+
out2 = dequantize_nvfp4(state2.packed_data, state2, out_dtype=torch.float32)
493+
assert torch.equal(out1, out2), "Dequantized outputs differ after serialization"
494+
495+
def test_state_dict_save_load_file(self):
496+
"""Save to file and reload."""
497+
import tempfile
498+
499+
from bitsandbytes.functional import NVFP4QuantState, quantize_nvfp4
500+
501+
torch.manual_seed(99)
502+
x = torch.randn(128, dtype=torch.float16, device="cuda")
503+
_, state = quantize_nvfp4(x)
504+
505+
with tempfile.NamedTemporaryFile(suffix=".pt") as f:
506+
torch.save(state.state_dict(), f.name)
507+
loaded = torch.load(f.name, weights_only=False)
508+
state2 = NVFP4QuantState.from_state_dict(loaded, device="cuda")
509+
510+
assert torch.equal(state.packed_data, state2.packed_data)
511+
assert state.tensor_scale == state2.tensor_scale
512+
assert state.dtype == state2.dtype
513+
514+
341515
if __name__ == "__main__":
342516
pytest.main([__file__, "-v", "-s"])

0 commit comments

Comments
 (0)