@@ -361,27 +361,26 @@ void fused_topk_with_score_function_forward(const Tensor logits, int num_tokens,
361361 num_experts, " ], got " , intermediate_output.data .shape );
362362 const auto routing_map_shape =
363363 expected_routing_map_shape (num_tokens, num_experts, routing_map_format);
364- NVTE_CHECK (routing_map.data .shape == routing_map_shape,
365- " routing_map shape mismatch for " ,
364+ NVTE_CHECK (routing_map.data .shape == routing_map_shape, " routing_map shape mismatch for " ,
366365 (routing_map_format == NVTE_ROUTING_MAP_FORMAT_BITMAP_U8 ? " BITMAP_U8" : " BYTEMAP" ),
367366 " ; expected " , routing_map_shape, " , got " , routing_map.data .shape );
368367 if (expert_bias.data .dptr != nullptr ) {
369368 NVTE_CHECK (expert_bias.data .shape == std::vector<size_t >{static_cast <size_t >(num_experts)},
370369 " expert_bias shape must be [num_experts]=[" , num_experts, " ], got " ,
371370 expert_bias.data .shape );
372371 }
373- #define ROUTER_FORWARD_DISPATCH (RoutingMapFormatVal ) \
374- TE_ROUTER_PROBS_TYPE_SWITCH_ALL ( \
375- logits.data .dtype , DataType, \
376- TE_ROUTER_PROBS_TYPE_SWITCH_ALL ( \
377- expert_bias.data .dtype , BiasType, \
378- fused_topk_with_score_function_forward_kernel_launcher<DataType, BiasType, \
379- RoutingMapFormatVal>( \
372+ #define ROUTER_FORWARD_DISPATCH (RoutingMapFormatVal ) \
373+ TE_ROUTER_PROBS_TYPE_SWITCH_ALL ( \
374+ logits.data .dtype , DataType, \
375+ TE_ROUTER_PROBS_TYPE_SWITCH_ALL ( \
376+ expert_bias.data .dtype , BiasType, \
377+ fused_topk_with_score_function_forward_kernel_launcher<DataType, BiasType, \
378+ RoutingMapFormatVal>( \
380379 reinterpret_cast <DataType *>(logits.data .dptr ), num_tokens, num_experts, topk, \
381- use_pre_softmax, num_groups, group_topk, scaling_factor, score_function, \
382- reinterpret_cast <BiasType *>(expert_bias.data .dptr ), \
383- reinterpret_cast <DataType *>(probs.data .dptr ), \
384- reinterpret_cast <uint8_t *>(routing_map.data .dptr ), \
380+ use_pre_softmax, num_groups, group_topk, scaling_factor, score_function, \
381+ reinterpret_cast <BiasType *>(expert_bias.data .dptr ), \
382+ reinterpret_cast <DataType *>(probs.data .dptr ), \
383+ reinterpret_cast <uint8_t *>(routing_map.data .dptr ), \
385384 reinterpret_cast <CompType *>(intermediate_output.data .dptr ), stream);););
386385 if (routing_map_format == NVTE_ROUTING_MAP_FORMAT_BITMAP_U8 ) {
387386 ROUTER_FORWARD_DISPATCH (NVTE_ROUTING_MAP_FORMAT_BITMAP_U8 )
@@ -621,18 +620,17 @@ void fused_topk_with_score_function_backward(const Tensor &routing_map,
621620 " ], got " , grad_logits.data .shape );
622621 const auto routing_map_shape =
623622 expected_routing_map_shape (num_tokens, num_experts, routing_map_format);
624- NVTE_CHECK (routing_map.data .shape == routing_map_shape,
625- " routing_map shape mismatch for " ,
623+ NVTE_CHECK (routing_map.data .shape == routing_map_shape, " routing_map shape mismatch for " ,
626624 (routing_map_format == NVTE_ROUTING_MAP_FORMAT_BITMAP_U8 ? " BITMAP_U8" : " BYTEMAP" ),
627625 " ; expected " , routing_map_shape, " , got " , routing_map.data .shape );
628- #define ROUTER_BACKWARD_DISPATCH (RoutingMapFormatVal ) \
629- TE_ROUTER_PROBS_TYPE_SWITCH_ALL ( \
630- grad_logits.data .dtype , DataType, \
631- fused_topk_with_score_function_backward_kernel_launcher<DataType, RoutingMapFormatVal>( \
632- reinterpret_cast <uint8_t *>(routing_map.data .dptr ), \
633- reinterpret_cast <CompType *>(intermediate_output.data .dptr ), \
634- reinterpret_cast <DataType *>(grad_probs.data .dptr ), num_tokens, num_experts, topk, \
635- use_pre_softmax, scaling_factor, score_function, \
626+ #define ROUTER_BACKWARD_DISPATCH (RoutingMapFormatVal ) \
627+ TE_ROUTER_PROBS_TYPE_SWITCH_ALL ( \
628+ grad_logits.data .dtype , DataType, \
629+ fused_topk_with_score_function_backward_kernel_launcher<DataType, RoutingMapFormatVal>( \
630+ reinterpret_cast <uint8_t *>(routing_map.data .dptr ), \
631+ reinterpret_cast <CompType *>(intermediate_output.data .dptr ), \
632+ reinterpret_cast <DataType *>(grad_probs.data .dptr ), num_tokens, num_experts, topk, \
633+ use_pre_softmax, scaling_factor, score_function, \
636634 reinterpret_cast <DataType *>(grad_logits.data .dptr ), stream););
637635 if (routing_map_format == NVTE_ROUTING_MAP_FORMAT_BITMAP_U8 ) {
638636 ROUTER_BACKWARD_DISPATCH (NVTE_ROUTING_MAP_FORMAT_BITMAP_U8 )
@@ -674,11 +672,13 @@ void nvte_fused_topk_with_score_function_forward(
674672 NVTE_ROUTING_MAP_FORMAT_BYTEMAP , intermediate_output, stream);
675673}
676674
677- void nvte_fused_topk_with_score_function_backward_v2 (
678- const NVTETensor routing_map, NVTERoutingMapFormat routing_map_format,
679- const NVTETensor intermediate_output, const NVTETensor grad_probs, int num_tokens,
680- int num_experts, int topk, int use_pre_softmax, float scaling_factor, int score_function,
681- NVTETensor grad_logits, cudaStream_t stream) {
675+ void nvte_fused_topk_with_score_function_backward_v2 (const NVTETensor routing_map,
676+ NVTERoutingMapFormat routing_map_format,
677+ const NVTETensor intermediate_output,
678+ const NVTETensor grad_probs, int num_tokens,
679+ int num_experts, int topk, int use_pre_softmax,
680+ float scaling_factor, int score_function,
681+ NVTETensor grad_logits, cudaStream_t stream) {
682682 NVTE_API_CALL (nvte_fused_topk_with_score_function_backward_v2);
683683 using namespace transformer_engine ;
684684 fused_router::fused_topk_with_score_function_backward (
0 commit comments