4141#include " example/common/checkpoint_loader.h"
4242#include " example/common/tiny_shakespeare_dataset.h"
4343#include " example/common/tokenizer.h"
44+ #include " example/gpt2/checkpoint_loader.h"
4445#include " example/gpt2/config.h"
4546
4647// I/O
@@ -81,12 +82,12 @@ DEFINE_uint32(virtual_pipeline_parallel, 1, "Number of chunks in PP stage.");
8182
8283// precision
8384DEFINE_string (dtype, " float32" , " precision used in training (float32/bfloat16)" );
84- DEFINE_uint32 (save_steps , 0 , " save checkpoint every N steps; 0 disables saving" );
85- DEFINE_string (resume_from , " " , " checkpoint directory to resume from" );
86- DEFINE_string (checkpoint_dir , " ./checkpoints" , " root directory used to store checkpoints" );
85+ DEFINE_uint32 (save_interval , 0 , " save checkpoint every N steps; 0 disables saving" );
86+ DEFINE_string (load , " " , " checkpoint directory to resume from" );
87+ DEFINE_string (save , " ./checkpoints" , " root directory used to store checkpoints" );
8788DEFINE_uint32 (max_checkpoint_keep, 3 , " max number of checkpoint steps to keep" );
88- DEFINE_bool (save_optimizer_state, true , " whether optimizer state is persisted in checkpoints" );
89- DEFINE_string (checkpoint_format , " ckpt" ,
89+ DEFINE_bool (no_save_optim, false , " whether optimizer state is persisted in checkpoints" );
90+ DEFINE_string (checkpoint_file_format , " ckpt" ,
9091 " checkpoint format: bin|ckpt. "
9192 " 'bin' generates model.bin/optimizer.bin (bin supports LLMC model format via callbacks); "
9293 " 'ckpt' generates model.ckpt/optimizer.ckpt (native StateDict binary)." );
@@ -317,7 +318,7 @@ void Train(const nn::parallel::Rank &rank) {
317318 // TODO(dcj): support more complex optimizer later
318319 // auto optimizer = optimizers::SGD(model->Parameters(), FLAGS_learning_rate);
319320 std::shared_ptr<Optimizer> optimizer = nullptr ;
320- auto optimizer_creator = optimizers::SGD::Create (FLAGS_learning_rate);
321+ auto optimizer_creator = optimizers::SGD::CreateNamed (FLAGS_learning_rate);
321322
322323 if (FLAGS_use_distributed_optimizer) {
323324 auto model_chunks = (pp_world_size > 1 )
@@ -349,8 +350,8 @@ void Train(const nn::parallel::Rank &rank) {
349350 auto loaded_model = gpt2::LoadFromLLMC (model_path.string ());
350351 target_model->LoadStateDict (loaded_model->StateDict ());
351352 };
352- const auto resume_result = infini_train:: ResumeFromCheckpoint ({
353- .resume_root = FLAGS_resume_from ,
353+ const auto resume_result = ResumeFromCheckpoint ({
354+ .resume_root = FLAGS_load ,
354355 .rank = rank,
355356 .model = model,
356357 .optimizer = optimizer,
@@ -365,21 +366,21 @@ void Train(const nn::parallel::Rank &rank) {
365366
366367 auto save_checkpoint
367368 = [&](const std::filesystem::path &save_dir, int64_t global_step, bool prune_step_checkpoints) {
368- infini_train:: SaveCheckpoint ({
369+ SaveCheckpoint ({
369370 .save_dir = save_dir,
370371 .global_step = global_step,
371372 .data_batch_idx = saved_data_batch_idx,
372373 .best_loss = best_loss,
373374 .last_lr = FLAGS_learning_rate,
374375 .optimizer_type = " SGD" ,
375- .checkpoint_format = FLAGS_checkpoint_format ,
376+ .checkpoint_file_format = FLAGS_checkpoint_file_format ,
376377 .ddp_size = ddp_world_size,
377378 .tp_size = tp_world_size,
378379 .sp_size = sp_world_size,
379380 .pp_size = pp_world_size,
380- .save_optimizer_state = FLAGS_save_optimizer_state ,
381+ .no_save_optim = FLAGS_no_save_optim ,
381382 .prune_step_checkpoints = prune_step_checkpoints,
382- .checkpoint_root_dir = FLAGS_checkpoint_dir ,
383+ .checkpoint_root_dir = FLAGS_save ,
383384 .max_checkpoint_keep = FLAGS_max_checkpoint_keep,
384385 .rank = rank,
385386 .model = *model,
@@ -509,9 +510,9 @@ void Train(const nn::parallel::Rank &rank) {
509510 }
510511 }
511512
512- if (FLAGS_save_steps > 0 && (step + 1 ) % FLAGS_save_steps == 0 ) {
513+ if (FLAGS_save_interval > 0 && (step + 1 ) % FLAGS_save_interval == 0 ) {
513514 std::filesystem::path step_dir
514- = std::filesystem::path (FLAGS_checkpoint_dir ) / std::format (" checkpoint_step_{:06d}" , step + 1 );
515+ = std::filesystem::path (FLAGS_save ) / std::format (" checkpoint_step_{:06d}" , step + 1 );
515516 if (rank.IsParallel ()) {
516517 step_dir /= std::format (" rank_{:06d}" , rank.GlobalRank ());
517518 }
@@ -525,7 +526,7 @@ void Train(const nn::parallel::Rank &rank) {
525526 nn::lora::SaveLoRAWeights (model, FLAGS_lora_save_path);
526527 }
527528
528- std::filesystem::path final_dir = std::filesystem::path (FLAGS_checkpoint_dir ) / " checkpoint_final" ;
529+ std::filesystem::path final_dir = std::filesystem::path (FLAGS_save ) / " checkpoint_final" ;
529530 if (rank.IsParallel ()) {
530531 final_dir /= std::format (" rank_{:06d}" , rank.GlobalRank ());
531532 }
0 commit comments