Skip to content

Commit be06982

Browse files
TimDettmersclaude
andcommitted
Add quantize_nvfp4_raw() with device-side global_scale, no host sync
New variant accepts global_scale as a GPU tensor (1/abs_max) instead of a host float, avoiding the .item() GPU→CPU sync. Returns raw (packed, scales) without swizzling or QuantState — caller uses scale_to_blocked_batched. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent cec86d7 commit be06982

File tree

3 files changed

+95
-0
lines changed

3 files changed

+95
-0
lines changed

bitsandbytes/_ops.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,26 @@ def _(
470470
return packed, block_scales, ts_out
471471

472472

473+
# Device-side quantize variant: global_scale is a device tensor (no .item() sync).
474+
# Returns (packed, block_scales) — row-major scales without swizzling.
475+
torch.library.define(
476+
"bitsandbytes::cutlass_fused_quantize_nvfp4_raw",
477+
"(Tensor A, Tensor global_scale_dev) -> (Tensor, Tensor)",
478+
)
479+
480+
481+
@register_fake("bitsandbytes::cutlass_fused_quantize_nvfp4_raw")
482+
def _(
483+
A: torch.Tensor,
484+
global_scale_dev: torch.Tensor,
485+
) -> tuple[torch.Tensor, torch.Tensor]:
486+
n = A.numel()
487+
torch._check(n % 16 == 0, lambda: f"NVFP4 requires numel divisible by 16, got {n}")
488+
packed = torch.empty(n // 2, dtype=torch.uint8, device=A.device)
489+
block_scales = torch.empty(n // 16, dtype=torch.uint8, device=A.device)
490+
return packed, block_scales
491+
492+
473493
# Scale reordering for CUTLASS block-scaled GEMM
474494
torch.library.define(
475495
"bitsandbytes::scale_to_blocked",

bitsandbytes/backends/cuda/ops.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -937,6 +937,54 @@ def _(
937937
return _fused_quantize_nvfp4_impl(A, tensor_scale)
938938

939939

940+
@register_kernel("bitsandbytes::cutlass_fused_quantize_nvfp4_raw", "cuda")
941+
def _(
942+
A: torch.Tensor,
943+
global_scale_dev: torch.Tensor,
944+
) -> tuple[torch.Tensor, torch.Tensor]:
945+
"""Device-side quantize: global_scale is a pre-computed device tensor.
946+
947+
Returns (packed_data, block_scales_rowmajor) — no swizzling, no QuantState.
948+
The global_scale_dev tensor should contain 1.0/tensor_scale as a float32
949+
scalar on the GPU (0-dim or 1-element tensor).
950+
"""
951+
A = A.contiguous()
952+
n = A.numel()
953+
torch._check(n % 16 == 0, lambda: f"NVFP4 requires numel divisible by 16, got {n}")
954+
torch._check(
955+
A.dtype == torch.bfloat16,
956+
lambda: f"CUTLASS fused quantize requires bfloat16, got {A.dtype}",
957+
)
958+
959+
K = 16
960+
orig_M = n // K
961+
padded_M = ((orig_M + 127) // 128) * 128
962+
963+
if padded_M != orig_M:
964+
A_2d = A.view(orig_M, K)
965+
A_2d = torch.nn.functional.pad(A_2d, (0, 0, 0, padded_M - orig_M))
966+
A_flat = A_2d.reshape(-1)
967+
else:
968+
A_flat = A
969+
970+
packed_padded = torch.zeros(padded_M * K // 2, dtype=torch.uint8, device=A.device)
971+
scales_padded = torch.zeros(padded_M, dtype=torch.uint8, device=A.device)
972+
973+
_fused_quantize_nvfp4_raw(
974+
A_flat,
975+
_get_rotation_matrix(A.device),
976+
packed_padded,
977+
scales_padded,
978+
global_scale_dev.to(dtype=torch.float32).contiguous(),
979+
padded_M,
980+
)
981+
982+
packed = packed_padded[: orig_M * K // 2] if padded_M != orig_M else packed_padded
983+
block_scales = scales_padded[:orig_M] if padded_M != orig_M else scales_padded
984+
985+
return packed, block_scales
986+
987+
940988
# Scale reordering for CUTLASS block-scaled GEMM
941989
@register_kernel("bitsandbytes::scale_to_blocked", "cuda")
942990
def _(scales: torch.Tensor, H: int, W: int) -> torch.Tensor:

bitsandbytes/functional.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1204,6 +1204,33 @@ def quantize_nvfp4(
12041204
return packed, state
12051205

12061206

1207+
def quantize_nvfp4_raw(
1208+
A: torch.Tensor,
1209+
global_scale_dev: torch.Tensor,
1210+
) -> tuple[torch.Tensor, torch.Tensor]:
1211+
"""Quantize to NVFP4 with a pre-computed device-side global scale.
1212+
1213+
Unlike quantize_nvfp4(), this variant:
1214+
- Takes global_scale as a device tensor (1/abs_max), no .item() sync
1215+
- Skips scale_to_blocked (caller uses scale_to_blocked_batched instead)
1216+
- Returns raw (packed_data, block_scales_rowmajor) without QuantState
1217+
1218+
Args:
1219+
A: Input tensor (bfloat16). Must have numel divisible by 16.
1220+
global_scale_dev: Device tensor containing 1.0/tensor_scale (float32).
1221+
1222+
Returns:
1223+
Tuple of (packed_data [uint8], block_scales_rowmajor [uint8]).
1224+
"""
1225+
A_flat = A.reshape(-1).contiguous()
1226+
A_bf16 = A_flat.to(torch.bfloat16) if A_flat.dtype != torch.bfloat16 else A_flat
1227+
1228+
packed, block_scales = torch.ops.bitsandbytes.cutlass_fused_quantize_nvfp4_raw(
1229+
A_bf16, global_scale_dev,
1230+
)
1231+
return packed, block_scales
1232+
1233+
12071234
def dequantize_nvfp4(
12081235
packed_data: torch.Tensor,
12091236
quant_state: NVFP4QuantState,

0 commit comments

Comments
 (0)