Skip to content

Commit 3c811bc

Browse files
authored
Save multiple checkpoints per epoch (#509)
<!-- markdownlint-disable --> PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTTOM) HAVE BEEN CONSIDERED. ## Purpose Save multiple checkpoints per epoch <!--- Why your changes are needed --> ## Description updated `checkpoint_freq` to float, setting it to 0.5 would mean save twice per epoch to the same directory by overwriting. We skip the last saving right before validation since it's handled separately. (Or should it?) When checkpoint_freq > 1 it should be an int and we save per multiple epochs. <!--- High-level concise summary of changes --> ## Related Issue <!--- Link related issue if applicable --> #493 ## Tests Tested locally with an example. <!--- Please describe in detail how you tested your changes. --> I have filled in: - [x] The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)". - [x] The test plan/results, such as providing test command and pasting the results. - [ ] (Optional) The necessary documentation update. - [x] I (a human) have written or reviewed the code in this pr to the best of my ability. --------- Signed-off-by: shanjiaz <zsjwpianpian@gmail.com>
1 parent 2979b8a commit 3c811bc

2 files changed

Lines changed: 36 additions & 11 deletions

File tree

scripts/train.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -383,11 +383,16 @@ def main(args: argparse.Namespace):
383383
maybe_destroy_distributed()
384384

385385

386-
def _checkpoint_freq(value: str) -> int:
387-
ivalue = int(value)
388-
if ivalue < 1:
389-
raise argparse.ArgumentTypeError("--checkpoint-freq must be >= 1")
390-
return ivalue
386+
def _checkpoint_freq(value: str) -> float:
387+
fvalue = float(value)
388+
if fvalue <= 0:
389+
raise argparse.ArgumentTypeError("--checkpoint-freq must be > 0")
390+
if fvalue > 1 and not fvalue.is_integer():
391+
raise argparse.ArgumentTypeError(
392+
f"--checkpoint-freq={fvalue} is not an integer. Values > 1 are treated "
393+
"as epoch counts and must be whole numbers."
394+
)
395+
return fvalue
391396

392397

393398
def parse_args():
@@ -632,8 +637,9 @@ def parse_args():
632637
parser.add_argument(
633638
"--checkpoint-freq",
634639
type=_checkpoint_freq,
635-
default=1,
636-
help="Save a checkpoint every N epochs.",
640+
default=1.0,
641+
help="Save a checkpoint every N epochs. Values < 1 enable sub-epoch "
642+
"checkpointing (e.g. 0.5 = every half epoch).",
637643
)
638644
parser.add_argument(
639645
"--save-best",

src/speculators/train/trainer.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
metric_logger = logging.getLogger("speculators.metrics")
3030

3131
warnings.filterwarnings("ignore", category=TqdmExperimentalWarning)
32+
MIN_STEP_PCT = 0.25
3233

3334

3435
class TrainerConfig(NamedTuple):
@@ -44,7 +45,7 @@ class TrainerConfig(NamedTuple):
4445
scheduler_warmup_steps: int | None = None
4546
scheduler_total_steps: int | None = None
4647
scheduler_num_cosine_cycles: float = 0.5
47-
checkpoint_freq: int = 1
48+
checkpoint_freq: float = 1
4849
save_best: bool = False
4950
hidden_states_dtype: torch.dtype = torch.bfloat16
5051
log_freq: int = 1
@@ -191,7 +192,13 @@ def train_epoch(self, epoch: int):
191192
if self.local_rank == 0:
192193
train_loader = tqdm(train_loader, desc=f"Epoch {epoch}") # type: ignore[assignment]
193194

194-
for batch in train_loader:
195+
num_steps = len(self.train_loader)
196+
step_interval = (
197+
max(1, round(num_steps * self.config.checkpoint_freq))
198+
if self.config.checkpoint_freq < 1
199+
else None
200+
)
201+
for local_step, batch in enumerate(train_loader, 1):
195202
gpu_batch = {
196203
k: v.to(self.local_rank, non_blocking=True)
197204
if isinstance(v, torch.Tensor)
@@ -229,6 +236,15 @@ def train_epoch(self, epoch: int):
229236
)
230237
self.global_step += 1
231238

239+
if (
240+
step_interval is not None
241+
and not self.config.save_best
242+
and local_step % step_interval == 0
243+
and num_steps - local_step >= step_interval * MIN_STEP_PCT
244+
# Avoid saving back to back ay the end of each epoch
245+
):
246+
self.maybe_save_checkpoint(epoch)
247+
232248
@torch.no_grad()
233249
def val_epoch(self, epoch: int) -> dict[str, float] | None:
234250
if self.val_loader is None:
@@ -271,7 +287,8 @@ def maybe_save_checkpoint(self, epoch: int | str):
271287
if epoch != "interrupted" and (
272288
self.config.save_best
273289
or (
274-
isinstance(epoch, int)
290+
self.config.checkpoint_freq >= 1
291+
and isinstance(epoch, int)
275292
and epoch != 0
276293
and (epoch + 1) % self.config.checkpoint_freq != 0
277294
)
@@ -294,7 +311,9 @@ def maybe_update_best(self, epoch: int, val_metrics: dict | None):
294311
self.checkpointer.save_checkpoint(self.model, self.opt, epoch)
295312
if self.scheduler is not None:
296313
self.checkpointer.save_scheduler_state_dict(self.scheduler, epoch)
297-
elif not (epoch == 0 or (epoch + 1) % self.config.checkpoint_freq == 0):
314+
elif self.config.checkpoint_freq >= 1 and not (
315+
epoch == 0 or (epoch + 1) % int(self.config.checkpoint_freq) == 0
316+
):
298317
return
299318

300319
self.best_val_loss = val_metrics["loss_epoch"]

0 commit comments

Comments
 (0)