Skip to content

Commit 766e679

Browse files
committed
feat: implement MoE infrastructure
1 parent 60e5605 commit 766e679

15 files changed

Lines changed: 542 additions & 1 deletion

File tree

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#pragma once
2+
3+
#include <memory>
4+
#include <vector>
5+
6+
#include "infini_train/include/autograd/function.h"
7+
8+
namespace infini_train {
9+
class Tensor;
10+
}
11+
12+
namespace infini_train::autograd {
13+
14+
class Top1Mask : public Function {
15+
public:
16+
static constexpr char kType[] = "Top1MaskFunction";
17+
18+
Top1Mask() : Function(kType) {}
19+
20+
std::vector<std::shared_ptr<Tensor>> Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) override;
21+
void SetupContext(const std::vector<std::shared_ptr<Tensor>> &input_tensors,
22+
const std::vector<std::shared_ptr<Tensor>> &output_tensors) override;
23+
std::vector<std::shared_ptr<Tensor>> Backward(const std::vector<std::shared_ptr<Tensor>> &grad_outputs) override;
24+
};
25+
26+
} // namespace infini_train::autograd
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#pragma once
2+
3+
#include <memory>
4+
#include <vector>
5+
6+
#include "infini_train/include/nn/modules/module.h"
7+
#include "infini_train/include/nn/modules/transformer/transformer_config.h"
8+
9+
namespace infini_train::nn::moe {
10+
11+
class SequentialMLP : public CloneableModule<SequentialMLP> {
12+
public:
13+
static constexpr char kType[] = "SequentialMLP";
14+
static constexpr char kExpertNamePrefix[] = "expert_";
15+
16+
explicit SequentialMLP(const TransformerConfig &config);
17+
18+
std::vector<std::shared_ptr<Tensor>> Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) override;
19+
20+
private:
21+
TransformerConfig config_;
22+
int64_t num_local_experts_ = 0;
23+
};
24+
25+
} // namespace infini_train::nn::moe
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#pragma once
2+
3+
#include <memory>
4+
#include <vector>
5+
6+
#include "infini_train/include/nn/modules/module.h"
7+
#include "infini_train/include/nn/modules/transformer/transformer_config.h"
8+
9+
namespace infini_train::nn::moe {
10+
11+
class MoELayer : public CloneableModule<MoELayer> {
12+
public:
13+
static constexpr char kType[] = "MoELayer";
14+
static constexpr char kRouterLayerName[] = "router";
15+
static constexpr char kExpertsLayerName[] = "experts";
16+
17+
explicit MoELayer(const TransformerConfig &config);
18+
19+
std::vector<std::shared_ptr<Tensor>> Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) override;
20+
21+
private:
22+
TransformerConfig config_;
23+
};
24+
25+
} // namespace infini_train::nn::moe
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
#pragma once
2+
3+
#include "infini_train/include/nn/modules/transformer/transformer_config.h"
4+
5+
namespace infini_train::nn::moe {
6+
7+
const MoEConfig &RequireMoEConfig(const TransformerConfig &config);
8+
9+
} // namespace infini_train::nn::moe
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#pragma once
2+
3+
#include <memory>
4+
#include <vector>
5+
6+
#include "infini_train/include/nn/modules/module.h"
7+
#include "infini_train/include/nn/modules/transformer/transformer_config.h"
8+
9+
namespace infini_train::nn::moe {
10+
11+
class TopKRouter : public CloneableModule<TopKRouter> {
12+
public:
13+
static constexpr char kType[] = "TopKRouter";
14+
static constexpr char kParamWeightName[] = "weight";
15+
static constexpr char kParamBiasName[] = "bias";
16+
17+
explicit TopKRouter(const TransformerConfig &config);
18+
19+
std::vector<std::shared_ptr<Tensor>> Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) override;
20+
21+
private:
22+
TransformerConfig config_;
23+
};
24+
25+
} // namespace infini_train::nn::moe

infini_train/include/nn/modules/transformer/transformer_config.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,42 @@ enum class MLPType {
2020
kSwiGLU // SwiGLU activation
2121
};
2222

23+
enum class FFNType {
24+
kDense, // Standard dense MLP
25+
kMoE // Mixture-of-Experts MLP
26+
};
27+
2328
enum class NormType {
2429
kLayerNorm, // LayerNorm
2530
kRMSNorm // RMSNorm
2631
};
2732

33+
enum class MoERouterType {
34+
kTopK // Top-k router. The initial implementation supports top-1.
35+
};
36+
37+
enum class MoEDispatcherType {
38+
kLocal, // No cross-rank token exchange
39+
kAllGather // Reserved for expert parallel MoE
40+
};
41+
42+
enum class MoEExpertImpl {
43+
kSequential // Run local experts sequentially
44+
};
45+
46+
struct MoEConfig {
47+
int64_t num_experts = 0;
48+
int64_t expert_parallel_size = 1;
49+
int64_t router_topk = 1;
50+
float aux_loss_coeff = 0.0f;
51+
std::optional<float> expert_capacity_factor = std::nullopt;
52+
bool pad_expert_input_to_capacity = false;
53+
int64_t moe_ffn_hidden_size = 0;
54+
MoERouterType router_type = MoERouterType::kTopK;
55+
MoEDispatcherType dispatcher_type = MoEDispatcherType::kLocal;
56+
MoEExpertImpl expert_impl = MoEExpertImpl::kSequential;
57+
};
58+
2859
struct TransformerConfig {
2960
int64_t block_size = 1024; // Max seq_len
3061
int64_t vocab_size = 50304; // Vocab size
@@ -36,6 +67,7 @@ struct TransformerConfig {
3667

3768
AttentionType attention_type = AttentionType::kStandard; // Attention mechanism type
3869
MLPType activation_type = MLPType::kGELU; // MLP activation type
70+
FFNType ffn_type = FFNType::kDense; // Feed-forward module type
3971
NormType norm_type = NormType::kLayerNorm; // Normalization type
4072

4173
bool add_bias_linear = true; // Whether to add learnable bias to all Linear layers in the Transformer block,
@@ -48,6 +80,7 @@ struct TransformerConfig {
4880
float ffn_expansion_ratio = 4.0f; // MLP output: n_embd * ffn_expansion_ratio
4981
std::optional<float> ffn_dim_multiplier = 1.5f; // FFN dim multiplier
5082
int64_t multiple_of = 256; // FFN dims must be multiple of this number
83+
std::optional<MoEConfig> moe_config = std::nullopt;
5184

5285
// RoPE config
5386
float rope_theta = 500000.0f; // theta in RoPE

infini_train/src/autograd/moe.cc

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#include "infini_train/include/autograd/moe.h"
2+
3+
#include "glog/logging.h"
4+
5+
#include "infini_train/include/dispatcher.h"
6+
#include "infini_train/include/tensor.h"
7+
8+
namespace infini_train::autograd {
9+
10+
std::vector<std::shared_ptr<Tensor>> Top1Mask::Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) {
11+
CHECK_EQ(input_tensors.size(), 1);
12+
const auto &input = input_tensors[0];
13+
auto device = input->GetDevice().type();
14+
return {Dispatcher::Instance().Call<std::shared_ptr<Tensor>>({device, "Top1MaskForward"}, input)};
15+
}
16+
17+
void Top1Mask::SetupContext(const std::vector<std::shared_ptr<Tensor>> &,
18+
const std::vector<std::shared_ptr<Tensor>> &output_tensors) {
19+
saved_tensors_ = {output_tensors[0]};
20+
}
21+
22+
std::vector<std::shared_ptr<Tensor>> Top1Mask::Backward(const std::vector<std::shared_ptr<Tensor>> &grad_outputs) {
23+
CHECK_EQ(grad_outputs.size(), 1);
24+
const auto &grad_output = grad_outputs[0];
25+
const auto &mask_values = saved_tensors_[0];
26+
auto device = grad_output->GetDevice().type();
27+
return {
28+
Dispatcher::Instance().Call<std::shared_ptr<Tensor>>({device, "Top1MaskBackward"}, grad_output, mask_values)};
29+
}
30+
31+
} // namespace infini_train::autograd
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
#include <memory>
2+
3+
#include "glog/logging.h"
4+
5+
#include "infini_train/include/dispatcher.h"
6+
#include "infini_train/include/tensor.h"
7+
8+
namespace infini_train::kernels::cpu {
9+
10+
std::shared_ptr<Tensor> Top1MaskForward(const std::shared_ptr<Tensor> &input) {
11+
CHECK(input->Dtype() == DataType::kFLOAT32) << "CPU Top1MaskForward currently supports float32 only";
12+
CHECK_GE(input->Dims().size(), 1);
13+
14+
const auto &dims = input->Dims();
15+
const int64_t num_experts = dims.back();
16+
CHECK_GT(num_experts, 0);
17+
const int64_t rows = input->NumElements() / num_experts;
18+
19+
auto output = std::make_shared<Tensor>(dims, input->Dtype(), input->GetDevice());
20+
output->Fill(0.0f);
21+
22+
const float *in = static_cast<const float *>(input->DataPtr());
23+
float *out = static_cast<float *>(output->DataPtr());
24+
for (int64_t row = 0; row < rows; ++row) {
25+
int64_t best_idx = 0;
26+
float best_value = in[row * num_experts];
27+
for (int64_t expert_idx = 1; expert_idx < num_experts; ++expert_idx) {
28+
const float value = in[row * num_experts + expert_idx];
29+
if (value > best_value) {
30+
best_value = value;
31+
best_idx = expert_idx;
32+
}
33+
}
34+
out[row * num_experts + best_idx] = best_value;
35+
}
36+
37+
return output;
38+
}
39+
40+
std::shared_ptr<Tensor> Top1MaskBackward(const std::shared_ptr<Tensor> &grad_output,
41+
const std::shared_ptr<Tensor> &mask_values) {
42+
CHECK(grad_output->Dtype() == DataType::kFLOAT32) << "CPU Top1MaskBackward currently supports float32 only";
43+
CHECK(mask_values->Dtype() == DataType::kFLOAT32);
44+
CHECK(grad_output->Dims() == mask_values->Dims());
45+
46+
auto grad_input = std::make_shared<Tensor>(grad_output->Dims(), grad_output->Dtype(), grad_output->GetDevice());
47+
grad_input->Fill(0.0f);
48+
49+
const float *grad = static_cast<const float *>(grad_output->DataPtr());
50+
const float *mask = static_cast<const float *>(mask_values->DataPtr());
51+
float *out = static_cast<float *>(grad_input->DataPtr());
52+
for (int64_t i = 0; i < static_cast<int64_t>(grad_output->NumElements()); ++i) {
53+
out[i] = mask[i] != 0.0f ? grad[i] : 0.0f;
54+
}
55+
56+
return grad_input;
57+
}
58+
59+
} // namespace infini_train::kernels::cpu
60+
61+
#define REGISTER_CPU_TOP1_MASK_KERNEL(kernel_name) \
62+
REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name)
63+
64+
REGISTER_CPU_TOP1_MASK_KERNEL(Top1MaskForward)
65+
REGISTER_CPU_TOP1_MASK_KERNEL(Top1MaskBackward)
66+
67+
#undef REGISTER_CPU_TOP1_MASK_KERNEL
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
#include "glog/logging.h"
2+
3+
#include "infini_train/include/common/cuda/common_cuda.h"
4+
#include "infini_train/include/core/runtime/device_guard.h"
5+
#include "infini_train/include/dispatcher.h"
6+
#include "infini_train/include/tensor.h"
7+
8+
#include "infini_train/src/core/runtime/cuda/cuda_dispatch.h"
9+
#include "infini_train/src/core/runtime/cuda/cuda_runtime_common.h"
10+
11+
namespace infini_train::kernels::cuda {
12+
13+
template <typename T>
14+
__global__ void Top1MaskForwardKernel(const T *__restrict__ input, T *__restrict__ output, int64_t rows,
15+
int64_t num_experts) {
16+
int64_t row = blockIdx.x * blockDim.x + threadIdx.x;
17+
if (row >= rows) {
18+
return;
19+
}
20+
21+
const int64_t offset = row * num_experts;
22+
int64_t best_idx = 0;
23+
float best_value = static_cast<float>(input[offset]);
24+
for (int64_t expert_idx = 1; expert_idx < num_experts; ++expert_idx) {
25+
const float value = static_cast<float>(input[offset + expert_idx]);
26+
if (value > best_value) {
27+
best_value = value;
28+
best_idx = expert_idx;
29+
}
30+
}
31+
for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) {
32+
output[offset + expert_idx] = expert_idx == best_idx ? input[offset + expert_idx] : T(0.0f);
33+
}
34+
}
35+
36+
std::shared_ptr<Tensor> Top1MaskForward(const std::shared_ptr<Tensor> &input) {
37+
CHECK_GE(input->Dims().size(), 1);
38+
const auto &dims = input->Dims();
39+
const int64_t num_experts = dims.back();
40+
CHECK_GT(num_experts, 0);
41+
const int64_t rows = input->NumElements() / num_experts;
42+
43+
auto output = std::make_shared<Tensor>(dims, input->Dtype(), input->GetDevice());
44+
45+
auto device = input->GetDevice();
46+
const auto &stream = dynamic_cast<infini_train::core::cuda::CudaStream *>(
47+
infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device))
48+
->cuda_stream();
49+
const int threads = 256;
50+
const int blocks = static_cast<int>((rows + threads - 1) / threads);
51+
52+
core::cuda::DispatchCudaFunc<INFINI_ALL_FLOATING_TYPES>(
53+
input->Dtype(),
54+
[=]<typename T>() {
55+
Top1MaskForwardKernel<T><<<blocks, threads, 0, stream>>>(
56+
static_cast<const T *>(input->DataPtr()), static_cast<T *>(output->DataPtr()), rows, num_experts);
57+
},
58+
"CUDA Top1MaskForward");
59+
60+
return output;
61+
}
62+
63+
template <typename T>
64+
__global__ void Top1MaskBackwardKernel(const T *__restrict__ grad_output, const T *__restrict__ mask_values,
65+
T *__restrict__ grad_input, int64_t total_elements) {
66+
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
67+
if (idx >= total_elements) {
68+
return;
69+
}
70+
grad_input[idx] = static_cast<float>(mask_values[idx]) != 0.0f ? grad_output[idx] : T(0.0f);
71+
}
72+
73+
std::shared_ptr<Tensor> Top1MaskBackward(const std::shared_ptr<Tensor> &grad_output,
74+
const std::shared_ptr<Tensor> &mask_values) {
75+
CHECK(grad_output->Dims() == mask_values->Dims());
76+
CHECK(grad_output->Dtype() == mask_values->Dtype());
77+
auto grad_input = std::make_shared<Tensor>(grad_output->Dims(), grad_output->Dtype(), grad_output->GetDevice());
78+
79+
auto device = grad_output->GetDevice();
80+
const auto &stream = dynamic_cast<infini_train::core::cuda::CudaStream *>(
81+
infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device))
82+
->cuda_stream();
83+
const int64_t total_elements = grad_output->NumElements();
84+
const int threads = 256;
85+
const int blocks = static_cast<int>((total_elements + threads - 1) / threads);
86+
87+
core::cuda::DispatchCudaFunc<INFINI_ALL_FLOATING_TYPES>(
88+
grad_output->Dtype(),
89+
[=]<typename T>() {
90+
Top1MaskBackwardKernel<T><<<blocks, threads, 0, stream>>>(
91+
static_cast<const T *>(grad_output->DataPtr()), static_cast<const T *>(mask_values->DataPtr()),
92+
static_cast<T *>(grad_input->DataPtr()), total_elements);
93+
},
94+
"CUDA Top1MaskBackward");
95+
96+
return grad_input;
97+
}
98+
99+
} // namespace infini_train::kernels::cuda
100+
101+
#define REGISTER_CUDA_TOP1_MASK_KERNEL(kernel_name) \
102+
REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name)
103+
104+
REGISTER_CUDA_TOP1_MASK_KERNEL(Top1MaskForward)
105+
REGISTER_CUDA_TOP1_MASK_KERNEL(Top1MaskBackward)
106+
107+
#undef REGISTER_CUDA_TOP1_MASK_KERNEL

0 commit comments

Comments
 (0)