Skip to content

Commit 7588d2a

Browse files
TimDettmersclaude
andcommitted
Multi-block scatter/gather: saturate B200 memory bandwidth
Use 2D grid (num_experts × chunks_per_expert) instead of single block per expert. Targets 320 total blocks (2× B200 SM count) so all 160 SMs stay busy during scatter and gather. Previous design used only 8 blocks for 8 experts, achieving ~113-286 GB/s vs B200's 8TB/s theoretical. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 6059f9e commit 7588d2a

File tree

1 file changed

+81
-67
lines changed

1 file changed

+81
-67
lines changed

csrc/qutlass/moe_scatter_gather.cu

Lines changed: 81 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@
2323
// =========================================================================
2424
// Scatter: concatenated FP4 → padded per-expert batched FP4
2525
// =========================================================================
26-
// Each threadblock handles one expert. Threads cooperatively copy
27-
// n_tokens * row_bytes from the concatenated source to the padded
28-
// destination, then zero-fill padding rows.
26+
// Grid: (num_experts, chunks_per_expert). Each block handles a byte-range
27+
// slice of one expert's total_bytes = max_M * row_bytes, splitting work
28+
// across multiple SMs for bandwidth saturation on wide GPUs (B200: 160 SMs).
2929
//
3030
// Data layout:
3131
// Input: packed_concat [total_tokens * row_bytes] contiguous
@@ -37,80 +37,76 @@ __global__ void kMoeScatterNVFP4(
3737
uint8_t* __restrict__ output, // [num_experts * max_M * row_bytes]
3838
const int* __restrict__ expert_offsets, // [num_experts + 1] cumulative token offsets
3939
int max_M, // padded max tokens per expert
40-
int row_bytes // K / 2
40+
int row_bytes, // K / 2
41+
int chunks_per_expert // gridDim.y
4142
) {
4243
int expert = blockIdx.x;
44+
int chunk = blockIdx.y;
45+
4346
int start = expert_offsets[expert];
4447
int end = expert_offsets[expert + 1];
4548
int n_tokens = end - start;
4649

47-
// Source: contiguous in concatenated buffer
4850
const uint8_t* src = input + (long long)start * row_bytes;
49-
50-
// Destination: padded slot for this expert
5151
uint8_t* dst = output + (long long)expert * max_M * row_bytes;
5252

53-
// Total bytes to process for this expert (data + padding)
5453
long long total_bytes = (long long)max_M * row_bytes;
5554
long long data_bytes = (long long)n_tokens * row_bytes;
5655

57-
// Use vectorized uint4 (16-byte) copies where possible
56+
// This block's byte range (aligned to 16 for vectorization)
57+
long long bytes_per_chunk = ((total_bytes + chunks_per_expert - 1) / chunks_per_expert + 15) & ~15LL;
58+
long long my_start = (long long)chunk * bytes_per_chunk;
59+
long long my_end = min(my_start + bytes_per_chunk, total_bytes);
60+
if (my_start >= total_bytes) return;
61+
5862
int tid = threadIdx.x;
5963
int stride = blockDim.x;
6064

61-
// Copy data rows using uint4 vectorization
62-
long long vec_data_bytes = (data_bytes / 16) * 16;
63-
const uint4* src4 = reinterpret_cast<const uint4*>(src);
64-
uint4* dst4 = reinterpret_cast<uint4*>(dst);
65-
long long n_vec = vec_data_bytes / 16;
66-
67-
for (long long i = tid; i < n_vec; i += stride) {
68-
dst4[i] = src4[i];
69-
}
65+
// Process byte range [my_start, my_end) — copy from src where < data_bytes, zero otherwise
66+
// Use uint4 (16-byte) vectorization
67+
long long vec_start = (my_start + 15) / 16; // first full uint4 in range
68+
long long vec_end = my_end / 16; // last full uint4 in range
7069

71-
// Handle remaining bytes in data region
72-
for (long long i = vec_data_bytes + tid; i < data_bytes; i += stride) {
73-
dst[i] = src[i];
70+
// Scalar head bytes
71+
for (long long i = my_start + tid; i < min(vec_start * 16, my_end); i += stride) {
72+
dst[i] = (i < data_bytes) ? src[i] : 0;
7473
}
7574

76-
// Zero-fill padding region using uint4
77-
long long pad_start = data_bytes;
78-
long long pad_bytes = total_bytes - pad_start;
79-
80-
if (pad_bytes > 0) {
81-
// Align pad_start up to 16-byte boundary for vectorized zeroing
82-
long long aligned_pad_start = ((pad_start + 15) / 16) * 16;
83-
84-
// Zero unaligned bytes at start of padding
85-
for (long long i = pad_start + tid; i < aligned_pad_start && i < total_bytes; i += stride) {
86-
dst[i] = 0;
87-
}
88-
89-
// Vectorized zero-fill
90-
uint4 zero4 = make_uint4(0, 0, 0, 0);
91-
long long vec_pad_end = (total_bytes / 16) * 16;
92-
uint4* dst4_pad = reinterpret_cast<uint4*>(dst);
93-
long long vec_start = aligned_pad_start / 16;
94-
long long vec_end = vec_pad_end / 16;
95-
96-
for (long long i = vec_start + tid; i < vec_end; i += stride) {
97-
dst4_pad[i] = zero4;
75+
// Vectorized middle
76+
const uint4* src4 = reinterpret_cast<const uint4*>(src);
77+
uint4* dst4 = reinterpret_cast<uint4*>(dst);
78+
uint4 zero4 = make_uint4(0, 0, 0, 0);
79+
long long data_vec_boundary = data_bytes / 16; // last full uint4 within data
80+
81+
for (long long i = vec_start + tid; i < vec_end; i += stride) {
82+
if (i < data_vec_boundary) {
83+
dst4[i] = src4[i];
84+
} else if (i * 16 >= data_bytes) {
85+
dst4[i] = zero4;
86+
} else {
87+
// Straddles data/padding boundary — byte-by-byte
88+
uint8_t tmp[16];
89+
const uint8_t* s = src + i * 16;
90+
for (int b = 0; b < 16; b++) {
91+
long long pos = i * 16 + b;
92+
tmp[b] = (pos < data_bytes) ? s[b] : 0;
93+
}
94+
dst4[i] = *reinterpret_cast<uint4*>(tmp);
9895
}
96+
}
9997

100-
// Zero remaining bytes at end
101-
for (long long i = vec_pad_end + tid; i < total_bytes; i += stride) {
102-
dst[i] = 0;
103-
}
98+
// Scalar tail bytes
99+
for (long long i = vec_end * 16 + tid; i < my_end; i += stride) {
100+
dst[i] = (i < data_bytes) ? src[i] : 0;
104101
}
105102
}
106103

107104

108105
// =========================================================================
109106
// Gather: padded per-expert BF16 → concatenated BF16
110107
// =========================================================================
111-
// Each threadblock handles one expert. Threads cooperatively copy
112-
// n_tokens * row_elems BF16 values from the padded batched output
113-
// to the concatenated result.
108+
// Grid: (num_experts, chunks_per_expert). Each block handles a byte-range
109+
// slice of one expert's data_bytes = n_tokens * row_bytes.
114110
//
115111
// Data layout:
116112
// Input: D_batched [num_experts * max_M * N] bf16
@@ -122,38 +118,49 @@ __global__ void kMoeGatherBF16(
122118
uint8_t* __restrict__ output, // [total_tokens * row_bytes]
123119
const int* __restrict__ expert_offsets, // [num_experts + 1]
124120
int max_M,
125-
int row_bytes // N * 2
121+
int row_bytes, // N * 2
122+
int chunks_per_expert // gridDim.y
126123
) {
127124
int expert = blockIdx.x;
125+
int chunk = blockIdx.y;
126+
128127
int start = expert_offsets[expert];
129128
int end = expert_offsets[expert + 1];
130129
int n_tokens = end - start;
131-
132130
if (n_tokens <= 0) return;
133131

134-
// Source: padded slot for this expert
135132
const uint8_t* src = input + (long long)expert * max_M * row_bytes;
136-
137-
// Destination: contiguous in concatenated buffer
138133
uint8_t* dst = output + (long long)start * row_bytes;
139134

140135
long long data_bytes = (long long)n_tokens * row_bytes;
141136

137+
// This block's byte range (aligned to 16)
138+
long long bytes_per_chunk = ((data_bytes + chunks_per_expert - 1) / chunks_per_expert + 15) & ~15LL;
139+
long long my_start = (long long)chunk * bytes_per_chunk;
140+
long long my_end = min(my_start + bytes_per_chunk, data_bytes);
141+
if (my_start >= data_bytes) return;
142+
142143
int tid = threadIdx.x;
143144
int stride = blockDim.x;
144145

145-
// Vectorized uint4 copy
146-
long long vec_bytes = (data_bytes / 16) * 16;
146+
// Vectorized uint4 copy over [my_start, my_end)
147+
long long vec_start = (my_start + 15) / 16;
148+
long long vec_end = my_end / 16;
149+
150+
// Scalar head
151+
for (long long i = my_start + tid; i < min(vec_start * 16, my_end); i += stride) {
152+
dst[i] = src[i];
153+
}
154+
155+
// Vectorized middle
147156
const uint4* src4 = reinterpret_cast<const uint4*>(src);
148157
uint4* dst4 = reinterpret_cast<uint4*>(dst);
149-
long long n_vec = vec_bytes / 16;
150-
151-
for (long long i = tid; i < n_vec; i += stride) {
158+
for (long long i = vec_start + tid; i < vec_end; i += stride) {
152159
dst4[i] = src4[i];
153160
}
154161

155-
// Handle remaining bytes
156-
for (long long i = vec_bytes + tid; i < data_bytes; i += stride) {
162+
// Scalar tail
163+
for (long long i = vec_end * 16 + tid; i < my_end; i += stride) {
157164
dst[i] = src[i];
158165
}
159166
}
@@ -219,6 +226,10 @@ __global__ void kConvertFP32ToBF16(
219226
// extern "C" launchers
220227
// =========================================================================
221228

229+
// Target enough total blocks to saturate GPU SMs.
230+
// B200 has 160 SMs; 2× oversubscription hides latency.
231+
static constexpr int kTargetBlocks = 320;
232+
222233
extern "C" void cmoe_scatter_nvfp4(
223234
const void* input,
224235
void* output,
@@ -229,17 +240,18 @@ extern "C" void cmoe_scatter_nvfp4(
229240
cudaStream_t stream
230241
) {
231242
int row_bytes = K / 2; // packed FP4: 2 values per byte
243+
int chunks = max(1, kTargetBlocks / max(num_experts, 1));
232244

233-
// One threadblock per expert, 256 threads
234-
dim3 grid(num_experts);
245+
dim3 grid(num_experts, chunks);
235246
dim3 block(256);
236247

237248
kMoeScatterNVFP4<<<grid, block, 0, stream>>>(
238249
static_cast<const uint8_t*>(input),
239250
static_cast<uint8_t*>(output),
240251
expert_offsets,
241252
max_M,
242-
row_bytes
253+
row_bytes,
254+
chunks
243255
);
244256
}
245257

@@ -253,16 +265,18 @@ extern "C" void cmoe_gather_bf16(
253265
cudaStream_t stream
254266
) {
255267
int row_bytes = N * 2; // bf16: 2 bytes per element
268+
int chunks = max(1, kTargetBlocks / max(num_experts, 1));
256269

257-
dim3 grid(num_experts);
270+
dim3 grid(num_experts, chunks);
258271
dim3 block(256);
259272

260273
kMoeGatherBF16<<<grid, block, 0, stream>>>(
261274
static_cast<const uint8_t*>(input),
262275
static_cast<uint8_t*>(output),
263276
expert_offsets,
264277
max_M,
265-
row_bytes
278+
row_bytes,
279+
chunks
266280
);
267281
}
268282

0 commit comments

Comments
 (0)