Skip to content

Commit 38c340c

Browse files
committed
add Flax/JAX implementation of LTX-2 Latent Upsampler
1 parent 6de9d57 commit 38c340c

9 files changed

Lines changed: 1000 additions & 21 deletions

File tree

src/maxdiffusion/checkpointing/ltx2_checkpointer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,19 +79,19 @@ def load_ltx2_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[di
7979
return restored_checkpoint, step
8080

8181
def load_checkpoint(
82-
self, step=None, vae_only=False, load_transformer=True
82+
self, step=None, vae_only=False, load_transformer=True, load_upsampler=False
8383
) -> Tuple[LTX2Pipeline, Optional[dict], Optional[int]]:
8484
restored_checkpoint, step = self.load_ltx2_configs_from_orbax(step)
8585
opt_state = None
8686

8787
if restored_checkpoint:
8888
max_logging.log("Loading LTX2 pipeline from checkpoint")
89-
pipeline = LTX2Pipeline.from_checkpoint(self.config, restored_checkpoint, vae_only, load_transformer)
89+
pipeline = LTX2Pipeline.from_checkpoint(self.config, restored_checkpoint, vae_only, load_transformer, load_upsampler)
9090
if "opt_state" in restored_checkpoint.ltx2_state.keys():
9191
opt_state = restored_checkpoint.ltx2_state["opt_state"]
9292
else:
9393
max_logging.log("No checkpoint found, loading pipeline from pretrained hub")
94-
pipeline = LTX2Pipeline.from_pretrained(self.config, vae_only, load_transformer)
94+
pipeline = LTX2Pipeline.from_pretrained(self.config, vae_only, load_transformer, load_upsampler)
9595

9696
return pipeline, opt_state, step
9797

src/maxdiffusion/configs/ltx2_video.yml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,3 +102,13 @@ jit_initializers: True
102102
enable_single_replica_ckpt_restoring: False
103103
seed: 0
104104
audio_format: "s16"
105+
106+
# LTX-2 Latent Upsampler
107+
run_latent_upsampler: False
108+
upsampler_model_path: "Lightricks/LTX-2"
109+
upsampler_spatial_patch_size: 1
110+
upsampler_temporal_patch_size: 1
111+
upsampler_adain_factor: 0.0
112+
upsampler_tone_map_compression_ratio: 0.0
113+
upsampler_rational_spatial_scale: 2.0
114+
upsampler_output_type: "pil"

src/maxdiffusion/generate_ltx2.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,6 @@ def get_git_commit_hash():
8181

8282

8383
def call_pipeline(config, pipeline, prompt, negative_prompt):
84-
# Set default generation arguments
8584
generator = jax.random.key(config.seed) if hasattr(config, "seed") else jax.random.key(0)
8685
guidance_scale = config.guidance_scale if hasattr(config, "guidance_scale") else 3.0
8786

@@ -99,6 +98,7 @@ def call_pipeline(config, pipeline, prompt, negative_prompt):
9998
decode_noise_scale=getattr(config, "decode_noise_scale", None),
10099
max_sequence_length=getattr(config, "max_sequence_length", 1024),
101100
dtype=jnp.bfloat16 if getattr(config, "activations_dtype", "bfloat16") == "bfloat16" else jnp.float32,
101+
output_type=getattr(config, "upsampler_output_type", "pil"),
102102
)
103103
return out
104104

@@ -114,9 +114,11 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
114114
else:
115115
max_logging.log("Could not retrieve Git commit hash.")
116116

117+
checkpoint_loader = LTX2Checkpointer(config=config)
117118
if pipeline is None:
118-
checkpoint_loader = LTX2Checkpointer(config=config)
119-
pipeline, _, _ = checkpoint_loader.load_checkpoint()
119+
# Use the config flag to determine if the upsampler should be loaded
120+
run_latent_upsampler = getattr(config, "run_latent_upsampler", False)
121+
pipeline, _, _ = checkpoint_loader.load_checkpoint(load_upsampler=run_latent_upsampler)
120122

121123
pipeline.enable_vae_slicing()
122124
pipeline.enable_vae_tiling()
@@ -135,6 +137,7 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
135137
)
136138

137139
out = call_pipeline(config, pipeline, prompt, negative_prompt)
140+
138141
# out should have .frames and .audio
139142
videos = out.frames if hasattr(out, "frames") else out[0]
140143
audios = out.audio if hasattr(out, "audio") else None
@@ -143,6 +146,8 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
143146
max_logging.log(f"model name: {getattr(config, 'model_name', 'ltx-video')}")
144147
max_logging.log(f"model path: {config.pretrained_model_name_or_path}")
145148
max_logging.log(f"model type: {getattr(config, 'model_type', 'T2V')}")
149+
if getattr(config, "run_latent_upsampler", False):
150+
max_logging.log(f"upsampler model path: {config.upsampler_model_path}")
146151
max_logging.log(f"hardware: {jax.devices()[0].platform}")
147152
max_logging.log(f"number of devices: {jax.device_count()}")
148153
max_logging.log(f"per_device_batch_size: {config.per_device_batch_size}")

0 commit comments

Comments
 (0)