diff --git a/docs/basic_usage/training.md b/docs/basic_usage/training.md index a41b5a0de..9f2e5dbf6 100644 --- a/docs/basic_usage/training.md +++ b/docs/basic_usage/training.md @@ -60,3 +60,11 @@ It is important to note that the `run_llama3.1_8b_eagle3_offline.sh` script cons ## 📈 Experiment Tracking This project supports logging training progress to Wandb, TensorBoard, and SwanLab. You can enable tracking by adding the `--report-to` argument to the command line in your shell script. + + +## 💾 Checkpoint Management +To ensure robust training, SpecForge now supports fine-grained checkpointing strategies through `--save-strategy=best` and `--save-total-limit`. + +`--save-strategy=best`: Automatically tracks and preserves the checkpoint with the best evaluation metric (by default, `acc_0`). This prevents overfitting and ensures that you always have access to the most optimized version of the draft model, regardless of when the training is interrupted. + +`--save-total-limit`: This feature strictly limits the number of saved checkpoints, automatically deleting older or underperforming ones to prevent disk overflow while maintaining a reliable recovery point. diff --git a/examples/run_llama3.1_8b_eagle3_online.sh b/examples/run_llama3.1_8b_eagle3_online.sh index d47c1797f..a1baab39a 100755 --- a/examples/run_llama3.1_8b_eagle3_online.sh +++ b/examples/run_llama3.1_8b_eagle3_online.sh @@ -26,4 +26,8 @@ torchrun \ --attention-backend sdpa \ --target-model-backend sglang \ --log-interval 10 \ - --sglang-mem-fraction-static 0.25 + --sglang-mem-fraction-static 0.25 \ + --save-total-limit 3 \ + --save-strategy best \ + --metric-for-best acc_0 \ + --load-best-model-at-end diff --git a/scripts/train_eagle3.py b/scripts/train_eagle3.py index 0bd157b39..c8e3edf32 100644 --- a/scripts/train_eagle3.py +++ b/scripts/train_eagle3.py @@ -2,6 +2,7 @@ import hashlib import math import os +import shutil import time from argparse import ArgumentParser, Namespace from typing import List, Optional, Tuple, Union @@ -46,12 +47,14 @@ from specforge.optimizer import BF16Optimizer from specforge.tracker import Tracker, create_tracker, get_tracker_class from specforge.utils import ( + PREFIX_CHECKPOINT_DIR, create_draft_config_from_target, get_last_checkpoint, print_args_with_dots, print_on_rank0, print_with_rank, rank_0_priority, + rotate_checkpoints, safe_conversations_generator, ) @@ -157,6 +160,10 @@ def parse_args() -> Tuple[ArgumentParser, Namespace]: ) training_group.add_argument("--eval-interval", type=int, default=5000) training_group.add_argument("--save-interval", type=int, default=5000) + training_group.add_argument( + "--save-strategy", type=str, default="steps", choices=["steps", "best"] + ) + training_group.add_argument("--save-total-limit", type=int, default=None) training_group.add_argument( "--log-interval", type=int, @@ -165,6 +172,17 @@ def parse_args() -> Tuple[ArgumentParser, Namespace]: ) training_group.add_argument("--seed", type=int, default=0) training_group.add_argument("--draft-accumulation-steps", type=int, default=1) + # best model tracking + training_group.add_argument( + "--metric-for-best", + type=str, + default="acc_0", + choices=["acc_0", "acc_1", "acc_2", "ploss_0", "ploss_1", "ploss_2"], + ) + training_group.add_argument( + "--greater-is-better", action="store_true", default=True + ) + training_group.add_argument("--load-best-model-at-end", action="store_true") # data processing type optimization_group = parser.add_argument_group("optimization") @@ -341,6 +359,13 @@ def sanity_check(args: Namespace) -> None: args.target_batch_size = args.tp_size * args.batch_size if args.attention_backend == "usp": sp_sanity_check(args) + # check save best model args + assert not ( + args.load_best_model_at_end and args.save_strategy != "best" + ), "--load-best-model-at-end requires --save-strategy to be 'best'" + assert ( + args.eval_interval % args.save_interval == 0 + ), "--eval-interval should be a multiple of --save-interval to ensure proper best model tracking" def sp_sanity_check(args: Namespace) -> None: @@ -555,8 +580,14 @@ def save_checkpoints( step: int, eagle3_model: nn.Module, optimizer: Optimizer, -): - epoch_output_dir = os.path.join(args.output_dir, f"epoch_{epoch}_step_{step}") + is_best: bool = False, + best_model_checkpoint: str = None, +) -> str: + epoch_output_dir = os.path.join( + args.output_dir, f"{PREFIX_CHECKPOINT_DIR}_{epoch}_step_{step}" + ) + if is_best: + best_model_checkpoint = epoch_output_dir if dist.get_rank() == 0: os.makedirs(epoch_output_dir, exist_ok=True) dist.barrier() @@ -588,7 +619,15 @@ def save_checkpoints( state_dict=draft_model_state_dict, ) print_on_rank0(f"Saved model configuration to {epoch_output_dir}") + rotate_checkpoints( + args.output_dir, + args.save_total_limit, + best_model_checkpoint, + False, + PREFIX_CHECKPOINT_DIR, + ) dist.barrier() + return epoch_output_dir def run_forward( @@ -744,6 +783,11 @@ def main(): sanity_check(args) print_args_with_dots(args) print_with_rank("Initialized distributed environment") + # to track best model + best_metric = float("-inf") if args.greater_is_better else float("inf") + best_model_checkpoint = None + current_is_best = False + is_best = False # ================================================ # 2. Build models @@ -984,12 +1028,40 @@ def main(): tracker, mode="eval", ) + # best metric + metric_name = args.metric_for_best + kind, idx = metric_name.split("_") + idx = int(idx) + current = ( + eval_acces[idx] if kind == "acc" else eval_plosses[idx] + ).item() + is_best = ( + (current > best_metric) + if args.greater_is_better + else (current < best_metric) + ) + if is_best: + best_metric = current + current_is_best = True + else: + current_is_best = False # ================================================ # 7.3 Save Checkpoints # ================================================ if global_step % args.save_interval == 0: # Save the model - save_checkpoints(args, epoch, global_step, eagle3_model, optimizer) + ckpt_path = save_checkpoints( + args, + epoch, + global_step, + eagle3_model, + optimizer, + is_best=current_is_best, + best_model_checkpoint=best_model_checkpoint, + ) + if current_is_best: + best_model_checkpoint = ckpt_path + current_is_best = False if args.max_num_steps is not None and global_step >= args.max_num_steps: break @@ -1001,7 +1073,25 @@ def main(): print_on_rank0( f"Training completed at step {global_step}, saving final checkpoint..." ) - save_checkpoints(args, epoch, global_step, eagle3_model, optimizer) + save_checkpoints( + args, + epoch, + global_step, + eagle3_model, + optimizer, + is_best=current_is_best, + best_model_checkpoint=best_model_checkpoint, + ) + # save best model at end + if ( + dist.get_rank() == 0 + and args.load_best_model_at_end + and best_model_checkpoint is not None + ): + final_model_ckpt = os.path.join(args.output_dir, "best") + if os.path.exists(final_model_ckpt): + shutil.rmtree(final_model_ckpt) + shutil.copytree(best_model_checkpoint, final_model_ckpt) # Close the tracker tracker.close() diff --git a/specforge/utils.py b/specforge/utils.py index af4d627c8..427383d87 100644 --- a/specforge/utils.py +++ b/specforge/utils.py @@ -2,7 +2,9 @@ import logging import os import re +import shutil from contextlib import contextmanager +from pathlib import Path import torch import torch.distributed as dist @@ -10,6 +12,7 @@ from transformers import AutoConfig, PretrainedConfig logger = logging.getLogger(__name__) +PREFIX_CHECKPOINT_DIR = "epoch" # keep specforge original format @contextmanager @@ -409,4 +412,96 @@ def safe_conversations_generator(file_path): except Exception as e: logger.warning(f"Skipping line {i + 1}: {e}") - continue + + +def rotate_checkpoints( + output_dir: str, + save_total_limit: int | None = None, + best_model_checkpoint: str | None = None, + use_mtime: bool = False, + checkpoint_prefix: str = PREFIX_CHECKPOINT_DIR, +): + if save_total_limit is None or save_total_limit <= 0: + return + + checkpoints = sort_checkpoints(output_dir, checkpoint_prefix, use_mtime) + if len(checkpoints) <= save_total_limit: + return + + # Checkpoints that must not be deleted + protected = {checkpoints[-1]} # most recent, for resuming + if best_model_checkpoint is not None: + protected.add(str(Path(best_model_checkpoint))) + + # Delete oldest non-protected checkpoints until we have save_total_limit left + num_to_keep = max(save_total_limit, len(protected)) + remaining = len(checkpoints) + for checkpoint in checkpoints: + if remaining <= num_to_keep: + break + if checkpoint not in protected: + shutil.rmtree(checkpoint, ignore_errors=True) + remaining -= 1 + + +def sort_checkpoints( + output_dir: str, + checkpoint_prefix: str = PREFIX_CHECKPOINT_DIR, + use_mtime: bool = False, + best_model_checkpoint: str | None = None, +) -> list[str]: + glob_checkpoints = [ + str(x) + for x in Path(output_dir).glob(f"{checkpoint_prefix}_*") + if os.path.isdir(x) + ] + + ordering_and_checkpoint_path = [] + for path in glob_checkpoints: + if use_mtime: + ordering_and_checkpoint_path.append((os.path.getmtime(path), path)) + else: + regex_match = re.match( + f".*{checkpoint_prefix}_([0-9]+)_step_([0-9]+)", path + ) + if regex_match is not None and regex_match.groups() is not None: + ordering_and_checkpoint_path.append( + (int(regex_match.groups()[1]), path) + ) # sort by step + + checkpoints_sorted = sorted(ordering_and_checkpoint_path) + + # mtime is not reliable on some filesystems (e.g., cloud fuse filesystem) + # so we check if the mtime is fake and fail back to numerical ordering + if use_mtime and len(checkpoints_sorted) > 1: + mtime_diff = checkpoints_sorted[-1][0] - checkpoints_sorted[0][0] + if mtime_diff < 1.0: + logger.warning( + "mtime may not be reliable on this filesystem, falling back to numerical ordering" + ) + return sort_checkpoints( + output_dir, + checkpoint_prefix, + use_mtime=False, + best_model_checkpoint=best_model_checkpoint, + ) + + checkpoints_sorted = [path for _, path in checkpoints_sorted] + + # Move best_model_checkpoint to second-to-last position to protect it from deletion + # while keeping the most recent checkpoint at the end for resuming training + if best_model_checkpoint is not None: + best_model_checkpoint = str(Path(best_model_checkpoint)) + if ( + best_model_checkpoint in checkpoints_sorted + and checkpoints_sorted[-1] != best_model_checkpoint + ): + most_recent = checkpoints_sorted[-1] + checkpoints_sorted = [ + c + for c in checkpoints_sorted + if c not in {best_model_checkpoint, most_recent} + ] + checkpoints_sorted += [best_model_checkpoint, most_recent] + + return checkpoints_sorted