Skip to content

Commit 830491d

Browse files
TimDettmersclaude
andcommitted
Add scatter/gather CUDA kernels for MoE batched pipeline
Scatter copies packed FP4 from concatenated token layout to padded per-expert batched layout with zero-filling. Gather copies BF16 results back. Both use one threadblock per expert with vectorized uint4 loads/stores. Includes C entry points, op definitions, registered kernels, and convenience functions. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 381f611 commit 830491d

File tree

5 files changed

+346
-0
lines changed

5 files changed

+346
-0
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,7 @@ if(BUILD_CUDA)
343343
add_library(nvfp4_common OBJECT
344344
csrc/qutlass/scale_reorder.cu
345345
csrc/qutlass/fused_quantize_nv.cu
346+
csrc/qutlass/moe_scatter_gather.cu
346347
)
347348
set_target_properties(nvfp4_common PROPERTIES
348349
CUDA_ARCHITECTURES "${_NVFP4_COMMON_ARCHS}"

bitsandbytes/_ops.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -540,6 +540,44 @@ def _(blocked_scales: torch.Tensor, H: int, W: int) -> torch.Tensor:
540540
return torch.empty(H * W, dtype=torch.uint8, device=blocked_scales.device)
541541

542542

543+
# MoE scatter: concatenated FP4 → padded per-expert batched FP4
544+
torch.library.define(
545+
"bitsandbytes::moe_scatter_nvfp4",
546+
"(Tensor packed_concat, Tensor expert_offsets, int max_M, int K, int num_experts) -> Tensor",
547+
)
548+
549+
550+
@register_fake("bitsandbytes::moe_scatter_nvfp4")
551+
def _(
552+
packed_concat: torch.Tensor,
553+
expert_offsets: torch.Tensor,
554+
max_M: int,
555+
K: int,
556+
num_experts: int,
557+
) -> torch.Tensor:
558+
row_bytes = K // 2
559+
return torch.empty(num_experts * max_M * row_bytes, dtype=torch.uint8, device=packed_concat.device)
560+
561+
562+
# MoE gather: padded per-expert BF16 → concatenated BF16
563+
torch.library.define(
564+
"bitsandbytes::moe_gather_bf16",
565+
"(Tensor D_batched, Tensor expert_offsets, int max_M, int N, int num_experts, int total_tokens) -> Tensor",
566+
)
567+
568+
569+
@register_fake("bitsandbytes::moe_gather_bf16")
570+
def _(
571+
D_batched: torch.Tensor,
572+
expert_offsets: torch.Tensor,
573+
max_M: int,
574+
N: int,
575+
num_experts: int,
576+
total_tokens: int,
577+
) -> torch.Tensor:
578+
return torch.empty(total_tokens * N, dtype=torch.bfloat16, device=D_batched.device)
579+
580+
543581
# NVFP4 GEMM (A @ B^T with block-scaled FP4 inputs)
544582
torch.library.define(
545583
"bitsandbytes::gemm_nvfp4",

bitsandbytes/backends/cuda/ops.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1055,6 +1055,58 @@ def _(
10551055
return out
10561056

10571057

1058+
@register_kernel("bitsandbytes::moe_scatter_nvfp4", "cuda")
1059+
def _(
1060+
packed_concat: torch.Tensor,
1061+
expert_offsets: torch.Tensor,
1062+
max_M: int,
1063+
K: int,
1064+
num_experts: int,
1065+
) -> torch.Tensor:
1066+
"""Scatter concatenated FP4 data to padded per-expert batched layout."""
1067+
row_bytes = K // 2
1068+
out = torch.empty(
1069+
num_experts * max_M * row_bytes, dtype=torch.uint8, device=packed_concat.device,
1070+
)
1071+
with _cuda_device_of(packed_concat):
1072+
lib.cmoe_scatter_nvfp4(
1073+
get_ptr(packed_concat),
1074+
get_ptr(out),
1075+
get_ptr(expert_offsets),
1076+
ct.c_int(max_M),
1077+
ct.c_int(K),
1078+
ct.c_int(num_experts),
1079+
_get_tensor_stream(packed_concat),
1080+
)
1081+
return out
1082+
1083+
1084+
@register_kernel("bitsandbytes::moe_gather_bf16", "cuda")
1085+
def _(
1086+
D_batched: torch.Tensor,
1087+
expert_offsets: torch.Tensor,
1088+
max_M: int,
1089+
N: int,
1090+
num_experts: int,
1091+
total_tokens: int,
1092+
) -> torch.Tensor:
1093+
"""Gather BF16 results from padded per-expert layout to concatenated output."""
1094+
out = torch.empty(
1095+
total_tokens * N, dtype=torch.bfloat16, device=D_batched.device,
1096+
)
1097+
with _cuda_device_of(D_batched):
1098+
lib.cmoe_gather_bf16(
1099+
get_ptr(D_batched),
1100+
get_ptr(out),
1101+
get_ptr(expert_offsets),
1102+
ct.c_int(max_M),
1103+
ct.c_int(N),
1104+
ct.c_int(num_experts),
1105+
_get_tensor_stream(D_batched),
1106+
)
1107+
return out
1108+
1109+
10581110
# Hand-written NVFP4 GEMM (SM_120+)
10591111
#
10601112
# Uses mma.sync.aligned.block_scale instructions for small-M decode.

bitsandbytes/functional.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1280,6 +1280,56 @@ def scale_to_blocked_batched(
12801280
)
12811281

12821282

1283+
def moe_scatter_nvfp4(
1284+
packed_concat: torch.Tensor,
1285+
expert_offsets: torch.Tensor,
1286+
max_M: int,
1287+
K: int,
1288+
num_experts: int,
1289+
) -> torch.Tensor:
1290+
"""Scatter concatenated FP4 data to padded per-expert batched layout.
1291+
1292+
Args:
1293+
packed_concat: Packed FP4 data [total_tokens * K/2] (uint8).
1294+
expert_offsets: Cumulative token offsets [num_experts + 1] (int32, device).
1295+
max_M: Padded max tokens per expert (128-aligned).
1296+
K: Hidden dimension.
1297+
num_experts: Number of experts.
1298+
1299+
Returns:
1300+
Padded batched FP4 data [num_experts * max_M * K/2] (uint8, zero-padded).
1301+
"""
1302+
return torch.ops.bitsandbytes.moe_scatter_nvfp4(
1303+
packed_concat, expert_offsets, max_M, K, num_experts,
1304+
)
1305+
1306+
1307+
def moe_gather_bf16(
1308+
D_batched: torch.Tensor,
1309+
expert_offsets: torch.Tensor,
1310+
max_M: int,
1311+
N: int,
1312+
num_experts: int,
1313+
total_tokens: int,
1314+
) -> torch.Tensor:
1315+
"""Gather BF16 results from padded per-expert layout to concatenated.
1316+
1317+
Args:
1318+
D_batched: Batched BF16 output [num_experts * max_M * N] (bf16).
1319+
expert_offsets: Cumulative token offsets [num_experts + 1] (int32, device).
1320+
max_M: Padded max tokens per expert.
1321+
N: Output dimension.
1322+
num_experts: Number of experts.
1323+
total_tokens: Total tokens across all experts.
1324+
1325+
Returns:
1326+
Concatenated BF16 output [total_tokens * N].
1327+
"""
1328+
return torch.ops.bitsandbytes.moe_gather_bf16(
1329+
D_batched, expert_offsets, max_M, N, num_experts, total_tokens,
1330+
)
1331+
1332+
12831333
def dequantize_nvfp4(
12841334
packed_data: torch.Tensor,
12851335
quant_state: NVFP4QuantState,

csrc/qutlass/moe_scatter_gather.cu

Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
/*
2+
* Scatter and gather kernels for MoE batched NVFP4 GEMM pipeline.
3+
*
4+
* Scatter: copies packed FP4 data from concatenated token layout to
5+
* padded per-expert batched layout. Zero-fills padding rows.
6+
*
7+
* Gather: copies BF16 results from padded per-expert batched layout
8+
* back to concatenated token layout.
9+
*
10+
* Both kernels use one threadblock per expert with vectorized 128-bit
11+
* (uint4) loads/stores for bandwidth efficiency.
12+
*/
13+
14+
#include <cuda_runtime.h>
15+
#include <cstdint>
16+
17+
// =========================================================================
18+
// Scatter: concatenated FP4 → padded per-expert batched FP4
19+
// =========================================================================
20+
// Each threadblock handles one expert. Threads cooperatively copy
21+
// n_tokens * row_bytes from the concatenated source to the padded
22+
// destination, then zero-fill padding rows.
23+
//
24+
// Data layout:
25+
// Input: packed_concat [total_tokens * row_bytes] contiguous
26+
// Output: packed_batched [num_experts * max_M * row_bytes] with zero padding
27+
//
28+
// row_bytes = K / 2 (packed FP4: 2 values per byte)
29+
__global__ void kMoeScatterNVFP4(
30+
const uint8_t* __restrict__ input, // [total_tokens * row_bytes]
31+
uint8_t* __restrict__ output, // [num_experts * max_M * row_bytes]
32+
const int* __restrict__ expert_offsets, // [num_experts + 1] cumulative token offsets
33+
int max_M, // padded max tokens per expert
34+
int row_bytes // K / 2
35+
) {
36+
int expert = blockIdx.x;
37+
int start = expert_offsets[expert];
38+
int end = expert_offsets[expert + 1];
39+
int n_tokens = end - start;
40+
41+
// Source: contiguous in concatenated buffer
42+
const uint8_t* src = input + (long long)start * row_bytes;
43+
44+
// Destination: padded slot for this expert
45+
uint8_t* dst = output + (long long)expert * max_M * row_bytes;
46+
47+
// Total bytes to process for this expert (data + padding)
48+
long long total_bytes = (long long)max_M * row_bytes;
49+
long long data_bytes = (long long)n_tokens * row_bytes;
50+
51+
// Use vectorized uint4 (16-byte) copies where possible
52+
int tid = threadIdx.x;
53+
int stride = blockDim.x;
54+
55+
// Copy data rows using uint4 vectorization
56+
long long vec_data_bytes = (data_bytes / 16) * 16;
57+
const uint4* src4 = reinterpret_cast<const uint4*>(src);
58+
uint4* dst4 = reinterpret_cast<uint4*>(dst);
59+
long long n_vec = vec_data_bytes / 16;
60+
61+
for (long long i = tid; i < n_vec; i += stride) {
62+
dst4[i] = src4[i];
63+
}
64+
65+
// Handle remaining bytes in data region
66+
for (long long i = vec_data_bytes + tid; i < data_bytes; i += stride) {
67+
dst[i] = src[i];
68+
}
69+
70+
// Zero-fill padding region using uint4
71+
long long pad_start = data_bytes;
72+
long long pad_bytes = total_bytes - pad_start;
73+
74+
if (pad_bytes > 0) {
75+
// Align pad_start up to 16-byte boundary for vectorized zeroing
76+
long long aligned_pad_start = ((pad_start + 15) / 16) * 16;
77+
78+
// Zero unaligned bytes at start of padding
79+
for (long long i = pad_start + tid; i < aligned_pad_start && i < total_bytes; i += stride) {
80+
dst[i] = 0;
81+
}
82+
83+
// Vectorized zero-fill
84+
uint4 zero4 = make_uint4(0, 0, 0, 0);
85+
long long vec_pad_end = (total_bytes / 16) * 16;
86+
uint4* dst4_pad = reinterpret_cast<uint4*>(dst);
87+
long long vec_start = aligned_pad_start / 16;
88+
long long vec_end = vec_pad_end / 16;
89+
90+
for (long long i = vec_start + tid; i < vec_end; i += stride) {
91+
dst4_pad[i] = zero4;
92+
}
93+
94+
// Zero remaining bytes at end
95+
for (long long i = vec_pad_end + tid; i < total_bytes; i += stride) {
96+
dst[i] = 0;
97+
}
98+
}
99+
}
100+
101+
102+
// =========================================================================
103+
// Gather: padded per-expert BF16 → concatenated BF16
104+
// =========================================================================
105+
// Each threadblock handles one expert. Threads cooperatively copy
106+
// n_tokens * row_elems BF16 values from the padded batched output
107+
// to the concatenated result.
108+
//
109+
// Data layout:
110+
// Input: D_batched [num_experts * max_M * N] bf16
111+
// Output: D_concat [total_tokens * N] bf16
112+
//
113+
// row_bytes = N * 2 (bf16 = 2 bytes per element)
114+
__global__ void kMoeGatherBF16(
115+
const uint8_t* __restrict__ input, // [num_experts * max_M * row_bytes]
116+
uint8_t* __restrict__ output, // [total_tokens * row_bytes]
117+
const int* __restrict__ expert_offsets, // [num_experts + 1]
118+
int max_M,
119+
int row_bytes // N * 2
120+
) {
121+
int expert = blockIdx.x;
122+
int start = expert_offsets[expert];
123+
int end = expert_offsets[expert + 1];
124+
int n_tokens = end - start;
125+
126+
if (n_tokens <= 0) return;
127+
128+
// Source: padded slot for this expert
129+
const uint8_t* src = input + (long long)expert * max_M * row_bytes;
130+
131+
// Destination: contiguous in concatenated buffer
132+
uint8_t* dst = output + (long long)start * row_bytes;
133+
134+
long long data_bytes = (long long)n_tokens * row_bytes;
135+
136+
int tid = threadIdx.x;
137+
int stride = blockDim.x;
138+
139+
// Vectorized uint4 copy
140+
long long vec_bytes = (data_bytes / 16) * 16;
141+
const uint4* src4 = reinterpret_cast<const uint4*>(src);
142+
uint4* dst4 = reinterpret_cast<uint4*>(dst);
143+
long long n_vec = vec_bytes / 16;
144+
145+
for (long long i = tid; i < n_vec; i += stride) {
146+
dst4[i] = src4[i];
147+
}
148+
149+
// Handle remaining bytes
150+
for (long long i = vec_bytes + tid; i < data_bytes; i += stride) {
151+
dst[i] = src[i];
152+
}
153+
}
154+
155+
156+
// =========================================================================
157+
// extern "C" launchers
158+
// =========================================================================
159+
160+
extern "C" void cmoe_scatter_nvfp4(
161+
const void* input,
162+
void* output,
163+
const int* expert_offsets,
164+
int max_M,
165+
int K,
166+
int num_experts,
167+
cudaStream_t stream
168+
) {
169+
int row_bytes = K / 2; // packed FP4: 2 values per byte
170+
171+
// One threadblock per expert, 256 threads
172+
dim3 grid(num_experts);
173+
dim3 block(256);
174+
175+
kMoeScatterNVFP4<<<grid, block, 0, stream>>>(
176+
static_cast<const uint8_t*>(input),
177+
static_cast<uint8_t*>(output),
178+
expert_offsets,
179+
max_M,
180+
row_bytes
181+
);
182+
}
183+
184+
extern "C" void cmoe_gather_bf16(
185+
const void* input,
186+
void* output,
187+
const int* expert_offsets,
188+
int max_M,
189+
int N,
190+
int num_experts,
191+
cudaStream_t stream
192+
) {
193+
int row_bytes = N * 2; // bf16: 2 bytes per element
194+
195+
dim3 grid(num_experts);
196+
dim3 block(256);
197+
198+
kMoeGatherBF16<<<grid, block, 0, stream>>>(
199+
static_cast<const uint8_t*>(input),
200+
static_cast<uint8_t*>(output),
201+
expert_offsets,
202+
max_M,
203+
row_bytes
204+
);
205+
}

0 commit comments

Comments
 (0)