Skip to content

Commit bfb5924

Browse files
committed
feat: support moe_ffn_hidden_size config
1 parent ea6af08 commit bfb5924

2 files changed

Lines changed: 17 additions & 3 deletions

File tree

infini_train/src/nn/modules/transformer/mlp.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,14 @@ MLP::MLP(const TransformerConfig &config) : CloneableModule(kType) {
3535
}
3636

3737
// Round up to multiple_of
38-
int64_t before_round = ffn_hidden;
3938
ffn_hidden = (ffn_hidden + config.multiple_of - 1) / config.multiple_of * config.multiple_of;
4039

40+
if (config.ffn_type == FFNType::kMoE && config.moe_config.has_value()
41+
&& config.moe_config->moe_ffn_hidden_size > 0) {
42+
ffn_hidden = config.moe_config->moe_ffn_hidden_size;
43+
}
44+
CHECK_GT(ffn_hidden, 0);
45+
4146
// c_fc: ColumnParallel (input full, output parallel)
4247
modules_[kCFcLayerName] = std::make_shared<parallel::ColumnParallelLinear>(
4348
/*in_features=*/config.n_embd, /*out_features=*/ffn_hidden,

test/transformer/test_transformer_architecture.cc

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -564,12 +564,13 @@ void TestMoELayerTop2() {
564564
config.n_embd = 32;
565565
config.n_head = 2;
566566
config.n_kv_head = 2;
567-
config.activation_type = nn::MLPType::kGELU;
568-
config.add_bias_linear = true;
567+
config.activation_type = nn::MLPType::kSwiGLU;
568+
config.add_bias_linear = false;
569569
config.ffn_type = nn::FFNType::kMoE;
570570
config.moe_config = nn::MoEConfig{};
571571
config.moe_config->num_experts = 4;
572572
config.moe_config->router_topk = 2;
573+
config.moe_config->moe_ffn_hidden_size = 48;
573574

574575
auto moe = std::make_shared<nn::moe::MoELayer>(config);
575576
auto input = std::make_shared<Tensor>(std::vector<int64_t>{2, 4, config.n_embd}, DataType::kFLOAT32);
@@ -579,6 +580,14 @@ void TestMoELayerTop2() {
579580
CHECK_EQ(output.size(), 1);
580581
CHECK(output[0]->Dims() == input->Dims());
581582

583+
auto state = moe->StateDict();
584+
CHECK(state.contains("experts.expert_0.c_fc.weight"));
585+
CHECK(state.contains("experts.expert_0.c_fc2.weight"));
586+
CHECK(state.contains("experts.expert_0.c_proj.weight"));
587+
CHECK(state.at("experts.expert_0.c_fc.weight")->Dims() == std::vector<int64_t>({48, config.n_embd}));
588+
CHECK(state.at("experts.expert_0.c_fc2.weight")->Dims() == std::vector<int64_t>({48, config.n_embd}));
589+
CHECK(state.at("experts.expert_0.c_proj.weight")->Dims() == std::vector<int64_t>({config.n_embd, 48}));
590+
582591
std::cout << "SUCCESS: MoE layer top-2 forward works correctly!" << std::endl;
583592
}
584593

0 commit comments

Comments
 (0)