Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,14 @@ add_executable(gpt2
)
link_infini_train_exe(gpt2)

add_executable(tiny_mixtral
example/tiny_mixtral/main.cc
example/common/tiny_shakespeare_dataset.cc
example/common/utils.cc
example/tiny_mixtral/checkpoint_loader.cc
)
link_infini_train_exe(tiny_mixtral)

add_executable(llama3
example/llama3/main.cc
example/common/tiny_shakespeare_dataset.cc
Expand Down
173 changes: 173 additions & 0 deletions example/tiny_mixtral/checkpoint_loader.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
#include "example/tiny_mixtral/checkpoint_loader.h"

#include <cstdint>
#include <fstream>
#include <memory>
#include <string>
#include <vector>

#include "glog/logging.h"

#include "infini_train/include/datatype.h"
#include "infini_train/include/nn/modules/transformer/transformer.h"
#include "infini_train/include/tensor.h"

#include "example/common/utils.h"
#include "example/tiny_mixtral/config.h"

namespace nn = infini_train::nn;

namespace {

constexpr int32_t kTinyMixtralLLMCMagic = 20260513;
constexpr int32_t kTinyMixtralLLMCVersion = 2;
constexpr int64_t kLLMCHeaderEntries = 256;

} // namespace

namespace tiny_mixtral {

namespace {

template <typename T>
void CompareCheckpointValue(const std::string &name, const T &checkpoint_value, const T &runtime_value) {
CHECK_EQ(checkpoint_value, runtime_value) << name << " value from checkpoint (" << checkpoint_value
<< ") is not equal to runtime config value (" << runtime_value << ")";
}

} // namespace

nn::TransformerConfig ConfigFromLLMC(const std::string &filepath) {
std::ifstream ifs(filepath, std::ios::binary);
CHECK(ifs) << "Failed to open tiny Mixtral LLMC file: " << filepath;
const auto header = infini_train::ReadSeveralBytesFromIfstream(kLLMCHeaderEntries * sizeof(int32_t), &ifs);
CHECK(ifs) << "Failed to read tiny Mixtral LLMC header: " << filepath;
CHECK_EQ(infini_train::BytesToType<int32_t>(header, 0 * sizeof(int32_t)), kTinyMixtralLLMCMagic);
CHECK_EQ(infini_train::BytesToType<int32_t>(header, 1 * sizeof(int32_t)), kTinyMixtralLLMCVersion);

auto config = TinyMixtralConfig();
config.block_size = infini_train::BytesToType<int32_t>(header, 2 * sizeof(int32_t));
config.vocab_size = infini_train::BytesToType<int32_t>(header, 3 * sizeof(int32_t));
config.original_vocab_size = config.vocab_size;
config.n_layer = infini_train::BytesToType<int32_t>(header, 4 * sizeof(int32_t));
config.n_head = infini_train::BytesToType<int32_t>(header, 5 * sizeof(int32_t));
config.n_kv_head = infini_train::BytesToType<int32_t>(header, 6 * sizeof(int32_t));
config.n_embd = infini_train::BytesToType<int32_t>(header, 7 * sizeof(int32_t));
config.ffn_expansion_ratio = infini_train::BytesToType<float>(header, 9 * sizeof(int32_t));
config.ffn_dim_multiplier = infini_train::BytesToType<float>(header, 10 * sizeof(int32_t));
config.multiple_of = infini_train::BytesToType<int32_t>(header, 11 * sizeof(int32_t));
config.norm_eps = infini_train::BytesToType<float>(header, 12 * sizeof(int32_t));
config.rope_theta = infini_train::BytesToType<float>(header, 13 * sizeof(int32_t));
config.use_scaled_rope = infini_train::BytesToType<int32_t>(header, 14 * sizeof(int32_t)) != 0;

nn::MoEConfig moe_config;
moe_config.num_experts = infini_train::BytesToType<int32_t>(header, 8 * sizeof(int32_t));
moe_config.expert_parallel_size = 1;
moe_config.router_topk = infini_train::BytesToType<int32_t>(header, 15 * sizeof(int32_t));
moe_config.moe_ffn_hidden_size = infini_train::BytesToType<int32_t>(header, 16 * sizeof(int32_t));
moe_config.dispatcher_type = nn::MoEConfig::DispatcherType::kAllGather;
moe_config.expert_impl = nn::MoEConfig::ExpertImpl::kSequential;
config.moe_config = moe_config;
SanitizeTinyMixtralConfig(config);
return config;
}

void CheckLLMCConfig(const std::string &filepath, const nn::TransformerConfig &expected_config) {
SanitizeTinyMixtralConfig(expected_config);
const auto checkpoint_config = ConfigFromLLMC(filepath);
CompareCheckpointValue("block_size", checkpoint_config.block_size, expected_config.block_size);
CompareCheckpointValue("vocab_size", checkpoint_config.vocab_size, expected_config.vocab_size);
CompareCheckpointValue("original_vocab_size", checkpoint_config.original_vocab_size,
expected_config.original_vocab_size);
CompareCheckpointValue("n_layer", checkpoint_config.n_layer, expected_config.n_layer);
CompareCheckpointValue("n_head", checkpoint_config.n_head, expected_config.n_head);
CompareCheckpointValue("n_kv_head", checkpoint_config.n_kv_head, expected_config.n_kv_head);
CompareCheckpointValue("n_embd", checkpoint_config.n_embd, expected_config.n_embd);
CompareCheckpointValue("ffn_expansion_ratio", checkpoint_config.ffn_expansion_ratio,
expected_config.ffn_expansion_ratio);
CHECK(checkpoint_config.ffn_dim_multiplier.has_value()) << "checkpoint ffn_dim_multiplier is missing";
CHECK(expected_config.ffn_dim_multiplier.has_value()) << "runtime ffn_dim_multiplier is missing";
CompareCheckpointValue("ffn_dim_multiplier", checkpoint_config.ffn_dim_multiplier.value(),
expected_config.ffn_dim_multiplier.value());
CompareCheckpointValue("multiple_of", checkpoint_config.multiple_of, expected_config.multiple_of);
CompareCheckpointValue("norm_eps", checkpoint_config.norm_eps, expected_config.norm_eps);
CompareCheckpointValue("rope_theta", checkpoint_config.rope_theta, expected_config.rope_theta);
CompareCheckpointValue("use_scaled_rope", checkpoint_config.use_scaled_rope, expected_config.use_scaled_rope);

CHECK(expected_config.moe_config.has_value()) << "tiny Mixtral runtime config requires MoE config";
const auto &checkpoint_moe = checkpoint_config.moe_config.value();
const auto &expected_moe = expected_config.moe_config.value();
CompareCheckpointValue("num_experts", checkpoint_moe.num_experts, expected_moe.num_experts);
CompareCheckpointValue("router_topk", checkpoint_moe.router_topk, expected_moe.router_topk);
CompareCheckpointValue("moe_ffn_hidden_size", checkpoint_moe.moe_ffn_hidden_size, expected_moe.moe_ffn_hidden_size);
}

std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath,
const nn::TransformerConfig &expected_config) {
CheckLLMCConfig(filepath, expected_config);
auto model = std::make_shared<nn::TransformerModel>(expected_config);

std::ifstream ifs(filepath, std::ios::binary);
CHECK(ifs) << "Failed to open tiny Mixtral LLMC file: " << filepath;
const auto header = infini_train::ReadSeveralBytesFromIfstream(kLLMCHeaderEntries * sizeof(int32_t), &ifs);
CHECK(ifs) << "Failed to read tiny Mixtral LLMC header: " << filepath;
CHECK_EQ(infini_train::BytesToType<int32_t>(header, 0 * sizeof(int32_t)), kTinyMixtralLLMCMagic);
CHECK_EQ(infini_train::BytesToType<int32_t>(header, 1 * sizeof(int32_t)), kTinyMixtralLLMCVersion);

const auto &config = expected_config;
auto state = model->StateDict();
auto read_tensor_by_state_key = [&](const std::string &name) {
CHECK(state.contains(name)) << "Model state_dict does not contain " << name;
std::shared_ptr<infini_train::Tensor> tensor = state.at(name);
CHECK(tensor->Dtype() == infini_train::DataType::kFLOAT32)
<< "Only float32 tiny Mixtral LLMC files are supported: " << name;
infini_train::ReadMatrixAllFloat(ifs, static_cast<float *>(tensor->DataPtr()), tensor->NumElements(), 1);
CHECK(ifs) << "Failed to read tensor " << name;
};

auto read_projection_into_packed_qkv = [&](const std::string &packed_qkv_name, int64_t row_offset, int64_t num_rows,
const std::string &projection_name) {
CHECK(state.contains(packed_qkv_name)) << "Model state_dict does not contain " << packed_qkv_name;
std::shared_ptr<infini_train::Tensor> tensor = state.at(packed_qkv_name);
CHECK(tensor->Dtype() == infini_train::DataType::kFLOAT32)
<< "Only float32 tiny Mixtral LLMC files are supported: " << projection_name;
CHECK_EQ(tensor->Dims().size(), 2);
CHECK_GE(row_offset, 0);
CHECK_GT(num_rows, 0);
CHECK_LE(row_offset + num_rows, tensor->Dims()[0]);
const int64_t cols = tensor->Dims()[1];
auto *data = static_cast<float *>(tensor->DataPtr()) + row_offset * cols;
infini_train::ReadMatrixAllFloat(ifs, data, num_rows, cols);
CHECK(ifs) << "Failed to read tensor rows " << projection_name;
};

const auto &moe_config = config.moe_config.value();
read_tensor_by_state_key("transformer.wte.weight");
for (int64_t layer = 0; layer < config.n_layer; ++layer) {
const std::string prefix = "transformer.h." + std::to_string(layer);
read_tensor_by_state_key(prefix + ".ln_1.weight");
const auto c_attn_name = prefix + ".attn.c_attn.weight";
const int64_t head_dim = config.n_embd / config.n_head;
const int64_t q_rows = config.n_head * head_dim;
const int64_t kv_rows = config.n_kv_head * head_dim;
read_projection_into_packed_qkv(c_attn_name, 0, q_rows, c_attn_name + ".q_proj");
read_projection_into_packed_qkv(c_attn_name, q_rows, kv_rows, c_attn_name + ".k_proj");
read_projection_into_packed_qkv(c_attn_name, q_rows + kv_rows, kv_rows, c_attn_name + ".v_proj");
read_tensor_by_state_key(prefix + ".attn.c_proj.weight");
read_tensor_by_state_key(prefix + ".ln_2.weight");
read_tensor_by_state_key(prefix + ".mlp.router.weight");
for (int64_t expert = 0; expert < moe_config.num_experts; ++expert) {
const std::string expert_prefix = prefix + ".mlp.experts.expert_" + std::to_string(expert);
read_tensor_by_state_key(expert_prefix + ".c_fc2.weight"); // Mixtral w1/gate_proj
read_tensor_by_state_key(expert_prefix + ".c_fc.weight"); // Mixtral w3/up_proj
read_tensor_by_state_key(expert_prefix + ".c_proj.weight"); // Mixtral w2/down_proj
}
}
read_tensor_by_state_key("transformer.ln_f.weight");
read_tensor_by_state_key("lm_head.weight");

CHECK_EQ(ifs.peek(), std::ifstream::traits_type::eof()) << "Unexpected trailing bytes in tiny Mixtral LLMC file";
return model;
}

} // namespace tiny_mixtral
21 changes: 21 additions & 0 deletions example/tiny_mixtral/checkpoint_loader.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#pragma once

#include <memory>
#include <string>

#include "infini_train/include/nn/modules/transformer/transformer_config.h"

namespace infini_train::nn {
class TransformerModel;
} // namespace infini_train::nn

namespace tiny_mixtral {

infini_train::nn::TransformerConfig ConfigFromLLMC(const std::string &filepath);

void CheckLLMCConfig(const std::string &filepath, const infini_train::nn::TransformerConfig &expected_config);

std::shared_ptr<infini_train::nn::TransformerModel>
LoadFromLLMC(const std::string &filepath, const infini_train::nn::TransformerConfig &expected_config);

} // namespace tiny_mixtral
80 changes: 80 additions & 0 deletions example/tiny_mixtral/config.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
#pragma once

#include "glog/logging.h"

#include "infini_train/include/nn/modules/transformer/transformer_config.h"

namespace nn = infini_train::nn;

namespace tiny_mixtral {

inline nn::TransformerConfig TinyMixtralConfig() {
nn::TransformerConfig config;
config.block_size = 16; // Tiny max positions; Megatron --max-position-embeddings is 32768.
config.vocab_size = 128256; // Validation data uses LLaMA3 token ids; real Mixtral uses 32000.
config.original_vocab_size = 128256;
config.n_layer = 2; // Tiny scale; Megatron --num-layers 32.
config.n_head = 4; // Tiny scale; preserves the Megatron 4:1 GQA ratio.
config.n_kv_head = 1; // Tiny scale; Megatron --num-query-groups 8.
config.n_embd = 32; // Tiny scale; Megatron --hidden-size 4096.
config.attention_type = nn::AttentionType::kRoPE;
config.activation_type = nn::MLPType::kSwiGLU;
config.ffn_type = nn::FFNType::kMoE;
config.norm_type = nn::NormType::kRMSNorm;
config.add_bias_linear = false;
config.add_bias_lm_head = false;
config.tie_weights = false;
config.ffn_expansion_ratio = 3.5f;
config.ffn_dim_multiplier = 1.0f; // InfiniTrain helper; Megatron passes --ffn-hidden-size directly.
config.multiple_of = 1; // InfiniTrain helper; Megatron passes --ffn-hidden-size directly.
config.norm_eps = 1e-5f;
config.rope_theta = 1000000.0f;
config.use_scaled_rope = false;

nn::MoEConfig moe_config;
moe_config.num_experts = 8;
moe_config.expert_parallel_size = 1; // Single-rank validation scale.
moe_config.router_topk = 2;
moe_config.moe_ffn_hidden_size = 112; // Tiny scale; Megatron --ffn-hidden-size 14336.
moe_config.dispatcher_type = nn::MoEConfig::DispatcherType::kAllGather; // Single-rank validation path.
moe_config.expert_impl = nn::MoEConfig::ExpertImpl::kSequential; // Local correctness path.
config.moe_config = moe_config;
return config;
}

inline void SanitizeTinyMixtralConfig(const nn::TransformerConfig &c) {
CHECK_GT(c.block_size, 0);
CHECK_GT(c.vocab_size, 0);
CHECK_GE(c.vocab_size, c.original_vocab_size);
CHECK_GT(c.n_layer, 0);
CHECK_GT(c.n_head, 0);
CHECK_GT(c.n_kv_head, 0);
CHECK_LE(c.n_kv_head, c.n_head);
CHECK_EQ(c.n_head % c.n_kv_head, 0) << "n_head must be divisible by n_kv_head for GQA";
CHECK_GT(c.n_embd, 0);
CHECK_EQ(c.n_embd % c.n_head, 0) << "n_embd must be divisible by n_head";
CHECK(c.attention_type == nn::AttentionType::kRoPE) << "tiny Mixtral requires RoPE attention";
CHECK(c.activation_type == nn::MLPType::kSwiGLU) << "tiny Mixtral requires SwiGLU activation";
CHECK(c.ffn_type == nn::FFNType::kMoE) << "tiny Mixtral requires MoE FFN";
CHECK(c.norm_type == nn::NormType::kRMSNorm) << "tiny Mixtral requires RMSNorm";
CHECK(!c.add_bias_linear) << "tiny Mixtral has no bias in linear layers";
CHECK(!c.add_bias_lm_head) << "tiny Mixtral has no bias in lm_head";
CHECK(!c.tie_weights) << "tiny Mixtral does not tie embedding and lm_head weights";
CHECK(c.ffn_dim_multiplier.has_value()) << "tiny Mixtral requires ffn_dim_multiplier";
CHECK_GT(c.multiple_of, 0);
CHECK(!c.use_scaled_rope) << "tiny Mixtral precision validation keeps scaled RoPE disabled";
CHECK(c.moe_config.has_value()) << "tiny Mixtral requires MoE config";

const auto &moe = c.moe_config.value();
CHECK_GT(moe.num_experts, 0);
CHECK_EQ(moe.expert_parallel_size, 1) << "tiny Mixtral single-rank validation expects EP=1";
CHECK_GT(moe.router_topk, 0);
CHECK_LE(moe.router_topk, moe.num_experts);
CHECK_GT(moe.moe_ffn_hidden_size, 0);
CHECK(moe.dispatcher_type == nn::MoEConfig::DispatcherType::kAllGather)
<< "tiny Mixtral uses the Megatron-style AllGather dispatcher";
CHECK(moe.expert_impl == nn::MoEConfig::ExpertImpl::kSequential)
<< "tiny Mixtral validation uses SequentialMLP experts";
}

} // namespace tiny_mixtral
Loading
Loading