Skip to content

Commit 0bdef97

Browse files
JYMiracle305kilinchange
authored andcommitted
feat: ckpt and bin checkpoint formats are kept as an interim solution, with plans to unify into one later.
1 parent b550d35 commit 0bdef97

25 files changed

Lines changed: 125 additions & 333 deletions

example/common/checkpoint_loader.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
using namespace infini_train;
1717
namespace nn = infini_train::nn;
1818

19+
// TODO(jym): ckpt is a new checkpoint format; bin is the legacy format. Keeping both as an interim solution; plan to
20+
// consolidate into one later.
1921
ResumeFromCheckpointResult ResumeFromCheckpoint(const ResumeFromCheckpointArgs &args) {
2022
ResumeFromCheckpointResult result;
2123
if (args.resume_root.empty()) {
@@ -36,7 +38,7 @@ ResumeFromCheckpointResult ResumeFromCheckpoint(const ResumeFromCheckpointArgs &
3638
}
3739
}
3840

39-
Checkpoint::Load(resume_dir, *args.model, args.optimizer.get(), args.state);
41+
Checkpoint::Load(resume_dir, *args.model, args.optimizer.get(), args.state, true);
4042

4143
result.global_step = static_cast<int>(args.state.global_step);
4244

@@ -86,7 +88,7 @@ void SaveCheckpoint(const SaveCheckpointArgs &args) {
8688
state.sp_size = args.sp_size;
8789
state.pp_size = args.pp_size;
8890

89-
Checkpoint::Save(args.save_dir, args.model, &args.optimizer, state);
91+
Checkpoint::Save(args.save_dir, args.model, &args.optimizer, state, args.no_save_optim);
9092

9193
const auto ckpt_end = std::chrono::high_resolution_clock::now();
9294
const double ckpt_ms = std::chrono::duration<double, std::milli>(ckpt_end - ckpt_start).count();

example/common/checkpoint_loader.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ struct ResumeFromCheckpointArgs {
2222
const nn::parallel::Rank &rank;
2323
std::shared_ptr<nn::Module> model;
2424
std::shared_ptr<Optimizer> optimizer;
25-
DistributedDataLoader &train_loader;
2625
const nn::TransformerConfig &model_config;
2726
TrainerState &state;
2827
};

example/gpt2/checkpoint_loader.cc

Lines changed: 23 additions & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <filesystem>
66
#include <fstream>
77
#include <memory>
8+
#include <random>
89
#include <string>
910
#include <tuple>
1011
#include <vector>
@@ -28,25 +29,35 @@ using namespace infini_train;
2829
namespace nn = infini_train::nn;
2930

3031
namespace {
31-
constexpr int32_t kGPT2Magic = 20240326;
32-
constexpr int32_t kGPT2FP32Version = 3;
33-
constexpr int32_t kGPT2BF16Version = 5;
32+
constexpr int kRandomSeed = 42;
3433

35-
std::tuple<int32_t, DataType> DetermineAndCheckVersion(const std::vector<uint8_t> &header, size_t offset) {
34+
// TODO(dcj): make this rng generator compatible with torch later
35+
static std::mt19937 gen{kRandomSeed};
36+
} // namespace
37+
38+
namespace {
39+
constexpr int32_t kHeaderMagic = 20240326;
40+
constexpr int32_t kHeaderFP32Version = 3;
41+
constexpr int32_t kHeaderBF16Version = 5;
42+
43+
std::tuple<int32_t, infini_train::DataType> DetermineAndCheckVersion(const std::vector<uint8_t> &header,
44+
size_t offset) {
3645
const auto version = BytesToType<uint32_t>(header, offset);
3746
switch (version) {
38-
case kGPT2BF16Version:
39-
return {version, DataType::kBFLOAT16};
40-
case kGPT2FP32Version:
41-
return {version, DataType::kFLOAT32};
47+
case kHeaderBF16Version:
48+
return {version, infini_train::DataType::kBFLOAT16};
49+
case kHeaderFP32Version:
50+
return {version, infini_train::DataType::kFLOAT32};
4251
default:
4352
LOG(FATAL) << "Unsupported version: " << version << " at " << __FILE__ << ":" << __LINE__;
4453
return {}; // Unreachable, but keeps compiler happy
4554
}
4655
}
4756
} // namespace
4857

49-
std::shared_ptr<nn::TransformerModel> gpt2::LoadFromLLMC(const std::string &filepath) {
58+
namespace gpt2 {
59+
60+
std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath) {
5061
if (!std::filesystem::exists(filepath)) {
5162
LOG(FATAL) << "File not found: " << filepath;
5263
}
@@ -55,9 +66,9 @@ std::shared_ptr<nn::TransformerModel> gpt2::LoadFromLLMC(const std::string &file
5566
const auto header = ReadSeveralBytesFromIfstream(256 * sizeof(int32_t), &ifs);
5667

5768
const auto magic = BytesToType<uint32_t>(header, 0);
58-
CHECK_EQ(magic, kGPT2Magic);
69+
CHECK_EQ(magic, kHeaderMagic);
5970
auto [version, dtype] = DetermineAndCheckVersion(header, 4);
60-
CHECK_EQ(version, kGPT2FP32Version);
71+
CHECK_EQ(version, kHeaderFP32Version);
6172

6273
auto tp_size = nn::parallel::global::GetTensorParallelSize();
6374

@@ -418,127 +429,4 @@ std::shared_ptr<nn::TransformerModel> gpt2::LoadFromLLMC(const std::string &file
418429

419430
return local_gpt2;
420431
}
421-
422-
void gpt2::SaveAsLLMC(const std::shared_ptr<nn::TransformerModel> &model, const std::string &filepath) {
423-
CHECK_EQ(nn::parallel::global::GetTensorParallelSize(), 1) << "SaveAsLLMC currently supports TP=1 only.";
424-
CHECK_EQ(nn::parallel::global::GetPipelineParallelSize(), 1) << "SaveAsLLMC currently supports PP=1 only.";
425-
426-
std::ofstream ofs(filepath, std::ios::binary);
427-
CHECK(ofs.is_open()) << "Failed to open model file for write: " << filepath;
428-
429-
auto config = model->Config();
430-
std::vector<int32_t> header(256, 0);
431-
header[0] = kGPT2Magic;
432-
header[1] = kGPT2FP32Version;
433-
header[2] = static_cast<int32_t>(config.block_size);
434-
header[3] = static_cast<int32_t>(config.original_vocab_size);
435-
header[4] = static_cast<int32_t>(config.n_layer);
436-
header[5] = static_cast<int32_t>(config.n_head);
437-
header[6] = static_cast<int32_t>(config.n_embd);
438-
header[7] = static_cast<int32_t>(config.vocab_size);
439-
ofs.write(reinterpret_cast<const char *>(header.data()),
440-
static_cast<std::streamsize>(header.size() * sizeof(int32_t)));
441-
442-
const auto state_dict = model->StateDict();
443-
auto get_tensor = [&](const std::string &name) -> std::shared_ptr<Tensor> {
444-
CHECK(state_dict.contains(name)) << "Missing tensor in GPT2 state_dict: " << name;
445-
return state_dict.at(name);
446-
};
447-
448-
auto write_tensor_fp32 = [&](const std::shared_ptr<Tensor> &tensor) {
449-
Tensor cpu = tensor->To(Device());
450-
if (cpu.Dtype() != DataType::kFLOAT32) {
451-
cpu = cpu.To(DataType::kFLOAT32);
452-
}
453-
const auto bytes = static_cast<std::streamsize>(cpu.SizeInBytes());
454-
ofs.write(reinterpret_cast<const char *>(cpu.DataPtr()), bytes);
455-
};
456-
457-
// transformer.wte.weight
458-
write_tensor_fp32(get_tensor(std::format("{}.{}.{}", nn::TransformerModel::kTransformerModelName,
459-
nn::TransformerFirstStage::kWTELayerName,
460-
nn::parallel::VocabParallelEmbedding::kParamWeightName)));
461-
462-
// transformer.wpe.weight
463-
write_tensor_fp32(
464-
get_tensor(std::format("{}.{}.{}", nn::TransformerModel::kTransformerModelName,
465-
nn::TransformerFirstStage::kWPELayerName, nn::Embedding::kParamWeightName)));
466-
467-
for (int idx = 0; idx < config.n_layer; ++idx) {
468-
write_tensor_fp32(get_tensor(std::format(
469-
"{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName, nn::TransformerChunk::kHLayerName, idx,
470-
nn::TransformerLayer::kLn1LayerName, nn::LayerNorm::kParamWeightName)));
471-
}
472-
for (int idx = 0; idx < config.n_layer; ++idx) {
473-
write_tensor_fp32(get_tensor(std::format("{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName,
474-
nn::TransformerChunk::kHLayerName, idx,
475-
nn::TransformerLayer::kLn1LayerName, nn::LayerNorm::kParamBiasName)));
476-
}
477-
for (int idx = 0; idx < config.n_layer; ++idx) {
478-
write_tensor_fp32(get_tensor(std::format(
479-
"{}.{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName, nn::TransformerChunk::kHLayerName, idx,
480-
nn::TransformerLayer::kAttnLayerName, nn::CausalSelfAttention::kCAttnLayerName,
481-
nn::parallel::ColumnParallelLinear::kParamWeightName)));
482-
}
483-
for (int idx = 0; idx < config.n_layer; ++idx) {
484-
write_tensor_fp32(get_tensor(
485-
std::format("{}.{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName,
486-
nn::TransformerChunk::kHLayerName, idx, nn::TransformerLayer::kAttnLayerName,
487-
nn::CausalSelfAttention::kCAttnLayerName, nn::parallel::ColumnParallelLinear::kParamBiasName)));
488-
}
489-
for (int idx = 0; idx < config.n_layer; ++idx) {
490-
write_tensor_fp32(get_tensor(
491-
std::format("{}.{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName,
492-
nn::TransformerChunk::kHLayerName, idx, nn::TransformerLayer::kAttnLayerName,
493-
nn::CausalSelfAttention::kCProjLayerName, nn::parallel::RowParallelLinear::kParamWeightName)));
494-
}
495-
for (int idx = 0; idx < config.n_layer; ++idx) {
496-
write_tensor_fp32(get_tensor(
497-
std::format("{}.{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName,
498-
nn::TransformerChunk::kHLayerName, idx, nn::TransformerLayer::kAttnLayerName,
499-
nn::CausalSelfAttention::kCProjLayerName, nn::parallel::RowParallelLinear::kParamBiasName)));
500-
}
501-
for (int idx = 0; idx < config.n_layer; ++idx) {
502-
write_tensor_fp32(get_tensor(std::format(
503-
"{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName, nn::TransformerChunk::kHLayerName, idx,
504-
nn::TransformerLayer::kLn2LayerName, nn::LayerNorm::kParamWeightName)));
505-
}
506-
for (int idx = 0; idx < config.n_layer; ++idx) {
507-
write_tensor_fp32(get_tensor(std::format("{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName,
508-
nn::TransformerChunk::kHLayerName, idx,
509-
nn::TransformerLayer::kLn2LayerName, nn::LayerNorm::kParamBiasName)));
510-
}
511-
for (int idx = 0; idx < config.n_layer; ++idx) {
512-
write_tensor_fp32(
513-
get_tensor(std::format("{}.{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName,
514-
nn::TransformerChunk::kHLayerName, idx, nn::TransformerLayer::kMlpLayerName,
515-
nn::MLP::kCFcLayerName, nn::parallel::ColumnParallelLinear::kParamWeightName)));
516-
}
517-
for (int idx = 0; idx < config.n_layer; ++idx) {
518-
write_tensor_fp32(
519-
get_tensor(std::format("{}.{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName,
520-
nn::TransformerChunk::kHLayerName, idx, nn::TransformerLayer::kMlpLayerName,
521-
nn::MLP::kCFcLayerName, nn::parallel::ColumnParallelLinear::kParamBiasName)));
522-
}
523-
for (int idx = 0; idx < config.n_layer; ++idx) {
524-
write_tensor_fp32(
525-
get_tensor(std::format("{}.{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName,
526-
nn::TransformerChunk::kHLayerName, idx, nn::TransformerLayer::kMlpLayerName,
527-
nn::MLP::kCProjLayerName, nn::parallel::RowParallelLinear::kParamWeightName)));
528-
}
529-
for (int idx = 0; idx < config.n_layer; ++idx) {
530-
write_tensor_fp32(
531-
get_tensor(std::format("{}.{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName,
532-
nn::TransformerChunk::kHLayerName, idx, nn::TransformerLayer::kMlpLayerName,
533-
nn::MLP::kCProjLayerName, nn::parallel::RowParallelLinear::kParamBiasName)));
534-
}
535-
536-
write_tensor_fp32(
537-
get_tensor(std::format("{}.{}.{}", nn::TransformerModel::kTransformerModelName,
538-
nn::TransformerLastStage::kLnFLayerName, nn::LayerNorm::kParamWeightName)));
539-
write_tensor_fp32(get_tensor(std::format("{}.{}.{}", nn::TransformerModel::kTransformerModelName,
540-
nn::TransformerLastStage::kLnFLayerName, nn::LayerNorm::kParamBiasName)));
541-
542-
ofs.flush();
543-
CHECK(ofs.good()) << "Failed to flush model file: " << filepath;
544-
}
432+
} // namespace gpt2

example/gpt2/checkpoint_loader.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
#pragma once
22

3-
#include <cstring>
43
#include <memory>
54
#include <string>
65

76
namespace infini_train::nn {
87
class TransformerModel;
9-
}
8+
} // namespace infini_train::nn
109

1110
namespace gpt2 {
1211
std::shared_ptr<infini_train::nn::TransformerModel> LoadFromLLMC(const std::string &filepath);
13-
void SaveAsLLMC(const std::shared_ptr<infini_train::nn::TransformerModel> &model, const std::string &filepath);
1412
} // namespace gpt2

example/gpt2/config.h

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,19 @@
44

55
#include "infini_train/include/nn/modules/transformer/transformer_config.h"
66

7+
namespace nn = infini_train::nn;
78
namespace gpt2 {
8-
inline infini_train::nn::TransformerConfig GPT2Config() {
9+
inline nn::TransformerConfig GPT2Config() {
910
return {.block_size = 1024,
1011
.vocab_size = 50304,
1112
.original_vocab_size = 50257,
1213
.n_layer = 12,
1314
.n_head = 12,
1415
.n_kv_head = 12,
1516
.n_embd = 768,
16-
.attention_type = infini_train::nn::AttentionType::kStandard,
17-
.activation_type = infini_train::nn::MLPType::kGELU,
18-
.norm_type = infini_train::nn::NormType::kLayerNorm,
17+
.attention_type = nn::AttentionType::kStandard,
18+
.activation_type = nn::MLPType::kGELU,
19+
.norm_type = nn::NormType::kLayerNorm,
1920
.add_bias_linear = true,
2021
.add_bias_lm_head = false,
2122
.tie_weights = true,
@@ -24,7 +25,7 @@ inline infini_train::nn::TransformerConfig GPT2Config() {
2425
.multiple_of = 1};
2526
}
2627

27-
inline void SanitizeGPT2Config(const infini_train::nn::TransformerConfig &c) {
28+
inline void SanitizeGPT2Config(const nn::TransformerConfig &c) {
2829
CHECK_GT(c.block_size, 0);
2930
CHECK_GT(c.vocab_size, 0);
3031
CHECK_GE(c.vocab_size, c.original_vocab_size);
@@ -33,8 +34,9 @@ inline void SanitizeGPT2Config(const infini_train::nn::TransformerConfig &c) {
3334
CHECK_GT(c.n_embd, 0);
3435
CHECK_EQ(c.n_embd % c.n_head, 0) << "n_embd must be divisible by n_head";
3536
CHECK_EQ(c.n_kv_head, c.n_head) << "GPT-2 does not use GQA; n_kv_head must equal n_head";
36-
CHECK(c.attention_type == infini_train::nn::AttentionType::kStandard) << "GPT-2 requires standard attention";
37-
CHECK(c.activation_type == infini_train::nn::MLPType::kGELU) << "GPT-2 requires GELU activation";
38-
CHECK(c.norm_type == infini_train::nn::NormType::kLayerNorm) << "GPT-2 requires LayerNorm";
37+
CHECK(c.attention_type == nn::AttentionType::kStandard) << "GPT-2 requires standard attention";
38+
CHECK(c.activation_type == nn::MLPType::kGELU) << "GPT-2 requires GELU activation";
39+
CHECK(c.norm_type == nn::NormType::kLayerNorm) << "GPT-2 requires LayerNorm";
3940
}
41+
4042
} // namespace gpt2

example/gpt2/main.cc

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
1-
#include <algorithm>
21
#include <chrono>
32
#include <cstdlib>
43
#include <filesystem>
54
#include <format>
6-
#include <limits>
75
#include <memory>
86
#include <optional>
97
#include <unordered_map>
@@ -205,8 +203,6 @@ void Train(const nn::parallel::Rank &rank) {
205203
gpt2::SanitizeGPT2Config(model_config);
206204
model = std::make_shared<nn::TransformerModel>(model_config);
207205
}
208-
auto llmc_model = std::dynamic_pointer_cast<nn::TransformerModel>(model);
209-
CHECK(llmc_model != nullptr) << "Failed to cast model to GPT2 for LLMC checkpoint I/O.";
210206

211207
model->To(device);
212208

@@ -305,8 +301,8 @@ void Train(const nn::parallel::Rank &rank) {
305301

306302
// TODO(dcj): support more complex optimizer later
307303
// auto optimizer = optimizers::SGD(model->Parameters(), FLAGS_learning_rate);
308-
std::shared_ptr<Optimizer> optimizer = nullptr;
309304
auto optimizer_creator = optimizers::SGD::Create(FLAGS_learning_rate);
305+
std::shared_ptr<Optimizer> optimizer = nullptr;
310306

311307
if (FLAGS_zero_stage >= 1) {
312308
auto model_chunks = (pp_world_size > 1)
@@ -319,7 +315,6 @@ void Train(const nn::parallel::Rank &rank) {
319315
}
320316

321317
auto train_iter = train_loader.begin();
322-
323318
std::shared_ptr<nn::Module> loss_fn
324319
= (tp_world_size > 1) ? std::static_pointer_cast<nn::Module>(
325320
std::make_shared<VocabParallelCrossEntropyLoss>(model_config.original_vocab_size))
@@ -335,7 +330,6 @@ void Train(const nn::parallel::Rank &rank) {
335330
.rank = rank,
336331
.model = model,
337332
.optimizer = optimizer,
338-
.train_loader = train_loader,
339333
.model_config = model_config,
340334
.state = state});
341335
start_step = resume_result.global_step;
@@ -377,6 +371,8 @@ void Train(const nn::parallel::Rank &rank) {
377371
});
378372
};
379373

374+
LOG(INFO) << "start training";
375+
380376
for (int step = start_step; step < FLAGS_num_iteration + 1; ++step) {
381377
// Reset precision check counters at start of each iteration for file overwrite
382378
utils::PrecisionChecker::ResetCounters();

0 commit comments

Comments
 (0)