Skip to content

Commit 69a74fd

Browse files
committed
feat: implement MoETokenDispatcher base class and MoEAllGatherTokenDispatcher
1 parent abcfbef commit 69a74fd

9 files changed

Lines changed: 408 additions & 25 deletions

File tree

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#pragma once
2+
3+
#include <memory>
4+
#include <vector>
5+
6+
#include "infini_train/include/autograd/function.h"
7+
8+
namespace infini_train {
9+
class Tensor;
10+
}
11+
12+
namespace infini_train::autograd {
13+
14+
class ScatterAdd : public Function {
15+
public:
16+
static constexpr char kType[] = "ScatterAddFunction";
17+
18+
ScatterAdd(int64_t dim, const std::vector<int64_t> &output_dims)
19+
: Function(kType), dim_(dim), output_dims_(output_dims) {}
20+
21+
std::vector<std::shared_ptr<Tensor>> Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) override;
22+
void SetupContext(const std::vector<std::shared_ptr<Tensor>> &input_tensors,
23+
const std::vector<std::shared_ptr<Tensor>> &output_tensors) override;
24+
std::vector<std::shared_ptr<Tensor>> Backward(const std::vector<std::shared_ptr<Tensor>> &grad_outputs) override;
25+
26+
private:
27+
int64_t dim_ = 0;
28+
std::vector<int64_t> output_dims_;
29+
};
30+
31+
} // namespace infini_train::autograd

infini_train/include/nn/modules/transformer/moe/moe_utils.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,32 @@
99

1010
namespace infini_train::nn::moe {
1111

12+
struct PermutationMetadata {
13+
std::shared_ptr<Tensor> sorted_indices;
14+
std::shared_ptr<Tensor> gather_indices;
15+
std::shared_ptr<Tensor> route_indices;
16+
std::shared_ptr<Tensor> tokens_per_expert;
17+
std::vector<int64_t> tokens_per_expert_host;
18+
};
19+
20+
struct PermutationResult {
21+
std::shared_ptr<Tensor> permuted_hidden_states;
22+
std::shared_ptr<Tensor> permuted_probs;
23+
PermutationMetadata metadata;
24+
};
25+
1226
std::vector<std::shared_ptr<Tensor>> TopkRoutingWithScoreFunction(const std::shared_ptr<Tensor> &logits, int64_t topk,
1327
bool use_pre_softmax,
1428
std::optional<float> scaling_factor,
1529
const MoEConfig::RouterScoreFunction &score_function);
1630

1731
const MoEConfig &RequireMoEConfig(const TransformerConfig &config);
32+
PermutationMetadata BuildPermutationMetadata(const std::shared_ptr<Tensor> &routing_map);
33+
PermutationResult Permute(const std::shared_ptr<Tensor> &hidden_states_2d,
34+
const std::shared_ptr<Tensor> &routing_probs_2d,
35+
const std::shared_ptr<Tensor> &routing_map_2d);
36+
std::shared_ptr<Tensor> Unpermute(const std::shared_ptr<Tensor> &permuted_hidden_states,
37+
const std::shared_ptr<Tensor> &permuted_probs, const PermutationMetadata &metadata,
38+
const std::vector<int64_t> &restore_shape);
1839

1940
} // namespace infini_train::nn::moe
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
#pragma once
2+
3+
#include <cstdint>
4+
#include <memory>
5+
#include <vector>
6+
7+
#include "infini_train/include/nn/modules/transformer/moe/moe_utils.h"
8+
#include "infini_train/include/nn/modules/transformer/transformer_config.h"
9+
#include "infini_train/include/tensor.h"
10+
11+
namespace infini_train::nn::moe {
12+
13+
class MoETokenDispatcher {
14+
public:
15+
virtual ~MoETokenDispatcher() = default;
16+
17+
const PermutationResult &Dispatch(const std::shared_ptr<Tensor> &tokens, const std::shared_ptr<Tensor> &routing_map,
18+
const std::shared_ptr<Tensor> &probs);
19+
std::shared_ptr<Tensor> Combine(const std::shared_ptr<Tensor> &hidden_states) const;
20+
21+
protected:
22+
explicit MoETokenDispatcher(const TransformerConfig &config);
23+
24+
virtual std::vector<std::shared_ptr<Tensor>> DispatchPreprocess(const std::shared_ptr<Tensor> &tokens,
25+
const std::shared_ptr<Tensor> &routing_map,
26+
const std::shared_ptr<Tensor> &probs)
27+
= 0;
28+
virtual std::vector<std::shared_ptr<Tensor>> TokenDispatch(const std::shared_ptr<Tensor> &hidden_states,
29+
const std::shared_ptr<Tensor> &probs) const
30+
= 0;
31+
virtual const PermutationResult &DispatchPostprocess(const std::shared_ptr<Tensor> &hidden_states,
32+
const std::shared_ptr<Tensor> &probs)
33+
= 0;
34+
virtual std::shared_ptr<Tensor> CombinePreprocess(const std::shared_ptr<Tensor> &hidden_states) const = 0;
35+
virtual std::shared_ptr<Tensor> TokenCombine(const std::shared_ptr<Tensor> &hidden_states) const = 0;
36+
virtual std::shared_ptr<Tensor> CombinePostprocess(const std::shared_ptr<Tensor> &hidden_states) const = 0;
37+
38+
TransformerConfig config_;
39+
PermutationResult dispatch_;
40+
std::vector<int64_t> hidden_dims_;
41+
std::shared_ptr<Tensor> routing_map_;
42+
std::shared_ptr<Tensor> local_map_;
43+
std::shared_ptr<Tensor> local_probs_;
44+
int64_t num_tokens_ = 0;
45+
int64_t hidden_size_ = 0;
46+
};
47+
48+
class MoEAllGatherTokenDispatcher : public MoETokenDispatcher {
49+
public:
50+
MoEAllGatherTokenDispatcher(int64_t num_local_experts, const TransformerConfig &config);
51+
52+
private:
53+
std::vector<std::shared_ptr<Tensor>> DispatchPreprocess(const std::shared_ptr<Tensor> &tokens,
54+
const std::shared_ptr<Tensor> &routing_map,
55+
const std::shared_ptr<Tensor> &probs) override;
56+
std::vector<std::shared_ptr<Tensor>> TokenDispatch(const std::shared_ptr<Tensor> &hidden_states,
57+
const std::shared_ptr<Tensor> &probs) const override;
58+
const PermutationResult &DispatchPostprocess(const std::shared_ptr<Tensor> &hidden_states,
59+
const std::shared_ptr<Tensor> &probs) override;
60+
std::shared_ptr<Tensor> CombinePreprocess(const std::shared_ptr<Tensor> &hidden_states) const override;
61+
std::shared_ptr<Tensor> TokenCombine(const std::shared_ptr<Tensor> &hidden_states) const override;
62+
std::shared_ptr<Tensor> CombinePostprocess(const std::shared_ptr<Tensor> &hidden_states) const override;
63+
64+
int64_t num_local_experts_ = 0;
65+
};
66+
67+
} // namespace infini_train::nn::moe
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#include "infini_train/include/autograd/scatter_add.h"
2+
3+
#include "glog/logging.h"
4+
5+
#include "infini_train/include/dispatcher.h"
6+
#include "infini_train/include/tensor.h"
7+
8+
namespace infini_train::autograd {
9+
10+
std::vector<std::shared_ptr<Tensor>> ScatterAdd::Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) {
11+
CHECK_EQ(input_tensors.size(), 2);
12+
const auto &values = input_tensors[0];
13+
const auto &indices = input_tensors[1];
14+
auto device = values->GetDevice().type();
15+
auto output = Dispatcher::Instance().Call<std::shared_ptr<Tensor>>({device, "GatherBackward"}, values, indices,
16+
dim_, output_dims_);
17+
return {output};
18+
}
19+
20+
void ScatterAdd::SetupContext(const std::vector<std::shared_ptr<Tensor>> &input_tensors,
21+
const std::vector<std::shared_ptr<Tensor>> &) {
22+
saved_tensors_ = {input_tensors[1]};
23+
}
24+
25+
std::vector<std::shared_ptr<Tensor>> ScatterAdd::Backward(const std::vector<std::shared_ptr<Tensor>> &grad_outputs) {
26+
CHECK_EQ(grad_outputs.size(), 1);
27+
const auto &grad_output = grad_outputs[0];
28+
const auto &indices = saved_tensors_[0];
29+
auto device = grad_output->GetDevice().type();
30+
auto grad_values
31+
= Dispatcher::Instance().Call<std::shared_ptr<Tensor>>({device, "GatherForward"}, grad_output, indices, dim_);
32+
return {grad_values, nullptr};
33+
}
34+
35+
} // namespace infini_train::autograd

infini_train/src/kernels/cpu/concat.cc

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
#include <algorithm>
1+
#include <cstddef>
22
#include <memory>
33
#include <numeric>
4-
#include <utility>
54
#include <vector>
65

76
#include "glog/logging.h"
@@ -42,23 +41,24 @@ std::shared_ptr<Tensor> ConcatForward(const std::vector<std::shared_ptr<Tensor>>
4241
const int64_t K_total = std::accumulate(Ks.begin(), Ks.end(), int64_t{0});
4342
output_dims[dim] = K_total;
4443

45-
auto output = std::make_shared<Tensor>(output_dims, DataType::kFLOAT32);
44+
auto output = std::make_shared<Tensor>(output_dims, dtype, device);
4645

4746
const int64_t outer_size
4847
= std::accumulate(output_dims.begin(), output_dims.begin() + dim, 1LL, std::multiplies<int64_t>());
4948
const int64_t inner_size
5049
= std::accumulate(output_dims.begin() + dim + 1, output_dims.end(), 1LL, std::multiplies<int64_t>());
51-
const size_t elem_size = sizeof(float);
50+
const size_t elem_size = kDataTypeToSize.at(dtype);
5251

53-
float *dst_ptr_base = static_cast<float *>(output->DataPtr());
52+
auto *dst_ptr_base = static_cast<std::byte *>(output->DataPtr());
5453
for (int64_t n = 0; n < outer_size; ++n) {
5554
int64_t offset_k = 0;
56-
float *dst_block = dst_ptr_base + n * K_total * inner_size;
55+
auto *dst_block = dst_ptr_base + n * K_total * inner_size * elem_size;
5756

5857
for (size_t i = 0; i < inputs.size(); ++i) {
5958
const int64_t Ki = Ks[i];
60-
const float *src_ptr = static_cast<const float *>(inputs[i]->DataPtr()) + n * Ki * inner_size;
61-
float *dst_ptr = dst_block + offset_k * inner_size;
59+
const auto *src_ptr
60+
= static_cast<const std::byte *>(inputs[i]->DataPtr()) + n * Ki * inner_size * elem_size;
61+
auto *dst_ptr = dst_block + offset_k * inner_size * elem_size;
6262
std::memcpy(dst_ptr, src_ptr, static_cast<size_t>(Ki) * inner_size * elem_size);
6363
offset_k += Ki;
6464
}

infini_train/src/kernels/cpu/transform.cc

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
#include <cmath>
2+
#include <cstddef>
3+
#include <cstring>
24
#include <memory>
35

46
#include "glog/logging.h"
@@ -167,14 +169,15 @@ std::shared_ptr<Tensor> RepeatInterleaveForward(const std::shared_ptr<Tensor> &i
167169
output_dims[dim] = dim_size * repeat;
168170
auto output = std::make_shared<Tensor>(output_dims, input->Dtype(), input->GetDevice());
169171

170-
const float *input_ptr = static_cast<const float *>(input->DataPtr());
171-
float *output_ptr = static_cast<float *>(output->DataPtr());
172+
const size_t elem_size = kDataTypeToSize.at(input->Dtype());
173+
const auto *input_ptr = static_cast<const std::byte *>(input->DataPtr());
174+
auto *output_ptr = static_cast<std::byte *>(output->DataPtr());
172175

173176
for (int64_t o = 0; o < outer; ++o) {
174177
for (int64_t i = 0; i < dim_size; ++i) {
175178
for (int r = 0; r < repeat; ++r) {
176-
std::memcpy(output_ptr + ((o * dim_size * repeat + i * repeat + r) * inner),
177-
input_ptr + ((o * dim_size + i) * inner), sizeof(float) * inner);
179+
std::memcpy(output_ptr + ((o * dim_size * repeat + i * repeat + r) * inner * elem_size),
180+
input_ptr + ((o * dim_size + i) * inner * elem_size), elem_size * inner);
178181
}
179182
}
180183
}

infini_train/src/nn/modules/transformer/moe/experts.cc

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,21 @@
66

77
#include "glog/logging.h"
88

9+
#include "infini_train/include/nn/functional.h"
910
#include "infini_train/include/nn/modules/transformer/mlp.h"
1011
#include "infini_train/include/nn/modules/transformer/moe/moe_utils.h"
12+
#include "infini_train/include/nn/modules/transformer/moe/token_dispatcher.h"
1113
#include "infini_train/include/tensor.h"
1214

1315
namespace infini_train::nn::moe {
1416

1517
SequentialMLP::SequentialMLP(const TransformerConfig &config) : CloneableModule(kType), config_(config) {
1618
const auto &moe_config = RequireMoEConfig(config_);
17-
CHECK(moe_config.expert_impl == MoEExpertImpl::kSequential);
19+
CHECK(moe_config.expert_impl == MoEConfig::ExpertImpl::kSequential);
1820
CHECK_EQ(moe_config.expert_parallel_size, 1)
1921
<< "Current InfiniTrain MoE implementation supports expert_parallel_size=1 only";
20-
CHECK(moe_config.dispatcher_type == MoEDispatcherType::kLocal)
21-
<< "Current InfiniTrain MoE implementation supports local dispatch only";
22+
CHECK(moe_config.dispatcher_type == MoEConfig::DispatcherType::kAllGather)
23+
<< "Current InfiniTrain MoE implementation supports AllGather dispatcher only";
2224

2325
num_local_experts_ = moe_config.num_experts;
2426
CHECK_GT(num_local_experts_, 0);
@@ -29,22 +31,35 @@ SequentialMLP::SequentialMLP(const TransformerConfig &config) : CloneableModule(
2931
}
3032

3133
std::vector<std::shared_ptr<Tensor>> SequentialMLP::Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) {
32-
CHECK_EQ(input_tensors.size(), 2);
34+
CHECK_EQ(input_tensors.size(), 3);
3335
auto hidden_states = input_tensors[0];
3436
auto routing_probs = input_tensors[1];
35-
CHECK_EQ(routing_probs->Dims().back(), num_local_experts_);
37+
auto routing_map = input_tensors[2];
38+
std::unique_ptr<MoETokenDispatcher> dispatcher
39+
= std::make_unique<MoEAllGatherTokenDispatcher>(num_local_experts_, config_);
40+
const auto &dispatch = dispatcher->Dispatch(hidden_states, routing_map, routing_probs);
3641

37-
std::shared_ptr<Tensor> output = nullptr;
38-
const int64_t expert_dim = static_cast<int64_t>(routing_probs->Dims().size()) - 1;
42+
std::vector<std::shared_ptr<Tensor>> expert_outputs;
43+
int64_t start = 0;
3944
for (int64_t expert_idx = 0; expert_idx < num_local_experts_; ++expert_idx) {
45+
const int64_t num_tokens_for_expert = dispatch.metadata.tokens_per_expert_host[expert_idx];
46+
const int64_t end = start + num_tokens_for_expert;
47+
if (num_tokens_for_expert == 0) {
48+
start = end;
49+
continue;
50+
}
51+
52+
auto expert_input = dispatch.permuted_hidden_states->Slice(0, start, end);
4053
auto expert_name = std::string(kExpertNamePrefix) + std::to_string(expert_idx);
41-
auto expert_output = (*modules_.at(expert_name))({hidden_states})[0];
42-
auto expert_prob = routing_probs->Slice(expert_dim, expert_idx, expert_idx + 1);
43-
auto weighted_output = expert_output * expert_prob;
44-
output = output == nullptr ? weighted_output : output + weighted_output;
54+
expert_outputs.push_back((*modules_.at(expert_name))({expert_input})[0]);
55+
start = end;
4556
}
57+
CHECK_EQ(start, dispatch.permuted_hidden_states->Dims()[0]);
58+
CHECK(!expert_outputs.empty()) << "No tokens were dispatched to any local expert";
4659

47-
return {output};
60+
auto permuted_expert_output
61+
= expert_outputs.size() == 1 ? expert_outputs[0] : nn::function::Concat(expert_outputs, 0);
62+
return {dispatcher->Combine(permuted_expert_output)};
4863
}
4964

5065
} // namespace infini_train::nn::moe

0 commit comments

Comments
 (0)