Skip to content

Commit 026a02d

Browse files
authored
Add the implementations for the MultiHeadAttention Layer
Add the implementations for the MultiHeadAttention Layer
1 parent 15c8794 commit 026a02d

1 file changed

Lines changed: 72 additions & 0 deletions

File tree

  • examples/singa_peft/examples/model

examples/singa_peft/examples/model/trans.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,3 +372,75 @@ def matmul4d(x1, x2):
372372
ys.append(yb)
373373
y = autograd.cat(ys, axis=0)
374374
return y
375+
376+
class MultiHeadAttention(layer.Layer):
377+
def __init__(self, d_model=512, n_head=8):
378+
super(MultiHeadAttention, self).__init__()
379+
self.d_k = d_model // n_head
380+
assert (
381+
self.d_k * n_head == d_model
382+
), "embed_dim must be divisible by num_heads"
383+
self.d_model = d_model
384+
self.d_v = self.d_k
385+
self.n_head = n_head
386+
self.W_Q = Linear3D(d_model, self.d_k * n_head)
387+
self.W_K = Linear3D(d_model, self.d_k * n_head)
388+
self.W_V = Linear3D(d_model, self.d_v * n_head)
389+
390+
self.scaled_dot_product_attention = ScaledDotProductAttention(d_model, n_head)
391+
self.linear = Linear3D(self.d_v * n_head, d_model)
392+
self.add = layer.Add()
393+
self.layer_norm = LayerNorm(d_model)
394+
395+
def forward(self, query, key, value, attn_mask):
396+
"""
397+
Args:
398+
query: [batch_size, len_q, d_model]
399+
key: [batch_size, len_k, d_model]
400+
value: [batch_size, len_v(=len_k), d_model]
401+
attn_mask: [batch_size, seq_len, seq_len]
402+
Returns:
403+
"""
404+
residual = query
405+
batch_size = query.shape[0]
406+
407+
# (B, S, D) -proj-> (B, S, D_new) -split-> (B, S, H, W) -trans-> (B, H, S, W)
408+
Q = self.W_Q(query)
409+
Q = autograd.reshape(Q, [batch_size, -1, self.n_head, self.d_k])
410+
Q = autograd.transpose(Q, [0, 2, 1, 3])
411+
412+
K = self.W_K(key)
413+
K = autograd.reshape(K, [batch_size, -1, self.n_head, self.d_k])
414+
K = autograd.transpose(K, [0, 2, 1, 3])
415+
416+
V = self.W_V(value)
417+
V = autograd.reshape(V, [batch_size, -1, self.n_head, self.d_v])
418+
V = autograd.transpose(V, [0, 2, 1, 3])
419+
420+
# Q: [batch_size, n_heads, len_q, d_k]
421+
# K: [batch_size, n_heads, len_k, d_k]
422+
# V: [batch_size, n_heads, len_v(=len_k), d_v]
423+
424+
# attn_mask : [batch_size, n_heads, seq_len, seq_len]
425+
attn_mask = MultiHeadAttention._get_attn_mask(attn_mask, self.n_head)
426+
427+
# context: [batch_size, n_heads, len_q, d_v]
428+
# attn: [batch_size, n_heads, seq_len, seq_len]
429+
context, attn = self.scaled_dot_product_attention(Q, K, V, attn_mask)
430+
context = autograd.transpose(context, [0, 2, 1, 3])
431+
# context: [batch_size, len_q, n_heads * d_v]
432+
context = autograd.reshape(context, [batch_size, -1, self.n_head * self.d_v])
433+
434+
output = self.linear(context)
435+
output = self.add(output, residual)
436+
# [batch_size, len_q, d_model]
437+
output = self.layer_norm(output)
438+
return output, attn
439+
440+
@staticmethod
441+
def _get_attn_mask(attn_mask, n_head):
442+
batch_size, seq_q_len,seq_k_len = attn_mask.shape[0], attn_mask.shape[1], attn_mask.shape[2]
443+
attn_mask_np = tensor.to_numpy(attn_mask)
444+
attn_mask_np = np.expand_dims(attn_mask_np, axis=1)
445+
attn_mask_np = np.broadcast_to(attn_mask_np, (batch_size, n_head, seq_q_len, seq_k_len))
446+
return tensor.from_numpy(attn_mask_np)

0 commit comments

Comments
 (0)