@@ -87,8 +87,10 @@ void fused_moe_aux_loss_forward_kernel_launcher(const DataType* probs,
8787 int num_cols, int topk, float coeff,
8888 DataType* aux_loss, float * Coeff_buf,
8989 cudaStream_t stream) {
90- NVTE_CHECK (num_experts == num_cols, " Number of experts (" , num_experts,
91- " ) must be equal to number of input columns (" , num_cols, " )." );
90+ NVTE_CHECK (num_cols > 0 , " num_cols must be positive, got " , num_cols);
91+ NVTE_CHECK (num_experts > 0 , " num_experts must be positive, got " , num_experts);
92+ NVTE_CHECK (num_cols % num_experts == 0 , " Number of input columns (" , num_cols,
93+ " ) must be a multiple of number of experts (" , num_experts, " )." );
9294
9395 // Round up to a multiple of warp size for correct warp shuffles.
9496 const int block_size = ((std::min (1024 , num_cols) + static_cast <int >(kThreadsPerWarp ) - 1 ) /
@@ -98,7 +100,7 @@ void fused_moe_aux_loss_forward_kernel_launcher(const DataType* probs,
98100
99101 // One CompType per thread in shared memory.
100102 const size_t smem_size = block_size * sizeof (CompType);
101- check_shared_memory_capacity_num_experts (smem_size, num_experts );
103+ check_shared_memory_capacity_num_experts (smem_size, num_cols );
102104
103105 // Compute final coefficient and zero the float accumulator (Coeff_buf[1]) before launch.
104106 const float C_coeff = (num_experts * coeff) / topk / total_num_tokens / total_num_tokens;
0 commit comments