Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,6 +782,8 @@ class BasicTransformerBlock(nn.Module):
The type of positional embeddings to apply to.
num_positional_embeddings (`int`, *optional*, defaults to `None`):
The maximum number of positional embeddings to apply.
exclusive_self_attention (`bool`, *optional*, defaults to `False`):
Whether to remove the value-vector component from self-attention outputs.
"""

def __init__(
Expand Down Expand Up @@ -809,6 +811,7 @@ def __init__(
ff_inner_dim: int | None = None,
ff_bias: bool = True,
attention_out_bias: bool = True,
exclusive_self_attention: bool = False,
):
super().__init__()
self.dim = dim
Expand Down Expand Up @@ -877,6 +880,7 @@ def __init__(
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
upcast_attention=upcast_attention,
out_bias=attention_out_bias,
exclusive_self_attention=exclusive_self_attention and not only_cross_attention,
)

# 2. Cross-Attn
Expand Down Expand Up @@ -907,6 +911,7 @@ def __init__(
bias=attention_bias,
upcast_attention=upcast_attention,
out_bias=attention_out_bias,
exclusive_self_attention=exclusive_self_attention and double_self_attention,
) # is self-attn if encoder_hidden_states is none
else:
if norm_type == "ada_norm_single": # For Latte
Expand Down
39 changes: 39 additions & 0 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,15 @@
XLA_AVAILABLE = False


def _apply_exclusive_self_attention(hidden_states: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
if hidden_states.shape != value.shape:
return hidden_states

eps = 1e-6 if value.dtype in (torch.float16, torch.bfloat16) else 1e-12
value_normalized = F.normalize(value, p=2, dim=-1, eps=eps)
return hidden_states - (hidden_states * value_normalized).sum(dim=-1, keepdim=True) * value_normalized


@maybe_allow_in_graph
class Attention(nn.Module):
r"""
Expand Down Expand Up @@ -97,6 +106,8 @@ class Attention(nn.Module):
A factor to rescale the output by dividing it with this value.
residual_connection (`bool`, *optional*, defaults to `False`):
Set to `True` to add the residual connection to the output.
exclusive_self_attention (`bool`, *optional*, defaults to `False`):
Whether to remove the value-vector component from self-attention outputs.
_from_deprecated_attn_block (`bool`, *optional*, defaults to `False`):
Set to `True` if the attention block is loaded from a deprecated state dict.
processor (`AttnProcessor`, *optional*, defaults to `None`):
Expand Down Expand Up @@ -136,6 +147,7 @@ def __init__(
pre_only=False,
elementwise_affine: bool = True,
is_causal: bool = False,
exclusive_self_attention: bool = False,
):
super().__init__()

Expand All @@ -159,6 +171,7 @@ def __init__(
self.context_pre_only = context_pre_only
self.pre_only = pre_only
self.is_causal = is_causal
self.exclusive_self_attention = exclusive_self_attention

# we make use of this private variable to know whether this class is loaded
# with an deprecated state dict so that we can convert it on the fly
Expand Down Expand Up @@ -1120,6 +1133,7 @@ def __call__(
deprecate("scale", "1.0.0", deprecation_message)

residual = hidden_states
is_self_attention = encoder_hidden_states is None

if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
Expand Down Expand Up @@ -1154,6 +1168,8 @@ def __call__(

attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
if attn.exclusive_self_attention and is_self_attention:
hidden_states = _apply_exclusive_self_attention(hidden_states, value)
hidden_states = attn.batch_to_head_dim(hidden_states)

# linear proj
Expand Down Expand Up @@ -2515,6 +2531,7 @@ def __call__(
deprecate("scale", "1.0.0", deprecation_message)

residual = hidden_states
is_self_attention = encoder_hidden_states is None

if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
Expand Down Expand Up @@ -2561,6 +2578,8 @@ def __call__(
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
)
hidden_states = hidden_states.to(query.dtype)
if attn.exclusive_self_attention and is_self_attention:
hidden_states = _apply_exclusive_self_attention(hidden_states, value)
hidden_states = attn.batch_to_head_dim(hidden_states)

# linear proj
Expand Down Expand Up @@ -2606,6 +2625,7 @@ def __call__(
deprecate("scale", "1.0.0", deprecation_message)

residual = hidden_states
is_self_attention = encoder_hidden_states is None
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)

Expand Down Expand Up @@ -2674,6 +2694,9 @@ def __call__(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)

if attn.exclusive_self_attention and is_self_attention:
hidden_states = _apply_exclusive_self_attention(hidden_states, value)

hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)

Expand Down Expand Up @@ -2717,6 +2740,7 @@ def __call__(
deprecate("scale", "1.0.0", deprecation_message)

residual = hidden_states
is_self_attention = encoder_hidden_states is None
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)

Expand Down Expand Up @@ -2768,6 +2792,9 @@ def __call__(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)

if attn.exclusive_self_attention and is_self_attention:
hidden_states = _apply_exclusive_self_attention(hidden_states, value)

hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)

Expand Down Expand Up @@ -2814,6 +2841,7 @@ def __call__(
**kwargs,
) -> torch.Tensor:
residual = hidden_states
is_self_attention = encoder_hidden_states is None
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)

Expand Down Expand Up @@ -2884,6 +2912,9 @@ def __call__(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)

if attn.exclusive_self_attention and is_self_attention:
hidden_states = _apply_exclusive_self_attention(hidden_states, value)

hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)

Expand Down Expand Up @@ -3695,6 +3726,7 @@ def __call__(
deprecate("scale", "1.0.0", deprecation_message)

residual = hidden_states
is_self_attention = encoder_hidden_states is None
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)

Expand Down Expand Up @@ -3748,6 +3780,9 @@ def __call__(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)

if attn.exclusive_self_attention and is_self_attention:
hidden_states = _apply_exclusive_self_attention(hidden_states, value)

hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)

Expand Down Expand Up @@ -4018,6 +4053,7 @@ def __call__(
attention_mask: torch.Tensor | None = None,
) -> torch.Tensor:
residual = hidden_states
is_self_attention = encoder_hidden_states is None

input_ndim = hidden_states.ndim

Expand Down Expand Up @@ -4066,6 +4102,9 @@ def __call__(

hidden_states[start_idx:end_idx] = attn_slice

if attn.exclusive_self_attention and is_self_attention:
hidden_states = _apply_exclusive_self_attention(hidden_states, value)

hidden_states = attn.batch_to_head_dim(hidden_states)

# linear proj
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/models/transformers/dit_transformer_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def __init__(
norm_type: str = "ada_norm_zero",
norm_elementwise_affine: bool = False,
norm_eps: float = 1e-5,
exclusive_self_attention: bool = False,
):
super().__init__()

Expand Down Expand Up @@ -133,6 +134,7 @@ def __init__(
norm_type=norm_type,
norm_elementwise_affine=self.config.norm_elementwise_affine,
norm_eps=self.config.norm_eps,
exclusive_self_attention=self.config.exclusive_self_attention,
)
for _ in range(self.config.num_layers)
]
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/models/transformers/pixart_transformer_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def __init__(
use_additional_conditions: bool | None = None,
caption_channels: int | None = None,
attention_type: str | None = "default",
exclusive_self_attention: bool = False,
):
super().__init__()

Expand Down Expand Up @@ -165,6 +166,7 @@ def __init__(
norm_elementwise_affine=self.config.norm_elementwise_affine,
norm_eps=self.config.norm_eps,
attention_type=self.config.attention_type,
exclusive_self_attention=self.config.exclusive_self_attention,
)
for _ in range(self.config.num_layers)
]
Expand Down
4 changes: 4 additions & 0 deletions src/diffusers/models/transformers/transformer_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def __init__(
caption_channels: int = None,
interpolation_scale: float = None,
use_additional_conditions: bool | None = None,
exclusive_self_attention: bool = False,
):
super().__init__()

Expand Down Expand Up @@ -199,6 +200,7 @@ def _init_continuous_input(self, norm_type):
norm_elementwise_affine=self.config.norm_elementwise_affine,
norm_eps=self.config.norm_eps,
attention_type=self.config.attention_type,
exclusive_self_attention=self.config.exclusive_self_attention,
)
for _ in range(self.config.num_layers)
]
Expand Down Expand Up @@ -241,6 +243,7 @@ def _init_vectorized_inputs(self, norm_type):
norm_elementwise_affine=self.config.norm_elementwise_affine,
norm_eps=self.config.norm_eps,
attention_type=self.config.attention_type,
exclusive_self_attention=self.config.exclusive_self_attention,
)
for _ in range(self.config.num_layers)
]
Expand Down Expand Up @@ -288,6 +291,7 @@ def _init_patched_inputs(self, norm_type):
norm_elementwise_affine=self.config.norm_elementwise_affine,
norm_eps=self.config.norm_eps,
attention_type=self.config.attention_type,
exclusive_self_attention=self.config.exclusive_self_attention,
)
for _ in range(self.config.num_layers)
]
Expand Down
Loading
Loading