Skip to content

Commit e165a19

Browse files
committed
feat: add tiny_mixtral example
1 parent 69a74fd commit e165a19

5 files changed

Lines changed: 408 additions & 0 deletions

File tree

CMakeLists.txt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,14 @@ add_executable(gpt2
199199
)
200200
link_infini_train_exe(gpt2)
201201

202+
add_executable(tiny_mixtral
203+
example/tiny_mixtral/main.cc
204+
example/common/tiny_shakespeare_dataset.cc
205+
example/common/utils.cc
206+
example/tiny_mixtral/checkpoint_loader.cc
207+
)
208+
link_infini_train_exe(tiny_mixtral)
209+
202210
add_executable(llama3
203211
example/llama3/main.cc
204212
example/common/tiny_shakespeare_dataset.cc
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
#include "example/tiny_mixtral/checkpoint_loader.h"
2+
3+
#include <cstdint>
4+
#include <fstream>
5+
#include <memory>
6+
#include <string>
7+
#include <vector>
8+
9+
#include "glog/logging.h"
10+
11+
#include "infini_train/include/datatype.h"
12+
#include "infini_train/include/nn/modules/transformer/transformer.h"
13+
#include "infini_train/include/tensor.h"
14+
15+
#include "example/common/utils.h"
16+
#include "example/tiny_mixtral/config.h"
17+
18+
namespace nn = infini_train::nn;
19+
20+
namespace {
21+
22+
constexpr int32_t kTinyMixtralLLMCMagic = 20260513;
23+
constexpr int32_t kTinyMixtralLLMCVersion = 2;
24+
constexpr int64_t kLLMCHeaderEntries = 256;
25+
26+
} // namespace
27+
28+
namespace tiny_mixtral {
29+
30+
namespace {
31+
32+
template <typename T>
33+
void CompareCheckpointValue(const std::string &name, const T &checkpoint_value,
34+
const T &runtime_value) {
35+
CHECK_EQ(checkpoint_value, runtime_value)
36+
<< name << " value from checkpoint (" << checkpoint_value << ") is not equal to runtime config value ("
37+
<< runtime_value << ")";
38+
}
39+
40+
} // namespace
41+
42+
nn::TransformerConfig ConfigFromLLMC(const std::string &filepath) {
43+
std::ifstream ifs(filepath, std::ios::binary);
44+
CHECK(ifs) << "Failed to open tiny Mixtral LLMC file: " << filepath;
45+
const auto header = infini_train::ReadSeveralBytesFromIfstream(kLLMCHeaderEntries * sizeof(int32_t), &ifs);
46+
CHECK(ifs) << "Failed to read tiny Mixtral LLMC header: " << filepath;
47+
CHECK_EQ(infini_train::BytesToType<int32_t>(header, 0 * sizeof(int32_t)), kTinyMixtralLLMCMagic);
48+
CHECK_EQ(infini_train::BytesToType<int32_t>(header, 1 * sizeof(int32_t)), kTinyMixtralLLMCVersion);
49+
50+
auto config = TinyMixtralConfig();
51+
config.block_size = infini_train::BytesToType<int32_t>(header, 2 * sizeof(int32_t));
52+
config.vocab_size = infini_train::BytesToType<int32_t>(header, 3 * sizeof(int32_t));
53+
config.original_vocab_size = config.vocab_size;
54+
config.n_layer = infini_train::BytesToType<int32_t>(header, 4 * sizeof(int32_t));
55+
config.n_head = infini_train::BytesToType<int32_t>(header, 5 * sizeof(int32_t));
56+
config.n_kv_head = infini_train::BytesToType<int32_t>(header, 6 * sizeof(int32_t));
57+
config.n_embd = infini_train::BytesToType<int32_t>(header, 7 * sizeof(int32_t));
58+
config.ffn_expansion_ratio = infini_train::BytesToType<float>(header, 9 * sizeof(int32_t));
59+
config.ffn_dim_multiplier = infini_train::BytesToType<float>(header, 10 * sizeof(int32_t));
60+
config.multiple_of = infini_train::BytesToType<int32_t>(header, 11 * sizeof(int32_t));
61+
config.norm_eps = infini_train::BytesToType<float>(header, 12 * sizeof(int32_t));
62+
config.rope_theta = infini_train::BytesToType<float>(header, 13 * sizeof(int32_t));
63+
config.use_scaled_rope = infini_train::BytesToType<int32_t>(header, 14 * sizeof(int32_t)) != 0;
64+
65+
nn::MoEConfig moe_config;
66+
moe_config.num_experts = infini_train::BytesToType<int32_t>(header, 8 * sizeof(int32_t));
67+
moe_config.expert_parallel_size = 1;
68+
moe_config.router_topk = infini_train::BytesToType<int32_t>(header, 15 * sizeof(int32_t));
69+
moe_config.moe_ffn_hidden_size = infini_train::BytesToType<int32_t>(header, 16 * sizeof(int32_t));
70+
moe_config.dispatcher_type = nn::MoEConfig::DispatcherType::kAllGather;
71+
moe_config.expert_impl = nn::MoEConfig::ExpertImpl::kSequential;
72+
config.moe_config = moe_config;
73+
SanitizeTinyMixtralConfig(config);
74+
return config;
75+
}
76+
77+
void CheckLLMCConfig(const std::string &filepath, const nn::TransformerConfig &expected_config) {
78+
SanitizeTinyMixtralConfig(expected_config);
79+
const auto checkpoint_config = ConfigFromLLMC(filepath);
80+
CompareCheckpointValue("block_size", checkpoint_config.block_size, expected_config.block_size);
81+
CompareCheckpointValue("vocab_size", checkpoint_config.vocab_size, expected_config.vocab_size);
82+
CompareCheckpointValue("original_vocab_size", checkpoint_config.original_vocab_size,
83+
expected_config.original_vocab_size);
84+
CompareCheckpointValue("n_layer", checkpoint_config.n_layer, expected_config.n_layer);
85+
CompareCheckpointValue("n_head", checkpoint_config.n_head, expected_config.n_head);
86+
CompareCheckpointValue("n_kv_head", checkpoint_config.n_kv_head, expected_config.n_kv_head);
87+
CompareCheckpointValue("n_embd", checkpoint_config.n_embd, expected_config.n_embd);
88+
CompareCheckpointValue("ffn_expansion_ratio", checkpoint_config.ffn_expansion_ratio,
89+
expected_config.ffn_expansion_ratio);
90+
CHECK(checkpoint_config.ffn_dim_multiplier.has_value()) << "checkpoint ffn_dim_multiplier is missing";
91+
CHECK(expected_config.ffn_dim_multiplier.has_value()) << "runtime ffn_dim_multiplier is missing";
92+
CompareCheckpointValue("ffn_dim_multiplier", checkpoint_config.ffn_dim_multiplier.value(),
93+
expected_config.ffn_dim_multiplier.value());
94+
CompareCheckpointValue("multiple_of", checkpoint_config.multiple_of, expected_config.multiple_of);
95+
CompareCheckpointValue("norm_eps", checkpoint_config.norm_eps, expected_config.norm_eps);
96+
CompareCheckpointValue("rope_theta", checkpoint_config.rope_theta, expected_config.rope_theta);
97+
CompareCheckpointValue("use_scaled_rope", checkpoint_config.use_scaled_rope, expected_config.use_scaled_rope);
98+
99+
CHECK(expected_config.moe_config.has_value()) << "tiny Mixtral runtime config requires MoE config";
100+
const auto &checkpoint_moe = checkpoint_config.moe_config.value();
101+
const auto &expected_moe = expected_config.moe_config.value();
102+
CompareCheckpointValue("num_experts", checkpoint_moe.num_experts, expected_moe.num_experts);
103+
CompareCheckpointValue("router_topk", checkpoint_moe.router_topk, expected_moe.router_topk);
104+
CompareCheckpointValue("moe_ffn_hidden_size", checkpoint_moe.moe_ffn_hidden_size,
105+
expected_moe.moe_ffn_hidden_size);
106+
}
107+
108+
std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath,
109+
const nn::TransformerConfig &expected_config) {
110+
CheckLLMCConfig(filepath, expected_config);
111+
auto model = std::make_shared<nn::TransformerModel>(expected_config);
112+
113+
std::ifstream ifs(filepath, std::ios::binary);
114+
CHECK(ifs) << "Failed to open tiny Mixtral LLMC file: " << filepath;
115+
const auto header = infini_train::ReadSeveralBytesFromIfstream(kLLMCHeaderEntries * sizeof(int32_t), &ifs);
116+
CHECK(ifs) << "Failed to read tiny Mixtral LLMC header: " << filepath;
117+
CHECK_EQ(infini_train::BytesToType<int32_t>(header, 0 * sizeof(int32_t)), kTinyMixtralLLMCMagic);
118+
CHECK_EQ(infini_train::BytesToType<int32_t>(header, 1 * sizeof(int32_t)), kTinyMixtralLLMCVersion);
119+
120+
const auto &config = expected_config;
121+
auto state = model->StateDict();
122+
auto read_tensor_by_state_key = [&](const std::string &name) {
123+
CHECK(state.contains(name)) << "Model state_dict does not contain " << name;
124+
std::shared_ptr<infini_train::Tensor> tensor = state.at(name);
125+
CHECK(tensor->Dtype() == infini_train::DataType::kFLOAT32)
126+
<< "Only float32 tiny Mixtral LLMC files are supported: " << name;
127+
infini_train::ReadMatrixAllFloat(ifs, static_cast<float *>(tensor->DataPtr()), tensor->NumElements(), 1);
128+
CHECK(ifs) << "Failed to read tensor " << name;
129+
};
130+
131+
auto read_projection_into_packed_qkv = [&](const std::string &packed_qkv_name, int64_t row_offset,
132+
int64_t num_rows, const std::string &projection_name) {
133+
CHECK(state.contains(packed_qkv_name)) << "Model state_dict does not contain " << packed_qkv_name;
134+
std::shared_ptr<infini_train::Tensor> tensor = state.at(packed_qkv_name);
135+
CHECK(tensor->Dtype() == infini_train::DataType::kFLOAT32)
136+
<< "Only float32 tiny Mixtral LLMC files are supported: " << projection_name;
137+
CHECK_EQ(tensor->Dims().size(), 2);
138+
CHECK_GE(row_offset, 0);
139+
CHECK_GT(num_rows, 0);
140+
CHECK_LE(row_offset + num_rows, tensor->Dims()[0]);
141+
const int64_t cols = tensor->Dims()[1];
142+
auto *data = static_cast<float *>(tensor->DataPtr()) + row_offset * cols;
143+
infini_train::ReadMatrixAllFloat(ifs, data, num_rows, cols);
144+
CHECK(ifs) << "Failed to read tensor rows " << projection_name;
145+
};
146+
147+
const auto &moe_config = config.moe_config.value();
148+
read_tensor_by_state_key("transformer.wte.weight");
149+
for (int64_t layer = 0; layer < config.n_layer; ++layer) {
150+
const std::string prefix = "transformer.h." + std::to_string(layer);
151+
read_tensor_by_state_key(prefix + ".ln_1.weight");
152+
const auto c_attn_name = prefix + ".attn.c_attn.weight";
153+
const int64_t head_dim = config.n_embd / config.n_head;
154+
const int64_t q_rows = config.n_head * head_dim;
155+
const int64_t kv_rows = config.n_kv_head * head_dim;
156+
read_projection_into_packed_qkv(c_attn_name, 0, q_rows, c_attn_name + ".q_proj");
157+
read_projection_into_packed_qkv(c_attn_name, q_rows, kv_rows, c_attn_name + ".k_proj");
158+
read_projection_into_packed_qkv(c_attn_name, q_rows + kv_rows, kv_rows, c_attn_name + ".v_proj");
159+
read_tensor_by_state_key(prefix + ".attn.c_proj.weight");
160+
read_tensor_by_state_key(prefix + ".ln_2.weight");
161+
read_tensor_by_state_key(prefix + ".mlp.router.weight");
162+
for (int64_t expert = 0; expert < moe_config.num_experts; ++expert) {
163+
const std::string expert_prefix = prefix + ".mlp.experts.expert_" + std::to_string(expert);
164+
read_tensor_by_state_key(expert_prefix + ".c_fc2.weight"); // Mixtral w1/gate_proj
165+
read_tensor_by_state_key(expert_prefix + ".c_fc.weight"); // Mixtral w3/up_proj
166+
read_tensor_by_state_key(expert_prefix + ".c_proj.weight"); // Mixtral w2/down_proj
167+
}
168+
}
169+
read_tensor_by_state_key("transformer.ln_f.weight");
170+
read_tensor_by_state_key("lm_head.weight");
171+
172+
CHECK_EQ(ifs.peek(), std::ifstream::traits_type::eof()) << "Unexpected trailing bytes in tiny Mixtral LLMC file";
173+
return model;
174+
}
175+
176+
} // namespace tiny_mixtral
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#pragma once
2+
3+
#include <memory>
4+
#include <string>
5+
6+
#include "infini_train/include/nn/modules/transformer/transformer_config.h"
7+
8+
namespace infini_train::nn {
9+
class TransformerModel;
10+
} // namespace infini_train::nn
11+
12+
namespace tiny_mixtral {
13+
14+
infini_train::nn::TransformerConfig ConfigFromLLMC(const std::string &filepath);
15+
16+
void CheckLLMCConfig(const std::string &filepath, const infini_train::nn::TransformerConfig &expected_config);
17+
18+
std::shared_ptr<infini_train::nn::TransformerModel> LoadFromLLMC(
19+
const std::string &filepath, const infini_train::nn::TransformerConfig &expected_config);
20+
21+
} // namespace tiny_mixtral

example/tiny_mixtral/config.h

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
#pragma once
2+
3+
#include "glog/logging.h"
4+
5+
#include "infini_train/include/nn/modules/transformer/transformer_config.h"
6+
7+
namespace nn = infini_train::nn;
8+
9+
namespace tiny_mixtral {
10+
11+
inline nn::TransformerConfig TinyMixtralConfig() {
12+
nn::TransformerConfig config;
13+
config.block_size = 16; // Tiny max positions; Megatron --max-position-embeddings is 32768.
14+
config.vocab_size = 128256; // Validation data uses LLaMA3 token ids; real Mixtral uses 32000.
15+
config.original_vocab_size = 128256;
16+
config.n_layer = 2; // Tiny scale; Megatron --num-layers 32.
17+
config.n_head = 4; // Tiny scale; preserves the Megatron 4:1 GQA ratio.
18+
config.n_kv_head = 1; // Tiny scale; Megatron --num-query-groups 8.
19+
config.n_embd = 32; // Tiny scale; Megatron --hidden-size 4096.
20+
config.attention_type = nn::AttentionType::kRoPE;
21+
config.activation_type = nn::MLPType::kSwiGLU;
22+
config.ffn_type = nn::FFNType::kMoE;
23+
config.norm_type = nn::NormType::kRMSNorm;
24+
config.add_bias_linear = false;
25+
config.add_bias_lm_head = false;
26+
config.tie_weights = false;
27+
config.ffn_expansion_ratio = 3.5f;
28+
config.ffn_dim_multiplier = 1.0f; // InfiniTrain helper; Megatron passes --ffn-hidden-size directly.
29+
config.multiple_of = 1; // InfiniTrain helper; Megatron passes --ffn-hidden-size directly.
30+
config.norm_eps = 1e-5f;
31+
config.rope_theta = 1000000.0f;
32+
config.use_scaled_rope = false;
33+
34+
nn::MoEConfig moe_config;
35+
moe_config.num_experts = 8;
36+
moe_config.expert_parallel_size = 1; // Single-rank validation scale.
37+
moe_config.router_topk = 2;
38+
moe_config.moe_ffn_hidden_size = 112; // Tiny scale; Megatron --ffn-hidden-size 14336.
39+
moe_config.dispatcher_type = nn::MoEConfig::DispatcherType::kAllGather; // Single-rank validation path.
40+
moe_config.expert_impl = nn::MoEConfig::ExpertImpl::kSequential; // Local correctness path.
41+
config.moe_config = moe_config;
42+
return config;
43+
}
44+
45+
inline void SanitizeTinyMixtralConfig(const nn::TransformerConfig &c) {
46+
CHECK_GT(c.block_size, 0);
47+
CHECK_GT(c.vocab_size, 0);
48+
CHECK_GE(c.vocab_size, c.original_vocab_size);
49+
CHECK_GT(c.n_layer, 0);
50+
CHECK_GT(c.n_head, 0);
51+
CHECK_GT(c.n_kv_head, 0);
52+
CHECK_LE(c.n_kv_head, c.n_head);
53+
CHECK_EQ(c.n_head % c.n_kv_head, 0) << "n_head must be divisible by n_kv_head for GQA";
54+
CHECK_GT(c.n_embd, 0);
55+
CHECK_EQ(c.n_embd % c.n_head, 0) << "n_embd must be divisible by n_head";
56+
CHECK(c.attention_type == nn::AttentionType::kRoPE) << "tiny Mixtral requires RoPE attention";
57+
CHECK(c.activation_type == nn::MLPType::kSwiGLU) << "tiny Mixtral requires SwiGLU activation";
58+
CHECK(c.ffn_type == nn::FFNType::kMoE) << "tiny Mixtral requires MoE FFN";
59+
CHECK(c.norm_type == nn::NormType::kRMSNorm) << "tiny Mixtral requires RMSNorm";
60+
CHECK(!c.add_bias_linear) << "tiny Mixtral has no bias in linear layers";
61+
CHECK(!c.add_bias_lm_head) << "tiny Mixtral has no bias in lm_head";
62+
CHECK(!c.tie_weights) << "tiny Mixtral does not tie embedding and lm_head weights";
63+
CHECK(c.ffn_dim_multiplier.has_value()) << "tiny Mixtral requires ffn_dim_multiplier";
64+
CHECK_GT(c.multiple_of, 0);
65+
CHECK(!c.use_scaled_rope) << "tiny Mixtral precision validation keeps scaled RoPE disabled";
66+
CHECK(c.moe_config.has_value()) << "tiny Mixtral requires MoE config";
67+
68+
const auto &moe = c.moe_config.value();
69+
CHECK_GT(moe.num_experts, 0);
70+
CHECK_EQ(moe.expert_parallel_size, 1) << "tiny Mixtral single-rank validation expects EP=1";
71+
CHECK_GT(moe.router_topk, 0);
72+
CHECK_LE(moe.router_topk, moe.num_experts);
73+
CHECK_GT(moe.moe_ffn_hidden_size, 0);
74+
CHECK(moe.dispatcher_type == nn::MoEConfig::DispatcherType::kAllGather)
75+
<< "tiny Mixtral uses the Megatron-style AllGather dispatcher";
76+
CHECK(moe.expert_impl == nn::MoEConfig::ExpertImpl::kSequential)
77+
<< "tiny Mixtral validation uses SequentialMLP experts";
78+
}
79+
80+
} // namespace tiny_mixtral

0 commit comments

Comments
 (0)