Skip to content

Commit e9f4eb0

Browse files
add linear warmup for cr-ctc loss (#2075)
- this prevents cr-ctc loss from diverging at the beginning of the training
1 parent 2d2470b commit e9f4eb0

2 files changed

Lines changed: 7 additions & 3 deletions

File tree

egs/aishell/ASR/zipformer/train.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -882,8 +882,10 @@ def compute_loss(
882882
if params.use_ctc:
883883
loss += params.ctc_loss_scale * ctc_loss
884884
if use_cr_ctc:
885-
loss += params.cr_loss_scale * cr_loss
886-
885+
# linear warmup
886+
cr_loss_scale = min(batch_idx_train / warm_step, 1.0) * params.cr_loss_scale
887+
loss += cr_loss_scale * cr_loss
888+
887889
assert loss.requires_grad == is_training
888890

889891
info = MetricsTracker()

egs/librispeech/ASR/zipformer/train.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -967,7 +967,9 @@ def compute_loss(
967967
if params.use_ctc:
968968
loss += params.ctc_loss_scale * ctc_loss
969969
if use_cr_ctc:
970-
loss += params.cr_loss_scale * cr_loss
970+
# linear warmup
971+
cr_loss_scale = min(batch_idx_train / warm_step, 1.0) * params.cr_loss_scale
972+
loss += cr_loss_scale * cr_loss
971973

972974
if params.use_attention_decoder:
973975
loss += params.attention_decoder_loss_scale * attention_decoder_loss

0 commit comments

Comments
 (0)