Skip to content

Commit 8fc75c3

Browse files
committed
feat: support topk_router
1 parent 766e679 commit 8fc75c3

7 files changed

Lines changed: 126 additions & 73 deletions

File tree

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,19 @@ class Tensor;
1111

1212
namespace infini_train::autograd {
1313

14-
class Top1Mask : public Function {
14+
class TopKMask : public Function {
1515
public:
16-
static constexpr char kType[] = "Top1MaskFunction";
16+
static constexpr char kType[] = "TopKMaskFunction";
1717

18-
Top1Mask() : Function(kType) {}
18+
explicit TopKMask(int64_t topk) : Function(kType), topk_(topk) {}
1919

2020
std::vector<std::shared_ptr<Tensor>> Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) override;
2121
void SetupContext(const std::vector<std::shared_ptr<Tensor>> &input_tensors,
2222
const std::vector<std::shared_ptr<Tensor>> &output_tensors) override;
2323
std::vector<std::shared_ptr<Tensor>> Backward(const std::vector<std::shared_ptr<Tensor>> &grad_outputs) override;
24+
25+
private:
26+
int64_t topk_ = 1;
2427
};
2528

2629
} // namespace infini_train::autograd

infini_train/include/nn/modules/transformer/transformer_config.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ enum class NormType {
3131
};
3232

3333
enum class MoERouterType {
34-
kTopK // Top-k router. The initial implementation supports top-1.
34+
kTopK // Top-k router.
3535
};
3636

3737
enum class MoEDispatcherType {
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#include "infini_train/include/autograd/moe.h"
1+
#include "infini_train/include/autograd/topk_mask.h"
22

33
#include "glog/logging.h"
44

@@ -7,25 +7,26 @@
77

88
namespace infini_train::autograd {
99

10-
std::vector<std::shared_ptr<Tensor>> Top1Mask::Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) {
10+
std::vector<std::shared_ptr<Tensor>> TopKMask::Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) {
1111
CHECK_EQ(input_tensors.size(), 1);
12+
CHECK_GT(topk_, 0);
1213
const auto &input = input_tensors[0];
1314
auto device = input->GetDevice().type();
14-
return {Dispatcher::Instance().Call<std::shared_ptr<Tensor>>({device, "Top1MaskForward"}, input)};
15+
return {Dispatcher::Instance().Call<std::shared_ptr<Tensor>>({device, "TopKMaskForward"}, input, topk_)};
1516
}
1617

17-
void Top1Mask::SetupContext(const std::vector<std::shared_ptr<Tensor>> &,
18+
void TopKMask::SetupContext(const std::vector<std::shared_ptr<Tensor>> &,
1819
const std::vector<std::shared_ptr<Tensor>> &output_tensors) {
1920
saved_tensors_ = {output_tensors[0]};
2021
}
2122

22-
std::vector<std::shared_ptr<Tensor>> Top1Mask::Backward(const std::vector<std::shared_ptr<Tensor>> &grad_outputs) {
23+
std::vector<std::shared_ptr<Tensor>> TopKMask::Backward(const std::vector<std::shared_ptr<Tensor>> &grad_outputs) {
2324
CHECK_EQ(grad_outputs.size(), 1);
2425
const auto &grad_output = grad_outputs[0];
2526
const auto &mask_values = saved_tensors_[0];
2627
auto device = grad_output->GetDevice().type();
2728
return {
28-
Dispatcher::Instance().Call<std::shared_ptr<Tensor>>({device, "Top1MaskBackward"}, grad_output, mask_values)};
29+
Dispatcher::Instance().Call<std::shared_ptr<Tensor>>({device, "TopKMaskBackward"}, grad_output, mask_values)};
2930
}
3031

3132
} // namespace infini_train::autograd
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
#include <limits>
12
#include <memory>
3+
#include <vector>
24

35
#include "glog/logging.h"
46

@@ -7,13 +9,15 @@
79

810
namespace infini_train::kernels::cpu {
911

10-
std::shared_ptr<Tensor> Top1MaskForward(const std::shared_ptr<Tensor> &input) {
11-
CHECK(input->Dtype() == DataType::kFLOAT32) << "CPU Top1MaskForward currently supports float32 only";
12+
std::shared_ptr<Tensor> TopKMaskForward(const std::shared_ptr<Tensor> &input, int64_t topk) {
13+
CHECK(input->Dtype() == DataType::kFLOAT32) << "CPU TopKMaskForward currently supports float32 only";
1214
CHECK_GE(input->Dims().size(), 1);
1315

1416
const auto &dims = input->Dims();
1517
const int64_t num_experts = dims.back();
1618
CHECK_GT(num_experts, 0);
19+
CHECK_GT(topk, 0);
20+
CHECK_LE(topk, num_experts);
1721
const int64_t rows = input->NumElements() / num_experts;
1822

1923
auto output = std::make_shared<Tensor>(dims, input->Dtype(), input->GetDevice());
@@ -22,24 +26,41 @@ std::shared_ptr<Tensor> Top1MaskForward(const std::shared_ptr<Tensor> &input) {
2226
const float *in = static_cast<const float *>(input->DataPtr());
2327
float *out = static_cast<float *>(output->DataPtr());
2428
for (int64_t row = 0; row < rows; ++row) {
25-
int64_t best_idx = 0;
26-
float best_value = in[row * num_experts];
27-
for (int64_t expert_idx = 1; expert_idx < num_experts; ++expert_idx) {
28-
const float value = in[row * num_experts + expert_idx];
29-
if (value > best_value) {
30-
best_value = value;
31-
best_idx = expert_idx;
29+
const int64_t row_offset = row * num_experts;
30+
std::vector<bool> selected_experts(num_experts, false);
31+
float selected_sum = 0.0f;
32+
for (int64_t selected = 0; selected < topk; ++selected) {
33+
int64_t best_idx = -1;
34+
float best_value = -std::numeric_limits<float>::infinity();
35+
for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) {
36+
if (selected_experts[expert_idx]) {
37+
continue;
38+
}
39+
const float value = in[row_offset + expert_idx];
40+
if (value > best_value) {
41+
best_value = value;
42+
best_idx = expert_idx;
43+
}
44+
}
45+
CHECK_GE(best_idx, 0);
46+
selected_experts[best_idx] = true;
47+
out[row_offset + best_idx] = best_value;
48+
selected_sum += best_value;
49+
}
50+
if (topk > 1 && selected_sum != 0.0f) {
51+
for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) {
52+
out[row_offset + expert_idx]
53+
= out[row_offset + expert_idx] == 0.0f ? 0.0f : out[row_offset + expert_idx] / selected_sum;
3254
}
3355
}
34-
out[row * num_experts + best_idx] = best_value;
3556
}
3657

3758
return output;
3859
}
3960

40-
std::shared_ptr<Tensor> Top1MaskBackward(const std::shared_ptr<Tensor> &grad_output,
61+
std::shared_ptr<Tensor> TopKMaskBackward(const std::shared_ptr<Tensor> &grad_output,
4162
const std::shared_ptr<Tensor> &mask_values) {
42-
CHECK(grad_output->Dtype() == DataType::kFLOAT32) << "CPU Top1MaskBackward currently supports float32 only";
63+
CHECK(grad_output->Dtype() == DataType::kFLOAT32) << "CPU TopKMaskBackward currently supports float32 only";
4364
CHECK(mask_values->Dtype() == DataType::kFLOAT32);
4465
CHECK(grad_output->Dims() == mask_values->Dims());
4566

@@ -58,10 +79,10 @@ std::shared_ptr<Tensor> Top1MaskBackward(const std::shared_ptr<Tensor> &grad_out
5879

5980
} // namespace infini_train::kernels::cpu
6081

61-
#define REGISTER_CPU_TOP1_MASK_KERNEL(kernel_name) \
82+
#define REGISTER_CPU_TOPK_MASK_KERNEL(kernel_name) \
6283
REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name)
6384

64-
REGISTER_CPU_TOP1_MASK_KERNEL(Top1MaskForward)
65-
REGISTER_CPU_TOP1_MASK_KERNEL(Top1MaskBackward)
85+
REGISTER_CPU_TOPK_MASK_KERNEL(TopKMaskForward)
86+
REGISTER_CPU_TOPK_MASK_KERNEL(TopKMaskBackward)
6687

67-
#undef REGISTER_CPU_TOP1_MASK_KERNEL
88+
#undef REGISTER_CPU_TOPK_MASK_KERNEL

infini_train/src/kernels/cuda/top1_mask.cu renamed to infini_train/src/kernels/cuda/topk_mask.cu

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -11,33 +11,44 @@
1111
namespace infini_train::kernels::cuda {
1212

1313
template <typename T>
14-
__global__ void Top1MaskForwardKernel(const T *__restrict__ input, T *__restrict__ output, int64_t rows,
15-
int64_t num_experts) {
14+
__global__ void TopKMaskForwardKernel(const T *__restrict__ input, T *__restrict__ output, int64_t rows,
15+
int64_t num_experts, int64_t topk) {
1616
int64_t row = blockIdx.x * blockDim.x + threadIdx.x;
1717
if (row >= rows) {
1818
return;
1919
}
2020

2121
const int64_t offset = row * num_experts;
22-
int64_t best_idx = 0;
23-
float best_value = static_cast<float>(input[offset]);
24-
for (int64_t expert_idx = 1; expert_idx < num_experts; ++expert_idx) {
22+
float selected_sum = 0.0f;
23+
for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) {
2524
const float value = static_cast<float>(input[offset + expert_idx]);
26-
if (value > best_value) {
27-
best_value = value;
28-
best_idx = expert_idx;
25+
int64_t rank = 0;
26+
for (int64_t other_idx = 0; other_idx < num_experts; ++other_idx) {
27+
const float other_value = static_cast<float>(input[offset + other_idx]);
28+
if (other_value > value || (other_value == value && other_idx < expert_idx)) {
29+
++rank;
30+
}
2931
}
32+
const bool selected = rank < topk;
33+
output[offset + expert_idx] = selected ? input[offset + expert_idx] : T(0.0f);
34+
selected_sum += selected ? value : 0.0f;
3035
}
31-
for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) {
32-
output[offset + expert_idx] = expert_idx == best_idx ? input[offset + expert_idx] : T(0.0f);
36+
if (topk > 1 && selected_sum != 0.0f) {
37+
for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) {
38+
if (static_cast<float>(output[offset + expert_idx]) != 0.0f) {
39+
output[offset + expert_idx] = T(static_cast<float>(output[offset + expert_idx]) / selected_sum);
40+
}
41+
}
3342
}
3443
}
3544

36-
std::shared_ptr<Tensor> Top1MaskForward(const std::shared_ptr<Tensor> &input) {
45+
std::shared_ptr<Tensor> TopKMaskForward(const std::shared_ptr<Tensor> &input, int64_t topk) {
3746
CHECK_GE(input->Dims().size(), 1);
3847
const auto &dims = input->Dims();
3948
const int64_t num_experts = dims.back();
4049
CHECK_GT(num_experts, 0);
50+
CHECK_GT(topk, 0);
51+
CHECK_LE(topk, num_experts);
4152
const int64_t rows = input->NumElements() / num_experts;
4253

4354
auto output = std::make_shared<Tensor>(dims, input->Dtype(), input->GetDevice());
@@ -52,16 +63,16 @@ std::shared_ptr<Tensor> Top1MaskForward(const std::shared_ptr<Tensor> &input) {
5263
core::cuda::DispatchCudaFunc<INFINI_ALL_FLOATING_TYPES>(
5364
input->Dtype(),
5465
[=]<typename T>() {
55-
Top1MaskForwardKernel<T><<<blocks, threads, 0, stream>>>(
56-
static_cast<const T *>(input->DataPtr()), static_cast<T *>(output->DataPtr()), rows, num_experts);
66+
TopKMaskForwardKernel<T><<<blocks, threads, 0, stream>>>(
67+
static_cast<const T *>(input->DataPtr()), static_cast<T *>(output->DataPtr()), rows, num_experts, topk);
5768
},
58-
"CUDA Top1MaskForward");
69+
"CUDA TopKMaskForward");
5970

6071
return output;
6172
}
6273

6374
template <typename T>
64-
__global__ void Top1MaskBackwardKernel(const T *__restrict__ grad_output, const T *__restrict__ mask_values,
75+
__global__ void TopKMaskBackwardKernel(const T *__restrict__ grad_output, const T *__restrict__ mask_values,
6576
T *__restrict__ grad_input, int64_t total_elements) {
6677
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
6778
if (idx >= total_elements) {
@@ -70,7 +81,7 @@ __global__ void Top1MaskBackwardKernel(const T *__restrict__ grad_output, const
7081
grad_input[idx] = static_cast<float>(mask_values[idx]) != 0.0f ? grad_output[idx] : T(0.0f);
7182
}
7283

73-
std::shared_ptr<Tensor> Top1MaskBackward(const std::shared_ptr<Tensor> &grad_output,
84+
std::shared_ptr<Tensor> TopKMaskBackward(const std::shared_ptr<Tensor> &grad_output,
7485
const std::shared_ptr<Tensor> &mask_values) {
7586
CHECK(grad_output->Dims() == mask_values->Dims());
7687
CHECK(grad_output->Dtype() == mask_values->Dtype());
@@ -87,21 +98,21 @@ std::shared_ptr<Tensor> Top1MaskBackward(const std::shared_ptr<Tensor> &grad_out
8798
core::cuda::DispatchCudaFunc<INFINI_ALL_FLOATING_TYPES>(
8899
grad_output->Dtype(),
89100
[=]<typename T>() {
90-
Top1MaskBackwardKernel<T><<<blocks, threads, 0, stream>>>(
101+
TopKMaskBackwardKernel<T><<<blocks, threads, 0, stream>>>(
91102
static_cast<const T *>(grad_output->DataPtr()), static_cast<const T *>(mask_values->DataPtr()),
92103
static_cast<T *>(grad_input->DataPtr()), total_elements);
93104
},
94-
"CUDA Top1MaskBackward");
105+
"CUDA TopKMaskBackward");
95106

96107
return grad_input;
97108
}
98109

99110
} // namespace infini_train::kernels::cuda
100111

101-
#define REGISTER_CUDA_TOP1_MASK_KERNEL(kernel_name) \
112+
#define REGISTER_CUDA_TOPK_MASK_KERNEL(kernel_name) \
102113
REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name)
103114

104-
REGISTER_CUDA_TOP1_MASK_KERNEL(Top1MaskForward)
105-
REGISTER_CUDA_TOP1_MASK_KERNEL(Top1MaskBackward)
115+
REGISTER_CUDA_TOPK_MASK_KERNEL(TopKMaskForward)
116+
REGISTER_CUDA_TOPK_MASK_KERNEL(TopKMaskBackward)
106117

107-
#undef REGISTER_CUDA_TOP1_MASK_KERNEL
118+
#undef REGISTER_CUDA_TOPK_MASK_KERNEL

infini_train/src/nn/modules/transformer/moe/router.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
#include "glog/logging.h"
77

88
#include "infini_train/include/autograd/linear.h"
9-
#include "infini_train/include/autograd/moe.h"
9+
#include "infini_train/include/autograd/topk_mask.h"
1010
#include "infini_train/include/nn/functional.h"
1111
#include "infini_train/include/nn/init.h"
1212
#include "infini_train/include/nn/modules/transformer/moe/moe_utils.h"
@@ -17,8 +17,9 @@ namespace infini_train::nn::moe {
1717
TopKRouter::TopKRouter(const TransformerConfig &config) : CloneableModule(kType), config_(config) {
1818
const auto &moe_config = RequireMoEConfig(config_);
1919
CHECK(moe_config.router_type == MoERouterType::kTopK);
20-
CHECK_EQ(moe_config.router_topk, 1) << "Current InfiniTrain MoE implementation supports top-1 routing only";
2120
CHECK_GT(moe_config.num_experts, 0);
21+
CHECK_GT(moe_config.router_topk, 0);
22+
CHECK_LE(moe_config.router_topk, moe_config.num_experts);
2223

2324
parameters_[kParamWeightName]
2425
= std::make_shared<Tensor>(std::vector<int64_t>{moe_config.num_experts, config_.n_embd}, DataType::kFLOAT32,
@@ -43,7 +44,8 @@ std::vector<std::shared_ptr<Tensor>> TopKRouter::Forward(const std::vector<std::
4344

4445
auto logits = std::make_shared<autograd::Linear>()->Apply(linear_inputs)[0];
4546
auto scores = function::Softmax(logits, -1);
46-
auto routing_probs = std::make_shared<autograd::Top1Mask>()->Apply({scores})[0];
47+
const auto &moe_config = RequireMoEConfig(config_);
48+
auto routing_probs = std::make_shared<autograd::TopKMask>(moe_config.router_topk)->Apply({scores})[0];
4749
return {routing_probs};
4850
}
4951

test/transformer/test_transformer_architecture.cc

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -527,10 +527,10 @@ void TestStateDict() {
527527
}
528528

529529
// ============================================================================
530-
// Test 11: MoE Layer MVP
530+
// Test 11: MoE Layer
531531
// ============================================================================
532532
void TestMoELayer() {
533-
std::cout << "\n=== Test 11: MoE Layer MVP ===" << std::endl;
533+
std::cout << "\n=== Test 11: MoE Layer ===" << std::endl;
534534

535535
nn::TransformerConfig config;
536536
config.n_embd = 32;
@@ -543,29 +543,43 @@ void TestMoELayer() {
543543
config.moe_config->num_experts = 2;
544544
config.moe_config->router_topk = 1;
545545

546-
try {
547-
auto moe = std::make_shared<nn::moe::MoELayer>(config);
548-
auto input = std::make_shared<Tensor>(std::vector<int64_t>{2, 4, config.n_embd}, DataType::kFLOAT32);
549-
input->Uniform();
546+
auto moe = std::make_shared<nn::moe::MoELayer>(config);
547+
auto input = std::make_shared<Tensor>(std::vector<int64_t>{2, 4, config.n_embd}, DataType::kFLOAT32);
548+
input->Uniform();
550549

551-
auto output = (*moe)({input});
552-
if (output.size() != 1) {
553-
std::cout << "FAIL: MoELayer forward should return 1 tensor" << std::endl;
554-
return;
555-
}
556-
if (output[0]->Dims() != input->Dims()) {
557-
std::cout << "FAIL: MoELayer output shape mismatch" << std::endl;
558-
return;
559-
}
550+
auto output = (*moe)({input});
551+
CHECK_EQ(output.size(), 1);
552+
CHECK(output[0]->Dims() == input->Dims());
560553

561-
auto params = moe->Parameters();
562-
if (params.empty()) {
563-
std::cout << "FAIL: MoELayer should own router and expert parameters" << std::endl;
564-
return;
565-
}
554+
auto params = moe->Parameters();
555+
CHECK(!params.empty());
566556

567-
std::cout << "SUCCESS: MoE layer MVP forward works correctly!" << std::endl;
568-
} catch (const std::exception &e) { std::cout << "FAIL: Exception: " << e.what() << std::endl; }
557+
std::cout << "SUCCESS: MoE layer forward works correctly!" << std::endl;
558+
}
559+
560+
void TestMoELayerTop2() {
561+
std::cout << "\n=== Test 12: MoE Layer Top-2 ===" << std::endl;
562+
563+
nn::TransformerConfig config;
564+
config.n_embd = 32;
565+
config.n_head = 2;
566+
config.n_kv_head = 2;
567+
config.activation_type = nn::MLPType::kGELU;
568+
config.add_bias_linear = true;
569+
config.ffn_type = nn::FFNType::kMoE;
570+
config.moe_config = nn::MoEConfig{};
571+
config.moe_config->num_experts = 4;
572+
config.moe_config->router_topk = 2;
573+
574+
auto moe = std::make_shared<nn::moe::MoELayer>(config);
575+
auto input = std::make_shared<Tensor>(std::vector<int64_t>{2, 4, config.n_embd}, DataType::kFLOAT32);
576+
input->Uniform();
577+
578+
auto output = (*moe)({input});
579+
CHECK_EQ(output.size(), 1);
580+
CHECK(output[0]->Dims() == input->Dims());
581+
582+
std::cout << "SUCCESS: MoE layer top-2 forward works correctly!" << std::endl;
569583
}
570584

571585
// ============================================================================
@@ -591,6 +605,7 @@ int main(int argc, char *argv[]) {
591605
TestRopeUtils();
592606
TestStateDict();
593607
TestMoELayer();
608+
TestMoELayerTop2();
594609

595610
std::cout << "\n========================================" << std::endl;
596611
std::cout << " All Tests Completed" << std::endl;

0 commit comments

Comments
 (0)