Skip to content

Commit 80771b1

Browse files
authored
LTXVid text2vid pipeline (#208)
* set up files for ltxvid * ltx-video-transformer-setup * formatting * conversion script added * format fixed * conversion script checked * comments removed * Added running instructions * edited instruction * ruff check error fixed * mesh edit * key error fix * transformer step and test * removed diffusers import * fixed mesh * changed path * changed path * changed config path * ruff check * changed back pyconfig * changed sharding back * removed testing for now * Update pyconfig.py * Update max_utils.py * Update ltx_video.yml * Delete src/maxdiffusion/tests/ltx_vid_transformer_test_ref_pred * Delete src/maxdiffusion/tests/ltx_transformer_step_test.py * added header * ruff fixed * added header * license headers * exclude test * auto script * headers * pulled * change base branch * save now * load transformer error * later * changed repeatable layer * Update max_utils.py * functional * moved upsampler * initial cleaning * multiscale pipeline * remove init * new empty folders * downloaded files * changed upsampler * kept latents as jnp * prepare latents * save * fixed transformer init * error attribute weight already exist * baseline pipeline cleaned * pipeline cleaned * added timing * pipeline cleaned, licence added * changed output to cmd line * added init file * changed input format * updated requirements * merged in conversion script * fixed importing error * fixed importing issue * merged from main * Delete myenv directory * changed ckpt name * style fix
1 parent 76f84a1 commit 80771b1

51 files changed

Lines changed: 10854 additions & 8 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

Whitespace-only changes.

.github/workflows/UnitTests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ jobs:
5454
ruff check .
5555
- name: PyTest
5656
run: |
57-
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ python3 -m pytest -x
57+
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ python3 -m pytest -x --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py
5858
# add_pull_ready:
5959
# if: github.ref != 'refs/heads/main'
6060
# permissions:

requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ pytest==8.2.2
2323
tensorflow>=2.17.0
2424
tensorflow-datasets>=4.9.6
2525
ruff>=0.1.5,<=0.2
26+
git+https://github.com/Lightricks/LTX-Video
27+
git+https://github.com/zmelumian972/xla@torchax/jittable_module_callable#subdirectory=torchax
2628
opencv-python-headless==4.10.0.84
2729
orbax-checkpoint
2830
tokenizers==0.21.0

setup.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,4 +112,4 @@ else
112112
fi
113113

114114
# Install maxdiffusion
115-
pip3 install -U . || echo "Failed to install maxdiffusion" >&2
115+
pip3 install -U . || echo "Failed to install maxdiffusion" >&2

src/maxdiffusion/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,7 @@
374374
_import_structure["models.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
375375
_import_structure["models.flux.transformers.transformer_flux_flax"] = ["FluxTransformer2DModel"]
376376
_import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"]
377+
_import_structure["models.ltx_video.transformers.transformer3d"] = ["Transformer3DModel"]
377378
_import_structure["pipelines"].extend(["FlaxDiffusionPipeline"])
378379
_import_structure["schedulers"].extend(
379380
[
@@ -453,6 +454,7 @@
453454
from .models.modeling_flax_utils import FlaxModelMixin
454455
from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
455456
from .models.flux.transformers.transformer_flux_flax import FluxTransformer2DModel
457+
from .models.ltx_video.transformers.transformer3d import Transformer3DModel
456458
from .models.vae_flax import FlaxAutoencoderKL
457459
from .pipelines import FlaxDiffusionPipeline
458460
from .schedulers import (

src/maxdiffusion/checkpointing/checkpointing_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,8 +213,11 @@ def load_state_if_possible(
213213
max_logging.log(f"restoring from this run's directory latest step {latest_step}")
214214
try:
215215
if not enable_single_replica_ckpt_restoring:
216-
item = {checkpoint_item: orbax.checkpoint.args.PyTreeRestore(item=abstract_unboxed_pre_state)}
217-
return checkpoint_manager.restore(latest_step, args=orbax.checkpoint.args.Composite(**item))
216+
if checkpoint_item == "ltxvid_transformer":
217+
return checkpoint_manager.restore(latest_step, args=ocp.args.StandardRestore(abstract_unboxed_pre_state))
218+
else:
219+
item = {checkpoint_item: orbax.checkpoint.args.PyTreeRestore(item=abstract_unboxed_pre_state)}
220+
return checkpoint_manager.restore(latest_step, args=orbax.checkpoint.args.Composite(**item))
218221

219222
def map_to_pspec(data):
220223
pspec = data.sharding.spec
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
#hardware
2+
hardware: 'tpu'
3+
skip_jax_distributed_system: False
4+
5+
jax_cache_dir: ''
6+
weights_dtype: 'bfloat16'
7+
activations_dtype: 'bfloat16'
8+
9+
10+
run_name: ''
11+
output_dir: ''
12+
config_path: ''
13+
save_config_to_gcs: False
14+
15+
#Checkpoints
16+
text_encoder_model_name_or_path: "ariG23498/t5-v1-1-xxl-flax"
17+
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
18+
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
19+
frame_rate: 30
20+
max_sequence_length: 512
21+
sampler: "from_checkpoint"
22+
23+
# Generation parameters
24+
pipeline_type: multi-scale
25+
prompt: "A man in a dimly lit room talks on a vintage telephone, hangs up, and looks down with a sad expression. He holds the black rotary phone to his right ear with his right hand, his left hand holding a rocks glass with amber liquid. He wears a brown suit jacket over a white shirt, and a gold ring on his left ring finger. His short hair is neatly combed, and he has light skin with visible wrinkles around his eyes. The camera remains stationary, focused on his face and upper body. The room is dark, lit only by a warm light source off-screen to the left, casting shadows on the wall behind him. The scene appears to be from a movie. "
26+
#negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
27+
height: 512
28+
width: 512
29+
num_frames: 88
30+
flow_shift: 5.0
31+
downscale_factor: 0.6666666
32+
spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.7.safetensors"
33+
prompt_enhancement_words_threshold: 120
34+
stg_mode: "attention_values"
35+
decode_timestep: 0.05
36+
decode_noise_scale: 0.025
37+
seed: 10
38+
39+
40+
first_pass:
41+
guidance_scale: [1, 1, 6, 8, 6, 1, 1]
42+
stg_scale: [0, 0, 4, 4, 4, 2, 1]
43+
rescaling_scale: [1, 1, 0.5, 0.5, 1, 1, 1]
44+
guidance_timesteps: [1.0, 0.996, 0.9933, 0.9850, 0.9767, 0.9008, 0.6180]
45+
skip_block_list: [[], [11, 25, 35, 39], [22, 35, 39], [28], [28], [28], [28]]
46+
num_inference_steps: 30
47+
skip_final_inference_steps: 3
48+
skip_initial_inference_steps: 0
49+
cfg_star_rescale: True
50+
51+
second_pass:
52+
guidance_scale: [1]
53+
stg_scale: [1]
54+
rescaling_scale: [1]
55+
guidance_timesteps: [1.0]
56+
skip_block_list: [27]
57+
num_inference_steps: 30
58+
skip_initial_inference_steps: 17
59+
skip_final_inference_steps: 0
60+
cfg_star_rescale: True
61+
62+
#parallelism
63+
mesh_axes: ['data', 'fsdp', 'tensor']
64+
logical_axis_rules: [
65+
['batch', 'data'],
66+
['activation_heads', 'fsdp'],
67+
['activation_batch', 'data'],
68+
['activation_kv', 'tensor'],
69+
['mlp','tensor'],
70+
['embed','fsdp'],
71+
['heads', 'tensor'],
72+
['norm', 'fsdp'],
73+
['conv_batch', ['data','fsdp']],
74+
['out_channels', 'tensor'],
75+
['conv_out', 'fsdp'],
76+
['conv_in', 'fsdp']
77+
]
78+
data_sharding: [['data', 'fsdp', 'tensor']]
79+
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
80+
dcn_fsdp_parallelism: -1
81+
dcn_tensor_parallelism: 1
82+
ici_data_parallelism: 1
83+
ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded
84+
ici_tensor_parallelism: 1
85+
86+
allow_split_physical_axes: False
87+
learning_rate_schedule_steps: -1
88+
max_train_steps: 500
89+
pretrained_model_name_or_path: ''
90+
unet_checkpoint: ''
91+
dataset_name: 'diffusers/pokemon-gpt4-captions'
92+
train_split: 'train'
93+
dataset_type: 'tf'
94+
cache_latents_text_encoder_outputs: True
95+
per_device_batch_size: 1
96+
compile_topology_num_slices: -1
97+
quantization_local_shard_count: -1
98+
jit_initializers: True
99+
enable_single_replica_ckpt_restoring: False
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import numpy as np
18+
from absl import app
19+
from typing import Sequence
20+
from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXVideoPipeline
21+
from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXMultiScalePipeline
22+
from maxdiffusion import pyconfig, max_logging
23+
import imageio
24+
from datetime import datetime
25+
import os
26+
import time
27+
from pathlib import Path
28+
29+
30+
def calculate_padding(
31+
source_height: int, source_width: int, target_height: int, target_width: int
32+
) -> tuple[int, int, int, int]:
33+
34+
# Calculate total padding needed
35+
pad_height = target_height - source_height
36+
pad_width = target_width - source_width
37+
38+
# Calculate padding for each side
39+
pad_top = pad_height // 2
40+
pad_bottom = pad_height - pad_top # Handles odd padding
41+
pad_left = pad_width // 2
42+
pad_right = pad_width - pad_left # Handles odd padding
43+
padding = (pad_left, pad_right, pad_top, pad_bottom)
44+
return padding
45+
46+
47+
def convert_prompt_to_filename(text: str, max_len: int = 20) -> str:
48+
# Remove non-letters and convert to lowercase
49+
clean_text = "".join(char.lower() for char in text if char.isalpha() or char.isspace())
50+
51+
# Split into words
52+
words = clean_text.split()
53+
54+
# Build result string keeping track of length
55+
result = []
56+
current_length = 0
57+
58+
for word in words:
59+
# Add word length plus 1 for underscore (except for first word)
60+
new_length = current_length + len(word)
61+
62+
if new_length <= max_len:
63+
result.append(word)
64+
current_length += len(word)
65+
else:
66+
break
67+
68+
return "-".join(result)
69+
70+
71+
def get_unique_filename(
72+
base: str,
73+
ext: str,
74+
prompt: str,
75+
resolution: tuple[int, int, int],
76+
dir: Path,
77+
endswith=None,
78+
index_range=1000,
79+
) -> Path:
80+
base_filename = f"{base}_{convert_prompt_to_filename(prompt, max_len=30)}_{resolution[0]}x{resolution[1]}x{resolution[2]}"
81+
for i in range(index_range):
82+
filename = dir / f"{base_filename}_{i}{endswith if endswith else ''}{ext}"
83+
if not os.path.exists(filename):
84+
return filename
85+
raise FileExistsError(f"Could not find a unique filename after {index_range} attempts.")
86+
87+
88+
def run(config):
89+
height_padded = ((config.height - 1) // 32 + 1) * 32
90+
width_padded = ((config.width - 1) // 32 + 1) * 32
91+
num_frames_padded = ((config.num_frames - 2) // 8 + 1) * 8 + 1
92+
padding = calculate_padding(config.height, config.width, height_padded, width_padded)
93+
prompt_enhancement_words_threshold = config.prompt_enhancement_words_threshold
94+
prompt_word_count = len(config.prompt.split())
95+
enhance_prompt = prompt_enhancement_words_threshold > 0 and prompt_word_count < prompt_enhancement_words_threshold
96+
97+
pipeline = LTXVideoPipeline.from_pretrained(config, enhance_prompt=enhance_prompt)
98+
if config.pipeline_type == "multi-scale":
99+
pipeline = LTXMultiScalePipeline(pipeline)
100+
s0 = time.perf_counter()
101+
images = pipeline(
102+
height=height_padded,
103+
width=width_padded,
104+
num_frames=num_frames_padded,
105+
is_video=True,
106+
output_type="pt",
107+
config=config,
108+
enhance_prompt=enhance_prompt,
109+
seed=config.seed,
110+
)
111+
max_logging.log(f"Compile time: {time.perf_counter() - s0:.1f}s.")
112+
113+
(pad_left, pad_right, pad_top, pad_bottom) = padding
114+
pad_bottom = -pad_bottom
115+
pad_right = -pad_right
116+
if pad_bottom == 0:
117+
pad_bottom = images.shape[3]
118+
if pad_right == 0:
119+
pad_right = images.shape[4]
120+
images = images[:, :, : config.num_frames, pad_top:pad_bottom, pad_left:pad_right]
121+
output_dir = Path(f"outputs/{datetime.today().strftime('%Y-%m-%d')}")
122+
output_dir.mkdir(parents=True, exist_ok=True)
123+
124+
for i in range(images.shape[0]):
125+
# Gathering from B, C, F, H, W to C, F, H, W and then permuting to F, H, W, C
126+
video_np = images[i].permute(1, 2, 3, 0).detach().float().numpy()
127+
# Unnormalizing images to [0, 255] range
128+
video_np = (video_np * 255).astype(np.uint8)
129+
fps = config.frame_rate
130+
height, width = video_np.shape[1:3]
131+
# In case a single image is generated
132+
if video_np.shape[0] == 1:
133+
output_filename = get_unique_filename(
134+
f"image_output_{i}",
135+
".png",
136+
prompt=config.prompt,
137+
resolution=(height, width, config.num_frames),
138+
dir=output_dir,
139+
)
140+
imageio.imwrite(output_filename, video_np[0])
141+
else:
142+
output_filename = get_unique_filename(
143+
f"video_output_{i}",
144+
".mp4",
145+
prompt=config.prompt,
146+
resolution=(height, width, config.num_frames),
147+
dir=output_dir,
148+
)
149+
# Write video
150+
with imageio.get_writer(output_filename, fps=fps) as video:
151+
for frame in video_np:
152+
video.append_data(frame)
153+
154+
155+
def main(argv: Sequence[str]) -> None:
156+
pyconfig.initialize(argv)
157+
run(pyconfig.config)
158+
159+
160+
if __name__ == "__main__":
161+
app.run(main)

src/maxdiffusion/max_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,10 @@ def setup_initial_state(
405405
config.enable_single_replica_ckpt_restoring,
406406
)
407407
if state:
408-
state = state[checkpoint_item]
408+
if checkpoint_item == "ltxvid_transformer":
409+
state = state
410+
else:
411+
state = state[checkpoint_item]
409412
if not state:
410413
max_logging.log(f"Could not find the item in orbax, creating state...")
411414
init_train_state_partial = functools.partial(

src/maxdiffusion/models/__init__.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,7 @@
1313
# limitations under the License.
1414

1515
from typing import TYPE_CHECKING
16-
17-
from ..utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, is_flax_available, is_torch_available
18-
16+
from maxdiffusion.utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, is_flax_available, is_torch_available
1917

2018
_import_structure = {}
2119

@@ -32,6 +30,7 @@
3230
from .vae_flax import FlaxAutoencoderKL
3331
from .lora import *
3432
from .flux.transformers.transformer_flux_flax import FluxTransformer2DModel
33+
from .ltx_video.transformers.transformer3d import Transformer3DModel
3534

3635
else:
3736
import sys

0 commit comments

Comments
 (0)