Skip to content

Commit 325934c

Browse files
committed
[Common] Allow expanded columns in fused MoE aux loss
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
1 parent a014300 commit 325934c

1 file changed

Lines changed: 5 additions & 3 deletions

File tree

transformer_engine/common/fused_router/fused_moe_aux_loss.cu

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,10 @@ void fused_moe_aux_loss_forward_kernel_launcher(const DataType* probs,
8787
int num_cols, int topk, float coeff,
8888
DataType* aux_loss, float* Coeff_buf,
8989
cudaStream_t stream) {
90-
NVTE_CHECK(num_experts == num_cols, "Number of experts (", num_experts,
91-
") must be equal to number of input columns (", num_cols, ").");
90+
NVTE_CHECK(num_cols > 0, "num_cols must be positive, got ", num_cols);
91+
NVTE_CHECK(num_experts > 0, "num_experts must be positive, got ", num_experts);
92+
NVTE_CHECK(num_cols % num_experts == 0, "Number of input columns (", num_cols,
93+
") must be a multiple of number of experts (", num_experts, ").");
9294

9395
// Round up to a multiple of warp size for correct warp shuffles.
9496
const int block_size = ((std::min(1024, num_cols) + static_cast<int>(kThreadsPerWarp) - 1) /
@@ -98,7 +100,7 @@ void fused_moe_aux_loss_forward_kernel_launcher(const DataType* probs,
98100

99101
// One CompType per thread in shared memory.
100102
const size_t smem_size = block_size * sizeof(CompType);
101-
check_shared_memory_capacity_num_experts(smem_size, num_experts);
103+
check_shared_memory_capacity_num_experts(smem_size, num_cols);
102104

103105
// Compute final coefficient and zero the float accumulator (Coeff_buf[1]) before launch.
104106
const float C_coeff = (num_experts * coeff) / topk / total_num_tokens / total_num_tokens;

0 commit comments

Comments
 (0)