Skip to content

Commit dfdd913

Browse files
committed
feat: extract the common module of Transformer
1 parent b1e4b03 commit dfdd913

File tree

20 files changed

+1963
-1166
lines changed

20 files changed

+1963
-1166
lines changed

CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,3 +200,6 @@ target_link_libraries(test_hook infini_train)
200200

201201
add_executable(test_precision_check test/hook/test_precision_check.cc)
202202
target_link_libraries(test_precision_check infini_train)
203+
204+
add_executable(test_transformer_spec test/transformer_spec/test_transformer_spec.cc)
205+
link_infini_train_exe(test_transformer_spec)

example/gpt2/main.cc

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010
#include "glog/logging.h"
1111

1212
#include "infini_train/include/autocast.h"
13+
#include "infini_train/include/core/models/decode_only_transformer/model.h"
1314
#include "infini_train/include/core/runtime/device_guard.h"
15+
#include "infini_train/include/core/transformer/transformer_config.h"
1416
#include "infini_train/include/dataloader.h"
1517
#include "infini_train/include/device.h"
1618
#include "infini_train/include/nn/modules/loss.h"
@@ -34,7 +36,6 @@
3436

3537
#include "example/common/tiny_shakespeare_dataset.h"
3638
#include "example/common/tokenizer.h"
37-
#include "example/gpt2/net.h"
3839

3940
// I/O
4041
DEFINE_string(input_bin, "", "input .bin to train on");
@@ -91,7 +92,7 @@ constexpr char kDtypeFP32[] = "float32";
9192
constexpr char kDtypeBF16[] = "bfloat16";
9293

9394
//
94-
const std::unordered_map<std::string, GPT2Config> kModelToConfigs = {
95+
const std::unordered_map<std::string, nn::TransformerConfig> kModelToConfigs = {
9596
{"d12", {.block_size = 1024, .vocab_size = 50257, .n_layer = 12, .n_head = 12, .n_embd = 768}},
9697
{"d24", {.block_size = 1024, .vocab_size = 50257, .n_layer = 24, .n_head = 16, .n_embd = 1024}},
9798
{"d36", {.block_size = 1024, .vocab_size = 50257, .n_layer = 36, .n_head = 20, .n_embd = 1280}},
@@ -178,10 +179,12 @@ void Train(const nn::parallel::Rank &rank) {
178179
// ManualSeed(42);
179180

180181
// init the model, either from scratch or from OpenAI pretrained checkpoint
181-
GPT2Config model_config;
182+
nn::TransformerConfig model_config;
182183
std::shared_ptr<nn::Module> model = nullptr;
183184
if (!FLAGS_llmc_filepath.empty()) {
184-
model = GPT2::FromLLMC(FLAGS_llmc_filepath);
185+
auto gpt2_model = GPT2::FromLLMC(FLAGS_llmc_filepath);
186+
model_config = gpt2_model->GetConfig();
187+
model = gpt2_model;
185188
} else if (kModelToConfigs.count(FLAGS_model)) {
186189
model_config = kModelToConfigs.at(FLAGS_model);
187190
model = std::make_shared<GPT2>(model_config);

example/gpt2/net.cc

Lines changed: 74 additions & 356 deletions
Large diffs are not rendered by default.

example/gpt2/net.h

Lines changed: 0 additions & 150 deletions
This file was deleted.

example/llama3/main.cc

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88
#include "glog/logging.h"
99

1010
#include "infini_train/include/autocast.h"
11+
#include "infini_train/include/core/models/decode_only_transformer/model.h"
1112
#include "infini_train/include/core/runtime/device_guard.h"
13+
#include "infini_train/include/core/transformer/transformer_config.h"
1214
#include "infini_train/include/dataloader.h"
1315
#include "infini_train/include/device.h"
1416
#include "infini_train/include/nn/modules/loss.h"
@@ -33,7 +35,6 @@
3335

3436
#include "example/common/tiny_shakespeare_dataset.h"
3537
#include "example/common/tokenizer.h"
36-
#include "example/llama3/net.h"
3738

3839
// I/O
3940
DEFINE_string(input_bin, "", "input .bin to train on");
@@ -160,10 +161,12 @@ void Train(const nn::parallel::Rank &rank) {
160161
// rng / reproducibility
161162
// ManualSeed(42);
162163

163-
LLaMA3Config model_config = LLaMA3Config();
164+
nn::TransformerConfig model_config;
164165
std::shared_ptr<nn::Module> model = nullptr;
165166
if (!FLAGS_llmc_filepath.empty()) {
166-
model = LLaMA3::FromLLMC(FLAGS_llmc_filepath);
167+
auto llama3_model = LLaMA3::FromLLMC(FLAGS_llmc_filepath);
168+
model_config = llama3_model->GetConfig();
169+
model = llama3_model;
167170
} else {
168171
model = std::make_shared<LLaMA3>(model_config);
169172
}

0 commit comments

Comments
 (0)