-
Notifications
You must be signed in to change notification settings - Fork 7k
Expand file tree
/
Copy pathtransformer_ltx2.py
More file actions
1639 lines (1430 loc) · 76.1 KB
/
transformer_ltx2.py
File metadata and controls
1639 lines (1430 loc) · 76.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2025 The Lightricks team and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from dataclasses import dataclass
from typing import Any
import torch
import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import BaseOutput, apply_lora_scale, is_torch_version, logging
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
from ..cache_utils import CacheMixin
from ..embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings, PixArtAlphaTextProjection
from ..modeling_utils import ModelMixin
from ..normalization import RMSNorm
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def apply_interleaved_rotary_emb(x: torch.Tensor, freqs: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
cos, sin = freqs
x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, C // 2]
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2)
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
return out
def apply_split_rotary_emb(x: torch.Tensor, freqs: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
cos, sin = freqs
x_dtype = x.dtype
needs_reshape = False
if x.ndim != 4 and cos.ndim == 4:
# cos is (b, h, t, r) -> reshape x to (b, h, t, dim_per_head)
b, h, t, _ = cos.shape
x = x.reshape(b, t, h, -1).swapaxes(1, 2)
needs_reshape = True
# Split last dim (2*r) into (d=2, r)
last = x.shape[-1]
if last % 2 != 0:
raise ValueError(f"Expected x.shape[-1] to be even for split rotary, got {last}.")
r = last // 2
# (..., 2, r)
split_x = x.reshape(*x.shape[:-1], 2, r).float() # Explicitly upcast to float
first_x = split_x[..., :1, :] # (..., 1, r)
second_x = split_x[..., 1:, :] # (..., 1, r)
cos_u = cos.unsqueeze(-2) # broadcast to (..., 1, r) against (..., 2, r)
sin_u = sin.unsqueeze(-2)
out = split_x * cos_u
first_out = out[..., :1, :]
second_out = out[..., 1:, :]
first_out.addcmul_(-sin_u, second_x)
second_out.addcmul_(sin_u, first_x)
out = out.reshape(*out.shape[:-2], last)
if needs_reshape:
out = out.swapaxes(1, 2).reshape(b, t, -1)
out = out.to(dtype=x_dtype)
return out
@dataclass
class AudioVisualModelOutput(BaseOutput):
r"""
Holds the output of an audiovisual model which produces both visual (e.g. video) and audio outputs.
Args:
sample (`torch.Tensor` of shape `(batch_size, num_channels, num_frames, height, width)`):
The hidden states output conditioned on the `encoder_hidden_states` input, representing the visual output
of the model. This is typically a video (spatiotemporal) output.
audio_sample (`torch.Tensor` of shape `(batch_size, TODO)`):
The audio output of the audiovisual model.
"""
sample: "torch.Tensor" # noqa: F821
audio_sample: "torch.Tensor" # noqa: F821
class LTX2AdaLayerNormSingle(nn.Module):
r"""
Norm layer adaptive layer norm single (adaLN-single).
As proposed in PixArt-Alpha (see: https://huggingface.co/papers/2310.00426; Section 2.3) and adapted by the LTX-2.0
model. In particular, the number of modulation parameters to be calculated is now configurable.
Parameters:
embedding_dim (`int`): The size of each embedding vector.
num_mod_params (`int`, *optional*, defaults to `6`):
The number of modulation parameters which will be calculated in the first return argument. The default of 6
is standard, but sometimes we may want to have a different (usually smaller) number of modulation
parameters.
use_additional_conditions (`bool`, *optional*, defaults to `False`):
Whether to use additional conditions for normalization or not.
"""
def __init__(self, embedding_dim: int, num_mod_params: int = 6, use_additional_conditions: bool = False):
super().__init__()
self.num_mod_params = num_mod_params
self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions
)
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, self.num_mod_params * embedding_dim, bias=True)
def forward(
self,
timestep: torch.Tensor,
added_cond_kwargs: dict[str, torch.Tensor] | None = None,
batch_size: int | None = None,
hidden_dtype: torch.dtype | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# No modulation happening here.
added_cond_kwargs = added_cond_kwargs or {"resolution": None, "aspect_ratio": None}
embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
return self.linear(self.silu(embedded_timestep)), embedded_timestep
class LTX2AudioVideoAttnProcessor:
r"""
Processor for implementing attention (SDPA is used by default if you're using PyTorch 2.0) for the LTX-2.0 model.
Compared to the LTX-1.0 model, we allow the RoPE embeddings for the queries and keys to be separate so that we can
support audio-to-video (a2v) and video-to-audio (v2a) cross attention.
"""
_attention_backend = None
_parallel_config = None
def __init__(self):
if is_torch_version("<", "2.0"):
raise ValueError(
"LTX attention processors require a minimum PyTorch version of 2.0. Please upgrade your PyTorch installation."
)
def __call__(
self,
attn: "LTX2Attention",
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
query_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
key_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
) -> torch.Tensor:
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
if attn.to_gate_logits is not None:
# Calculate gate logits on original hidden_states
gate_logits = attn.to_gate_logits(hidden_states)
query = attn.to_q(hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.norm_q(query)
key = attn.norm_k(key)
if query_rotary_emb is not None:
if attn.rope_type == "interleaved":
query = apply_interleaved_rotary_emb(query, query_rotary_emb)
key = apply_interleaved_rotary_emb(
key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb
)
elif attn.rope_type == "split":
query = apply_split_rotary_emb(query, query_rotary_emb)
key = apply_split_rotary_emb(key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb)
query = query.unflatten(2, (attn.heads, -1))
key = key.unflatten(2, (attn.heads, -1))
value = value.unflatten(2, (attn.heads, -1))
hidden_states = dispatch_attention_fn(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
if attn.to_gate_logits is not None:
hidden_states = hidden_states.unflatten(2, (attn.heads, -1)) # [B, T, H, D]
# The factor of 2.0 is so that if the gates logits are zero-initialized the initial gates are all 1
gates = 2.0 * torch.sigmoid(gate_logits) # [B, T, H]
hidden_states = hidden_states * gates.unsqueeze(-1)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
class LTX2PerturbedAttnProcessor:
r"""
Processor which implements attention with perturbation masking and per-head gating for LTX-2.X models.
"""
_attention_backend = None
_parallel_config = None
def __init__(self):
if is_torch_version("<", "2.0"):
raise ValueError(
"LTX attention processors require a minimum PyTorch version of 2.0. Please upgrade your PyTorch installation."
)
def __call__(
self,
attn: "LTX2Attention",
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
query_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
key_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
perturbation_mask: torch.Tensor | None = None,
all_perturbed: bool | None = None,
) -> torch.Tensor:
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
if attn.to_gate_logits is not None:
# Calculate gate logits on original hidden_states
gate_logits = attn.to_gate_logits(hidden_states)
value = attn.to_v(encoder_hidden_states)
if all_perturbed is None:
all_perturbed = torch.all(perturbation_mask == 0) if perturbation_mask is not None else False
if all_perturbed:
# Skip attention, use the value projection value
hidden_states = value
else:
query = attn.to_q(hidden_states)
key = attn.to_k(encoder_hidden_states)
query = attn.norm_q(query)
key = attn.norm_k(key)
if query_rotary_emb is not None:
if attn.rope_type == "interleaved":
query = apply_interleaved_rotary_emb(query, query_rotary_emb)
key = apply_interleaved_rotary_emb(
key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb
)
elif attn.rope_type == "split":
query = apply_split_rotary_emb(query, query_rotary_emb)
key = apply_split_rotary_emb(
key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb
)
query = query.unflatten(2, (attn.heads, -1))
key = key.unflatten(2, (attn.heads, -1))
value = value.unflatten(2, (attn.heads, -1))
hidden_states = dispatch_attention_fn(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
if perturbation_mask is not None:
value = value.flatten(2, 3)
hidden_states = torch.lerp(value, hidden_states, perturbation_mask)
if attn.to_gate_logits is not None:
hidden_states = hidden_states.unflatten(2, (attn.heads, -1)) # [B, T, H, D]
# The factor of 2.0 is so that if the gates logits are zero-initialized the initial gates are all 1
gates = 2.0 * torch.sigmoid(gate_logits) # [B, T, H]
hidden_states = hidden_states * gates.unsqueeze(-1)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
class LTX2Attention(torch.nn.Module, AttentionModuleMixin):
r"""
Attention class for all LTX-2.0 attention layers. Compared to LTX-1.0, this supports specifying the query and key
RoPE embeddings separately for audio-to-video (a2v) and video-to-audio (v2a) cross-attention.
"""
_default_processor_cls = LTX2AudioVideoAttnProcessor
_available_processors = [LTX2AudioVideoAttnProcessor, LTX2PerturbedAttnProcessor]
def __init__(
self,
query_dim: int,
heads: int = 8,
kv_heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
bias: bool = True,
cross_attention_dim: int | None = None,
out_bias: bool = True,
qk_norm: str = "rms_norm_across_heads",
norm_eps: float = 1e-6,
norm_elementwise_affine: bool = True,
rope_type: str = "interleaved",
apply_gated_attention: bool = False,
processor=None,
):
super().__init__()
if qk_norm != "rms_norm_across_heads":
raise NotImplementedError("Only 'rms_norm_across_heads' is supported as a valid value for `qk_norm`.")
self.head_dim = dim_head
self.inner_dim = dim_head * heads
self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
self.query_dim = query_dim
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.use_bias = bias
self.dropout = dropout
self.out_dim = query_dim
self.heads = heads
self.rope_type = rope_type
self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
self.norm_k = torch.nn.RMSNorm(dim_head * kv_heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_k = torch.nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
self.to_v = torch.nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
self.to_out = torch.nn.ModuleList([])
self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
self.to_out.append(torch.nn.Dropout(dropout))
if apply_gated_attention:
# Per head gate values
self.to_gate_logits = torch.nn.Linear(query_dim, heads, bias=True)
else:
self.to_gate_logits = None
if processor is None:
processor = self._default_processor_cls()
self.set_processor(processor)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
query_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
key_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
**kwargs,
) -> torch.Tensor:
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters]
if len(unused_kwargs) > 0:
logger.warning(
f"attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
)
kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
hidden_states = self.processor(
self, hidden_states, encoder_hidden_states, attention_mask, query_rotary_emb, key_rotary_emb, **kwargs
)
return hidden_states
class LTX2VideoTransformerBlock(nn.Module):
r"""
Transformer block used in [LTX-2.0](https://huggingface.co/Lightricks/LTX-Video).
Args:
dim (`int`):
The number of channels in the input and output.
num_attention_heads (`int`):
The number of heads to use for multi-head attention.
attention_head_dim (`int`):
The number of channels in each head.
qk_norm (`str`, defaults to `"rms_norm"`):
The normalization layer to use.
activation_fn (`str`, defaults to `"gelu-approximate"`):
Activation function to use in feed-forward.
eps (`float`, defaults to `1e-6`):
Epsilon value for normalization layers.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
cross_attention_dim: int,
audio_dim: int,
audio_num_attention_heads: int,
audio_attention_head_dim,
audio_cross_attention_dim: int,
video_gated_attn: bool = False,
video_cross_attn_adaln: bool = False,
audio_gated_attn: bool = False,
audio_cross_attn_adaln: bool = False,
qk_norm: str = "rms_norm_across_heads",
activation_fn: str = "gelu-approximate",
attention_bias: bool = True,
attention_out_bias: bool = True,
eps: float = 1e-6,
elementwise_affine: bool = False,
rope_type: str = "interleaved",
perturbed_attn: bool = False,
):
super().__init__()
self.perturbed_attn = perturbed_attn
if perturbed_attn:
attn_processor_cls = LTX2PerturbedAttnProcessor
else:
attn_processor_cls = LTX2AudioVideoAttnProcessor
# 1. Self-Attention (video and audio)
self.norm1 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
self.attn1 = LTX2Attention(
query_dim=dim,
heads=num_attention_heads,
kv_heads=num_attention_heads,
dim_head=attention_head_dim,
bias=attention_bias,
cross_attention_dim=None,
out_bias=attention_out_bias,
qk_norm=qk_norm,
rope_type=rope_type,
apply_gated_attention=video_gated_attn,
processor=attn_processor_cls(),
)
self.audio_norm1 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine)
self.audio_attn1 = LTX2Attention(
query_dim=audio_dim,
heads=audio_num_attention_heads,
kv_heads=audio_num_attention_heads,
dim_head=audio_attention_head_dim,
bias=attention_bias,
cross_attention_dim=None,
out_bias=attention_out_bias,
qk_norm=qk_norm,
rope_type=rope_type,
apply_gated_attention=audio_gated_attn,
processor=attn_processor_cls(),
)
# 2. Prompt Cross-Attention
self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
self.attn2 = LTX2Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
kv_heads=num_attention_heads,
dim_head=attention_head_dim,
bias=attention_bias,
out_bias=attention_out_bias,
qk_norm=qk_norm,
rope_type=rope_type,
apply_gated_attention=video_gated_attn,
processor=attn_processor_cls(),
)
self.audio_norm2 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine)
self.audio_attn2 = LTX2Attention(
query_dim=audio_dim,
cross_attention_dim=audio_cross_attention_dim,
heads=audio_num_attention_heads,
kv_heads=audio_num_attention_heads,
dim_head=audio_attention_head_dim,
bias=attention_bias,
out_bias=attention_out_bias,
qk_norm=qk_norm,
rope_type=rope_type,
apply_gated_attention=audio_gated_attn,
processor=attn_processor_cls(),
)
# 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention
# Audio-to-Video (a2v) Attention --> Q: Video; K,V: Audio
self.audio_to_video_norm = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
self.audio_to_video_attn = LTX2Attention(
query_dim=dim,
cross_attention_dim=audio_dim,
heads=audio_num_attention_heads,
kv_heads=audio_num_attention_heads,
dim_head=audio_attention_head_dim,
bias=attention_bias,
out_bias=attention_out_bias,
qk_norm=qk_norm,
rope_type=rope_type,
apply_gated_attention=video_gated_attn,
processor=attn_processor_cls(),
)
# Video-to-Audio (v2a) Attention --> Q: Audio; K,V: Video
self.video_to_audio_norm = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine)
self.video_to_audio_attn = LTX2Attention(
query_dim=audio_dim,
cross_attention_dim=dim,
heads=audio_num_attention_heads,
kv_heads=audio_num_attention_heads,
dim_head=audio_attention_head_dim,
bias=attention_bias,
out_bias=attention_out_bias,
qk_norm=qk_norm,
rope_type=rope_type,
apply_gated_attention=audio_gated_attn,
processor=attn_processor_cls(),
)
# 4. Feedforward layers
self.norm3 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
self.ff = FeedForward(dim, activation_fn=activation_fn)
self.audio_norm3 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine)
self.audio_ff = FeedForward(audio_dim, activation_fn=activation_fn)
# 5. Per-Layer Modulation Parameters
# Self-Attention (attn1) / Feedforward AdaLayerNorm-Zero mod params
# 6 base mod params for text cross-attn K,V; if cross_attn_adaln, also has mod params for Q
self.video_cross_attn_adaln = video_cross_attn_adaln
self.audio_cross_attn_adaln = audio_cross_attn_adaln
video_mod_param_num = 9 if self.video_cross_attn_adaln else 6
audio_mod_param_num = 9 if self.audio_cross_attn_adaln else 6
self.scale_shift_table = nn.Parameter(torch.randn(video_mod_param_num, dim) / dim**0.5)
self.audio_scale_shift_table = nn.Parameter(torch.randn(audio_mod_param_num, audio_dim) / audio_dim**0.5)
# Prompt cross-attn (attn2) additional modulation params
self.cross_attn_adaln = video_cross_attn_adaln or audio_cross_attn_adaln
if self.cross_attn_adaln:
self.prompt_scale_shift_table = nn.Parameter(torch.randn(2, dim))
self.audio_prompt_scale_shift_table = nn.Parameter(torch.randn(2, audio_dim))
# Per-layer a2v, v2a Cross-Attention mod params
self.video_a2v_cross_attn_scale_shift_table = nn.Parameter(torch.randn(5, dim))
self.audio_a2v_cross_attn_scale_shift_table = nn.Parameter(torch.randn(5, audio_dim))
@staticmethod
def get_mod_params(
scale_shift_table: torch.Tensor, temb: torch.Tensor, batch_size: int
) -> tuple[torch.Tensor, ...]:
num_ada_params = scale_shift_table.shape[0]
ada_values = scale_shift_table[None, None].to(temb.device) + temb.reshape(
batch_size, temb.shape[1], num_ada_params, -1
)
ada_params = ada_values.unbind(dim=2)
return ada_params
def forward(
self,
hidden_states: torch.Tensor,
audio_hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
audio_encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
temb_audio: torch.Tensor,
temb_ca_scale_shift: torch.Tensor,
temb_ca_audio_scale_shift: torch.Tensor,
temb_ca_gate: torch.Tensor,
temb_ca_audio_gate: torch.Tensor,
temb_prompt: torch.Tensor | None = None,
temb_prompt_audio: torch.Tensor | None = None,
video_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
audio_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
ca_video_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
ca_audio_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
encoder_attention_mask: torch.Tensor | None = None,
audio_encoder_attention_mask: torch.Tensor | None = None,
self_attention_mask: torch.Tensor | None = None,
audio_self_attention_mask: torch.Tensor | None = None,
a2v_cross_attention_mask: torch.Tensor | None = None,
v2a_cross_attention_mask: torch.Tensor | None = None,
use_a2v_cross_attention: bool = True,
use_v2a_cross_attention: bool = True,
perturbation_mask: torch.Tensor | None = None,
all_perturbed: bool | None = None,
) -> torch.Tensor:
batch_size = hidden_states.size(0)
# 1. Video and Audio Self-Attention
# 1.1. Video Self-Attention
video_ada_params = self.get_mod_params(self.scale_shift_table, temb, batch_size)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = video_ada_params[:6]
if self.video_cross_attn_adaln:
shift_text_q, scale_text_q, gate_text_q = video_ada_params[6:9]
norm_hidden_states = self.norm1(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
video_self_attn_args = {
"hidden_states": norm_hidden_states,
"encoder_hidden_states": None,
"query_rotary_emb": video_rotary_emb,
"attention_mask": self_attention_mask,
}
if self.perturbed_attn:
video_self_attn_args["perturbation_mask"] = perturbation_mask
video_self_attn_args["all_perturbed"] = all_perturbed
attn_hidden_states = self.attn1(**video_self_attn_args)
hidden_states = hidden_states + attn_hidden_states * gate_msa
# 1.2. Audio Self-Attention
audio_ada_params = self.get_mod_params(self.audio_scale_shift_table, temb_audio, batch_size)
audio_shift_msa, audio_scale_msa, audio_gate_msa, audio_shift_mlp, audio_scale_mlp, audio_gate_mlp = (
audio_ada_params[:6]
)
if self.audio_cross_attn_adaln:
audio_shift_text_q, audio_scale_text_q, audio_gate_text_q = audio_ada_params[6:9]
norm_audio_hidden_states = self.audio_norm1(audio_hidden_states)
norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_scale_msa) + audio_shift_msa
audio_self_attn_args = {
"hidden_states": norm_audio_hidden_states,
"encoder_hidden_states": None,
"query_rotary_emb": audio_rotary_emb,
"attention_mask": audio_self_attention_mask,
}
if self.perturbed_attn:
audio_self_attn_args["perturbation_mask"] = perturbation_mask
audio_self_attn_args["all_perturbed"] = all_perturbed
attn_audio_hidden_states = self.audio_attn1(**audio_self_attn_args)
audio_hidden_states = audio_hidden_states + attn_audio_hidden_states * audio_gate_msa
# 2. Video and Audio Cross-Attention with the text embeddings (Q: Video or Audio; K,V: Text)
if self.cross_attn_adaln:
video_prompt_ada_params = self.get_mod_params(self.prompt_scale_shift_table, temb_prompt, batch_size)
shift_text_kv, scale_text_kv = video_prompt_ada_params
audio_prompt_ada_params = self.get_mod_params(
self.audio_prompt_scale_shift_table, temb_prompt_audio, batch_size
)
audio_shift_text_kv, audio_scale_text_kv = audio_prompt_ada_params
# 2.1. Video-Text Cross-Attention (Q: Video; K,V: Text)
norm_hidden_states = self.norm2(hidden_states)
if self.video_cross_attn_adaln:
norm_hidden_states = norm_hidden_states * (1 + scale_text_q) + shift_text_q
if self.cross_attn_adaln:
encoder_hidden_states = encoder_hidden_states * (1 + scale_text_kv) + shift_text_kv
attn_hidden_states = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
query_rotary_emb=None,
attention_mask=encoder_attention_mask,
)
if self.video_cross_attn_adaln:
attn_hidden_states = attn_hidden_states * gate_text_q
hidden_states = hidden_states + attn_hidden_states
# 2.2. Audio-Text Cross-Attention
norm_audio_hidden_states = self.audio_norm2(audio_hidden_states)
if self.audio_cross_attn_adaln:
norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_scale_text_q) + audio_shift_text_q
if self.cross_attn_adaln:
audio_encoder_hidden_states = audio_encoder_hidden_states * (1 + audio_scale_text_kv) + audio_shift_text_kv
attn_audio_hidden_states = self.audio_attn2(
norm_audio_hidden_states,
encoder_hidden_states=audio_encoder_hidden_states,
query_rotary_emb=None,
attention_mask=audio_encoder_attention_mask,
)
if self.audio_cross_attn_adaln:
attn_audio_hidden_states = attn_audio_hidden_states * audio_gate_text_q
audio_hidden_states = audio_hidden_states + attn_audio_hidden_states
# 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention
if use_a2v_cross_attention or use_v2a_cross_attention:
norm_hidden_states = self.audio_to_video_norm(hidden_states)
norm_audio_hidden_states = self.video_to_audio_norm(audio_hidden_states)
# 3.1. Combine global and per-layer cross attention modulation parameters
# Video
video_per_layer_ca_scale_shift = self.video_a2v_cross_attn_scale_shift_table[:4, :]
video_per_layer_ca_gate = self.video_a2v_cross_attn_scale_shift_table[4:, :]
video_ca_ada_params = self.get_mod_params(video_per_layer_ca_scale_shift, temb_ca_scale_shift, batch_size)
video_ca_gate_param = self.get_mod_params(video_per_layer_ca_gate, temb_ca_gate, batch_size)
video_a2v_ca_scale, video_a2v_ca_shift, video_v2a_ca_scale, video_v2a_ca_shift = video_ca_ada_params
a2v_gate = video_ca_gate_param[0].squeeze(2)
# Audio
audio_per_layer_ca_scale_shift = self.audio_a2v_cross_attn_scale_shift_table[:4, :]
audio_per_layer_ca_gate = self.audio_a2v_cross_attn_scale_shift_table[4:, :]
audio_ca_ada_params = self.get_mod_params(
audio_per_layer_ca_scale_shift, temb_ca_audio_scale_shift, batch_size
)
audio_ca_gate_param = self.get_mod_params(audio_per_layer_ca_gate, temb_ca_audio_gate, batch_size)
audio_a2v_ca_scale, audio_a2v_ca_shift, audio_v2a_ca_scale, audio_v2a_ca_shift = audio_ca_ada_params
v2a_gate = audio_ca_gate_param[0].squeeze(2)
# 3.2. Audio-to-Video Cross Attention: Q: Video; K,V: Audio
if use_a2v_cross_attention:
mod_norm_hidden_states = norm_hidden_states * (
1 + video_a2v_ca_scale.squeeze(2)
) + video_a2v_ca_shift.squeeze(2)
mod_norm_audio_hidden_states = norm_audio_hidden_states * (
1 + audio_a2v_ca_scale.squeeze(2)
) + audio_a2v_ca_shift.squeeze(2)
a2v_attn_hidden_states = self.audio_to_video_attn(
mod_norm_hidden_states,
encoder_hidden_states=mod_norm_audio_hidden_states,
query_rotary_emb=ca_video_rotary_emb,
key_rotary_emb=ca_audio_rotary_emb,
attention_mask=a2v_cross_attention_mask,
)
hidden_states = hidden_states + a2v_gate * a2v_attn_hidden_states
# 3.3. Video-to-Audio Cross Attention: Q: Audio; K,V: Video
if use_v2a_cross_attention:
mod_norm_hidden_states = norm_hidden_states * (
1 + video_v2a_ca_scale.squeeze(2)
) + video_v2a_ca_shift.squeeze(2)
mod_norm_audio_hidden_states = norm_audio_hidden_states * (
1 + audio_v2a_ca_scale.squeeze(2)
) + audio_v2a_ca_shift.squeeze(2)
v2a_attn_hidden_states = self.video_to_audio_attn(
mod_norm_audio_hidden_states,
encoder_hidden_states=mod_norm_hidden_states,
query_rotary_emb=ca_audio_rotary_emb,
key_rotary_emb=ca_video_rotary_emb,
attention_mask=v2a_cross_attention_mask,
)
audio_hidden_states = audio_hidden_states + v2a_gate * v2a_attn_hidden_states
# 4. Feedforward
norm_hidden_states = self.norm3(hidden_states) * (1 + scale_mlp) + shift_mlp
ff_output = self.ff(norm_hidden_states)
hidden_states = hidden_states + ff_output * gate_mlp
norm_audio_hidden_states = self.audio_norm3(audio_hidden_states) * (1 + audio_scale_mlp) + audio_shift_mlp
audio_ff_output = self.audio_ff(norm_audio_hidden_states)
audio_hidden_states = audio_hidden_states + audio_ff_output * audio_gate_mlp
return hidden_states, audio_hidden_states
class LTX2AudioVideoRotaryPosEmbed(nn.Module):
"""
Video and audio rotary positional embeddings (RoPE) for the LTX-2.0 model.
Args:
causal_offset (`int`, *optional*, defaults to `1`):
Offset in the temporal axis for causal VAE modeling. This is typically 1 (for causal modeling where the VAE
treats the very first frame differently), but could also be 0 (for non-causal modeling).
"""
def __init__(
self,
dim: int,
patch_size: int = 1,
patch_size_t: int = 1,
base_num_frames: int = 20,
base_height: int = 2048,
base_width: int = 2048,
sampling_rate: int = 16000,
hop_length: int = 160,
scale_factors: tuple[int, ...] = (8, 32, 32),
theta: float = 10000.0,
causal_offset: int = 1,
modality: str = "video",
double_precision: bool = True,
rope_type: str = "interleaved",
num_attention_heads: int = 32,
) -> None:
super().__init__()
self.dim = dim
self.patch_size = patch_size
self.patch_size_t = patch_size_t
if rope_type not in ["interleaved", "split"]:
raise ValueError(f"{rope_type=} not supported. Choose between 'interleaved' and 'split'.")
self.rope_type = rope_type
self.base_num_frames = base_num_frames
self.num_attention_heads = num_attention_heads
# Video-specific
self.base_height = base_height
self.base_width = base_width
# Audio-specific
self.sampling_rate = sampling_rate
self.hop_length = hop_length
self.audio_latents_per_second = float(sampling_rate) / float(hop_length) / float(scale_factors[0])
self.scale_factors = scale_factors
self.theta = theta
self.causal_offset = causal_offset
self.modality = modality
if self.modality not in ["video", "audio"]:
raise ValueError(f"Modality {modality} is not supported. Supported modalities are `video` and `audio`.")
self.double_precision = double_precision
def prepare_video_coords(
self,
batch_size: int,
num_frames: int,
height: int,
width: int,
device: torch.device,
fps: float = 24.0,
) -> torch.Tensor:
"""
Create per-dimension bounds [inclusive start, exclusive end) for each patch with respect to the original pixel
space video grid (num_frames, height, width). This will ultimately have shape (batch_size, 3, num_patches, 2)
where
- axis 1 (size 3) enumerates (frame, height, width) dimensions (e.g. idx 0 corresponds to frames)
- axis 3 (size 2) stores `[start, end)` indices within each dimension
Args:
batch_size (`int`):
Batch size of the video latents.
num_frames (`int`):
Number of latent frames in the video latents.
height (`int`):
Latent height of the video latents.
width (`int`):
Latent width of the video latents.
device (`torch.device`):
Device on which to create the video grid.
Returns:
`torch.Tensor`:
Per-dimension patch boundaries tensor of shape [batch_size, 3, num_patches, 2].
"""
# 1. Generate grid coordinates for each spatiotemporal dimension (frames, height, width)
# Always compute rope in fp32
grid_f = torch.arange(start=0, end=num_frames, step=self.patch_size_t, dtype=torch.float32, device=device)
grid_h = torch.arange(start=0, end=height, step=self.patch_size, dtype=torch.float32, device=device)
grid_w = torch.arange(start=0, end=width, step=self.patch_size, dtype=torch.float32, device=device)
# indexing='ij' ensures that the dimensions are kept in order as (frames, height, width)
grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing="ij")
grid = torch.stack(grid, dim=0) # [3, N_F, N_H, N_W], where e.g. N_F is the number of temporal patches
# 2. Get the patch boundaries with respect to the latent video grid
patch_size = (self.patch_size_t, self.patch_size, self.patch_size)
patch_size_delta = torch.tensor(patch_size, dtype=grid.dtype, device=grid.device)
patch_ends = grid + patch_size_delta.view(3, 1, 1, 1)
# Combine the start (grid) and end (patch_ends) coordinates along new trailing dimension
latent_coords = torch.stack([grid, patch_ends], dim=-1) # [3, N_F, N_H, N_W, 2]
# Reshape to (batch_size, 3, num_patches, 2)
latent_coords = latent_coords.flatten(1, 3)
latent_coords = latent_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1)
# 3. Calculate the pixel space patch boundaries from the latent boundaries.
scale_tensor = torch.tensor(self.scale_factors, device=latent_coords.device)
# Broadcast the VAE scale factors such that they are compatible with latent_coords's shape
broadcast_shape = [1] * latent_coords.ndim
broadcast_shape[1] = -1 # This is the (frame, height, width) dim
# Apply per-axis scaling to convert latent coordinates to pixel space coordinates
pixel_coords = latent_coords * scale_tensor.view(*broadcast_shape)
# As the VAE temporal stride for the first frame is 1 instead of self.vae_scale_factors[0], we need to shift
# and clamp to keep the first-frame timestamps causal and non-negative.
pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + self.causal_offset - self.scale_factors[0]).clamp(min=0)
# Scale the temporal coordinates by the video FPS
pixel_coords[:, 0, ...] = pixel_coords[:, 0, ...] / fps
return pixel_coords
def prepare_audio_coords(
self,
batch_size: int,
num_frames: int,
device: torch.device,
shift: int = 0,
) -> torch.Tensor:
"""
Create per-dimension bounds [inclusive start, exclusive end) of start and end timestamps for each latent frame.
This will ultimately have shape (batch_size, 3, num_patches, 2) where
- axis 1 (size 1) represents the temporal dimension
- axis 3 (size 2) stores `[start, end)` indices within each dimension
Args:
batch_size (`int`):
Batch size of the audio latents.
num_frames (`int`):
Number of latent frames in the audio latents.
device (`torch.device`):
Device on which to create the audio grid.
shift (`int`, *optional*, defaults to `0`):
Offset on the latent indices. Different shift values correspond to different overlapping windows with
respect to the same underlying latent grid.
Returns:
`torch.Tensor`:
Per-dimension patch boundaries tensor of shape [batch_size, 1, num_patches, 2].
"""
# 1. Generate coordinates in the frame (time) dimension.
# Always compute rope in fp32
grid_f = torch.arange(
start=shift, end=num_frames + shift, step=self.patch_size_t, dtype=torch.float32, device=device
)
# 2. Calculate start timstamps in seconds with respect to the original spectrogram grid
audio_scale_factor = self.scale_factors[0]
# Scale back to mel spectrogram space
grid_start_mel = grid_f * audio_scale_factor
# Handle first frame causal offset, ensuring non-negative timestamps
grid_start_mel = (grid_start_mel + self.causal_offset - audio_scale_factor).clip(min=0)
# Convert mel bins back into seconds
grid_start_s = grid_start_mel * self.hop_length / self.sampling_rate
# 3. Calculate start timstamps in seconds with respect to the original spectrogram grid
grid_end_mel = (grid_f + self.patch_size_t) * audio_scale_factor
grid_end_mel = (grid_end_mel + self.causal_offset - audio_scale_factor).clip(min=0)
grid_end_s = grid_end_mel * self.hop_length / self.sampling_rate
audio_coords = torch.stack([grid_start_s, grid_end_s], dim=-1) # [num_patches, 2]
audio_coords = audio_coords.unsqueeze(0).expand(batch_size, -1, -1) # [batch_size, num_patches, 2]
audio_coords = audio_coords.unsqueeze(1) # [batch_size, 1, num_patches, 2]
return audio_coords
def prepare_coords(self, *args, **kwargs):
if self.modality == "video":
return self.prepare_video_coords(*args, **kwargs)
elif self.modality == "audio":
return self.prepare_audio_coords(*args, **kwargs)
def forward(
self, coords: torch.Tensor, device: str | torch.device | None = None
) -> tuple[torch.Tensor, torch.Tensor]:
device = device or coords.device
# Number of spatiotemporal dimensions (3 for video, 1 (temporal) for audio and cross attn)
num_pos_dims = coords.shape[1]
# 1. If the coords are patch boundaries [start, end), use the midpoint of these boundaries as the patch
# position index
if coords.ndim == 4:
coords_start, coords_end = coords.chunk(2, dim=-1)
coords = (coords_start + coords_end) / 2.0
coords = coords.squeeze(-1) # [B, num_pos_dims, num_patches]
# 2. Get coordinates as a fraction of the base data shape
if self.modality == "video":