@@ -101,6 +101,38 @@ def train_one_batch(self, enc_inputs, dec_inputs, dec_outputs, pad):
101101 def set_optimizer (self , opt ):
102102 self .opt = opt
103103
104+ class TransformerDecoderLayer (layer .Layer ):
105+ def __init__ (self , d_model = 512 , n_head = 8 , dim_feedforward = 2048 ):
106+ super (TransformerDecoderLayer , self ).__init__ ()
107+
108+ self .d_model = d_model
109+ self .n_head = n_head
110+ self .dim_feedforward = dim_feedforward
111+
112+ self .dec_self_attn = MultiHeadAttention (d_model = d_model , n_head = n_head )
113+ self .dec_enc_attn = MultiHeadAttention (d_model = d_model , n_head = n_head )
114+ self .pos_ffn = PoswiseFeedForwardNet (d_model = d_model , dim_feedforward = dim_feedforward )
115+
116+ def forward (self , dec_inputs , enc_outputs , dec_self_attn_mask , dec_enc_attn_mask ):
117+ """
118+ Args:
119+ dec_inputs: [batch_size, tgt_len, d_model]
120+ enc_outputs: [batch_size, src_len, d_model]
121+ dec_self_attn_mask: [batch_size, tgt_len, tgt_len]
122+ dec_enc_attn_mask: [batch_size, tgt_len, src_len]
123+ """
124+
125+ # dec_outputs: [batch_size, tgt_len, d_model]
126+ # dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len]
127+ dec_outputs , dec_self_attn = self .dec_self_attn (dec_inputs , dec_inputs , dec_inputs , dec_self_attn_mask )
128+
129+ # dec_outputs: [batch_size, tgt_len, d_model]
130+ # dec_self_attn: [batch_size, n_heads, tgt_len, src_len]
131+ dec_outputs , dec_enc_attn = self .dec_enc_attn (dec_outputs , enc_outputs , enc_outputs , dec_enc_attn_mask )
132+ # [batch_size, tgt_len, d_model]
133+ dec_outputs = self .pos_ffn (dec_outputs )
134+ return dec_outputs , dec_self_attn , dec_enc_attn
135+
104136
105137class TransformerDecoder (layer .Layer ):
106138 """TransformerDecoder is a stack of N decoder layers
0 commit comments