Skip to content

Commit 7048618

Browse files
authored
Merge pull request #376 from bigict/trainer
feat: multiply the final loss of each training example by the square root of the number of residues after cropping.
2 parents b8dfa48 + ebfd182 commit 7048618

1 file changed

Lines changed: 12 additions & 1 deletion

File tree

profold2/command/trainer.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,11 @@ def _step(data_loader, it, writer, stage='train', batch_callback=None):
347347
'%d %d %d seq.shape: %s pid: %s, clips: %s', epoch, it, jt, seq.shape,
348348
','.join(batch['pid']), batch.get('clip')
349349
)
350+
length_scaler = 1.0
351+
if args.train_apply_sqrt_length_scale:
352+
length_scaler = torch.sqrt(
353+
(torch.mean(torch.sum(batch['mask'], dim=-1) + 1e-6)) / args.max_crop_len
354+
)
350355

351356
# maybe sync or not
352357
with no_sync_ctx(
@@ -360,7 +365,7 @@ def _step(data_loader, it, writer, stage='train', batch_callback=None):
360365
shard_size=args.model_shard_size
361366
)
362367
)
363-
grad_scaler.scale(r.loss * loss_scaler).backward()
368+
grad_scaler.scale(r.loss * loss_scaler * length_scaler).backward()
364369

365370
# running loss
366371
running_loss += MetricDict({'all': r.loss})
@@ -542,6 +547,12 @@ def add_arguments(parser): # pylint: disable=redefined-outer-name
542547
parser.add_argument(
543548
'--max_crop_len', type=int, default=255, help='crop protein whose length>LEN.'
544549
)
550+
parser.add_argument(
551+
'--train_apply_sqrt_length_scale',
552+
action='store_true',
553+
help='multiply the final loss of each training example by the srqt of the number '
554+
'of residues after cropping.'
555+
)
545556
parser.add_argument(
546557
'--crop_algorithm',
547558
type=str,

0 commit comments

Comments
 (0)