@@ -82,7 +82,6 @@ def get_git_commit_hash():
8282
8383
8484def call_pipeline (config , pipeline , prompt , negative_prompt ):
85- # Set default generation arguments
8685 generator = jax .random .key (config .seed ) if hasattr (config , "seed" ) else jax .random .key (0 )
8786 guidance_scale = config .guidance_scale if hasattr (config , "guidance_scale" ) else 3.0
8887
@@ -100,6 +99,7 @@ def call_pipeline(config, pipeline, prompt, negative_prompt):
10099 decode_noise_scale = getattr (config , "decode_noise_scale" , None ),
101100 max_sequence_length = getattr (config , "max_sequence_length" , 1024 ),
102101 dtype = jnp .bfloat16 if getattr (config , "activations_dtype" , "bfloat16" ) == "bfloat16" else jnp .float32 ,
102+ output_type = getattr (config , "upsampler_output_type" , "pil" ),
103103 )
104104 return out
105105
@@ -115,9 +115,11 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
115115 else :
116116 max_logging .log ("Could not retrieve Git commit hash." )
117117
118+ checkpoint_loader = LTX2Checkpointer (config = config )
118119 if pipeline is None :
119- checkpoint_loader = LTX2Checkpointer (config = config )
120- pipeline , _ , _ = checkpoint_loader .load_checkpoint ()
120+ # Use the config flag to determine if the upsampler should be loaded
121+ run_latent_upsampler = getattr (config , "run_latent_upsampler" , False )
122+ pipeline , _ , _ = checkpoint_loader .load_checkpoint (load_upsampler = run_latent_upsampler )
121123
122124 # If LoRA is specified, inject layers and load weights.
123125 if (
@@ -161,6 +163,7 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
161163 )
162164
163165 out = call_pipeline (config , pipeline , prompt , negative_prompt )
166+
164167 # out should have .frames and .audio
165168 videos = out .frames if hasattr (out , "frames" ) else out [0 ]
166169 audios = out .audio if hasattr (out , "audio" ) else None
@@ -169,6 +172,8 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
169172 max_logging .log (f"model name: { getattr (config , 'model_name' , 'ltx-video' )} " )
170173 max_logging .log (f"model path: { config .pretrained_model_name_or_path } " )
171174 max_logging .log (f"model type: { getattr (config , 'model_type' , 'T2V' )} " )
175+ if getattr (config , "run_latent_upsampler" , False ):
176+ max_logging .log (f"upsampler model path: { config .upsampler_model_path } " )
172177 max_logging .log (f"hardware: { jax .devices ()[0 ].platform } " )
173178 max_logging .log (f"number of devices: { jax .device_count ()} " )
174179 max_logging .log (f"per_device_batch_size: { config .per_device_batch_size } " )
0 commit comments