Skip to content

Commit e78cb5f

Browse files
authored
Merge pull request #384 from bigict/optim
feat: add lr_scheduler
2 parents 61f11df + fe18ec1 commit e78cb5f

2 files changed

Lines changed: 66 additions & 13 deletions

File tree

profold2/command/trainer.py

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
)
2626
from profold2.model import accelerator, optim, FeatureBuilder, MetricDict, ReturnValues
2727
from profold2.model.utils import CheckpointManager
28-
from profold2.utils import exists
28+
from profold2.utils import default, exists
2929

3030
from profold2.command import worker
3131

@@ -260,11 +260,11 @@ def model_params_groups(optim_options):
260260
break
261261
return params
262262

263-
optim = Adam(
263+
optimizer = Adam(
264264
model_params_groups(args.model_params_optim_option), lr=args.learning_rate
265265
)
266266
else:
267-
optim = Adam(model.parameters(), lr=args.learning_rate)
267+
optimizer = Adam(model.parameters(), lr=args.learning_rate)
268268

269269
# tensorboard
270270
writer = SummaryWriter(os.path.join(args.prefix, 'runs', 'eval')
@@ -310,12 +310,21 @@ def writer_add_scalars(writer, loss, it, prefix=''):
310310
os.path.join(args.prefix, 'checkpoints'),
311311
max_to_keep=args.checkpoint_max_to_keep,
312312
model=model,
313-
optimizer=optim
313+
optimizer=optimizer
314314
)
315315
global_step = checkpoint_manager.restore_or_initialize() + 1
316316
logging.info('checkpoint_manager.global_step: %d', global_step)
317317
model.train()
318318

319+
scheduler = optim.get_scheduler(
320+
args.lr_scheduler,
321+
optimizer,
322+
num_warmup_steps=args.lr_scheduler_warmup_steps,
323+
num_training_steps=default(args.lr_scheduler_training_steps, args.num_batches),
324+
eta_min=args.lr_scheduler_eta_min,
325+
last_global_step=global_step
326+
)
327+
319328
# .. note:: When a model is trained on ``M`` nodes with ``batch=N``, the
320329
# gradient will be ``M`` times smaller when compared to the same model
321330
# trained on a single node with ``batch=M*N`` if the loss is summed (NOT
@@ -331,9 +340,11 @@ def writer_add_scalars(writer, loss, it, prefix=''):
331340
1) / (args.gradient_accumulate_every or 1.0)
332341

333342
def _step(data_loader, it, writer, stage='train', batch_callback=None):
334-
optim.zero_grad(set_to_none=True)
343+
optimizer.zero_grad(set_to_none=True)
335344

336-
logging.debug('_step it: %d, loss_scaler: %f', it, loss_scaler)
345+
logging.debug(
346+
'_step it: %d, loss_scaler: %f, lr: %s', it, loss_scaler, scheduler.get_lr()
347+
)
337348

338349
running_loss = MetricDict()
339350
for jt in range(args.gradient_accumulate_every):
@@ -378,9 +389,10 @@ def _step(data_loader, it, writer, stage='train', batch_callback=None):
378389
writer_add_scalars(writer, v, it, prefix=f'Loss/{stage}@{k}')
379390
# writer.add_scalar(f'Loss/train@{k}', v, it)
380391

381-
# optim.step()
382-
grad_scaler.step(optim)
392+
# optimizer.step()
393+
grad_scaler.step(optimizer)
383394
grad_scaler.update()
395+
scheduler.step()
384396

385397
def batch_seq_only(batch):
386398
batch = copy.copy(batch)
@@ -687,6 +699,31 @@ def add_arguments(parser): # pylint: disable=redefined-outer-name
687699
parser.add_argument(
688700
'-l', '--learning_rate', type=float, default='1e-3', help='learning rate.'
689701
)
702+
parser.add_argument(
703+
'--lr_scheduler',
704+
type=str,
705+
default=optim.SchedulerType.CONSTANT.value,
706+
choices=[m.value for m in optim.SchedulerType],
707+
help='lr scheduler.'
708+
)
709+
parser.add_argument(
710+
'--lr_scheduler_warmup_steps',
711+
type=float,
712+
default=None,
713+
help='num of warmup steps for lr scheduler.'
714+
)
715+
parser.add_argument(
716+
'--lr_scheduler_training_steps',
717+
type=float,
718+
default=None,
719+
help='num of training steps for applying lr scheduler.'
720+
)
721+
parser.add_argument(
722+
'--lr_scheduler_eta_min',
723+
type=float,
724+
default=0.0,
725+
help='eta_min for applying lr scheduler.'
726+
)
690727

691728
parser.add_argument(
692729
'--model_features',

profold2/model/optim.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""lr_scheduler wrapper
22
"""
33
from enum import Enum
4+
import functools
45
import math
56
from typing import Optional, Union
67

@@ -22,36 +23,51 @@ def get_scheduler(
2223
num_warmup_steps: Optional[int] = None,
2324
num_training_steps: Optional[int] = None,
2425
eta_min: float = 0.0,
25-
last_epoch: int = -1,
26+
last_global_step: int = 0,
2627
) -> LambdaLR:
2728
name = SchedulerType(name)
2829

2930
if name == SchedulerType.CONSTANT:
3031

31-
def lr_lambda(current_step: int) -> float:
32+
def lr_lambda(
33+
current_step: int, num_warmup_steps: Optional[int] = None
34+
) -> float:
35+
current_step = current_step + last_global_step
3236
if exists(num_warmup_steps) and current_step < num_warmup_steps:
3337
return current_step / max(1.0, num_warmup_steps)
3438
return 1.0
3539
elif name == SchedulerType.COSINE:
3640

37-
def lr_lambda(current_step: int) -> float:
41+
def lr_lambda(
42+
current_step: int, num_warmup_steps: Optional[int] = None
43+
) -> float:
44+
current_step = current_step + last_global_step
3845
if exists(num_warmup_steps) and current_step < num_warmup_steps:
3946
return current_step / max(1.0, num_warmup_steps)
47+
elif current_step > num_training_steps:
48+
return eta_min
4049
num_warmup_steps = default(num_warmup_steps, 0)
4150
progress = (
4251
(current_step - num_warmup_steps) / (num_training_steps - num_warmup_steps)
4352
)
4453
return 0.5 * (1.0 - eta_min) * (1.0 + math.cos(math.pi * progress)) + eta_min
4554
elif name == SchedulerType.LINEAR:
4655

47-
def lr_lambda(current_step: int) -> float:
56+
def lr_lambda(
57+
current_step: int, num_warmup_steps: Optional[int] = None
58+
) -> float:
59+
current_step = current_step + last_global_step
4860
if exists(num_warmup_steps) and current_step < num_warmup_steps:
4961
return current_step / max(1.0, num_warmup_steps)
62+
elif current_step > num_training_steps:
63+
return eta_min
5064

5165
num_warmup_steps = default(num_warmup_steps, 0)
5266
progress = (
5367
(num_training_steps - current_step) / (num_training_steps - num_warmup_steps)
5468
)
5569
return (1.0 - eta_min) * progress + eta_min
5670

57-
return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
71+
return LambdaLR(
72+
optimizer, functools.partial(lr_lambda, num_warmup_steps=num_warmup_steps)
73+
)

0 commit comments

Comments
 (0)