Skip to content

Commit 8ebe12f

Browse files
committed
temp
1 parent fa1e61b commit 8ebe12f

28 files changed

Lines changed: 875 additions & 1238 deletions

CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,7 @@ add_executable(gpt2
196196
example/common/utils.cc
197197
example/common/checkpoint_loader.cc
198198
example/common/tokenizer.cc
199+
example/gpt2/checkpoint_loader.cc
199200
)
200201
link_infini_train_exe(gpt2)
201202

@@ -205,6 +206,7 @@ add_executable(llama3
205206
example/common/utils.cc
206207
example/common/checkpoint_loader.cc
207208
example/common/tokenizer.cc
209+
example/llama3/checkpoint_loader.cc
208210
)
209211
link_infini_train_exe(llama3)
210212

example/common/checkpoint_loader.cc

Lines changed: 5 additions & 967 deletions
Large diffs are not rendered by default.

example/common/checkpoint_loader.h

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,22 +15,11 @@
1515
#include "infini_train/include/nn/parallel/rank.h"
1616
#include "infini_train/include/optimizer.h"
1717

18-
namespace infini_train {
19-
namespace nn {
20-
class TransformerModel;
21-
}
22-
23-
namespace gpt2 {
24-
std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath);
25-
void SaveAsLLMC(const std::shared_ptr<nn::TransformerModel> &model, const std::string &filepath);
26-
} // namespace gpt2
27-
namespace llama3 {
28-
std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath);
29-
void SaveAsLLMC(const std::shared_ptr<nn::TransformerModel> &model, const std::string &filepath);
30-
} // namespace llama3
18+
using namespace infini_train;
19+
namespace nn = infini_train::nn;
3120

3221
struct ResumeFromCheckpointArgs {
33-
fLS::clstring resume_root;
22+
std::filesystem::path resume_root;
3423
const nn::parallel::Rank &rank;
3524
std::shared_ptr<nn::Module> model;
3625
std::shared_ptr<Optimizer> optimizer;
@@ -42,23 +31,21 @@ struct ResumeFromCheckpointArgs {
4231

4332
struct ResumeFromCheckpointResult {
4433
int global_step = 0;
45-
float best_loss = std::numeric_limits<float>::infinity();
4634
size_t data_batch_idx = 0;
4735
};
4836

4937
struct SaveCheckpointArgs {
5038
std::filesystem::path save_dir;
5139
int64_t global_step = 0;
5240
size_t data_batch_idx = 0;
53-
float best_loss = std::numeric_limits<float>::infinity();
5441
double last_lr = 0.0;
5542
std::string optimizer_type;
56-
std::string checkpoint_format = "bin";
43+
std::string checkpoint_file_format = "bin";
5744
int ddp_size = 1;
5845
int tp_size = 1;
5946
int sp_size = 1;
6047
int pp_size = 1;
61-
bool save_optimizer_state = true;
48+
bool no_save_optim = false;
6249
bool prune_step_checkpoints = false;
6350
std::filesystem::path checkpoint_root_dir;
6451
size_t max_checkpoint_keep = 0;
@@ -71,5 +58,3 @@ struct SaveCheckpointArgs {
7158
ResumeFromCheckpointResult ResumeFromCheckpoint(const ResumeFromCheckpointArgs &args);
7259

7360
void SaveCheckpoint(const SaveCheckpointArgs &args);
74-
75-
} // namespace infini_train

example/gpt2/checkpoint_loader.cc

Lines changed: 544 additions & 0 deletions
Large diffs are not rendered by default.

example/gpt2/checkpoint_loader.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
#pragma once
2+
3+
#include <cstring>
4+
#include <memory>
5+
#include <string>
6+
7+
namespace infini_train::nn {
8+
class TransformerModel;
9+
}
10+
11+
namespace gpt2 {
12+
std::shared_ptr<infini_train::nn::TransformerModel> LoadFromLLMC(const std::string &filepath);
13+
void SaveAsLLMC(const std::shared_ptr<infini_train::nn::TransformerModel> &model, const std::string &filepath);
14+
} // namespace gpt2

example/gpt2/config.h

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,18 @@
44

55
#include "infini_train/include/nn/modules/transformer/transformer_config.h"
66

7-
namespace infini_train {
87
namespace gpt2 {
9-
inline nn::TransformerConfig GPT2Config() {
8+
inline infini_train::nn::TransformerConfig GPT2Config() {
109
return {.block_size = 1024,
1110
.vocab_size = 50304,
1211
.original_vocab_size = 50257,
1312
.n_layer = 12,
1413
.n_head = 12,
1514
.n_kv_head = 12,
1615
.n_embd = 768,
17-
.attention_type = nn::AttentionType::kStandard,
18-
.activation_type = nn::MLPType::kGELU,
19-
.norm_type = nn::NormType::kLayerNorm,
16+
.attention_type = infini_train::nn::AttentionType::kStandard,
17+
.activation_type = infini_train::nn::MLPType::kGELU,
18+
.norm_type = infini_train::nn::NormType::kLayerNorm,
2019
.add_bias_linear = true,
2120
.add_bias_lm_head = false,
2221
.tie_weights = true,
@@ -25,7 +24,7 @@ inline nn::TransformerConfig GPT2Config() {
2524
.multiple_of = 1};
2625
}
2726

28-
inline void SanitizeGPT2Config(const nn::TransformerConfig &c) {
27+
inline void SanitizeGPT2Config(const infini_train::nn::TransformerConfig &c) {
2928
CHECK_GT(c.block_size, 0);
3029
CHECK_GT(c.vocab_size, 0);
3130
CHECK_GE(c.vocab_size, c.original_vocab_size);
@@ -34,10 +33,8 @@ inline void SanitizeGPT2Config(const nn::TransformerConfig &c) {
3433
CHECK_GT(c.n_embd, 0);
3534
CHECK_EQ(c.n_embd % c.n_head, 0) << "n_embd must be divisible by n_head";
3635
CHECK_EQ(c.n_kv_head, c.n_head) << "GPT-2 does not use GQA; n_kv_head must equal n_head";
37-
CHECK(c.attention_type == nn::AttentionType::kStandard) << "GPT-2 requires standard attention";
38-
CHECK(c.activation_type == nn::MLPType::kGELU) << "GPT-2 requires GELU activation";
39-
CHECK(c.norm_type == nn::NormType::kLayerNorm) << "GPT-2 requires LayerNorm";
36+
CHECK(c.attention_type == infini_train::nn::AttentionType::kStandard) << "GPT-2 requires standard attention";
37+
CHECK(c.activation_type == infini_train::nn::MLPType::kGELU) << "GPT-2 requires GELU activation";
38+
CHECK(c.norm_type == infini_train::nn::NormType::kLayerNorm) << "GPT-2 requires LayerNorm";
4039
}
41-
4240
} // namespace gpt2
43-
} // namespace infini_train

example/gpt2/main.cc

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,10 @@
4141
#include "example/common/checkpoint_loader.h"
4242
#include "example/common/tiny_shakespeare_dataset.h"
4343
#include "example/common/tokenizer.h"
44+
#include "example/gpt2/checkpoint_loader.h"
4445
#include "example/gpt2/config.h"
4546

47+
// TODO(jym): Reorganize CLI flags into categories for better readability and maintainability.
4648
// I/O
4749
DEFINE_string(input_bin, "", "input .bin to train on");
4850
DEFINE_string(input_val_bin, "", "input .bin to eval validation loss on");
@@ -81,12 +83,12 @@ DEFINE_uint32(virtual_pipeline_parallel, 1, "Number of chunks in PP stage.");
8183

8284
// precision
8385
DEFINE_string(dtype, "float32", "precision used in training (float32/bfloat16)");
84-
DEFINE_uint32(save_steps, 0, "save checkpoint every N steps; 0 disables saving");
85-
DEFINE_string(resume_from, "", "checkpoint directory to resume from");
86-
DEFINE_string(checkpoint_dir, "./checkpoints", "root directory used to store checkpoints");
86+
DEFINE_uint32(save_interval, 0, "save checkpoint every N steps; 0 disables saving");
87+
DEFINE_string(load, "", "checkpoint directory to resume from");
88+
DEFINE_string(save, "./checkpoints", "root directory used to store checkpoints");
8789
DEFINE_uint32(max_checkpoint_keep, 3, "max number of checkpoint steps to keep");
88-
DEFINE_bool(save_optimizer_state, true, "whether optimizer state is persisted in checkpoints");
89-
DEFINE_string(checkpoint_format, "ckpt",
90+
DEFINE_bool(no_save_optim, false, "whether optimizer state is persisted in checkpoints");
91+
DEFINE_string(checkpoint_file_format, "ckpt",
9092
"checkpoint format: bin|ckpt. "
9193
"'bin' generates model.bin/optimizer.bin (bin supports LLMC model format via callbacks); "
9294
"'ckpt' generates model.ckpt/optimizer.ckpt (native StateDict binary).");
@@ -317,7 +319,7 @@ void Train(const nn::parallel::Rank &rank) {
317319
// TODO(dcj): support more complex optimizer later
318320
// auto optimizer = optimizers::SGD(model->Parameters(), FLAGS_learning_rate);
319321
std::shared_ptr<Optimizer> optimizer = nullptr;
320-
auto optimizer_creator = optimizers::SGD::Create(FLAGS_learning_rate);
322+
auto optimizer_creator = optimizers::SGD::CreateNamed(FLAGS_learning_rate);
321323

322324
if (FLAGS_use_distributed_optimizer) {
323325
auto model_chunks = (pp_world_size > 1)
@@ -341,16 +343,15 @@ void Train(const nn::parallel::Rank &rank) {
341343
auto impl = core::GetDeviceGuardImpl(device.type());
342344

343345
int start_step = 0;
344-
float best_loss = std::numeric_limits<float>::infinity();
345346
TrainerState state;
346347
CheckpointLoadOptions load_options;
347348
load_options.load_optimizer_state = true;
348349
load_options.model_bin_loader = [](nn::Module *target_model, const std::filesystem::path &model_path) {
349350
auto loaded_model = gpt2::LoadFromLLMC(model_path.string());
350351
target_model->LoadStateDict(loaded_model->StateDict());
351352
};
352-
const auto resume_result = infini_train::ResumeFromCheckpoint({
353-
.resume_root = FLAGS_resume_from,
353+
const auto resume_result = ResumeFromCheckpoint({
354+
.resume_root = FLAGS_load,
354355
.rank = rank,
355356
.model = model,
356357
.optimizer = optimizer,
@@ -360,26 +361,24 @@ void Train(const nn::parallel::Rank &rank) {
360361
.load_options = load_options,
361362
});
362363
start_step = resume_result.global_step;
363-
best_loss = resume_result.best_loss;
364364
saved_data_batch_idx = resume_result.data_batch_idx;
365365

366366
auto save_checkpoint
367367
= [&](const std::filesystem::path &save_dir, int64_t global_step, bool prune_step_checkpoints) {
368-
infini_train::SaveCheckpoint({
368+
SaveCheckpoint({
369369
.save_dir = save_dir,
370370
.global_step = global_step,
371371
.data_batch_idx = saved_data_batch_idx,
372-
.best_loss = best_loss,
373372
.last_lr = FLAGS_learning_rate,
374373
.optimizer_type = "SGD",
375-
.checkpoint_format = FLAGS_checkpoint_format,
374+
.checkpoint_file_format = FLAGS_checkpoint_file_format,
376375
.ddp_size = ddp_world_size,
377376
.tp_size = tp_world_size,
378377
.sp_size = sp_world_size,
379378
.pp_size = pp_world_size,
380-
.save_optimizer_state = FLAGS_save_optimizer_state,
379+
.no_save_optim = FLAGS_no_save_optim,
381380
.prune_step_checkpoints = prune_step_checkpoints,
382-
.checkpoint_root_dir = FLAGS_checkpoint_dir,
381+
.checkpoint_root_dir = FLAGS_save,
383382
.max_checkpoint_keep = FLAGS_max_checkpoint_keep,
384383
.rank = rank,
385384
.model = *model,
@@ -484,8 +483,6 @@ void Train(const nn::parallel::Rank &rank) {
484483
lossf = static_cast<const float *>(lossf_tensor->To(Device()).DataPtr())[0];
485484
}
486485

487-
best_loss = std::min(best_loss, lossf);
488-
489486
const auto iter_end = std::chrono::high_resolution_clock::now();
490487
const double duration_us = std::chrono::duration<double, std::micro>(iter_end - iter_start).count();
491488
const double tps = FLAGS_total_batch_size / (duration_us / 1e6);
@@ -509,9 +506,9 @@ void Train(const nn::parallel::Rank &rank) {
509506
}
510507
}
511508

512-
if (FLAGS_save_steps > 0 && (step + 1) % FLAGS_save_steps == 0) {
509+
if (FLAGS_save_interval > 0 && (step + 1) % FLAGS_save_interval == 0) {
513510
std::filesystem::path step_dir
514-
= std::filesystem::path(FLAGS_checkpoint_dir) / std::format("checkpoint_step_{:06d}", step + 1);
511+
= std::filesystem::path(FLAGS_save) / std::format("checkpoint_step_{:06d}", step + 1);
515512
if (rank.IsParallel()) {
516513
step_dir /= std::format("rank_{:06d}", rank.GlobalRank());
517514
}
@@ -525,7 +522,7 @@ void Train(const nn::parallel::Rank &rank) {
525522
nn::lora::SaveLoRAWeights(model, FLAGS_lora_save_path);
526523
}
527524

528-
std::filesystem::path final_dir = std::filesystem::path(FLAGS_checkpoint_dir) / "checkpoint_final";
525+
std::filesystem::path final_dir = std::filesystem::path(FLAGS_save) / "checkpoint_final";
529526
if (rank.IsParallel()) {
530527
final_dir /= std::format("rank_{:06d}", rank.GlobalRank());
531528
}

example/llama3/checkpoint_loader.cc

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,12 @@
55
#include <filesystem>
66
#include <fstream>
77
#include <memory>
8-
#include <random>
98
#include <string>
109
#include <unordered_map>
1110
#include <vector>
1211

1312
#include "glog/logging.h"
1413

15-
#include "example/common/utils.h"
16-
#include "example/llama3/config.h"
1714
#include "infini_train/include/nn/modules/normalization.h"
1815
#include "infini_train/include/nn/modules/transformer/causal_self_attention.h"
1916
#include "infini_train/include/nn/modules/transformer/mlp.h"
@@ -22,24 +19,18 @@
2219
#include "infini_train/include/nn/parallel/tensor_parallel.h"
2320
#include "infini_train/include/tensor.h"
2421

22+
#include "example/common/utils.h"
23+
#include "example/llama3/config.h"
24+
2525
using namespace infini_train;
2626
namespace nn = infini_train::nn;
2727

28-
namespace {
29-
constexpr int kRandomSeed = 42;
30-
31-
// TODO(zbl): make this rng generator compatible with torch later
32-
static std::mt19937 gen{kRandomSeed};
33-
} // namespace
34-
3528
namespace {
3629
constexpr int32_t kLLaMA3Magic = 20240803;
3730
constexpr int32_t kLLaMA3FP32Version = 3;
3831
} // namespace
3932

40-
namespace llama3 {
41-
42-
std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath) {
33+
std::shared_ptr<nn::TransformerModel> llama3::LoadFromLLMC(const std::string &filepath) {
4334
if (!std::filesystem::exists(filepath)) {
4435
LOG(FATAL) << "File not found: " << filepath;
4536
}
@@ -346,7 +337,7 @@ std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath)
346337
return llama3;
347338
}
348339

349-
void SaveAsLLMC(const std::shared_ptr<nn::TransformerModel> &model, const std::string &filepath) {
340+
void llama3::SaveAsLLMC(const std::shared_ptr<nn::TransformerModel> &model, const std::string &filepath) {
350341
CHECK_EQ(nn::parallel::global::GetTensorParallelSize(), 1) << "SaveAsLLMC currently supports TP=1 only.";
351342
CHECK_EQ(nn::parallel::global::GetPipelineParallelSize(), 1) << "SaveAsLLMC currently supports PP=1 only.";
352343

@@ -448,4 +439,3 @@ void SaveAsLLMC(const std::shared_ptr<nn::TransformerModel> &model, const std::s
448439
ofs.flush();
449440
CHECK(ofs.good()) << "Failed to flush model file: " << filepath;
450441
}
451-
} // namespace llama3

example/llama3/checkpoint_loader.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
#pragma once
2+
3+
#include <cstring>
4+
#include <memory>
5+
#include <string>
6+
7+
namespace infini_train::nn {
8+
class TransformerModel;
9+
}
10+
11+
namespace llama3 {
12+
std::shared_ptr<infini_train::nn::TransformerModel> LoadFromLLMC(const std::string &filepath);
13+
void SaveAsLLMC(const std::shared_ptr<infini_train::nn::TransformerModel> &model, const std::string &filepath);
14+
} // namespace llama3

example/llama3/config.h

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,18 @@
44

55
#include "infini_train/include/nn/modules/transformer/transformer_config.h"
66

7-
namespace infini_train {
87
namespace llama3 {
9-
inline nn::TransformerConfig LLaMA3Config() {
8+
inline infini_train::nn::TransformerConfig LLaMA3Config() {
109
return {.block_size = 8192,
1110
.vocab_size = 128256,
1211
.original_vocab_size = 128256,
1312
.n_layer = 16,
1413
.n_head = 32,
1514
.n_kv_head = 8,
1615
.n_embd = 2048,
17-
.attention_type = nn::AttentionType::kRoPE,
18-
.activation_type = nn::MLPType::kSwiGLU,
19-
.norm_type = nn::NormType::kRMSNorm,
16+
.attention_type = infini_train::nn::AttentionType::kRoPE,
17+
.activation_type = infini_train::nn::MLPType::kSwiGLU,
18+
.norm_type = infini_train::nn::NormType::kRMSNorm,
2019
.add_bias_linear = false,
2120
.add_bias_lm_head = false,
2221
.tie_weights = false,
@@ -25,7 +24,7 @@ inline nn::TransformerConfig LLaMA3Config() {
2524
.multiple_of = 256};
2625
}
2726

28-
inline void SanitizeLLaMA3Config(const nn::TransformerConfig &c) {
27+
inline void SanitizeLLaMA3Config(const infini_train::nn::TransformerConfig &c) {
2928
CHECK_GT(c.block_size, 0);
3029
CHECK_GT(c.vocab_size, 0);
3130
CHECK_GE(c.vocab_size, c.original_vocab_size);
@@ -36,13 +35,12 @@ inline void SanitizeLLaMA3Config(const nn::TransformerConfig &c) {
3635
CHECK_EQ(c.n_head % c.n_kv_head, 0) << "n_head must be divisible by n_kv_head for GQA";
3736
CHECK_GT(c.n_embd, 0);
3837
CHECK_EQ(c.n_embd % c.n_head, 0) << "n_embd must be divisible by n_head";
39-
CHECK(c.attention_type == nn::AttentionType::kRoPE) << "LLaMA-3 requires RoPE attention";
40-
CHECK(c.activation_type == nn::MLPType::kSwiGLU) << "LLaMA-3 requires SwiGLU activation";
41-
CHECK(c.norm_type == nn::NormType::kRMSNorm) << "LLaMA-3 requires RMSNorm";
38+
CHECK(c.attention_type == infini_train::nn::AttentionType::kRoPE) << "LLaMA-3 requires RoPE attention";
39+
CHECK(c.activation_type == infini_train::nn::MLPType::kSwiGLU) << "LLaMA-3 requires SwiGLU activation";
40+
CHECK(c.norm_type == infini_train::nn::NormType::kRMSNorm) << "LLaMA-3 requires RMSNorm";
4241
CHECK(!c.add_bias_linear) << "LLaMA-3 has no bias in linear layers";
4342
CHECK(!c.tie_weights) << "LLaMA-3 does not tie embedding and lm_head weights";
4443
CHECK(c.ffn_dim_multiplier.has_value()) << "LLaMA-3 requires ffn_dim_multiplier";
4544
CHECK_GT(c.multiple_of, 0);
4645
}
4746
} // namespace llama3
48-
} // namespace infini_train

0 commit comments

Comments
 (0)