Skip to content

Commit 30f522f

Browse files
committed
Add opt-in exclusive self-attention
1 parent 48f39c2 commit 30f522f

9 files changed

Lines changed: 269 additions & 0 deletions

src/diffusers/models/attention.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -782,6 +782,8 @@ class BasicTransformerBlock(nn.Module):
782782
The type of positional embeddings to apply to.
783783
num_positional_embeddings (`int`, *optional*, defaults to `None`):
784784
The maximum number of positional embeddings to apply.
785+
exclusive_self_attention (`bool`, *optional*, defaults to `False`):
786+
Whether to remove the value-vector component from self-attention outputs.
785787
"""
786788

787789
def __init__(
@@ -809,6 +811,7 @@ def __init__(
809811
ff_inner_dim: int | None = None,
810812
ff_bias: bool = True,
811813
attention_out_bias: bool = True,
814+
exclusive_self_attention: bool = False,
812815
):
813816
super().__init__()
814817
self.dim = dim
@@ -877,6 +880,7 @@ def __init__(
877880
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
878881
upcast_attention=upcast_attention,
879882
out_bias=attention_out_bias,
883+
exclusive_self_attention=exclusive_self_attention and not only_cross_attention,
880884
)
881885

882886
# 2. Cross-Attn
@@ -907,6 +911,7 @@ def __init__(
907911
bias=attention_bias,
908912
upcast_attention=upcast_attention,
909913
out_bias=attention_out_bias,
914+
exclusive_self_attention=exclusive_self_attention and double_self_attention,
910915
) # is self-attn if encoder_hidden_states is none
911916
else:
912917
if norm_type == "ada_norm_single": # For Latte

src/diffusers/models/attention_processor.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,15 @@
4848
XLA_AVAILABLE = False
4949

5050

51+
def _apply_exclusive_self_attention(hidden_states: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
52+
if hidden_states.shape != value.shape:
53+
return hidden_states
54+
55+
eps = 1e-6 if value.dtype in (torch.float16, torch.bfloat16) else 1e-12
56+
value_normalized = F.normalize(value, p=2, dim=-1, eps=eps)
57+
return hidden_states - (hidden_states * value_normalized).sum(dim=-1, keepdim=True) * value_normalized
58+
59+
5160
@maybe_allow_in_graph
5261
class Attention(nn.Module):
5362
r"""
@@ -97,6 +106,8 @@ class Attention(nn.Module):
97106
A factor to rescale the output by dividing it with this value.
98107
residual_connection (`bool`, *optional*, defaults to `False`):
99108
Set to `True` to add the residual connection to the output.
109+
exclusive_self_attention (`bool`, *optional*, defaults to `False`):
110+
Whether to remove the value-vector component from self-attention outputs.
100111
_from_deprecated_attn_block (`bool`, *optional*, defaults to `False`):
101112
Set to `True` if the attention block is loaded from a deprecated state dict.
102113
processor (`AttnProcessor`, *optional*, defaults to `None`):
@@ -136,6 +147,7 @@ def __init__(
136147
pre_only=False,
137148
elementwise_affine: bool = True,
138149
is_causal: bool = False,
150+
exclusive_self_attention: bool = False,
139151
):
140152
super().__init__()
141153

@@ -159,6 +171,7 @@ def __init__(
159171
self.context_pre_only = context_pre_only
160172
self.pre_only = pre_only
161173
self.is_causal = is_causal
174+
self.exclusive_self_attention = exclusive_self_attention
162175

163176
# we make use of this private variable to know whether this class is loaded
164177
# with an deprecated state dict so that we can convert it on the fly
@@ -1120,6 +1133,7 @@ def __call__(
11201133
deprecate("scale", "1.0.0", deprecation_message)
11211134

11221135
residual = hidden_states
1136+
is_self_attention = encoder_hidden_states is None
11231137

11241138
if attn.spatial_norm is not None:
11251139
hidden_states = attn.spatial_norm(hidden_states, temb)
@@ -1154,6 +1168,8 @@ def __call__(
11541168

11551169
attention_probs = attn.get_attention_scores(query, key, attention_mask)
11561170
hidden_states = torch.bmm(attention_probs, value)
1171+
if attn.exclusive_self_attention and is_self_attention:
1172+
hidden_states = _apply_exclusive_self_attention(hidden_states, value)
11571173
hidden_states = attn.batch_to_head_dim(hidden_states)
11581174

11591175
# linear proj
@@ -2515,6 +2531,7 @@ def __call__(
25152531
deprecate("scale", "1.0.0", deprecation_message)
25162532

25172533
residual = hidden_states
2534+
is_self_attention = encoder_hidden_states is None
25182535

25192536
if attn.spatial_norm is not None:
25202537
hidden_states = attn.spatial_norm(hidden_states, temb)
@@ -2561,6 +2578,8 @@ def __call__(
25612578
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
25622579
)
25632580
hidden_states = hidden_states.to(query.dtype)
2581+
if attn.exclusive_self_attention and is_self_attention:
2582+
hidden_states = _apply_exclusive_self_attention(hidden_states, value)
25642583
hidden_states = attn.batch_to_head_dim(hidden_states)
25652584

25662585
# linear proj
@@ -2606,6 +2625,7 @@ def __call__(
26062625
deprecate("scale", "1.0.0", deprecation_message)
26072626

26082627
residual = hidden_states
2628+
is_self_attention = encoder_hidden_states is None
26092629
if attn.spatial_norm is not None:
26102630
hidden_states = attn.spatial_norm(hidden_states, temb)
26112631

@@ -2674,6 +2694,9 @@ def __call__(
26742694
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
26752695
)
26762696

2697+
if attn.exclusive_self_attention and is_self_attention:
2698+
hidden_states = _apply_exclusive_self_attention(hidden_states, value)
2699+
26772700
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
26782701
hidden_states = hidden_states.to(query.dtype)
26792702

@@ -2717,6 +2740,7 @@ def __call__(
27172740
deprecate("scale", "1.0.0", deprecation_message)
27182741

27192742
residual = hidden_states
2743+
is_self_attention = encoder_hidden_states is None
27202744
if attn.spatial_norm is not None:
27212745
hidden_states = attn.spatial_norm(hidden_states, temb)
27222746

@@ -2768,6 +2792,9 @@ def __call__(
27682792
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
27692793
)
27702794

2795+
if attn.exclusive_self_attention and is_self_attention:
2796+
hidden_states = _apply_exclusive_self_attention(hidden_states, value)
2797+
27712798
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
27722799
hidden_states = hidden_states.to(query.dtype)
27732800

@@ -2814,6 +2841,7 @@ def __call__(
28142841
**kwargs,
28152842
) -> torch.Tensor:
28162843
residual = hidden_states
2844+
is_self_attention = encoder_hidden_states is None
28172845
if attn.spatial_norm is not None:
28182846
hidden_states = attn.spatial_norm(hidden_states, temb)
28192847

@@ -2884,6 +2912,9 @@ def __call__(
28842912
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
28852913
)
28862914

2915+
if attn.exclusive_self_attention and is_self_attention:
2916+
hidden_states = _apply_exclusive_self_attention(hidden_states, value)
2917+
28872918
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
28882919
hidden_states = hidden_states.to(query.dtype)
28892920

@@ -3695,6 +3726,7 @@ def __call__(
36953726
deprecate("scale", "1.0.0", deprecation_message)
36963727

36973728
residual = hidden_states
3729+
is_self_attention = encoder_hidden_states is None
36983730
if attn.spatial_norm is not None:
36993731
hidden_states = attn.spatial_norm(hidden_states, temb)
37003732

@@ -3748,6 +3780,9 @@ def __call__(
37483780
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
37493781
)
37503782

3783+
if attn.exclusive_self_attention and is_self_attention:
3784+
hidden_states = _apply_exclusive_self_attention(hidden_states, value)
3785+
37513786
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
37523787
hidden_states = hidden_states.to(query.dtype)
37533788

@@ -4018,6 +4053,7 @@ def __call__(
40184053
attention_mask: torch.Tensor | None = None,
40194054
) -> torch.Tensor:
40204055
residual = hidden_states
4056+
is_self_attention = encoder_hidden_states is None
40214057

40224058
input_ndim = hidden_states.ndim
40234059

@@ -4066,6 +4102,9 @@ def __call__(
40664102

40674103
hidden_states[start_idx:end_idx] = attn_slice
40684104

4105+
if attn.exclusive_self_attention and is_self_attention:
4106+
hidden_states = _apply_exclusive_self_attention(hidden_states, value)
4107+
40694108
hidden_states = attn.batch_to_head_dim(hidden_states)
40704109

40714110
# linear proj

src/diffusers/models/transformers/dit_transformer_2d.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def __init__(
8787
norm_type: str = "ada_norm_zero",
8888
norm_elementwise_affine: bool = False,
8989
norm_eps: float = 1e-5,
90+
exclusive_self_attention: bool = False,
9091
):
9192
super().__init__()
9293

@@ -133,6 +134,7 @@ def __init__(
133134
norm_type=norm_type,
134135
norm_elementwise_affine=self.config.norm_elementwise_affine,
135136
norm_eps=self.config.norm_eps,
137+
exclusive_self_attention=self.config.exclusive_self_attention,
136138
)
137139
for _ in range(self.config.num_layers)
138140
]

src/diffusers/models/transformers/pixart_transformer_2d.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def __init__(
105105
use_additional_conditions: bool | None = None,
106106
caption_channels: int | None = None,
107107
attention_type: str | None = "default",
108+
exclusive_self_attention: bool = False,
108109
):
109110
super().__init__()
110111

@@ -165,6 +166,7 @@ def __init__(
165166
norm_elementwise_affine=self.config.norm_elementwise_affine,
166167
norm_eps=self.config.norm_eps,
167168
attention_type=self.config.attention_type,
169+
exclusive_self_attention=self.config.exclusive_self_attention,
168170
)
169171
for _ in range(self.config.num_layers)
170172
]

src/diffusers/models/transformers/transformer_2d.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ def __init__(
9696
caption_channels: int = None,
9797
interpolation_scale: float = None,
9898
use_additional_conditions: bool | None = None,
99+
exclusive_self_attention: bool = False,
99100
):
100101
super().__init__()
101102

@@ -199,6 +200,7 @@ def _init_continuous_input(self, norm_type):
199200
norm_elementwise_affine=self.config.norm_elementwise_affine,
200201
norm_eps=self.config.norm_eps,
201202
attention_type=self.config.attention_type,
203+
exclusive_self_attention=self.config.exclusive_self_attention,
202204
)
203205
for _ in range(self.config.num_layers)
204206
]
@@ -241,6 +243,7 @@ def _init_vectorized_inputs(self, norm_type):
241243
norm_elementwise_affine=self.config.norm_elementwise_affine,
242244
norm_eps=self.config.norm_eps,
243245
attention_type=self.config.attention_type,
246+
exclusive_self_attention=self.config.exclusive_self_attention,
244247
)
245248
for _ in range(self.config.num_layers)
246249
]
@@ -288,6 +291,7 @@ def _init_patched_inputs(self, norm_type):
288291
norm_elementwise_affine=self.config.norm_elementwise_affine,
289292
norm_eps=self.config.norm_eps,
290293
attention_type=self.config.attention_type,
294+
exclusive_self_attention=self.config.exclusive_self_attention,
291295
)
292296
for _ in range(self.config.num_layers)
293297
]

0 commit comments

Comments
 (0)