@@ -81,7 +81,6 @@ def get_git_commit_hash():
8181
8282
8383def call_pipeline (config , pipeline , prompt , negative_prompt ):
84- # Set default generation arguments
8584 generator = jax .random .key (config .seed ) if hasattr (config , "seed" ) else jax .random .key (0 )
8685 guidance_scale = config .guidance_scale if hasattr (config , "guidance_scale" ) else 3.0
8786
@@ -99,6 +98,7 @@ def call_pipeline(config, pipeline, prompt, negative_prompt):
9998 decode_noise_scale = getattr (config , "decode_noise_scale" , None ),
10099 max_sequence_length = getattr (config , "max_sequence_length" , 1024 ),
101100 dtype = jnp .bfloat16 if getattr (config , "activations_dtype" , "bfloat16" ) == "bfloat16" else jnp .float32 ,
101+ output_type = getattr (config , "upsampler_output_type" , "pil" ),
102102 )
103103 return out
104104
@@ -114,9 +114,11 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
114114 else :
115115 max_logging .log ("Could not retrieve Git commit hash." )
116116
117+ checkpoint_loader = LTX2Checkpointer (config = config )
117118 if pipeline is None :
118- checkpoint_loader = LTX2Checkpointer (config = config )
119- pipeline , _ , _ = checkpoint_loader .load_checkpoint ()
119+ # Use the config flag to determine if the upsampler should be loaded
120+ run_latent_upsampler = getattr (config , "run_latent_upsampler" , False )
121+ pipeline , _ , _ = checkpoint_loader .load_checkpoint (load_upsampler = run_latent_upsampler )
120122
121123 pipeline .enable_vae_slicing ()
122124 pipeline .enable_vae_tiling ()
@@ -135,6 +137,7 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
135137 )
136138
137139 out = call_pipeline (config , pipeline , prompt , negative_prompt )
140+
138141 # out should have .frames and .audio
139142 videos = out .frames if hasattr (out , "frames" ) else out [0 ]
140143 audios = out .audio if hasattr (out , "audio" ) else None
@@ -143,6 +146,8 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
143146 max_logging .log (f"model name: { getattr (config , 'model_name' , 'ltx-video' )} " )
144147 max_logging .log (f"model path: { config .pretrained_model_name_or_path } " )
145148 max_logging .log (f"model type: { getattr (config , 'model_type' , 'T2V' )} " )
149+ if getattr (config , "run_latent_upsampler" , False ):
150+ max_logging .log (f"upsampler model path: { config .upsampler_model_path } " )
146151 max_logging .log (f"hardware: { jax .devices ()[0 ].platform } " )
147152 max_logging .log (f"number of devices: { jax .device_count ()} " )
148153 max_logging .log (f"per_device_batch_size: { config .per_device_batch_size } " )
0 commit comments