diff --git a/CMakeLists.txt b/CMakeLists.txt index e25de71d..cc3d5ee1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -199,6 +199,14 @@ add_executable(gpt2 ) link_infini_train_exe(gpt2) +add_executable(tiny_mixtral + example/tiny_mixtral/main.cc + example/common/tiny_shakespeare_dataset.cc + example/common/utils.cc + example/tiny_mixtral/checkpoint_loader.cc +) +link_infini_train_exe(tiny_mixtral) + add_executable(llama3 example/llama3/main.cc example/common/tiny_shakespeare_dataset.cc diff --git a/example/tiny_mixtral/checkpoint_loader.cc b/example/tiny_mixtral/checkpoint_loader.cc new file mode 100644 index 00000000..1e27ac53 --- /dev/null +++ b/example/tiny_mixtral/checkpoint_loader.cc @@ -0,0 +1,167 @@ +#include "example/tiny_mixtral/checkpoint_loader.h" + +#include +#include +#include +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/datatype.h" +#include "infini_train/include/nn/modules/transformer/transformer.h" +#include "infini_train/include/tensor.h" + +#include "example/common/utils.h" +#include "example/tiny_mixtral/config.h" + +namespace nn = infini_train::nn; + +namespace { + +constexpr int32_t kTinyMixtralLLMCMagic = 20260513; +constexpr int32_t kTinyMixtralLLMCVersion = 2; +constexpr int64_t kLLMCHeaderEntries = 256; + +} // namespace + +namespace tiny_mixtral { + +namespace { + +template +void CompareCheckpointValue(const std::string &name, const T &checkpoint_value, const T &runtime_value) { + CHECK_EQ(checkpoint_value, runtime_value) << name << " value from checkpoint (" << checkpoint_value + << ") is not equal to runtime config value (" << runtime_value << ")"; +} + +} // namespace + +nn::TransformerConfig ConfigFromLLMC(const std::string &filepath) { + std::ifstream ifs(filepath, std::ios::binary); + CHECK(ifs) << "Failed to open tiny Mixtral LLMC file: " << filepath; + const auto header = infini_train::ReadSeveralBytesFromIfstream(kLLMCHeaderEntries * sizeof(int32_t), &ifs); + CHECK(ifs) << "Failed to read tiny Mixtral LLMC header: " << filepath; + CHECK_EQ(infini_train::BytesToType(header, 0 * sizeof(int32_t)), kTinyMixtralLLMCMagic); + CHECK_EQ(infini_train::BytesToType(header, 1 * sizeof(int32_t)), kTinyMixtralLLMCVersion); + + auto config = TinyMixtralConfig(); + config.block_size = infini_train::BytesToType(header, 2 * sizeof(int32_t)); + config.vocab_size = infini_train::BytesToType(header, 3 * sizeof(int32_t)); + config.original_vocab_size = config.vocab_size; + config.n_layer = infini_train::BytesToType(header, 4 * sizeof(int32_t)); + config.n_head = infini_train::BytesToType(header, 5 * sizeof(int32_t)); + config.n_kv_head = infini_train::BytesToType(header, 6 * sizeof(int32_t)); + config.n_embd = infini_train::BytesToType(header, 7 * sizeof(int32_t)); + config.ffn_expansion_ratio = infini_train::BytesToType(header, 9 * sizeof(int32_t)); + // Header slots 10 and 11 store dense-MLP helpers; MoE expert size is stored in moe_ffn_hidden_size. + config.norm_eps = infini_train::BytesToType(header, 12 * sizeof(int32_t)); + config.rope_theta = infini_train::BytesToType(header, 13 * sizeof(int32_t)); + config.use_scaled_rope = infini_train::BytesToType(header, 14 * sizeof(int32_t)) != 0; + + nn::MoEConfig moe_config; + moe_config.num_experts = infini_train::BytesToType(header, 8 * sizeof(int32_t)); + moe_config.expert_parallel_size = 1; + moe_config.router_topk = infini_train::BytesToType(header, 15 * sizeof(int32_t)); + moe_config.moe_ffn_hidden_size = infini_train::BytesToType(header, 16 * sizeof(int32_t)); + moe_config.token_dispatcher_type = nn::MoEConfig::TokenDispatcherType::kAllGather; + moe_config.expert_impl = nn::MoEConfig::ExpertImpl::kSequential; + config.moe_config = moe_config; + SanitizeTinyMixtralConfig(config); + return config; +} + +void CheckLLMCConfig(const std::string &filepath, const nn::TransformerConfig &expected_config) { + SanitizeTinyMixtralConfig(expected_config); + const auto checkpoint_config = ConfigFromLLMC(filepath); + CompareCheckpointValue("block_size", checkpoint_config.block_size, expected_config.block_size); + CompareCheckpointValue("vocab_size", checkpoint_config.vocab_size, expected_config.vocab_size); + CompareCheckpointValue("original_vocab_size", checkpoint_config.original_vocab_size, + expected_config.original_vocab_size); + CompareCheckpointValue("n_layer", checkpoint_config.n_layer, expected_config.n_layer); + CompareCheckpointValue("n_head", checkpoint_config.n_head, expected_config.n_head); + CompareCheckpointValue("n_kv_head", checkpoint_config.n_kv_head, expected_config.n_kv_head); + CompareCheckpointValue("n_embd", checkpoint_config.n_embd, expected_config.n_embd); + CompareCheckpointValue("ffn_expansion_ratio", checkpoint_config.ffn_expansion_ratio, + expected_config.ffn_expansion_ratio); + CompareCheckpointValue("norm_eps", checkpoint_config.norm_eps, expected_config.norm_eps); + CompareCheckpointValue("rope_theta", checkpoint_config.rope_theta, expected_config.rope_theta); + CompareCheckpointValue("use_scaled_rope", checkpoint_config.use_scaled_rope, expected_config.use_scaled_rope); + + CHECK(expected_config.moe_config.has_value()) << "tiny Mixtral runtime config requires MoE config"; + const auto &checkpoint_moe = checkpoint_config.moe_config.value(); + const auto &expected_moe = expected_config.moe_config.value(); + CompareCheckpointValue("num_experts", checkpoint_moe.num_experts, expected_moe.num_experts); + CompareCheckpointValue("router_topk", checkpoint_moe.router_topk, expected_moe.router_topk); + CompareCheckpointValue("moe_ffn_hidden_size", checkpoint_moe.moe_ffn_hidden_size, expected_moe.moe_ffn_hidden_size); +} + +std::shared_ptr LoadFromLLMC(const std::string &filepath, + const nn::TransformerConfig &expected_config) { + CheckLLMCConfig(filepath, expected_config); + auto model = std::make_shared(expected_config); + + std::ifstream ifs(filepath, std::ios::binary); + CHECK(ifs) << "Failed to open tiny Mixtral LLMC file: " << filepath; + const auto header = infini_train::ReadSeveralBytesFromIfstream(kLLMCHeaderEntries * sizeof(int32_t), &ifs); + CHECK(ifs) << "Failed to read tiny Mixtral LLMC header: " << filepath; + CHECK_EQ(infini_train::BytesToType(header, 0 * sizeof(int32_t)), kTinyMixtralLLMCMagic); + CHECK_EQ(infini_train::BytesToType(header, 1 * sizeof(int32_t)), kTinyMixtralLLMCVersion); + + const auto &config = expected_config; + auto state = model->StateDict(); + auto read_tensor_by_state_key = [&](const std::string &name) { + CHECK(state.contains(name)) << "Model state_dict does not contain " << name; + std::shared_ptr tensor = state.at(name); + CHECK(tensor->Dtype() == infini_train::DataType::kFLOAT32) + << "Only float32 tiny Mixtral LLMC files are supported: " << name; + infini_train::ReadMatrixAllFloat(ifs, static_cast(tensor->DataPtr()), tensor->NumElements(), 1); + CHECK(ifs) << "Failed to read tensor " << name; + }; + + auto read_projection_into_packed_qkv = [&](const std::string &packed_qkv_name, int64_t row_offset, int64_t num_rows, + const std::string &projection_name) { + CHECK(state.contains(packed_qkv_name)) << "Model state_dict does not contain " << packed_qkv_name; + std::shared_ptr tensor = state.at(packed_qkv_name); + CHECK(tensor->Dtype() == infini_train::DataType::kFLOAT32) + << "Only float32 tiny Mixtral LLMC files are supported: " << projection_name; + CHECK_EQ(tensor->Dims().size(), 2); + CHECK_GE(row_offset, 0); + CHECK_GT(num_rows, 0); + CHECK_LE(row_offset + num_rows, tensor->Dims()[0]); + const int64_t cols = tensor->Dims()[1]; + auto *data = static_cast(tensor->DataPtr()) + row_offset * cols; + infini_train::ReadMatrixAllFloat(ifs, data, num_rows, cols); + CHECK(ifs) << "Failed to read tensor rows " << projection_name; + }; + + const auto &moe_config = config.moe_config.value(); + read_tensor_by_state_key("transformer.wte.weight"); + for (int64_t layer = 0; layer < config.n_layer; ++layer) { + const std::string prefix = "transformer.h." + std::to_string(layer); + read_tensor_by_state_key(prefix + ".ln_1.weight"); + const auto c_attn_name = prefix + ".attn.c_attn.weight"; + const int64_t head_dim = config.n_embd / config.n_head; + const int64_t q_rows = config.n_head * head_dim; + const int64_t kv_rows = config.n_kv_head * head_dim; + read_projection_into_packed_qkv(c_attn_name, 0, q_rows, c_attn_name + ".q_proj"); + read_projection_into_packed_qkv(c_attn_name, q_rows, kv_rows, c_attn_name + ".k_proj"); + read_projection_into_packed_qkv(c_attn_name, q_rows + kv_rows, kv_rows, c_attn_name + ".v_proj"); + read_tensor_by_state_key(prefix + ".attn.c_proj.weight"); + read_tensor_by_state_key(prefix + ".ln_2.weight"); + read_tensor_by_state_key(prefix + ".mlp.router.weight"); + for (int64_t expert = 0; expert < moe_config.num_experts; ++expert) { + const std::string expert_prefix = prefix + ".mlp.experts.expert_" + std::to_string(expert); + read_tensor_by_state_key(expert_prefix + ".c_fc2.weight"); // Mixtral w1/gate_proj + read_tensor_by_state_key(expert_prefix + ".c_fc.weight"); // Mixtral w3/up_proj + read_tensor_by_state_key(expert_prefix + ".c_proj.weight"); // Mixtral w2/down_proj + } + } + read_tensor_by_state_key("transformer.ln_f.weight"); + read_tensor_by_state_key("lm_head.weight"); + + CHECK_EQ(ifs.peek(), std::ifstream::traits_type::eof()) << "Unexpected trailing bytes in tiny Mixtral LLMC file"; + return model; +} + +} // namespace tiny_mixtral diff --git a/example/tiny_mixtral/checkpoint_loader.h b/example/tiny_mixtral/checkpoint_loader.h new file mode 100644 index 00000000..738538ad --- /dev/null +++ b/example/tiny_mixtral/checkpoint_loader.h @@ -0,0 +1,21 @@ +#pragma once + +#include +#include + +#include "infini_train/include/nn/modules/transformer/transformer_config.h" + +namespace infini_train::nn { +class TransformerModel; +} // namespace infini_train::nn + +namespace tiny_mixtral { + +infini_train::nn::TransformerConfig ConfigFromLLMC(const std::string &filepath); + +void CheckLLMCConfig(const std::string &filepath, const infini_train::nn::TransformerConfig &expected_config); + +std::shared_ptr +LoadFromLLMC(const std::string &filepath, const infini_train::nn::TransformerConfig &expected_config); + +} // namespace tiny_mixtral diff --git a/example/tiny_mixtral/config.h b/example/tiny_mixtral/config.h new file mode 100644 index 00000000..5293fa93 --- /dev/null +++ b/example/tiny_mixtral/config.h @@ -0,0 +1,76 @@ +#pragma once + +#include "glog/logging.h" + +#include "infini_train/include/nn/modules/transformer/transformer_config.h" + +namespace nn = infini_train::nn; + +namespace tiny_mixtral { + +inline nn::TransformerConfig TinyMixtralConfig() { + nn::TransformerConfig config; + config.block_size = 32768; // Same as Mixtral/Megatron --max-position-embeddings. + config.vocab_size = 128256; // Validation data uses LLaMA3 token ids; real Mixtral uses 32000. + config.original_vocab_size = 128256; + config.n_layer = 32; + config.n_head = 4; // Scaled down; preserves Mixtral 4:1 GQA ratio. + config.n_kv_head = 1; // Scaled down with n_head; real Mixtral uses 8 KV heads. + config.n_embd = 512; // Scaled down from Mixtral/Megatron --hidden-size 4096. + config.attention_type = nn::AttentionType::kRoPE; + config.activation_type = nn::MLPType::kSwiGLU; + config.ffn_type = nn::FFNType::kMoE; + config.norm_type = nn::NormType::kRMSNorm; + config.add_bias_linear = false; + config.add_bias_lm_head = false; + config.tie_weights = false; + config.ffn_expansion_ratio = 3.5f; + config.norm_eps = 1e-5f; + config.rope_theta = 1000000.0f; + config.use_scaled_rope = false; + + nn::MoEConfig moe_config; + moe_config.num_experts = 8; + moe_config.expert_parallel_size = 1; // Single-rank validation scale. + moe_config.router_topk = 2; + moe_config.moe_ffn_hidden_size = 1792; // Scaled down as 512 * 3.5; real Mixtral uses 14336. + moe_config.token_dispatcher_type = nn::MoEConfig::TokenDispatcherType::kAllGather; // Single-rank validation path. + moe_config.expert_impl = nn::MoEConfig::ExpertImpl::kSequential; // Local correctness path. + config.moe_config = moe_config; + return config; +} + +inline void SanitizeTinyMixtralConfig(const nn::TransformerConfig &c) { + CHECK_GT(c.block_size, 0); + CHECK_GT(c.vocab_size, 0); + CHECK_GE(c.vocab_size, c.original_vocab_size); + CHECK_GT(c.n_layer, 0); + CHECK_GT(c.n_head, 0); + CHECK_GT(c.n_kv_head, 0); + CHECK_LE(c.n_kv_head, c.n_head); + CHECK_EQ(c.n_head % c.n_kv_head, 0) << "n_head must be divisible by n_kv_head for GQA"; + CHECK_GT(c.n_embd, 0); + CHECK_EQ(c.n_embd % c.n_head, 0) << "n_embd must be divisible by n_head"; + CHECK(c.attention_type == nn::AttentionType::kRoPE) << "tiny Mixtral requires RoPE attention"; + CHECK(c.activation_type == nn::MLPType::kSwiGLU) << "tiny Mixtral requires SwiGLU activation"; + CHECK(c.ffn_type == nn::FFNType::kMoE) << "tiny Mixtral requires MoE FFN"; + CHECK(c.norm_type == nn::NormType::kRMSNorm) << "tiny Mixtral requires RMSNorm"; + CHECK(!c.add_bias_linear) << "tiny Mixtral has no bias in linear layers"; + CHECK(!c.add_bias_lm_head) << "tiny Mixtral has no bias in lm_head"; + CHECK(!c.tie_weights) << "tiny Mixtral does not tie embedding and lm_head weights"; + CHECK(!c.use_scaled_rope) << "tiny Mixtral precision validation keeps scaled RoPE disabled"; + CHECK(c.moe_config.has_value()) << "tiny Mixtral requires MoE config"; + + const auto &moe = c.moe_config.value(); + CHECK_GT(moe.num_experts, 0); + CHECK_EQ(moe.expert_parallel_size, 1) << "tiny Mixtral single-rank validation expects EP=1"; + CHECK_GT(moe.router_topk, 0); + CHECK_LE(moe.router_topk, moe.num_experts); + CHECK_GT(moe.moe_ffn_hidden_size, 0); + CHECK(moe.token_dispatcher_type == nn::MoEConfig::TokenDispatcherType::kAllGather) + << "tiny Mixtral uses the Megatron-style AllGather dispatcher"; + CHECK(moe.expert_impl == nn::MoEConfig::ExpertImpl::kSequential) + << "tiny Mixtral validation uses SequentialMLP experts"; +} + +} // namespace tiny_mixtral diff --git a/example/tiny_mixtral/main.cc b/example/tiny_mixtral/main.cc new file mode 100644 index 00000000..e38d0346 --- /dev/null +++ b/example/tiny_mixtral/main.cc @@ -0,0 +1,161 @@ +#include +#include +#include +#include +#include +#include + +#include "gflags/gflags.h" +#include "glog/logging.h" + +#include "example/common/tiny_shakespeare_dataset.h" +#include "example/tiny_mixtral/checkpoint_loader.h" +#include "example/tiny_mixtral/config.h" +#include "infini_train/include/autocast.h" +#include "infini_train/include/core/runtime/device_guard.h" +#include "infini_train/include/dataloader.h" +#include "infini_train/include/device.h" +#include "infini_train/include/nn/modules/loss.h" +#include "infini_train/include/nn/modules/transformer/transformer.h" +#include "infini_train/include/nn/parallel/global.h" +#include "infini_train/include/optimizer.h" +#include "infini_train/include/tensor.h" + +DEFINE_string(input_bin, "", "input .bin to train on"); +DEFINE_uint32(micro_batch_size, 4, "micro batch size per training step"); +DEFINE_uint32(global_batch_size, 4, "global batch size across gradient accumulation and data parallelism"); +DEFINE_uint32(sequence_length, 64, "sequence length"); +DEFINE_uint32(num_iteration, 10, "number of training iterations"); +DEFINE_double(learning_rate, 1e-4, "SGD learning rate"); +DEFINE_string(llmc_filepath, "", + "optional PyTorch-generated tiny Mixtral LLMC model file path to load before training"); +DEFINE_string(device, "cpu", "Training device: cpu or cuda."); +DEFINE_string(dtype, "float32", "precision used in training (float32/bfloat16)"); +DEFINE_uint32(log_interval, 1, "Print train loss every N steps. 0 disables step loss logging."); +DEFINE_bool(print_timing, false, "Print training-loop elapsed time and token throughput."); + +namespace { + +using infini_train::Device; +using infini_train::Tensor; + +constexpr char kDtypeFP32[] = "float32"; +constexpr char kDtypeBF16[] = "bfloat16"; + +void ValidateRuntimeFlags(const infini_train::nn::TransformerConfig &config) { + CHECK(!FLAGS_input_bin.empty()) << "tiny Mixtral training requires --input_bin"; + CHECK_GT(FLAGS_micro_batch_size, 0); + CHECK_GT(FLAGS_global_batch_size, 0); + CHECK_EQ(FLAGS_global_batch_size % FLAGS_micro_batch_size, 0) + << "global_batch_size must be divisible by micro_batch_size"; + CHECK_GT(FLAGS_sequence_length, 0); + CHECK_LE(FLAGS_sequence_length, config.block_size) << "sequence_length must be <= model max positions (block_size)"; +} + +} // namespace + +int main(int argc, char *argv[]) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + google::InitGoogleLogging(argv[0]); + + infini_train::nn::parallel::global::InitAllEnv( + /*nthread_per_process=*/1, + /*tensor_parallel_size=*/1, + /*sequence_parallel_enabled=*/false, + /*pipeline_parallel_size=*/1, + /*virtual_pipeline_parallel_size=*/1); + + infini_train::nn::TransformerConfig model_config = tiny_mixtral::TinyMixtralConfig(); + tiny_mixtral::SanitizeTinyMixtralConfig(model_config); + std::shared_ptr model = nullptr; + if (!FLAGS_llmc_filepath.empty()) { + model = tiny_mixtral::LoadFromLLMC(FLAGS_llmc_filepath, model_config); + } else { + model = std::make_shared(model_config); + } + ValidateRuntimeFlags(model_config); + + Device train_device; + if (FLAGS_device == "cuda") { + train_device = Device(Device::DeviceType::kCUDA, 0); + model->To(train_device); + } else { + CHECK_EQ(FLAGS_device, "cpu") << "Unsupported training device: " << FLAGS_device; + train_device = Device(); + } + + infini_train::DistributedDataLoader train_loader( + std::make_shared(FLAGS_input_bin, FLAGS_sequence_length), FLAGS_micro_batch_size, + /*ddp_rank=*/0, /*ddp_world_size=*/1); + auto train_iter = train_loader.begin(); + + infini_train::DataType dtype; + if (FLAGS_dtype == kDtypeFP32) { + dtype = infini_train::DataType::kFLOAT32; + } else if (FLAGS_dtype == kDtypeBF16) { + dtype = infini_train::DataType::kBFLOAT16; + } else { + LOG(FATAL) << "Datatype " << FLAGS_dtype << " not supported."; + } + + auto loss_fn = std::make_shared(); + auto optimizer + = infini_train::optimizers::SGD::Create(static_cast(FLAGS_learning_rate))(model->Parameters()); + + auto device_impl = infini_train::core::GetDeviceGuardImpl(train_device.type()); + std::vector step_duration_ms; + step_duration_ms.reserve(FLAGS_num_iteration); + const uint32_t grad_accum_steps = FLAGS_global_batch_size / FLAGS_micro_batch_size; + const double tokens_per_step = static_cast(FLAGS_global_batch_size) * FLAGS_sequence_length; + for (uint32_t step = 0; step < FLAGS_num_iteration; ++step) { + device_impl->SynchronizeDevice(train_device); + const auto step_start_time = std::chrono::steady_clock::now(); + + optimizer->ZeroGrad(); + float lossf = 0.0f; + for (uint32_t micro_step = 0; micro_step < grad_accum_steps; ++micro_step) { + infini_train::AutocastGuard autocast_guard(train_device.type(), dtype); + if (train_iter == train_loader.end()) { + train_iter = train_loader.begin(); + } + auto [x_cpu, y_cpu] = *train_iter; + ++train_iter; + auto x = std::make_shared(x_cpu->To(train_device)); + auto y = std::make_shared(y_cpu->To(train_device)); + auto logits = (*model)({x})[0]; + auto loss = (*loss_fn)({logits, y})[0]; + auto loss_cpu = loss->To(Device()); + lossf += static_cast(loss_cpu.DataPtr())[0] / grad_accum_steps; + loss = loss / static_cast(grad_accum_steps); + autocast_guard.Disable(); + loss->Backward(); + } + optimizer->Step(); + + device_impl->SynchronizeDevice(train_device); + const auto step_end_time = std::chrono::steady_clock::now(); + const double duration_ms = std::chrono::duration(step_end_time - step_start_time).count(); + step_duration_ms.push_back(duration_ms); + + if (FLAGS_log_interval > 0 && ((step + 1) % FLAGS_log_interval == 0 || step + 1 == FLAGS_num_iteration)) { + std::cout << std::format( + "step {:4d}/{} | train loss {:.6f} | norm -1.0000 | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s)", step + 1, + FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_ms, tokens_per_step / (duration_ms / 1e3)) + << std::endl; + } + } + if (!step_duration_ms.empty()) { + double duration_sum_ms = 0.0; + for (size_t idx = step_duration_ms.size() > 1 ? 1 : 0; idx < step_duration_ms.size(); ++idx) { + duration_sum_ms += step_duration_ms[idx]; + } + const size_t averaged_steps + = step_duration_ms.size() > 1 ? step_duration_ms.size() - 1 : step_duration_ms.size(); + std::cout << std::format("final {} iters avg: {:.3f}ms", averaged_steps, duration_sum_ms / averaged_steps) + << std::endl; + } + + gflags::ShutDownCommandLineFlags(); + google::ShutdownGoogleLogging(); + return 0; +} diff --git a/infini_train/include/autograd/scatter_add.h b/infini_train/include/autograd/scatter_add.h new file mode 100644 index 00000000..3adc1586 --- /dev/null +++ b/infini_train/include/autograd/scatter_add.h @@ -0,0 +1,31 @@ +#pragma once + +#include +#include + +#include "infini_train/include/autograd/function.h" + +namespace infini_train { +class Tensor; +} + +namespace infini_train::autograd { + +class ScatterAdd : public Function { +public: + static constexpr char kType[] = "ScatterAddFunction"; + + ScatterAdd(int64_t dim, const std::vector &output_dims) + : Function(kType), dim_(dim), output_dims_(output_dims) {} + + std::vector> Forward(const std::vector> &input_tensors) override; + void SetupContext(const std::vector> &input_tensors, + const std::vector> &output_tensors) override; + std::vector> Backward(const std::vector> &grad_outputs) override; + +private: + int64_t dim_ = 0; + std::vector output_dims_; +}; + +} // namespace infini_train::autograd diff --git a/infini_train/include/autograd/topk.h b/infini_train/include/autograd/topk.h new file mode 100644 index 00000000..7752efca --- /dev/null +++ b/infini_train/include/autograd/topk.h @@ -0,0 +1,40 @@ +#pragma once + +#include +#include + +#include "infini_train/include/autograd/function.h" + +namespace infini_train { +class Tensor; +} + +namespace infini_train::autograd { + +// FIXME(dcj): Align this API with torch.topk and return both values and indices from Forward once +// InfiniTrain autograd supports marking individual outputs as non-differentiable. Today indices +// are exposed through TopIndices() to avoid waiting for gradients on metadata outputs. +class TopK : public Function { +public: + static constexpr char kType[] = "TopKFunction"; + + explicit TopK(int64_t topk, int64_t dim = -1, bool largest = true, bool sorted = true) + : Function(kType), topk_(topk), dim_(dim), largest_(largest), sorted_(sorted) {} + + std::vector> Forward(const std::vector> &input_tensors) override; + void SetupContext(const std::vector> &input_tensors, + const std::vector> &output_tensors) override; + std::vector> Backward(const std::vector> &grad_outputs) override; + + std::shared_ptr TopIndices() const; + +private: + int64_t topk_ = 1; + int64_t dim_ = -1; + bool largest_ = true; + bool sorted_ = true; + std::shared_ptr top_indices_; + std::vector input_dims_; +}; + +} // namespace infini_train::autograd diff --git a/infini_train/include/nn/modules/transformer/moe/experts.h b/infini_train/include/nn/modules/transformer/moe/experts.h new file mode 100644 index 00000000..a3dda7f0 --- /dev/null +++ b/infini_train/include/nn/modules/transformer/moe/experts.h @@ -0,0 +1,25 @@ +#pragma once + +#include +#include + +#include "infini_train/include/nn/modules/module.h" +#include "infini_train/include/nn/modules/transformer/transformer_config.h" + +namespace infini_train::nn::moe { + +class SequentialMLP : public CloneableModule { +public: + static constexpr char kType[] = "SequentialMLP"; + static constexpr char kExpertNamePrefix[] = "expert_"; + + explicit SequentialMLP(const TransformerConfig &config); + + std::vector> Forward(const std::vector> &input_tensors) override; + +private: + TransformerConfig config_; + int64_t num_local_experts_ = 0; +}; + +} // namespace infini_train::nn::moe diff --git a/infini_train/include/nn/modules/transformer/moe/moe_layer.h b/infini_train/include/nn/modules/transformer/moe/moe_layer.h new file mode 100644 index 00000000..e5fdb3ab --- /dev/null +++ b/infini_train/include/nn/modules/transformer/moe/moe_layer.h @@ -0,0 +1,25 @@ +#pragma once + +#include +#include + +#include "infini_train/include/nn/modules/module.h" +#include "infini_train/include/nn/modules/transformer/transformer_config.h" + +namespace infini_train::nn::moe { + +class MoELayer : public CloneableModule { +public: + static constexpr char kType[] = "MoELayer"; + static constexpr char kRouterLayerName[] = "router"; + static constexpr char kExpertsLayerName[] = "experts"; + + explicit MoELayer(const TransformerConfig &config); + + std::vector> Forward(const std::vector> &input_tensors) override; + +private: + TransformerConfig config_; +}; + +} // namespace infini_train::nn::moe diff --git a/infini_train/include/nn/modules/transformer/moe/moe_utils.h b/infini_train/include/nn/modules/transformer/moe/moe_utils.h new file mode 100644 index 00000000..f6941049 --- /dev/null +++ b/infini_train/include/nn/modules/transformer/moe/moe_utils.h @@ -0,0 +1,40 @@ +#pragma once + +#include +#include +#include + +#include "infini_train/include/nn/modules/transformer/transformer_config.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::nn::moe { + +struct PermutationMetadata { + std::shared_ptr sorted_indices; + std::shared_ptr gather_indices; + std::shared_ptr route_indices; + std::shared_ptr tokens_per_expert; + std::vector tokens_per_expert_host; +}; + +struct PermutationResult { + std::shared_ptr permuted_hidden_states; + std::shared_ptr permuted_probs; + PermutationMetadata metadata; +}; + +std::vector> TopkRoutingWithScoreFunction(const std::shared_ptr &logits, int64_t topk, + bool use_pre_softmax, + std::optional scaling_factor, + const MoEConfig::RouterScoreFunction &score_function); + +const MoEConfig &RequireMoEConfig(const TransformerConfig &config); +PermutationMetadata BuildPermutationMetadata(const std::shared_ptr &routing_map); +PermutationResult Permute(const std::shared_ptr &hidden_states_2d, + const std::shared_ptr &routing_probs_2d, + const std::shared_ptr &routing_map_2d); +std::shared_ptr Unpermute(const std::shared_ptr &permuted_hidden_states, + const std::shared_ptr &permuted_probs, const PermutationMetadata &metadata, + const std::vector &restore_shape); + +} // namespace infini_train::nn::moe diff --git a/infini_train/include/nn/modules/transformer/moe/router.h b/infini_train/include/nn/modules/transformer/moe/router.h new file mode 100644 index 00000000..1279c217 --- /dev/null +++ b/infini_train/include/nn/modules/transformer/moe/router.h @@ -0,0 +1,25 @@ +#pragma once + +#include +#include + +#include "infini_train/include/nn/modules/module.h" +#include "infini_train/include/nn/modules/transformer/transformer_config.h" + +namespace infini_train::nn::moe { + +class TopKRouter : public CloneableModule { +public: + static constexpr char kType[] = "TopKRouter"; + static constexpr char kParamWeightName[] = "weight"; + static constexpr char kParamBiasName[] = "bias"; + + explicit TopKRouter(const TransformerConfig &config); + + std::vector> Forward(const std::vector> &input_tensors) override; + +private: + TransformerConfig config_; +}; + +} // namespace infini_train::nn::moe diff --git a/infini_train/include/nn/modules/transformer/moe/token_dispatcher.h b/infini_train/include/nn/modules/transformer/moe/token_dispatcher.h new file mode 100644 index 00000000..f9e3c614 --- /dev/null +++ b/infini_train/include/nn/modules/transformer/moe/token_dispatcher.h @@ -0,0 +1,67 @@ +#pragma once + +#include +#include +#include + +#include "infini_train/include/nn/modules/transformer/moe/moe_utils.h" +#include "infini_train/include/nn/modules/transformer/transformer_config.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::nn::moe { + +class MoETokenDispatcher { +public: + virtual ~MoETokenDispatcher() = default; + + const PermutationResult &Dispatch(const std::shared_ptr &tokens, const std::shared_ptr &routing_map, + const std::shared_ptr &probs); + std::shared_ptr Combine(const std::shared_ptr &hidden_states) const; + +protected: + explicit MoETokenDispatcher(const TransformerConfig &config); + + virtual std::vector> DispatchPreprocess(const std::shared_ptr &tokens, + const std::shared_ptr &routing_map, + const std::shared_ptr &probs) + = 0; + virtual std::vector> TokenDispatch(const std::shared_ptr &hidden_states, + const std::shared_ptr &probs) const + = 0; + virtual const PermutationResult &DispatchPostprocess(const std::shared_ptr &hidden_states, + const std::shared_ptr &probs) + = 0; + virtual std::shared_ptr CombinePreprocess(const std::shared_ptr &hidden_states) const = 0; + virtual std::shared_ptr TokenCombine(const std::shared_ptr &hidden_states) const = 0; + virtual std::shared_ptr CombinePostprocess(const std::shared_ptr &hidden_states) const = 0; + + TransformerConfig config_; + PermutationResult dispatch_; + std::vector hidden_dims_; + std::shared_ptr routing_map_; + std::shared_ptr local_map_; + std::shared_ptr local_probs_; + int64_t num_tokens_ = 0; + int64_t hidden_size_ = 0; +}; + +class MoEAllGatherTokenDispatcher : public MoETokenDispatcher { +public: + MoEAllGatherTokenDispatcher(int64_t num_local_experts, const TransformerConfig &config); + +private: + std::vector> DispatchPreprocess(const std::shared_ptr &tokens, + const std::shared_ptr &routing_map, + const std::shared_ptr &probs) override; + std::vector> TokenDispatch(const std::shared_ptr &hidden_states, + const std::shared_ptr &probs) const override; + const PermutationResult &DispatchPostprocess(const std::shared_ptr &hidden_states, + const std::shared_ptr &probs) override; + std::shared_ptr CombinePreprocess(const std::shared_ptr &hidden_states) const override; + std::shared_ptr TokenCombine(const std::shared_ptr &hidden_states) const override; + std::shared_ptr CombinePostprocess(const std::shared_ptr &hidden_states) const override; + + int64_t num_local_experts_ = 0; +}; + +} // namespace infini_train::nn::moe diff --git a/infini_train/include/nn/modules/transformer/transformer_config.h b/infini_train/include/nn/modules/transformer/transformer_config.h index 448e7b30..713ce58f 100644 --- a/infini_train/include/nn/modules/transformer/transformer_config.h +++ b/infini_train/include/nn/modules/transformer/transformer_config.h @@ -15,11 +15,45 @@ enum class MLPType { kSwiGLU // SwiGLU activation }; +enum class FFNType { + kDense, // Standard dense MLP + kMoE // Mixture-of-Experts MLP +}; + enum class NormType { kLayerNorm, // LayerNorm kRMSNorm // RMSNorm }; +struct MoEConfig { + enum class RouterScoreFunction { + kSoftmax, + kSigmoid, + }; + + enum class TokenDispatcherType { + kAllGather, // Megatron-style AllGather dispatcher. Degenerates to local dispatch when TP=EP=1. + kAllToAll // Megatron-style AllToAll dispatcher for expert parallel MoE. + }; + + enum class ExpertImpl { + kSequential // Run local experts sequentially + }; + + int64_t num_experts = 0; + int64_t expert_parallel_size = 1; + int64_t router_topk = 1; + bool router_pre_softmax = false; + std::optional router_topk_scaling_factor = std::nullopt; + RouterScoreFunction router_score_function = RouterScoreFunction::kSoftmax; + float aux_loss_coeff = 0.0f; + std::optional expert_capacity_factor = std::nullopt; + bool pad_expert_input_to_capacity = false; + int64_t moe_ffn_hidden_size = 0; + TokenDispatcherType token_dispatcher_type = TokenDispatcherType::kAllGather; + ExpertImpl expert_impl = ExpertImpl::kSequential; +}; + struct TransformerConfig { int64_t block_size = 1024; // Max seq_len int64_t vocab_size = 50304; // Vocab size @@ -31,6 +65,7 @@ struct TransformerConfig { AttentionType attention_type = AttentionType::kStandard; // Attention mechanism type MLPType activation_type = MLPType::kGELU; // MLP activation type + FFNType ffn_type = FFNType::kDense; // Feed-forward module type NormType norm_type = NormType::kLayerNorm; // Normalization type bool add_bias_linear = true; // Whether to add learnable bias to all Linear layers in the Transformer block, @@ -43,6 +78,7 @@ struct TransformerConfig { float ffn_expansion_ratio = 4.0f; // MLP output: n_embd * ffn_expansion_ratio std::optional ffn_dim_multiplier = 1.5f; // FFN dim multiplier int64_t multiple_of = 256; // FFN dims must be multiple of this number + std::optional moe_config = std::nullopt; // RoPE config float rope_theta = 500000.0f; // theta in RoPE diff --git a/infini_train/src/autograd/scatter_add.cc b/infini_train/src/autograd/scatter_add.cc new file mode 100644 index 00000000..428f4f08 --- /dev/null +++ b/infini_train/src/autograd/scatter_add.cc @@ -0,0 +1,35 @@ +#include "infini_train/include/autograd/scatter_add.h" + +#include "glog/logging.h" + +#include "infini_train/include/dispatcher.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::autograd { + +std::vector> ScatterAdd::Forward(const std::vector> &input_tensors) { + CHECK_EQ(input_tensors.size(), 2); + const auto &values = input_tensors[0]; + const auto &indices = input_tensors[1]; + auto device = values->GetDevice().type(); + auto output = Dispatcher::Instance().Call>({device, "GatherBackward"}, values, indices, + dim_, output_dims_); + return {output}; +} + +void ScatterAdd::SetupContext(const std::vector> &input_tensors, + const std::vector> &) { + saved_tensors_ = {input_tensors[1]}; +} + +std::vector> ScatterAdd::Backward(const std::vector> &grad_outputs) { + CHECK_EQ(grad_outputs.size(), 1); + const auto &grad_output = grad_outputs[0]; + const auto &indices = saved_tensors_[0]; + auto device = grad_output->GetDevice().type(); + auto grad_values + = Dispatcher::Instance().Call>({device, "GatherForward"}, grad_output, indices, dim_); + return {grad_values, nullptr}; +} + +} // namespace infini_train::autograd diff --git a/infini_train/src/autograd/topk.cc b/infini_train/src/autograd/topk.cc new file mode 100644 index 00000000..4e0420b8 --- /dev/null +++ b/infini_train/src/autograd/topk.cc @@ -0,0 +1,39 @@ +#include "infini_train/include/autograd/topk.h" + +#include "glog/logging.h" + +#include "infini_train/include/dispatcher.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::autograd { + +std::vector> TopK::Forward(const std::vector> &input_tensors) { + CHECK_EQ(input_tensors.size(), 1); + CHECK_GT(topk_, 0); + const auto &input = input_tensors[0]; + auto device = input->GetDevice().type(); + auto topk_outputs = Dispatcher::Instance().Call>>( + {device, "TopKForward"}, input, topk_, dim_, largest_, sorted_); + CHECK_EQ(topk_outputs.size(), 2); + top_indices_ = topk_outputs[1]; + return {topk_outputs[0]}; +} + +void TopK::SetupContext(const std::vector> &input_tensors, + const std::vector> &) { + input_dims_ = input_tensors[0]->Dims(); + saved_tensors_ = {top_indices_}; +} + +std::vector> TopK::Backward(const std::vector> &grad_outputs) { + CHECK_EQ(grad_outputs.size(), 1); + const auto &top_grad = grad_outputs[0]; + const auto &top_indices = saved_tensors_[0]; + auto device = top_grad->GetDevice().type(); + return {Dispatcher::Instance().Call>({device, "TopKBackward"}, top_grad, top_indices, + input_dims_, dim_)}; +} + +std::shared_ptr TopK::TopIndices() const { return top_indices_; } + +} // namespace infini_train::autograd diff --git a/infini_train/src/kernels/cpu/concat.cc b/infini_train/src/kernels/cpu/concat.cc index b421063f..169cc557 100644 --- a/infini_train/src/kernels/cpu/concat.cc +++ b/infini_train/src/kernels/cpu/concat.cc @@ -1,7 +1,6 @@ -#include +#include #include #include -#include #include #include "glog/logging.h" @@ -42,23 +41,24 @@ std::shared_ptr ConcatForward(const std::vector> const int64_t K_total = std::accumulate(Ks.begin(), Ks.end(), int64_t{0}); output_dims[dim] = K_total; - auto output = std::make_shared(output_dims, DataType::kFLOAT32); + auto output = std::make_shared(output_dims, dtype, device); const int64_t outer_size = std::accumulate(output_dims.begin(), output_dims.begin() + dim, 1LL, std::multiplies()); const int64_t inner_size = std::accumulate(output_dims.begin() + dim + 1, output_dims.end(), 1LL, std::multiplies()); - const size_t elem_size = sizeof(float); + const size_t elem_size = kDataTypeToSize.at(dtype); - float *dst_ptr_base = static_cast(output->DataPtr()); + auto *dst_ptr_base = static_cast(output->DataPtr()); for (int64_t n = 0; n < outer_size; ++n) { int64_t offset_k = 0; - float *dst_block = dst_ptr_base + n * K_total * inner_size; + auto *dst_block = dst_ptr_base + n * K_total * inner_size * elem_size; for (size_t i = 0; i < inputs.size(); ++i) { const int64_t Ki = Ks[i]; - const float *src_ptr = static_cast(inputs[i]->DataPtr()) + n * Ki * inner_size; - float *dst_ptr = dst_block + offset_k * inner_size; + const auto *src_ptr + = static_cast(inputs[i]->DataPtr()) + n * Ki * inner_size * elem_size; + auto *dst_ptr = dst_block + offset_k * inner_size * elem_size; std::memcpy(dst_ptr, src_ptr, static_cast(Ki) * inner_size * elem_size); offset_k += Ki; } diff --git a/infini_train/src/kernels/cpu/topk.cc b/infini_train/src/kernels/cpu/topk.cc new file mode 100644 index 00000000..9e191143 --- /dev/null +++ b/infini_train/src/kernels/cpu/topk.cc @@ -0,0 +1,124 @@ +#include +#include +#include +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/dispatcher.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::kernels::cpu { + +std::vector> TopKForward(const std::shared_ptr &input, int64_t topk, int64_t dim, + bool largest, bool sorted) { + CHECK(input->Dtype() == DataType::kFLOAT32) << "CPU TopKForward currently supports float32 only"; + CHECK_GE(input->Dims().size(), 1); + (void)sorted; + + const auto &dims = input->Dims(); + if (dim < 0) { + dim += static_cast(dims.size()); + } + CHECK_GE(dim, 0); + CHECK_LT(dim, static_cast(dims.size())); + + const int64_t dim_size = dims[dim]; + CHECK_GT(dim_size, 0); + CHECK_GT(topk, 0); + CHECK_LE(topk, dim_size); + + int64_t outer_size = 1; + for (int64_t idx = 0; idx < dim; ++idx) { outer_size *= dims[idx]; } + int64_t inner_size = 1; + for (size_t idx = static_cast(dim) + 1; idx < dims.size(); ++idx) { inner_size *= dims[idx]; } + + auto topk_dims = dims; + topk_dims[dim] = topk; + auto top_values = std::make_shared(topk_dims, input->Dtype(), input->GetDevice()); + auto top_indices = std::make_shared(topk_dims, DataType::kINT64, input->GetDevice()); + + const float *in = static_cast(input->DataPtr()); + float *values = static_cast(top_values->DataPtr()); + int64_t *indices = static_cast(top_indices->DataPtr()); + for (int64_t outer = 0; outer < outer_size; ++outer) { + for (int64_t inner = 0; inner < inner_size; ++inner) { + std::vector selected_indices(dim_size, false); + for (int64_t selected = 0; selected < topk; ++selected) { + int64_t best_idx = -1; + float best_value + = largest ? -std::numeric_limits::infinity() : std::numeric_limits::infinity(); + for (int64_t idx = 0; idx < dim_size; ++idx) { + if (selected_indices[idx]) { + continue; + } + const float value = in[outer * dim_size * inner_size + idx * inner_size + inner]; + const bool better = largest ? value > best_value : value < best_value; + if (better) { + best_value = value; + best_idx = idx; + } + } + CHECK_GE(best_idx, 0); + selected_indices[best_idx] = true; + const int64_t out_offset = outer * topk * inner_size + selected * inner_size + inner; + values[out_offset] = best_value; + indices[out_offset] = best_idx; + } + } + } + + return {top_values, top_indices}; +} + +std::shared_ptr TopKBackward(const std::shared_ptr &grad_values, const std::shared_ptr &indices, + const std::vector &input_dims, int64_t dim) { + CHECK(indices->Dtype() == DataType::kINT64) << "CPU TopKBackward expects int64 indices"; + CHECK(grad_values->Dims() == indices->Dims()); + CHECK(!input_dims.empty()); + if (dim < 0) { + dim += static_cast(input_dims.size()); + } + CHECK_GE(dim, 0); + CHECK_LT(dim, static_cast(input_dims.size())); + + const int64_t dim_size = input_dims[dim]; + const int64_t topk = indices->Dims()[dim]; + int64_t outer_size = 1; + for (int64_t idx = 0; idx < dim; ++idx) { outer_size *= input_dims[idx]; } + int64_t inner_size = 1; + for (size_t idx = static_cast(dim) + 1; idx < input_dims.size(); ++idx) { inner_size *= input_dims[idx]; } + + auto grad_input = std::make_shared(input_dims, grad_values->Dtype(), grad_values->GetDevice()); + std::memset(grad_input->DataPtr(), 0, grad_input->SizeInBytes()); + + const size_t elem_size = kDataTypeToSize.at(grad_values->Dtype()); + const auto *src = static_cast(grad_values->DataPtr()); + auto *dst = static_cast(grad_input->DataPtr()); + const auto *idx_ptr = static_cast(indices->DataPtr()); + for (int64_t outer = 0; outer < outer_size; ++outer) { + for (int64_t inner = 0; inner < inner_size; ++inner) { + for (int64_t selected = 0; selected < topk; ++selected) { + const int64_t out_offset = outer * topk * inner_size + selected * inner_size + inner; + const int64_t selected_idx = idx_ptr[out_offset]; + CHECK_GE(selected_idx, 0); + CHECK_LT(selected_idx, dim_size); + std::memcpy(dst + (outer * dim_size * inner_size + selected_idx * inner_size + inner) * elem_size, + src + out_offset * elem_size, elem_size); + } + } + } + + return grad_input; +} + +} // namespace infini_train::kernels::cpu + +#define REGISTER_CPU_TOPK_KERNEL(kernel_name) \ + REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) + +REGISTER_CPU_TOPK_KERNEL(TopKForward) +REGISTER_CPU_TOPK_KERNEL(TopKBackward) + +#undef REGISTER_CPU_TOPK_KERNEL diff --git a/infini_train/src/kernels/cpu/transform.cc b/infini_train/src/kernels/cpu/transform.cc index 1a810b44..48063c7a 100644 --- a/infini_train/src/kernels/cpu/transform.cc +++ b/infini_train/src/kernels/cpu/transform.cc @@ -1,4 +1,6 @@ #include +#include +#include #include #include "glog/logging.h" @@ -167,14 +169,15 @@ std::shared_ptr RepeatInterleaveForward(const std::shared_ptr &i output_dims[dim] = dim_size * repeat; auto output = std::make_shared(output_dims, input->Dtype(), input->GetDevice()); - const float *input_ptr = static_cast(input->DataPtr()); - float *output_ptr = static_cast(output->DataPtr()); + const size_t elem_size = kDataTypeToSize.at(input->Dtype()); + const auto *input_ptr = static_cast(input->DataPtr()); + auto *output_ptr = static_cast(output->DataPtr()); for (int64_t o = 0; o < outer; ++o) { for (int64_t i = 0; i < dim_size; ++i) { for (int r = 0; r < repeat; ++r) { - std::memcpy(output_ptr + ((o * dim_size * repeat + i * repeat + r) * inner), - input_ptr + ((o * dim_size + i) * inner), sizeof(float) * inner); + std::memcpy(output_ptr + ((o * dim_size * repeat + i * repeat + r) * inner * elem_size), + input_ptr + ((o * dim_size + i) * inner * elem_size), elem_size * inner); } } } diff --git a/infini_train/src/kernels/cuda/topk.cu b/infini_train/src/kernels/cuda/topk.cu new file mode 100644 index 00000000..32044c3f --- /dev/null +++ b/infini_train/src/kernels/cuda/topk.cu @@ -0,0 +1,155 @@ +#include "glog/logging.h" + +#include "infini_train/include/common/cuda/common_cuda.h" +#include "infini_train/include/core/runtime/device_guard.h" +#include "infini_train/include/dispatcher.h" +#include "infini_train/include/tensor.h" + +#include "infini_train/src/core/runtime/cuda/cuda_dispatch.h" +#include "infini_train/src/core/runtime/cuda/cuda_runtime_common.h" + +namespace infini_train::kernels::cuda { + +template +__global__ void TopKForwardKernel(const T *__restrict__ input, T *__restrict__ top_values, + int64_t *__restrict__ top_indices, int64_t rows, int64_t dim_size, int64_t inner_size, + int64_t topk, bool largest) { + int64_t row = blockIdx.x * blockDim.x + threadIdx.x; + if (row >= rows) { + return; + } + + const int64_t outer = row / inner_size; + const int64_t inner = row % inner_size; + for (int64_t idx = 0; idx < dim_size; ++idx) { + const float value = static_cast(input[outer * dim_size * inner_size + idx * inner_size + inner]); + int64_t rank = 0; + for (int64_t other_idx = 0; other_idx < dim_size; ++other_idx) { + const float other_value + = static_cast(input[outer * dim_size * inner_size + other_idx * inner_size + inner]); + const bool ranks_before = largest ? (other_value > value || (other_value == value && other_idx < idx)) + : (other_value < value || (other_value == value && other_idx < idx)); + if (ranks_before) { + ++rank; + } + } + if (rank < topk) { + const int64_t out_offset = outer * topk * inner_size + rank * inner_size + inner; + top_values[out_offset] = input[outer * dim_size * inner_size + idx * inner_size + inner]; + top_indices[out_offset] = idx; + } + } +} + +std::vector> TopKForward(const std::shared_ptr &input, int64_t topk, int64_t dim, + bool largest, bool sorted) { + CHECK_GE(input->Dims().size(), 1); + (void)sorted; + const auto &dims = input->Dims(); + if (dim < 0) { + dim += static_cast(dims.size()); + } + CHECK_GE(dim, 0); + CHECK_LT(dim, static_cast(dims.size())); + + const int64_t dim_size = dims[dim]; + CHECK_GT(dim_size, 0); + CHECK_GT(topk, 0); + CHECK_LE(topk, dim_size); + int64_t outer_size = 1; + for (int64_t idx = 0; idx < dim; ++idx) { outer_size *= dims[idx]; } + int64_t inner_size = 1; + for (size_t idx = static_cast(dim) + 1; idx < dims.size(); ++idx) { inner_size *= dims[idx]; } + const int64_t rows = outer_size * inner_size; + + auto topk_dims = dims; + topk_dims[dim] = topk; + auto top_values = std::make_shared(topk_dims, input->Dtype(), input->GetDevice()); + auto top_indices = std::make_shared(topk_dims, DataType::kINT64, input->GetDevice()); + + auto device = input->GetDevice(); + const auto &stream = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); + const int threads = 256; + const int blocks = static_cast((rows + threads - 1) / threads); + + core::cuda::DispatchCudaFunc( + input->Dtype(), + [=]() { + TopKForwardKernel<<>>( + static_cast(input->DataPtr()), static_cast(top_values->DataPtr()), + static_cast(top_indices->DataPtr()), rows, dim_size, inner_size, topk, largest); + }, + "CUDA TopKForward"); + + return {top_values, top_indices}; +} + +template +__global__ void TopKBackwardKernel(const T *__restrict__ grad_values, const int64_t *__restrict__ indices, + T *__restrict__ grad_input, int64_t rows, int64_t dim_size, int64_t inner_size, + int64_t topk) { + int64_t row = blockIdx.x * blockDim.x + threadIdx.x; + if (row >= rows) { + return; + } + + const int64_t outer = row / inner_size; + const int64_t inner = row % inner_size; + for (int64_t selected = 0; selected < topk; ++selected) { + const int64_t out_offset = outer * topk * inner_size + selected * inner_size + inner; + const int64_t selected_idx = indices[out_offset]; + grad_input[outer * dim_size * inner_size + selected_idx * inner_size + inner] = grad_values[out_offset]; + } +} + +std::shared_ptr TopKBackward(const std::shared_ptr &grad_values, const std::shared_ptr &indices, + const std::vector &input_dims, int64_t dim) { + CHECK(indices->Dtype() == DataType::kINT64) << "CUDA TopKBackward expects int64 indices"; + CHECK(grad_values->Dims() == indices->Dims()); + CHECK(!input_dims.empty()); + if (dim < 0) { + dim += static_cast(input_dims.size()); + } + CHECK_GE(dim, 0); + CHECK_LT(dim, static_cast(input_dims.size())); + + const int64_t dim_size = input_dims[dim]; + const int64_t topk = indices->Dims()[dim]; + int64_t outer_size = 1; + for (int64_t idx = 0; idx < dim; ++idx) { outer_size *= input_dims[idx]; } + int64_t inner_size = 1; + for (size_t idx = static_cast(dim) + 1; idx < input_dims.size(); ++idx) { inner_size *= input_dims[idx]; } + const int64_t rows = outer_size * inner_size; + + auto grad_input = std::make_shared(input_dims, grad_values->Dtype(), grad_values->GetDevice()); + auto device = grad_values->GetDevice(); + const auto &stream = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); + CUDA_CHECK(cudaMemsetAsync(grad_input->DataPtr(), 0, grad_input->SizeInBytes(), stream)); + + const int threads = 256; + const int blocks = static_cast((rows + threads - 1) / threads); + core::cuda::DispatchCudaFunc( + grad_values->Dtype(), + [=]() { + TopKBackwardKernel<<>>( + static_cast(grad_values->DataPtr()), static_cast(indices->DataPtr()), + static_cast(grad_input->DataPtr()), rows, dim_size, inner_size, topk); + }, + "CUDA TopKBackward"); + + return grad_input; +} + +} // namespace infini_train::kernels::cuda + +#define REGISTER_CUDA_TOPK_KERNEL(kernel_name) \ + REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) + +REGISTER_CUDA_TOPK_KERNEL(TopKForward) +REGISTER_CUDA_TOPK_KERNEL(TopKBackward) + +#undef REGISTER_CUDA_TOPK_KERNEL diff --git a/infini_train/src/nn/modules/transformer/mlp.cc b/infini_train/src/nn/modules/transformer/mlp.cc index 9f1f488c..ac35d144 100644 --- a/infini_train/src/nn/modules/transformer/mlp.cc +++ b/infini_train/src/nn/modules/transformer/mlp.cc @@ -37,6 +37,12 @@ MLP::MLP(const TransformerConfig &config) : CloneableModule(kType) { // Round up to multiple_of ffn_hidden = (ffn_hidden + config.multiple_of - 1) / config.multiple_of * config.multiple_of; + if (config.ffn_type == FFNType::kMoE && config.moe_config.has_value() + && config.moe_config->moe_ffn_hidden_size > 0) { + ffn_hidden = config.moe_config->moe_ffn_hidden_size; + } + CHECK_GT(ffn_hidden, 0); + // c_fc: ColumnParallel (input full, output parallel) modules_[kCFcLayerName] = std::make_shared( /*in_features=*/config.n_embd, /*out_features=*/ffn_hidden, diff --git a/infini_train/src/nn/modules/transformer/moe/experts.cc b/infini_train/src/nn/modules/transformer/moe/experts.cc new file mode 100644 index 00000000..fa8681da --- /dev/null +++ b/infini_train/src/nn/modules/transformer/moe/experts.cc @@ -0,0 +1,65 @@ +#include "infini_train/include/nn/modules/transformer/moe/experts.h" + +#include +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/nn/functional.h" +#include "infini_train/include/nn/modules/transformer/mlp.h" +#include "infini_train/include/nn/modules/transformer/moe/moe_utils.h" +#include "infini_train/include/nn/modules/transformer/moe/token_dispatcher.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::nn::moe { + +SequentialMLP::SequentialMLP(const TransformerConfig &config) : CloneableModule(kType), config_(config) { + const auto &moe_config = RequireMoEConfig(config_); + CHECK(moe_config.expert_impl == MoEConfig::ExpertImpl::kSequential); + CHECK_EQ(moe_config.expert_parallel_size, 1) + << "Current InfiniTrain MoE implementation supports expert_parallel_size=1 only"; + CHECK(moe_config.token_dispatcher_type == MoEConfig::TokenDispatcherType::kAllGather) + << "Current InfiniTrain MoE implementation supports AllGather dispatcher only"; + + num_local_experts_ = moe_config.num_experts; + CHECK_GT(num_local_experts_, 0); + + for (int64_t expert_idx = 0; expert_idx < num_local_experts_; ++expert_idx) { + modules_[std::string(kExpertNamePrefix) + std::to_string(expert_idx)] = std::make_shared(config_); + } +} + +std::vector> SequentialMLP::Forward(const std::vector> &input_tensors) { + CHECK_EQ(input_tensors.size(), 3); + auto hidden_states = input_tensors[0]; + auto routing_probs = input_tensors[1]; + auto routing_map = input_tensors[2]; + std::unique_ptr dispatcher + = std::make_unique(num_local_experts_, config_); + const auto &dispatch = dispatcher->Dispatch(hidden_states, routing_map, routing_probs); + + std::vector> expert_outputs; + int64_t start = 0; + for (int64_t expert_idx = 0; expert_idx < num_local_experts_; ++expert_idx) { + const int64_t num_tokens_for_expert = dispatch.metadata.tokens_per_expert_host[expert_idx]; + const int64_t end = start + num_tokens_for_expert; + if (num_tokens_for_expert == 0) { + start = end; + continue; + } + + auto expert_input = dispatch.permuted_hidden_states->Slice(0, start, end); + auto expert_name = std::string(kExpertNamePrefix) + std::to_string(expert_idx); + expert_outputs.push_back((*modules_.at(expert_name))({expert_input})[0]); + start = end; + } + CHECK_EQ(start, dispatch.permuted_hidden_states->Dims()[0]); + CHECK(!expert_outputs.empty()) << "No tokens were dispatched to any local expert"; + + auto permuted_expert_output + = expert_outputs.size() == 1 ? expert_outputs[0] : nn::function::Concat(expert_outputs, 0); + return {dispatcher->Combine(permuted_expert_output)}; +} + +} // namespace infini_train::nn::moe diff --git a/infini_train/src/nn/modules/transformer/moe/moe_layer.cc b/infini_train/src/nn/modules/transformer/moe/moe_layer.cc new file mode 100644 index 00000000..1e15fe81 --- /dev/null +++ b/infini_train/src/nn/modules/transformer/moe/moe_layer.cc @@ -0,0 +1,33 @@ +#include "infini_train/include/nn/modules/transformer/moe/moe_layer.h" + +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/nn/modules/transformer/moe/experts.h" +#include "infini_train/include/nn/modules/transformer/moe/moe_utils.h" +#include "infini_train/include/nn/modules/transformer/moe/router.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::nn::moe { + +MoELayer::MoELayer(const TransformerConfig &config) : CloneableModule(kType), config_(config) { + const auto &moe_config = RequireMoEConfig(config_); + CHECK(config_.ffn_type == FFNType::kMoE); + CHECK(moe_config.token_dispatcher_type == MoEConfig::TokenDispatcherType::kAllGather) + << "Current InfiniTrain MoE implementation supports AllGather dispatcher only"; + + modules_[kRouterLayerName] = std::make_shared(config_); + modules_[kExpertsLayerName] = std::make_shared(config_); +} + +std::vector> MoELayer::Forward(const std::vector> &input_tensors) { + CHECK_EQ(input_tensors.size(), 1); + auto hidden_states = input_tensors[0]; + auto router_output = (*modules_.at(kRouterLayerName))({hidden_states}); + CHECK_EQ(router_output.size(), 2); + return (*modules_.at(kExpertsLayerName))({hidden_states, router_output[0], router_output[1]}); +} + +} // namespace infini_train::nn::moe diff --git a/infini_train/src/nn/modules/transformer/moe/moe_utils.cc b/infini_train/src/nn/modules/transformer/moe/moe_utils.cc new file mode 100644 index 00000000..040b29df --- /dev/null +++ b/infini_train/src/nn/modules/transformer/moe/moe_utils.cc @@ -0,0 +1,180 @@ +#include "infini_train/include/nn/modules/transformer/moe/moe_utils.h" + +#include + +#include "glog/logging.h" + +#include "infini_train/include/autograd/scatter.h" +#include "infini_train/include/autograd/scatter_add.h" +#include "infini_train/include/autograd/topk.h" +#include "infini_train/include/dispatcher.h" +#include "infini_train/include/nn/functional.h" + +namespace infini_train::nn::moe { + +std::vector> +TopkRoutingWithScoreFunction(const std::shared_ptr &logits, int64_t topk, bool use_pre_softmax, + std::optional scaling_factor, + const MoEConfig::RouterScoreFunction &score_function) { + + // Megatron TopKRouter returns dense tensors: + // routing_probs: [num_tokens, num_experts] + // routing_map: [num_tokens, num_experts], bool + std::shared_ptr top_probs; + std::shared_ptr top_indices; + + if (score_function == MoEConfig::RouterScoreFunction::kSoftmax) { + if (use_pre_softmax) { + auto scores = function::Softmax(logits, -1); + auto topk_function = std::make_shared(topk); + top_probs = topk_function->Apply({scores})[0]; + top_indices = topk_function->TopIndices(); + } else { + auto topk_function = std::make_shared(topk); + auto top_scores = topk_function->Apply({logits})[0]; + top_indices = topk_function->TopIndices(); + top_probs = function::Softmax(top_scores, -1); + } + } else if (score_function == MoEConfig::RouterScoreFunction::kSigmoid) { + auto sigmoid_scores = function::Sigmoid(logits); + auto topk_function = std::make_shared(topk); + top_probs = topk_function->Apply({sigmoid_scores})[0]; + top_indices = topk_function->TopIndices(); + if (topk > 1) { + top_probs = top_probs / (top_probs->Sum(-1, true) + 1e-20f); + } + } else { + LOG(FATAL) << "Unsupported MoE router score function"; + } + + if (scaling_factor.has_value()) { + top_probs = top_probs * scaling_factor.value(); + } + + auto routing_probs = std::make_shared(logits->Dims())->Apply({top_probs, top_indices})[0]; + auto routing_map_values = std::make_shared(top_indices->Equals(top_indices)->To(DataType::kBOOL)); + auto routing_map = Dispatcher::Instance().Call>( + {logits->GetDevice().type(), "ScatterForward"}, routing_map_values, top_indices, logits->Dims()); + return {routing_probs, routing_map}; +} + +const MoEConfig &RequireMoEConfig(const TransformerConfig &config) { + CHECK(config.moe_config.has_value()) << "MoE layer requires TransformerConfig::moe_config"; + return config.moe_config.value(); +} + +PermutationMetadata BuildPermutationMetadata(const std::shared_ptr &routing_map) { + CHECK(routing_map->Dtype() == DataType::kBOOL); + CHECK_EQ(routing_map->Dims().size(), 2); + + const int64_t num_tokens = routing_map->Dims()[0]; + const int64_t num_experts = routing_map->Dims()[1]; + CHECK_GT(num_tokens, 0); + CHECK_GT(num_experts, 0); + + Tensor routing_map_cpu_storage = routing_map->To(Device()); + auto routing_map_cpu = std::make_shared(routing_map_cpu_storage); + const auto *routing_map_ptr = static_cast(routing_map_cpu->DataPtr()); + + std::vector sorted_indices_host; + std::vector route_indices_host; + std::vector tokens_per_expert_host; + sorted_indices_host.reserve(routing_map->NumElements()); + route_indices_host.reserve(routing_map->NumElements()); + tokens_per_expert_host.reserve(num_experts); + + for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { + int64_t tokens_for_expert = 0; + for (int64_t token_idx = 0; token_idx < num_tokens; ++token_idx) { + if (routing_map_ptr[token_idx * num_experts + expert_idx]) { + sorted_indices_host.push_back(token_idx); + route_indices_host.push_back(token_idx * num_experts + expert_idx); + ++tokens_for_expert; + } + } + tokens_per_expert_host.push_back(tokens_for_expert); + } + + const int64_t num_dispatched_tokens = static_cast(sorted_indices_host.size()); + auto sorted_indices_cpu + = std::make_shared(std::vector{num_dispatched_tokens}, DataType::kINT64, Device()); + auto route_indices_cpu + = std::make_shared(std::vector{num_dispatched_tokens}, DataType::kINT64, Device()); + auto gather_indices_cpu + = std::make_shared(std::vector{num_dispatched_tokens, 1}, DataType::kINT64, Device()); + auto tokens_per_expert_cpu + = std::make_shared(std::vector{num_experts}, DataType::kINT64, Device()); + + auto *sorted_indices_ptr = static_cast(sorted_indices_cpu->DataPtr()); + auto *route_indices_ptr = static_cast(route_indices_cpu->DataPtr()); + auto *gather_indices_ptr = static_cast(gather_indices_cpu->DataPtr()); + auto *tokens_per_expert_ptr = static_cast(tokens_per_expert_cpu->DataPtr()); + for (int64_t idx = 0; idx < num_dispatched_tokens; ++idx) { + sorted_indices_ptr[idx] = sorted_indices_host[idx]; + route_indices_ptr[idx] = route_indices_host[idx]; + gather_indices_ptr[idx] = sorted_indices_host[idx]; + } + for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { + tokens_per_expert_ptr[expert_idx] = tokens_per_expert_host[expert_idx]; + } + + auto to_device = [&](const std::shared_ptr &cpu_tensor) -> std::shared_ptr { + if (routing_map->GetDevice().type() == Device::DeviceType::kCPU) { + return cpu_tensor; + } + return std::make_shared(cpu_tensor->To(routing_map->GetDevice())); + }; + + return {to_device(sorted_indices_cpu), to_device(gather_indices_cpu), to_device(route_indices_cpu), + to_device(tokens_per_expert_cpu), tokens_per_expert_host}; +} + +PermutationResult Permute(const std::shared_ptr &hidden_states_2d, + const std::shared_ptr &routing_probs_2d, + const std::shared_ptr &routing_map_2d) { + CHECK_EQ(hidden_states_2d->Dims().size(), 2); + CHECK(routing_probs_2d->Dims() == routing_map_2d->Dims()); + CHECK(routing_map_2d->Dtype() == DataType::kBOOL); + + const int64_t hidden_size = hidden_states_2d->Dims()[1]; + auto metadata = BuildPermutationMetadata(routing_map_2d); + const int64_t num_dispatched_tokens = metadata.sorted_indices->Dims()[0]; + + std::shared_ptr permuted_hidden_states; + std::shared_ptr permuted_probs; + if (num_dispatched_tokens == 0) { + permuted_hidden_states = std::make_shared(std::vector{0, hidden_size}, + hidden_states_2d->Dtype(), hidden_states_2d->GetDevice()); + permuted_probs = std::make_shared(std::vector{0}, routing_probs_2d->Dtype(), + routing_probs_2d->GetDevice()); + } else { + auto gather_indices = metadata.gather_indices; + if (hidden_size != 1) { + gather_indices = metadata.gather_indices->RepeatInterleave(hidden_size, 1); + } + permuted_hidden_states = hidden_states_2d->Gather(0, gather_indices); + permuted_probs = routing_probs_2d->View({static_cast(routing_probs_2d->NumElements())}) + ->Gather(0, metadata.route_indices); + } + + return {permuted_hidden_states, permuted_probs, metadata}; +} + +std::shared_ptr Unpermute(const std::shared_ptr &permuted_hidden_states, + const std::shared_ptr &permuted_probs, const PermutationMetadata &metadata, + const std::vector &restore_shape) { + CHECK_EQ(permuted_hidden_states->Dims().size(), 2); + CHECK_EQ(permuted_probs->Dims().size(), 1); + CHECK_EQ(permuted_hidden_states->Dims()[0], permuted_probs->Dims()[0]); + CHECK_EQ(restore_shape.size(), 2); + + auto weighted = permuted_hidden_states * permuted_probs->View({permuted_probs->Dims()[0], 1}); + auto scatter_indices = metadata.gather_indices; + const int64_t hidden_size = restore_shape[1]; + if (hidden_size != 1) { + scatter_indices = metadata.gather_indices->RepeatInterleave(hidden_size, 1); + } + return std::make_shared(0, restore_shape)->Apply({weighted, scatter_indices})[0]; +} + +} // namespace infini_train::nn::moe diff --git a/infini_train/src/nn/modules/transformer/moe/router.cc b/infini_train/src/nn/modules/transformer/moe/router.cc new file mode 100644 index 00000000..25208684 --- /dev/null +++ b/infini_train/src/nn/modules/transformer/moe/router.cc @@ -0,0 +1,57 @@ +#include "infini_train/include/nn/modules/transformer/moe/router.h" + +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/autograd/linear.h" +#include "infini_train/include/autograd/scatter.h" +#include "infini_train/include/autograd/topk.h" +#include "infini_train/include/nn/functional.h" +#include "infini_train/include/nn/init.h" +#include "infini_train/include/nn/modules/transformer/moe/moe_utils.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::nn::moe { + +TopKRouter::TopKRouter(const TransformerConfig &config) : CloneableModule(kType), config_(config) { + const auto &moe_config = RequireMoEConfig(config_); + CHECK_GT(moe_config.num_experts, 0); + CHECK_GT(moe_config.router_topk, 0); + CHECK_LE(moe_config.router_topk, moe_config.num_experts); + parameters_[kParamWeightName] + = std::make_shared(std::vector{moe_config.num_experts, config_.n_embd}, DataType::kFLOAT32, + device_) + ->RequiresGrad(); + init::KaimingUniform(parameters_[kParamWeightName]); + + if (config_.add_bias_linear) { + parameters_[kParamBiasName] + = std::make_shared(std::vector{moe_config.num_experts}, DataType::kFLOAT32, device_) + ->RequiresGrad(); + parameters_[kParamBiasName]->Fill(0.0f); + } +} + +std::vector> TopKRouter::Forward(const std::vector> &input_tensors) { + CHECK_EQ(input_tensors.size(), 1); + std::vector> linear_inputs{input_tensors[0], parameters_.at(kParamWeightName)}; + if (parameters_.contains(kParamBiasName)) { + linear_inputs.push_back(parameters_.at(kParamBiasName)); + } + + auto logits = std::make_shared()->Apply(linear_inputs)[0]; + + const auto &moe_config = RequireMoEConfig(config_); + + auto routing_results + = TopkRoutingWithScoreFunction(logits, moe_config.router_topk, moe_config.router_pre_softmax, + moe_config.router_topk_scaling_factor, moe_config.router_score_function); + + auto routing_probs = routing_results[0]; + auto routing_map = routing_results[1]; + return {routing_probs, routing_map}; +} + +} // namespace infini_train::nn::moe diff --git a/infini_train/src/nn/modules/transformer/moe/token_dispatcher.cc b/infini_train/src/nn/modules/transformer/moe/token_dispatcher.cc new file mode 100644 index 00000000..667dba8f --- /dev/null +++ b/infini_train/src/nn/modules/transformer/moe/token_dispatcher.cc @@ -0,0 +1,95 @@ +#include "infini_train/include/nn/modules/transformer/moe/token_dispatcher.h" + +#include +#include + +#include "glog/logging.h" + +namespace infini_train::nn::moe { + +MoETokenDispatcher::MoETokenDispatcher(const TransformerConfig &config) : config_(config) {} + +const PermutationResult &MoETokenDispatcher::Dispatch(const std::shared_ptr &tokens, + const std::shared_ptr &routing_map, + const std::shared_ptr &probs) { + auto preprocessed = DispatchPreprocess(tokens, routing_map, probs); + auto dispatched = TokenDispatch(preprocessed[0], preprocessed[1]); + return DispatchPostprocess(dispatched[0], dispatched[1]); +} + +std::shared_ptr MoETokenDispatcher::Combine(const std::shared_ptr &hidden_states) const { + auto preprocessed = CombinePreprocess(hidden_states); + auto combined = TokenCombine(preprocessed); + return CombinePostprocess(combined); +} + +MoEAllGatherTokenDispatcher::MoEAllGatherTokenDispatcher(int64_t num_local_experts, const TransformerConfig &config) + : MoETokenDispatcher(config), num_local_experts_(num_local_experts) { + CHECK_GT(num_local_experts_, 0); +} + +std::vector> +MoEAllGatherTokenDispatcher::DispatchPreprocess(const std::shared_ptr &tokens, + const std::shared_ptr &routing_map, + const std::shared_ptr &probs) { + CHECK(probs->Dims() == routing_map->Dims()); + CHECK(routing_map->Dtype() == DataType::kBOOL); + CHECK_GE(tokens->Dims().size(), 2); + + hidden_dims_ = tokens->Dims(); + hidden_size_ = hidden_dims_.back(); + CHECK_GT(hidden_size_, 0); + num_tokens_ = tokens->NumElements() / hidden_size_; + CHECK_EQ(probs->Dims().back(), num_local_experts_); + CHECK_EQ(probs->NumElements(), static_cast(num_tokens_ * num_local_experts_)); + + routing_map_ = routing_map->View({num_tokens_, num_local_experts_}); + auto hidden_states_2d = tokens->View({num_tokens_, hidden_size_}); + auto probs_2d = probs->View({num_tokens_, num_local_experts_}); + return {hidden_states_2d, probs_2d}; +} + +std::vector> +MoEAllGatherTokenDispatcher::TokenDispatch(const std::shared_ptr &hidden_states, + const std::shared_ptr &probs) const { + // AllGather dispatcher will gather tokens across TP*EP ranks here. For the current single-rank + // path (tp_size=1, ep_size=1), no communication is required. + return {hidden_states, probs}; +} + +const PermutationResult &MoEAllGatherTokenDispatcher::DispatchPostprocess(const std::shared_ptr &hidden_states, + const std::shared_ptr &probs) { + CHECK(routing_map_ != nullptr); + CHECK_EQ(hidden_states->Dims().size(), 2); + CHECK_EQ(probs->Dims().size(), 2); + CHECK_EQ(hidden_states->Dims()[0], probs->Dims()[0]); + CHECK_EQ(probs->Dims()[1], num_local_experts_); + + // With ep_size=1 all experts are local, so the local expert map/probs are the gathered map/probs. + // Future EP support should slice [local_expert_start, local_expert_end) after AllGather. + local_map_ = routing_map_; + local_probs_ = probs; + dispatch_ = Permute(hidden_states, local_probs_, local_map_); + routing_map_ = nullptr; + return dispatch_; +} + +std::shared_ptr +MoEAllGatherTokenDispatcher::CombinePreprocess(const std::shared_ptr &hidden_states) const { + CHECK(local_map_ != nullptr); + CHECK(local_probs_ != nullptr); + return Unpermute(hidden_states, dispatch_.permuted_probs, dispatch_.metadata, + std::vector{num_tokens_, hidden_size_}); +} + +std::shared_ptr MoEAllGatherTokenDispatcher::TokenCombine(const std::shared_ptr &hidden_states) const { + // AllGather dispatcher will reduce-scatter combined token outputs here. For ep_size=1 this is a no-op. + return hidden_states; +} + +std::shared_ptr +MoEAllGatherTokenDispatcher::CombinePostprocess(const std::shared_ptr &hidden_states) const { + return hidden_states->View(hidden_dims_); +} + +} // namespace infini_train::nn::moe diff --git a/infini_train/src/nn/modules/transformer/transformer.cc b/infini_train/src/nn/modules/transformer/transformer.cc index c7e0f28c..bdcde449 100644 --- a/infini_train/src/nn/modules/transformer/transformer.cc +++ b/infini_train/src/nn/modules/transformer/transformer.cc @@ -15,6 +15,7 @@ #include "infini_train/include/nn/modules/sparse.h" #include "infini_train/include/nn/modules/transformer/causal_self_attention.h" #include "infini_train/include/nn/modules/transformer/mlp.h" +#include "infini_train/include/nn/modules/transformer/moe/moe_layer.h" #include "infini_train/include/nn/modules/transformer/utils.h" #include "infini_train/include/nn/parallel/global.h" #include "infini_train/include/nn/parallel/tensor_parallel.h" @@ -86,7 +87,11 @@ TransformerLayer::TransformerLayer(const nn::TransformerConfig &config) : Clonea } modules_[kAttnLayerName] = std::make_shared(config); - modules_[kMlpLayerName] = std::make_shared(config); + if (config.ffn_type == FFNType::kMoE) { + modules_[kMlpLayerName] = std::make_shared(config); + } else { + modules_[kMlpLayerName] = std::make_shared(config); + } } std::vector> TransformerLayer::Forward(const std::vector> &x) { diff --git a/scripts/run_models_and_profile.bash b/scripts/run_models_and_profile.bash index e3c67293..351e6755 100755 --- a/scripts/run_models_and_profile.bash +++ b/scripts/run_models_and_profile.bash @@ -66,12 +66,15 @@ read_var() { jq -r --arg k "$key" '.variables[$k] // empty' "$CONFIG_FILE" } -BUILD_DIR="$(read_var BUILD_DIR)"; : "${BUILD_DIR:=../build}" -LOG_DIR="$(read_var LOG_DIR)"; : "${LOG_DIR:=logs}" -PROFILE_LOG_DIR="$(read_var PROFILE_LOG_DIR)"; : "${PROFILE_LOG_DIR:=./profile_logs}" -COMPARE_LOG_DIR="$(read_var COMPARE_LOG_DIR)"; : "${COMPARE_LOG_DIR:=}" -RUN_CTEST="$(read_var RUN_CTEST)"; : "${RUN_CTEST:=true}" -CTEST_CMD="$(read_var CTEST_CMD)"; : "${CTEST_CMD:=ctest --output-on-failure -LE cuda -j$(nproc) && ctest --output-on-failure -L cuda -j1}" +BUILD_DIR="$(read_var BUILD_DIR)"; : "${BUILD_DIR:=../build}" +LOG_DIR="$(read_var LOG_DIR)"; : "${LOG_DIR:=logs}" +PROFILE_LOG_DIR="$(read_var PROFILE_LOG_DIR)"; : "${PROFILE_LOG_DIR:=./profile_logs}" +COMPARE_LOG_DIR="$(read_var COMPARE_LOG_DIR)"; : "${COMPARE_LOG_DIR:=}" +RUN_CTEST="$(read_var RUN_CTEST)"; : "${RUN_CTEST:=true}" +CTEST_CMD="$(read_var CTEST_CMD)"; : "${CTEST_CMD:=ctest --output-on-failure -LE cuda -j$(nproc) && ctest --output-on-failure -L cuda -j1}" +GPT2_TEST_GROUPS="$(read_var GPT2_TEST_GROUPS)"; : "${GPT2_TEST_GROUPS:=basic,zero,lora}" +LLAMA3_TEST_GROUPS="$(read_var LLAMA3_TEST_GROUPS)"; : "${LLAMA3_TEST_GROUPS:=basic,zero,lora}" +MIXTRAL_TEST_GROUPS="$(read_var MIXTRAL_TEST_GROUPS)"; : "${MIXTRAL_TEST_GROUPS:=moe}" mkdir -p "$BUILD_DIR" "$LOG_DIR" "$PROFILE_LOG_DIR" @@ -219,6 +222,54 @@ args_string_for_test() { ' "$CONFIG_FILE" | paste -sd' ' - } +tag_enabled_for_model() { + local tag="$1" + local enabled_tags="$2" + + if [[ "$enabled_tags" == "*" ]]; then + return 0 + fi + + IFS=',' read -r -a tags <<< "$enabled_tags" + for raw_tag in "${tags[@]}"; do + local enabled_tag + enabled_tag="$(normalize_tag "$raw_tag")" + if [[ "$enabled_tag" == "$tag" ]]; then + return 0 + fi + done + return 1 +} + +model_has_selected_group() { + local enabled_tags="$1" + + for ((gi=0; gi&2 + exit 1 + fi + if [[ ! -f "$MIXTRAL_LLMC_FILEPATH" ]]; then + echo "Error: missing MIXTRAL_LLMC_FILEPATH: $MIXTRAL_LLMC_FILEPATH" >&2 + exit 1 + fi + fi +} + # Run tests num_builds=$(jq '.builds | length' "$CONFIG_FILE") num_groups=$(jq '.test_groups | length' "$CONFIG_FILE") @@ -226,7 +277,7 @@ num_groups=$(jq '.test_groups | length' "$CONFIG_FILE") selected_group_count=0 for ((gi=0; ginum_experts = 2; + config.moe_config->router_topk = 1; + config.moe_config->router_pre_softmax = true; + + auto moe = std::make_shared(config); + moe->To(GetDevice()); + auto input = std::make_shared(std::vector{2, 4, config.n_embd}, DataType::kFLOAT32, GetDevice()); + input->Uniform(); + + auto output = (*moe)({input}); + ASSERT_EQ(output.size(), 1); + EXPECT_EQ(output[0]->Dims(), input->Dims()); + EXPECT_FALSE(moe->Parameters().empty()); +} + +TEST_P(TransformerModuleTest, MoELayerTop2SwiGLU) { + nn::TransformerConfig config; + config.n_embd = 32; + config.n_head = 2; + config.n_kv_head = 2; + config.activation_type = nn::MLPType::kSwiGLU; + config.add_bias_linear = false; + config.ffn_type = nn::FFNType::kMoE; + config.moe_config = nn::MoEConfig{}; + config.moe_config->num_experts = 4; + config.moe_config->router_topk = 2; + config.moe_config->moe_ffn_hidden_size = 48; + + auto moe = std::make_shared(config); + moe->To(GetDevice()); + auto input = std::make_shared(std::vector{2, 4, config.n_embd}, DataType::kFLOAT32, GetDevice()); + input->Uniform(); + + auto output = (*moe)({input}); + ASSERT_EQ(output.size(), 1); + EXPECT_EQ(output[0]->Dims(), input->Dims()); + + auto state = moe->StateDict(); + ASSERT_TRUE(state.contains("experts.expert_0.c_fc.weight")); + ASSERT_TRUE(state.contains("experts.expert_0.c_fc2.weight")); + ASSERT_TRUE(state.contains("experts.expert_0.c_proj.weight")); + EXPECT_EQ(state.at("experts.expert_0.c_fc.weight")->Dims(), (std::vector{48, config.n_embd})); + EXPECT_EQ(state.at("experts.expert_0.c_fc2.weight")->Dims(), (std::vector{48, config.n_embd})); + EXPECT_EQ(state.at("experts.expert_0.c_proj.weight")->Dims(), (std::vector{config.n_embd, 48})); +} + +TEST_P(TransformerModuleTest, TopKRouterMegatronOutputs) { + nn::TransformerConfig config; + config.n_embd = 32; + config.add_bias_linear = false; + config.ffn_type = nn::FFNType::kMoE; + config.moe_config = nn::MoEConfig{}; + config.moe_config->num_experts = 4; + config.moe_config->router_topk = 2; + + auto router = std::make_shared(config); + router->To(GetDevice()); + auto input = std::make_shared(std::vector{2, 4, config.n_embd}, DataType::kFLOAT32, GetDevice()); + input->Uniform(); + + auto output = (*router)({input}); + ASSERT_EQ(output.size(), 2); + EXPECT_EQ(output[0]->Dims(), (std::vector{2, 4, 4})); + EXPECT_EQ(output[1]->Dims(), (std::vector{2, 4, 4})); + EXPECT_EQ(output[0]->Dtype(), DataType::kFLOAT32); + EXPECT_EQ(output[1]->Dtype(), DataType::kBOOL); +} + +TEST_P(TransformerModuleTest, TopKTorchInterface) { + ONLY_CPU(); + const float data[] = {1.0f, 5.0f, 2.0f, 4.0f, 3.0f, 0.0f}; + auto input = std::make_shared(data, std::vector{2, 3}, DataType::kFLOAT32); + + auto largest_topk = std::make_shared(2, 1, true, true); + auto largest_values = largest_topk->Apply({input})[0]; + auto largest_indices = largest_topk->TopIndices(); + ASSERT_EQ(largest_values->Dims(), (std::vector{2, 2})); + ASSERT_EQ(largest_indices->Dims(), (std::vector{2, 2})); + const auto *largest_values_ptr = static_cast(largest_values->DataPtr()); + const auto *largest_indices_ptr = static_cast(largest_indices->DataPtr()); + EXPECT_FLOAT_EQ(largest_values_ptr[0], 5.0f); + EXPECT_FLOAT_EQ(largest_values_ptr[1], 2.0f); + EXPECT_FLOAT_EQ(largest_values_ptr[2], 4.0f); + EXPECT_FLOAT_EQ(largest_values_ptr[3], 3.0f); + EXPECT_EQ(largest_indices_ptr[0], 1); + EXPECT_EQ(largest_indices_ptr[1], 2); + EXPECT_EQ(largest_indices_ptr[2], 0); + EXPECT_EQ(largest_indices_ptr[3], 1); + + auto smallest_topk = std::make_shared(1, 0, false, true); + auto smallest_values = smallest_topk->Apply({input})[0]; + auto smallest_indices = smallest_topk->TopIndices(); + ASSERT_EQ(smallest_values->Dims(), (std::vector{1, 3})); + ASSERT_EQ(smallest_indices->Dims(), (std::vector{1, 3})); + const auto *smallest_values_ptr = static_cast(smallest_values->DataPtr()); + const auto *smallest_indices_ptr = static_cast(smallest_indices->DataPtr()); + EXPECT_FLOAT_EQ(smallest_values_ptr[0], 1.0f); + EXPECT_FLOAT_EQ(smallest_values_ptr[1], 3.0f); + EXPECT_FLOAT_EQ(smallest_values_ptr[2], 0.0f); + EXPECT_EQ(smallest_indices_ptr[0], 0); + EXPECT_EQ(smallest_indices_ptr[1], 1); + EXPECT_EQ(smallest_indices_ptr[2], 1); +} + +TEST_P(TransformerModuleTest, TopKRouterNormalization) { + ONLY_CPU(); + auto make_router = [](nn::MoEConfig::RouterScoreFunction score_function, bool pre_softmax) { + nn::TransformerConfig config; + config.n_embd = 2; + config.add_bias_linear = false; + config.ffn_type = nn::FFNType::kMoE; + config.moe_config = nn::MoEConfig{}; + config.moe_config->num_experts = 3; + config.moe_config->router_topk = 2; + config.moe_config->router_score_function = score_function; + config.moe_config->router_pre_softmax = pre_softmax; + auto router = std::make_shared(config); + auto weight = router->parameter(nn::moe::TopKRouter::kParamWeightName); + auto *weight_ptr = static_cast(weight->DataPtr()); + weight_ptr[0] = 1.0f; + weight_ptr[1] = 0.0f; + weight_ptr[2] = 2.0f; + weight_ptr[3] = 0.0f; + weight_ptr[4] = 0.0f; + weight_ptr[5] = 0.0f; + return router; + }; + + const float input_data[] = {1.0f, 1.0f}; + auto input = std::make_shared(input_data, std::vector{1, 1, 2}, DataType::kFLOAT32); + + auto softmax_router = make_router(nn::MoEConfig::RouterScoreFunction::kSoftmax, false); + auto softmax_output = (*softmax_router)({input}); + const auto *softmax_probs = static_cast(softmax_output[0]->DataPtr()); + EXPECT_NEAR(softmax_probs[0] + softmax_probs[1] + softmax_probs[2], 1.0f, 1e-5f); + EXPECT_GT(softmax_probs[1], softmax_probs[0]); + EXPECT_FLOAT_EQ(softmax_probs[2], 0.0f); + + auto sigmoid_router = make_router(nn::MoEConfig::RouterScoreFunction::kSigmoid, true); + auto sigmoid_output = (*sigmoid_router)({input}); + const auto *sigmoid_probs = static_cast(sigmoid_output[0]->DataPtr()); + EXPECT_NEAR(sigmoid_probs[0] + sigmoid_probs[1] + sigmoid_probs[2], 1.0f, 1e-5f); + EXPECT_GT(sigmoid_probs[1], sigmoid_probs[0]); + EXPECT_FLOAT_EQ(sigmoid_probs[2], 0.0f); +} + INFINI_TRAIN_REGISTER_TEST(TransformerModuleTest);