|
| 1 | +import jax |
| 2 | +import jax.numpy as jnp |
| 3 | +import numpy as np |
| 4 | +from flax import nnx |
| 5 | +from flax.traverse_util import unflatten_dict, flatten_dict |
| 6 | +from maxdiffusion.utils import max_logging |
| 7 | +from maxdiffusion.models.flax_pytorch_utils import ( |
| 8 | + load_sharded_checkpoint, |
| 9 | + validate_flax_state_dict, |
| 10 | +) |
| 11 | +from maxdiffusion.models.ltx2.ltx2_utils import ( |
| 12 | + _tuple_str_to_int, |
| 13 | + LTX_2_0_VIDEO_VAE_RENAME_DICT, |
| 14 | +) |
| 15 | + |
| 16 | +LTX_2_3_VIDEO_VAE_RENAME_DICT = { |
| 17 | + **LTX_2_0_VIDEO_VAE_RENAME_DICT, |
| 18 | + # Decoder extra blocks |
| 19 | + "up_blocks.7": "up_blocks.3.upsamplers.0", |
| 20 | + "up_blocks.8": "up_blocks.3", |
| 21 | +} |
| 22 | + |
| 23 | +LTX_2_3_CONNECTORS_KEYS_RENAME_DICT = { |
| 24 | + "connectors.": "", |
| 25 | + "video_embeddings_connector": "video_connector", |
| 26 | + "audio_embeddings_connector": "audio_connector", |
| 27 | + "transformer_1d_blocks": "transformer_blocks", |
| 28 | + "text_embedding_projection.audio_aggregate_embed": "audio_text_proj_in", |
| 29 | + "text_embedding_projection.video_aggregate_embed": "video_text_proj_in", |
| 30 | + "q_norm": "norm_q", |
| 31 | + "k_norm": "norm_k", |
| 32 | +} |
| 33 | + |
| 34 | +def load_connectors_weights( |
| 35 | + pretrained_model_name_or_path: str, |
| 36 | + eval_shapes: dict, |
| 37 | + device: str, |
| 38 | + hf_download: bool = True, |
| 39 | + subfolder: str = "", |
| 40 | + filename: str = None, |
| 41 | +): |
| 42 | + device = jax.local_devices(backend=device)[0] |
| 43 | + max_logging.log(f"Load and port {pretrained_model_name_or_path} Connectors on {device}") |
| 44 | + |
| 45 | + with jax.default_device(device): |
| 46 | + tensors = load_sharded_checkpoint(pretrained_model_name_or_path, subfolder, device, filename=filename) |
| 47 | + flax_state_dict = {} |
| 48 | + cpu = jax.local_devices(backend="cpu")[0] |
| 49 | + flattened_eval = flatten_dict(eval_shapes) |
| 50 | + |
| 51 | + for pt_key, tensor in tensors.items(): |
| 52 | + if not any(x in pt_key for x in ["connectors.", "video_embeddings_connector", "audio_embeddings_connector"]): |
| 53 | + continue |
| 54 | + |
| 55 | + flax_key_str = pt_key |
| 56 | + for replace_key, rename_to in LTX_2_3_CONNECTORS_KEYS_RENAME_DICT.items(): |
| 57 | + flax_key_str = flax_key_str.replace(replace_key, rename_to) |
| 58 | + |
| 59 | + flax_key = _tuple_str_to_int(flax_key_str.split(".")) |
| 60 | + flax_state_dict[flax_key] = jax.device_put(tensor, device=cpu) |
| 61 | + |
| 62 | + filtered_eval_shapes = { |
| 63 | + k: v for k, v in flattened_eval.items() if not any("dropout" in str(x) or "rngs" in str(x) for x in k) |
| 64 | + } |
| 65 | + validate_flax_state_dict(unflatten_dict(filtered_eval_shapes), flax_state_dict) |
| 66 | + return unflatten_dict(flax_state_dict) |
0 commit comments