Skip to content

Commit cca3b0b

Browse files
authored
Merge pull request #1352 from calmdown539/dev-postgresql
2 parents 4383899 + d813917 commit cca3b0b

1 file changed

Lines changed: 32 additions & 0 deletions

File tree

  • examples/singa_peft/examples/model

examples/singa_peft/examples/model/trans.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,38 @@ def train_one_batch(self, enc_inputs, dec_inputs, dec_outputs, pad):
101101
def set_optimizer(self, opt):
102102
self.opt = opt
103103

104+
class TransformerDecoderLayer(layer.Layer):
105+
def __init__(self, d_model=512, n_head=8, dim_feedforward=2048):
106+
super(TransformerDecoderLayer, self).__init__()
107+
108+
self.d_model = d_model
109+
self.n_head = n_head
110+
self.dim_feedforward = dim_feedforward
111+
112+
self.dec_self_attn = MultiHeadAttention(d_model=d_model, n_head=n_head)
113+
self.dec_enc_attn = MultiHeadAttention(d_model=d_model, n_head=n_head)
114+
self.pos_ffn = PoswiseFeedForwardNet(d_model=d_model, dim_feedforward=dim_feedforward)
115+
116+
def forward(self, dec_inputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask):
117+
"""
118+
Args:
119+
dec_inputs: [batch_size, tgt_len, d_model]
120+
enc_outputs: [batch_size, src_len, d_model]
121+
dec_self_attn_mask: [batch_size, tgt_len, tgt_len]
122+
dec_enc_attn_mask: [batch_size, tgt_len, src_len]
123+
"""
124+
125+
# dec_outputs: [batch_size, tgt_len, d_model]
126+
# dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len]
127+
dec_outputs, dec_self_attn = self.dec_self_attn(dec_inputs, dec_inputs, dec_inputs, dec_self_attn_mask)
128+
129+
# dec_outputs: [batch_size, tgt_len, d_model]
130+
# dec_self_attn: [batch_size, n_heads, tgt_len, src_len]
131+
dec_outputs, dec_enc_attn = self.dec_enc_attn(dec_outputs, enc_outputs, enc_outputs, dec_enc_attn_mask)
132+
# [batch_size, tgt_len, d_model]
133+
dec_outputs = self.pos_ffn(dec_outputs)
134+
return dec_outputs, dec_self_attn, dec_enc_attn
135+
104136

105137
class TransformerDecoder(layer.Layer):
106138
"""TransformerDecoder is a stack of N decoder layers

0 commit comments

Comments
 (0)