Skip to content

Commit 072d15e

Browse files
dg845sayakpaul
andauthored
Add Support for LTX-2.3 Models (#13217)
* Initial implementation of perturbed attn processor for LTX 2.3 * Update DiT block for LTX 2.3 + add self_attention_mask * Add flag to control using perturbed attn processor for now * Add support for new video upsampling blocks used by LTX-2.3 * Support LTX-2.3 Big-VGAN V2-style vocoder * Initial implementation of LTX-2.3 vocoder with bandwidth extender * Initial support for LTX-2.3 per-modality feature extractor * Refactor so that text connectors own all text encoder hidden_states normalization logic * Fix some bugs for inference * Fix LTX-2.X DiT block forward pass * Support prompt timestep embeds and prompt cross attn modulation * Add LTX-2.3 configs to conversion script * Support converting LTX-2.3 DiT checkpoints * Support converting LTX-2.3 Video VAE checkpoints * Support converting LTX-2.3 Vocoder with bandwidth extender * Support converting LTX-2.3 text connectors * Don't convert any upsamplers for now * Support self attention mask for LTX2Pipeline * Fix some inference bugs * Support self attn mask and sigmas for LTX-2.3 I2V, Cond pipelines * Support STG and modality isolation guidance for LTX-2.3 * make style and make quality * Make audio guidance values default to video values by default * Update to LTX-2.3 style guidance rescaling * Support cross timesteps for LTX-2.3 cross attention modulation * Fix RMS norm bug for LTX-2.3 text connectors * Perform guidance rescale in sample (x0) space following original code * Support LTX-2.3 Latent Spatial Upsampler model * Support LTX-2.3 distilled LoRA * Support LTX-2.3 Distilled checkpoint * Support LTX-2.3 prompt enhancement * Make LTX-2.X processor non-required so that tests pass * Fix test_components_function tests for LTX2 T2V and I2V * Fix LTX-2.3 Video VAE configuration bug causing pixel jitter * Apply suggestions from code review Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * Refactor LTX-2.X Video VAE upsampler block init logic * Refactor LTX-2.X guidance rescaling to use rescale_noise_cfg * Use generator initial seed to control prompt enhancement if available * Remove self attention mask logic as it is not used in any current pipelines * Commit fixes suggested by claude code (guidance in sample (x0) space, denormalize after timestep conditioning) * Use constant shift following original code --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent 6761336 commit 072d15e

File tree

13 files changed

+2494
-566
lines changed

13 files changed

+2494
-566
lines changed

scripts/convert_ltx2_to_diffusers.py

Lines changed: 335 additions & 43 deletions
Large diffs are not rendered by default.

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2156,6 +2156,9 @@ def _convert_non_diffusers_ltx2_lora_to_diffusers(state_dict, non_diffusers_pref
21562156
"scale_shift_table_a2v_ca_audio": "audio_a2v_cross_attn_scale_shift_table",
21572157
"q_norm": "norm_q",
21582158
"k_norm": "norm_k",
2159+
# LTX-2.3
2160+
"audio_prompt_adaln_single": "audio_prompt_adaln",
2161+
"prompt_adaln_single": "prompt_adaln",
21592162
}
21602163
else:
21612164
rename_dict = {"aggregate_embed": "text_proj_in"}

src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py

Lines changed: 58 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ def forward(
237237

238238

239239
# Like LTX 1.0 LTXVideoDownsampler3d, but uses new causal Conv3d
240-
class LTXVideoDownsampler3d(nn.Module):
240+
class LTX2VideoDownsampler3d(nn.Module):
241241
def __init__(
242242
self,
243243
in_channels: int,
@@ -285,10 +285,11 @@ def forward(self, hidden_states: torch.Tensor, causal: bool = True) -> torch.Ten
285285

286286

287287
# Like LTX 1.0 LTXVideoUpsampler3d, but uses new causal Conv3d
288-
class LTXVideoUpsampler3d(nn.Module):
288+
class LTX2VideoUpsampler3d(nn.Module):
289289
def __init__(
290290
self,
291291
in_channels: int,
292+
out_channels: int | None = None,
292293
stride: int | tuple[int, int, int] = 1,
293294
residual: bool = False,
294295
upscale_factor: int = 1,
@@ -300,7 +301,8 @@ def __init__(
300301
self.residual = residual
301302
self.upscale_factor = upscale_factor
302303

303-
out_channels = (in_channels * stride[0] * stride[1] * stride[2]) // upscale_factor
304+
out_channels = out_channels or in_channels
305+
out_channels = (out_channels * stride[0] * stride[1] * stride[2]) // upscale_factor
304306

305307
self.conv = LTX2VideoCausalConv3d(
306308
in_channels=in_channels,
@@ -408,7 +410,7 @@ def __init__(
408410
)
409411
elif downsample_type == "spatial":
410412
self.downsamplers.append(
411-
LTXVideoDownsampler3d(
413+
LTX2VideoDownsampler3d(
412414
in_channels=in_channels,
413415
out_channels=out_channels,
414416
stride=(1, 2, 2),
@@ -417,7 +419,7 @@ def __init__(
417419
)
418420
elif downsample_type == "temporal":
419421
self.downsamplers.append(
420-
LTXVideoDownsampler3d(
422+
LTX2VideoDownsampler3d(
421423
in_channels=in_channels,
422424
out_channels=out_channels,
423425
stride=(2, 1, 1),
@@ -426,7 +428,7 @@ def __init__(
426428
)
427429
elif downsample_type == "spatiotemporal":
428430
self.downsamplers.append(
429-
LTXVideoDownsampler3d(
431+
LTX2VideoDownsampler3d(
430432
in_channels=in_channels,
431433
out_channels=out_channels,
432434
stride=(2, 2, 2),
@@ -580,6 +582,7 @@ def __init__(
580582
resnet_eps: float = 1e-6,
581583
resnet_act_fn: str = "swish",
582584
spatio_temporal_scale: bool = True,
585+
upsample_type: str = "spatiotemporal",
583586
inject_noise: bool = False,
584587
timestep_conditioning: bool = False,
585588
upsample_residual: bool = False,
@@ -609,16 +612,23 @@ def __init__(
609612

610613
self.upsamplers = None
611614
if spatio_temporal_scale:
612-
self.upsamplers = nn.ModuleList(
613-
[
614-
LTXVideoUpsampler3d(
615-
out_channels * upscale_factor,
616-
stride=(2, 2, 2),
617-
residual=upsample_residual,
618-
upscale_factor=upscale_factor,
619-
spatial_padding_mode=spatial_padding_mode,
620-
)
621-
]
615+
self.upsamplers = nn.ModuleList()
616+
617+
if upsample_type == "spatial":
618+
upsample_stride = (1, 2, 2)
619+
elif upsample_type == "temporal":
620+
upsample_stride = (2, 1, 1)
621+
elif upsample_type == "spatiotemporal":
622+
upsample_stride = (2, 2, 2)
623+
624+
self.upsamplers.append(
625+
LTX2VideoUpsampler3d(
626+
in_channels=out_channels * upscale_factor,
627+
stride=upsample_stride,
628+
residual=upsample_residual,
629+
upscale_factor=upscale_factor,
630+
spatial_padding_mode=spatial_padding_mode,
631+
)
622632
)
623633

624634
resnets = []
@@ -716,7 +726,7 @@ def __init__(
716726
"LTX2VideoDownBlock3D",
717727
"LTX2VideoDownBlock3D",
718728
),
719-
spatio_temporal_scaling: tuple[bool, ...] = (True, True, True, True),
729+
spatio_temporal_scaling: bool | tuple[bool, ...] = (True, True, True, True),
720730
layers_per_block: tuple[int, ...] = (4, 6, 6, 2, 2),
721731
downsample_type: tuple[str, ...] = ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
722732
patch_size: int = 4,
@@ -726,6 +736,9 @@ def __init__(
726736
spatial_padding_mode: str = "zeros",
727737
):
728738
super().__init__()
739+
num_encoder_blocks = len(layers_per_block)
740+
if isinstance(spatio_temporal_scaling, bool):
741+
spatio_temporal_scaling = (spatio_temporal_scaling,) * (num_encoder_blocks - 1)
729742

730743
self.patch_size = patch_size
731744
self.patch_size_t = patch_size_t
@@ -860,19 +873,27 @@ def __init__(
860873
in_channels: int = 128,
861874
out_channels: int = 3,
862875
block_out_channels: tuple[int, ...] = (256, 512, 1024),
863-
spatio_temporal_scaling: tuple[bool, ...] = (True, True, True),
876+
spatio_temporal_scaling: bool | tuple[bool, ...] = (True, True, True),
864877
layers_per_block: tuple[int, ...] = (5, 5, 5, 5),
878+
upsample_type: tuple[str, ...] = ("spatiotemporal", "spatiotemporal", "spatiotemporal"),
865879
patch_size: int = 4,
866880
patch_size_t: int = 1,
867881
resnet_norm_eps: float = 1e-6,
868882
is_causal: bool = False,
869-
inject_noise: tuple[bool, ...] = (False, False, False),
883+
inject_noise: bool | tuple[bool, ...] = (False, False, False),
870884
timestep_conditioning: bool = False,
871-
upsample_residual: tuple[bool, ...] = (True, True, True),
885+
upsample_residual: bool | tuple[bool, ...] = (True, True, True),
872886
upsample_factor: tuple[bool, ...] = (2, 2, 2),
873887
spatial_padding_mode: str = "reflect",
874888
) -> None:
875889
super().__init__()
890+
num_decoder_blocks = len(layers_per_block)
891+
if isinstance(spatio_temporal_scaling, bool):
892+
spatio_temporal_scaling = (spatio_temporal_scaling,) * (num_decoder_blocks - 1)
893+
if isinstance(inject_noise, bool):
894+
inject_noise = (inject_noise,) * num_decoder_blocks
895+
if isinstance(upsample_residual, bool):
896+
upsample_residual = (upsample_residual,) * (num_decoder_blocks - 1)
876897

877898
self.patch_size = patch_size
878899
self.patch_size_t = patch_size_t
@@ -917,6 +938,7 @@ def __init__(
917938
num_layers=layers_per_block[i + 1],
918939
resnet_eps=resnet_norm_eps,
919940
spatio_temporal_scale=spatio_temporal_scaling[i],
941+
upsample_type=upsample_type[i],
920942
inject_noise=inject_noise[i + 1],
921943
timestep_conditioning=timestep_conditioning,
922944
upsample_residual=upsample_residual[i],
@@ -1058,11 +1080,12 @@ def __init__(
10581080
decoder_block_out_channels: tuple[int, ...] = (256, 512, 1024),
10591081
layers_per_block: tuple[int, ...] = (4, 6, 6, 2, 2),
10601082
decoder_layers_per_block: tuple[int, ...] = (5, 5, 5, 5),
1061-
spatio_temporal_scaling: tuple[bool, ...] = (True, True, True, True),
1062-
decoder_spatio_temporal_scaling: tuple[bool, ...] = (True, True, True),
1063-
decoder_inject_noise: tuple[bool, ...] = (False, False, False, False),
1083+
spatio_temporal_scaling: bool | tuple[bool, ...] = (True, True, True, True),
1084+
decoder_spatio_temporal_scaling: bool | tuple[bool, ...] = (True, True, True),
1085+
decoder_inject_noise: bool | tuple[bool, ...] = (False, False, False, False),
10641086
downsample_type: tuple[str, ...] = ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
1065-
upsample_residual: tuple[bool, ...] = (True, True, True),
1087+
upsample_type: tuple[str, ...] = ("spatiotemporal", "spatiotemporal", "spatiotemporal"),
1088+
upsample_residual: bool | tuple[bool, ...] = (True, True, True),
10661089
upsample_factor: tuple[int, ...] = (2, 2, 2),
10671090
timestep_conditioning: bool = False,
10681091
patch_size: int = 4,
@@ -1077,6 +1100,16 @@ def __init__(
10771100
temporal_compression_ratio: int = None,
10781101
) -> None:
10791102
super().__init__()
1103+
num_encoder_blocks = len(layers_per_block)
1104+
num_decoder_blocks = len(decoder_layers_per_block)
1105+
if isinstance(spatio_temporal_scaling, bool):
1106+
spatio_temporal_scaling = (spatio_temporal_scaling,) * (num_encoder_blocks - 1)
1107+
if isinstance(decoder_spatio_temporal_scaling, bool):
1108+
decoder_spatio_temporal_scaling = (decoder_spatio_temporal_scaling,) * (num_decoder_blocks - 1)
1109+
if isinstance(decoder_inject_noise, bool):
1110+
decoder_inject_noise = (decoder_inject_noise,) * num_decoder_blocks
1111+
if isinstance(upsample_residual, bool):
1112+
upsample_residual = (upsample_residual,) * (num_decoder_blocks - 1)
10801113

10811114
self.encoder = LTX2VideoEncoder3d(
10821115
in_channels=in_channels,
@@ -1098,6 +1131,7 @@ def __init__(
10981131
block_out_channels=decoder_block_out_channels,
10991132
spatio_temporal_scaling=decoder_spatio_temporal_scaling,
11001133
layers_per_block=decoder_layers_per_block,
1134+
upsample_type=upsample_type,
11011135
patch_size=patch_size,
11021136
patch_size_t=patch_size_t,
11031137
resnet_norm_eps=resnet_norm_eps,

0 commit comments

Comments
 (0)