66
77#include " glog/logging.h"
88
9+ #include " infini_train/include/nn/functional.h"
910#include " infini_train/include/nn/modules/transformer/mlp.h"
1011#include " infini_train/include/nn/modules/transformer/moe/moe_utils.h"
12+ #include " infini_train/include/nn/modules/transformer/moe/token_dispatcher.h"
1113#include " infini_train/include/tensor.h"
1214
1315namespace infini_train ::nn::moe {
1416
1517SequentialMLP::SequentialMLP (const TransformerConfig &config) : CloneableModule(kType ), config_(config) {
1618 const auto &moe_config = RequireMoEConfig (config_);
17- CHECK (moe_config.expert_impl == MoEExpertImpl ::kSequential );
19+ CHECK (moe_config.expert_impl == MoEConfig::ExpertImpl ::kSequential );
1820 CHECK_EQ (moe_config.expert_parallel_size , 1 )
1921 << " Current InfiniTrain MoE implementation supports expert_parallel_size=1 only" ;
20- CHECK (moe_config.dispatcher_type == MoEDispatcherType:: kLocal )
21- << " Current InfiniTrain MoE implementation supports local dispatch only" ;
22+ CHECK (moe_config.dispatcher_type == MoEConfig::DispatcherType:: kAllGather )
23+ << " Current InfiniTrain MoE implementation supports AllGather dispatcher only" ;
2224
2325 num_local_experts_ = moe_config.num_experts ;
2426 CHECK_GT (num_local_experts_, 0 );
@@ -29,22 +31,35 @@ SequentialMLP::SequentialMLP(const TransformerConfig &config) : CloneableModule(
2931}
3032
3133std::vector<std::shared_ptr<Tensor>> SequentialMLP::Forward (const std::vector<std::shared_ptr<Tensor>> &input_tensors) {
32- CHECK_EQ (input_tensors.size (), 2 );
34+ CHECK_EQ (input_tensors.size (), 3 );
3335 auto hidden_states = input_tensors[0 ];
3436 auto routing_probs = input_tensors[1 ];
35- CHECK_EQ (routing_probs->Dims ().back (), num_local_experts_);
37+ auto routing_map = input_tensors[2 ];
38+ std::unique_ptr<MoETokenDispatcher> dispatcher
39+ = std::make_unique<MoEAllGatherTokenDispatcher>(num_local_experts_, config_);
40+ const auto &dispatch = dispatcher->Dispatch (hidden_states, routing_map, routing_probs);
3641
37- std::shared_ptr<Tensor> output = nullptr ;
38- const int64_t expert_dim = static_cast < int64_t >(routing_probs-> Dims (). size ()) - 1 ;
42+ std::vector<std:: shared_ptr<Tensor>> expert_outputs ;
43+ int64_t start = 0 ;
3944 for (int64_t expert_idx = 0 ; expert_idx < num_local_experts_; ++expert_idx) {
45+ const int64_t num_tokens_for_expert = dispatch.metadata .tokens_per_expert_host [expert_idx];
46+ const int64_t end = start + num_tokens_for_expert;
47+ if (num_tokens_for_expert == 0 ) {
48+ start = end;
49+ continue ;
50+ }
51+
52+ auto expert_input = dispatch.permuted_hidden_states ->Slice (0 , start, end);
4053 auto expert_name = std::string (kExpertNamePrefix ) + std::to_string (expert_idx);
41- auto expert_output = (*modules_.at (expert_name))({hidden_states})[0 ];
42- auto expert_prob = routing_probs->Slice (expert_dim, expert_idx, expert_idx + 1 );
43- auto weighted_output = expert_output * expert_prob;
44- output = output == nullptr ? weighted_output : output + weighted_output;
54+ expert_outputs.push_back ((*modules_.at (expert_name))({expert_input})[0 ]);
55+ start = end;
4556 }
57+ CHECK_EQ (start, dispatch.permuted_hidden_states ->Dims ()[0 ]);
58+ CHECK (!expert_outputs.empty ()) << " No tokens were dispatched to any local expert" ;
4659
47- return {output};
60+ auto permuted_expert_output
61+ = expert_outputs.size () == 1 ? expert_outputs[0 ] : nn::function::Concat (expert_outputs, 0 );
62+ return {dispatcher->Combine (permuted_expert_output)};
4863}
4964
5065} // namespace infini_train::nn::moe
0 commit comments