Skip to content

Commit b550d35

Browse files
JYMiracle305kilinchange
authored andcommitted
feat: add model config validation on resume
1 parent c82cbe8 commit b550d35

9 files changed

Lines changed: 92 additions & 15 deletions

File tree

example/common/checkpoint_loader.cc

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
#include "glog/logging.h"
1111

12+
#include "infini_train/include/nn/modules/transformer/transformer_config.h"
1213
#include "infini_train/include/nn/parallel/global.h"
1314
#include "infini_train/include/tensor.h"
1415

@@ -39,6 +40,17 @@ ResumeFromCheckpointResult ResumeFromCheckpoint(const ResumeFromCheckpointArgs &
3940

4041
result.global_step = static_cast<int>(args.state.global_step);
4142

43+
CHECK_EQ(args.state.n_layer, args.model_config.n_layer)
44+
<< "n_layer mismatch: ckpt=" << args.state.n_layer << ", config=" << args.model_config.n_layer;
45+
CHECK_EQ(args.state.n_head, args.model_config.n_head)
46+
<< "n_head mismatch: ckpt=" << args.state.n_head << ", config=" << args.model_config.n_head;
47+
CHECK_EQ(args.state.n_kv_head, args.model_config.n_kv_head)
48+
<< "n_kv_head mismatch: ckpt=" << args.state.n_kv_head << ", config=" << args.model_config.n_kv_head;
49+
CHECK_EQ(args.state.n_embd, args.model_config.n_embd)
50+
<< "n_embd mismatch: ckpt=" << args.state.n_embd << ", config=" << args.model_config.n_embd;
51+
CHECK_EQ(args.state.vocab_size, args.model_config.vocab_size)
52+
<< "vocab_size mismatch: ckpt=" << args.state.vocab_size << ", config=" << args.model_config.vocab_size;
53+
4254
CHECK_EQ(args.state.ddp_size, ddp_world_size) << "DDP size mismatch: checkpoint has DDP=" << args.state.ddp_size
4355
<< ", but current run has DDP=" << ddp_world_size;
4456
CHECK_EQ(args.state.tp_size, tp_world_size)
@@ -64,6 +76,11 @@ void SaveCheckpoint(const SaveCheckpointArgs &args) {
6476
state.global_step = args.global_step;
6577
state.consumed_batches = static_cast<int64_t>(args.consumed_batches);
6678
state.last_lr = args.last_lr;
79+
state.n_layer = args.n_layer;
80+
state.n_head = args.n_head;
81+
state.n_kv_head = args.n_kv_head;
82+
state.n_embd = args.n_embd;
83+
state.vocab_size = args.vocab_size;
6784
state.ddp_size = args.ddp_size;
6885
state.tp_size = args.tp_size;
6986
state.sp_size = args.sp_size;

example/common/checkpoint_loader.h

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,8 @@
11
#pragma once
22

3-
#include "gflags/gflags.h"
4-
53
#include <cstdint>
64
#include <cstring>
75
#include <filesystem>
8-
#include <functional>
9-
#include <limits>
10-
#include <string>
116

127
#include "infini_train/include/checkpoint.h"
138
#include "infini_train/include/dataloader.h"
@@ -18,12 +13,17 @@
1813
using namespace infini_train;
1914
namespace nn = infini_train::nn;
2015

16+
namespace infini_train::nn {
17+
class TransformerConfig;
18+
}
19+
2120
struct ResumeFromCheckpointArgs {
2221
std::filesystem::path resume_root;
2322
const nn::parallel::Rank &rank;
2423
std::shared_ptr<nn::Module> model;
2524
std::shared_ptr<Optimizer> optimizer;
2625
DistributedDataLoader &train_loader;
26+
const nn::TransformerConfig &model_config;
2727
TrainerState &state;
2828
};
2929

@@ -37,6 +37,11 @@ struct SaveCheckpointArgs {
3737
int64_t global_step = 0;
3838
size_t consumed_batches = 0;
3939
double last_lr = 0.0;
40+
int64_t n_layer = 0;
41+
int64_t n_head = 0;
42+
int64_t n_kv_head = 0;
43+
int64_t n_embd = 0;
44+
int64_t vocab_size = 0;
4045
int ddp_size = 1;
4146
int tp_size = 1;
4247
int sp_size = 1;

example/gpt2/main.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,7 @@ void Train(const nn::parallel::Rank &rank) {
336336
.model = model,
337337
.optimizer = optimizer,
338338
.train_loader = train_loader,
339+
.model_config = model_config,
339340
.state = state});
340341
start_step = resume_result.global_step;
341342
size_t consumed_batches = resume_result.consumed_batches;
@@ -357,6 +358,11 @@ void Train(const nn::parallel::Rank &rank) {
357358
.global_step = global_step,
358359
.consumed_batches = consumed_batches,
359360
.last_lr = FLAGS_learning_rate,
361+
.n_layer = model_config.n_layer,
362+
.n_head = model_config.n_head,
363+
.n_kv_head = model_config.n_kv_head,
364+
.n_embd = model_config.n_embd,
365+
.vocab_size = model_config.vocab_size,
360366
.ddp_size = ddp_world_size,
361367
.tp_size = tp_world_size,
362368
.sp_size = sp_world_size,

example/llama3/main.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,7 @@ void Train(const nn::parallel::Rank &rank) {
316316
.model = model,
317317
.optimizer = optimizer,
318318
.train_loader = train_loader,
319+
.model_config = model_config,
319320
.state = state,
320321
});
321322

@@ -339,6 +340,11 @@ void Train(const nn::parallel::Rank &rank) {
339340
.global_step = global_step,
340341
.consumed_batches = consumed_batches,
341342
.last_lr = FLAGS_learning_rate,
343+
.n_layer = model_config.n_layer,
344+
.n_head = model_config.n_head,
345+
.n_kv_head = model_config.n_kv_head,
346+
.n_embd = model_config.n_embd,
347+
.vocab_size = model_config.vocab_size,
342348
.ddp_size = ddp_world_size,
343349
.tp_size = tp_world_size,
344350
.sp_size = sp_world_size,

infini_train/include/checkpoint.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,11 @@ struct TrainerState {
2020
// FIXME(jym): learning_rate should be restored from scheduler state, move `last_lr` from TrainerState to
2121
// SchedulerState later
2222
double last_lr = 0.0;
23-
23+
int64_t n_layer = 0;
24+
int64_t n_head = 0;
25+
int64_t n_kv_head = 0;
26+
int64_t n_embd = 0;
27+
int64_t vocab_size = 0;
2428
int ddp_size = 1;
2529
int tp_size = 1;
2630
int sp_size = 1;

infini_train/src/checkpoint.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,11 @@ void Checkpoint::SaveTrainerState(const std::filesystem::path &path, const Train
208208
std::ofstream ofs(path);
209209
CHECK(ofs.is_open()) << "Failed to open trainer state file: " << path;
210210
ofs << "{\n";
211+
ofs << " \"n_layer\": " << state.n_layer << ",\n";
212+
ofs << " \"n_head\": " << state.n_head << ",\n";
213+
ofs << " \"n_kv_head\": " << state.n_kv_head << ",\n";
214+
ofs << " \"n_embd\": " << state.n_embd << ",\n";
215+
ofs << " \"vocab_size\": " << state.vocab_size << "\n";
211216
ofs << " \"global_step\": " << state.global_step << ",\n";
212217
ofs << " \"consumed_batches \": " << state.consumed_batches << ",\n";
213218
ofs << " \"last_lr\": " << state.last_lr << ",\n";
@@ -226,6 +231,11 @@ TrainerState Checkpoint::LoadTrainerState(const std::filesystem::path &path) {
226231
const std::string content((std::istreambuf_iterator<char>(ifs)), std::istreambuf_iterator<char>());
227232

228233
TrainerState state;
234+
state.n_layer = ExtractNumberField<int64_t>(content, "n_layer", 0);
235+
state.n_head = ExtractNumberField<int64_t>(content, "n_head", 0);
236+
state.n_kv_head = ExtractNumberField<int64_t>(content, "n_kv_head", 0);
237+
state.n_embd = ExtractNumberField<int64_t>(content, "n_embd", 0);
238+
state.vocab_size = ExtractNumberField<int64_t>(content, "vocab_size", 0);
229239
state.global_step = ExtractNumberField<int64_t>(content, "global_step", 0);
230240
state.consumed_batches = ExtractNumberField<int64_t>(content, "consumed_batches ", 0);
231241
state.last_lr = ExtractNumberField<double>(content, "last_lr", 0.0);

scripts/run_models_and_profile.bash

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,10 @@ PROFILE_LOG_DIR="$(read_var PROFILE_LOG_DIR)"; : "${PROFILE_LOG_DIR:=./profile_
7272
COMPARE_LOG_DIR="$(read_var COMPARE_LOG_DIR)"; : "${COMPARE_LOG_DIR:=}"
7373
RUN_CTEST="$(read_var RUN_CTEST)"; : "${RUN_CTEST:=true}"
7474
CTEST_CMD="$(read_var CTEST_CMD)"; : "${CTEST_CMD:=ctest --output-on-failure -LE cuda -j$(nproc) && ctest --output-on-failure -L cuda -j1}"
75+
CKPT_CLEAN_DIRS=(
76+
"/data1/ckpt"
77+
"./checkpoints"
78+
)
7579

7680
mkdir -p "$BUILD_DIR" "$LOG_DIR" "$PROFILE_LOG_DIR"
7781

@@ -114,6 +118,17 @@ clean_build_dir() {
114118
rm -rf "${BUILD_DIR:?}/"*
115119
}
116120

121+
# Clean checkpoint directories (called once at start of script)
122+
clean_checkpoints() {
123+
echo -e "\033[1;31m[CLEAN] Removing checkpoint directories from previous run\033[0m"
124+
for dir in "${CKPT_CLEAN_DIRS[@]}"; do
125+
if [[ -d "$dir" ]]; then
126+
echo -e "\033[1;31m[CLEAN] Removing: ${dir}\033[0m"
127+
rm -rf "${dir:?}"
128+
fi
129+
done
130+
}
131+
117132
# Run a command and log output
118133
run_and_log() {
119134
local cmd="$1"
@@ -298,6 +313,9 @@ for ((id=0; id<num_builds; ++id)); do
298313
llama3_cmd="${prefix}./llama3 --input_bin ${LLAMA3_INPUT_BIN} --llmc_filepath ${LLAMA3_LLMC_FILEPATH} --device cuda ${llama3_arg_str}"
299314
run_and_log "$llama3_cmd" "llama3_${test_id}${log_suffix}" "$profile_flag" "$group_tag"
300315
done
316+
317+
# Clean checkpoints from previous run to avoid disk overflow and stale state
318+
clean_checkpoints
301319
done
302320
done
303321

tests/checkpoint/test_optimizer_state.cc

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,16 @@ TEST_P(OptimizerStateTest, AdamStateDictKeys) {
1717
param->set_requires_grad(true);
1818
param->Fill(1.0f);
1919

20-
auto adam = std::make_shared<optimizers::Adam>(
21-
std::vector<std::pair<std::string, std::shared_ptr<Tensor>>>{{"weight", param}}, 0.001);
20+
auto adam = std::make_shared<optimizers::Adam>(std::vector<std::shared_ptr<Tensor>>{{param}}, 0.001);
2221

2322
adam->ZeroGrad();
2423
adam->Step(); // t=1
2524
adam->Step(); // t=2
2625

2726
auto state = adam->StateDict();
2827
EXPECT_GT(state.size(), 0);
29-
EXPECT_TRUE(state.count("adam.m.weight"));
30-
EXPECT_TRUE(state.count("adam.v.weight"));
28+
EXPECT_TRUE(state.count("adam.m.0"));
29+
EXPECT_TRUE(state.count("adam.v.0"));
3130
EXPECT_TRUE(state.count("adam.t"));
3231

3332
auto t_cpu = state["adam.t"]->To(Device());
@@ -41,8 +40,7 @@ TEST_P(OptimizerStateTest, AdamStateDictRoundTrip) {
4140
param1->set_requires_grad(true);
4241
param1->Fill(1.0f);
4342

44-
auto adam1 = std::make_shared<optimizers::Adam>(
45-
std::vector<std::pair<std::string, std::shared_ptr<Tensor>>>{{"w", param1}}, 0.001);
43+
auto adam1 = std::make_shared<optimizers::Adam>(std::vector<std::shared_ptr<Tensor>>{{param1}}, 0.001);
4644
adam1->ZeroGrad();
4745
adam1->Step();
4846
adam1->Step();
@@ -54,8 +52,7 @@ TEST_P(OptimizerStateTest, AdamStateDictRoundTrip) {
5452
param2->set_requires_grad(true);
5553
param2->Fill(1.0f);
5654

57-
auto adam2 = std::make_shared<optimizers::Adam>(
58-
std::vector<std::pair<std::string, std::shared_ptr<Tensor>>>{{"w", param2}}, 0.001);
55+
auto adam2 = std::make_shared<optimizers::Adam>(std::vector<std::shared_ptr<Tensor>>{{param2}}, 0.001);
5956
adam2->LoadStateDict(saved);
6057

6158
adam2->ZeroGrad();

tests/checkpoint/test_trainer_state.cc

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,11 @@ TEST_P(TrainerStateTest, DefaultValues) {
2121
TrainerState state;
2222
EXPECT_EQ(state.global_step, 0);
2323
EXPECT_EQ(state.consumed_batches, 0);
24+
EXPECT_EQ(state.n_layer, 0);
25+
EXPECT_EQ(state.n_head, 0);
26+
EXPECT_EQ(state.n_kv_head, 0);
27+
EXPECT_EQ(state.n_embd, 0);
28+
EXPECT_EQ(state.vocab_size, 0);
2429
EXPECT_EQ(state.ddp_size, 1);
2530
EXPECT_EQ(state.tp_size, 1);
2631
EXPECT_EQ(state.sp_size, 1);
@@ -48,7 +53,6 @@ TEST_P(TrainerStateTest, TrainerStateFileCreated) {
4853
std::string content((std::istreambuf_iterator<char>(ifs)), std::istreambuf_iterator<char>());
4954
EXPECT_NE(content.find("\"global_step\""), std::string::npos);
5055
EXPECT_NE(content.find("\"consumed_batches \""), std::string::npos);
51-
EXPECT_NE(content.find("\"Adam\""), std::string::npos);
5256

5357
std::filesystem::remove_all(dir);
5458
}
@@ -61,6 +65,11 @@ TEST_P(TrainerStateTest, RoundTrip) {
6165
.global_step = 99,
6266
.consumed_batches = 5000,
6367
.last_lr = 3e-4,
68+
.n_layer = 24,
69+
.n_head = 16,
70+
.n_kv_head = 8,
71+
.n_embd = 1024,
72+
.vocab_size = 128256,
6473
.ddp_size = 2,
6574
.tp_size = 1,
6675
.sp_size = 1,
@@ -87,6 +96,11 @@ TEST_P(TrainerStateTest, RoundTrip) {
8796
EXPECT_EQ(loaded.global_step, 99);
8897
EXPECT_EQ(loaded.consumed_batches, 5000);
8998
EXPECT_NEAR(loaded.last_lr, 3e-4, 1e-10);
99+
EXPECT_EQ(loaded.n_layer, 24);
100+
EXPECT_EQ(loaded.n_head, 16);
101+
EXPECT_EQ(loaded.n_kv_head, 8);
102+
EXPECT_EQ(loaded.n_embd, 1024);
103+
EXPECT_EQ(loaded.vocab_size, 128256);
90104
EXPECT_EQ(loaded.ddp_size, 2);
91105
EXPECT_EQ(loaded.pp_size, 2);
92106

0 commit comments

Comments
 (0)