@@ -347,6 +347,11 @@ def _step(data_loader, it, writer, stage='train', batch_callback=None):
347347 '%d %d %d seq.shape: %s pid: %s, clips: %s' , epoch , it , jt , seq .shape ,
348348 ',' .join (batch ['pid' ]), batch .get ('clip' )
349349 )
350+ length_scaler = 1.0
351+ if args .train_apply_sqrt_length_scale :
352+ length_scaler = torch .sqrt (
353+ (torch .mean (torch .sum (batch ['mask' ], dim = - 1 ) + 1e-6 )) / args .max_crop_len
354+ )
350355
351356 # maybe sync or not
352357 with no_sync_ctx (
@@ -360,7 +365,7 @@ def _step(data_loader, it, writer, stage='train', batch_callback=None):
360365 shard_size = args .model_shard_size
361366 )
362367 )
363- grad_scaler .scale (r .loss * loss_scaler ).backward ()
368+ grad_scaler .scale (r .loss * loss_scaler * length_scaler ).backward ()
364369
365370 # running loss
366371 running_loss += MetricDict ({'all' : r .loss })
@@ -542,6 +547,12 @@ def add_arguments(parser): # pylint: disable=redefined-outer-name
542547 parser .add_argument (
543548 '--max_crop_len' , type = int , default = 255 , help = 'crop protein whose length>LEN.'
544549 )
550+ parser .add_argument (
551+ '--train_apply_sqrt_length_scale' ,
552+ action = 'store_true' ,
553+ help = 'multiply the final loss of each training example by the srqt of the number '
554+ 'of residues after cropping.'
555+ )
545556 parser .add_argument (
546557 '--crop_algorithm' ,
547558 type = str ,
0 commit comments