Skip to content

Commit d283af4

Browse files
committed
upd
Signed-off-by: Lancer <maruixiang6688@gmail.com>
1 parent 0635182 commit d283af4

5 files changed

Lines changed: 250 additions & 136 deletions

File tree

src/diffusers/models/transformers/transformer_longcat_audio_dit.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -188,8 +188,8 @@ def __call__(
188188
self,
189189
attn: "AudioDiTAttention",
190190
hidden_states: torch.Tensor,
191-
mask: torch.BoolTensor | None = None,
192-
rope: tuple | None = None,
191+
attention_mask: torch.BoolTensor | None = None,
192+
audio_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
193193
) -> torch.Tensor:
194194
batch_size = hidden_states.shape[0]
195195
query = attn.to_q(hidden_states)
@@ -205,20 +205,20 @@ def __call__(
205205
key = key.view(batch_size, -1, attn.heads, head_dim)
206206
value = value.view(batch_size, -1, attn.heads, head_dim)
207207

208-
if rope is not None:
209-
query = _apply_rotary_emb(query, rope)
210-
key = _apply_rotary_emb(key, rope)
208+
if audio_rotary_emb is not None:
209+
query = _apply_rotary_emb(query, audio_rotary_emb)
210+
key = _apply_rotary_emb(key, audio_rotary_emb)
211211

212212
hidden_states = dispatch_attention_fn(
213213
query,
214214
key,
215215
value,
216-
attn_mask=mask,
216+
attn_mask=attention_mask,
217217
backend=self._attention_backend,
218218
parallel_config=self._parallel_config,
219219
)
220-
if mask is not None:
221-
hidden_states = hidden_states * mask[:, :, None, None].to(hidden_states.dtype)
220+
if attention_mask is not None:
221+
hidden_states = hidden_states * attention_mask[:, :, None, None].to(hidden_states.dtype)
222222

223223
hidden_states = hidden_states.flatten(2, 3).to(query.dtype)
224224
hidden_states = attn.to_out[0](hidden_states)
@@ -261,11 +261,14 @@ def forward(
261261
attention_mask: torch.BoolTensor | None = None,
262262
audio_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
263263
prompt_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
264-
mask: torch.BoolTensor | None = None,
265-
rope: tuple | None = None,
266264
) -> torch.Tensor:
267265
if encoder_hidden_states is None:
268-
return self.processor(self, hidden_states, mask=mask, rope=rope)
266+
return self.processor(
267+
self,
268+
hidden_states,
269+
attention_mask=attention_mask,
270+
audio_rotary_emb=audio_rotary_emb,
271+
)
269272
return self.processor(
270273
self,
271274
hidden_states,
@@ -419,7 +422,11 @@ def forward(
419422

420423
norm_hidden_states = F.layer_norm(hidden_states.float(), (hidden_states.shape[-1],), eps=1e-6).type_as(hidden_states)
421424
norm_hidden_states = norm_hidden_states * (1 + scale_sa[:, None]) + shift_sa[:, None]
422-
attn_output = self.self_attn(norm_hidden_states, mask=mask, rope=rope)
425+
attn_output = self.self_attn(
426+
norm_hidden_states,
427+
attention_mask=mask,
428+
audio_rotary_emb=rope,
429+
)
423430
hidden_states = hidden_states + gate_sa.unsqueeze(1) * attn_output
424431

425432
if self.use_cross_attn:

src/diffusers/pipelines/longcat_audio_dit/pipeline_longcat_audio_dit.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -171,8 +171,7 @@ def __init__(
171171
transformer=transformer,
172172
)
173173
self.sample_rate = getattr(vae.config, "sample_rate", 24000)
174-
self.latent_hop = getattr(vae.config, "downsampling_ratio", 2048)
175-
self.vae_scale_factor = self.latent_hop
174+
self.vae_scale_factor = getattr(vae.config, "downsampling_ratio", 2048)
176175
self.latent_dim = getattr(transformer.config, "latent_dim", 64)
177176
self.max_wav_duration = 30.0
178177
self.text_norm_feat = True
@@ -321,8 +320,7 @@ def from_pretrained(
321320

322321
pipe = cls(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, transformer=transformer)
323322
pipe.sample_rate = config.get("sampling_rate", pipe.sample_rate)
324-
pipe.latent_hop = config.get("latent_hop", pipe.latent_hop)
325-
pipe.vae_scale_factor = pipe.latent_hop
323+
pipe.vae_scale_factor = config.get("vae_scale_factor", config.get("latent_hop", pipe.vae_scale_factor))
326324
pipe.max_wav_duration = config.get("max_wav_duration", pipe.max_wav_duration)
327325
pipe.text_norm_feat = config.get("text_norm_feat", pipe.text_norm_feat)
328326
pipe.text_add_embed = config.get("text_add_embed", pipe.text_add_embed)
Lines changed: 120 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,125 @@
1+
# coding=utf-8
2+
# Copyright 2025 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import pytest
117
import torch
218

319
from diffusers import LongCatAudioDiTTransformer
20+
from diffusers.utils.torch_utils import randn_tensor
21+
22+
from ...testing_utils import enable_full_determinism, torch_device
23+
from ..testing_utils import (
24+
AttentionTesterMixin,
25+
BaseModelTesterConfig,
26+
MemoryTesterMixin,
27+
ModelTesterMixin,
28+
TorchCompileTesterMixin,
29+
)
30+
31+
32+
enable_full_determinism()
33+
34+
35+
class LongCatAudioDiTTransformerTesterConfig(BaseModelTesterConfig):
36+
@property
37+
def main_input_name(self) -> str:
38+
return "hidden_states"
39+
40+
@property
41+
def model_class(self):
42+
return LongCatAudioDiTTransformer
43+
44+
@property
45+
def output_shape(self) -> tuple[int, ...]:
46+
return (16, 8)
47+
48+
@property
49+
def generator(self):
50+
return torch.Generator("cpu").manual_seed(0)
51+
52+
def get_init_dict(self) -> dict[str, int | bool | float | str]:
53+
return {
54+
"dit_dim": 64,
55+
"dit_depth": 2,
56+
"dit_heads": 4,
57+
"dit_text_dim": 32,
58+
"latent_dim": 8,
59+
"text_conv": False,
60+
}
61+
62+
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
63+
batch_size = 1
64+
sequence_length = 16
65+
encoder_sequence_length = 10
66+
latent_dim = 8
67+
text_dim = 32
68+
69+
return {
70+
"hidden_states": randn_tensor(
71+
(batch_size, sequence_length, latent_dim), generator=self.generator, device=torch_device
72+
),
73+
"encoder_hidden_states": randn_tensor(
74+
(batch_size, encoder_sequence_length, text_dim), generator=self.generator, device=torch_device
75+
),
76+
"encoder_attention_mask": torch.ones(
77+
batch_size, encoder_sequence_length, dtype=torch.bool, device=torch_device
78+
),
79+
"attention_mask": torch.ones(batch_size, sequence_length, dtype=torch.bool, device=torch_device),
80+
"timestep": torch.ones(batch_size, device=torch_device),
81+
}
82+
83+
84+
class TestLongCatAudioDiTTransformer(LongCatAudioDiTTransformerTesterConfig, ModelTesterMixin):
85+
pass
86+
87+
88+
class TestLongCatAudioDiTTransformerMemory(LongCatAudioDiTTransformerTesterConfig, MemoryTesterMixin):
89+
def test_layerwise_casting_memory(self):
90+
pytest.skip("LongCatAudioDiTTransformer does not support standard layerwise casting memory tests yet.")
91+
92+
def test_layerwise_casting_training(self):
93+
pytest.skip("LongCatAudioDiTTransformer does not support standard layerwise casting training tests yet.")
94+
95+
def test_group_offloading_with_layerwise_casting(self, *args, **kwargs):
96+
pytest.skip("LongCatAudioDiTTransformer does not support combined group offloading and layerwise casting tests yet.")
97+
98+
99+
class TestLongCatAudioDiTTransformerCompile(LongCatAudioDiTTransformerTesterConfig, TorchCompileTesterMixin):
100+
def test_torch_compile_repeated_blocks(self):
101+
pytest.skip("LongCatAudioDiTTransformer does not define repeated blocks for regional compilation.")
102+
103+
104+
class TestLongCatAudioDiTTransformerAttention(LongCatAudioDiTTransformerTesterConfig, AttentionTesterMixin):
105+
pass
106+
107+
108+
def test_longcat_audio_attention_uses_standard_self_attn_kwargs():
109+
from diffusers.models.transformers.transformer_longcat_audio_dit import AudioDiTAttention
110+
111+
attn = AudioDiTAttention(q_dim=4, kv_dim=None, heads=1, dim_head=4, dropout=0.0, bias=False)
112+
113+
eye = torch.eye(4)
114+
with torch.no_grad():
115+
attn.to_q.weight.copy_(eye)
116+
attn.to_k.weight.copy_(eye)
117+
attn.to_v.weight.copy_(eye)
118+
attn.to_out[0].weight.copy_(eye)
119+
120+
hidden_states = torch.tensor([[[1.0, 0.0, 0.0, 0.0], [0.5, 0.5, 0.5, 0.5]]])
121+
attention_mask = torch.tensor([[True, False]])
4122

123+
output = attn(hidden_states=hidden_states, attention_mask=attention_mask)
5124

6-
def test_longcat_audio_transformer_forward_shape():
7-
model = LongCatAudioDiTTransformer(
8-
dit_dim=64,
9-
dit_depth=2,
10-
dit_heads=4,
11-
dit_text_dim=32,
12-
latent_dim=8,
13-
text_conv=False,
14-
)
15-
hidden_states = torch.randn(2, 16, 8)
16-
encoder_hidden_states = torch.randn(2, 10, 32)
17-
encoder_attention_mask = torch.ones(2, 10, dtype=torch.bool)
18-
timestep = torch.tensor([1.0, 1.0])
19-
20-
output = model(
21-
hidden_states=hidden_states,
22-
encoder_hidden_states=encoder_hidden_states,
23-
encoder_attention_mask=encoder_attention_mask,
24-
timestep=timestep,
25-
)
26-
27-
assert output.sample.shape == hidden_states.shape
28-
29-
30-
def test_longcat_audio_transformer_masked_forward():
31-
model = LongCatAudioDiTTransformer(
32-
dit_dim=64,
33-
dit_depth=2,
34-
dit_heads=4,
35-
dit_text_dim=32,
36-
latent_dim=8,
37-
text_conv=False,
38-
)
39-
hidden_states = torch.randn(2, 16, 8)
40-
encoder_hidden_states = torch.randn(2, 10, 32)
41-
encoder_attention_mask = torch.tensor([[1] * 10, [1] * 6 + [0] * 4], dtype=torch.bool)
42-
attention_mask = torch.tensor([[1] * 16, [1] * 9 + [0] * 7], dtype=torch.bool)
43-
timestep = torch.tensor([1.0, 1.0])
44-
45-
output = model(
46-
hidden_states=hidden_states,
47-
encoder_hidden_states=encoder_hidden_states,
48-
encoder_attention_mask=encoder_attention_mask,
49-
timestep=timestep,
50-
attention_mask=attention_mask,
51-
)
52-
53-
assert output.sample.shape == hidden_states.shape
54-
assert torch.all(output.sample[1, 9:] == 0)
125+
assert torch.allclose(output[:, 1], torch.zeros_like(output[:, 1]))

tests/pipelines/longcat_audio_dit/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)