4141#include " example/common/checkpoint_loader.h"
4242#include " example/common/tiny_shakespeare_dataset.h"
4343#include " example/common/tokenizer.h"
44+ #include " example/gpt2/checkpoint_loader.h"
4445#include " example/gpt2/config.h"
4546
47+ // TODO(jym): Reorganize CLI flags into categories for better readability and maintainability.
4648// I/O
4749DEFINE_string (input_bin, " " , " input .bin to train on" );
4850DEFINE_string (input_val_bin, " " , " input .bin to eval validation loss on" );
@@ -81,12 +83,12 @@ DEFINE_uint32(virtual_pipeline_parallel, 1, "Number of chunks in PP stage.");
8183
8284// precision
8385DEFINE_string (dtype, " float32" , " precision used in training (float32/bfloat16)" );
84- DEFINE_uint32 (save_steps , 0 , " save checkpoint every N steps; 0 disables saving" );
85- DEFINE_string (resume_from , " " , " checkpoint directory to resume from" );
86- DEFINE_string (checkpoint_dir , " ./checkpoints" , " root directory used to store checkpoints" );
86+ DEFINE_uint32 (save_interval , 0 , " save checkpoint every N steps; 0 disables saving" );
87+ DEFINE_string (load , " " , " checkpoint directory to resume from" );
88+ DEFINE_string (save , " ./checkpoints" , " root directory used to store checkpoints" );
8789DEFINE_uint32 (max_checkpoint_keep, 3 , " max number of checkpoint steps to keep" );
88- DEFINE_bool (save_optimizer_state, true , " whether optimizer state is persisted in checkpoints" );
89- DEFINE_string (checkpoint_format , " ckpt" ,
90+ DEFINE_bool (no_save_optim, false , " whether optimizer state is persisted in checkpoints" );
91+ DEFINE_string (checkpoint_file_format , " ckpt" ,
9092 " checkpoint format: bin|ckpt. "
9193 " 'bin' generates model.bin/optimizer.bin (bin supports LLMC model format via callbacks); "
9294 " 'ckpt' generates model.ckpt/optimizer.ckpt (native StateDict binary)." );
@@ -317,7 +319,7 @@ void Train(const nn::parallel::Rank &rank) {
317319 // TODO(dcj): support more complex optimizer later
318320 // auto optimizer = optimizers::SGD(model->Parameters(), FLAGS_learning_rate);
319321 std::shared_ptr<Optimizer> optimizer = nullptr ;
320- auto optimizer_creator = optimizers::SGD::Create (FLAGS_learning_rate);
322+ auto optimizer_creator = optimizers::SGD::CreateNamed (FLAGS_learning_rate);
321323
322324 if (FLAGS_use_distributed_optimizer) {
323325 auto model_chunks = (pp_world_size > 1 )
@@ -341,16 +343,15 @@ void Train(const nn::parallel::Rank &rank) {
341343 auto impl = core::GetDeviceGuardImpl (device.type ());
342344
343345 int start_step = 0 ;
344- float best_loss = std::numeric_limits<float >::infinity ();
345346 TrainerState state;
346347 CheckpointLoadOptions load_options;
347348 load_options.load_optimizer_state = true ;
348349 load_options.model_bin_loader = [](nn::Module *target_model, const std::filesystem::path &model_path) {
349350 auto loaded_model = gpt2::LoadFromLLMC (model_path.string ());
350351 target_model->LoadStateDict (loaded_model->StateDict ());
351352 };
352- const auto resume_result = infini_train:: ResumeFromCheckpoint ({
353- .resume_root = FLAGS_resume_from ,
353+ const auto resume_result = ResumeFromCheckpoint ({
354+ .resume_root = FLAGS_load ,
354355 .rank = rank,
355356 .model = model,
356357 .optimizer = optimizer,
@@ -360,26 +361,24 @@ void Train(const nn::parallel::Rank &rank) {
360361 .load_options = load_options,
361362 });
362363 start_step = resume_result.global_step ;
363- best_loss = resume_result.best_loss ;
364364 saved_data_batch_idx = resume_result.data_batch_idx ;
365365
366366 auto save_checkpoint
367367 = [&](const std::filesystem::path &save_dir, int64_t global_step, bool prune_step_checkpoints) {
368- infini_train:: SaveCheckpoint ({
368+ SaveCheckpoint ({
369369 .save_dir = save_dir,
370370 .global_step = global_step,
371371 .data_batch_idx = saved_data_batch_idx,
372- .best_loss = best_loss,
373372 .last_lr = FLAGS_learning_rate,
374373 .optimizer_type = " SGD" ,
375- .checkpoint_format = FLAGS_checkpoint_format ,
374+ .checkpoint_file_format = FLAGS_checkpoint_file_format ,
376375 .ddp_size = ddp_world_size,
377376 .tp_size = tp_world_size,
378377 .sp_size = sp_world_size,
379378 .pp_size = pp_world_size,
380- .save_optimizer_state = FLAGS_save_optimizer_state ,
379+ .no_save_optim = FLAGS_no_save_optim ,
381380 .prune_step_checkpoints = prune_step_checkpoints,
382- .checkpoint_root_dir = FLAGS_checkpoint_dir ,
381+ .checkpoint_root_dir = FLAGS_save ,
383382 .max_checkpoint_keep = FLAGS_max_checkpoint_keep,
384383 .rank = rank,
385384 .model = *model,
@@ -484,8 +483,6 @@ void Train(const nn::parallel::Rank &rank) {
484483 lossf = static_cast <const float *>(lossf_tensor->To (Device ()).DataPtr ())[0 ];
485484 }
486485
487- best_loss = std::min (best_loss, lossf);
488-
489486 const auto iter_end = std::chrono::high_resolution_clock::now ();
490487 const double duration_us = std::chrono::duration<double , std::micro>(iter_end - iter_start).count ();
491488 const double tps = FLAGS_total_batch_size / (duration_us / 1e6 );
@@ -509,9 +506,9 @@ void Train(const nn::parallel::Rank &rank) {
509506 }
510507 }
511508
512- if (FLAGS_save_steps > 0 && (step + 1 ) % FLAGS_save_steps == 0 ) {
509+ if (FLAGS_save_interval > 0 && (step + 1 ) % FLAGS_save_interval == 0 ) {
513510 std::filesystem::path step_dir
514- = std::filesystem::path (FLAGS_checkpoint_dir ) / std::format (" checkpoint_step_{:06d}" , step + 1 );
511+ = std::filesystem::path (FLAGS_save ) / std::format (" checkpoint_step_{:06d}" , step + 1 );
515512 if (rank.IsParallel ()) {
516513 step_dir /= std::format (" rank_{:06d}" , rank.GlobalRank ());
517514 }
@@ -525,7 +522,7 @@ void Train(const nn::parallel::Rank &rank) {
525522 nn::lora::SaveLoRAWeights (model, FLAGS_lora_save_path);
526523 }
527524
528- std::filesystem::path final_dir = std::filesystem::path (FLAGS_checkpoint_dir ) / " checkpoint_final" ;
525+ std::filesystem::path final_dir = std::filesystem::path (FLAGS_save ) / " checkpoint_final" ;
529526 if (rank.IsParallel ()) {
530527 final_dir /= std::format (" rank_{:06d}" , rank.GlobalRank ());
531528 }
0 commit comments