Skip to content

Commit 1e43784

Browse files
committed
feat: extract the common module of Transformer
1 parent 77030ca commit 1e43784

4 files changed

Lines changed: 33 additions & 8 deletions

File tree

example/gpt2/main.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,8 +192,7 @@ void Train(const nn::parallel::Rank &rank) {
192192
std::shared_ptr<nn::Module> model = nullptr;
193193

194194
if (!FLAGS_llmc_filepath.empty()) {
195-
auto gpt2_model = GPT2::FromLLMC(FLAGS_llmc_filepath);
196-
model = gpt2_model;
195+
model = GPT2::FromLLMC(FLAGS_llmc_filepath);
197196
} else if (kModelToConfigs.count(FLAGS_model)) {
198197
model_config = kModelToConfigs.at(FLAGS_model);
199198
model = std::make_shared<GPT2>(model_config);

example/llama3/main.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,8 +171,7 @@ void Train(const nn::parallel::Rank &rank) {
171171
nn::TransformerConfig model_config = nn::TransformerConfig::LLaMA3();
172172
std::shared_ptr<nn::Module> model = nullptr;
173173
if (!FLAGS_llmc_filepath.empty()) {
174-
auto llama3_model = LLaMA3::FromLLMC(FLAGS_llmc_filepath);
175-
model = llama3_model;
174+
model = LLaMA3::FromLLMC(FLAGS_llmc_filepath);
176175
} else {
177176
model = std::make_shared<LLaMA3>(model_config);
178177
}

example/llama3/net.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,12 @@ std::shared_ptr<LLaMA3> LLaMA3::FromLLMC(const std::string &filepath) {
7575
llama3_config.n_head = n_head;
7676
llama3_config.n_kv_head = n_kv_head;
7777
llama3_config.n_embd = n_embd;
78+
llama3_config.ffn_dim_multiplier = ffn_dim_multiplier;
79+
llama3_config.multiple_of = multiple_of;
80+
llama3_config.rope_theta = rope_theta;
81+
llama3_config.use_scaled_rope = static_cast<bool>(use_scaled_rope);
82+
llama3_config.norm_eps = norm_eps;
83+
llama3_config.max_gen_batch_size = max_gen_bs;
7884
auto llama3 = std::make_shared<LLaMA3>(llama3_config);
7985

8086
// ========== pp_size:num_stages; vpp_size: num_chunks_per_stage ==========

infini_train/include/core/transformer/transformer_config.h

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@ enum class NormType {
2020
kRMSNorm // RMSNorm (LLaMA3 style)
2121
};
2222

23-
class TransformerConfig {
24-
public:
23+
struct TransformerConfig {
2524
static constexpr char kGPT2Name[] = "GPT2";
2625
static constexpr char kLLaMA3Name[] = "LLaMA3";
2726

@@ -61,7 +60,26 @@ class TransformerConfig {
6160
bool flash = false; // flash attention
6261
int64_t max_gen_batch_size = 4; // max batch size during inference
6362

64-
static TransformerConfig GPT2() { return {}; }
63+
static TransformerConfig GPT2() {
64+
return {.model_type = kGPT2Name,
65+
.block_size = 1024,
66+
.vocab_size = 50304,
67+
.original_vocab_size = 50257,
68+
.n_layer = 12,
69+
.n_head = 12,
70+
.n_kv_head = 12,
71+
.n_embd = 768,
72+
.attention_type = AttentionType::kStandard,
73+
.activation_type = MLPType::kGELU,
74+
.norm_type = NormType::kLayerNorm,
75+
.use_bias = true,
76+
.use_gqa = false,
77+
.use_rope = false,
78+
.tie_weights = true,
79+
.ffn_expansion_ratio = 4.0f,
80+
.ffn_dim_multiplier = std::nullopt,
81+
.multiple_of = 1};
82+
}
6583

6684
static TransformerConfig LLaMA3() {
6785
return {.model_type = kLLaMA3Name,
@@ -77,7 +95,10 @@ class TransformerConfig {
7795
.use_bias = false,
7896
.use_gqa = true,
7997
.use_rope = true,
80-
.tie_weights = false};
98+
.tie_weights = false,
99+
.ffn_expansion_ratio = 4.0f,
100+
.ffn_dim_multiplier = 1.5f,
101+
.multiple_of = 256};
81102
}
82103
};
83104
} // namespace infini_train::nn

0 commit comments

Comments
 (0)