@@ -114,6 +114,7 @@ def __init__(self,
114114 tgt_block_size = None ,
115115 use_sigmoid_attn = False ,
116116 sigmoid_attn_bias = None ,
117+ linformer_dim = None ,
117118 ** kwargs ):
118119 """Initializes `TransformerEncoderBlock`.
119120
@@ -191,6 +192,8 @@ def __init__(self,
191192 `block_sparse_attention.MultiHeadAttention`
192193 sigmoid_attn_bias: This param is only used in
193194 `block_sparse_attention.MultiHeadAttention`
195+ linformer_dim: Applies low-rank factorization on keys/values as in
196+ https://arxiv.org/pdf/2006.04768.
194197 **kwargs: keyword arguments.
195198 """
196199 util .filter_kwargs (kwargs )
@@ -230,6 +233,7 @@ def __init__(self,
230233 self ._tgt_block_size = tgt_block_size
231234 self ._use_sigmoid_attn = use_sigmoid_attn
232235 self ._sigmoid_attn_bias = sigmoid_attn_bias
236+ self ._linformer_dim = linformer_dim
233237 if self ._num_kv_heads is not None and self ._src_block_size is not None :
234238 raise ValueError (
235239 "Block sparse attention does not support Multi-query attention."
@@ -366,16 +370,31 @@ def build(self, input_shape):
366370 name = "output" ,
367371 kernel_initializer = tf_utils .clone_initializer (self ._kernel_initializer ),
368372 bias_initializer = tf_utils .clone_initializer (self ._bias_initializer ),
369- ** common_kwargs )
373+ ** common_kwargs ,
374+ )
370375 self ._output_dropout = tf_keras .layers .Dropout (
371- rate = self ._output_dropout_rate )
376+ rate = self ._output_dropout_rate
377+ )
372378 # Use float32 in layernorm for numeric stability.
373379 self ._output_layer_norm = tf_keras .layers .LayerNormalization (
374380 name = "output_layer_norm" ,
375381 axis = - 1 ,
376382 epsilon = self ._norm_epsilon ,
377- dtype = tf .float32 )
378-
383+ dtype = tf .float32 ,
384+ )
385+ if self ._linformer_dim is not None :
386+ # Current implementation uses the same weights for keys and values.
387+ # TODO(akandoor): Explore using different weights for keys and values.
388+ self ._lowrank_kv_projection = tf_keras .layers .EinsumDense (
389+ "...bc,cd->...bd" ,
390+ output_shape = (None , self ._linformer_dim ),
391+ kernel_initializer = tf_utils .clone_initializer (
392+ self ._kernel_initializer
393+ ),
394+ bias_initializer = tf_utils .clone_initializer (self ._bias_initializer ),
395+ name = "lowrank_kv_projection" ,
396+ ** common_kwargs ,
397+ )
379398 super ().build (input_shape )
380399
381400 def get_config (self ):
@@ -480,6 +499,19 @@ def call(self, inputs: Any, output_range: Optional[tf.Tensor] = None) -> Any:
480499 if key_value is None :
481500 key_value = input_tensor
482501
502+ if self ._linformer_dim is not None :
503+ if attention_mask is not None :
504+ # Applying mask before the low rank factorization so that padding is
505+ # accounted for.
506+ query_mask = tf .cast (attention_mask [:, :, 0 ], dtype = target_tensor .dtype )
507+ target_tensor = target_tensor * tf .expand_dims (query_mask , axis = - 1 )
508+ key_mask = tf .cast (attention_mask [:, 0 , :], dtype = target_tensor .dtype )
509+ key_value = key_value * tf .expand_dims (key_mask , axis = - 1 )
510+ attention_mask = None
511+ key_value = tf .transpose (key_value , [0 , 2 , 1 ])
512+ key_value = self ._lowrank_kv_projection (key_value )
513+ key_value = tf .transpose (key_value , [0 , 2 , 1 ])
514+
483515 if self ._return_attention_scores :
484516 attention_output , attention_scores = self ._attention_layer (
485517 query = target_tensor ,
0 commit comments