Skip to content

Commit 34fc1fd

Browse files
authored
Fix transformer decoder layer (#1995)
1 parent 5fe1307 commit 34fc1fd

8 files changed

Lines changed: 8 additions & 0 deletions

File tree

egs/aishell/ASR/conformer_ctc/transformer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,7 @@ def forward(
545545
memory_mask: Optional[torch.Tensor] = None,
546546
tgt_key_padding_mask: Optional[torch.Tensor] = None,
547547
memory_key_padding_mask: Optional[torch.Tensor] = None,
548+
**kwargs,
548549
) -> torch.Tensor:
549550
"""Pass the inputs (and mask) through the decoder layer.
550551

egs/aishell/ASR/conformer_mmi/transformer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,7 @@ def forward(
545545
memory_mask: Optional[torch.Tensor] = None,
546546
tgt_key_padding_mask: Optional[torch.Tensor] = None,
547547
memory_key_padding_mask: Optional[torch.Tensor] = None,
548+
**kwargs,
548549
) -> torch.Tensor:
549550
"""Pass the inputs (and mask) through the decoder layer.
550551

egs/gigaspeech/ASR/conformer_ctc/transformer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,7 @@ def forward(
549549
memory_mask: Optional[torch.Tensor] = None,
550550
tgt_key_padding_mask: Optional[torch.Tensor] = None,
551551
memory_key_padding_mask: Optional[torch.Tensor] = None,
552+
**kwargs,
552553
) -> torch.Tensor:
553554
"""Pass the inputs (and mask) through the decoder layer.
554555

egs/librispeech/ASR/conformer_ctc/transformer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,7 @@ def forward(
549549
memory_mask: Optional[torch.Tensor] = None,
550550
tgt_key_padding_mask: Optional[torch.Tensor] = None,
551551
memory_key_padding_mask: Optional[torch.Tensor] = None,
552+
**kwargs,
552553
) -> torch.Tensor:
553554
"""Pass the inputs (and mask) through the decoder layer.
554555

egs/librispeech/ASR/conformer_ctc2/transformer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -550,6 +550,7 @@ def forward(
550550
tgt_key_padding_mask: Optional[torch.Tensor] = None,
551551
memory_key_padding_mask: Optional[torch.Tensor] = None,
552552
warmup: float = 1.0,
553+
**kwargs,
553554
) -> torch.Tensor:
554555
"""Pass the inputs (and mask) through the decoder layer.
555556

egs/librispeech/ASR/conformer_mmi/transformer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,7 @@ def forward(
537537
memory_mask: Optional[torch.Tensor] = None,
538538
tgt_key_padding_mask: Optional[torch.Tensor] = None,
539539
memory_key_padding_mask: Optional[torch.Tensor] = None,
540+
**kwargs,
540541
) -> torch.Tensor:
541542
"""Pass the inputs (and mask) through the decoder layer.
542543

egs/librispeech/ASR/streaming_conformer_ctc/transformer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -567,6 +567,7 @@ def forward(
567567
memory_mask: Optional[torch.Tensor] = None,
568568
tgt_key_padding_mask: Optional[torch.Tensor] = None,
569569
memory_key_padding_mask: Optional[torch.Tensor] = None,
570+
**kwargs,
570571
) -> torch.Tensor:
571572
"""Pass the inputs (and mask) through the decoder layer.
572573

egs/tedlium3/ASR/conformer_ctc2/transformer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -612,6 +612,7 @@ def forward(
612612
tgt_key_padding_mask: Optional[torch.Tensor] = None,
613613
memory_key_padding_mask: Optional[torch.Tensor] = None,
614614
warmup: float = 1.0,
615+
**kwargs,
615616
) -> torch.Tensor:
616617
"""Pass the inputs (and mask) through the decoder layer.
617618

0 commit comments

Comments
 (0)