Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@
"CosmosControlNetModel",
"CosmosTransformer3DModel",
"DiTTransformer2DModel",
"DualTransformer2DModel",
"EasyAnimateTransformer3DModel",
"ErnieImageTransformer2DModel",
"Flux2Transformer2DModel",
Expand Down Expand Up @@ -1057,6 +1058,7 @@
CosmosControlNetModel,
CosmosTransformer3DModel,
DiTTransformer2DModel,
DualTransformer2DModel,
EasyAnimateTransformer3DModel,
ErnieImageTransformer2DModel,
Flux2Transformer2DModel,
Expand Down
10 changes: 9 additions & 1 deletion src/diffusers/models/transformers/dual_transformer_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def forward(
encoder_hidden_states,
timestep=None,
attention_mask=None,
encoder_attention_mask=None,
cross_attention_kwargs=None,
return_dict: bool = True,
):
Expand All @@ -113,6 +114,8 @@ def forward(
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
attention_mask (`torch.Tensor`, *optional*):
Optional attention mask to be applied in Attention.
encoder_attention_mask (`torch.Tensor`, *optional*):
Optional cross-attention mask to be applied to `encoder_hidden_states`.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
Expand All @@ -130,15 +133,20 @@ def forward(

encoded_states = []
tokens_start = 0
# attention_mask is not used yet
for i in range(2):
# for each of the two transformers, pass the corresponding condition tokens
condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]]
condition_mask = None
if encoder_attention_mask is not None:
condition_mask = encoder_attention_mask[..., tokens_start : tokens_start + self.condition_lengths[i]]

transformer_index = self.transformer_index_for_condition[i]
encoded_state = self.transformers[transformer_index](
input_states,
encoder_hidden_states=condition_state,
timestep=timestep,
attention_mask=attention_mask,
encoder_attention_mask=condition_mask,
cross_attention_kwargs=cross_attention_kwargs,
return_dict=False,
)[0]
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/transformers/transformer_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ def _get_output_for_vectorized_inputs(self, hidden_states):
# (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
logits = logits.permute(0, 2, 1)
# log(p(x_0))
output = F.log_softmax(logits.double(), dim=1).float()
output = F.log_softmax(logits.float(), dim=1)
return output

def _get_output_for_patched_inputs(
Expand Down
12 changes: 9 additions & 3 deletions src/diffusers/models/transformers/transformer_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ def __init__(
inner_dim = num_attention_heads * attention_head_dim

self.in_channels = in_channels
self.out_channels = in_channels if out_channels is None else out_channels
if self.out_channels != in_channels:
raise ValueError("`out_channels` must be `None` or equal to `in_channels` for TransformerTemporalModel.")

self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
self.proj_in = nn.Linear(in_channels, inner_dim)
Expand All @@ -118,7 +121,7 @@ def __init__(
]
)

self.proj_out = nn.Linear(inner_dim, in_channels)
self.proj_out = nn.Linear(inner_dim, self.out_channels)

def forward(
self,
Expand Down Expand Up @@ -272,8 +275,11 @@ def __init__(

# 4. Define output layers
self.out_channels = in_channels if out_channels is None else out_channels
# TODO: should use out_channels for continuous projections
self.proj_out = nn.Linear(inner_dim, in_channels)
if self.out_channels != in_channels:
raise ValueError(
"`out_channels` must be `None` or equal to `in_channels` for TransformerSpatioTemporalModel."
)
self.proj_out = nn.Linear(inner_dim, self.out_channels)

self.gradient_checkpointing = False

Expand Down
15 changes: 15 additions & 0 deletions src/diffusers/utils/dummy_pt_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -1110,6 +1110,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])


class DualTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])

@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])

@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])


class EasyAnimateTransformer3DModel(metaclass=DummyObject):
_backends = ["torch"]

Expand Down
13 changes: 10 additions & 3 deletions tests/models/test_layers_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@


import unittest
from unittest import mock

import numpy as np
import torch
Expand All @@ -23,11 +24,11 @@
from diffusers.models.attention import GEGLU, AdaLayerNorm, ApproximateGELU
from diffusers.models.embeddings import get_timestep_embedding
from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
from diffusers.models.transformers import transformer_2d as transformer_2d_module
from diffusers.models.transformers.transformer_2d import Transformer2DModel

from ..testing_utils import (
backend_manual_seed,
require_torch_accelerator_with_fp64,
require_torch_version_greater_equal,
torch_device,
)
Expand Down Expand Up @@ -432,7 +433,6 @@ def test_spatial_transformer_dropout(self):
)
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)

@require_torch_accelerator_with_fp64
def test_spatial_transformer_discrete(self):
torch.manual_seed(0)
backend_manual_seed(torch_device, 0)
Expand All @@ -451,8 +451,15 @@ def test_spatial_transformer_discrete(self):
.eval()
)

original_log_softmax = transformer_2d_module.F.log_softmax

def checked_log_softmax(input, *args, **kwargs):
assert input.dtype != torch.float64
return original_log_softmax(input, *args, **kwargs)

with torch.no_grad():
attention_scores = spatial_transformer_block(sample).sample
with mock.patch.object(transformer_2d_module.F, "log_softmax", side_effect=checked_log_softmax):
attention_scores = spatial_transformer_block(sample).sample

assert attention_scores.shape == (1, num_embed - 1, 32)

Expand Down
156 changes: 156 additions & 0 deletions tests/models/transformers/test_models_dual_transformer_2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import torch
from torch import nn

from diffusers.models import DualTransformer2DModel
from diffusers.models.unets.unet_2d_blocks import CrossAttnDownBlock2D


class CapturingTransformer(nn.Module):
def __init__(self, delta):
super().__init__()
self.delta = delta
self.calls = []

def forward(
self,
hidden_states,
encoder_hidden_states=None,
timestep=None,
attention_mask=None,
encoder_attention_mask=None,
cross_attention_kwargs=None,
return_dict=True,
):
self.calls.append(
{
"encoder_hidden_states": encoder_hidden_states,
"timestep": timestep,
"attention_mask": attention_mask,
"encoder_attention_mask": encoder_attention_mask,
"cross_attention_kwargs": cross_attention_kwargs,
"return_dict": return_dict,
}
)
return (hidden_states + self.delta,)


class DualTransformer2DModelTests(unittest.TestCase):
def get_model_with_capturing_transformers(self):
model = DualTransformer2DModel(
num_attention_heads=1,
attention_head_dim=4,
in_channels=4,
num_layers=1,
norm_num_groups=1,
cross_attention_dim=4,
)
transformer_0 = CapturingTransformer(delta=4)
transformer_1 = CapturingTransformer(delta=2)
model.transformers = nn.ModuleList([transformer_0, transformer_1])
model.condition_lengths = [2, 3]
model.transformer_index_for_condition = [1, 0]
return model, transformer_0, transformer_1

def check_mask_routing(self, encoder_attention_mask):
model, transformer_0, transformer_1 = self.get_model_with_capturing_transformers()

hidden_states = torch.randn(1, 4, 2, 2)
encoder_hidden_states = torch.randn(1, 5, 4)
attention_mask = torch.ones(1, 4)
timestep = torch.tensor([1])
cross_attention_kwargs = {"foo": "bar"}

output = model(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
timestep=timestep,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
)

assert torch.equal(output.sample, hidden_states + 3)

assert len(transformer_1.calls) == 1
first_condition_call = transformer_1.calls[0]
assert first_condition_call["attention_mask"] is attention_mask
assert first_condition_call["timestep"] is timestep
assert first_condition_call["cross_attention_kwargs"] is cross_attention_kwargs
assert first_condition_call["return_dict"] is False
assert torch.equal(first_condition_call["encoder_hidden_states"], encoder_hidden_states[:, :2])
assert torch.equal(first_condition_call["encoder_attention_mask"], encoder_attention_mask[..., :2])

assert len(transformer_0.calls) == 1
second_condition_call = transformer_0.calls[0]
assert second_condition_call["attention_mask"] is attention_mask
assert second_condition_call["timestep"] is timestep
assert second_condition_call["cross_attention_kwargs"] is cross_attention_kwargs
assert second_condition_call["return_dict"] is False
assert torch.equal(second_condition_call["encoder_hidden_states"], encoder_hidden_states[:, 2:5])
assert torch.equal(second_condition_call["encoder_attention_mask"], encoder_attention_mask[..., 2:5])

def test_forward_passes_attention_masks_to_child_transformers(self):
self.check_mask_routing(torch.tensor([[1.0, 1.0, 0.0, 1.0, 0.0]]))
self.check_mask_routing(torch.tensor([[[0.0, 0.0, -10000.0, 0.0, -10000.0]]]))

def test_forward_tuple_output(self):
model, _, _ = self.get_model_with_capturing_transformers()

hidden_states = torch.randn(1, 4, 2, 2)
output = model(
hidden_states,
encoder_hidden_states=torch.randn(1, 5, 4),
return_dict=False,
)

assert isinstance(output, tuple)
assert torch.equal(output[0], hidden_states + 3)

def test_cross_attn_down_block_dual_cross_attention_accepts_encoder_attention_mask(self):
block = CrossAttnDownBlock2D(
in_channels=4,
out_channels=4,
temb_channels=8,
num_layers=1,
transformer_layers_per_block=1,
num_attention_heads=1,
cross_attention_dim=8,
dual_cross_attention=True,
resnet_groups=1,
add_downsample=False,
)
block.eval()

with torch.no_grad():
output = block(
torch.randn(1, 4, 4, 4),
temb=torch.randn(1, 8),
encoder_hidden_states=torch.randn(1, 77 + 257, 8),
attention_mask=torch.ones(1, 16),
encoder_attention_mask=torch.ones(1, 77 + 257),
)

assert output[0].shape == (1, 4, 4, 4)

def test_top_level_import(self):
from diffusers import DualTransformer2DModel as TopLevelDualTransformer2DModel
from diffusers.models import DualTransformer2DModel as ModelsDualTransformer2DModel

assert TopLevelDualTransformer2DModel is ModelsDualTransformer2DModel
55 changes: 55 additions & 0 deletions tests/models/transformers/test_models_transformer_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch

from diffusers.models.transformers import TransformerTemporalModel
from diffusers.models.transformers.transformer_temporal import TransformerSpatioTemporalModel

from ...testing_utils import (
enable_full_determinism,
Expand Down Expand Up @@ -65,3 +66,57 @@ def prepare_init_args_and_inputs_for_common(self):
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict

def test_out_channels_none_and_equal_channels_supported(self):
model = TransformerTemporalModel(
num_attention_heads=1,
attention_head_dim=4,
in_channels=4,
out_channels=None,
num_layers=1,
norm_num_groups=1,
)
assert model.proj_out.out_features == 4

model = TransformerTemporalModel(
num_attention_heads=1,
attention_head_dim=4,
in_channels=4,
out_channels=4,
num_layers=1,
norm_num_groups=1,
)
assert model.proj_out.out_features == 4

hidden_states = torch.randn((2, 4, 4, 4)).to(torch_device)
output = model(hidden_states, num_frames=2).sample
assert output.shape == hidden_states.shape

spatio_temporal_model = TransformerSpatioTemporalModel(
num_attention_heads=1,
attention_head_dim=4,
in_channels=32,
out_channels=32,
num_layers=1,
)
assert spatio_temporal_model.proj_out.out_features == 32

def test_out_channels_mismatch_raises(self):
with self.assertRaisesRegex(ValueError, "out_channels.*in_channels"):
TransformerTemporalModel(
num_attention_heads=1,
attention_head_dim=4,
in_channels=4,
out_channels=8,
num_layers=1,
norm_num_groups=1,
)

with self.assertRaisesRegex(ValueError, "out_channels.*in_channels"):
TransformerSpatioTemporalModel(
num_attention_heads=1,
attention_head_dim=4,
in_channels=32,
out_channels=64,
num_layers=1,
)
Loading