|
112 | 112 | "model.diffusion_model.transformer_blocks.27.scale_shift_table", |
113 | 113 | "patchify_proj.weight", |
114 | 114 | "transformer_blocks.27.scale_shift_table", |
115 | | - "vae.per_channel_statistics.mean-of-means", |
| 115 | + "vae.decoder.last_scale_shift_table", # 0.9.1, 0.9.5, 0.9.7, 0.9.8 |
| 116 | + "vae.decoder.up_blocks.9.res_blocks.0.conv1.conv.weight", # 0.9.0 |
116 | 117 | ], |
117 | 118 | "autoencoder-dc": "decoder.stages.1.op_list.0.main.conv.conv.bias", |
118 | 119 | "autoencoder-dc-sana": "encoder.project_in.conv.bias", |
|
147 | 148 | "net.pos_embedder.dim_spatial_range", |
148 | 149 | ], |
149 | 150 | "flux2": ["model.diffusion_model.single_stream_modulation.lin.weight", "single_stream_modulation.lin.weight"], |
| 151 | + "ltx2": [ |
| 152 | + "model.diffusion_model.av_ca_a2v_gate_adaln_single.emb.timestep_embedder.linear_1.weight", |
| 153 | + "vae.per_channel_statistics.mean-of-means", |
| 154 | + "audio_vae.per_channel_statistics.mean-of-means", |
| 155 | + ], |
150 | 156 | } |
151 | 157 |
|
152 | 158 | DIFFUSERS_DEFAULT_PIPELINE_PATHS = { |
|
228 | 234 | "z-image-turbo-controlnet": {"pretrained_model_name_or_path": "hlky/Z-Image-Turbo-Fun-Controlnet-Union"}, |
229 | 235 | "z-image-turbo-controlnet-2.0": {"pretrained_model_name_or_path": "hlky/Z-Image-Turbo-Fun-Controlnet-Union-2.0"}, |
230 | 236 | "z-image-turbo-controlnet-2.1": {"pretrained_model_name_or_path": "hlky/Z-Image-Turbo-Fun-Controlnet-Union-2.1"}, |
| 237 | + "ltx2-dev": {"pretrained_model_name_or_path": "Lightricks/LTX-2"}, |
231 | 238 | } |
232 | 239 |
|
233 | 240 | # Use to configure model sample size when original config is provided |
@@ -796,6 +803,9 @@ def infer_diffusers_model_type(checkpoint): |
796 | 803 | elif CHECKPOINT_KEY_NAMES["z-image-turbo-controlnet"] in checkpoint: |
797 | 804 | model_type = "z-image-turbo-controlnet" |
798 | 805 |
|
| 806 | + elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["ltx2"]): |
| 807 | + model_type = "ltx2-dev" |
| 808 | + |
799 | 809 | else: |
800 | 810 | model_type = "v1" |
801 | 811 |
|
@@ -3920,3 +3930,161 @@ def convert_z_image_controlnet_checkpoint_to_diffusers(checkpoint, config, **kwa |
3920 | 3930 | return converted_state_dict |
3921 | 3931 | else: |
3922 | 3932 | raise ValueError("Unknown Z-Image Turbo ControlNet type.") |
| 3933 | + |
| 3934 | + |
| 3935 | +def convert_ltx2_transformer_to_diffusers(checkpoint, **kwargs): |
| 3936 | + LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT = { |
| 3937 | + # Transformer prefix |
| 3938 | + "model.diffusion_model.": "", |
| 3939 | + # Input Patchify Projections |
| 3940 | + "patchify_proj": "proj_in", |
| 3941 | + "audio_patchify_proj": "audio_proj_in", |
| 3942 | + # Modulation Parameters |
| 3943 | + # Handle adaln_single --> time_embed, audioln_single --> audio_time_embed separately as the original keys are |
| 3944 | + # substrings of the other modulation parameters below |
| 3945 | + "av_ca_video_scale_shift_adaln_single": "av_cross_attn_video_scale_shift", |
| 3946 | + "av_ca_a2v_gate_adaln_single": "av_cross_attn_video_a2v_gate", |
| 3947 | + "av_ca_audio_scale_shift_adaln_single": "av_cross_attn_audio_scale_shift", |
| 3948 | + "av_ca_v2a_gate_adaln_single": "av_cross_attn_audio_v2a_gate", |
| 3949 | + # Transformer Blocks |
| 3950 | + # Per-Block Cross Attention Modulation Parameters |
| 3951 | + "scale_shift_table_a2v_ca_video": "video_a2v_cross_attn_scale_shift_table", |
| 3952 | + "scale_shift_table_a2v_ca_audio": "audio_a2v_cross_attn_scale_shift_table", |
| 3953 | + # Attention QK Norms |
| 3954 | + "q_norm": "norm_q", |
| 3955 | + "k_norm": "norm_k", |
| 3956 | + } |
| 3957 | + |
| 3958 | + def update_state_dict_inplace(state_dict, old_key: str, new_key: str) -> None: |
| 3959 | + state_dict[new_key] = state_dict.pop(old_key) |
| 3960 | + |
| 3961 | + def remove_keys_inplace(key: str, state_dict) -> None: |
| 3962 | + state_dict.pop(key) |
| 3963 | + |
| 3964 | + def convert_ltx2_transformer_adaln_single(key: str, state_dict) -> None: |
| 3965 | + # Skip if not a weight, bias |
| 3966 | + if ".weight" not in key and ".bias" not in key: |
| 3967 | + return |
| 3968 | + |
| 3969 | + if key.startswith("adaln_single."): |
| 3970 | + new_key = key.replace("adaln_single.", "time_embed.") |
| 3971 | + param = state_dict.pop(key) |
| 3972 | + state_dict[new_key] = param |
| 3973 | + |
| 3974 | + if key.startswith("audio_adaln_single."): |
| 3975 | + new_key = key.replace("audio_adaln_single.", "audio_time_embed.") |
| 3976 | + param = state_dict.pop(key) |
| 3977 | + state_dict[new_key] = param |
| 3978 | + |
| 3979 | + return |
| 3980 | + |
| 3981 | + LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP = { |
| 3982 | + "video_embeddings_connector": remove_keys_inplace, |
| 3983 | + "audio_embeddings_connector": remove_keys_inplace, |
| 3984 | + "adaln_single": convert_ltx2_transformer_adaln_single, |
| 3985 | + } |
| 3986 | + |
| 3987 | + converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())} |
| 3988 | + |
| 3989 | + # Handle official code --> diffusers key remapping via the remap dict |
| 3990 | + for key in list(converted_state_dict.keys()): |
| 3991 | + new_key = key[:] |
| 3992 | + for replace_key, rename_key in LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT.items(): |
| 3993 | + new_key = new_key.replace(replace_key, rename_key) |
| 3994 | + |
| 3995 | + update_state_dict_inplace(converted_state_dict, key, new_key) |
| 3996 | + |
| 3997 | + # Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in |
| 3998 | + # special_keys_remap |
| 3999 | + for key in list(converted_state_dict.keys()): |
| 4000 | + for special_key, handler_fn_inplace in LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP.items(): |
| 4001 | + if special_key not in key: |
| 4002 | + continue |
| 4003 | + handler_fn_inplace(key, converted_state_dict) |
| 4004 | + |
| 4005 | + return converted_state_dict |
| 4006 | + |
| 4007 | + |
| 4008 | +def convert_ltx2_vae_to_diffusers(checkpoint, **kwargs): |
| 4009 | + LTX_2_0_VIDEO_VAE_RENAME_DICT = { |
| 4010 | + # Video VAE prefix |
| 4011 | + "vae.": "", |
| 4012 | + # Encoder |
| 4013 | + "down_blocks.0": "down_blocks.0", |
| 4014 | + "down_blocks.1": "down_blocks.0.downsamplers.0", |
| 4015 | + "down_blocks.2": "down_blocks.1", |
| 4016 | + "down_blocks.3": "down_blocks.1.downsamplers.0", |
| 4017 | + "down_blocks.4": "down_blocks.2", |
| 4018 | + "down_blocks.5": "down_blocks.2.downsamplers.0", |
| 4019 | + "down_blocks.6": "down_blocks.3", |
| 4020 | + "down_blocks.7": "down_blocks.3.downsamplers.0", |
| 4021 | + "down_blocks.8": "mid_block", |
| 4022 | + # Decoder |
| 4023 | + "up_blocks.0": "mid_block", |
| 4024 | + "up_blocks.1": "up_blocks.0.upsamplers.0", |
| 4025 | + "up_blocks.2": "up_blocks.0", |
| 4026 | + "up_blocks.3": "up_blocks.1.upsamplers.0", |
| 4027 | + "up_blocks.4": "up_blocks.1", |
| 4028 | + "up_blocks.5": "up_blocks.2.upsamplers.0", |
| 4029 | + "up_blocks.6": "up_blocks.2", |
| 4030 | + # Common |
| 4031 | + # For all 3D ResNets |
| 4032 | + "res_blocks": "resnets", |
| 4033 | + "per_channel_statistics.mean-of-means": "latents_mean", |
| 4034 | + "per_channel_statistics.std-of-means": "latents_std", |
| 4035 | + } |
| 4036 | + |
| 4037 | + def update_state_dict_inplace(state_dict, old_key: str, new_key: str) -> None: |
| 4038 | + state_dict[new_key] = state_dict.pop(old_key) |
| 4039 | + |
| 4040 | + def remove_keys_inplace(key: str, state_dict) -> None: |
| 4041 | + state_dict.pop(key) |
| 4042 | + |
| 4043 | + LTX_2_0_VAE_SPECIAL_KEYS_REMAP = { |
| 4044 | + "per_channel_statistics.channel": remove_keys_inplace, |
| 4045 | + "per_channel_statistics.mean-of-stds": remove_keys_inplace, |
| 4046 | + } |
| 4047 | + |
| 4048 | + converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())} |
| 4049 | + |
| 4050 | + # Handle official code --> diffusers key remapping via the remap dict |
| 4051 | + for key in list(converted_state_dict.keys()): |
| 4052 | + new_key = key[:] |
| 4053 | + for replace_key, rename_key in LTX_2_0_VIDEO_VAE_RENAME_DICT.items(): |
| 4054 | + new_key = new_key.replace(replace_key, rename_key) |
| 4055 | + |
| 4056 | + update_state_dict_inplace(converted_state_dict, key, new_key) |
| 4057 | + |
| 4058 | + # Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in |
| 4059 | + # special_keys_remap |
| 4060 | + for key in list(converted_state_dict.keys()): |
| 4061 | + for special_key, handler_fn_inplace in LTX_2_0_VAE_SPECIAL_KEYS_REMAP.items(): |
| 4062 | + if special_key not in key: |
| 4063 | + continue |
| 4064 | + handler_fn_inplace(key, converted_state_dict) |
| 4065 | + |
| 4066 | + return converted_state_dict |
| 4067 | + |
| 4068 | + |
| 4069 | +def convert_ltx2_audio_vae_to_diffusers(checkpoint, **kwargs): |
| 4070 | + LTX_2_0_AUDIO_VAE_RENAME_DICT = { |
| 4071 | + # Audio VAE prefix |
| 4072 | + "audio_vae.": "", |
| 4073 | + "per_channel_statistics.mean-of-means": "latents_mean", |
| 4074 | + "per_channel_statistics.std-of-means": "latents_std", |
| 4075 | + } |
| 4076 | + |
| 4077 | + def update_state_dict_inplace(state_dict, old_key: str, new_key: str) -> None: |
| 4078 | + state_dict[new_key] = state_dict.pop(old_key) |
| 4079 | + |
| 4080 | + converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())} |
| 4081 | + |
| 4082 | + # Handle official code --> diffusers key remapping via the remap dict |
| 4083 | + for key in list(converted_state_dict.keys()): |
| 4084 | + new_key = key[:] |
| 4085 | + for replace_key, rename_key in LTX_2_0_AUDIO_VAE_RENAME_DICT.items(): |
| 4086 | + new_key = new_key.replace(replace_key, rename_key) |
| 4087 | + |
| 4088 | + update_state_dict_inplace(converted_state_dict, key, new_key) |
| 4089 | + |
| 4090 | + return converted_state_dict |
0 commit comments