|
1 | 1 | /* |
2 | 2 | * Scatter and gather kernels for MoE batched NVFP4 GEMM pipeline. |
3 | 3 | * |
4 | | - * Scatter: copies packed FP4 data from concatenated token layout to |
| 4 | + * Scatter: copies packed FP4/uint8 data from concatenated token layout to |
5 | 5 | * padded per-expert batched layout. Zero-fills padding rows. |
| 6 | + * Works for both packed FP4 activations (row_bytes = K/2) and |
| 7 | + * scale factors (same kernel, different row_bytes). |
6 | 8 | * |
7 | 9 | * Gather: copies BF16 results from padded per-expert batched layout |
8 | 10 | * back to concatenated token layout. |
9 | 11 | * |
10 | | - * Both kernels use one threadblock per expert with vectorized 128-bit |
| 12 | + * Weighted gather: fused gather + multiply by expert gating weight + |
| 13 | + * atomicAdd into output. Single kernel replaces gather + scale + sum. |
| 14 | + * |
| 15 | + * All kernels use one threadblock per expert with vectorized 128-bit |
11 | 16 | * (uint4) loads/stores for bandwidth efficiency. |
12 | 17 | */ |
13 | 18 |
|
14 | 19 | #include <cuda_runtime.h> |
| 20 | +#include <cuda_bf16.h> |
15 | 21 | #include <cstdint> |
16 | 22 |
|
17 | 23 | // ========================================================================= |
@@ -153,6 +159,62 @@ __global__ void kMoeGatherBF16( |
153 | 159 | } |
154 | 160 |
|
155 | 161 |
|
| 162 | +// ========================================================================= |
| 163 | +// Weighted gather: padded per-expert BF16 → FP32 accumulate → BF16 output |
| 164 | +// ========================================================================= |
| 165 | +// Two-phase operation (both launched from one extern "C" call): |
| 166 | +// Phase 1: kMoeWeightedGatherAccum — read BF16 expert output, multiply by |
| 167 | +// gating weight, atomicAdd into FP32 workspace. |
| 168 | +// Phase 2: kConvertFP32ToBF16 — convert FP32 workspace to BF16 output. |
| 169 | +// |
| 170 | +// Uses a token-parallel layout: grid = (total_assignments,) where each |
| 171 | +// assignment is a (token_id, expert_id, weight) triple. Atomic contention |
| 172 | +// is minimal — at most top_k experts write to the same token row, and with |
| 173 | +// N=4096 elements spread across 256 threads, collisions are rare. |
| 174 | +// |
| 175 | +// FP32 accumulation avoids BF16 rounding error across top_k additions. |
| 176 | +// The final conversion to BF16 rounds once at the end. |
| 177 | + |
| 178 | +__global__ void kMoeWeightedGatherAccum( |
| 179 | + const __nv_bfloat16* __restrict__ D_batched, // [num_experts * max_M * N] |
| 180 | + float* __restrict__ workspace, // [num_tokens * N] fp32, zero-initialized |
| 181 | + const int* __restrict__ token_ids, // [total_assignments] |
| 182 | + const int* __restrict__ expert_ids, // [total_assignments] |
| 183 | + const int* __restrict__ slot_ids, // [total_assignments] |
| 184 | + const float* __restrict__ weights, // [total_assignments] |
| 185 | + int max_M, |
| 186 | + int N |
| 187 | +) { |
| 188 | + int assign = blockIdx.x; |
| 189 | + int token_id = token_ids[assign]; |
| 190 | + int expert_id = expert_ids[assign]; |
| 191 | + int slot_id = slot_ids[assign]; |
| 192 | + float w = weights[assign]; |
| 193 | + |
| 194 | + const __nv_bfloat16* src = D_batched + ((long long)expert_id * max_M + slot_id) * N; |
| 195 | + float* dst = workspace + (long long)token_id * N; |
| 196 | + |
| 197 | + int tid = threadIdx.x; |
| 198 | + int stride = blockDim.x; |
| 199 | + |
| 200 | + for (int i = tid; i < N; i += stride) { |
| 201 | + float val = __bfloat162float(src[i]) * w; |
| 202 | + atomicAdd(&dst[i], val); |
| 203 | + } |
| 204 | +} |
| 205 | + |
| 206 | +__global__ void kConvertFP32ToBF16( |
| 207 | + const float* __restrict__ input, // [n_elements] |
| 208 | + __nv_bfloat16* __restrict__ output, // [n_elements] |
| 209 | + int n_elements |
| 210 | +) { |
| 211 | + int idx = blockIdx.x * blockDim.x + threadIdx.x; |
| 212 | + if (idx < n_elements) { |
| 213 | + output[idx] = __float2bfloat16(input[idx]); |
| 214 | + } |
| 215 | +} |
| 216 | + |
| 217 | + |
156 | 218 | // ========================================================================= |
157 | 219 | // extern "C" launchers |
158 | 220 | // ========================================================================= |
@@ -203,3 +265,54 @@ extern "C" void cmoe_gather_bf16( |
203 | 265 | row_bytes |
204 | 266 | ); |
205 | 267 | } |
| 268 | + |
| 269 | +extern "C" void cmoe_weighted_gather_bf16( |
| 270 | + const void* D_batched, // [num_experts * max_M * N] bf16 |
| 271 | + void* output_bf16, // [num_tokens * N] bf16, final output |
| 272 | + float* workspace_fp32, // [num_tokens * N] fp32, caller-managed scratch |
| 273 | + const int* token_ids, // [total_assignments] |
| 274 | + const int* expert_ids, // [total_assignments] |
| 275 | + const int* slot_ids, // [total_assignments] |
| 276 | + const float* weights, // [total_assignments] |
| 277 | + int total_assignments, |
| 278 | + int num_tokens, |
| 279 | + int max_M, |
| 280 | + int N, |
| 281 | + cudaStream_t stream |
| 282 | +) { |
| 283 | + if (total_assignments <= 0) return; |
| 284 | + |
| 285 | + int n_elements = num_tokens * N; |
| 286 | + |
| 287 | + // Zero the FP32 workspace |
| 288 | + cudaMemsetAsync(workspace_fp32, 0, (size_t)n_elements * sizeof(float), stream); |
| 289 | + |
| 290 | + // Phase 1: weighted accumulate into FP32 workspace |
| 291 | + { |
| 292 | + dim3 grid(total_assignments); |
| 293 | + dim3 block(256); |
| 294 | + |
| 295 | + kMoeWeightedGatherAccum<<<grid, block, 0, stream>>>( |
| 296 | + static_cast<const __nv_bfloat16*>(D_batched), |
| 297 | + workspace_fp32, |
| 298 | + token_ids, |
| 299 | + expert_ids, |
| 300 | + slot_ids, |
| 301 | + weights, |
| 302 | + max_M, |
| 303 | + N |
| 304 | + ); |
| 305 | + } |
| 306 | + |
| 307 | + // Phase 2: convert FP32 → BF16 |
| 308 | + { |
| 309 | + int threads = 256; |
| 310 | + int blocks = (n_elements + threads - 1) / threads; |
| 311 | + |
| 312 | + kConvertFP32ToBF16<<<blocks, threads, 0, stream>>>( |
| 313 | + workspace_fp32, |
| 314 | + static_cast<__nv_bfloat16*>(output_bf16), |
| 315 | + n_elements |
| 316 | + ); |
| 317 | + } |
| 318 | +} |
0 commit comments