Skip to content

Commit ba78cf0

Browse files
committed
address all comments
Signed-off-by: tdophung <tdophung@nvidia.com>
1 parent 22ae793 commit ba78cf0

10 files changed

Lines changed: 429 additions & 171 deletions

File tree

transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu

Lines changed: 84 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,12 @@
1616
namespace transformer_engine {
1717
namespace fused_router {
1818

19-
template <typename DataType, TopkFuncType TopkFunc = TopkFuncType::Naive>
20-
__global__ void fused_score_for_moe_aux_loss_forward_kernel(const DataType *logits, int num_tokens,
21-
int num_experts, int topk,
22-
int score_function, float *scores,
23-
uint8_t *routing_map,
24-
NVTERoutingMapFormat routing_map_format,
25-
CompType *intermediate_output) {
19+
template <typename DataType, NVTERoutingMapFormat RoutingMapFormat,
20+
TopkFuncType TopkFunc = TopkFuncType::Naive>
21+
__global__ void fused_score_for_moe_aux_loss_forward_kernel(
22+
const DataType *logits, int num_tokens, int num_experts, int topk, int score_function,
23+
float *scores, uint8_t *routing_map, CompType *intermediate_output) {
24+
constexpr bool kIsBitmap = (RoutingMapFormat == NVTE_ROUTING_MAP_FORMAT_BITMAP_U8);
2625
/***
2726
* Section: Global Variables/Addresses init
2827
* - Each warp is responsible for one token, and has own shared memory buffer.
@@ -42,7 +41,7 @@ __global__ void fused_score_for_moe_aux_loss_forward_kernel(const DataType *logi
4241
const int bitmap_words_per_warp = (num_experts + 31) / 32;
4342
const int bitmap_row_bytes = (num_experts + 7) / 8;
4443
uint32_t *bitmap_words_buf = nullptr;
45-
if (routing_map_format == NVTE_ROUTING_MAP_FORMAT_BITMAP_U8) {
44+
if constexpr (kIsBitmap) {
4645
bitmap_words_buf = reinterpret_cast<uint32_t *>(topk_indices_buf + topk * num_token_per_block);
4746
}
4847
// The address of buffers on the current warp
@@ -77,7 +76,7 @@ __global__ void fused_score_for_moe_aux_loss_forward_kernel(const DataType *logi
7776
intermediate_output[pos_offset + i] = -std::numeric_limits<CompType>::infinity();
7877
}
7978
}
80-
if (routing_map_format == NVTE_ROUTING_MAP_FORMAT_BYTEMAP) {
79+
if constexpr (!kIsBitmap) {
8180
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
8281
routing_map[pos_offset + i] = 0;
8382
}
@@ -148,7 +147,7 @@ __global__ void fused_score_for_moe_aux_loss_forward_kernel(const DataType *logi
148147
__syncwarp();
149148

150149
// Write the routing_map to the output tensor
151-
if (routing_map_format == NVTE_ROUTING_MAP_FORMAT_BYTEMAP) {
150+
if constexpr (!kIsBitmap) {
152151
for (int i = lane_id; i < topk; i += kThreadsPerWarp) {
153152
routing_map[pos_offset + topk_indices[i]] = 1;
154153
}
@@ -174,18 +173,17 @@ __global__ void fused_score_for_moe_aux_loss_forward_kernel(const DataType *logi
174173
}
175174
}
176175

177-
template <typename DataType>
176+
template <typename DataType, NVTERoutingMapFormat RoutingMapFormat>
178177
void fused_score_for_moe_aux_loss_forward_kernel_launcher(
179178
const DataType *logits, int num_tokens, int num_experts, int topk, int score_function,
180-
float *scores, uint8_t *routing_map, NVTERoutingMapFormat routing_map_format,
181-
CompType *intermediate_output, cudaStream_t stream) {
179+
float *scores, uint8_t *routing_map, CompType *intermediate_output, cudaStream_t stream) {
182180
// Meta data for the kernel
183181
size_t num_token_per_block = kThreadsPerBlock / kThreadsPerWarp;
184182
size_t grid_size = (num_tokens + num_token_per_block - 1) / num_token_per_block;
185183
size_t shared_memory_size = num_experts * num_token_per_block * sizeof(CompType) // logits
186184
+ topk * num_token_per_block * sizeof(CompType) // topk_logits
187185
+ topk * num_token_per_block * sizeof(int); // topk_indices
188-
if (routing_map_format == NVTE_ROUTING_MAP_FORMAT_BITMAP_U8) {
186+
if constexpr (RoutingMapFormat == NVTE_ROUTING_MAP_FORMAT_BITMAP_U8) {
189187
size_t bitmap_words_per_warp = (num_experts + 31) / 32;
190188
shared_memory_size +=
191189
bitmap_words_per_warp * num_token_per_block * sizeof(uint32_t); // bitmap accumulator
@@ -195,36 +193,76 @@ void fused_score_for_moe_aux_loss_forward_kernel_launcher(
195193
// switch at K=16 where naive O(K^2*E) starts to dominate
196194
if (topk < 16) {
197195
NVTE_CHECK_CUDA(cudaFuncSetAttribute(
198-
fused_score_for_moe_aux_loss_forward_kernel<DataType, TopkFuncType::Naive>,
196+
fused_score_for_moe_aux_loss_forward_kernel<DataType, RoutingMapFormat,
197+
TopkFuncType::Naive>,
199198
cudaFuncAttributeMaxDynamicSharedMemorySize, shared_memory_size));
200-
fused_score_for_moe_aux_loss_forward_kernel<DataType, TopkFuncType::Naive>
199+
fused_score_for_moe_aux_loss_forward_kernel<DataType, RoutingMapFormat, TopkFuncType::Naive>
201200
<<<grid_size, kThreadsPerBlock, shared_memory_size, stream>>>(
202201
logits, num_tokens, num_experts, topk, score_function, scores, routing_map,
203-
routing_map_format, intermediate_output);
202+
intermediate_output);
204203
} else {
205204
NVTE_CHECK_CUDA(cudaFuncSetAttribute(
206-
fused_score_for_moe_aux_loss_forward_kernel<DataType, TopkFuncType::Radix>,
205+
fused_score_for_moe_aux_loss_forward_kernel<DataType, RoutingMapFormat,
206+
TopkFuncType::Radix>,
207207
cudaFuncAttributeMaxDynamicSharedMemorySize, shared_memory_size));
208-
fused_score_for_moe_aux_loss_forward_kernel<DataType, TopkFuncType::Radix>
208+
fused_score_for_moe_aux_loss_forward_kernel<DataType, RoutingMapFormat, TopkFuncType::Radix>
209209
<<<grid_size, kThreadsPerBlock, shared_memory_size, stream>>>(
210210
logits, num_tokens, num_experts, topk, score_function, scores, routing_map,
211-
routing_map_format, intermediate_output);
211+
intermediate_output);
212212
}
213213
NVTE_CHECK_CUDA(cudaGetLastError());
214214
}
215215

216+
// Build the expected routing_map shape for a given NVTERoutingMapFormat.
217+
// BYTEMAP -> [num_tokens, num_experts]
218+
// BITMAP_U8 -> [num_tokens, ceil(num_experts/8)]
219+
static std::vector<size_t> expected_routing_map_shape(int num_tokens, int num_experts,
220+
NVTERoutingMapFormat format) {
221+
const size_t t = static_cast<size_t>(num_tokens);
222+
const size_t e = static_cast<size_t>(num_experts);
223+
if (format == NVTE_ROUTING_MAP_FORMAT_BITMAP_U8) {
224+
return {t, (e + 7) / 8};
225+
}
226+
return {t, e};
227+
}
228+
216229
void fused_score_for_moe_aux_loss_forward(const Tensor &logits, int num_tokens, int num_experts,
217230
int topk, int score_function, Tensor &scores,
218231
Tensor &routing_map,
219232
NVTERoutingMapFormat routing_map_format,
220233
Tensor &intermediate_output, cudaStream_t stream) {
221-
TE_ROUTER_PROBS_TYPE_SWITCH_ALL(
222-
logits.data.dtype, DataType,
223-
fused_score_for_moe_aux_loss_forward_kernel_launcher<DataType>(
224-
reinterpret_cast<DataType *>(logits.data.dptr), num_tokens, num_experts, topk,
225-
score_function, reinterpret_cast<float *>(scores.data.dptr),
226-
reinterpret_cast<uint8_t *>(routing_map.data.dptr), routing_map_format,
234+
NVTE_CHECK(num_tokens > 0 && num_experts > 0,
235+
"num_tokens and num_experts must be positive; got num_tokens=", num_tokens,
236+
", num_experts=", num_experts);
237+
const std::vector<size_t> dense_shape{static_cast<size_t>(num_tokens),
238+
static_cast<size_t>(num_experts)};
239+
NVTE_CHECK(logits.data.shape == dense_shape, "logits shape must be [num_tokens, num_experts]=[",
240+
num_tokens, ", ", num_experts, "], got ", logits.data.shape);
241+
NVTE_CHECK(scores.data.shape == dense_shape, "scores shape must be [num_tokens, num_experts]=[",
242+
num_tokens, ", ", num_experts, "], got ", scores.data.shape);
243+
NVTE_CHECK(intermediate_output.data.shape == dense_shape,
244+
"intermediate_output shape must be [num_tokens, num_experts]=[", num_tokens, ", ",
245+
num_experts, "], got ", intermediate_output.data.shape);
246+
const auto routing_map_shape =
247+
expected_routing_map_shape(num_tokens, num_experts, routing_map_format);
248+
NVTE_CHECK(routing_map.data.shape == routing_map_shape,
249+
"routing_map shape mismatch for ",
250+
(routing_map_format == NVTE_ROUTING_MAP_FORMAT_BITMAP_U8 ? "BITMAP_U8" : "BYTEMAP"),
251+
"; expected ", routing_map_shape, ", got ", routing_map.data.shape);
252+
#define AUX_LOSS_FORWARD_DISPATCH(RoutingMapFormatVal) \
253+
TE_ROUTER_PROBS_TYPE_SWITCH_ALL( \
254+
logits.data.dtype, DataType, \
255+
fused_score_for_moe_aux_loss_forward_kernel_launcher<DataType, RoutingMapFormatVal>( \
256+
reinterpret_cast<DataType *>(logits.data.dptr), num_tokens, num_experts, topk, \
257+
score_function, reinterpret_cast<float *>(scores.data.dptr), \
258+
reinterpret_cast<uint8_t *>(routing_map.data.dptr), \
227259
reinterpret_cast<CompType *>(intermediate_output.data.dptr), stream););
260+
if (routing_map_format == NVTE_ROUTING_MAP_FORMAT_BITMAP_U8) {
261+
AUX_LOSS_FORWARD_DISPATCH(NVTE_ROUTING_MAP_FORMAT_BITMAP_U8)
262+
} else {
263+
AUX_LOSS_FORWARD_DISPATCH(NVTE_ROUTING_MAP_FORMAT_BYTEMAP)
264+
}
265+
#undef AUX_LOSS_FORWARD_DISPATCH
228266
}
229267

230268
template <typename DataType>
@@ -389,18 +427,32 @@ void fused_score_for_moe_aux_loss_backward(const Tensor &intermediate_output,
389427
} // namespace fused_router
390428
} // namespace transformer_engine
391429

430+
void nvte_fused_score_for_moe_aux_loss_forward_v2(const NVTETensor logits, int num_tokens,
431+
int num_experts, int topk, int score_function,
432+
NVTETensor scores, NVTETensor routing_map,
433+
NVTERoutingMapFormat routing_map_format,
434+
const NVTETensor intermediate_output,
435+
cudaStream_t stream) {
436+
NVTE_API_CALL(nvte_fused_score_for_moe_aux_loss_forward_v2);
437+
using namespace transformer_engine;
438+
fused_router::fused_score_for_moe_aux_loss_forward(
439+
*convertNVTETensorCheck(logits), num_tokens, num_experts, topk, score_function,
440+
*convertNVTETensorCheck(scores), *convertNVTETensorCheck(routing_map), routing_map_format,
441+
*convertNVTETensorCheck(intermediate_output), stream);
442+
}
443+
444+
// Deprecated V1 entry point: forwards to the V2 above with the BYTEMAP layout.
445+
// Kept for ABI compatibility with external C API consumers.
392446
void nvte_fused_score_for_moe_aux_loss_forward(const NVTETensor logits, int num_tokens,
393447
int num_experts, int topk, int score_function,
394448
NVTETensor scores, NVTETensor routing_map,
395-
NVTERoutingMapFormat routing_map_format,
396449
const NVTETensor intermediate_output,
397450
cudaStream_t stream) {
398451
NVTE_API_CALL(nvte_fused_score_for_moe_aux_loss_forward);
399-
using namespace transformer_engine;
400-
fused_router::fused_score_for_moe_aux_loss_forward(
401-
*convertNVTETensorCheck(logits), num_tokens, num_experts, topk, score_function,
402-
*convertNVTETensorCheck(scores), *convertNVTETensorCheck(routing_map), routing_map_format,
403-
*convertNVTETensorCheck(intermediate_output), stream);
452+
nvte_fused_score_for_moe_aux_loss_forward_v2(logits, num_tokens, num_experts, topk,
453+
score_function, scores, routing_map,
454+
NVTE_ROUTING_MAP_FORMAT_BYTEMAP,
455+
intermediate_output, stream);
404456
}
405457

406458
void nvte_fused_score_for_moe_aux_loss_backward(const NVTETensor intermediate_output,

0 commit comments

Comments
 (0)