Skip to content

Commit 95ec556

Browse files
TimDettmersclaude
andcommitted
Add weighted gather kernel to moe_scatter_gather.cu
Port cmoe_weighted_gather_bf16 from SM_120 branch. This adds a fused gather + weight multiply + FP32 accumulate + BF16 convert kernel that replaces separate gather + scale + sum operations. The kernel uses atomicAdd for cross-expert accumulation with minimal contention. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 7d04203 commit 95ec556

File tree

1 file changed

+115
-2
lines changed

1 file changed

+115
-2
lines changed

csrc/qutlass/moe_scatter_gather.cu

Lines changed: 115 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,23 @@
11
/*
22
* Scatter and gather kernels for MoE batched NVFP4 GEMM pipeline.
33
*
4-
* Scatter: copies packed FP4 data from concatenated token layout to
4+
* Scatter: copies packed FP4/uint8 data from concatenated token layout to
55
* padded per-expert batched layout. Zero-fills padding rows.
6+
* Works for both packed FP4 activations (row_bytes = K/2) and
7+
* scale factors (same kernel, different row_bytes).
68
*
79
* Gather: copies BF16 results from padded per-expert batched layout
810
* back to concatenated token layout.
911
*
10-
* Both kernels use one threadblock per expert with vectorized 128-bit
12+
* Weighted gather: fused gather + multiply by expert gating weight +
13+
* atomicAdd into output. Single kernel replaces gather + scale + sum.
14+
*
15+
* All kernels use one threadblock per expert with vectorized 128-bit
1116
* (uint4) loads/stores for bandwidth efficiency.
1217
*/
1318

1419
#include <cuda_runtime.h>
20+
#include <cuda_bf16.h>
1521
#include <cstdint>
1622

1723
// =========================================================================
@@ -153,6 +159,62 @@ __global__ void kMoeGatherBF16(
153159
}
154160

155161

162+
// =========================================================================
163+
// Weighted gather: padded per-expert BF16 → FP32 accumulate → BF16 output
164+
// =========================================================================
165+
// Two-phase operation (both launched from one extern "C" call):
166+
// Phase 1: kMoeWeightedGatherAccum — read BF16 expert output, multiply by
167+
// gating weight, atomicAdd into FP32 workspace.
168+
// Phase 2: kConvertFP32ToBF16 — convert FP32 workspace to BF16 output.
169+
//
170+
// Uses a token-parallel layout: grid = (total_assignments,) where each
171+
// assignment is a (token_id, expert_id, weight) triple. Atomic contention
172+
// is minimal — at most top_k experts write to the same token row, and with
173+
// N=4096 elements spread across 256 threads, collisions are rare.
174+
//
175+
// FP32 accumulation avoids BF16 rounding error across top_k additions.
176+
// The final conversion to BF16 rounds once at the end.
177+
178+
__global__ void kMoeWeightedGatherAccum(
179+
const __nv_bfloat16* __restrict__ D_batched, // [num_experts * max_M * N]
180+
float* __restrict__ workspace, // [num_tokens * N] fp32, zero-initialized
181+
const int* __restrict__ token_ids, // [total_assignments]
182+
const int* __restrict__ expert_ids, // [total_assignments]
183+
const int* __restrict__ slot_ids, // [total_assignments]
184+
const float* __restrict__ weights, // [total_assignments]
185+
int max_M,
186+
int N
187+
) {
188+
int assign = blockIdx.x;
189+
int token_id = token_ids[assign];
190+
int expert_id = expert_ids[assign];
191+
int slot_id = slot_ids[assign];
192+
float w = weights[assign];
193+
194+
const __nv_bfloat16* src = D_batched + ((long long)expert_id * max_M + slot_id) * N;
195+
float* dst = workspace + (long long)token_id * N;
196+
197+
int tid = threadIdx.x;
198+
int stride = blockDim.x;
199+
200+
for (int i = tid; i < N; i += stride) {
201+
float val = __bfloat162float(src[i]) * w;
202+
atomicAdd(&dst[i], val);
203+
}
204+
}
205+
206+
__global__ void kConvertFP32ToBF16(
207+
const float* __restrict__ input, // [n_elements]
208+
__nv_bfloat16* __restrict__ output, // [n_elements]
209+
int n_elements
210+
) {
211+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
212+
if (idx < n_elements) {
213+
output[idx] = __float2bfloat16(input[idx]);
214+
}
215+
}
216+
217+
156218
// =========================================================================
157219
// extern "C" launchers
158220
// =========================================================================
@@ -203,3 +265,54 @@ extern "C" void cmoe_gather_bf16(
203265
row_bytes
204266
);
205267
}
268+
269+
extern "C" void cmoe_weighted_gather_bf16(
270+
const void* D_batched, // [num_experts * max_M * N] bf16
271+
void* output_bf16, // [num_tokens * N] bf16, final output
272+
float* workspace_fp32, // [num_tokens * N] fp32, caller-managed scratch
273+
const int* token_ids, // [total_assignments]
274+
const int* expert_ids, // [total_assignments]
275+
const int* slot_ids, // [total_assignments]
276+
const float* weights, // [total_assignments]
277+
int total_assignments,
278+
int num_tokens,
279+
int max_M,
280+
int N,
281+
cudaStream_t stream
282+
) {
283+
if (total_assignments <= 0) return;
284+
285+
int n_elements = num_tokens * N;
286+
287+
// Zero the FP32 workspace
288+
cudaMemsetAsync(workspace_fp32, 0, (size_t)n_elements * sizeof(float), stream);
289+
290+
// Phase 1: weighted accumulate into FP32 workspace
291+
{
292+
dim3 grid(total_assignments);
293+
dim3 block(256);
294+
295+
kMoeWeightedGatherAccum<<<grid, block, 0, stream>>>(
296+
static_cast<const __nv_bfloat16*>(D_batched),
297+
workspace_fp32,
298+
token_ids,
299+
expert_ids,
300+
slot_ids,
301+
weights,
302+
max_M,
303+
N
304+
);
305+
}
306+
307+
// Phase 2: convert FP32 → BF16
308+
{
309+
int threads = 256;
310+
int blocks = (n_elements + threads - 1) / threads;
311+
312+
kConvertFP32ToBF16<<<blocks, threads, 0, stream>>>(
313+
workspace_fp32,
314+
static_cast<__nv_bfloat16*>(output_bf16),
315+
n_elements
316+
);
317+
}
318+
}

0 commit comments

Comments
 (0)