Skip to content

Commit 4ff6844

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent ba78cf0 commit 4ff6844

5 files changed

Lines changed: 66 additions & 67 deletions

File tree

transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,11 @@ namespace fused_router {
1818

1919
template <typename DataType, NVTERoutingMapFormat RoutingMapFormat,
2020
TopkFuncType TopkFunc = TopkFuncType::Naive>
21-
__global__ void fused_score_for_moe_aux_loss_forward_kernel(
22-
const DataType *logits, int num_tokens, int num_experts, int topk, int score_function,
23-
float *scores, uint8_t *routing_map, CompType *intermediate_output) {
21+
__global__ void fused_score_for_moe_aux_loss_forward_kernel(const DataType *logits, int num_tokens,
22+
int num_experts, int topk,
23+
int score_function, float *scores,
24+
uint8_t *routing_map,
25+
CompType *intermediate_output) {
2426
constexpr bool kIsBitmap = (RoutingMapFormat == NVTE_ROUTING_MAP_FORMAT_BITMAP_U8);
2527
/***
2628
* Section: Global Variables/Addresses init
@@ -192,19 +194,19 @@ void fused_score_for_moe_aux_loss_forward_kernel_launcher(
192194
// Radix selection is O(E), independent of K, but it needs 4 passes for 32-bit float;
193195
// switch at K=16 where naive O(K^2*E) starts to dominate
194196
if (topk < 16) {
195-
NVTE_CHECK_CUDA(cudaFuncSetAttribute(
196-
fused_score_for_moe_aux_loss_forward_kernel<DataType, RoutingMapFormat,
197-
TopkFuncType::Naive>,
198-
cudaFuncAttributeMaxDynamicSharedMemorySize, shared_memory_size));
197+
NVTE_CHECK_CUDA(
198+
cudaFuncSetAttribute(fused_score_for_moe_aux_loss_forward_kernel<DataType, RoutingMapFormat,
199+
TopkFuncType::Naive>,
200+
cudaFuncAttributeMaxDynamicSharedMemorySize, shared_memory_size));
199201
fused_score_for_moe_aux_loss_forward_kernel<DataType, RoutingMapFormat, TopkFuncType::Naive>
200202
<<<grid_size, kThreadsPerBlock, shared_memory_size, stream>>>(
201203
logits, num_tokens, num_experts, topk, score_function, scores, routing_map,
202204
intermediate_output);
203205
} else {
204-
NVTE_CHECK_CUDA(cudaFuncSetAttribute(
205-
fused_score_for_moe_aux_loss_forward_kernel<DataType, RoutingMapFormat,
206-
TopkFuncType::Radix>,
207-
cudaFuncAttributeMaxDynamicSharedMemorySize, shared_memory_size));
206+
NVTE_CHECK_CUDA(
207+
cudaFuncSetAttribute(fused_score_for_moe_aux_loss_forward_kernel<DataType, RoutingMapFormat,
208+
TopkFuncType::Radix>,
209+
cudaFuncAttributeMaxDynamicSharedMemorySize, shared_memory_size));
208210
fused_score_for_moe_aux_loss_forward_kernel<DataType, RoutingMapFormat, TopkFuncType::Radix>
209211
<<<grid_size, kThreadsPerBlock, shared_memory_size, stream>>>(
210212
logits, num_tokens, num_experts, topk, score_function, scores, routing_map,
@@ -245,17 +247,16 @@ void fused_score_for_moe_aux_loss_forward(const Tensor &logits, int num_tokens,
245247
num_experts, "], got ", intermediate_output.data.shape);
246248
const auto routing_map_shape =
247249
expected_routing_map_shape(num_tokens, num_experts, routing_map_format);
248-
NVTE_CHECK(routing_map.data.shape == routing_map_shape,
249-
"routing_map shape mismatch for ",
250+
NVTE_CHECK(routing_map.data.shape == routing_map_shape, "routing_map shape mismatch for ",
250251
(routing_map_format == NVTE_ROUTING_MAP_FORMAT_BITMAP_U8 ? "BITMAP_U8" : "BYTEMAP"),
251252
"; expected ", routing_map_shape, ", got ", routing_map.data.shape);
252-
#define AUX_LOSS_FORWARD_DISPATCH(RoutingMapFormatVal) \
253-
TE_ROUTER_PROBS_TYPE_SWITCH_ALL( \
254-
logits.data.dtype, DataType, \
255-
fused_score_for_moe_aux_loss_forward_kernel_launcher<DataType, RoutingMapFormatVal>( \
256-
reinterpret_cast<DataType *>(logits.data.dptr), num_tokens, num_experts, topk, \
257-
score_function, reinterpret_cast<float *>(scores.data.dptr), \
258-
reinterpret_cast<uint8_t *>(routing_map.data.dptr), \
253+
#define AUX_LOSS_FORWARD_DISPATCH(RoutingMapFormatVal) \
254+
TE_ROUTER_PROBS_TYPE_SWITCH_ALL( \
255+
logits.data.dtype, DataType, \
256+
fused_score_for_moe_aux_loss_forward_kernel_launcher<DataType, RoutingMapFormatVal>( \
257+
reinterpret_cast<DataType *>(logits.data.dptr), num_tokens, num_experts, topk, \
258+
score_function, reinterpret_cast<float *>(scores.data.dptr), \
259+
reinterpret_cast<uint8_t *>(routing_map.data.dptr), \
259260
reinterpret_cast<CompType *>(intermediate_output.data.dptr), stream););
260261
if (routing_map_format == NVTE_ROUTING_MAP_FORMAT_BITMAP_U8) {
261262
AUX_LOSS_FORWARD_DISPATCH(NVTE_ROUTING_MAP_FORMAT_BITMAP_U8)
@@ -449,10 +450,9 @@ void nvte_fused_score_for_moe_aux_loss_forward(const NVTETensor logits, int num_
449450
const NVTETensor intermediate_output,
450451
cudaStream_t stream) {
451452
NVTE_API_CALL(nvte_fused_score_for_moe_aux_loss_forward);
452-
nvte_fused_score_for_moe_aux_loss_forward_v2(logits, num_tokens, num_experts, topk,
453-
score_function, scores, routing_map,
454-
NVTE_ROUTING_MAP_FORMAT_BYTEMAP,
455-
intermediate_output, stream);
453+
nvte_fused_score_for_moe_aux_loss_forward_v2(
454+
logits, num_tokens, num_experts, topk, score_function, scores, routing_map,
455+
NVTE_ROUTING_MAP_FORMAT_BYTEMAP, intermediate_output, stream);
456456
}
457457

458458
void nvte_fused_score_for_moe_aux_loss_backward(const NVTETensor intermediate_output,

transformer_engine/common/fused_router/fused_topk_with_score_function.cu

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -361,27 +361,26 @@ void fused_topk_with_score_function_forward(const Tensor logits, int num_tokens,
361361
num_experts, "], got ", intermediate_output.data.shape);
362362
const auto routing_map_shape =
363363
expected_routing_map_shape(num_tokens, num_experts, routing_map_format);
364-
NVTE_CHECK(routing_map.data.shape == routing_map_shape,
365-
"routing_map shape mismatch for ",
364+
NVTE_CHECK(routing_map.data.shape == routing_map_shape, "routing_map shape mismatch for ",
366365
(routing_map_format == NVTE_ROUTING_MAP_FORMAT_BITMAP_U8 ? "BITMAP_U8" : "BYTEMAP"),
367366
"; expected ", routing_map_shape, ", got ", routing_map.data.shape);
368367
if (expert_bias.data.dptr != nullptr) {
369368
NVTE_CHECK(expert_bias.data.shape == std::vector<size_t>{static_cast<size_t>(num_experts)},
370369
"expert_bias shape must be [num_experts]=[", num_experts, "], got ",
371370
expert_bias.data.shape);
372371
}
373-
#define ROUTER_FORWARD_DISPATCH(RoutingMapFormatVal) \
374-
TE_ROUTER_PROBS_TYPE_SWITCH_ALL( \
375-
logits.data.dtype, DataType, \
376-
TE_ROUTER_PROBS_TYPE_SWITCH_ALL( \
377-
expert_bias.data.dtype, BiasType, \
378-
fused_topk_with_score_function_forward_kernel_launcher<DataType, BiasType, \
379-
RoutingMapFormatVal>( \
372+
#define ROUTER_FORWARD_DISPATCH(RoutingMapFormatVal) \
373+
TE_ROUTER_PROBS_TYPE_SWITCH_ALL( \
374+
logits.data.dtype, DataType, \
375+
TE_ROUTER_PROBS_TYPE_SWITCH_ALL( \
376+
expert_bias.data.dtype, BiasType, \
377+
fused_topk_with_score_function_forward_kernel_launcher<DataType, BiasType, \
378+
RoutingMapFormatVal>( \
380379
reinterpret_cast<DataType *>(logits.data.dptr), num_tokens, num_experts, topk, \
381-
use_pre_softmax, num_groups, group_topk, scaling_factor, score_function, \
382-
reinterpret_cast<BiasType *>(expert_bias.data.dptr), \
383-
reinterpret_cast<DataType *>(probs.data.dptr), \
384-
reinterpret_cast<uint8_t *>(routing_map.data.dptr), \
380+
use_pre_softmax, num_groups, group_topk, scaling_factor, score_function, \
381+
reinterpret_cast<BiasType *>(expert_bias.data.dptr), \
382+
reinterpret_cast<DataType *>(probs.data.dptr), \
383+
reinterpret_cast<uint8_t *>(routing_map.data.dptr), \
385384
reinterpret_cast<CompType *>(intermediate_output.data.dptr), stream);););
386385
if (routing_map_format == NVTE_ROUTING_MAP_FORMAT_BITMAP_U8) {
387386
ROUTER_FORWARD_DISPATCH(NVTE_ROUTING_MAP_FORMAT_BITMAP_U8)
@@ -621,18 +620,17 @@ void fused_topk_with_score_function_backward(const Tensor &routing_map,
621620
"], got ", grad_logits.data.shape);
622621
const auto routing_map_shape =
623622
expected_routing_map_shape(num_tokens, num_experts, routing_map_format);
624-
NVTE_CHECK(routing_map.data.shape == routing_map_shape,
625-
"routing_map shape mismatch for ",
623+
NVTE_CHECK(routing_map.data.shape == routing_map_shape, "routing_map shape mismatch for ",
626624
(routing_map_format == NVTE_ROUTING_MAP_FORMAT_BITMAP_U8 ? "BITMAP_U8" : "BYTEMAP"),
627625
"; expected ", routing_map_shape, ", got ", routing_map.data.shape);
628-
#define ROUTER_BACKWARD_DISPATCH(RoutingMapFormatVal) \
629-
TE_ROUTER_PROBS_TYPE_SWITCH_ALL( \
630-
grad_logits.data.dtype, DataType, \
631-
fused_topk_with_score_function_backward_kernel_launcher<DataType, RoutingMapFormatVal>( \
632-
reinterpret_cast<uint8_t *>(routing_map.data.dptr), \
633-
reinterpret_cast<CompType *>(intermediate_output.data.dptr), \
634-
reinterpret_cast<DataType *>(grad_probs.data.dptr), num_tokens, num_experts, topk, \
635-
use_pre_softmax, scaling_factor, score_function, \
626+
#define ROUTER_BACKWARD_DISPATCH(RoutingMapFormatVal) \
627+
TE_ROUTER_PROBS_TYPE_SWITCH_ALL( \
628+
grad_logits.data.dtype, DataType, \
629+
fused_topk_with_score_function_backward_kernel_launcher<DataType, RoutingMapFormatVal>( \
630+
reinterpret_cast<uint8_t *>(routing_map.data.dptr), \
631+
reinterpret_cast<CompType *>(intermediate_output.data.dptr), \
632+
reinterpret_cast<DataType *>(grad_probs.data.dptr), num_tokens, num_experts, topk, \
633+
use_pre_softmax, scaling_factor, score_function, \
636634
reinterpret_cast<DataType *>(grad_logits.data.dptr), stream););
637635
if (routing_map_format == NVTE_ROUTING_MAP_FORMAT_BITMAP_U8) {
638636
ROUTER_BACKWARD_DISPATCH(NVTE_ROUTING_MAP_FORMAT_BITMAP_U8)
@@ -674,11 +672,13 @@ void nvte_fused_topk_with_score_function_forward(
674672
NVTE_ROUTING_MAP_FORMAT_BYTEMAP, intermediate_output, stream);
675673
}
676674

677-
void nvte_fused_topk_with_score_function_backward_v2(
678-
const NVTETensor routing_map, NVTERoutingMapFormat routing_map_format,
679-
const NVTETensor intermediate_output, const NVTETensor grad_probs, int num_tokens,
680-
int num_experts, int topk, int use_pre_softmax, float scaling_factor, int score_function,
681-
NVTETensor grad_logits, cudaStream_t stream) {
675+
void nvte_fused_topk_with_score_function_backward_v2(const NVTETensor routing_map,
676+
NVTERoutingMapFormat routing_map_format,
677+
const NVTETensor intermediate_output,
678+
const NVTETensor grad_probs, int num_tokens,
679+
int num_experts, int topk, int use_pre_softmax,
680+
float scaling_factor, int score_function,
681+
NVTETensor grad_logits, cudaStream_t stream) {
682682
NVTE_API_CALL(nvte_fused_topk_with_score_function_backward_v2);
683683
using namespace transformer_engine;
684684
fused_router::fused_topk_with_score_function_backward(

transformer_engine/common/include/transformer_engine/fused_router.h

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -122,11 +122,13 @@ void nvte_fused_topk_with_score_function_backward(const NVTETensor routing_map,
122122
* \param[out] grad_logits Gradient of logits.
123123
* \param[in] stream CUDA stream used for the operation.
124124
*/
125-
void nvte_fused_topk_with_score_function_backward_v2(
126-
const NVTETensor routing_map, NVTERoutingMapFormat routing_map_format,
127-
const NVTETensor intermediate_output, const NVTETensor grad_probs, int num_tokens,
128-
int num_experts, int topk, int use_pre_softmax, float scaling_factor, int score_function,
129-
NVTETensor grad_logits, cudaStream_t stream);
125+
void nvte_fused_topk_with_score_function_backward_v2(const NVTETensor routing_map,
126+
NVTERoutingMapFormat routing_map_format,
127+
const NVTETensor intermediate_output,
128+
const NVTETensor grad_probs, int num_tokens,
129+
int num_experts, int topk, int use_pre_softmax,
130+
float scaling_factor, int score_function,
131+
NVTETensor grad_logits, cudaStream_t stream);
130132

131133
/*! \brief Forward pass for computing scores/routing map for auxiliary loss (deprecated).
132134
*

transformer_engine/jax/csrc/extensions/router.cpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,9 @@ Error_Type FusedTopkWithScoreFunctionForwardFFI(
4545
// dim depends on the requested format: num_experts for BYTEMAP, ceil(num_experts/8)
4646
// for BITMAP_U8. Keeping this 2D also lets the kernel's shape NVTE_CHECKs fire.
4747
auto routing_map_format_nvte = static_cast<NVTERoutingMapFormat>(routing_map_format);
48-
size_t routing_map_trailing =
49-
(routing_map_format_nvte == NVTE_ROUTING_MAP_FORMAT_BITMAP_U8)
50-
? static_cast<size_t>((num_experts + 7) / 8)
51-
: static_cast<size_t>(num_experts);
48+
size_t routing_map_trailing = (routing_map_format_nvte == NVTE_ROUTING_MAP_FORMAT_BITMAP_U8)
49+
? static_cast<size_t>((num_experts + 7) / 8)
50+
: static_cast<size_t>(num_experts);
5251
auto routing_map_shape =
5352
std::vector<size_t>{static_cast<size_t>(num_tokens), routing_map_trailing};
5453
auto routing_map_tensor = TensorWrapper(routing_map, routing_map_shape, DType::kByte);
@@ -144,10 +143,9 @@ Error_Type FusedTopkWithScoreFunctionBackwardFFI(
144143
grad_logits_tensor.data(), stream);
145144
} else {
146145
auto routing_map_format_nvte = static_cast<NVTERoutingMapFormat>(routing_map_format);
147-
size_t routing_map_trailing =
148-
(routing_map_format_nvte == NVTE_ROUTING_MAP_FORMAT_BITMAP_U8)
149-
? static_cast<size_t>((num_experts + 7) / 8)
150-
: static_cast<size_t>(num_experts);
146+
size_t routing_map_trailing = (routing_map_format_nvte == NVTE_ROUTING_MAP_FORMAT_BITMAP_U8)
147+
? static_cast<size_t>((num_experts + 7) / 8)
148+
: static_cast<size_t>(num_experts);
151149
auto routing_map_shape =
152150
std::vector<size_t>{static_cast<size_t>(num_tokens), routing_map_trailing};
153151
auto routing_map_tensor =

transformer_engine/pytorch/csrc/extensions/pybind.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,7 @@ void init_router_bindings(pybind11::module &m) {
145145
py::arg("num_tokens"), py::arg("num_experts"), py::arg("routing_map"),
146146
py::arg("intermediate_output"), py::arg("grad_probs"), py::arg("grad_logits"),
147147
py::arg("topk"), py::arg("use_pre_softmax"), py::arg("scaling_factor"),
148-
py::arg("score_function"),
149-
py::arg("routing_map_format") = NVTE_ROUTING_MAP_FORMAT_BYTEMAP,
148+
py::arg("score_function"), py::arg("routing_map_format") = NVTE_ROUTING_MAP_FORMAT_BYTEMAP,
150149
"Fused topk with score function bwd");
151150
m.def("fused_score_for_moe_aux_loss_fwd", &fused_score_for_moe_aux_loss_fwd, py::arg("logits"),
152151
py::arg("topk"), py::arg("score_function"),

0 commit comments

Comments
 (0)