Skip to content

Commit 388b4bb

Browse files
committed
feat: unify GPT2 and LLaMA3 into DecoderOnlyTransformer
1 parent 0bcc009 commit 388b4bb

File tree

7 files changed

+99
-108
lines changed

7 files changed

+99
-108
lines changed

example/gpt2/main.cc

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -107,11 +107,11 @@ const std::unordered_map<std::string, nn::TransformerConfig> kModelToConfigs = {
107107
{"d36", {.block_size = 1024, .vocab_size = 50257, .n_layer = 36, .n_head = 20, .n_embd = 1280}},
108108
{"d48", {.block_size = 1024, .vocab_size = 50257, .n_layer = 48, .n_head = 25, .n_embd = 1600}},
109109
};
110-
const std::unordered_map<std::string, GPT2::ModelType> kStrToModelType = {
111-
{"gpt2", GPT2::ModelType::kGPT2},
112-
{"gpt2-medium", GPT2::ModelType::kGPT2Medium},
113-
{"gpt2-large", GPT2::ModelType::kGPT2Large},
114-
{"gpt2-xl", GPT2::ModelType::kGPT2XL},
110+
const std::unordered_map<std::string, DecoderOnlyTransformer::ModelType> kStrToModelType = {
111+
{"gpt2", DecoderOnlyTransformer::ModelType::kGPT2},
112+
{"gpt2-medium", DecoderOnlyTransformer::ModelType::kGPT2Medium},
113+
{"gpt2-large", DecoderOnlyTransformer::ModelType::kGPT2Large},
114+
{"gpt2-xl", DecoderOnlyTransformer::ModelType::kGPT2XL},
115115
};
116116

117117
} // namespace
@@ -192,20 +192,20 @@ void Train(const nn::parallel::Rank &rank) {
192192
std::shared_ptr<nn::Module> model = nullptr;
193193

194194
if (!FLAGS_llmc_filepath.empty()) {
195-
model = GPT2::FromLLMC(FLAGS_llmc_filepath);
195+
model = DecoderOnlyTransformer::FromLLMC_GPT2(FLAGS_llmc_filepath);
196196
} else if (kModelToConfigs.count(FLAGS_model)) {
197197
model_config = kModelToConfigs.at(FLAGS_model);
198-
model = std::make_shared<GPT2>(model_config);
198+
model = std::make_shared<DecoderOnlyTransformer>(model_config);
199199
} else {
200-
model = GPT2::FromPretrained(kStrToModelType.at(FLAGS_model));
200+
model = DecoderOnlyTransformer::FromPretrained(kStrToModelType.at(FLAGS_model));
201201
}
202202

203203
model->To(device);
204204

205205
utils::PrecisionChecker::BuildNameMap(model.get());
206206

207207
// Get chunk size before wrapping with LoRA (needed for PipelineParallel)
208-
auto gpt2_model = std::dynamic_pointer_cast<GPT2>(model);
208+
auto gpt2_model = std::dynamic_pointer_cast<DecoderOnlyTransformer>(model);
209209
CHECK(gpt2_model) << "GPT2 example expects GPT2 model.";
210210

211211
// Apply LoRA using GetLoRAModel (in-place injection)

example/gpt2/net.cc

Lines changed: 35 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,6 @@ constexpr int kRandomSeed = 42;
3232
static std::mt19937 gen{kRandomSeed};
3333
} // namespace
3434

35-
std::shared_ptr<GPT2> GPT2::FromPretrained(ModelType model_type) {
36-
// TODO(dcj): implement this later
37-
LOG(FATAL) << "Not implemented yet";
38-
return nullptr;
39-
}
40-
4135
namespace {
4236
constexpr int32_t kHeaderMagic = 20240326;
4337
constexpr int32_t kHeaderFP32Version = 3;
@@ -58,7 +52,7 @@ std::tuple<int32_t, infini_train::DataType> DetermineAndCheckVersion(const std::
5852
}
5953
} // namespace
6054

61-
std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
55+
std::shared_ptr<DecoderOnlyTransformer> DecoderOnlyTransformer::FromLLMC_GPT2(const std::string &filepath) {
6256
if (!std::filesystem::exists(filepath)) {
6357
LOG(FATAL) << "File not found: " << filepath;
6458
}
@@ -89,7 +83,7 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
8983
gpt2_config.n_layer = n_layer;
9084
gpt2_config.n_head = n_head;
9185
gpt2_config.n_embd = n_embd;
92-
auto local_gpt2 = std::make_shared<GPT2>(gpt2_config);
86+
auto local_gpt2 = std::make_shared<DecoderOnlyTransformer>(gpt2_config);
9387

9488
LOG(INFO) << "magic: " << magic << " version: " << version << " block_size: " << block_size
9589
<< " vocab_size: " << vocab_size << " n_layer: " << n_layer << " n_head: " << n_head
@@ -135,7 +129,7 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
135129
// local: (vocab_size_per_partition, n_embd)
136130
if (is_first_stage) {
137131
auto &transformer_wte_weight
138-
= state_dict[std::format("{}.{}.{}", GPT2::kTransformerModelName, nn::TransformerFirstStage::kWTELayerName,
132+
= state_dict[std::format("{}.{}.{}", kTransformerModelName, nn::TransformerFirstStage::kWTELayerName,
139133
nn::parallel::VocabParallelEmbedding::kParamWeightName)];
140134
ReadMatrixRowShardFloat(ifs, static_cast<float *>(transformer_wte_weight->DataPtr()), model_vocab_size, n_embd,
141135
v_start, vpp);
@@ -157,7 +151,7 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
157151
if (is_first_stage) {
158152
// transformer.wpe.weight
159153
auto &transformer_wpe_weight
160-
= state_dict[std::format("{}.{}.{}", GPT2::kTransformerModelName, nn::TransformerFirstStage::kWPELayerName,
154+
= state_dict[std::format("{}.{}.{}", kTransformerModelName, nn::TransformerFirstStage::kWPELayerName,
161155
nn::Embedding::kParamWeightName)];
162156
ReadMatrixAllFloat(ifs, static_cast<float *>(transformer_wpe_weight->DataPtr()), block_size, n_embd);
163157
} else {
@@ -170,9 +164,9 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
170164
for (int idx = 0; idx < n_layer; ++idx) {
171165
if (owned_layers[idx]) {
172166
auto &tensor
173-
= state_dict[std::format("{}.{}.{}.{}.{}", GPT2::kTransformerModelName,
174-
nn::TransformerChunk::kHLayerName, std::to_string(local_layer_index),
175-
nn::TransformerLayer::kLn1LayerName, nn::LayerNorm::kParamWeightName)];
167+
= state_dict[std::format("{}.{}.{}.{}.{}", kTransformerModelName, nn::TransformerChunk::kHLayerName,
168+
std::to_string(local_layer_index), nn::TransformerLayer::kLn1LayerName,
169+
nn::LayerNorm::kParamWeightName)];
176170
ReadVectorAllFloat(ifs, static_cast<float *>(tensor->DataPtr()), n_embd);
177171
++local_layer_index;
178172
} else {
@@ -185,7 +179,7 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
185179
local_layer_index = 0;
186180
for (int idx = 0; idx < n_layer; ++idx) {
187181
if (owned_layers[idx]) {
188-
auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", GPT2::kTransformerModelName,
182+
auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", kTransformerModelName,
189183
nn::TransformerChunk::kHLayerName, std::to_string(local_layer_index),
190184
nn::TransformerLayer::kLn1LayerName, nn::LayerNorm::kParamBiasName)];
191185
ReadVectorAllFloat(ifs, static_cast<float *>(tensor->DataPtr()), n_embd);
@@ -201,7 +195,7 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
201195
for (int idx = 0; idx < n_layer; ++idx) {
202196
if (owned_layers[idx]) {
203197
auto &tensor = state_dict[std::format(
204-
"{}.{}.{}.{}.{}.{}", GPT2::kTransformerModelName, nn::TransformerChunk::kHLayerName,
198+
"{}.{}.{}.{}.{}.{}", kTransformerModelName, nn::TransformerChunk::kHLayerName,
205199
std::to_string(local_layer_index), nn::TransformerLayer::kAttnLayerName,
206200
nn::CausalSelfAttention::kCAttnLayerName, nn::parallel::ColumnParallelLinear::kParamWeightName)];
207201
// NOTE(zbl): In the .bin model file, Q/K/V is concated along last dim,
@@ -244,7 +238,7 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
244238
for (int idx = 0; idx < n_layer; ++idx) {
245239
if (owned_layers[idx]) {
246240
auto &tensor = state_dict[std::format(
247-
"{}.{}.{}.{}.{}.{}", GPT2::kTransformerModelName, nn::TransformerChunk::kHLayerName,
241+
"{}.{}.{}.{}.{}.{}", kTransformerModelName, nn::TransformerChunk::kHLayerName,
248242
std::to_string(local_layer_index), nn::TransformerLayer::kAttnLayerName,
249243
nn::CausalSelfAttention::kCAttnLayerName, nn::parallel::ColumnParallelLinear::kParamBiasName)];
250244
// NOTE(zbl): Same as c_attn.weight, the bias for Q/K/V is concated
@@ -286,7 +280,7 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
286280
for (int idx = 0; idx < n_layer; ++idx) {
287281
if (owned_layers[idx]) {
288282
auto &tensor = state_dict[std::format(
289-
"{}.{}.{}.{}.{}.{}", GPT2::kTransformerModelName, nn::TransformerChunk::kHLayerName,
283+
"{}.{}.{}.{}.{}.{}", kTransformerModelName, nn::TransformerChunk::kHLayerName,
290284
std::to_string(local_layer_index), nn::TransformerLayer::kAttnLayerName,
291285
nn::CausalSelfAttention::kCProjLayerName, nn::parallel::RowParallelLinear::kParamWeightName)];
292286
ReadMatrixColShardFloat(ifs, static_cast<float *>(tensor->DataPtr()), n_embd, n_embd, tp_rank * in_pp,
@@ -303,7 +297,7 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
303297
for (int idx = 0; idx < n_layer; ++idx) {
304298
if (owned_layers[idx]) {
305299
auto &tensor = state_dict[std::format(
306-
"{}.{}.{}.{}.{}.{}", GPT2::kTransformerModelName, nn::TransformerChunk::kHLayerName,
300+
"{}.{}.{}.{}.{}.{}", kTransformerModelName, nn::TransformerChunk::kHLayerName,
307301
std::to_string(local_layer_index), nn::TransformerLayer::kAttnLayerName,
308302
nn::CausalSelfAttention::kCProjLayerName, nn::parallel::RowParallelLinear::kParamBiasName)];
309303
ReadVectorAllFloat(ifs, static_cast<float *>(tensor->DataPtr()), n_embd);
@@ -319,9 +313,9 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
319313
for (int idx = 0; idx < n_layer; ++idx) {
320314
if (owned_layers[idx]) {
321315
auto &tensor
322-
= state_dict[std::format("{}.{}.{}.{}.{}", GPT2::kTransformerModelName,
323-
nn::TransformerChunk::kHLayerName, std::to_string(local_layer_index),
324-
nn::TransformerLayer::kLn2LayerName, nn::LayerNorm::kParamWeightName)];
316+
= state_dict[std::format("{}.{}.{}.{}.{}", kTransformerModelName, nn::TransformerChunk::kHLayerName,
317+
std::to_string(local_layer_index), nn::TransformerLayer::kLn2LayerName,
318+
nn::LayerNorm::kParamWeightName)];
325319
ReadVectorAllFloat(ifs, static_cast<float *>(tensor->DataPtr()), n_embd);
326320
++local_layer_index;
327321
} else {
@@ -334,7 +328,7 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
334328
local_layer_index = 0;
335329
for (int idx = 0; idx < n_layer; ++idx) {
336330
if (owned_layers[idx]) {
337-
auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", GPT2::kTransformerModelName,
331+
auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", kTransformerModelName,
338332
nn::TransformerChunk::kHLayerName, std::to_string(local_layer_index),
339333
nn::TransformerLayer::kLn2LayerName, nn::LayerNorm::kParamBiasName)];
340334
ReadVectorAllFloat(ifs, static_cast<float *>(tensor->DataPtr()), n_embd);
@@ -349,10 +343,10 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
349343
local_layer_index = 0;
350344
for (int idx = 0; idx < n_layer; ++idx) {
351345
if (owned_layers[idx]) {
352-
auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerModelName,
353-
nn::TransformerChunk::kHLayerName, std::to_string(local_layer_index),
354-
nn::TransformerLayer::kMlpLayerName, nn::MLP::kCFcLayerName,
355-
nn::parallel::ColumnParallelLinear::kParamWeightName)];
346+
auto &tensor
347+
= state_dict[std::format("{}.{}.{}.{}.{}.{}", kTransformerModelName, nn::TransformerChunk::kHLayerName,
348+
std::to_string(local_layer_index), nn::TransformerLayer::kMlpLayerName,
349+
nn::MLP::kCFcLayerName, nn::parallel::ColumnParallelLinear::kParamWeightName)];
356350
ReadMatrixRowShardFloat(ifs, static_cast<float *>(tensor->DataPtr()), fc_out, n_embd, fc_start, fc_pp);
357351
++local_layer_index;
358352
} else {
@@ -365,10 +359,10 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
365359
local_layer_index = 0;
366360
for (int idx = 0; idx < n_layer; ++idx) {
367361
if (owned_layers[idx]) {
368-
auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerModelName,
369-
nn::TransformerChunk::kHLayerName, std::to_string(local_layer_index),
370-
nn::TransformerLayer::kMlpLayerName, nn::MLP::kCFcLayerName,
371-
nn::parallel::ColumnParallelLinear::kParamBiasName)];
362+
auto &tensor
363+
= state_dict[std::format("{}.{}.{}.{}.{}.{}", kTransformerModelName, nn::TransformerChunk::kHLayerName,
364+
std::to_string(local_layer_index), nn::TransformerLayer::kMlpLayerName,
365+
nn::MLP::kCFcLayerName, nn::parallel::ColumnParallelLinear::kParamBiasName)];
372366
ReadVectorShardFloat(ifs, static_cast<float *>(tensor->DataPtr()), fc_out, fc_start, fc_pp);
373367
++local_layer_index;
374368
} else {
@@ -381,10 +375,10 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
381375
local_layer_index = 0;
382376
for (int idx = 0; idx < n_layer; ++idx) {
383377
if (owned_layers[idx]) {
384-
auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerModelName,
385-
nn::TransformerChunk::kHLayerName, std::to_string(local_layer_index),
386-
nn::TransformerLayer::kMlpLayerName, nn::MLP::kCProjLayerName,
387-
nn::parallel::RowParallelLinear::kParamWeightName)];
378+
auto &tensor
379+
= state_dict[std::format("{}.{}.{}.{}.{}.{}", kTransformerModelName, nn::TransformerChunk::kHLayerName,
380+
std::to_string(local_layer_index), nn::TransformerLayer::kMlpLayerName,
381+
nn::MLP::kCProjLayerName, nn::parallel::RowParallelLinear::kParamWeightName)];
388382
ReadMatrixColShardFloat(ifs, static_cast<float *>(tensor->DataPtr()), n_embd, fc_out, tp_rank * in4_pp,
389383
in4_pp);
390384
++local_layer_index;
@@ -398,10 +392,10 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
398392
local_layer_index = 0;
399393
for (int idx = 0; idx < n_layer; ++idx) {
400394
if (owned_layers[idx]) {
401-
auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerModelName,
402-
nn::TransformerChunk::kHLayerName, std::to_string(local_layer_index),
403-
nn::TransformerLayer::kMlpLayerName, nn::MLP::kCProjLayerName,
404-
nn::parallel::RowParallelLinear::kParamBiasName)];
395+
auto &tensor
396+
= state_dict[std::format("{}.{}.{}.{}.{}.{}", kTransformerModelName, nn::TransformerChunk::kHLayerName,
397+
std::to_string(local_layer_index), nn::TransformerLayer::kMlpLayerName,
398+
nn::MLP::kCProjLayerName, nn::parallel::RowParallelLinear::kParamBiasName)];
405399
ReadVectorAllFloat(ifs, static_cast<float *>(tensor->DataPtr()), n_embd);
406400
++local_layer_index;
407401
} else {
@@ -413,13 +407,12 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
413407
if (is_last_stage) {
414408
// transformer.ln_f.weight
415409
auto &transformer_ln_f_weight
416-
= state_dict[std::format("{}.{}.{}", GPT2::kTransformerModelName, nn::TransformerLastStage::kLnFLayerName,
410+
= state_dict[std::format("{}.{}.{}", kTransformerModelName, nn::TransformerLastStage::kLnFLayerName,
417411
nn::LayerNorm::kParamWeightName)];
418412
ReadVectorAllFloat(ifs, static_cast<float *>(transformer_ln_f_weight->DataPtr()), n_embd);
419413
// transformer.ln_f.bias
420-
auto &transformer_ln_f_bias
421-
= state_dict[std::format("{}.{}.{}", GPT2::kTransformerModelName, nn::TransformerLastStage::kLnFLayerName,
422-
nn::LayerNorm::kParamBiasName)];
414+
auto &transformer_ln_f_bias = state_dict[std::format(
415+
"{}.{}.{}", kTransformerModelName, nn::TransformerLastStage::kLnFLayerName, nn::LayerNorm::kParamBiasName)];
423416
ReadVectorAllFloat(ifs, static_cast<float *>(transformer_ln_f_bias->DataPtr()), n_embd);
424417
} else {
425418
size_t ln_f_w_bytes = n_embd * sizeof(float);
@@ -428,5 +421,3 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
428421
}
429422
return local_gpt2;
430423
}
431-
432-
int GPT2::GetChunkSize() const { return stage_info_.layer_ranges_per_chunk.size(); }

example/llama3/main.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,9 +171,9 @@ void Train(const nn::parallel::Rank &rank) {
171171
nn::TransformerConfig model_config = nn::TransformerConfig::LLaMA3();
172172
std::shared_ptr<nn::Module> model = nullptr;
173173
if (!FLAGS_llmc_filepath.empty()) {
174-
model = LLaMA3::FromLLMC(FLAGS_llmc_filepath);
174+
model = DecoderOnlyTransformer::FromLLMC_LLaMA3(FLAGS_llmc_filepath);
175175
} else {
176-
model = std::make_shared<LLaMA3>(model_config);
176+
model = std::make_shared<DecoderOnlyTransformer>(model_config);
177177
}
178178

179179
model->To(device);
@@ -220,7 +220,7 @@ void Train(const nn::parallel::Rank &rank) {
220220

221221
model = std::make_shared<nn::parallel::PipelineParallel>(
222222
model, pp_world_size, num_micro_batches, shapes, pp_rank, device,
223-
std::dynamic_pointer_cast<LLaMA3>(model)->GetChunkSize());
223+
std::dynamic_pointer_cast<DecoderOnlyTransformer>(model)->GetChunkSize());
224224
if (ddp_world_size > 1) {
225225
auto ddp_config
226226
= DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer};

example/llama3/net.cc

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "infini_train/include/device.h"
1919
#include "infini_train/include/nn/modules/causal_self_attention.h"
2020
#include "infini_train/include/nn/modules/mlp.h"
21+
#include "infini_train/include/nn/modules/normalization.h"
2122
#include "infini_train/include/nn/modules/transformer.h"
2223
#include "infini_train/include/nn/parallel/tensor_parallel.h"
2324

@@ -31,18 +32,12 @@ constexpr int kRandomSeed = 42;
3132
static std::mt19937 gen{kRandomSeed};
3233
} // namespace
3334

34-
std::shared_ptr<LLaMA3> LLaMA3::FromPretrained(ModelType model_type) {
35-
// TODO(zbl): implement this later
36-
LOG(FATAL) << "Not implemented yet";
37-
return nullptr;
38-
}
39-
4035
namespace {
4136
constexpr int32_t kLLaMA3Magic = 20240803;
4237
constexpr int32_t kLLaMA3FP32Version = 3;
4338
} // namespace
4439

45-
std::shared_ptr<LLaMA3> LLaMA3::FromLLMC(const std::string &filepath) {
40+
std::shared_ptr<DecoderOnlyTransformer> DecoderOnlyTransformer::FromLLMC_LLaMA3(const std::string &filepath) {
4641
if (!std::filesystem::exists(filepath)) {
4742
LOG(FATAL) << "File not found: " << filepath;
4843
}
@@ -83,7 +78,7 @@ std::shared_ptr<LLaMA3> LLaMA3::FromLLMC(const std::string &filepath) {
8378
llama3_config.use_scaled_rope = static_cast<bool>(use_scaled_rope);
8479
llama3_config.norm_eps = norm_eps;
8580
llama3_config.max_gen_batch_size = max_gen_bs;
86-
auto llama3 = std::make_shared<LLaMA3>(llama3_config);
81+
auto llama3 = std::make_shared<DecoderOnlyTransformer>(llama3_config);
8782

8883
// ========== pp_size:num_stages; vpp_size: num_chunks_per_stage ==========
8984
int pp_size = nn::parallel::global::GetPipelineParallelSize();

0 commit comments

Comments
 (0)