@@ -86,10 +86,7 @@ DEFINE_string(load, "", "checkpoint directory to resume from");
8686DEFINE_string (save, " ./checkpoints" , " root directory used to store checkpoints" );
8787DEFINE_uint32 (max_checkpoint_keep, 3 , " max number of checkpoint steps to keep" );
8888DEFINE_bool (no_save_optim, true , " whether optimizer state is persisted in checkpoints" );
89- DEFINE_string (checkpoint_file_format, " ckpt" ,
90- " checkpoint format: bin|ckpt. "
91- " 'bin' generates model.bin/optimizer.bin (bin supports LLMC model format via callbacks); "
92- " 'ckpt' generates model.ckpt/optimizer.ckpt (native StateDict binary)." );
89+
9390// precision check
9491DEFINE_string (
9592 precision_check, " " ,
@@ -311,7 +308,7 @@ void Train(const nn::parallel::Rank &rank) {
311308 }
312309
313310 auto train_iter = train_loader.begin ();
314- size_t saved_data_batch_idx = train_iter. BatchIndex ();
311+
315312 std::shared_ptr<nn::Module> loss_fn
316313 = (tp_world_size > 1 ) ? std::static_pointer_cast<nn::Module>(std::make_shared<VocabParallelCrossEntropyLoss>())
317314 : std::static_pointer_cast<nn::Module>(std::make_shared<nn::CrossEntropyLoss>());
@@ -322,50 +319,46 @@ void Train(const nn::parallel::Rank &rank) {
322319
323320 int start_step = 0 ;
324321 TrainerState state;
325- CheckpointLoadOptions load_options;
326- load_options.load_optimizer_state = true ;
327- load_options.model_bin_loader = [](nn::Module *target_model, const std::filesystem::path &model_path) {
328- auto loaded_model = llama3::LoadFromLLMC (model_path.string ());
329- target_model->LoadStateDict (loaded_model->StateDict ());
330- };
331322 const auto resume_result = ResumeFromCheckpoint ({
332323 .resume_root = FLAGS_load,
333324 .rank = rank,
334325 .model = model,
335326 .optimizer = optimizer,
336327 .train_loader = train_loader,
337328 .state = state,
338- .train_iter = train_iter,
339- .load_options = load_options,
340329 });
330+
341331 start_step = resume_result.global_step ;
342- saved_data_batch_idx = resume_result.data_batch_idx ;
343-
344- auto save_checkpoint = [&](const std::filesystem::path &save_dir, int64_t global_step,
345- bool prune_step_checkpoints) {
346- SaveCheckpoint ({
347- .save_dir = save_dir,
348- .global_step = global_step,
349- .data_batch_idx = saved_data_batch_idx,
350- .last_lr = FLAGS_learning_rate,
351- .optimizer_type = " Adam" ,
352- .checkpoint_file_format = FLAGS_checkpoint_file_format,
353- .ddp_size = ddp_world_size,
354- .tp_size = tp_world_size,
355- .sp_size = sp_world_size,
356- .pp_size = pp_world_size,
357- .no_save_optim = FLAGS_no_save_optim,
358- .prune_step_checkpoints = prune_step_checkpoints,
359- .checkpoint_root_dir = FLAGS_save,
360- .max_checkpoint_keep = FLAGS_max_checkpoint_keep,
361- .rank = rank,
362- .model = *model,
363- .optimizer = *optimizer,
364- .model_bin_writer
365- = [&](const nn::Module &,
366- const std::filesystem::path &model_path) { llama3::SaveAsLLMC (llmc_model, model_path.string ()); },
367- });
368- };
332+ size_t consumed_batches = resume_result.consumed_batches ;
333+
334+ // TODO(jym): Replace with Sampler abstraction when available.
335+ // Skip dataloader to resume from the correct batch position.
336+ if (consumed_batches > 0 ) {
337+ size_t start = train_iter.BatchIndex ();
338+ size_t num_skips = (consumed_batches - start) / ddp_world_size;
339+ for (size_t i = 0 ; i < num_skips; ++i) { ++train_iter; }
340+ }
341+
342+ auto save_checkpoint
343+ = [&](const std::filesystem::path &save_dir, int64_t global_step, bool prune_step_checkpoints) {
344+ SaveCheckpoint ({
345+ .save_dir = save_dir,
346+ .global_step = global_step,
347+ .consumed_batches = consumed_batches,
348+ .last_lr = FLAGS_learning_rate,
349+ .ddp_size = ddp_world_size,
350+ .tp_size = tp_world_size,
351+ .sp_size = sp_world_size,
352+ .pp_size = pp_world_size,
353+ .no_save_optim = FLAGS_no_save_optim,
354+ .prune_step_checkpoints = prune_step_checkpoints,
355+ .checkpoint_root_dir = FLAGS_save,
356+ .max_checkpoint_keep = FLAGS_max_checkpoint_keep,
357+ .rank = rank,
358+ .model = *model,
359+ .optimizer = *optimizer,
360+ });
361+ };
369362
370363 for (int step = start_step; step < FLAGS_num_iteration + 1 ; ++step) {
371364 // Reset precision check counters at start of each iteration for file overwrite
@@ -414,10 +407,11 @@ void Train(const nn::parallel::Rank &rank) {
414407
415408 // (bs, seq_len), (bs, seq_len)
416409 auto [x, y] = *train_iter;
417- // if we are trying to overfit a single batch, we reset the loader here by commenting out the line below
410+ // if we are trying to overfit a single batch, we reset the loader here by commenting out the
411+ // line below
418412 // TODO(dcj): support dataloader.reset() later
419413 ++train_iter;
420- saved_data_batch_idx = train_iter.BatchIndex ();
414+ consumed_batches = train_iter.BatchIndex ();
421415 x = std::make_shared<Tensor>(x->To (device));
422416 y = std::make_shared<Tensor>(y->To (device));
423417
@@ -444,10 +438,11 @@ void Train(const nn::parallel::Rank &rank) {
444438 optimizer->Step ();
445439 } else {
446440 auto [x, y] = *train_iter;
447- // if we are trying to overfit a single batch, we reset the loader here by commenting out the line below
441+ // if we are trying to overfit a single batch, we reset the loader here by commenting out the line
442+ // below
448443 // TODO(dcj): support dataloader.reset() later
449444 ++train_iter;
450- saved_data_batch_idx = train_iter.BatchIndex ();
445+ consumed_batches = train_iter.BatchIndex ();
451446 x = std::make_shared<Tensor>(x->To (device));
452447 y = std::make_shared<Tensor>(y->To (device));
453448
0 commit comments