Skip to content

Commit 8dcdcba

Browse files
authored
Merge pull request #1353 from calmdown539/dev-postgresql
Add the encoder layer for the transformer model
2 parents cca3b0b + 243fd76 commit 8dcdcba

1 file changed

Lines changed: 79 additions & 0 deletions

File tree

  • examples/singa_peft/examples/model

examples/singa_peft/examples/model/trans.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,3 +250,82 @@ def get_posi_angle_vec(position):
250250
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # Cosine function for odd digits
251251
return tensor.Tensor(data=sinusoid_table, requires_grad=False)
252252

253+
class TransformerEncoder(layer.Layer):
254+
"""TransformerEncoder is a stack of N encoder layers
255+
Args:
256+
src_n_token: the source vocab size
257+
d_model: the number of expected features in the encoder inputs (default=512).
258+
n_head: the number of heads in the multi head attention models (default=8).
259+
dim_feedforward: the dimension of the feedforward network model (default=2048).
260+
n_layers: the number of sub-encoder-layers in the encoder (default=6).
261+
"""
262+
263+
def __init__(self, src_n_token, d_model=512, n_head=8, dim_feedforward=2048, n_layers=6):
264+
super(TransformerEncoder, self).__init__()
265+
self.src_n_token = src_n_token
266+
self.d_model = d_model
267+
self.n_head = n_head
268+
self.dim_feedforward = dim_feedforward
269+
self.n_layers = n_layers
270+
271+
# input_emb / pos_emb / n-encoder layers
272+
self.input_emb = layer.Embedding(input_dim=src_n_token, output_dim=d_model)
273+
self.pos_emb = layer.Embedding(input_dim=src_n_token, output_dim=d_model)
274+
self.layers = []
275+
for _ in range(self.n_layers):
276+
self.layers.append(TransformerEncoderLayer(d_model=d_model, n_head=n_head, dim_feedforward=dim_feedforward))
277+
278+
def forward(self, enc_inputs):
279+
"""Pass the input through the encoder in turn.
280+
Args:
281+
enc_inputs: the sequence to the encoder (required). [batch_size, src_len]
282+
"""
283+
# [batch_size, src_len, d_model]
284+
word_emb = self.input_emb(enc_inputs)
285+
286+
self.pos_emb.initialize(enc_inputs)
287+
self.pos_emb.from_pretrained(W=TransformerEncoder._get_sinusoid_encoding_table(self.src_n_token, self.d_model), freeze=True)
288+
# [batch_size, src_len, d_model]
289+
pos_emb = self.pos_emb(enc_inputs)
290+
# enc_outputs [batch_size, src_len, d_model]
291+
enc_outputs = autograd.add(word_emb, pos_emb)
292+
293+
# enc_self_attn_mask [batch_size, src_len, src_len]
294+
enc_self_attn_mask = TransformerEncoder._get_attn_pad_mask(enc_inputs, enc_inputs)
295+
296+
enc_self_attns = []
297+
for layer in self.layers:
298+
enc_outputs, enc_self_attn = layer(enc_outputs, enc_self_attn_mask)
299+
enc_self_attns.append(enc_self_attn)
300+
return enc_outputs, enc_self_attns
301+
302+
@staticmethod
303+
def _get_attn_pad_mask(seq_q, seq_k):
304+
"""
305+
Args:
306+
seq_q: [batch_size, seq_len]
307+
seq_k: [batch_size, seq_len]
308+
Returns: [batch_size, seq_len, seq_len]
309+
"""
310+
batch_size, len_q = seq_q.shape
311+
batch_size, len_k = seq_k.shape
312+
seq_k_np = tensor.to_numpy(seq_k)
313+
pad_attn_mask_np = np.where(seq_k_np == 0, 1, 0)
314+
pad_attn_mask_np.astype(np.int32)
315+
pad_attn_mask_np = np.expand_dims(pad_attn_mask_np, axis=1)
316+
pad_attn_mask_np = np.broadcast_to(pad_attn_mask_np, (batch_size, len_q, len_k))
317+
pad_attn_mask_np = tensor.from_numpy(pad_attn_mask_np)
318+
return pad_attn_mask_np
319+
320+
@staticmethod
321+
def _get_sinusoid_encoding_table(n_position, d_model):
322+
def cal_angle(position, hid_idx):
323+
return position / np.power(10000, 2 * (hid_idx // 2) / d_model)
324+
325+
def get_posi_angle_vec(position):
326+
return [cal_angle(position, hid_j) for hid_j in range(d_model)]
327+
328+
sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)], np.float32)
329+
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])
330+
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])
331+
return tensor.Tensor(data=sinusoid_table, requires_grad=False)

0 commit comments

Comments
 (0)