1616namespace transformer_engine {
1717namespace fused_router {
1818
19- template <typename DataType, TopkFuncType TopkFunc = TopkFuncType::Naive>
20- __global__ void fused_score_for_moe_aux_loss_forward_kernel (const DataType *logits, int num_tokens,
21- int num_experts, int topk,
22- int score_function, float *scores,
23- uint8_t *routing_map,
24- NVTERoutingMapFormat routing_map_format,
25- CompType *intermediate_output) {
19+ template <typename DataType, NVTERoutingMapFormat RoutingMapFormat,
20+ TopkFuncType TopkFunc = TopkFuncType::Naive>
21+ __global__ void fused_score_for_moe_aux_loss_forward_kernel (
22+ const DataType *logits, int num_tokens, int num_experts, int topk, int score_function,
23+ float *scores, uint8_t *routing_map, CompType *intermediate_output) {
24+ constexpr bool kIsBitmap = (RoutingMapFormat == NVTE_ROUTING_MAP_FORMAT_BITMAP_U8 );
2625 /* **
2726 * Section: Global Variables/Addresses init
2827 * - Each warp is responsible for one token, and has own shared memory buffer.
@@ -42,7 +41,7 @@ __global__ void fused_score_for_moe_aux_loss_forward_kernel(const DataType *logi
4241 const int bitmap_words_per_warp = (num_experts + 31 ) / 32 ;
4342 const int bitmap_row_bytes = (num_experts + 7 ) / 8 ;
4443 uint32_t *bitmap_words_buf = nullptr ;
45- if (routing_map_format == NVTE_ROUTING_MAP_FORMAT_BITMAP_U8 ) {
44+ if constexpr ( kIsBitmap ) {
4645 bitmap_words_buf = reinterpret_cast <uint32_t *>(topk_indices_buf + topk * num_token_per_block);
4746 }
4847 // The address of buffers on the current warp
@@ -77,7 +76,7 @@ __global__ void fused_score_for_moe_aux_loss_forward_kernel(const DataType *logi
7776 intermediate_output[pos_offset + i] = -std::numeric_limits<CompType>::infinity ();
7877 }
7978 }
80- if (routing_map_format == NVTE_ROUTING_MAP_FORMAT_BYTEMAP ) {
79+ if constexpr (! kIsBitmap ) {
8180 for (int i = lane_id; i < num_experts; i += kThreadsPerWarp ) {
8281 routing_map[pos_offset + i] = 0 ;
8382 }
@@ -148,7 +147,7 @@ __global__ void fused_score_for_moe_aux_loss_forward_kernel(const DataType *logi
148147 __syncwarp ();
149148
150149 // Write the routing_map to the output tensor
151- if (routing_map_format == NVTE_ROUTING_MAP_FORMAT_BYTEMAP ) {
150+ if constexpr (! kIsBitmap ) {
152151 for (int i = lane_id; i < topk; i += kThreadsPerWarp ) {
153152 routing_map[pos_offset + topk_indices[i]] = 1 ;
154153 }
@@ -174,18 +173,17 @@ __global__ void fused_score_for_moe_aux_loss_forward_kernel(const DataType *logi
174173 }
175174}
176175
177- template <typename DataType>
176+ template <typename DataType, NVTERoutingMapFormat RoutingMapFormat >
178177void fused_score_for_moe_aux_loss_forward_kernel_launcher (
179178 const DataType *logits, int num_tokens, int num_experts, int topk, int score_function,
180- float *scores, uint8_t *routing_map, NVTERoutingMapFormat routing_map_format,
181- CompType *intermediate_output, cudaStream_t stream) {
179+ float *scores, uint8_t *routing_map, CompType *intermediate_output, cudaStream_t stream) {
182180 // Meta data for the kernel
183181 size_t num_token_per_block = kThreadsPerBlock / kThreadsPerWarp ;
184182 size_t grid_size = (num_tokens + num_token_per_block - 1 ) / num_token_per_block;
185183 size_t shared_memory_size = num_experts * num_token_per_block * sizeof (CompType) // logits
186184 + topk * num_token_per_block * sizeof (CompType) // topk_logits
187185 + topk * num_token_per_block * sizeof (int ); // topk_indices
188- if (routing_map_format == NVTE_ROUTING_MAP_FORMAT_BITMAP_U8 ) {
186+ if constexpr (RoutingMapFormat == NVTE_ROUTING_MAP_FORMAT_BITMAP_U8 ) {
189187 size_t bitmap_words_per_warp = (num_experts + 31 ) / 32 ;
190188 shared_memory_size +=
191189 bitmap_words_per_warp * num_token_per_block * sizeof (uint32_t ); // bitmap accumulator
@@ -195,36 +193,76 @@ void fused_score_for_moe_aux_loss_forward_kernel_launcher(
195193 // switch at K=16 where naive O(K^2*E) starts to dominate
196194 if (topk < 16 ) {
197195 NVTE_CHECK_CUDA (cudaFuncSetAttribute (
198- fused_score_for_moe_aux_loss_forward_kernel<DataType, TopkFuncType::Naive>,
196+ fused_score_for_moe_aux_loss_forward_kernel<DataType, RoutingMapFormat,
197+ TopkFuncType::Naive>,
199198 cudaFuncAttributeMaxDynamicSharedMemorySize, shared_memory_size));
200- fused_score_for_moe_aux_loss_forward_kernel<DataType, TopkFuncType::Naive>
199+ fused_score_for_moe_aux_loss_forward_kernel<DataType, RoutingMapFormat, TopkFuncType::Naive>
201200 <<<grid_size, kThreadsPerBlock , shared_memory_size, stream>>> (
202201 logits, num_tokens, num_experts, topk, score_function, scores, routing_map,
203- routing_map_format, intermediate_output);
202+ intermediate_output);
204203 } else {
205204 NVTE_CHECK_CUDA (cudaFuncSetAttribute (
206- fused_score_for_moe_aux_loss_forward_kernel<DataType, TopkFuncType::Radix>,
205+ fused_score_for_moe_aux_loss_forward_kernel<DataType, RoutingMapFormat,
206+ TopkFuncType::Radix>,
207207 cudaFuncAttributeMaxDynamicSharedMemorySize, shared_memory_size));
208- fused_score_for_moe_aux_loss_forward_kernel<DataType, TopkFuncType::Radix>
208+ fused_score_for_moe_aux_loss_forward_kernel<DataType, RoutingMapFormat, TopkFuncType::Radix>
209209 <<<grid_size, kThreadsPerBlock , shared_memory_size, stream>>> (
210210 logits, num_tokens, num_experts, topk, score_function, scores, routing_map,
211- routing_map_format, intermediate_output);
211+ intermediate_output);
212212 }
213213 NVTE_CHECK_CUDA (cudaGetLastError ());
214214}
215215
216+ // Build the expected routing_map shape for a given NVTERoutingMapFormat.
217+ // BYTEMAP -> [num_tokens, num_experts]
218+ // BITMAP_U8 -> [num_tokens, ceil(num_experts/8)]
219+ static std::vector<size_t > expected_routing_map_shape (int num_tokens, int num_experts,
220+ NVTERoutingMapFormat format) {
221+ const size_t t = static_cast <size_t >(num_tokens);
222+ const size_t e = static_cast <size_t >(num_experts);
223+ if (format == NVTE_ROUTING_MAP_FORMAT_BITMAP_U8 ) {
224+ return {t, (e + 7 ) / 8 };
225+ }
226+ return {t, e};
227+ }
228+
216229void fused_score_for_moe_aux_loss_forward (const Tensor &logits, int num_tokens, int num_experts,
217230 int topk, int score_function, Tensor &scores,
218231 Tensor &routing_map,
219232 NVTERoutingMapFormat routing_map_format,
220233 Tensor &intermediate_output, cudaStream_t stream) {
221- TE_ROUTER_PROBS_TYPE_SWITCH_ALL (
222- logits.data .dtype , DataType,
223- fused_score_for_moe_aux_loss_forward_kernel_launcher<DataType>(
224- reinterpret_cast <DataType *>(logits.data .dptr ), num_tokens, num_experts, topk,
225- score_function, reinterpret_cast <float *>(scores.data .dptr ),
226- reinterpret_cast <uint8_t *>(routing_map.data .dptr ), routing_map_format,
234+ NVTE_CHECK (num_tokens > 0 && num_experts > 0 ,
235+ " num_tokens and num_experts must be positive; got num_tokens=" , num_tokens,
236+ " , num_experts=" , num_experts);
237+ const std::vector<size_t > dense_shape{static_cast <size_t >(num_tokens),
238+ static_cast <size_t >(num_experts)};
239+ NVTE_CHECK (logits.data .shape == dense_shape, " logits shape must be [num_tokens, num_experts]=[" ,
240+ num_tokens, " , " , num_experts, " ], got " , logits.data .shape );
241+ NVTE_CHECK (scores.data .shape == dense_shape, " scores shape must be [num_tokens, num_experts]=[" ,
242+ num_tokens, " , " , num_experts, " ], got " , scores.data .shape );
243+ NVTE_CHECK (intermediate_output.data .shape == dense_shape,
244+ " intermediate_output shape must be [num_tokens, num_experts]=[" , num_tokens, " , " ,
245+ num_experts, " ], got " , intermediate_output.data .shape );
246+ const auto routing_map_shape =
247+ expected_routing_map_shape (num_tokens, num_experts, routing_map_format);
248+ NVTE_CHECK (routing_map.data .shape == routing_map_shape,
249+ " routing_map shape mismatch for " ,
250+ (routing_map_format == NVTE_ROUTING_MAP_FORMAT_BITMAP_U8 ? " BITMAP_U8" : " BYTEMAP" ),
251+ " ; expected " , routing_map_shape, " , got " , routing_map.data .shape );
252+ #define AUX_LOSS_FORWARD_DISPATCH (RoutingMapFormatVal ) \
253+ TE_ROUTER_PROBS_TYPE_SWITCH_ALL ( \
254+ logits.data .dtype , DataType, \
255+ fused_score_for_moe_aux_loss_forward_kernel_launcher<DataType, RoutingMapFormatVal>( \
256+ reinterpret_cast <DataType *>(logits.data .dptr ), num_tokens, num_experts, topk, \
257+ score_function, reinterpret_cast <float *>(scores.data .dptr ), \
258+ reinterpret_cast <uint8_t *>(routing_map.data .dptr ), \
227259 reinterpret_cast <CompType *>(intermediate_output.data .dptr ), stream););
260+ if (routing_map_format == NVTE_ROUTING_MAP_FORMAT_BITMAP_U8 ) {
261+ AUX_LOSS_FORWARD_DISPATCH (NVTE_ROUTING_MAP_FORMAT_BITMAP_U8 )
262+ } else {
263+ AUX_LOSS_FORWARD_DISPATCH (NVTE_ROUTING_MAP_FORMAT_BYTEMAP )
264+ }
265+ #undef AUX_LOSS_FORWARD_DISPATCH
228266}
229267
230268template <typename DataType>
@@ -389,18 +427,32 @@ void fused_score_for_moe_aux_loss_backward(const Tensor &intermediate_output,
389427} // namespace fused_router
390428} // namespace transformer_engine
391429
430+ void nvte_fused_score_for_moe_aux_loss_forward_v2 (const NVTETensor logits, int num_tokens,
431+ int num_experts, int topk, int score_function,
432+ NVTETensor scores, NVTETensor routing_map,
433+ NVTERoutingMapFormat routing_map_format,
434+ const NVTETensor intermediate_output,
435+ cudaStream_t stream) {
436+ NVTE_API_CALL (nvte_fused_score_for_moe_aux_loss_forward_v2);
437+ using namespace transformer_engine ;
438+ fused_router::fused_score_for_moe_aux_loss_forward (
439+ *convertNVTETensorCheck (logits), num_tokens, num_experts, topk, score_function,
440+ *convertNVTETensorCheck (scores), *convertNVTETensorCheck (routing_map), routing_map_format,
441+ *convertNVTETensorCheck (intermediate_output), stream);
442+ }
443+
444+ // Deprecated V1 entry point: forwards to the V2 above with the BYTEMAP layout.
445+ // Kept for ABI compatibility with external C API consumers.
392446void nvte_fused_score_for_moe_aux_loss_forward (const NVTETensor logits, int num_tokens,
393447 int num_experts, int topk, int score_function,
394448 NVTETensor scores, NVTETensor routing_map,
395- NVTERoutingMapFormat routing_map_format,
396449 const NVTETensor intermediate_output,
397450 cudaStream_t stream) {
398451 NVTE_API_CALL (nvte_fused_score_for_moe_aux_loss_forward);
399- using namespace transformer_engine ;
400- fused_router::fused_score_for_moe_aux_loss_forward (
401- *convertNVTETensorCheck (logits), num_tokens, num_experts, topk, score_function,
402- *convertNVTETensorCheck (scores), *convertNVTETensorCheck (routing_map), routing_map_format,
403- *convertNVTETensorCheck (intermediate_output), stream);
452+ nvte_fused_score_for_moe_aux_loss_forward_v2 (logits, num_tokens, num_experts, topk,
453+ score_function, scores, routing_map,
454+ NVTE_ROUTING_MAP_FORMAT_BYTEMAP ,
455+ intermediate_output, stream);
404456}
405457
406458void nvte_fused_score_for_moe_aux_loss_backward (const NVTETensor intermediate_output,
0 commit comments