Skip to content

Commit e896c4f

Browse files
committed
feat: add save-total-limit with rotation, save-strategy with best, and save the best parameter at end
1 parent d5fb617 commit e896c4f

4 files changed

Lines changed: 203 additions & 6 deletions

File tree

docs/basic_usage/training.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,11 @@ It is important to note that the `run_llama3.1_8b_eagle3_offline.sh` script cons
6060
## 📈 Experiment Tracking
6161

6262
This project supports logging training progress to Wandb, TensorBoard, and SwanLab. You can enable tracking by adding the `--report-to` argument to the command line in your shell script.
63+
64+
65+
## 💾 Checkpoint Management
66+
To ensure robust training, SpecForge now supports fine-grained checkpointing strategies through `--save-strategy=best` and `--save-total-limit`.
67+
68+
`--save-strategy=best`: Automatically tracks and preserves the checkpoint with the best evaluation metric (by default, `acc_0`). This prevents overfitting and ensures that you always have access to the most optimized version of the draft model, regardless of when the training is interrupted.
69+
70+
`--save-total-limit`: This feature strictly limits the number of saved checkpoints, automatically deleting older or underperforming ones to prevent disk overflow while maintaining a reliable recovery point.

examples/run_llama3.1_8b_eagle3_online.sh

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,8 @@ torchrun \
2626
--attention-backend sdpa \
2727
--target-model-backend sglang \
2828
--log-interval 10 \
29-
--sglang-mem-fraction-static 0.25
29+
--sglang-mem-fraction-static 0.25 \
30+
--save-total-limit 3 \
31+
--save-strategy best \
32+
--metric-for-best acc_0 \
33+
--load-best-model-at-end

scripts/train_eagle3.py

Lines changed: 94 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import hashlib
33
import math
44
import os
5+
import shutil
56
import time
67
from argparse import ArgumentParser, Namespace
78
from typing import List, Optional, Tuple, Union
@@ -46,12 +47,14 @@
4647
from specforge.optimizer import BF16Optimizer
4748
from specforge.tracker import Tracker, create_tracker, get_tracker_class
4849
from specforge.utils import (
50+
PREFIX_CHECKPOINT_DIR,
4951
create_draft_config_from_target,
5052
get_last_checkpoint,
5153
print_args_with_dots,
5254
print_on_rank0,
5355
print_with_rank,
5456
rank_0_priority,
57+
rotate_checkpoints,
5558
safe_conversations_generator,
5659
)
5760

@@ -157,6 +160,10 @@ def parse_args() -> Tuple[ArgumentParser, Namespace]:
157160
)
158161
training_group.add_argument("--eval-interval", type=int, default=5000)
159162
training_group.add_argument("--save-interval", type=int, default=5000)
163+
training_group.add_argument(
164+
"--save-strategy", type=str, default="steps", choices=["steps", "best"]
165+
)
166+
training_group.add_argument("--save-total-limit", type=int, default=None)
160167
training_group.add_argument(
161168
"--log-interval",
162169
type=int,
@@ -165,6 +172,17 @@ def parse_args() -> Tuple[ArgumentParser, Namespace]:
165172
)
166173
training_group.add_argument("--seed", type=int, default=0)
167174
training_group.add_argument("--draft-accumulation-steps", type=int, default=1)
175+
# best model tracking
176+
training_group.add_argument(
177+
"--metric-for-best",
178+
type=str,
179+
default="acc_0",
180+
choices=["acc_0", "acc_1", "acc_2", "ploss_0", "ploss_1", "ploss_2"],
181+
)
182+
training_group.add_argument(
183+
"--greater-is-better", action="store_true", default=True
184+
)
185+
training_group.add_argument("--load-best-model-at-end", action="store_true")
168186

169187
# data processing type
170188
optimization_group = parser.add_argument_group("optimization")
@@ -341,6 +359,13 @@ def sanity_check(args: Namespace) -> None:
341359
args.target_batch_size = args.tp_size * args.batch_size
342360
if args.attention_backend == "usp":
343361
sp_sanity_check(args)
362+
# check save best model args
363+
assert not (
364+
args.load_best_model_At_end and args.save_strategy != "best"
365+
), "--load-best-model-at-end requires --save-strategy to be 'best'"
366+
assert (
367+
args.eval_interval % args.save_interval == 0
368+
), "--eval-interval should be a multiple of --save-interval to ensure proper best model tracking"
344369

345370

346371
def sp_sanity_check(args: Namespace) -> None:
@@ -555,8 +580,14 @@ def save_checkpoints(
555580
step: int,
556581
eagle3_model: nn.Module,
557582
optimizer: Optimizer,
558-
):
559-
epoch_output_dir = os.path.join(args.output_dir, f"epoch_{epoch}_step_{step}")
583+
is_best: bool = False,
584+
best_model_checkpoint: str = None,
585+
) -> str:
586+
epoch_output_dir = os.path.join(
587+
args.output_dir, f"{PREFIX_CHECKPOINT_DIR}_{epoch}_step_{step}"
588+
)
589+
if is_best:
590+
best_model_checkpoint = epoch_output_dir
560591
if dist.get_rank() == 0:
561592
os.makedirs(epoch_output_dir, exist_ok=True)
562593
dist.barrier()
@@ -588,7 +619,15 @@ def save_checkpoints(
588619
state_dict=draft_model_state_dict,
589620
)
590621
print_on_rank0(f"Saved model configuration to {epoch_output_dir}")
622+
rotate_checkpoints(
623+
args.output_dir,
624+
args.save_total_limit,
625+
best_model_checkpoint,
626+
False,
627+
PREFIX_CHECKPOINT_DIR,
628+
)
591629
dist.barrier()
630+
return epoch_output_dir
592631

593632

594633
def run_forward(
@@ -744,6 +783,11 @@ def main():
744783
sanity_check(args)
745784
print_args_with_dots(args)
746785
print_with_rank("Initialized distributed environment")
786+
# to track best model
787+
best_metric = float("-inf") if args.greater_is_better else float("inf")
788+
best_model_checkpoint = None
789+
current_is_best = False
790+
is_best = False
747791

748792
# ================================================
749793
# 2. Build models
@@ -984,12 +1028,40 @@ def main():
9841028
tracker,
9851029
mode="eval",
9861030
)
1031+
# best metric
1032+
metric_name = args.metric_for_best
1033+
kind, idx = metric_name.split("_")
1034+
idx = int(idx)
1035+
current = (
1036+
eval_acces[idx] if kind == "acc" else eval_plosses[idx]
1037+
).item()
1038+
is_best = (
1039+
(current > best_metric)
1040+
if args.greater_is_better
1041+
else (current < best_metric)
1042+
)
1043+
if is_best:
1044+
best_metric = current
1045+
current_is_best = True
1046+
else:
1047+
current_is_best = False
9871048
# ================================================
9881049
# 7.3 Save Checkpoints
9891050
# ================================================
9901051
if global_step % args.save_interval == 0:
9911052
# Save the model
992-
save_checkpoints(args, epoch, global_step, eagle3_model, optimizer)
1053+
ckpt_path = save_checkpoints(
1054+
args,
1055+
epoch,
1056+
global_step,
1057+
eagle3_model,
1058+
optimizer,
1059+
is_best=current_is_best,
1060+
best_model_checkpoint=best_model_checkpoint,
1061+
)
1062+
if current_is_best:
1063+
best_model_checkpoint = ckpt_path
1064+
current_is_best = False
9931065

9941066
if args.max_num_steps is not None and global_step >= args.max_num_steps:
9951067
break
@@ -1001,7 +1073,25 @@ def main():
10011073
print_on_rank0(
10021074
f"Training completed at step {global_step}, saving final checkpoint..."
10031075
)
1004-
save_checkpoints(args, epoch, global_step, eagle3_model, optimizer)
1076+
save_checkpoints(
1077+
args,
1078+
epoch,
1079+
global_step,
1080+
eagle3_model,
1081+
optimizer,
1082+
is_best=current_is_best,
1083+
best_model_checkpoint=best_model_checkpoint,
1084+
)
1085+
# save best model at end
1086+
if (
1087+
dist.get_rank() == 0
1088+
and args.load_best_model_at_end
1089+
and best_model_checkpoint is not None
1090+
):
1091+
final_model_ckpt = os.path.join(args.output_dir, "best")
1092+
if os.path.exists(final_model_ckpt):
1093+
shutil.rmtree(final_model_ckpt)
1094+
shutil.copytree(best_model_checkpoint, final_model_ckpt)
10051095

10061096
# Close the tracker
10071097
tracker.close()

specforge/utils.py

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,17 @@
22
import logging
33
import os
44
import re
5+
import shutil
56
from contextlib import contextmanager
7+
from pathlib import Path
68

79
import torch
810
import torch.distributed as dist
911
from torch.distributed._tensor import DTensor, Shard, distribute_tensor
1012
from transformers import AutoConfig, PretrainedConfig
1113

1214
logger = logging.getLogger(__name__)
15+
PREFIX_CHECKPOINT_DIR = "epoch" # keep specforge original format
1316

1417

1518
@contextmanager
@@ -409,4 +412,96 @@ def safe_conversations_generator(file_path):
409412

410413
except Exception as e:
411414
logger.warning(f"Skipping line {i + 1}: {e}")
412-
continue
415+
416+
417+
def rotate_checkpoints(
418+
output_dir: str,
419+
save_total_limit: int | None = None,
420+
best_model_checkpoint: str | None = None,
421+
use_mtime: bool = False,
422+
checkpoint_prefix: str = PREFIX_CHECKPOINT_DIR,
423+
):
424+
if save_total_limit is None or save_total_limit <= 0:
425+
return
426+
427+
checkpoints = sort_checkpoints(output_dir, checkpoint_prefix, use_mtime)
428+
if len(checkpoints) <= save_total_limit:
429+
return
430+
431+
# Checkpoints that must not be deleted
432+
protected = {checkpoints[-1]} # most recent, for resuming
433+
if best_model_checkpoint is not None:
434+
protected.add(str(Path(best_model_checkpoint)))
435+
436+
# Delete oldest non-protected checkpoints until we have save_total_limit left
437+
num_to_keep = max(save_total_limit, len(protected))
438+
remaining = len(checkpoints)
439+
for checkpoint in checkpoints:
440+
if remaining <= num_to_keep:
441+
break
442+
if checkpoint not in protected:
443+
shutil.rmtree(checkpoint, ignore_errors=True)
444+
remaining -= 1
445+
446+
447+
def sort_checkpoints(
448+
output_dir: str,
449+
checkpoint_prefix: str = PREFIX_CHECKPOINT_DIR,
450+
use_mtime: bool = False,
451+
best_model_checkpoint: str | None = None,
452+
) -> list[str]:
453+
glob_checkpoints = [
454+
str(x)
455+
for x in Path(output_dir).glob(f"{checkpoint_prefix}_*")
456+
if os.path.isdir(x)
457+
]
458+
459+
ordering_and_checkpoint_path = []
460+
for path in glob_checkpoints:
461+
if use_mtime:
462+
ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
463+
else:
464+
regex_match = re.match(
465+
f".*{checkpoint_prefix}_([0-9]+)_step_([0-9]+)", path
466+
)
467+
if regex_match is not None and regex_match.groups() is not None:
468+
ordering_and_checkpoint_path.append(
469+
(int(regex_match.groups()[1]), path)
470+
) # sort by step
471+
472+
checkpoints_sorted = sorted(ordering_and_checkpoint_path)
473+
474+
# mtime is not reliable on some filesystems (e.g., cloud fuse filesystem)
475+
# so we check if the mtime is fake and fail back to numerical ordering
476+
if use_mtime and len(checkpoints_sorted) > 1:
477+
mtime_diff = checkpoints_sorted[-1][0] - checkpoints_sorted[0][0]
478+
if mtime_diff < 1.0:
479+
logger.warning_once(
480+
"mtime may not be reliable on this filesystem, falling back to numerical ordering"
481+
)
482+
return sort_checkpoints(
483+
output_dir,
484+
checkpoint_prefix,
485+
use_mtime=False,
486+
best_model_checkpoint=best_model_checkpoint,
487+
)
488+
489+
checkpoints_sorted = [path for _, path in checkpoints_sorted]
490+
491+
# Move best_model_checkpoint to second-to-last position to protect it from deletion
492+
# while keeping the most recent checkpoint at the end for resuming training
493+
if best_model_checkpoint is not None:
494+
best_model_checkpoint = str(Path(best_model_checkpoint))
495+
if (
496+
best_model_checkpoint in checkpoints_sorted
497+
and checkpoints_sorted[-1] != best_model_checkpoint
498+
):
499+
most_recent = checkpoints_sorted[-1]
500+
checkpoints_sorted = [
501+
c
502+
for c in checkpoints_sorted
503+
if c not in {best_model_checkpoint, most_recent}
504+
]
505+
checkpoint_sorted += [best_model_checkpoint, most_recent]
506+
507+
return checkpoints_sorted

0 commit comments

Comments
 (0)