2525)
2626from profold2 .model import accelerator , optim , FeatureBuilder , MetricDict , ReturnValues
2727from profold2 .model .utils import CheckpointManager
28- from profold2 .utils import exists
28+ from profold2 .utils import default , exists
2929
3030from profold2 .command import worker
3131
@@ -260,11 +260,11 @@ def model_params_groups(optim_options):
260260 break
261261 return params
262262
263- optim = Adam (
263+ optimizer = Adam (
264264 model_params_groups (args .model_params_optim_option ), lr = args .learning_rate
265265 )
266266 else :
267- optim = Adam (model .parameters (), lr = args .learning_rate )
267+ optimizer = Adam (model .parameters (), lr = args .learning_rate )
268268
269269 # tensorboard
270270 writer = SummaryWriter (os .path .join (args .prefix , 'runs' , 'eval' )
@@ -310,12 +310,21 @@ def writer_add_scalars(writer, loss, it, prefix=''):
310310 os .path .join (args .prefix , 'checkpoints' ),
311311 max_to_keep = args .checkpoint_max_to_keep ,
312312 model = model ,
313- optimizer = optim
313+ optimizer = optimizer
314314 )
315315 global_step = checkpoint_manager .restore_or_initialize () + 1
316316 logging .info ('checkpoint_manager.global_step: %d' , global_step )
317317 model .train ()
318318
319+ scheduler = optim .get_scheduler (
320+ args .lr_scheduler ,
321+ optimizer ,
322+ num_warmup_steps = args .lr_scheduler_warmup_steps ,
323+ num_training_steps = default (args .lr_scheduler_training_steps , args .num_batches ),
324+ eta_min = args .lr_scheduler_eta_min ,
325+ last_global_step = global_step
326+ )
327+
319328 # .. note:: When a model is trained on ``M`` nodes with ``batch=N``, the
320329 # gradient will be ``M`` times smaller when compared to the same model
321330 # trained on a single node with ``batch=M*N`` if the loss is summed (NOT
@@ -331,9 +340,11 @@ def writer_add_scalars(writer, loss, it, prefix=''):
331340 1 ) / (args .gradient_accumulate_every or 1.0 )
332341
333342 def _step (data_loader , it , writer , stage = 'train' , batch_callback = None ):
334- optim .zero_grad (set_to_none = True )
343+ optimizer .zero_grad (set_to_none = True )
335344
336- logging .debug ('_step it: %d, loss_scaler: %f' , it , loss_scaler )
345+ logging .debug (
346+ '_step it: %d, loss_scaler: %f, lr: %s' , it , loss_scaler , scheduler .get_lr ()
347+ )
337348
338349 running_loss = MetricDict ()
339350 for jt in range (args .gradient_accumulate_every ):
@@ -378,9 +389,10 @@ def _step(data_loader, it, writer, stage='train', batch_callback=None):
378389 writer_add_scalars (writer , v , it , prefix = f'Loss/{ stage } @{ k } ' )
379390 # writer.add_scalar(f'Loss/train@{k}', v, it)
380391
381- # optim .step()
382- grad_scaler .step (optim )
392+ # optimizer .step()
393+ grad_scaler .step (optimizer )
383394 grad_scaler .update ()
395+ scheduler .step ()
384396
385397 def batch_seq_only (batch ):
386398 batch = copy .copy (batch )
@@ -687,6 +699,31 @@ def add_arguments(parser): # pylint: disable=redefined-outer-name
687699 parser .add_argument (
688700 '-l' , '--learning_rate' , type = float , default = '1e-3' , help = 'learning rate.'
689701 )
702+ parser .add_argument (
703+ '--lr_scheduler' ,
704+ type = str ,
705+ default = optim .SchedulerType .CONSTANT .value ,
706+ choices = [m .value for m in optim .SchedulerType ],
707+ help = 'lr scheduler.'
708+ )
709+ parser .add_argument (
710+ '--lr_scheduler_warmup_steps' ,
711+ type = float ,
712+ default = None ,
713+ help = 'num of warmup steps for lr scheduler.'
714+ )
715+ parser .add_argument (
716+ '--lr_scheduler_training_steps' ,
717+ type = float ,
718+ default = None ,
719+ help = 'num of training steps for applying lr scheduler.'
720+ )
721+ parser .add_argument (
722+ '--lr_scheduler_eta_min' ,
723+ type = float ,
724+ default = 0.0 ,
725+ help = 'eta_min for applying lr scheduler.'
726+ )
690727
691728 parser .add_argument (
692729 '--model_features' ,
0 commit comments