@@ -523,6 +523,7 @@ def __init__(self, config):
523523 self .num_key_value_heads = config .num_key_value_heads
524524 self .num_key_value_groups = self .num_heads // self .num_key_value_heads
525525 self .max_position_embeddings = config .max_position_embeddings
526+ self .use_global_attention = getattr (config , "use_global_attention" , False )
526527
527528 self .q_proj = nn .Linear (
528529 self .hidden_size * 2 , self .num_heads * self .head_dim , bias = False
@@ -760,6 +761,10 @@ class LlamaFlexAttention(LlamaAttention):
760761 - past_key_values: dynamic cache used for storing past key and value states.
761762 """
762763
764+ def __init__ (self , config ):
765+ super ().__init__ (config )
766+ self .use_global_attention = getattr (config , "use_global_attention" , False )
767+
763768 def forward (
764769 self ,
765770 hidden_states : torch .Tensor ,
@@ -821,39 +826,45 @@ def forward(
821826 cache_kwargs = cache_kwargs ,
822827 )
823828
824- seq_lengths = attention_mask .sum (dim = - 1 )
825- # Shrink the attention mask to align with the padding to the right.
826- # This is equivalent to the shrinking logic in eagle3.py
827- seq_lengths -= lck
828- # TODO: Remove the usage of uncompiled create_block_mask after
829- # https://github.com/pytorch/pytorch/issues/160018
830- if q_len <= 128 :
831- create_block_mask_func = create_block_mask
832- flex_attention_func = flex_attention
829+ if self .use_global_attention :
830+ block_mask = None # Enables full attention
833831 else :
834- create_block_mask_func = compile_friendly_create_block_mask
835- flex_attention_func = compile_friendly_flex_attention
836-
837- block_mask = create_block_mask_func (
838- mask_mod = generate_eagle3_mask (
839- seq_lengths = seq_lengths ,
840- Q_LEN = q_len ,
841- KV_LEN = key_cache .shape [- 2 ],
842- lck = lck ,
843- ),
844- B = bsz ,
845- H = 1 , # Rely on broadcast
846- Q_LEN = q_len ,
847- KV_LEN = key_cache .shape [- 2 ],
848- device = query_states .device ,
849- )
850- attn_output = flex_attention_func (
851- query = query_states ,
852- key = key_cache .contiguous (),
853- value = value_cache .contiguous (),
854- block_mask = block_mask ,
855- enable_gqa = True ,
856- )
832+ seq_lengths = attention_mask .sum (dim = - 1 )
833+ # Shrink the attention mask to align with the padding to the right.
834+ # This is equivalent to the shrinking logic in eagle3.py
835+ seq_lengths -= lck
836+ # TODO: Remove the usage of uncompiled create_block_mask after
837+ # https://github.com/pytorch/pytorch/issues/160018
838+ if q_len <= 128 :
839+ create_block_mask_func = create_block_mask
840+ flex_attention_func = flex_attention
841+ else :
842+ create_block_mask_func = compile_friendly_create_block_mask
843+ flex_attention_func = compile_friendly_flex_attention
844+
845+ if self .use_global_attention :
846+ block_mask = None # This will result in dense attention
847+ else :
848+ block_mask = create_block_mask_func (
849+ mask_mod = generate_eagle3_mask (
850+ seq_lengths = seq_lengths ,
851+ Q_LEN = q_len ,
852+ KV_LEN = key_cache .shape [- 2 ],
853+ lck = lck ,
854+ ),
855+ B = bsz ,
856+ H = 1 , # Rely on broadcast
857+ Q_LEN = q_len ,
858+ KV_LEN = key_cache .shape [- 2 ],
859+ device = query_states .device ,
860+ )
861+ attn_output = flex_attention_func (
862+ query = query_states ,
863+ key = key_cache .contiguous (),
864+ value = value_cache .contiguous (),
865+ block_mask = block_mask ,
866+ enable_gqa = True ,
867+ )
857868 attn_output = attn_output .transpose (1 , 2 ).contiguous ()
858869 attn_output = attn_output .reshape (bsz , q_len , self .head_dim * self .num_heads )
859870 attn_output = self .o_proj (attn_output )
@@ -869,6 +880,10 @@ class LlamaFlashAttention(LlamaAttention):
869880 - cache_hidden: manual cache used for storing past key and value states
870881 """
871882
883+ def __init__ (self , config ):
884+ super ().__init__ (config )
885+ self .use_global_attention = getattr (config , "use_global_attention" , False )
886+
872887 def forward (
873888 self ,
874889 hidden_states : torch .Tensor ,
@@ -934,7 +949,7 @@ def forward(
934949 v0 ,
935950 dropout_p = 0.0 ,
936951 softmax_scale = 1.0 / math .sqrt (self .head_dim ),
937- causal = True ,
952+ causal = not self . use_global_attention , # Set causal based on the flag
938953 return_attn_probs = True ,
939954 )
940955 lse = lse .transpose (1 , 2 )
0 commit comments