Skip to content

Commit a1e9b30

Browse files
committed
temp2
1 parent 8ebe12f commit a1e9b30

12 files changed

Lines changed: 196 additions & 252 deletions

File tree

example/common/checkpoint_loader.cc

Lines changed: 13 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,16 @@ namespace nn = infini_train::nn;
1717

1818
ResumeFromCheckpointResult ResumeFromCheckpoint(const ResumeFromCheckpointArgs &args) {
1919
ResumeFromCheckpointResult result;
20-
int ddp_world_size = nn::parallel::global::GetDataParallelSize();
21-
int tp_world_size = nn::parallel::global::GetTensorParallelSize();
22-
int sp_world_size = nn::parallel::global::GetSequenceParallelEnabled() ? tp_world_size : 1;
23-
int pp_world_size = nn::parallel::global::GetPipelineParallelSize();
24-
2520
if (args.resume_root.empty()) {
2621
LOG(INFO) << "No checkpoint specified for resume. Starting training from scratch.";
2722
return result;
2823
}
2924

25+
int ddp_world_size = nn::parallel::global::GetDataParallelSize();
26+
int tp_world_size = nn::parallel::global::GetTensorParallelSize();
27+
int sp_world_size = nn::parallel::global::GetSequenceParallelEnabled() ? tp_world_size : 1;
28+
int pp_world_size = nn::parallel::global::GetPipelineParallelSize();
29+
3030
std::filesystem::path resume_dir = args.resume_root;
3131
if (args.rank.IsParallel()) {
3232
const auto rank_dir = resume_dir / std::format("rank_{:06d}", args.rank.GlobalRank());
@@ -35,27 +35,23 @@ ResumeFromCheckpointResult ResumeFromCheckpoint(const ResumeFromCheckpointArgs &
3535
}
3636
}
3737

38-
Checkpoint::Load(resume_dir, args.model.get(), args.optimizer.get(), &args.state, args.load_options);
38+
Checkpoint::Load(resume_dir, *args.model, args.optimizer.get(), args.state);
3939

4040
result.global_step = static_cast<int>(args.state.global_step);
41-
if (args.state.data_batch_stride != static_cast<int64_t>(ddp_world_size)) {
42-
LOG(FATAL) << std::format("Checkpoint data_batch_stride {} mismatches current ddp_world_size {}. "
43-
"Proceeding with recorded data_batch_idx {}.",
44-
args.state.data_batch_stride, ddp_world_size, args.state.data_batch_idx);
45-
}
4641

42+
CHECK_EQ(args.state.ddp_size, ddp_world_size) << "DDP size mismatch: checkpoint has DDP=" << args.state.ddp_size
43+
<< ", but current run has DDP=" << ddp_world_size;
4744
CHECK_EQ(args.state.tp_size, tp_world_size)
4845
<< "TP size mismatch: checkpoint has TP=" << args.state.tp_size << ", but current run has TP=" << tp_world_size;
4946
CHECK_EQ(args.state.sp_size, sp_world_size)
5047
<< "SP size mismatch: checkpoint has SP=" << args.state.sp_size << ", but current run has SP=" << sp_world_size;
5148
CHECK_EQ(args.state.pp_size, pp_world_size)
5249
<< "PP size mismatch: checkpoint has PP=" << args.state.pp_size << ", but current run has PP=" << pp_world_size;
5350

54-
result.data_batch_idx = static_cast<size_t>(std::max<int64_t>(args.state.data_batch_idx, 0));
55-
args.train_iter = args.train_loader.IteratorAtBatchIndex(result.data_batch_idx);
51+
result.consumed_batches = static_cast<size_t>(std::max<int64_t>(args.state.consumed_batches, 0));
5652
if (args.rank.IsMainRank()) {
57-
LOG(INFO) << std::format("Resume training from step {}, last_lr {:.3e}, data_batch_idx {}",
58-
args.state.global_step, args.state.last_lr, args.state.data_batch_idx);
53+
LOG(INFO) << std::format("Resume training from step {}, last_lr {:.3e}, consumed_batches {}",
54+
args.state.global_step, args.state.last_lr, args.state.consumed_batches);
5955
}
6056

6157
return result;
@@ -66,21 +62,14 @@ void SaveCheckpoint(const SaveCheckpointArgs &args) {
6662

6763
TrainerState state;
6864
state.global_step = args.global_step;
69-
state.data_batch_idx = static_cast<int64_t>(args.data_batch_idx);
70-
state.data_batch_stride = args.ddp_size;
65+
state.consumed_batches = static_cast<int64_t>(args.consumed_batches);
7166
state.last_lr = args.last_lr;
72-
state.optimizer_type = args.optimizer_type;
73-
state.checkpoint_file_format = args.checkpoint_file_format;
7467
state.ddp_size = args.ddp_size;
7568
state.tp_size = args.tp_size;
7669
state.sp_size = args.sp_size;
7770
state.pp_size = args.pp_size;
7871

79-
CheckpointOptions options;
80-
options.format = args.checkpoint_file_format;
81-
options.no_save_optim = args.no_save_optim;
82-
options.model_bin_writer = args.model_bin_writer;
83-
Checkpoint::Save(args.save_dir, args.model, args.optimizer, state, options);
72+
Checkpoint::Save(args.save_dir, args.model, &args.optimizer, state);
8473

8574
const auto ckpt_end = std::chrono::high_resolution_clock::now();
8675
const double ckpt_ms = std::chrono::duration<double, std::milli>(ckpt_end - ckpt_start).count();

example/common/checkpoint_loader.h

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,22 +25,18 @@ struct ResumeFromCheckpointArgs {
2525
std::shared_ptr<Optimizer> optimizer;
2626
DistributedDataLoader &train_loader;
2727
TrainerState &state;
28-
DataLoaderIterator &train_iter;
29-
CheckpointLoadOptions load_options;
3028
};
3129

3230
struct ResumeFromCheckpointResult {
3331
int global_step = 0;
34-
size_t data_batch_idx = 0;
32+
size_t consumed_batches = 0;
3533
};
3634

3735
struct SaveCheckpointArgs {
3836
std::filesystem::path save_dir;
3937
int64_t global_step = 0;
40-
size_t data_batch_idx = 0;
38+
size_t consumed_batches = 0;
4139
double last_lr = 0.0;
42-
std::string optimizer_type;
43-
std::string checkpoint_file_format = "bin";
4440
int ddp_size = 1;
4541
int tp_size = 1;
4642
int sp_size = 1;
@@ -52,7 +48,6 @@ struct SaveCheckpointArgs {
5248
const nn::parallel::Rank &rank;
5349
const nn::Module &model;
5450
const Optimizer &optimizer;
55-
std::function<void(const nn::Module &, const std::filesystem::path &)> model_bin_writer;
5651
};
5752

5853
ResumeFromCheckpointResult ResumeFromCheckpoint(const ResumeFromCheckpointArgs &args);

example/gpt2/main.cc

Lines changed: 21 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,6 @@ DEFINE_string(load, "", "checkpoint directory to resume from");
8888
DEFINE_string(save, "./checkpoints", "root directory used to store checkpoints");
8989
DEFINE_uint32(max_checkpoint_keep, 3, "max number of checkpoint steps to keep");
9090
DEFINE_bool(no_save_optim, false, "whether optimizer state is persisted in checkpoints");
91-
DEFINE_string(checkpoint_file_format, "ckpt",
92-
"checkpoint format: bin|ckpt. "
93-
"'bin' generates model.bin/optimizer.bin (bin supports LLMC model format via callbacks); "
94-
"'ckpt' generates model.ckpt/optimizer.ckpt (native StateDict binary).");
9591
// precision check
9692
DEFINE_string(
9793
precision_check, "",
@@ -332,7 +328,7 @@ void Train(const nn::parallel::Rank &rank) {
332328
}
333329

334330
auto train_iter = train_loader.begin();
335-
size_t saved_data_batch_idx = train_iter.BatchIndex();
331+
336332
std::shared_ptr<nn::Module> loss_fn
337333
= (tp_world_size > 1) ? std::static_pointer_cast<nn::Module>(
338334
std::make_shared<VocabParallelCrossEntropyLoss>(model_config.original_vocab_size))
@@ -344,34 +340,32 @@ void Train(const nn::parallel::Rank &rank) {
344340

345341
int start_step = 0;
346342
TrainerState state;
347-
CheckpointLoadOptions load_options;
348-
load_options.load_optimizer_state = true;
349-
load_options.model_bin_loader = [](nn::Module *target_model, const std::filesystem::path &model_path) {
350-
auto loaded_model = gpt2::LoadFromLLMC(model_path.string());
351-
target_model->LoadStateDict(loaded_model->StateDict());
352-
};
353-
const auto resume_result = ResumeFromCheckpoint({
354-
.resume_root = FLAGS_load,
355-
.rank = rank,
356-
.model = model,
357-
.optimizer = optimizer,
358-
.train_loader = train_loader,
359-
.state = state,
360-
.train_iter = train_iter,
361-
.load_options = load_options,
362-
});
343+
const auto resume_result = ResumeFromCheckpoint({.resume_root = FLAGS_load,
344+
.rank = rank,
345+
.model = model,
346+
.optimizer = optimizer,
347+
.train_loader = train_loader,
348+
.state = state});
363349
start_step = resume_result.global_step;
364-
saved_data_batch_idx = resume_result.data_batch_idx;
350+
size_t consumed_batches = resume_result.consumed_batches;
351+
352+
// TODO(jym): Replace with Sampler abstraction when available.
353+
// Skip dataloader to resume from the correct batch position.
354+
if (consumed_batches > 0) {
355+
size_t start = train_iter.BatchIndex();
356+
// Each rank processes every ddp_world_size-th batch starting from its own rank.
357+
// num_skips calculates how many ++ iterations to reach the saved batch position.
358+
size_t num_skips = (consumed_batches - start) / ddp_world_size;
359+
for (size_t i = 0; i < num_skips; ++i) { ++train_iter; }
360+
}
365361

366362
auto save_checkpoint
367363
= [&](const std::filesystem::path &save_dir, int64_t global_step, bool prune_step_checkpoints) {
368364
SaveCheckpoint({
369365
.save_dir = save_dir,
370366
.global_step = global_step,
371-
.data_batch_idx = saved_data_batch_idx,
367+
.consumed_batches = consumed_batches,
372368
.last_lr = FLAGS_learning_rate,
373-
.optimizer_type = "SGD",
374-
.checkpoint_file_format = FLAGS_checkpoint_file_format,
375369
.ddp_size = ddp_world_size,
376370
.tp_size = tp_world_size,
377371
.sp_size = sp_world_size,
@@ -383,9 +377,6 @@ void Train(const nn::parallel::Rank &rank) {
383377
.rank = rank,
384378
.model = *model,
385379
.optimizer = *optimizer,
386-
.model_bin_writer
387-
= [&](const nn::Module &,
388-
const std::filesystem::path &model_path) { gpt2::SaveAsLLMC(llmc_model, model_path.string()); },
389380
});
390381
};
391382

@@ -439,7 +430,7 @@ void Train(const nn::parallel::Rank &rank) {
439430
// if we are trying to overfit a single batch, we reset the loader here by commenting out the line below
440431
// TODO(dcj): support dataloader.reset() later
441432
++train_iter;
442-
saved_data_batch_idx = train_iter.BatchIndex();
433+
consumed_batches = train_iter.BatchIndex();
443434
x = std::make_shared<Tensor>(x->To(device));
444435
y = std::make_shared<Tensor>(y->To(device));
445436

@@ -470,7 +461,7 @@ void Train(const nn::parallel::Rank &rank) {
470461
// if we are trying to overfit a single batch, we reset the loader here by commenting out the line below
471462
// TODO(dcj): support dataloader.reset() later
472463
++train_iter;
473-
saved_data_batch_idx = train_iter.BatchIndex();
464+
consumed_batches = train_iter.BatchIndex();
474465
x = std::make_shared<Tensor>(x->To(device));
475466
y = std::make_shared<Tensor>(y->To(device));
476467

example/llama3/main.cc

Lines changed: 41 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,7 @@ DEFINE_string(load, "", "checkpoint directory to resume from");
8686
DEFINE_string(save, "./checkpoints", "root directory used to store checkpoints");
8787
DEFINE_uint32(max_checkpoint_keep, 3, "max number of checkpoint steps to keep");
8888
DEFINE_bool(no_save_optim, true, "whether optimizer state is persisted in checkpoints");
89-
DEFINE_string(checkpoint_file_format, "ckpt",
90-
"checkpoint format: bin|ckpt. "
91-
"'bin' generates model.bin/optimizer.bin (bin supports LLMC model format via callbacks); "
92-
"'ckpt' generates model.ckpt/optimizer.ckpt (native StateDict binary).");
89+
9390
// precision check
9491
DEFINE_string(
9592
precision_check, "",
@@ -311,7 +308,7 @@ void Train(const nn::parallel::Rank &rank) {
311308
}
312309

313310
auto train_iter = train_loader.begin();
314-
size_t saved_data_batch_idx = train_iter.BatchIndex();
311+
315312
std::shared_ptr<nn::Module> loss_fn
316313
= (tp_world_size > 1) ? std::static_pointer_cast<nn::Module>(std::make_shared<VocabParallelCrossEntropyLoss>())
317314
: std::static_pointer_cast<nn::Module>(std::make_shared<nn::CrossEntropyLoss>());
@@ -322,50 +319,48 @@ void Train(const nn::parallel::Rank &rank) {
322319

323320
int start_step = 0;
324321
TrainerState state;
325-
CheckpointLoadOptions load_options;
326-
load_options.load_optimizer_state = true;
327-
load_options.model_bin_loader = [](nn::Module *target_model, const std::filesystem::path &model_path) {
328-
auto loaded_model = llama3::LoadFromLLMC(model_path.string());
329-
target_model->LoadStateDict(loaded_model->StateDict());
330-
};
331322
const auto resume_result = ResumeFromCheckpoint({
332323
.resume_root = FLAGS_load,
333324
.rank = rank,
334325
.model = model,
335326
.optimizer = optimizer,
336327
.train_loader = train_loader,
337328
.state = state,
338-
.train_iter = train_iter,
339-
.load_options = load_options,
340329
});
330+
341331
start_step = resume_result.global_step;
342-
saved_data_batch_idx = resume_result.data_batch_idx;
343-
344-
auto save_checkpoint = [&](const std::filesystem::path &save_dir, int64_t global_step,
345-
bool prune_step_checkpoints) {
346-
SaveCheckpoint({
347-
.save_dir = save_dir,
348-
.global_step = global_step,
349-
.data_batch_idx = saved_data_batch_idx,
350-
.last_lr = FLAGS_learning_rate,
351-
.optimizer_type = "Adam",
352-
.checkpoint_file_format = FLAGS_checkpoint_file_format,
353-
.ddp_size = ddp_world_size,
354-
.tp_size = tp_world_size,
355-
.sp_size = sp_world_size,
356-
.pp_size = pp_world_size,
357-
.no_save_optim = FLAGS_no_save_optim,
358-
.prune_step_checkpoints = prune_step_checkpoints,
359-
.checkpoint_root_dir = FLAGS_save,
360-
.max_checkpoint_keep = FLAGS_max_checkpoint_keep,
361-
.rank = rank,
362-
.model = *model,
363-
.optimizer = *optimizer,
364-
.model_bin_writer
365-
= [&](const nn::Module &,
366-
const std::filesystem::path &model_path) { llama3::SaveAsLLMC(llmc_model, model_path.string()); },
367-
});
368-
};
332+
size_t consumed_batches = resume_result.consumed_batches;
333+
334+
// TODO(jym): Replace with Sampler abstraction when available.
335+
// Skip dataloader to resume from the correct batch position.
336+
if (consumed_batches > 0) {
337+
size_t start = train_iter.BatchIndex();
338+
// Each rank processes every ddp_world_size-th batch starting from its own rank.
339+
// num_skips calculates how many ++ iterations to reach the saved batch position.
340+
size_t num_skips = (consumed_batches - start) / ddp_world_size;
341+
for (size_t i = 0; i < num_skips; ++i) { ++train_iter; }
342+
}
343+
344+
auto save_checkpoint
345+
= [&](const std::filesystem::path &save_dir, int64_t global_step, bool prune_step_checkpoints) {
346+
SaveCheckpoint({
347+
.save_dir = save_dir,
348+
.global_step = global_step,
349+
.consumed_batches = consumed_batches,
350+
.last_lr = FLAGS_learning_rate,
351+
.ddp_size = ddp_world_size,
352+
.tp_size = tp_world_size,
353+
.sp_size = sp_world_size,
354+
.pp_size = pp_world_size,
355+
.no_save_optim = FLAGS_no_save_optim,
356+
.prune_step_checkpoints = prune_step_checkpoints,
357+
.checkpoint_root_dir = FLAGS_save,
358+
.max_checkpoint_keep = FLAGS_max_checkpoint_keep,
359+
.rank = rank,
360+
.model = *model,
361+
.optimizer = *optimizer,
362+
});
363+
};
369364

370365
for (int step = start_step; step < FLAGS_num_iteration + 1; ++step) {
371366
// Reset precision check counters at start of each iteration for file overwrite
@@ -414,10 +409,11 @@ void Train(const nn::parallel::Rank &rank) {
414409

415410
// (bs, seq_len), (bs, seq_len)
416411
auto [x, y] = *train_iter;
417-
// if we are trying to overfit a single batch, we reset the loader here by commenting out the line below
412+
// if we are trying to overfit a single batch, we reset the loader here by commenting out the
413+
// line below
418414
// TODO(dcj): support dataloader.reset() later
419415
++train_iter;
420-
saved_data_batch_idx = train_iter.BatchIndex();
416+
consumed_batches = train_iter.BatchIndex();
421417
x = std::make_shared<Tensor>(x->To(device));
422418
y = std::make_shared<Tensor>(y->To(device));
423419

@@ -444,10 +440,11 @@ void Train(const nn::parallel::Rank &rank) {
444440
optimizer->Step();
445441
} else {
446442
auto [x, y] = *train_iter;
447-
// if we are trying to overfit a single batch, we reset the loader here by commenting out the line below
443+
// if we are trying to overfit a single batch, we reset the loader here by commenting out the line
444+
// below
448445
// TODO(dcj): support dataloader.reset() later
449446
++train_iter;
450-
saved_data_batch_idx = train_iter.BatchIndex();
447+
consumed_batches = train_iter.BatchIndex();
451448
x = std::make_shared<Tensor>(x->To(device));
452449
y = std::make_shared<Tensor>(y->To(device));
453450

infini_train/include/checkpoint.h

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,36 +16,24 @@ class Module;
1616

1717
struct TrainerState {
1818
int64_t global_step = 0;
19-
int64_t data_batch_idx = 0;
20-
int64_t data_batch_stride = 1;
19+
int64_t consumed_batches = 0;
20+
// FIXME(jym): learning_rate should be restored from scheduler state, move `last_lr` from TrainerState to
21+
// SchedulerState later
2122
double last_lr = 0.0;
22-
std::string optimizer_type = "unknown";
23-
std::string checkpoint_file_format = "bin";
2423

2524
int ddp_size = 1;
2625
int tp_size = 1;
2726
int sp_size = 1;
2827
int pp_size = 1;
2928
};
3029

31-
struct CheckpointOptions {
32-
std::string format = "bin";
33-
bool no_save_optim = false;
34-
std::function<void(const nn::Module &, const std::filesystem::path &)> model_bin_writer;
35-
};
36-
37-
struct CheckpointLoadOptions {
38-
bool load_optimizer_state = true;
39-
std::function<void(nn::Module *, const std::filesystem::path &)> model_bin_loader;
40-
};
41-
4230
class Checkpoint {
4331
public:
44-
static void Save(const std::filesystem::path &checkpoint_dir, const nn::Module &model, const Optimizer &optimizer,
45-
const TrainerState &state, const CheckpointOptions &options = {});
32+
static void Save(const std::filesystem::path &checkpoint_dir, const nn::Module &model, const Optimizer *optimizer,
33+
const TrainerState &state, bool no_save_optim = false);
4634

47-
static void Load(const std::filesystem::path &checkpoint_dir, nn::Module *model, Optimizer *optimizer,
48-
TrainerState *state, const CheckpointLoadOptions &options = {});
35+
static void Load(const std::filesystem::path &checkpoint_dir, nn::Module &model, Optimizer *optimizer,
36+
TrainerState &state, bool load_optimizer_state = true);
4937

5038
private:
5139
static void SaveStateDictBinary(const std::filesystem::path &path,
@@ -56,7 +44,6 @@ class Checkpoint {
5644

5745
static void SaveTrainerState(const std::filesystem::path &path, const TrainerState &state);
5846
static TrainerState LoadTrainerState(const std::filesystem::path &path);
59-
static std::string InferFormat(const std::filesystem::path &checkpoint_dir);
6047
};
6148

6249
} // namespace infini_train

0 commit comments

Comments
 (0)