diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 36d0893734c7..bc0b8e2f5789 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -782,6 +782,8 @@ class BasicTransformerBlock(nn.Module): The type of positional embeddings to apply to. num_positional_embeddings (`int`, *optional*, defaults to `None`): The maximum number of positional embeddings to apply. + exclusive_self_attention (`bool`, *optional*, defaults to `False`): + Whether to remove the value-vector component from self-attention outputs. """ def __init__( @@ -809,6 +811,7 @@ def __init__( ff_inner_dim: int | None = None, ff_bias: bool = True, attention_out_bias: bool = True, + exclusive_self_attention: bool = False, ): super().__init__() self.dim = dim @@ -877,6 +880,7 @@ def __init__( cross_attention_dim=cross_attention_dim if only_cross_attention else None, upcast_attention=upcast_attention, out_bias=attention_out_bias, + exclusive_self_attention=exclusive_self_attention and not only_cross_attention, ) # 2. Cross-Attn @@ -907,6 +911,7 @@ def __init__( bias=attention_bias, upcast_attention=upcast_attention, out_bias=attention_out_bias, + exclusive_self_attention=exclusive_self_attention and double_self_attention, ) # is self-attn if encoder_hidden_states is none else: if norm_type == "ada_norm_single": # For Latte diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index e2ece5cb3685..53334364a209 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -48,6 +48,15 @@ XLA_AVAILABLE = False +def _apply_exclusive_self_attention(hidden_states: torch.Tensor, value: torch.Tensor) -> torch.Tensor: + if hidden_states.shape != value.shape: + return hidden_states + + eps = 1e-6 if value.dtype in (torch.float16, torch.bfloat16) else 1e-12 + value_normalized = F.normalize(value, p=2, dim=-1, eps=eps) + return hidden_states - (hidden_states * value_normalized).sum(dim=-1, keepdim=True) * value_normalized + + @maybe_allow_in_graph class Attention(nn.Module): r""" @@ -97,6 +106,8 @@ class Attention(nn.Module): A factor to rescale the output by dividing it with this value. residual_connection (`bool`, *optional*, defaults to `False`): Set to `True` to add the residual connection to the output. + exclusive_self_attention (`bool`, *optional*, defaults to `False`): + Whether to remove the value-vector component from self-attention outputs. _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`): Set to `True` if the attention block is loaded from a deprecated state dict. processor (`AttnProcessor`, *optional*, defaults to `None`): @@ -136,6 +147,7 @@ def __init__( pre_only=False, elementwise_affine: bool = True, is_causal: bool = False, + exclusive_self_attention: bool = False, ): super().__init__() @@ -159,6 +171,7 @@ def __init__( self.context_pre_only = context_pre_only self.pre_only = pre_only self.is_causal = is_causal + self.exclusive_self_attention = exclusive_self_attention # we make use of this private variable to know whether this class is loaded # with an deprecated state dict so that we can convert it on the fly @@ -1120,6 +1133,7 @@ def __call__( deprecate("scale", "1.0.0", deprecation_message) residual = hidden_states + is_self_attention = encoder_hidden_states is None if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) @@ -1154,6 +1168,8 @@ def __call__( attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) + if attn.exclusive_self_attention and is_self_attention: + hidden_states = _apply_exclusive_self_attention(hidden_states, value) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj @@ -2515,6 +2531,7 @@ def __call__( deprecate("scale", "1.0.0", deprecation_message) residual = hidden_states + is_self_attention = encoder_hidden_states is None if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) @@ -2561,6 +2578,8 @@ def __call__( query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale ) hidden_states = hidden_states.to(query.dtype) + if attn.exclusive_self_attention and is_self_attention: + hidden_states = _apply_exclusive_self_attention(hidden_states, value) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj @@ -2606,6 +2625,7 @@ def __call__( deprecate("scale", "1.0.0", deprecation_message) residual = hidden_states + is_self_attention = encoder_hidden_states is None if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) @@ -2674,6 +2694,9 @@ def __call__( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) + if attn.exclusive_self_attention and is_self_attention: + hidden_states = _apply_exclusive_self_attention(hidden_states, value) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) @@ -2717,6 +2740,7 @@ def __call__( deprecate("scale", "1.0.0", deprecation_message) residual = hidden_states + is_self_attention = encoder_hidden_states is None if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) @@ -2768,6 +2792,9 @@ def __call__( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) + if attn.exclusive_self_attention and is_self_attention: + hidden_states = _apply_exclusive_self_attention(hidden_states, value) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) @@ -2814,6 +2841,7 @@ def __call__( **kwargs, ) -> torch.Tensor: residual = hidden_states + is_self_attention = encoder_hidden_states is None if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) @@ -2884,6 +2912,9 @@ def __call__( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) + if attn.exclusive_self_attention and is_self_attention: + hidden_states = _apply_exclusive_self_attention(hidden_states, value) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) @@ -3695,6 +3726,7 @@ def __call__( deprecate("scale", "1.0.0", deprecation_message) residual = hidden_states + is_self_attention = encoder_hidden_states is None if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) @@ -3748,6 +3780,9 @@ def __call__( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) + if attn.exclusive_self_attention and is_self_attention: + hidden_states = _apply_exclusive_self_attention(hidden_states, value) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) @@ -4018,6 +4053,7 @@ def __call__( attention_mask: torch.Tensor | None = None, ) -> torch.Tensor: residual = hidden_states + is_self_attention = encoder_hidden_states is None input_ndim = hidden_states.ndim @@ -4066,6 +4102,9 @@ def __call__( hidden_states[start_idx:end_idx] = attn_slice + if attn.exclusive_self_attention and is_self_attention: + hidden_states = _apply_exclusive_self_attention(hidden_states, value) + hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj diff --git a/src/diffusers/models/transformers/dit_transformer_2d.py b/src/diffusers/models/transformers/dit_transformer_2d.py index 3d10c278cdbb..72daaf5d3d18 100644 --- a/src/diffusers/models/transformers/dit_transformer_2d.py +++ b/src/diffusers/models/transformers/dit_transformer_2d.py @@ -87,6 +87,7 @@ def __init__( norm_type: str = "ada_norm_zero", norm_elementwise_affine: bool = False, norm_eps: float = 1e-5, + exclusive_self_attention: bool = False, ): super().__init__() @@ -133,6 +134,7 @@ def __init__( norm_type=norm_type, norm_elementwise_affine=self.config.norm_elementwise_affine, norm_eps=self.config.norm_eps, + exclusive_self_attention=self.config.exclusive_self_attention, ) for _ in range(self.config.num_layers) ] diff --git a/src/diffusers/models/transformers/pixart_transformer_2d.py b/src/diffusers/models/transformers/pixart_transformer_2d.py index 2476668ba307..bb20614abe40 100644 --- a/src/diffusers/models/transformers/pixart_transformer_2d.py +++ b/src/diffusers/models/transformers/pixart_transformer_2d.py @@ -105,6 +105,7 @@ def __init__( use_additional_conditions: bool | None = None, caption_channels: int | None = None, attention_type: str | None = "default", + exclusive_self_attention: bool = False, ): super().__init__() @@ -165,6 +166,7 @@ def __init__( norm_elementwise_affine=self.config.norm_elementwise_affine, norm_eps=self.config.norm_eps, attention_type=self.config.attention_type, + exclusive_self_attention=self.config.exclusive_self_attention, ) for _ in range(self.config.num_layers) ] diff --git a/src/diffusers/models/transformers/transformer_2d.py b/src/diffusers/models/transformers/transformer_2d.py index 12f89201d752..3e915396e868 100644 --- a/src/diffusers/models/transformers/transformer_2d.py +++ b/src/diffusers/models/transformers/transformer_2d.py @@ -96,6 +96,7 @@ def __init__( caption_channels: int = None, interpolation_scale: float = None, use_additional_conditions: bool | None = None, + exclusive_self_attention: bool = False, ): super().__init__() @@ -199,6 +200,7 @@ def _init_continuous_input(self, norm_type): norm_elementwise_affine=self.config.norm_elementwise_affine, norm_eps=self.config.norm_eps, attention_type=self.config.attention_type, + exclusive_self_attention=self.config.exclusive_self_attention, ) for _ in range(self.config.num_layers) ] @@ -241,6 +243,7 @@ def _init_vectorized_inputs(self, norm_type): norm_elementwise_affine=self.config.norm_elementwise_affine, norm_eps=self.config.norm_eps, attention_type=self.config.attention_type, + exclusive_self_attention=self.config.exclusive_self_attention, ) for _ in range(self.config.num_layers) ] @@ -288,6 +291,7 @@ def _init_patched_inputs(self, norm_type): norm_elementwise_affine=self.config.norm_elementwise_affine, norm_eps=self.config.norm_eps, attention_type=self.config.attention_type, + exclusive_self_attention=self.config.exclusive_self_attention, ) for _ in range(self.config.num_layers) ] diff --git a/tests/models/test_exclusive_self_attention.py b/tests/models/test_exclusive_self_attention.py new file mode 100644 index 000000000000..8951324a2d44 --- /dev/null +++ b/tests/models/test_exclusive_self_attention.py @@ -0,0 +1,173 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +import torch.nn.functional as F + +from diffusers.models.attention import BasicTransformerBlock +from diffusers.models.attention_processor import ( + Attention, + AttnProcessor, + AttnProcessor2_0, + FusedAttnProcessor2_0, + SlicedAttnProcessor, + _apply_exclusive_self_attention, +) + + +class ExclusiveSelfAttentionTests(unittest.TestCase): + def test_apply_exclusive_self_attention_orthogonalizes_output(self): + torch.manual_seed(0) + hidden_states = torch.randn(2, 3, 4) + value = torch.randn(2, 3, 4) + + output = _apply_exclusive_self_attention(hidden_states, value) + value_normalized = F.normalize(value, p=2, dim=-1, eps=1e-12) + projection = (output * value_normalized).sum(dim=-1) + + assert torch.allclose(projection, torch.zeros_like(projection), atol=1e-5) + assert output.dtype == hidden_states.dtype + assert output.device == hidden_states.device + + def test_apply_exclusive_self_attention_handles_zero_values(self): + hidden_states = torch.randn(2, 3, 4) + value = torch.zeros_like(hidden_states) + + output = _apply_exclusive_self_attention(hidden_states, value) + + assert torch.isfinite(output).all() + assert torch.allclose(output, hidden_states) + + def test_apply_exclusive_self_attention_skips_shape_mismatch(self): + hidden_states = torch.randn(2, 3, 4) + value = torch.randn(2, 2, 4) + + output = _apply_exclusive_self_attention(hidden_states, value) + + assert output is hidden_states + + def _get_attention_pair(self, processor_factory, fused=False, cross_attention_dim=None): + torch.manual_seed(0) + base = Attention( + query_dim=8, + cross_attention_dim=cross_attention_dim, + heads=2, + dim_head=4, + dropout=0.0, + bias=True, + processor=AttnProcessor2_0() if fused else processor_factory(), + ) + exclusive = Attention( + query_dim=8, + cross_attention_dim=cross_attention_dim, + heads=2, + dim_head=4, + dropout=0.0, + bias=True, + exclusive_self_attention=True, + processor=AttnProcessor2_0() if fused else processor_factory(), + ) + exclusive.load_state_dict(base.state_dict()) + + if fused: + base.fuse_projections() + exclusive.fuse_projections() + base.set_processor(processor_factory()) + exclusive.set_processor(processor_factory()) + + return base, exclusive + + def test_exclusive_self_attention_changes_self_attention_output(self): + hidden_states = torch.randn(2, 4, 8) + processor_factories = [ + ("eager", AttnProcessor, False), + ("sliced", lambda: SlicedAttnProcessor(slice_size=2), False), + ] + if hasattr(F, "scaled_dot_product_attention"): + processor_factories.extend( + [ + ("sdpa", AttnProcessor2_0, False), + ("fused", FusedAttnProcessor2_0, True), + ] + ) + + for name, processor_factory, fused in processor_factories: + with self.subTest(name=name): + base, exclusive = self._get_attention_pair(processor_factory, fused=fused) + + base_output = base(hidden_states) + exclusive_output = exclusive(hidden_states) + + assert base_output.shape == exclusive_output.shape + assert torch.isfinite(exclusive_output).all() + assert not torch.allclose(base_output, exclusive_output) + + def test_exclusive_self_attention_does_not_change_cross_attention_output(self): + hidden_states = torch.randn(2, 4, 8) + encoder_hidden_states = torch.randn(2, 4, 8) + processor_factories = [ + ("eager", AttnProcessor, False), + ("sliced", lambda: SlicedAttnProcessor(slice_size=2), False), + ] + if hasattr(F, "scaled_dot_product_attention"): + processor_factories.extend( + [ + ("sdpa", AttnProcessor2_0, False), + ("fused", FusedAttnProcessor2_0, True), + ] + ) + + for name, processor_factory, fused in processor_factories: + with self.subTest(name=name): + base, exclusive = self._get_attention_pair(processor_factory, fused=fused, cross_attention_dim=8) + + base_output = base(hidden_states, encoder_hidden_states=encoder_hidden_states) + exclusive_output = exclusive(hidden_states, encoder_hidden_states=encoder_hidden_states) + + assert torch.allclose(base_output, exclusive_output) + + def test_basic_transformer_block_wires_exclusive_self_attention_to_self_attention_only(self): + block = BasicTransformerBlock( + dim=8, + num_attention_heads=2, + attention_head_dim=4, + cross_attention_dim=8, + exclusive_self_attention=True, + ) + assert block.attn1.exclusive_self_attention + assert not block.attn2.exclusive_self_attention + + block = BasicTransformerBlock( + dim=8, + num_attention_heads=2, + attention_head_dim=4, + cross_attention_dim=8, + only_cross_attention=True, + exclusive_self_attention=True, + ) + assert not block.attn1.exclusive_self_attention + assert not block.attn2.exclusive_self_attention + + block = BasicTransformerBlock( + dim=8, + num_attention_heads=2, + attention_head_dim=4, + double_self_attention=True, + exclusive_self_attention=True, + ) + assert block.attn1.exclusive_self_attention + assert block.attn2.exclusive_self_attention diff --git a/tests/models/test_layers_utils.py b/tests/models/test_layers_utils.py index eaeffa699db2..e6482636d6cb 100644 --- a/tests/models/test_layers_utils.py +++ b/tests/models/test_layers_utils.py @@ -532,3 +532,20 @@ def test_spatial_transformer_attention_bias(self): assert spatial_transformer_block.transformer_blocks[0].attn1.to_q.bias is not None assert spatial_transformer_block.transformer_blocks[0].attn1.to_k.bias is not None assert spatial_transformer_block.transformer_blocks[0].attn1.to_v.bias is not None + + def test_spatial_transformer_exclusive_self_attention_config(self): + sample = torch.randn(1, 32, 8, 8).to(torch_device) + spatial_transformer_block = Transformer2DModel( + num_attention_heads=1, + attention_head_dim=32, + in_channels=32, + exclusive_self_attention=True, + ).to(torch_device) + + assert spatial_transformer_block.config.exclusive_self_attention + assert spatial_transformer_block.transformer_blocks[0].attn1.exclusive_self_attention + + with torch.no_grad(): + output = spatial_transformer_block(sample).sample + + assert output.shape == sample.shape diff --git a/tests/models/transformers/test_models_dit_transformer2d.py b/tests/models/transformers/test_models_dit_transformer2d.py index 473a87637578..61be397c72c8 100644 --- a/tests/models/transformers/test_models_dit_transformer2d.py +++ b/tests/models/transformers/test_models_dit_transformer2d.py @@ -84,6 +84,19 @@ def test_correct_class_remapping_from_dict_config(self): model = Transformer2DModel.from_config(init_dict) assert isinstance(model, DiTTransformer2DModel) + def test_exclusive_self_attention_config(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + init_dict["exclusive_self_attention"] = True + model = self.model_class(**init_dict).to(torch_device) + + assert model.config.exclusive_self_attention + assert model.transformer_blocks[0].attn1.exclusive_self_attention + + with torch.no_grad(): + output = model(**inputs_dict).sample + + assert output.shape == (inputs_dict[self.main_input_name].shape[0],) + self.output_shape + def test_gradient_checkpointing_is_applied(self): expected_set = {"DiTTransformer2DModel"} super().test_gradient_checkpointing_is_applied(expected_set=expected_set) diff --git a/tests/models/transformers/test_models_pixart_transformer2d.py b/tests/models/transformers/test_models_pixart_transformer2d.py index 17c400cf1911..52a14db716f1 100644 --- a/tests/models/transformers/test_models_pixart_transformer2d.py +++ b/tests/models/transformers/test_models_pixart_transformer2d.py @@ -101,6 +101,20 @@ def test_correct_class_remapping_from_dict_config(self): model = Transformer2DModel.from_config(init_dict) assert isinstance(model, PixArtTransformer2DModel) + def test_exclusive_self_attention_config(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + init_dict["exclusive_self_attention"] = True + model = self.model_class(**init_dict).to(torch_device) + + assert model.config.exclusive_self_attention + assert model.transformer_blocks[0].attn1.exclusive_self_attention + assert not model.transformer_blocks[0].attn2.exclusive_self_attention + + with torch.no_grad(): + output = model(**inputs_dict).sample + + assert output.shape == (inputs_dict[self.main_input_name].shape[0],) + self.output_shape + def test_correct_class_remapping_from_pretrained_config(self): config = PixArtTransformer2DModel.load_config("PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="transformer") model = Transformer2DModel.from_config(config)