@@ -112,6 +112,8 @@ def __init__(self,
112112 num_kv_heads = None ,
113113 src_block_size = None ,
114114 tgt_block_size = None ,
115+ use_sigmoid_attn = False ,
116+ sigmoid_attn_bias = None ,
115117 ** kwargs ):
116118 """Initializes `TransformerEncoderBlock`.
117119
@@ -185,6 +187,10 @@ def __init__(self,
185187 `block_sparse_attention.MultiHeadAttention` for more details.
186188 tgt_block_size: Target block size. Refer to
187189 `block_sparse_attention.MultiHeadAttention` for more details.
190+ use_sigmoid_attn: This param is only used in
191+ `block_sparse_attention.MultiHeadAttention`
192+ sigmoid_attn_bias: This param is only used in
193+ `block_sparse_attention.MultiHeadAttention`
188194 **kwargs: keyword arguments.
189195 """
190196 util .filter_kwargs (kwargs )
@@ -222,6 +228,8 @@ def __init__(self,
222228 self ._num_kv_heads = num_kv_heads
223229 self ._src_block_size = src_block_size
224230 self ._tgt_block_size = tgt_block_size
231+ self ._use_sigmoid_attn = use_sigmoid_attn
232+ self ._sigmoid_attn_bias = sigmoid_attn_bias
225233 if self ._num_kv_heads is not None and self ._src_block_size is not None :
226234 raise ValueError (
227235 "Block sparse attention does not support Multi-query attention."
@@ -285,6 +293,8 @@ def build(self, input_shape):
285293 attention_layer_kwargs .update (
286294 src_block_size = self ._src_block_size ,
287295 tgt_block_size = self ._tgt_block_size ,
296+ use_sigmoid_attn = self ._use_sigmoid_attn ,
297+ sigmoid_attn_bias = self ._sigmoid_attn_bias ,
288298 name = "block_sparse_attention" ,
289299 )
290300 attention_fn = block_sparse_attention .MultiHeadAttention
@@ -413,6 +423,8 @@ def get_config(self):
413423 "num_kv_heads" : self ._num_kv_heads ,
414424 "src_block_size" : self ._src_block_size ,
415425 "tgt_block_size" : self ._tgt_block_size ,
426+ "use_sigmoid_attn" : self ._use_sigmoid_attn ,
427+ "sigmoid_attn_bias" : self ._sigmoid_attn_bias ,
416428 }
417429 base_config = super ().get_config ()
418430 return dict (list (base_config .items ()) + list (config .items ()))
0 commit comments