Skip to content

Commit b23e9c9

Browse files
committed
feat: add tiny_mixtral example
1 parent 69a74fd commit b23e9c9

8 files changed

Lines changed: 400 additions & 4 deletions

File tree

CMakeLists.txt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,14 @@ add_executable(gpt2
199199
)
200200
link_infini_train_exe(gpt2)
201201

202+
add_executable(tiny_mixtral
203+
example/tiny_mixtral/main.cc
204+
example/common/tiny_shakespeare_dataset.cc
205+
example/common/utils.cc
206+
example/tiny_mixtral/checkpoint_loader.cc
207+
)
208+
link_infini_train_exe(tiny_mixtral)
209+
202210
add_executable(llama3
203211
example/llama3/main.cc
204212
example/common/tiny_shakespeare_dataset.cc
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
#include "example/tiny_mixtral/checkpoint_loader.h"
2+
3+
#include <cstdint>
4+
#include <fstream>
5+
#include <memory>
6+
#include <string>
7+
#include <vector>
8+
9+
#include "glog/logging.h"
10+
11+
#include "infini_train/include/datatype.h"
12+
#include "infini_train/include/nn/modules/transformer/transformer.h"
13+
#include "infini_train/include/tensor.h"
14+
15+
#include "example/common/utils.h"
16+
#include "example/tiny_mixtral/config.h"
17+
18+
namespace nn = infini_train::nn;
19+
20+
namespace {
21+
22+
constexpr int32_t kTinyMixtralLLMCMagic = 20260513;
23+
constexpr int32_t kTinyMixtralLLMCVersion = 2;
24+
constexpr int64_t kLLMCHeaderEntries = 256;
25+
26+
} // namespace
27+
28+
namespace tiny_mixtral {
29+
30+
namespace {
31+
32+
template <typename T>
33+
void CompareCheckpointValue(const std::string &name, const T &checkpoint_value, const T &runtime_value) {
34+
CHECK_EQ(checkpoint_value, runtime_value) << name << " value from checkpoint (" << checkpoint_value
35+
<< ") is not equal to runtime config value (" << runtime_value << ")";
36+
}
37+
38+
} // namespace
39+
40+
nn::TransformerConfig ConfigFromLLMC(const std::string &filepath) {
41+
std::ifstream ifs(filepath, std::ios::binary);
42+
CHECK(ifs) << "Failed to open tiny Mixtral LLMC file: " << filepath;
43+
const auto header = infini_train::ReadSeveralBytesFromIfstream(kLLMCHeaderEntries * sizeof(int32_t), &ifs);
44+
CHECK(ifs) << "Failed to read tiny Mixtral LLMC header: " << filepath;
45+
CHECK_EQ(infini_train::BytesToType<int32_t>(header, 0 * sizeof(int32_t)), kTinyMixtralLLMCMagic);
46+
CHECK_EQ(infini_train::BytesToType<int32_t>(header, 1 * sizeof(int32_t)), kTinyMixtralLLMCVersion);
47+
48+
auto config = TinyMixtralConfig();
49+
config.block_size = infini_train::BytesToType<int32_t>(header, 2 * sizeof(int32_t));
50+
config.vocab_size = infini_train::BytesToType<int32_t>(header, 3 * sizeof(int32_t));
51+
config.original_vocab_size = config.vocab_size;
52+
config.n_layer = infini_train::BytesToType<int32_t>(header, 4 * sizeof(int32_t));
53+
config.n_head = infini_train::BytesToType<int32_t>(header, 5 * sizeof(int32_t));
54+
config.n_kv_head = infini_train::BytesToType<int32_t>(header, 6 * sizeof(int32_t));
55+
config.n_embd = infini_train::BytesToType<int32_t>(header, 7 * sizeof(int32_t));
56+
config.ffn_expansion_ratio = infini_train::BytesToType<float>(header, 9 * sizeof(int32_t));
57+
// Header slots 10 and 11 store dense-MLP helpers; MoE expert size is stored in moe_ffn_hidden_size.
58+
config.norm_eps = infini_train::BytesToType<float>(header, 12 * sizeof(int32_t));
59+
config.rope_theta = infini_train::BytesToType<float>(header, 13 * sizeof(int32_t));
60+
config.use_scaled_rope = infini_train::BytesToType<int32_t>(header, 14 * sizeof(int32_t)) != 0;
61+
62+
nn::MoEConfig moe_config;
63+
moe_config.num_experts = infini_train::BytesToType<int32_t>(header, 8 * sizeof(int32_t));
64+
moe_config.expert_parallel_size = 1;
65+
moe_config.router_topk = infini_train::BytesToType<int32_t>(header, 15 * sizeof(int32_t));
66+
moe_config.moe_ffn_hidden_size = infini_train::BytesToType<int32_t>(header, 16 * sizeof(int32_t));
67+
moe_config.token_dispatcher_type = nn::MoEConfig::TokenDispatcherType::kAllGather;
68+
moe_config.expert_impl = nn::MoEConfig::ExpertImpl::kSequential;
69+
config.moe_config = moe_config;
70+
SanitizeTinyMixtralConfig(config);
71+
return config;
72+
}
73+
74+
void CheckLLMCConfig(const std::string &filepath, const nn::TransformerConfig &expected_config) {
75+
SanitizeTinyMixtralConfig(expected_config);
76+
const auto checkpoint_config = ConfigFromLLMC(filepath);
77+
CompareCheckpointValue("block_size", checkpoint_config.block_size, expected_config.block_size);
78+
CompareCheckpointValue("vocab_size", checkpoint_config.vocab_size, expected_config.vocab_size);
79+
CompareCheckpointValue("original_vocab_size", checkpoint_config.original_vocab_size,
80+
expected_config.original_vocab_size);
81+
CompareCheckpointValue("n_layer", checkpoint_config.n_layer, expected_config.n_layer);
82+
CompareCheckpointValue("n_head", checkpoint_config.n_head, expected_config.n_head);
83+
CompareCheckpointValue("n_kv_head", checkpoint_config.n_kv_head, expected_config.n_kv_head);
84+
CompareCheckpointValue("n_embd", checkpoint_config.n_embd, expected_config.n_embd);
85+
CompareCheckpointValue("ffn_expansion_ratio", checkpoint_config.ffn_expansion_ratio,
86+
expected_config.ffn_expansion_ratio);
87+
CompareCheckpointValue("norm_eps", checkpoint_config.norm_eps, expected_config.norm_eps);
88+
CompareCheckpointValue("rope_theta", checkpoint_config.rope_theta, expected_config.rope_theta);
89+
CompareCheckpointValue("use_scaled_rope", checkpoint_config.use_scaled_rope, expected_config.use_scaled_rope);
90+
91+
CHECK(expected_config.moe_config.has_value()) << "tiny Mixtral runtime config requires MoE config";
92+
const auto &checkpoint_moe = checkpoint_config.moe_config.value();
93+
const auto &expected_moe = expected_config.moe_config.value();
94+
CompareCheckpointValue("num_experts", checkpoint_moe.num_experts, expected_moe.num_experts);
95+
CompareCheckpointValue("router_topk", checkpoint_moe.router_topk, expected_moe.router_topk);
96+
CompareCheckpointValue("moe_ffn_hidden_size", checkpoint_moe.moe_ffn_hidden_size, expected_moe.moe_ffn_hidden_size);
97+
}
98+
99+
std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath,
100+
const nn::TransformerConfig &expected_config) {
101+
CheckLLMCConfig(filepath, expected_config);
102+
auto model = std::make_shared<nn::TransformerModel>(expected_config);
103+
104+
std::ifstream ifs(filepath, std::ios::binary);
105+
CHECK(ifs) << "Failed to open tiny Mixtral LLMC file: " << filepath;
106+
const auto header = infini_train::ReadSeveralBytesFromIfstream(kLLMCHeaderEntries * sizeof(int32_t), &ifs);
107+
CHECK(ifs) << "Failed to read tiny Mixtral LLMC header: " << filepath;
108+
CHECK_EQ(infini_train::BytesToType<int32_t>(header, 0 * sizeof(int32_t)), kTinyMixtralLLMCMagic);
109+
CHECK_EQ(infini_train::BytesToType<int32_t>(header, 1 * sizeof(int32_t)), kTinyMixtralLLMCVersion);
110+
111+
const auto &config = expected_config;
112+
auto state = model->StateDict();
113+
auto read_tensor_by_state_key = [&](const std::string &name) {
114+
CHECK(state.contains(name)) << "Model state_dict does not contain " << name;
115+
std::shared_ptr<infini_train::Tensor> tensor = state.at(name);
116+
CHECK(tensor->Dtype() == infini_train::DataType::kFLOAT32)
117+
<< "Only float32 tiny Mixtral LLMC files are supported: " << name;
118+
infini_train::ReadMatrixAllFloat(ifs, static_cast<float *>(tensor->DataPtr()), tensor->NumElements(), 1);
119+
CHECK(ifs) << "Failed to read tensor " << name;
120+
};
121+
122+
auto read_projection_into_packed_qkv = [&](const std::string &packed_qkv_name, int64_t row_offset, int64_t num_rows,
123+
const std::string &projection_name) {
124+
CHECK(state.contains(packed_qkv_name)) << "Model state_dict does not contain " << packed_qkv_name;
125+
std::shared_ptr<infini_train::Tensor> tensor = state.at(packed_qkv_name);
126+
CHECK(tensor->Dtype() == infini_train::DataType::kFLOAT32)
127+
<< "Only float32 tiny Mixtral LLMC files are supported: " << projection_name;
128+
CHECK_EQ(tensor->Dims().size(), 2);
129+
CHECK_GE(row_offset, 0);
130+
CHECK_GT(num_rows, 0);
131+
CHECK_LE(row_offset + num_rows, tensor->Dims()[0]);
132+
const int64_t cols = tensor->Dims()[1];
133+
auto *data = static_cast<float *>(tensor->DataPtr()) + row_offset * cols;
134+
infini_train::ReadMatrixAllFloat(ifs, data, num_rows, cols);
135+
CHECK(ifs) << "Failed to read tensor rows " << projection_name;
136+
};
137+
138+
const auto &moe_config = config.moe_config.value();
139+
read_tensor_by_state_key("transformer.wte.weight");
140+
for (int64_t layer = 0; layer < config.n_layer; ++layer) {
141+
const std::string prefix = "transformer.h." + std::to_string(layer);
142+
read_tensor_by_state_key(prefix + ".ln_1.weight");
143+
const auto c_attn_name = prefix + ".attn.c_attn.weight";
144+
const int64_t head_dim = config.n_embd / config.n_head;
145+
const int64_t q_rows = config.n_head * head_dim;
146+
const int64_t kv_rows = config.n_kv_head * head_dim;
147+
read_projection_into_packed_qkv(c_attn_name, 0, q_rows, c_attn_name + ".q_proj");
148+
read_projection_into_packed_qkv(c_attn_name, q_rows, kv_rows, c_attn_name + ".k_proj");
149+
read_projection_into_packed_qkv(c_attn_name, q_rows + kv_rows, kv_rows, c_attn_name + ".v_proj");
150+
read_tensor_by_state_key(prefix + ".attn.c_proj.weight");
151+
read_tensor_by_state_key(prefix + ".ln_2.weight");
152+
read_tensor_by_state_key(prefix + ".mlp.router.weight");
153+
for (int64_t expert = 0; expert < moe_config.num_experts; ++expert) {
154+
const std::string expert_prefix = prefix + ".mlp.experts.expert_" + std::to_string(expert);
155+
read_tensor_by_state_key(expert_prefix + ".c_fc2.weight"); // Mixtral w1/gate_proj
156+
read_tensor_by_state_key(expert_prefix + ".c_fc.weight"); // Mixtral w3/up_proj
157+
read_tensor_by_state_key(expert_prefix + ".c_proj.weight"); // Mixtral w2/down_proj
158+
}
159+
}
160+
read_tensor_by_state_key("transformer.ln_f.weight");
161+
read_tensor_by_state_key("lm_head.weight");
162+
163+
CHECK_EQ(ifs.peek(), std::ifstream::traits_type::eof()) << "Unexpected trailing bytes in tiny Mixtral LLMC file";
164+
return model;
165+
}
166+
167+
} // namespace tiny_mixtral
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#pragma once
2+
3+
#include <memory>
4+
#include <string>
5+
6+
#include "infini_train/include/nn/modules/transformer/transformer_config.h"
7+
8+
namespace infini_train::nn {
9+
class TransformerModel;
10+
} // namespace infini_train::nn
11+
12+
namespace tiny_mixtral {
13+
14+
infini_train::nn::TransformerConfig ConfigFromLLMC(const std::string &filepath);
15+
16+
void CheckLLMCConfig(const std::string &filepath, const infini_train::nn::TransformerConfig &expected_config);
17+
18+
std::shared_ptr<infini_train::nn::TransformerModel>
19+
LoadFromLLMC(const std::string &filepath, const infini_train::nn::TransformerConfig &expected_config);
20+
21+
} // namespace tiny_mixtral

example/tiny_mixtral/config.h

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
#pragma once
2+
3+
#include "glog/logging.h"
4+
5+
#include "infini_train/include/nn/modules/transformer/transformer_config.h"
6+
7+
namespace nn = infini_train::nn;
8+
9+
namespace tiny_mixtral {
10+
11+
inline nn::TransformerConfig TinyMixtralConfig() {
12+
nn::TransformerConfig config;
13+
config.block_size = 16; // Tiny max positions; Megatron --max-position-embeddings is 32768.
14+
config.vocab_size = 128256; // Validation data uses LLaMA3 token ids; real Mixtral uses 32000.
15+
config.original_vocab_size = 128256;
16+
config.n_layer = 2; // Tiny scale; Megatron --num-layers 32.
17+
config.n_head = 4; // Tiny scale; preserves the Megatron 4:1 GQA ratio.
18+
config.n_kv_head = 1; // Tiny scale; Megatron --num-query-groups 8.
19+
config.n_embd = 32; // Tiny scale; Megatron --hidden-size 4096.
20+
config.attention_type = nn::AttentionType::kRoPE;
21+
config.activation_type = nn::MLPType::kSwiGLU;
22+
config.ffn_type = nn::FFNType::kMoE;
23+
config.norm_type = nn::NormType::kRMSNorm;
24+
config.add_bias_linear = false;
25+
config.add_bias_lm_head = false;
26+
config.tie_weights = false;
27+
config.ffn_expansion_ratio = 3.5f;
28+
config.norm_eps = 1e-5f;
29+
config.rope_theta = 1000000.0f;
30+
config.use_scaled_rope = false;
31+
32+
nn::MoEConfig moe_config;
33+
moe_config.num_experts = 8;
34+
moe_config.expert_parallel_size = 1; // Single-rank validation scale.
35+
moe_config.router_topk = 2;
36+
moe_config.moe_ffn_hidden_size = 112; // Tiny scale; Megatron --ffn-hidden-size 14336.
37+
moe_config.token_dispatcher_type = nn::MoEConfig::TokenDispatcherType::kAllGather; // Single-rank validation path.
38+
moe_config.expert_impl = nn::MoEConfig::ExpertImpl::kSequential; // Local correctness path.
39+
config.moe_config = moe_config;
40+
return config;
41+
}
42+
43+
inline void SanitizeTinyMixtralConfig(const nn::TransformerConfig &c) {
44+
CHECK_GT(c.block_size, 0);
45+
CHECK_GT(c.vocab_size, 0);
46+
CHECK_GE(c.vocab_size, c.original_vocab_size);
47+
CHECK_GT(c.n_layer, 0);
48+
CHECK_GT(c.n_head, 0);
49+
CHECK_GT(c.n_kv_head, 0);
50+
CHECK_LE(c.n_kv_head, c.n_head);
51+
CHECK_EQ(c.n_head % c.n_kv_head, 0) << "n_head must be divisible by n_kv_head for GQA";
52+
CHECK_GT(c.n_embd, 0);
53+
CHECK_EQ(c.n_embd % c.n_head, 0) << "n_embd must be divisible by n_head";
54+
CHECK(c.attention_type == nn::AttentionType::kRoPE) << "tiny Mixtral requires RoPE attention";
55+
CHECK(c.activation_type == nn::MLPType::kSwiGLU) << "tiny Mixtral requires SwiGLU activation";
56+
CHECK(c.ffn_type == nn::FFNType::kMoE) << "tiny Mixtral requires MoE FFN";
57+
CHECK(c.norm_type == nn::NormType::kRMSNorm) << "tiny Mixtral requires RMSNorm";
58+
CHECK(!c.add_bias_linear) << "tiny Mixtral has no bias in linear layers";
59+
CHECK(!c.add_bias_lm_head) << "tiny Mixtral has no bias in lm_head";
60+
CHECK(!c.tie_weights) << "tiny Mixtral does not tie embedding and lm_head weights";
61+
CHECK(!c.use_scaled_rope) << "tiny Mixtral precision validation keeps scaled RoPE disabled";
62+
CHECK(c.moe_config.has_value()) << "tiny Mixtral requires MoE config";
63+
64+
const auto &moe = c.moe_config.value();
65+
CHECK_GT(moe.num_experts, 0);
66+
CHECK_EQ(moe.expert_parallel_size, 1) << "tiny Mixtral single-rank validation expects EP=1";
67+
CHECK_GT(moe.router_topk, 0);
68+
CHECK_LE(moe.router_topk, moe.num_experts);
69+
CHECK_GT(moe.moe_ffn_hidden_size, 0);
70+
CHECK(moe.token_dispatcher_type == nn::MoEConfig::TokenDispatcherType::kAllGather)
71+
<< "tiny Mixtral uses the Megatron-style AllGather dispatcher";
72+
CHECK(moe.expert_impl == nn::MoEConfig::ExpertImpl::kSequential)
73+
<< "tiny Mixtral validation uses SequentialMLP experts";
74+
}
75+
76+
} // namespace tiny_mixtral

0 commit comments

Comments
 (0)