Skip to content

Commit 381f611

Browse files
TimDettmersclaude
andcommitted
Add Python wrapper for cscale_to_blocked_batched (batched scale swizzle)
Exposes the existing C kernel for batched per-expert scale reordering from row-major to CUTLASS block-scaled layout. Adds op definition, registered kernel, and convenience function that computes metadata from expert_offsets. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent be06982 commit 381f611

File tree

3 files changed

+104
-0
lines changed

3 files changed

+104
-0
lines changed

bitsandbytes/_ops.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,29 @@ def _(scales: torch.Tensor, H: int, W: int) -> torch.Tensor:
505505
return torch.empty(out_size, dtype=torch.uint8, device=scales.device)
506506

507507

508+
# Batched scale reordering for MoE: row-major → per-expert swizzled
509+
torch.library.define(
510+
"bitsandbytes::scale_to_blocked_batched",
511+
"(Tensor scales_rowmajor, Tensor expert_row_offsets, Tensor expert_M, "
512+
"Tensor expert_out_offsets, int W, int num_experts, int max_row_blocks, "
513+
"int total_out_bytes) -> Tensor",
514+
)
515+
516+
517+
@register_fake("bitsandbytes::scale_to_blocked_batched")
518+
def _(
519+
scales_rowmajor: torch.Tensor,
520+
expert_row_offsets: torch.Tensor,
521+
expert_M: torch.Tensor,
522+
expert_out_offsets: torch.Tensor,
523+
W: int,
524+
num_experts: int,
525+
max_row_blocks: int,
526+
total_out_bytes: int,
527+
) -> torch.Tensor:
528+
return torch.empty(total_out_bytes, dtype=torch.uint8, device=scales_rowmajor.device)
529+
530+
508531
# Inverse scale reordering: CUTLASS block-scaled layout → row-major
509532
torch.library.define(
510533
"bitsandbytes::scale_from_blocked",

bitsandbytes/backends/cuda/ops.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1023,6 +1023,38 @@ def _(blocked_scales: torch.Tensor, H: int, W: int) -> torch.Tensor:
10231023
return out
10241024

10251025

1026+
@register_kernel("bitsandbytes::scale_to_blocked_batched", "cuda")
1027+
def _(
1028+
scales_rowmajor: torch.Tensor,
1029+
expert_row_offsets: torch.Tensor,
1030+
expert_M: torch.Tensor,
1031+
expert_out_offsets: torch.Tensor,
1032+
W: int,
1033+
num_experts: int,
1034+
max_row_blocks: int,
1035+
total_out_bytes: int,
1036+
) -> torch.Tensor:
1037+
"""Batched scale swizzle: row-major → per-expert CUTLASS block-scaled layout.
1038+
1039+
Input: concatenated row-major scales from quantize_nvfp4_raw.
1040+
Output: contiguous buffer with independently swizzled per-expert blocks.
1041+
"""
1042+
out = torch.zeros(total_out_bytes, dtype=torch.uint8, device=scales_rowmajor.device)
1043+
with _cuda_device_of(scales_rowmajor):
1044+
lib.cscale_to_blocked_batched(
1045+
get_ptr(scales_rowmajor),
1046+
get_ptr(out),
1047+
get_ptr(expert_row_offsets),
1048+
get_ptr(expert_M),
1049+
get_ptr(expert_out_offsets),
1050+
ct.c_int(W),
1051+
ct.c_int(num_experts),
1052+
ct.c_int(max_row_blocks),
1053+
_get_tensor_stream(scales_rowmajor),
1054+
)
1055+
return out
1056+
1057+
10261058
# Hand-written NVFP4 GEMM (SM_120+)
10271059
#
10281060
# Uses mma.sync.aligned.block_scale instructions for small-M decode.

bitsandbytes/functional.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1231,6 +1231,55 @@ def quantize_nvfp4_raw(
12311231
return packed, block_scales
12321232

12331233

1234+
def scale_to_blocked_batched(
1235+
scales_rowmajor: torch.Tensor,
1236+
expert_offsets: torch.Tensor,
1237+
max_M: int,
1238+
K: int,
1239+
num_experts: int,
1240+
) -> torch.Tensor:
1241+
"""Swizzle concatenated row-major scales into per-expert CUTLASS layout.
1242+
1243+
Args:
1244+
scales_rowmajor: Concatenated row-major block scales [total_tokens * K/16] (uint8).
1245+
expert_offsets: Cumulative token offsets [num_experts + 1] (int32, device).
1246+
max_M: Max tokens per expert (padded to 128 alignment).
1247+
K: Hidden dimension.
1248+
num_experts: Number of experts.
1249+
1250+
Returns:
1251+
Contiguous buffer with per-expert swizzled scales for batched GEMM.
1252+
"""
1253+
W = K // 16 # scale columns
1254+
n_col_blocks = (W + 3) // 4
1255+
1256+
# Compute per-expert metadata on device
1257+
tokens_per_expert = expert_offsets[1:] - expert_offsets[:-1]
1258+
# Scale rows = tokens * (K / 16) / W = tokens (each token has K/16 scale values)
1259+
# Actually: scales are [total_tokens, W] in row-major, so expert_row_offsets = expert_offsets * W / W = expert_offsets
1260+
# Wait — the quantize output is flat: total_tokens * (K/16) bytes.
1261+
# For scale_to_blocked_batched, input is [total_rows, W] where total_rows = total_tokens
1262+
# expert_row_offsets[i] = expert_offsets[i] (token offset IS the row offset)
1263+
expert_row_offsets = expert_offsets[:-1].to(torch.int32)
1264+
expert_M_dev = tokens_per_expert.to(torch.int32)
1265+
1266+
# Output offsets: each expert gets n_row_blocks_e * n_col_blocks * 512 bytes
1267+
# For uniform max_M: all experts get the same size
1268+
n_row_blocks_per = (max_M + 127) // 128
1269+
per_expert_bytes = n_row_blocks_per * n_col_blocks * 512
1270+
expert_out_offsets = torch.arange(
1271+
num_experts, dtype=torch.int32, device=scales_rowmajor.device,
1272+
) * per_expert_bytes
1273+
1274+
max_row_blocks = n_row_blocks_per
1275+
total_out_bytes = num_experts * per_expert_bytes
1276+
1277+
return torch.ops.bitsandbytes.scale_to_blocked_batched(
1278+
scales_rowmajor, expert_row_offsets, expert_M_dev, expert_out_offsets,
1279+
W, num_experts, max_row_blocks, total_out_bytes,
1280+
)
1281+
1282+
12341283
def dequantize_nvfp4(
12351284
packed_data: torch.Tensor,
12361285
quant_state: NVFP4QuantState,

0 commit comments

Comments
 (0)