Skip to content

Commit 07f56d2

Browse files
Merge pull request #372 from AI-Hypercomputer:ltx2_lora
PiperOrigin-RevId: 899043023
2 parents 12b267c + 9810681 commit 07f56d2

4 files changed

Lines changed: 216 additions & 0 deletions

File tree

src/maxdiffusion/configs/ltx2_video.yml

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,26 @@ enable_single_replica_ckpt_restoring: False
106106
seed: 0
107107
audio_format: "s16"
108108

109+
# LoRA parameters
110+
enable_lora: False
111+
112+
# Distilled LoRA
113+
# lora_config: {
114+
# lora_model_name_or_path: ["Lightricks/LTX-2"],
115+
# weight_name: ["ltx-2-19b-distilled-lora-384.safetensors"],
116+
# adapter_name: ["distilled-lora-384"],
117+
# rank: [384]
118+
# }
119+
120+
# Standard LoRA
121+
lora_config: {
122+
lora_model_name_or_path: ["Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In"],
123+
weight_name: ["ltx-2-19b-lora-camera-control-dolly-in.safetensors"],
124+
adapter_name: ["camera-control-dolly-in"],
125+
rank: [32]
126+
}
127+
128+
109129
# LTX-2 Latent Upsampler
110130
run_latent_upsampler: False
111131
upsampler_model_path: "Lightricks/LTX-2"

src/maxdiffusion/generate_ltx2.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from google.api_core.exceptions import GoogleAPIError
2626
import flax
2727
from maxdiffusion.utils.export_utils import export_to_video_with_audio
28+
from maxdiffusion.loaders.ltx2_lora_nnx_loader import LTX2NNXLoraLoader
2829

2930

3031
def upload_video_to_gcs(output_dir: str, video_path: str):
@@ -120,6 +121,31 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
120121
run_latent_upsampler = getattr(config, "run_latent_upsampler", False)
121122
pipeline, _, _ = checkpoint_loader.load_checkpoint(load_upsampler=run_latent_upsampler)
122123

124+
# If LoRA is specified, inject layers and load weights.
125+
if (
126+
getattr(config, "enable_lora", False)
127+
and hasattr(config, "lora_config")
128+
and config.lora_config
129+
and config.lora_config.get("lora_model_name_or_path")
130+
):
131+
lora_loader = LTX2NNXLoraLoader()
132+
lora_config = config.lora_config
133+
paths = lora_config["lora_model_name_or_path"]
134+
weights = lora_config.get("weight_name", [None] * len(paths))
135+
scales = lora_config.get("scale", [1.0] * len(paths))
136+
ranks = lora_config.get("rank", [64] * len(paths))
137+
138+
for i in range(len(paths)):
139+
pipeline = lora_loader.load_lora_weights(
140+
pipeline,
141+
paths[i],
142+
transformer_weight_name=weights[i],
143+
rank=ranks[i],
144+
scale=scales[i],
145+
scan_layers=config.scan_layers,
146+
dtype=config.weights_dtype,
147+
)
148+
123149
pipeline.enable_vae_slicing()
124150
pipeline.enable_vae_tiling()
125151

src/maxdiffusion/loaders/lora_conversion_utils.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -703,3 +703,98 @@ def translate_wan_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=False):
703703
return f"diffusion_model.blocks.{idx}.{suffix_map[inner_suffix]}"
704704

705705
return None
706+
707+
708+
def translate_ltx2_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=False):
709+
"""
710+
Translates LTX2 NNX path to Diffusers/LoRA keys.
711+
"""
712+
# --- 2. Map NNX Suffixes to LoRA Suffixes ---
713+
suffix_map = {
714+
# Self Attention (attn1)
715+
"attn1.to_q": "attn1.to_q",
716+
"attn1.to_k": "attn1.to_k",
717+
"attn1.to_v": "attn1.to_v",
718+
"attn1.to_out": "attn1.to_out.0",
719+
# Audio Self Attention (audio_attn1)
720+
"audio_attn1.to_q": "audio_attn1.to_q",
721+
"audio_attn1.to_k": "audio_attn1.to_k",
722+
"audio_attn1.to_v": "audio_attn1.to_v",
723+
"audio_attn1.to_out": "audio_attn1.to_out.0",
724+
# Audio Cross Attention (audio_attn2)
725+
"audio_attn2.to_q": "audio_attn2.to_q",
726+
"audio_attn2.to_k": "audio_attn2.to_k",
727+
"audio_attn2.to_v": "audio_attn2.to_v",
728+
"audio_attn2.to_out": "audio_attn2.to_out.0",
729+
# Cross Attention (attn2)
730+
"attn2.to_q": "attn2.to_q",
731+
"attn2.to_k": "attn2.to_k",
732+
"attn2.to_v": "attn2.to_v",
733+
"attn2.to_out": "attn2.to_out.0",
734+
# Audio to Video Cross Attention
735+
"audio_to_video_attn.to_q": "audio_to_video_attn.to_q",
736+
"audio_to_video_attn.to_k": "audio_to_video_attn.to_k",
737+
"audio_to_video_attn.to_v": "audio_to_video_attn.to_v",
738+
"audio_to_video_attn.to_out": "audio_to_video_attn.to_out.0",
739+
# Video to Audio Cross Attention
740+
"video_to_audio_attn.to_q": "video_to_audio_attn.to_q",
741+
"video_to_audio_attn.to_k": "video_to_audio_attn.to_k",
742+
"video_to_audio_attn.to_v": "video_to_audio_attn.to_v",
743+
"video_to_audio_attn.to_out": "video_to_audio_attn.to_out.0",
744+
# Feed Forward
745+
"ff.net_0": "ff.net.0.proj",
746+
"ff.net_2": "ff.net.2",
747+
# Audio Feed Forward
748+
"audio_ff.net_0": "audio_ff.net.0.proj",
749+
"audio_ff.net_2": "audio_ff.net.2",
750+
}
751+
752+
# --- 3. Translation Logic ---
753+
global_map = {
754+
"proj_in": "diffusion_model.patchify_proj",
755+
"audio_proj_in": "diffusion_model.audio_patchify_proj",
756+
"proj_out": "diffusion_model.proj_out",
757+
"audio_proj_out": "diffusion_model.audio_proj_out",
758+
"time_embed.linear": "diffusion_model.adaln_single.linear",
759+
"audio_time_embed.linear": "diffusion_model.audio_adaln_single.linear",
760+
"av_cross_attn_video_a2v_gate.linear": "diffusion_model.av_ca_a2v_gate_adaln_single.linear",
761+
"av_cross_attn_audio_v2a_gate.linear": "diffusion_model.av_ca_v2a_gate_adaln_single.linear",
762+
"av_cross_attn_audio_scale_shift.linear": "diffusion_model.av_ca_audio_scale_shift_adaln_single.linear",
763+
"av_cross_attn_video_scale_shift.linear": "diffusion_model.av_ca_video_scale_shift_adaln_single.linear",
764+
# Nested conditioning layers
765+
"time_embed.emb.timestep_embedder.linear_1": "diffusion_model.adaln_single.emb.timestep_embedder.linear_1",
766+
"time_embed.emb.timestep_embedder.linear_2": "diffusion_model.adaln_single.emb.timestep_embedder.linear_2",
767+
"audio_time_embed.emb.timestep_embedder.linear_1": "diffusion_model.audio_adaln_single.emb.timestep_embedder.linear_1",
768+
"audio_time_embed.emb.timestep_embedder.linear_2": "diffusion_model.audio_adaln_single.emb.timestep_embedder.linear_2",
769+
"av_cross_attn_video_scale_shift.emb.timestep_embedder.linear_1": "diffusion_model.av_ca_video_scale_shift_adaln_single.emb.timestep_embedder.linear_1",
770+
"av_cross_attn_video_scale_shift.emb.timestep_embedder.linear_2": "diffusion_model.av_ca_video_scale_shift_adaln_single.emb.timestep_embedder.linear_2",
771+
"av_cross_attn_audio_scale_shift.emb.timestep_embedder.linear_1": "diffusion_model.av_ca_audio_scale_shift_adaln_single.emb.timestep_embedder.linear_1",
772+
"av_cross_attn_audio_scale_shift.emb.timestep_embedder.linear_2": "diffusion_model.av_ca_audio_scale_shift_adaln_single.emb.timestep_embedder.linear_2",
773+
"av_cross_attn_video_a2v_gate.emb.timestep_embedder.linear_1": "diffusion_model.av_ca_a2v_gate_adaln_single.emb.timestep_embedder.linear_1",
774+
"av_cross_attn_video_a2v_gate.emb.timestep_embedder.linear_2": "diffusion_model.av_ca_a2v_gate_adaln_single.emb.timestep_embedder.linear_2",
775+
"av_cross_attn_audio_v2a_gate.emb.timestep_embedder.linear_1": "diffusion_model.av_ca_v2a_gate_adaln_single.emb.timestep_embedder.linear_1",
776+
"av_cross_attn_audio_v2a_gate.emb.timestep_embedder.linear_2": "diffusion_model.av_ca_v2a_gate_adaln_single.emb.timestep_embedder.linear_2",
777+
"caption_projection.linear_1": "diffusion_model.caption_projection.linear_1",
778+
"caption_projection.linear_2": "diffusion_model.caption_projection.linear_2",
779+
"audio_caption_projection.linear_1": "diffusion_model.audio_caption_projection.linear_1",
780+
"audio_caption_projection.linear_2": "diffusion_model.audio_caption_projection.linear_2",
781+
# Connectors
782+
"feature_extractor.linear": "text_embedding_projection.aggregate_embed",
783+
}
784+
785+
if nnx_path_str in global_map:
786+
return global_map[nnx_path_str]
787+
788+
if scan_layers:
789+
if nnx_path_str.startswith("transformer_blocks."):
790+
inner_suffix = nnx_path_str[len("transformer_blocks.") :]
791+
if inner_suffix in suffix_map:
792+
return f"diffusion_model.transformer_blocks.{{}}.{suffix_map[inner_suffix]}"
793+
else:
794+
m = re.match(r"^transformer_blocks\.(\d+)\.(.+)$", nnx_path_str)
795+
if m:
796+
idx, inner_suffix = m.group(1), m.group(2)
797+
if inner_suffix in suffix_map:
798+
return f"diffusion_model.transformer_blocks.{idx}.{suffix_map[inner_suffix]}"
799+
800+
return None
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""NNX-based LoRA loader for LTX2 models."""
16+
17+
from flax import nnx
18+
from .lora_base import LoRABaseMixin
19+
from .lora_pipeline import StableDiffusionLoraLoaderMixin
20+
from ..models import lora_nnx
21+
from .. import max_logging
22+
from . import lora_conversion_utils
23+
24+
25+
class LTX2NNXLoraLoader(LoRABaseMixin):
26+
"""
27+
Handles loading LoRA weights into NNX-based LTX2 model.
28+
Assumes LTX2 pipeline contains 'transformer'
29+
attributes that are NNX Modules.
30+
"""
31+
32+
def load_lora_weights(
33+
self,
34+
pipeline: nnx.Module,
35+
lora_model_path: str,
36+
transformer_weight_name: str,
37+
rank: int,
38+
scale: float = 1.0,
39+
scan_layers: bool = False,
40+
dtype: str = "float32",
41+
**kwargs,
42+
):
43+
"""
44+
Merges LoRA weights into the pipeline from a checkpoint.
45+
"""
46+
lora_loader = StableDiffusionLoraLoaderMixin()
47+
48+
merge_fn = lora_nnx.merge_lora_for_scanned if scan_layers else lora_nnx.merge_lora
49+
50+
def translate_fn(nnx_path_str):
51+
return lora_conversion_utils.translate_ltx2_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=scan_layers)
52+
53+
h_state_dict = None
54+
if hasattr(pipeline, "transformer") and transformer_weight_name:
55+
max_logging.log(f"Merging LoRA into transformer with rank={rank}")
56+
h_state_dict, _ = lora_loader.lora_state_dict(lora_model_path, weight_name=transformer_weight_name, **kwargs)
57+
# Filter state dict for transformer keys to avoid confusing warnings
58+
transformer_state_dict = {k: v for k, v in h_state_dict.items() if k.startswith("diffusion_model")}
59+
merge_fn(pipeline.transformer, transformer_state_dict, rank, scale, translate_fn, dtype=dtype)
60+
else:
61+
max_logging.log("transformer not found or no weight name provided for LoRA.")
62+
63+
if hasattr(pipeline, "connectors"):
64+
max_logging.log(f"Merging LoRA into connectors with rank={rank}")
65+
if h_state_dict is None and transformer_weight_name:
66+
h_state_dict, _ = lora_loader.lora_state_dict(lora_model_path, weight_name=transformer_weight_name, **kwargs)
67+
68+
if h_state_dict is not None:
69+
# Filter state dict for connector keys to avoid confusing warnings
70+
connector_state_dict = {k: v for k, v in h_state_dict.items() if k.startswith("text_embedding_projection")}
71+
merge_fn(pipeline.connectors, connector_state_dict, rank, scale, translate_fn, dtype=dtype)
72+
else:
73+
max_logging.log("Could not load LoRA state dict for connectors.")
74+
75+
return pipeline

0 commit comments

Comments
 (0)