4848 XLA_AVAILABLE = False
4949
5050
51+ def _apply_exclusive_self_attention (hidden_states : torch .Tensor , value : torch .Tensor ) -> torch .Tensor :
52+ if hidden_states .shape != value .shape :
53+ return hidden_states
54+
55+ eps = 1e-6 if value .dtype in (torch .float16 , torch .bfloat16 ) else 1e-12
56+ value_normalized = F .normalize (value , p = 2 , dim = - 1 , eps = eps )
57+ return hidden_states - (hidden_states * value_normalized ).sum (dim = - 1 , keepdim = True ) * value_normalized
58+
59+
5160@maybe_allow_in_graph
5261class Attention (nn .Module ):
5362 r"""
@@ -97,6 +106,8 @@ class Attention(nn.Module):
97106 A factor to rescale the output by dividing it with this value.
98107 residual_connection (`bool`, *optional*, defaults to `False`):
99108 Set to `True` to add the residual connection to the output.
109+ exclusive_self_attention (`bool`, *optional*, defaults to `False`):
110+ Whether to remove the value-vector component from self-attention outputs.
100111 _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`):
101112 Set to `True` if the attention block is loaded from a deprecated state dict.
102113 processor (`AttnProcessor`, *optional*, defaults to `None`):
@@ -136,6 +147,7 @@ def __init__(
136147 pre_only = False ,
137148 elementwise_affine : bool = True ,
138149 is_causal : bool = False ,
150+ exclusive_self_attention : bool = False ,
139151 ):
140152 super ().__init__ ()
141153
@@ -159,6 +171,7 @@ def __init__(
159171 self .context_pre_only = context_pre_only
160172 self .pre_only = pre_only
161173 self .is_causal = is_causal
174+ self .exclusive_self_attention = exclusive_self_attention
162175
163176 # we make use of this private variable to know whether this class is loaded
164177 # with an deprecated state dict so that we can convert it on the fly
@@ -1120,6 +1133,7 @@ def __call__(
11201133 deprecate ("scale" , "1.0.0" , deprecation_message )
11211134
11221135 residual = hidden_states
1136+ is_self_attention = encoder_hidden_states is None
11231137
11241138 if attn .spatial_norm is not None :
11251139 hidden_states = attn .spatial_norm (hidden_states , temb )
@@ -1154,6 +1168,8 @@ def __call__(
11541168
11551169 attention_probs = attn .get_attention_scores (query , key , attention_mask )
11561170 hidden_states = torch .bmm (attention_probs , value )
1171+ if attn .exclusive_self_attention and is_self_attention :
1172+ hidden_states = _apply_exclusive_self_attention (hidden_states , value )
11571173 hidden_states = attn .batch_to_head_dim (hidden_states )
11581174
11591175 # linear proj
@@ -2515,6 +2531,7 @@ def __call__(
25152531 deprecate ("scale" , "1.0.0" , deprecation_message )
25162532
25172533 residual = hidden_states
2534+ is_self_attention = encoder_hidden_states is None
25182535
25192536 if attn .spatial_norm is not None :
25202537 hidden_states = attn .spatial_norm (hidden_states , temb )
@@ -2561,6 +2578,8 @@ def __call__(
25612578 query , key , value , attn_bias = attention_mask , op = self .attention_op , scale = attn .scale
25622579 )
25632580 hidden_states = hidden_states .to (query .dtype )
2581+ if attn .exclusive_self_attention and is_self_attention :
2582+ hidden_states = _apply_exclusive_self_attention (hidden_states , value )
25642583 hidden_states = attn .batch_to_head_dim (hidden_states )
25652584
25662585 # linear proj
@@ -2606,6 +2625,7 @@ def __call__(
26062625 deprecate ("scale" , "1.0.0" , deprecation_message )
26072626
26082627 residual = hidden_states
2628+ is_self_attention = encoder_hidden_states is None
26092629 if attn .spatial_norm is not None :
26102630 hidden_states = attn .spatial_norm (hidden_states , temb )
26112631
@@ -2674,6 +2694,9 @@ def __call__(
26742694 query , key , value , attn_mask = attention_mask , dropout_p = 0.0 , is_causal = False
26752695 )
26762696
2697+ if attn .exclusive_self_attention and is_self_attention :
2698+ hidden_states = _apply_exclusive_self_attention (hidden_states , value )
2699+
26772700 hidden_states = hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads * head_dim )
26782701 hidden_states = hidden_states .to (query .dtype )
26792702
@@ -2717,6 +2740,7 @@ def __call__(
27172740 deprecate ("scale" , "1.0.0" , deprecation_message )
27182741
27192742 residual = hidden_states
2743+ is_self_attention = encoder_hidden_states is None
27202744 if attn .spatial_norm is not None :
27212745 hidden_states = attn .spatial_norm (hidden_states , temb )
27222746
@@ -2768,6 +2792,9 @@ def __call__(
27682792 query , key , value , attn_mask = attention_mask , dropout_p = 0.0 , is_causal = False
27692793 )
27702794
2795+ if attn .exclusive_self_attention and is_self_attention :
2796+ hidden_states = _apply_exclusive_self_attention (hidden_states , value )
2797+
27712798 hidden_states = hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads * head_dim )
27722799 hidden_states = hidden_states .to (query .dtype )
27732800
@@ -2814,6 +2841,7 @@ def __call__(
28142841 ** kwargs ,
28152842 ) -> torch .Tensor :
28162843 residual = hidden_states
2844+ is_self_attention = encoder_hidden_states is None
28172845 if attn .spatial_norm is not None :
28182846 hidden_states = attn .spatial_norm (hidden_states , temb )
28192847
@@ -2884,6 +2912,9 @@ def __call__(
28842912 query , key , value , attn_mask = attention_mask , dropout_p = 0.0 , is_causal = False
28852913 )
28862914
2915+ if attn .exclusive_self_attention and is_self_attention :
2916+ hidden_states = _apply_exclusive_self_attention (hidden_states , value )
2917+
28872918 hidden_states = hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads * head_dim )
28882919 hidden_states = hidden_states .to (query .dtype )
28892920
@@ -3695,6 +3726,7 @@ def __call__(
36953726 deprecate ("scale" , "1.0.0" , deprecation_message )
36963727
36973728 residual = hidden_states
3729+ is_self_attention = encoder_hidden_states is None
36983730 if attn .spatial_norm is not None :
36993731 hidden_states = attn .spatial_norm (hidden_states , temb )
37003732
@@ -3748,6 +3780,9 @@ def __call__(
37483780 query , key , value , attn_mask = attention_mask , dropout_p = 0.0 , is_causal = False
37493781 )
37503782
3783+ if attn .exclusive_self_attention and is_self_attention :
3784+ hidden_states = _apply_exclusive_self_attention (hidden_states , value )
3785+
37513786 hidden_states = hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads * head_dim )
37523787 hidden_states = hidden_states .to (query .dtype )
37533788
@@ -4018,6 +4053,7 @@ def __call__(
40184053 attention_mask : torch .Tensor | None = None ,
40194054 ) -> torch .Tensor :
40204055 residual = hidden_states
4056+ is_self_attention = encoder_hidden_states is None
40214057
40224058 input_ndim = hidden_states .ndim
40234059
@@ -4066,6 +4102,9 @@ def __call__(
40664102
40674103 hidden_states [start_idx :end_idx ] = attn_slice
40684104
4105+ if attn .exclusive_self_attention and is_self_attention :
4106+ hidden_states = _apply_exclusive_self_attention (hidden_states , value )
4107+
40694108 hidden_states = attn .batch_to_head_dim (hidden_states )
40704109
40714110 # linear proj
0 commit comments