Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/basic_usage/training.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,11 @@ It is important to note that the `run_llama3.1_8b_eagle3_offline.sh` script cons
## 📈 Experiment Tracking

This project supports logging training progress to Wandb, TensorBoard, and SwanLab. You can enable tracking by adding the `--report-to` argument to the command line in your shell script.


## 💾 Checkpoint Management
To ensure robust training, SpecForge now supports fine-grained checkpointing strategies through `--save-strategy=best` and `--save-total-limit`.

`--save-strategy=best`: Automatically tracks and preserves the checkpoint with the best evaluation metric (by default, `acc_0`). This prevents overfitting and ensures that you always have access to the most optimized version of the draft model, regardless of when the training is interrupted.

`--save-total-limit`: This feature strictly limits the number of saved checkpoints, automatically deleting older or underperforming ones to prevent disk overflow while maintaining a reliable recovery point.
6 changes: 5 additions & 1 deletion examples/run_llama3.1_8b_eagle3_online.sh
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,8 @@ torchrun \
--attention-backend sdpa \
--target-model-backend sglang \
--log-interval 10 \
--sglang-mem-fraction-static 0.25
--sglang-mem-fraction-static 0.25 \
--save-total-limit 3 \
--save-strategy best \
--metric-for-best acc_0 \
--load-best-model-at-end
98 changes: 94 additions & 4 deletions scripts/train_eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import hashlib
import math
import os
import shutil
import time
from argparse import ArgumentParser, Namespace
from typing import List, Optional, Tuple, Union
Expand Down Expand Up @@ -46,12 +47,14 @@
from specforge.optimizer import BF16Optimizer
from specforge.tracker import Tracker, create_tracker, get_tracker_class
from specforge.utils import (
PREFIX_CHECKPOINT_DIR,
create_draft_config_from_target,
get_last_checkpoint,
print_args_with_dots,
print_on_rank0,
print_with_rank,
rank_0_priority,
rotate_checkpoints,
safe_conversations_generator,
)

Expand Down Expand Up @@ -157,6 +160,10 @@ def parse_args() -> Tuple[ArgumentParser, Namespace]:
)
training_group.add_argument("--eval-interval", type=int, default=5000)
training_group.add_argument("--save-interval", type=int, default=5000)
training_group.add_argument(
"--save-strategy", type=str, default="steps", choices=["steps", "best"]
)
training_group.add_argument("--save-total-limit", type=int, default=None)
training_group.add_argument(
"--log-interval",
type=int,
Expand All @@ -165,6 +172,17 @@ def parse_args() -> Tuple[ArgumentParser, Namespace]:
)
training_group.add_argument("--seed", type=int, default=0)
training_group.add_argument("--draft-accumulation-steps", type=int, default=1)
# best model tracking
training_group.add_argument(
"--metric-for-best",
type=str,
default="acc_0",
choices=["acc_0", "acc_1", "acc_2", "ploss_0", "ploss_1", "ploss_2"],
)
training_group.add_argument(
"--greater-is-better", action="store_true", default=True
)
training_group.add_argument("--load-best-model-at-end", action="store_true")

# data processing type
optimization_group = parser.add_argument_group("optimization")
Expand Down Expand Up @@ -341,6 +359,13 @@ def sanity_check(args: Namespace) -> None:
args.target_batch_size = args.tp_size * args.batch_size
if args.attention_backend == "usp":
sp_sanity_check(args)
# check save best model args
assert not (
args.load_best_model_at_end and args.save_strategy != "best"
), "--load-best-model-at-end requires --save-strategy to be 'best'"
assert (
args.eval_interval % args.save_interval == 0
), "--eval-interval should be a multiple of --save-interval to ensure proper best model tracking"


def sp_sanity_check(args: Namespace) -> None:
Expand Down Expand Up @@ -555,8 +580,14 @@ def save_checkpoints(
step: int,
eagle3_model: nn.Module,
optimizer: Optimizer,
):
epoch_output_dir = os.path.join(args.output_dir, f"epoch_{epoch}_step_{step}")
is_best: bool = False,
best_model_checkpoint: str = None,
) -> str:
epoch_output_dir = os.path.join(
args.output_dir, f"{PREFIX_CHECKPOINT_DIR}_{epoch}_step_{step}"
)
if is_best:
best_model_checkpoint = epoch_output_dir
if dist.get_rank() == 0:
os.makedirs(epoch_output_dir, exist_ok=True)
dist.barrier()
Expand Down Expand Up @@ -588,7 +619,15 @@ def save_checkpoints(
state_dict=draft_model_state_dict,
)
print_on_rank0(f"Saved model configuration to {epoch_output_dir}")
rotate_checkpoints(
args.output_dir,
args.save_total_limit,
best_model_checkpoint,
False,
PREFIX_CHECKPOINT_DIR,
)
dist.barrier()
return epoch_output_dir


def run_forward(
Expand Down Expand Up @@ -744,6 +783,11 @@ def main():
sanity_check(args)
print_args_with_dots(args)
print_with_rank("Initialized distributed environment")
# to track best model
best_metric = float("-inf") if args.greater_is_better else float("inf")
best_model_checkpoint = None
current_is_best = False
is_best = False
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variable is_best is initialized here but shadowed by a local variable of the same name inside the training loop (line 1038). It appears to be unused in this scope.


# ================================================
# 2. Build models
Expand Down Expand Up @@ -984,12 +1028,40 @@ def main():
tracker,
mode="eval",
)
# best metric
metric_name = args.metric_for_best
kind, idx = metric_name.split("_")
idx = int(idx)
current = (
eval_acces[idx] if kind == "acc" else eval_plosses[idx]
).item()
is_best = (
(current > best_metric)
if args.greater_is_better
else (current < best_metric)
)
if is_best:
best_metric = current
current_is_best = True
else:
current_is_best = False
# ================================================
# 7.3 Save Checkpoints
# ================================================
if global_step % args.save_interval == 0:
# Save the model
save_checkpoints(args, epoch, global_step, eagle3_model, optimizer)
ckpt_path = save_checkpoints(
args,
epoch,
global_step,
eagle3_model,
optimizer,
is_best=current_is_best,
best_model_checkpoint=best_model_checkpoint,
)
if current_is_best:
best_model_checkpoint = ckpt_path
current_is_best = False

if args.max_num_steps is not None and global_step >= args.max_num_steps:
break
Expand All @@ -1001,7 +1073,25 @@ def main():
print_on_rank0(
f"Training completed at step {global_step}, saving final checkpoint..."
)
save_checkpoints(args, epoch, global_step, eagle3_model, optimizer)
save_checkpoints(
args,
epoch,
global_step,
eagle3_model,
optimizer,
is_best=current_is_best,
best_model_checkpoint=best_model_checkpoint,
)
# save best model at end
if (
dist.get_rank() == 0
and args.load_best_model_at_end
and best_model_checkpoint is not None
):
final_model_ckpt = os.path.join(args.output_dir, "best")
if os.path.exists(final_model_ckpt):
shutil.rmtree(final_model_ckpt)
shutil.copytree(best_model_checkpoint, final_model_ckpt)

# Close the tracker
tracker.close()
Expand Down
97 changes: 96 additions & 1 deletion specforge/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@
import logging
import os
import re
import shutil
from contextlib import contextmanager
from pathlib import Path

import torch
import torch.distributed as dist
from torch.distributed._tensor import DTensor, Shard, distribute_tensor
from transformers import AutoConfig, PretrainedConfig

logger = logging.getLogger(__name__)
PREFIX_CHECKPOINT_DIR = "epoch" # keep specforge original format


@contextmanager
Expand Down Expand Up @@ -409,4 +412,96 @@ def safe_conversations_generator(file_path):

except Exception as e:
logger.warning(f"Skipping line {i + 1}: {e}")
continue


def rotate_checkpoints(
output_dir: str,
save_total_limit: int | None = None,
best_model_checkpoint: str | None = None,
use_mtime: bool = False,
checkpoint_prefix: str = PREFIX_CHECKPOINT_DIR,
):
if save_total_limit is None or save_total_limit <= 0:
return

checkpoints = sort_checkpoints(output_dir, checkpoint_prefix, use_mtime)
if len(checkpoints) <= save_total_limit:
return

# Checkpoints that must not be deleted
protected = {checkpoints[-1]} # most recent, for resuming
if best_model_checkpoint is not None:
protected.add(str(Path(best_model_checkpoint)))

# Delete oldest non-protected checkpoints until we have save_total_limit left
num_to_keep = max(save_total_limit, len(protected))
remaining = len(checkpoints)
for checkpoint in checkpoints:
if remaining <= num_to_keep:
break
if checkpoint not in protected:
shutil.rmtree(checkpoint, ignore_errors=True)
remaining -= 1


def sort_checkpoints(
output_dir: str,
checkpoint_prefix: str = PREFIX_CHECKPOINT_DIR,
use_mtime: bool = False,
best_model_checkpoint: str | None = None,
) -> list[str]:
glob_checkpoints = [
str(x)
for x in Path(output_dir).glob(f"{checkpoint_prefix}_*")
if os.path.isdir(x)
]

ordering_and_checkpoint_path = []
for path in glob_checkpoints:
if use_mtime:
ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
else:
regex_match = re.match(
f".*{checkpoint_prefix}_([0-9]+)_step_([0-9]+)", path
)
if regex_match is not None and regex_match.groups() is not None:
ordering_and_checkpoint_path.append(
(int(regex_match.groups()[1]), path)
) # sort by step

checkpoints_sorted = sorted(ordering_and_checkpoint_path)

# mtime is not reliable on some filesystems (e.g., cloud fuse filesystem)
# so we check if the mtime is fake and fail back to numerical ordering
if use_mtime and len(checkpoints_sorted) > 1:
mtime_diff = checkpoints_sorted[-1][0] - checkpoints_sorted[0][0]
if mtime_diff < 1.0:
logger.warning(
"mtime may not be reliable on this filesystem, falling back to numerical ordering"
)
return sort_checkpoints(
output_dir,
checkpoint_prefix,
use_mtime=False,
best_model_checkpoint=best_model_checkpoint,
)

checkpoints_sorted = [path for _, path in checkpoints_sorted]

# Move best_model_checkpoint to second-to-last position to protect it from deletion
# while keeping the most recent checkpoint at the end for resuming training
if best_model_checkpoint is not None:
best_model_checkpoint = str(Path(best_model_checkpoint))
if (
best_model_checkpoint in checkpoints_sorted
and checkpoints_sorted[-1] != best_model_checkpoint
):
most_recent = checkpoints_sorted[-1]
checkpoints_sorted = [
c
for c in checkpoints_sorted
if c not in {best_model_checkpoint, most_recent}
]
checkpoints_sorted += [best_model_checkpoint, most_recent]

return checkpoints_sorted