@@ -115,6 +115,7 @@ def __init__(self,
115115 use_sigmoid_attn = False ,
116116 sigmoid_attn_bias = None ,
117117 linformer_dim = None ,
118+ linformer_shared_kv_projection = True ,
118119 ** kwargs ):
119120 """Initializes `TransformerEncoderBlock`.
120121
@@ -194,6 +195,8 @@ def __init__(self,
194195 `block_sparse_attention.MultiHeadAttention`
195196 linformer_dim: Applies low-rank factorization on keys/values as in
196197 https://arxiv.org/pdf/2006.04768.
198+ linformer_shared_kv_projection: If set, projection layer is shared for
199+ keys and values.
197200 **kwargs: keyword arguments.
198201 """
199202 util .filter_kwargs (kwargs )
@@ -234,6 +237,7 @@ def __init__(self,
234237 self ._use_sigmoid_attn = use_sigmoid_attn
235238 self ._sigmoid_attn_bias = sigmoid_attn_bias
236239 self ._linformer_dim = linformer_dim
240+ self ._linformer_shared_kv_projection = linformer_shared_kv_projection
237241 if self ._num_kv_heads is not None and self ._src_block_size is not None :
238242 raise ValueError (
239243 "Block sparse attention does not support Multi-query attention."
@@ -383,11 +387,13 @@ def build(self, input_shape):
383387 dtype = tf .float32 ,
384388 )
385389 if self ._linformer_dim is not None :
386- # Current implementation uses the same weights for keys and values.
387- # TODO(akandoor): Explore using different weights for keys and values.
390+ if self ._linformer_shared_kv_projection :
391+ low_rank_dim = self ._linformer_dim
392+ else :
393+ low_rank_dim = 2 * self ._linformer_dim
388394 self ._lowrank_kv_projection = tf_keras .layers .EinsumDense (
389395 "...bc,cd->...bd" ,
390- output_shape = (None , self . _linformer_dim ),
396+ output_shape = (None , low_rank_dim ),
391397 kernel_initializer = tf_utils .clone_initializer (
392398 self ._kernel_initializer
393399 ),
@@ -444,6 +450,8 @@ def get_config(self):
444450 "tgt_block_size" : self ._tgt_block_size ,
445451 "use_sigmoid_attn" : self ._use_sigmoid_attn ,
446452 "sigmoid_attn_bias" : self ._sigmoid_attn_bias ,
453+ "linformer_dim" : self ._linformer_dim ,
454+ "linformer_shared_kv_projection" : self ._linformer_shared_kv_projection ,
447455 }
448456 base_config = super ().get_config ()
449457 return dict (list (base_config .items ()) + list (config .items ()))
@@ -499,6 +507,8 @@ def call(self, inputs: Any, output_range: Optional[tf.Tensor] = None) -> Any:
499507 if key_value is None :
500508 key_value = input_tensor
501509
510+ key = key_value
511+ value = key_value
502512 if self ._linformer_dim is not None :
503513 if attention_mask is not None :
504514 # Applying mask before the low rank factorization so that padding is
@@ -510,17 +520,28 @@ def call(self, inputs: Any, output_range: Optional[tf.Tensor] = None) -> Any:
510520 attention_mask = None
511521 key_value = tf .transpose (key_value , [0 , 2 , 1 ])
512522 key_value = self ._lowrank_kv_projection (key_value )
513- key_value = tf .transpose (key_value , [0 , 2 , 1 ])
514-
523+ if self ._linformer_shared_kv_projection :
524+ key_value = tf .transpose (key_value , [0 , 2 , 1 ])
525+ key = key_value
526+ value = key_value
527+ else :
528+ key = tf .transpose (key_value [:, :, :self ._linformer_dim ], [0 , 2 , 1 ])
529+ value = tf .transpose (key_value [:, :, self ._linformer_dim :], [0 , 2 , 1 ])
515530 if self ._return_attention_scores :
516531 attention_output , attention_scores = self ._attention_layer (
517532 query = target_tensor ,
518- value = key_value ,
533+ key = key ,
534+ value = value ,
519535 attention_mask = attention_mask ,
520- return_attention_scores = True )
536+ return_attention_scores = True ,
537+ )
521538 else :
522539 attention_output = self ._attention_layer (
523- query = target_tensor , value = key_value , attention_mask = attention_mask )
540+ query = target_tensor ,
541+ key = key ,
542+ value = value ,
543+ attention_mask = attention_mask ,
544+ )
524545 attention_output = self ._attention_dropout (attention_output )
525546
526547 if self ._norm_first :
0 commit comments