Skip to content

Commit 8af8e86

Browse files
authored
LTX 2 Single File Support (#12983)
* LTX 2 transformer single file support * LTX 2 video VAE single file support * LTX 2 audio VAE single file support * Make it easier to distinguish LTX 1 and 2 models
1 parent 74654df commit 8af8e86

File tree

2 files changed

+184
-1
lines changed

2 files changed

+184
-1
lines changed

src/diffusers/loaders/single_file_model.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@
4040
convert_hunyuan_video_transformer_to_diffusers,
4141
convert_ldm_unet_checkpoint,
4242
convert_ldm_vae_checkpoint,
43+
convert_ltx2_audio_vae_to_diffusers,
44+
convert_ltx2_transformer_to_diffusers,
45+
convert_ltx2_vae_to_diffusers,
4346
convert_ltx_transformer_checkpoint_to_diffusers,
4447
convert_ltx_vae_checkpoint_to_diffusers,
4548
convert_lumina2_to_diffusers,
@@ -176,6 +179,18 @@
176179
"ZImageControlNetModel": {
177180
"checkpoint_mapping_fn": convert_z_image_controlnet_checkpoint_to_diffusers,
178181
},
182+
"LTX2VideoTransformer3DModel": {
183+
"checkpoint_mapping_fn": convert_ltx2_transformer_to_diffusers,
184+
"default_subfolder": "transformer",
185+
},
186+
"AutoencoderKLLTX2Video": {
187+
"checkpoint_mapping_fn": convert_ltx2_vae_to_diffusers,
188+
"default_subfolder": "vae",
189+
},
190+
"AutoencoderKLLTX2Audio": {
191+
"checkpoint_mapping_fn": convert_ltx2_audio_vae_to_diffusers,
192+
"default_subfolder": "audio_vae",
193+
},
179194
}
180195

181196

src/diffusers/loaders/single_file_utils.py

Lines changed: 169 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,8 @@
112112
"model.diffusion_model.transformer_blocks.27.scale_shift_table",
113113
"patchify_proj.weight",
114114
"transformer_blocks.27.scale_shift_table",
115-
"vae.per_channel_statistics.mean-of-means",
115+
"vae.decoder.last_scale_shift_table", # 0.9.1, 0.9.5, 0.9.7, 0.9.8
116+
"vae.decoder.up_blocks.9.res_blocks.0.conv1.conv.weight", # 0.9.0
116117
],
117118
"autoencoder-dc": "decoder.stages.1.op_list.0.main.conv.conv.bias",
118119
"autoencoder-dc-sana": "encoder.project_in.conv.bias",
@@ -147,6 +148,11 @@
147148
"net.pos_embedder.dim_spatial_range",
148149
],
149150
"flux2": ["model.diffusion_model.single_stream_modulation.lin.weight", "single_stream_modulation.lin.weight"],
151+
"ltx2": [
152+
"model.diffusion_model.av_ca_a2v_gate_adaln_single.emb.timestep_embedder.linear_1.weight",
153+
"vae.per_channel_statistics.mean-of-means",
154+
"audio_vae.per_channel_statistics.mean-of-means",
155+
],
150156
}
151157

152158
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
@@ -228,6 +234,7 @@
228234
"z-image-turbo-controlnet": {"pretrained_model_name_or_path": "hlky/Z-Image-Turbo-Fun-Controlnet-Union"},
229235
"z-image-turbo-controlnet-2.0": {"pretrained_model_name_or_path": "hlky/Z-Image-Turbo-Fun-Controlnet-Union-2.0"},
230236
"z-image-turbo-controlnet-2.1": {"pretrained_model_name_or_path": "hlky/Z-Image-Turbo-Fun-Controlnet-Union-2.1"},
237+
"ltx2-dev": {"pretrained_model_name_or_path": "Lightricks/LTX-2"},
231238
}
232239

233240
# Use to configure model sample size when original config is provided
@@ -796,6 +803,9 @@ def infer_diffusers_model_type(checkpoint):
796803
elif CHECKPOINT_KEY_NAMES["z-image-turbo-controlnet"] in checkpoint:
797804
model_type = "z-image-turbo-controlnet"
798805

806+
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["ltx2"]):
807+
model_type = "ltx2-dev"
808+
799809
else:
800810
model_type = "v1"
801811

@@ -3920,3 +3930,161 @@ def convert_z_image_controlnet_checkpoint_to_diffusers(checkpoint, config, **kwa
39203930
return converted_state_dict
39213931
else:
39223932
raise ValueError("Unknown Z-Image Turbo ControlNet type.")
3933+
3934+
3935+
def convert_ltx2_transformer_to_diffusers(checkpoint, **kwargs):
3936+
LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT = {
3937+
# Transformer prefix
3938+
"model.diffusion_model.": "",
3939+
# Input Patchify Projections
3940+
"patchify_proj": "proj_in",
3941+
"audio_patchify_proj": "audio_proj_in",
3942+
# Modulation Parameters
3943+
# Handle adaln_single --> time_embed, audioln_single --> audio_time_embed separately as the original keys are
3944+
# substrings of the other modulation parameters below
3945+
"av_ca_video_scale_shift_adaln_single": "av_cross_attn_video_scale_shift",
3946+
"av_ca_a2v_gate_adaln_single": "av_cross_attn_video_a2v_gate",
3947+
"av_ca_audio_scale_shift_adaln_single": "av_cross_attn_audio_scale_shift",
3948+
"av_ca_v2a_gate_adaln_single": "av_cross_attn_audio_v2a_gate",
3949+
# Transformer Blocks
3950+
# Per-Block Cross Attention Modulation Parameters
3951+
"scale_shift_table_a2v_ca_video": "video_a2v_cross_attn_scale_shift_table",
3952+
"scale_shift_table_a2v_ca_audio": "audio_a2v_cross_attn_scale_shift_table",
3953+
# Attention QK Norms
3954+
"q_norm": "norm_q",
3955+
"k_norm": "norm_k",
3956+
}
3957+
3958+
def update_state_dict_inplace(state_dict, old_key: str, new_key: str) -> None:
3959+
state_dict[new_key] = state_dict.pop(old_key)
3960+
3961+
def remove_keys_inplace(key: str, state_dict) -> None:
3962+
state_dict.pop(key)
3963+
3964+
def convert_ltx2_transformer_adaln_single(key: str, state_dict) -> None:
3965+
# Skip if not a weight, bias
3966+
if ".weight" not in key and ".bias" not in key:
3967+
return
3968+
3969+
if key.startswith("adaln_single."):
3970+
new_key = key.replace("adaln_single.", "time_embed.")
3971+
param = state_dict.pop(key)
3972+
state_dict[new_key] = param
3973+
3974+
if key.startswith("audio_adaln_single."):
3975+
new_key = key.replace("audio_adaln_single.", "audio_time_embed.")
3976+
param = state_dict.pop(key)
3977+
state_dict[new_key] = param
3978+
3979+
return
3980+
3981+
LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP = {
3982+
"video_embeddings_connector": remove_keys_inplace,
3983+
"audio_embeddings_connector": remove_keys_inplace,
3984+
"adaln_single": convert_ltx2_transformer_adaln_single,
3985+
}
3986+
3987+
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
3988+
3989+
# Handle official code --> diffusers key remapping via the remap dict
3990+
for key in list(converted_state_dict.keys()):
3991+
new_key = key[:]
3992+
for replace_key, rename_key in LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT.items():
3993+
new_key = new_key.replace(replace_key, rename_key)
3994+
3995+
update_state_dict_inplace(converted_state_dict, key, new_key)
3996+
3997+
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
3998+
# special_keys_remap
3999+
for key in list(converted_state_dict.keys()):
4000+
for special_key, handler_fn_inplace in LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP.items():
4001+
if special_key not in key:
4002+
continue
4003+
handler_fn_inplace(key, converted_state_dict)
4004+
4005+
return converted_state_dict
4006+
4007+
4008+
def convert_ltx2_vae_to_diffusers(checkpoint, **kwargs):
4009+
LTX_2_0_VIDEO_VAE_RENAME_DICT = {
4010+
# Video VAE prefix
4011+
"vae.": "",
4012+
# Encoder
4013+
"down_blocks.0": "down_blocks.0",
4014+
"down_blocks.1": "down_blocks.0.downsamplers.0",
4015+
"down_blocks.2": "down_blocks.1",
4016+
"down_blocks.3": "down_blocks.1.downsamplers.0",
4017+
"down_blocks.4": "down_blocks.2",
4018+
"down_blocks.5": "down_blocks.2.downsamplers.0",
4019+
"down_blocks.6": "down_blocks.3",
4020+
"down_blocks.7": "down_blocks.3.downsamplers.0",
4021+
"down_blocks.8": "mid_block",
4022+
# Decoder
4023+
"up_blocks.0": "mid_block",
4024+
"up_blocks.1": "up_blocks.0.upsamplers.0",
4025+
"up_blocks.2": "up_blocks.0",
4026+
"up_blocks.3": "up_blocks.1.upsamplers.0",
4027+
"up_blocks.4": "up_blocks.1",
4028+
"up_blocks.5": "up_blocks.2.upsamplers.0",
4029+
"up_blocks.6": "up_blocks.2",
4030+
# Common
4031+
# For all 3D ResNets
4032+
"res_blocks": "resnets",
4033+
"per_channel_statistics.mean-of-means": "latents_mean",
4034+
"per_channel_statistics.std-of-means": "latents_std",
4035+
}
4036+
4037+
def update_state_dict_inplace(state_dict, old_key: str, new_key: str) -> None:
4038+
state_dict[new_key] = state_dict.pop(old_key)
4039+
4040+
def remove_keys_inplace(key: str, state_dict) -> None:
4041+
state_dict.pop(key)
4042+
4043+
LTX_2_0_VAE_SPECIAL_KEYS_REMAP = {
4044+
"per_channel_statistics.channel": remove_keys_inplace,
4045+
"per_channel_statistics.mean-of-stds": remove_keys_inplace,
4046+
}
4047+
4048+
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
4049+
4050+
# Handle official code --> diffusers key remapping via the remap dict
4051+
for key in list(converted_state_dict.keys()):
4052+
new_key = key[:]
4053+
for replace_key, rename_key in LTX_2_0_VIDEO_VAE_RENAME_DICT.items():
4054+
new_key = new_key.replace(replace_key, rename_key)
4055+
4056+
update_state_dict_inplace(converted_state_dict, key, new_key)
4057+
4058+
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
4059+
# special_keys_remap
4060+
for key in list(converted_state_dict.keys()):
4061+
for special_key, handler_fn_inplace in LTX_2_0_VAE_SPECIAL_KEYS_REMAP.items():
4062+
if special_key not in key:
4063+
continue
4064+
handler_fn_inplace(key, converted_state_dict)
4065+
4066+
return converted_state_dict
4067+
4068+
4069+
def convert_ltx2_audio_vae_to_diffusers(checkpoint, **kwargs):
4070+
LTX_2_0_AUDIO_VAE_RENAME_DICT = {
4071+
# Audio VAE prefix
4072+
"audio_vae.": "",
4073+
"per_channel_statistics.mean-of-means": "latents_mean",
4074+
"per_channel_statistics.std-of-means": "latents_std",
4075+
}
4076+
4077+
def update_state_dict_inplace(state_dict, old_key: str, new_key: str) -> None:
4078+
state_dict[new_key] = state_dict.pop(old_key)
4079+
4080+
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
4081+
4082+
# Handle official code --> diffusers key remapping via the remap dict
4083+
for key in list(converted_state_dict.keys()):
4084+
new_key = key[:]
4085+
for replace_key, rename_key in LTX_2_0_AUDIO_VAE_RENAME_DICT.items():
4086+
new_key = new_key.replace(replace_key, rename_key)
4087+
4088+
update_state_dict_inplace(converted_state_dict, key, new_key)
4089+
4090+
return converted_state_dict

0 commit comments

Comments
 (0)