@@ -86,10 +86,7 @@ DEFINE_string(load, "", "checkpoint directory to resume from");
8686DEFINE_string (save, " ./checkpoints" , " root directory used to store checkpoints" );
8787DEFINE_uint32 (max_checkpoint_keep, 3 , " max number of checkpoint steps to keep" );
8888DEFINE_bool (no_save_optim, true , " whether optimizer state is persisted in checkpoints" );
89- DEFINE_string (checkpoint_file_format, " ckpt" ,
90- " checkpoint format: bin|ckpt. "
91- " 'bin' generates model.bin/optimizer.bin (bin supports LLMC model format via callbacks); "
92- " 'ckpt' generates model.ckpt/optimizer.ckpt (native StateDict binary)." );
89+
9390// precision check
9491DEFINE_string (
9592 precision_check, " " ,
@@ -311,7 +308,7 @@ void Train(const nn::parallel::Rank &rank) {
311308 }
312309
313310 auto train_iter = train_loader.begin ();
314- size_t saved_data_batch_idx = train_iter. BatchIndex ();
311+
315312 std::shared_ptr<nn::Module> loss_fn
316313 = (tp_world_size > 1 ) ? std::static_pointer_cast<nn::Module>(std::make_shared<VocabParallelCrossEntropyLoss>())
317314 : std::static_pointer_cast<nn::Module>(std::make_shared<nn::CrossEntropyLoss>());
@@ -322,50 +319,48 @@ void Train(const nn::parallel::Rank &rank) {
322319
323320 int start_step = 0 ;
324321 TrainerState state;
325- CheckpointLoadOptions load_options;
326- load_options.load_optimizer_state = true ;
327- load_options.model_bin_loader = [](nn::Module *target_model, const std::filesystem::path &model_path) {
328- auto loaded_model = llama3::LoadFromLLMC (model_path.string ());
329- target_model->LoadStateDict (loaded_model->StateDict ());
330- };
331322 const auto resume_result = ResumeFromCheckpoint ({
332323 .resume_root = FLAGS_load,
333324 .rank = rank,
334325 .model = model,
335326 .optimizer = optimizer,
336327 .train_loader = train_loader,
337328 .state = state,
338- .train_iter = train_iter,
339- .load_options = load_options,
340329 });
330+
341331 start_step = resume_result.global_step ;
342- saved_data_batch_idx = resume_result.data_batch_idx ;
343-
344- auto save_checkpoint = [&](const std::filesystem::path &save_dir, int64_t global_step,
345- bool prune_step_checkpoints) {
346- SaveCheckpoint ({
347- .save_dir = save_dir,
348- .global_step = global_step,
349- .data_batch_idx = saved_data_batch_idx,
350- .last_lr = FLAGS_learning_rate,
351- .optimizer_type = " Adam" ,
352- .checkpoint_file_format = FLAGS_checkpoint_file_format,
353- .ddp_size = ddp_world_size,
354- .tp_size = tp_world_size,
355- .sp_size = sp_world_size,
356- .pp_size = pp_world_size,
357- .no_save_optim = FLAGS_no_save_optim,
358- .prune_step_checkpoints = prune_step_checkpoints,
359- .checkpoint_root_dir = FLAGS_save,
360- .max_checkpoint_keep = FLAGS_max_checkpoint_keep,
361- .rank = rank,
362- .model = *model,
363- .optimizer = *optimizer,
364- .model_bin_writer
365- = [&](const nn::Module &,
366- const std::filesystem::path &model_path) { llama3::SaveAsLLMC (llmc_model, model_path.string ()); },
367- });
368- };
332+ size_t consumed_batches = resume_result.consumed_batches ;
333+
334+ // TODO(jym): Replace with Sampler abstraction when available.
335+ // Skip dataloader to resume from the correct batch position.
336+ if (consumed_batches > 0 ) {
337+ size_t start = train_iter.BatchIndex ();
338+ // Each rank processes every ddp_world_size-th batch starting from its own rank.
339+ // num_skips calculates how many ++ iterations to reach the saved batch position.
340+ size_t num_skips = (consumed_batches - start) / ddp_world_size;
341+ for (size_t i = 0 ; i < num_skips; ++i) { ++train_iter; }
342+ }
343+
344+ auto save_checkpoint
345+ = [&](const std::filesystem::path &save_dir, int64_t global_step, bool prune_step_checkpoints) {
346+ SaveCheckpoint ({
347+ .save_dir = save_dir,
348+ .global_step = global_step,
349+ .consumed_batches = consumed_batches,
350+ .last_lr = FLAGS_learning_rate,
351+ .ddp_size = ddp_world_size,
352+ .tp_size = tp_world_size,
353+ .sp_size = sp_world_size,
354+ .pp_size = pp_world_size,
355+ .no_save_optim = FLAGS_no_save_optim,
356+ .prune_step_checkpoints = prune_step_checkpoints,
357+ .checkpoint_root_dir = FLAGS_save,
358+ .max_checkpoint_keep = FLAGS_max_checkpoint_keep,
359+ .rank = rank,
360+ .model = *model,
361+ .optimizer = *optimizer,
362+ });
363+ };
369364
370365 for (int step = start_step; step < FLAGS_num_iteration + 1 ; ++step) {
371366 // Reset precision check counters at start of each iteration for file overwrite
@@ -414,10 +409,11 @@ void Train(const nn::parallel::Rank &rank) {
414409
415410 // (bs, seq_len), (bs, seq_len)
416411 auto [x, y] = *train_iter;
417- // if we are trying to overfit a single batch, we reset the loader here by commenting out the line below
412+ // if we are trying to overfit a single batch, we reset the loader here by commenting out the
413+ // line below
418414 // TODO(dcj): support dataloader.reset() later
419415 ++train_iter;
420- saved_data_batch_idx = train_iter.BatchIndex ();
416+ consumed_batches = train_iter.BatchIndex ();
421417 x = std::make_shared<Tensor>(x->To (device));
422418 y = std::make_shared<Tensor>(y->To (device));
423419
@@ -444,10 +440,11 @@ void Train(const nn::parallel::Rank &rank) {
444440 optimizer->Step ();
445441 } else {
446442 auto [x, y] = *train_iter;
447- // if we are trying to overfit a single batch, we reset the loader here by commenting out the line below
443+ // if we are trying to overfit a single batch, we reset the loader here by commenting out the line
444+ // below
448445 // TODO(dcj): support dataloader.reset() later
449446 ++train_iter;
450- saved_data_batch_idx = train_iter.BatchIndex ();
447+ consumed_batches = train_iter.BatchIndex ();
451448 x = std::make_shared<Tensor>(x->To (device));
452449 y = std::make_shared<Tensor>(y->To (device));
453450
0 commit comments