@@ -372,3 +372,75 @@ def matmul4d(x1, x2):
372372 ys .append (yb )
373373 y = autograd .cat (ys , axis = 0 )
374374 return y
375+
376+ class MultiHeadAttention (layer .Layer ):
377+ def __init__ (self , d_model = 512 , n_head = 8 ):
378+ super (MultiHeadAttention , self ).__init__ ()
379+ self .d_k = d_model // n_head
380+ assert (
381+ self .d_k * n_head == d_model
382+ ), "embed_dim must be divisible by num_heads"
383+ self .d_model = d_model
384+ self .d_v = self .d_k
385+ self .n_head = n_head
386+ self .W_Q = Linear3D (d_model , self .d_k * n_head )
387+ self .W_K = Linear3D (d_model , self .d_k * n_head )
388+ self .W_V = Linear3D (d_model , self .d_v * n_head )
389+
390+ self .scaled_dot_product_attention = ScaledDotProductAttention (d_model , n_head )
391+ self .linear = Linear3D (self .d_v * n_head , d_model )
392+ self .add = layer .Add ()
393+ self .layer_norm = LayerNorm (d_model )
394+
395+ def forward (self , query , key , value , attn_mask ):
396+ """
397+ Args:
398+ query: [batch_size, len_q, d_model]
399+ key: [batch_size, len_k, d_model]
400+ value: [batch_size, len_v(=len_k), d_model]
401+ attn_mask: [batch_size, seq_len, seq_len]
402+ Returns:
403+ """
404+ residual = query
405+ batch_size = query .shape [0 ]
406+
407+ # (B, S, D) -proj-> (B, S, D_new) -split-> (B, S, H, W) -trans-> (B, H, S, W)
408+ Q = self .W_Q (query )
409+ Q = autograd .reshape (Q , [batch_size , - 1 , self .n_head , self .d_k ])
410+ Q = autograd .transpose (Q , [0 , 2 , 1 , 3 ])
411+
412+ K = self .W_K (key )
413+ K = autograd .reshape (K , [batch_size , - 1 , self .n_head , self .d_k ])
414+ K = autograd .transpose (K , [0 , 2 , 1 , 3 ])
415+
416+ V = self .W_V (value )
417+ V = autograd .reshape (V , [batch_size , - 1 , self .n_head , self .d_v ])
418+ V = autograd .transpose (V , [0 , 2 , 1 , 3 ])
419+
420+ # Q: [batch_size, n_heads, len_q, d_k]
421+ # K: [batch_size, n_heads, len_k, d_k]
422+ # V: [batch_size, n_heads, len_v(=len_k), d_v]
423+
424+ # attn_mask : [batch_size, n_heads, seq_len, seq_len]
425+ attn_mask = MultiHeadAttention ._get_attn_mask (attn_mask , self .n_head )
426+
427+ # context: [batch_size, n_heads, len_q, d_v]
428+ # attn: [batch_size, n_heads, seq_len, seq_len]
429+ context , attn = self .scaled_dot_product_attention (Q , K , V , attn_mask )
430+ context = autograd .transpose (context , [0 , 2 , 1 , 3 ])
431+ # context: [batch_size, len_q, n_heads * d_v]
432+ context = autograd .reshape (context , [batch_size , - 1 , self .n_head * self .d_v ])
433+
434+ output = self .linear (context )
435+ output = self .add (output , residual )
436+ # [batch_size, len_q, d_model]
437+ output = self .layer_norm (output )
438+ return output , attn
439+
440+ @staticmethod
441+ def _get_attn_mask (attn_mask , n_head ):
442+ batch_size , seq_q_len ,seq_k_len = attn_mask .shape [0 ], attn_mask .shape [1 ], attn_mask .shape [2 ]
443+ attn_mask_np = tensor .to_numpy (attn_mask )
444+ attn_mask_np = np .expand_dims (attn_mask_np , axis = 1 )
445+ attn_mask_np = np .broadcast_to (attn_mask_np , (batch_size , n_head , seq_q_len , seq_k_len ))
446+ return tensor .from_numpy (attn_mask_np )
0 commit comments