File tree Expand file tree Collapse file tree
librispeech/ASR/zipformer Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -882,8 +882,10 @@ def compute_loss(
882882 if params .use_ctc :
883883 loss += params .ctc_loss_scale * ctc_loss
884884 if use_cr_ctc :
885- loss += params .cr_loss_scale * cr_loss
886-
885+ # linear warmup
886+ cr_loss_scale = min (batch_idx_train / warm_step , 1.0 ) * params .cr_loss_scale
887+ loss += cr_loss_scale * cr_loss
888+
887889 assert loss .requires_grad == is_training
888890
889891 info = MetricsTracker ()
Original file line number Diff line number Diff line change @@ -967,7 +967,9 @@ def compute_loss(
967967 if params .use_ctc :
968968 loss += params .ctc_loss_scale * ctc_loss
969969 if use_cr_ctc :
970- loss += params .cr_loss_scale * cr_loss
970+ # linear warmup
971+ cr_loss_scale = min (batch_idx_train / warm_step , 1.0 ) * params .cr_loss_scale
972+ loss += cr_loss_scale * cr_loss
971973
972974 if params .use_attention_decoder :
973975 loss += params .attention_decoder_loss_scale * attention_decoder_loss
You can’t perform that action at this time.
0 commit comments