diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 1b1f6b3032b3..915f1a2d0a42 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -235,6 +235,7 @@ "CosmosControlNetModel", "CosmosTransformer3DModel", "DiTTransformer2DModel", + "DualTransformer2DModel", "EasyAnimateTransformer3DModel", "ErnieImageTransformer2DModel", "Flux2Transformer2DModel", @@ -1057,6 +1058,7 @@ CosmosControlNetModel, CosmosTransformer3DModel, DiTTransformer2DModel, + DualTransformer2DModel, EasyAnimateTransformer3DModel, ErnieImageTransformer2DModel, Flux2Transformer2DModel, diff --git a/src/diffusers/models/transformers/dual_transformer_2d.py b/src/diffusers/models/transformers/dual_transformer_2d.py index c25c6e9c4227..c4896e7efe65 100644 --- a/src/diffusers/models/transformers/dual_transformer_2d.py +++ b/src/diffusers/models/transformers/dual_transformer_2d.py @@ -99,6 +99,7 @@ def forward( encoder_hidden_states, timestep=None, attention_mask=None, + encoder_attention_mask=None, cross_attention_kwargs=None, return_dict: bool = True, ): @@ -113,6 +114,8 @@ def forward( Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. attention_mask (`torch.Tensor`, *optional*): Optional attention mask to be applied in Attention. + encoder_attention_mask (`torch.Tensor`, *optional*): + Optional cross-attention mask to be applied to `encoder_hidden_states`. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in @@ -130,15 +133,20 @@ def forward( encoded_states = [] tokens_start = 0 - # attention_mask is not used yet for i in range(2): # for each of the two transformers, pass the corresponding condition tokens condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]] + condition_mask = None + if encoder_attention_mask is not None: + condition_mask = encoder_attention_mask[..., tokens_start : tokens_start + self.condition_lengths[i]] + transformer_index = self.transformer_index_for_condition[i] encoded_state = self.transformers[transformer_index]( input_states, encoder_hidden_states=condition_state, timestep=timestep, + attention_mask=attention_mask, + encoder_attention_mask=condition_mask, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] diff --git a/src/diffusers/models/transformers/transformer_2d.py b/src/diffusers/models/transformers/transformer_2d.py index 12f89201d752..1f034f5ab218 100644 --- a/src/diffusers/models/transformers/transformer_2d.py +++ b/src/diffusers/models/transformers/transformer_2d.py @@ -517,7 +517,7 @@ def _get_output_for_vectorized_inputs(self, hidden_states): # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) logits = logits.permute(0, 2, 1) # log(p(x_0)) - output = F.log_softmax(logits.double(), dim=1).float() + output = F.log_softmax(logits.float(), dim=1) return output def _get_output_for_patched_inputs( diff --git a/src/diffusers/models/transformers/transformer_temporal.py b/src/diffusers/models/transformers/transformer_temporal.py index b6fedcb26cc8..9a1991d28578 100644 --- a/src/diffusers/models/transformers/transformer_temporal.py +++ b/src/diffusers/models/transformers/transformer_temporal.py @@ -94,6 +94,9 @@ def __init__( inner_dim = num_attention_heads * attention_head_dim self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + if self.out_channels != in_channels: + raise ValueError("`out_channels` must be `None` or equal to `in_channels` for TransformerTemporalModel.") self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) self.proj_in = nn.Linear(in_channels, inner_dim) @@ -118,7 +121,7 @@ def __init__( ] ) - self.proj_out = nn.Linear(inner_dim, in_channels) + self.proj_out = nn.Linear(inner_dim, self.out_channels) def forward( self, @@ -272,8 +275,11 @@ def __init__( # 4. Define output layers self.out_channels = in_channels if out_channels is None else out_channels - # TODO: should use out_channels for continuous projections - self.proj_out = nn.Linear(inner_dim, in_channels) + if self.out_channels != in_channels: + raise ValueError( + "`out_channels` must be `None` or equal to `in_channels` for TransformerSpatioTemporalModel." + ) + self.proj_out = nn.Linear(inner_dim, self.out_channels) self.gradient_checkpointing = False diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 9bfb73c1999e..6bb1078c3a55 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -1110,6 +1110,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class DualTransformer2DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class EasyAnimateTransformer3DModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/models/test_layers_utils.py b/tests/models/test_layers_utils.py index eaeffa699db2..8b0c794979e4 100644 --- a/tests/models/test_layers_utils.py +++ b/tests/models/test_layers_utils.py @@ -15,6 +15,7 @@ import unittest +from unittest import mock import numpy as np import torch @@ -23,11 +24,11 @@ from diffusers.models.attention import GEGLU, AdaLayerNorm, ApproximateGELU from diffusers.models.embeddings import get_timestep_embedding from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D +from diffusers.models.transformers import transformer_2d as transformer_2d_module from diffusers.models.transformers.transformer_2d import Transformer2DModel from ..testing_utils import ( backend_manual_seed, - require_torch_accelerator_with_fp64, require_torch_version_greater_equal, torch_device, ) @@ -432,7 +433,6 @@ def test_spatial_transformer_dropout(self): ) assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) - @require_torch_accelerator_with_fp64 def test_spatial_transformer_discrete(self): torch.manual_seed(0) backend_manual_seed(torch_device, 0) @@ -451,8 +451,15 @@ def test_spatial_transformer_discrete(self): .eval() ) + original_log_softmax = transformer_2d_module.F.log_softmax + + def checked_log_softmax(input, *args, **kwargs): + assert input.dtype != torch.float64 + return original_log_softmax(input, *args, **kwargs) + with torch.no_grad(): - attention_scores = spatial_transformer_block(sample).sample + with mock.patch.object(transformer_2d_module.F, "log_softmax", side_effect=checked_log_softmax): + attention_scores = spatial_transformer_block(sample).sample assert attention_scores.shape == (1, num_embed - 1, 32) diff --git a/tests/models/transformers/test_models_dual_transformer_2d.py b/tests/models/transformers/test_models_dual_transformer_2d.py new file mode 100644 index 000000000000..58bc1f3fba8c --- /dev/null +++ b/tests/models/transformers/test_models_dual_transformer_2d.py @@ -0,0 +1,156 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from torch import nn + +from diffusers.models import DualTransformer2DModel +from diffusers.models.unets.unet_2d_blocks import CrossAttnDownBlock2D + + +class CapturingTransformer(nn.Module): + def __init__(self, delta): + super().__init__() + self.delta = delta + self.calls = [] + + def forward( + self, + hidden_states, + encoder_hidden_states=None, + timestep=None, + attention_mask=None, + encoder_attention_mask=None, + cross_attention_kwargs=None, + return_dict=True, + ): + self.calls.append( + { + "encoder_hidden_states": encoder_hidden_states, + "timestep": timestep, + "attention_mask": attention_mask, + "encoder_attention_mask": encoder_attention_mask, + "cross_attention_kwargs": cross_attention_kwargs, + "return_dict": return_dict, + } + ) + return (hidden_states + self.delta,) + + +class DualTransformer2DModelTests(unittest.TestCase): + def get_model_with_capturing_transformers(self): + model = DualTransformer2DModel( + num_attention_heads=1, + attention_head_dim=4, + in_channels=4, + num_layers=1, + norm_num_groups=1, + cross_attention_dim=4, + ) + transformer_0 = CapturingTransformer(delta=4) + transformer_1 = CapturingTransformer(delta=2) + model.transformers = nn.ModuleList([transformer_0, transformer_1]) + model.condition_lengths = [2, 3] + model.transformer_index_for_condition = [1, 0] + return model, transformer_0, transformer_1 + + def check_mask_routing(self, encoder_attention_mask): + model, transformer_0, transformer_1 = self.get_model_with_capturing_transformers() + + hidden_states = torch.randn(1, 4, 2, 2) + encoder_hidden_states = torch.randn(1, 5, 4) + attention_mask = torch.ones(1, 4) + timestep = torch.tensor([1]) + cross_attention_kwargs = {"foo": "bar"} + + output = model( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + ) + + assert torch.equal(output.sample, hidden_states + 3) + + assert len(transformer_1.calls) == 1 + first_condition_call = transformer_1.calls[0] + assert first_condition_call["attention_mask"] is attention_mask + assert first_condition_call["timestep"] is timestep + assert first_condition_call["cross_attention_kwargs"] is cross_attention_kwargs + assert first_condition_call["return_dict"] is False + assert torch.equal(first_condition_call["encoder_hidden_states"], encoder_hidden_states[:, :2]) + assert torch.equal(first_condition_call["encoder_attention_mask"], encoder_attention_mask[..., :2]) + + assert len(transformer_0.calls) == 1 + second_condition_call = transformer_0.calls[0] + assert second_condition_call["attention_mask"] is attention_mask + assert second_condition_call["timestep"] is timestep + assert second_condition_call["cross_attention_kwargs"] is cross_attention_kwargs + assert second_condition_call["return_dict"] is False + assert torch.equal(second_condition_call["encoder_hidden_states"], encoder_hidden_states[:, 2:5]) + assert torch.equal(second_condition_call["encoder_attention_mask"], encoder_attention_mask[..., 2:5]) + + def test_forward_passes_attention_masks_to_child_transformers(self): + self.check_mask_routing(torch.tensor([[1.0, 1.0, 0.0, 1.0, 0.0]])) + self.check_mask_routing(torch.tensor([[[0.0, 0.0, -10000.0, 0.0, -10000.0]]])) + + def test_forward_tuple_output(self): + model, _, _ = self.get_model_with_capturing_transformers() + + hidden_states = torch.randn(1, 4, 2, 2) + output = model( + hidden_states, + encoder_hidden_states=torch.randn(1, 5, 4), + return_dict=False, + ) + + assert isinstance(output, tuple) + assert torch.equal(output[0], hidden_states + 3) + + def test_cross_attn_down_block_dual_cross_attention_accepts_encoder_attention_mask(self): + block = CrossAttnDownBlock2D( + in_channels=4, + out_channels=4, + temb_channels=8, + num_layers=1, + transformer_layers_per_block=1, + num_attention_heads=1, + cross_attention_dim=8, + dual_cross_attention=True, + resnet_groups=1, + add_downsample=False, + ) + block.eval() + + with torch.no_grad(): + output = block( + torch.randn(1, 4, 4, 4), + temb=torch.randn(1, 8), + encoder_hidden_states=torch.randn(1, 77 + 257, 8), + attention_mask=torch.ones(1, 16), + encoder_attention_mask=torch.ones(1, 77 + 257), + ) + + assert output[0].shape == (1, 4, 4, 4) + + def test_top_level_import(self): + from diffusers import DualTransformer2DModel as TopLevelDualTransformer2DModel + from diffusers.models import DualTransformer2DModel as ModelsDualTransformer2DModel + + assert TopLevelDualTransformer2DModel is ModelsDualTransformer2DModel diff --git a/tests/models/transformers/test_models_transformer_temporal.py b/tests/models/transformers/test_models_transformer_temporal.py index aff83be51124..60bea217e62f 100644 --- a/tests/models/transformers/test_models_transformer_temporal.py +++ b/tests/models/transformers/test_models_transformer_temporal.py @@ -18,6 +18,7 @@ import torch from diffusers.models.transformers import TransformerTemporalModel +from diffusers.models.transformers.transformer_temporal import TransformerSpatioTemporalModel from ...testing_utils import ( enable_full_determinism, @@ -65,3 +66,57 @@ def prepare_init_args_and_inputs_for_common(self): } inputs_dict = self.dummy_input return init_dict, inputs_dict + + def test_out_channels_none_and_equal_channels_supported(self): + model = TransformerTemporalModel( + num_attention_heads=1, + attention_head_dim=4, + in_channels=4, + out_channels=None, + num_layers=1, + norm_num_groups=1, + ) + assert model.proj_out.out_features == 4 + + model = TransformerTemporalModel( + num_attention_heads=1, + attention_head_dim=4, + in_channels=4, + out_channels=4, + num_layers=1, + norm_num_groups=1, + ) + assert model.proj_out.out_features == 4 + + hidden_states = torch.randn((2, 4, 4, 4)).to(torch_device) + output = model(hidden_states, num_frames=2).sample + assert output.shape == hidden_states.shape + + spatio_temporal_model = TransformerSpatioTemporalModel( + num_attention_heads=1, + attention_head_dim=4, + in_channels=32, + out_channels=32, + num_layers=1, + ) + assert spatio_temporal_model.proj_out.out_features == 32 + + def test_out_channels_mismatch_raises(self): + with self.assertRaisesRegex(ValueError, "out_channels.*in_channels"): + TransformerTemporalModel( + num_attention_heads=1, + attention_head_dim=4, + in_channels=4, + out_channels=8, + num_layers=1, + norm_num_groups=1, + ) + + with self.assertRaisesRegex(ValueError, "out_channels.*in_channels"): + TransformerSpatioTemporalModel( + num_attention_heads=1, + attention_head_dim=4, + in_channels=32, + out_channels=64, + num_layers=1, + )