2323// =========================================================================
2424// Scatter: concatenated FP4 → padded per-expert batched FP4
2525// =========================================================================
26- // Each threadblock handles one expert. Threads cooperatively copy
27- // n_tokens * row_bytes from the concatenated source to the padded
28- // destination, then zero-fill padding rows .
26+ // Grid: (num_experts, chunks_per_expert). Each block handles a byte-range
27+ // slice of one expert's total_bytes = max_M * row_bytes, splitting work
28+ // across multiple SMs for bandwidth saturation on wide GPUs (B200: 160 SMs) .
2929//
3030// Data layout:
3131// Input: packed_concat [total_tokens * row_bytes] contiguous
@@ -37,80 +37,76 @@ __global__ void kMoeScatterNVFP4(
3737 uint8_t * __restrict__ output, // [num_experts * max_M * row_bytes]
3838 const int * __restrict__ expert_offsets, // [num_experts + 1] cumulative token offsets
3939 int max_M, // padded max tokens per expert
40- int row_bytes // K / 2
40+ int row_bytes, // K / 2
41+ int chunks_per_expert // gridDim.y
4142) {
4243 int expert = blockIdx .x ;
44+ int chunk = blockIdx .y ;
45+
4346 int start = expert_offsets[expert];
4447 int end = expert_offsets[expert + 1 ];
4548 int n_tokens = end - start;
4649
47- // Source: contiguous in concatenated buffer
4850 const uint8_t * src = input + (long long )start * row_bytes;
49-
50- // Destination: padded slot for this expert
5151 uint8_t * dst = output + (long long )expert * max_M * row_bytes;
5252
53- // Total bytes to process for this expert (data + padding)
5453 long long total_bytes = (long long )max_M * row_bytes;
5554 long long data_bytes = (long long )n_tokens * row_bytes;
5655
57- // Use vectorized uint4 (16-byte) copies where possible
56+ // This block's byte range (aligned to 16 for vectorization)
57+ long long bytes_per_chunk = ((total_bytes + chunks_per_expert - 1 ) / chunks_per_expert + 15 ) & ~15LL ;
58+ long long my_start = (long long )chunk * bytes_per_chunk;
59+ long long my_end = min (my_start + bytes_per_chunk, total_bytes);
60+ if (my_start >= total_bytes) return ;
61+
5862 int tid = threadIdx .x ;
5963 int stride = blockDim .x ;
6064
61- // Copy data rows using uint4 vectorization
62- long long vec_data_bytes = (data_bytes / 16 ) * 16 ;
63- const uint4 * src4 = reinterpret_cast <const uint4 *>(src);
64- uint4 * dst4 = reinterpret_cast <uint4 *>(dst);
65- long long n_vec = vec_data_bytes / 16 ;
66-
67- for (long long i = tid; i < n_vec; i += stride) {
68- dst4[i] = src4[i];
69- }
65+ // Process byte range [my_start, my_end) — copy from src where < data_bytes, zero otherwise
66+ // Use uint4 (16-byte) vectorization
67+ long long vec_start = (my_start + 15 ) / 16 ; // first full uint4 in range
68+ long long vec_end = my_end / 16 ; // last full uint4 in range
7069
71- // Handle remaining bytes in data region
72- for (long long i = vec_data_bytes + tid; i < data_bytes ; i += stride) {
73- dst[i] = src[i];
70+ // Scalar head bytes
71+ for (long long i = my_start + tid; i < min (vec_start * 16 , my_end) ; i += stride) {
72+ dst[i] = (i < data_bytes) ? src[i] : 0 ;
7473 }
7574
76- // Zero-fill padding region using uint4
77- long long pad_start = data_bytes;
78- long long pad_bytes = total_bytes - pad_start;
79-
80- if (pad_bytes > 0 ) {
81- // Align pad_start up to 16-byte boundary for vectorized zeroing
82- long long aligned_pad_start = ((pad_start + 15 ) / 16 ) * 16 ;
83-
84- // Zero unaligned bytes at start of padding
85- for (long long i = pad_start + tid; i < aligned_pad_start && i < total_bytes; i += stride) {
86- dst[i] = 0 ;
87- }
88-
89- // Vectorized zero-fill
90- uint4 zero4 = make_uint4 (0 , 0 , 0 , 0 );
91- long long vec_pad_end = (total_bytes / 16 ) * 16 ;
92- uint4 * dst4_pad = reinterpret_cast <uint4 *>(dst);
93- long long vec_start = aligned_pad_start / 16 ;
94- long long vec_end = vec_pad_end / 16 ;
95-
96- for (long long i = vec_start + tid; i < vec_end; i += stride) {
97- dst4_pad[i] = zero4;
75+ // Vectorized middle
76+ const uint4 * src4 = reinterpret_cast <const uint4 *>(src);
77+ uint4 * dst4 = reinterpret_cast <uint4 *>(dst);
78+ uint4 zero4 = make_uint4 (0 , 0 , 0 , 0 );
79+ long long data_vec_boundary = data_bytes / 16 ; // last full uint4 within data
80+
81+ for (long long i = vec_start + tid; i < vec_end; i += stride) {
82+ if (i < data_vec_boundary) {
83+ dst4[i] = src4[i];
84+ } else if (i * 16 >= data_bytes) {
85+ dst4[i] = zero4;
86+ } else {
87+ // Straddles data/padding boundary — byte-by-byte
88+ uint8_t tmp[16 ];
89+ const uint8_t * s = src + i * 16 ;
90+ for (int b = 0 ; b < 16 ; b++) {
91+ long long pos = i * 16 + b;
92+ tmp[b] = (pos < data_bytes) ? s[b] : 0 ;
93+ }
94+ dst4[i] = *reinterpret_cast <uint4 *>(tmp);
9895 }
96+ }
9997
100- // Zero remaining bytes at end
101- for (long long i = vec_pad_end + tid; i < total_bytes; i += stride) {
102- dst[i] = 0 ;
103- }
98+ // Scalar tail bytes
99+ for (long long i = vec_end * 16 + tid; i < my_end; i += stride) {
100+ dst[i] = (i < data_bytes) ? src[i] : 0 ;
104101 }
105102}
106103
107104
108105// =========================================================================
109106// Gather: padded per-expert BF16 → concatenated BF16
110107// =========================================================================
111- // Each threadblock handles one expert. Threads cooperatively copy
112- // n_tokens * row_elems BF16 values from the padded batched output
113- // to the concatenated result.
108+ // Grid: (num_experts, chunks_per_expert). Each block handles a byte-range
109+ // slice of one expert's data_bytes = n_tokens * row_bytes.
114110//
115111// Data layout:
116112// Input: D_batched [num_experts * max_M * N] bf16
@@ -122,38 +118,49 @@ __global__ void kMoeGatherBF16(
122118 uint8_t * __restrict__ output, // [total_tokens * row_bytes]
123119 const int * __restrict__ expert_offsets, // [num_experts + 1]
124120 int max_M,
125- int row_bytes // N * 2
121+ int row_bytes, // N * 2
122+ int chunks_per_expert // gridDim.y
126123) {
127124 int expert = blockIdx .x ;
125+ int chunk = blockIdx .y ;
126+
128127 int start = expert_offsets[expert];
129128 int end = expert_offsets[expert + 1 ];
130129 int n_tokens = end - start;
131-
132130 if (n_tokens <= 0 ) return ;
133131
134- // Source: padded slot for this expert
135132 const uint8_t * src = input + (long long )expert * max_M * row_bytes;
136-
137- // Destination: contiguous in concatenated buffer
138133 uint8_t * dst = output + (long long )start * row_bytes;
139134
140135 long long data_bytes = (long long )n_tokens * row_bytes;
141136
137+ // This block's byte range (aligned to 16)
138+ long long bytes_per_chunk = ((data_bytes + chunks_per_expert - 1 ) / chunks_per_expert + 15 ) & ~15LL ;
139+ long long my_start = (long long )chunk * bytes_per_chunk;
140+ long long my_end = min (my_start + bytes_per_chunk, data_bytes);
141+ if (my_start >= data_bytes) return ;
142+
142143 int tid = threadIdx .x ;
143144 int stride = blockDim .x ;
144145
145- // Vectorized uint4 copy
146- long long vec_bytes = (data_bytes / 16 ) * 16 ;
146+ // Vectorized uint4 copy over [my_start, my_end)
147+ long long vec_start = (my_start + 15 ) / 16 ;
148+ long long vec_end = my_end / 16 ;
149+
150+ // Scalar head
151+ for (long long i = my_start + tid; i < min (vec_start * 16 , my_end); i += stride) {
152+ dst[i] = src[i];
153+ }
154+
155+ // Vectorized middle
147156 const uint4 * src4 = reinterpret_cast <const uint4 *>(src);
148157 uint4 * dst4 = reinterpret_cast <uint4 *>(dst);
149- long long n_vec = vec_bytes / 16 ;
150-
151- for (long long i = tid; i < n_vec; i += stride) {
158+ for (long long i = vec_start + tid; i < vec_end; i += stride) {
152159 dst4[i] = src4[i];
153160 }
154161
155- // Handle remaining bytes
156- for (long long i = vec_bytes + tid; i < data_bytes ; i += stride) {
162+ // Scalar tail
163+ for (long long i = vec_end * 16 + tid; i < my_end ; i += stride) {
157164 dst[i] = src[i];
158165 }
159166}
@@ -219,6 +226,10 @@ __global__ void kConvertFP32ToBF16(
219226// extern "C" launchers
220227// =========================================================================
221228
229+ // Target enough total blocks to saturate GPU SMs.
230+ // B200 has 160 SMs; 2× oversubscription hides latency.
231+ static constexpr int kTargetBlocks = 320 ;
232+
222233extern " C" void cmoe_scatter_nvfp4 (
223234 const void * input,
224235 void * output,
@@ -229,17 +240,18 @@ extern "C" void cmoe_scatter_nvfp4(
229240 cudaStream_t stream
230241) {
231242 int row_bytes = K / 2 ; // packed FP4: 2 values per byte
243+ int chunks = max (1 , kTargetBlocks / max (num_experts, 1 ));
232244
233- // One threadblock per expert, 256 threads
234- dim3 grid (num_experts);
245+ dim3 grid (num_experts, chunks);
235246 dim3 block (256 );
236247
237248 kMoeScatterNVFP4 <<<grid, block, 0 , stream>>> (
238249 static_cast <const uint8_t *>(input),
239250 static_cast <uint8_t *>(output),
240251 expert_offsets,
241252 max_M,
242- row_bytes
253+ row_bytes,
254+ chunks
243255 );
244256}
245257
@@ -253,16 +265,18 @@ extern "C" void cmoe_gather_bf16(
253265 cudaStream_t stream
254266) {
255267 int row_bytes = N * 2 ; // bf16: 2 bytes per element
268+ int chunks = max (1 , kTargetBlocks / max (num_experts, 1 ));
256269
257- dim3 grid (num_experts);
270+ dim3 grid (num_experts, chunks );
258271 dim3 block (256 );
259272
260273 kMoeGatherBF16 <<<grid, block, 0 , stream>>> (
261274 static_cast <const uint8_t *>(input),
262275 static_cast <uint8_t *>(output),
263276 expert_offsets,
264277 max_M,
265- row_bytes
278+ row_bytes,
279+ chunks
266280 );
267281}
268282
0 commit comments