Skip to content

Commit b02db5b

Browse files
Luodiananxiangsir
andauthored
fix: correct gradient accumulation off-by-one and lr_scheduler over-stepping (#82)
* fix: correct gradient accumulation off-by-one and lr_scheduler over-stepping * fix: align scheduler total_iters with optimizer steps under gradient accumulation lr_scheduler total_iters was set to micro-step count (total_steps), but after moving lr_scheduler.step() to only fire on optimizer steps, the scheduler would only traverse 1/backward_passes_per_step of its budget. Divide total_iters by backward_passes_per_step so the full LR curve (warmup + polynomial decay) completes over the actual optimizer steps. No-op when backward_passes_per_step=1 (Stage-1). --------- Co-authored-by: Xiang An <anxiangsir@outlook.com>
1 parent 29826ef commit b02db5b

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

training/train.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,8 @@ def _expand(name, v):
350350
optimizer_cls = torch.optim.AdamW
351351

352352
opt = optimizer_cls(parameters, lr=args.lr, weight_decay=args.weight_decay)
353-
lr_scheduler = PolynomialLRWarmup(opt, int(args.total_steps * args.warmup_ratio), args.total_steps, 2)
353+
optimizer_total_steps = args.total_steps // args.backward_passes_per_step
354+
lr_scheduler = PolynomialLRWarmup(opt, int(optimizer_total_steps * args.warmup_ratio), optimizer_total_steps, 2)
354355
else:
355356
raise ValueError(f"{args.opt} not support!")
356357

@@ -652,7 +653,7 @@ def wrap_ddp(model):
652653
list_loss.append(head_loss)
653654
list_loss_float.append(head_loss.item())
654655

655-
is_accumulation_step = global_step % args.backward_passes_per_step != 0
656+
is_accumulation_step = (global_step + 1) % args.backward_passes_per_step != 0
656657
scaled_loss = sum(list_loss) / args.backward_passes_per_step
657658

658659
if is_accumulation_step:
@@ -665,8 +666,7 @@ def wrap_ddp(model):
665666
clip_grad_norm_(pfc.parameters(), max_norm=5, norm_type=2)
666667
opt.step()
667668
opt.zero_grad(set_to_none=True)
668-
669-
lr_scheduler.step()
669+
lr_scheduler.step()
670670

671671
batch_end_callback(
672672
global_step=global_step,

0 commit comments

Comments
 (0)