@@ -654,8 +654,9 @@ def _create_common_components(cls, config: HyperParameters, vae_only=False):
654654 components ["tokenizer" ] = cls .load_tokenizer (config )
655655 components ["text_encoder" ] = cls .load_text_encoder (config )
656656 components ["connectors" ] = cls .load_connectors (devices_array , mesh , rngs , config )
657- components ["audio_vae" ] = cls .load_audio_vae (devices_array , mesh , rngs , config )
658- components ["vocoder" ] = cls .load_vocoder (devices_array , mesh , rngs , config )
657+ if getattr (config , "generate_audio" , True ):
658+ components ["audio_vae" ] = cls .load_audio_vae (devices_array , mesh , rngs , config )
659+ components ["vocoder" ] = cls .load_vocoder (devices_array , mesh , rngs , config )
659660 components ["scheduler" ] = cls .load_scheduler (config )
660661
661662 return components
@@ -1325,15 +1326,18 @@ def __call__(
13251326 )
13261327 audio_num_frames = round (duration_s * audio_latents_per_second )
13271328
1328- audio_latents = self .prepare_audio_latents (
1329- batch_size = batch_size ,
1330- num_channels_latents = audio_channels ,
1331- audio_latent_length = audio_num_frames ,
1332- noise_scale = noise_scale ,
1333- dtype = dtype ,
1334- generator = key_audio ,
1335- latents = audio_latents ,
1336- )
1329+ if getattr (self .config , "generate_audio" , True ):
1330+ audio_latents = self .prepare_audio_latents (
1331+ batch_size = batch_size ,
1332+ num_channels_latents = audio_channels ,
1333+ audio_latent_length = audio_num_frames ,
1334+ noise_scale = noise_scale ,
1335+ dtype = dtype ,
1336+ generator = key_audio ,
1337+ latents = audio_latents ,
1338+ )
1339+ else :
1340+ audio_latents = jnp .zeros ((batch_size , audio_channels , audio_num_frames ), dtype = dtype )
13371341
13381342 # 5. Prepare Timesteps
13391343 sigmas = jnp .linspace (1.0 , 1 / num_inference_steps , num_inference_steps ) if sigmas is None else sigmas
@@ -1700,24 +1704,27 @@ def convert_to_vel(lat, x0, sig):
17001704 video = self .video_processor .postprocess_video (torch .from_numpy (video_np ), output_type = output_type )
17011705
17021706 # Decode Audio
1703- import time
1704- audio_latents = audio_latents .astype (self .audio_vae .dtype )
1705- generated_mel_spectrograms = self .audio_vae .decode (audio_latents , return_dict = False )[0 ]
1706-
1707- # Audio VAE outputs (B, T, F, C), Vocoder expects (B, Channels, Time, MelBins)
1708- generated_mel_spectrograms = generated_mel_spectrograms .transpose (0 , 3 , 1 , 2 )
1709-
1710- vocoder_start_time = time .time ()
1711- # Cache the JITted function on the pipeline so it doesn't recompile on the 2nd run
1712- if not hasattr (self , "_jitted_vocoder" ):
1713- self ._jitted_vocoder = nnx .jit (lambda m , x : m (x ))
1714-
1715- audio = self ._jitted_vocoder (self .vocoder , generated_mel_spectrograms )
1716- jax .block_until_ready (audio )
1717- max_logging .log (f"⏱️ BWE Vocoder Execution Time: { time .time () - vocoder_start_time :.4f} seconds" )
1707+ if getattr (self .config , "generate_audio" , True ) and self .audio_vae is not None :
1708+ import time
1709+ audio_latents = audio_latents .astype (self .audio_vae .dtype )
1710+ generated_mel_spectrograms = self .audio_vae .decode (audio_latents , return_dict = False )[0 ]
1711+
1712+ # Audio VAE outputs (B, T, F, C), Vocoder expects (B, Channels, Time, MelBins)
1713+ generated_mel_spectrograms = generated_mel_spectrograms .transpose (0 , 3 , 1 , 2 )
1714+
1715+ vocoder_start_time = time .time ()
1716+ # Cache the JITted function on the pipeline so it doesn't recompile on the 2nd run
1717+ if not hasattr (self , "_jitted_vocoder" ):
1718+ self ._jitted_vocoder = nnx .jit (lambda m , x : m (x ))
1719+
1720+ audio = self ._jitted_vocoder (self .vocoder , generated_mel_spectrograms )
1721+ jax .block_until_ready (audio )
1722+ max_logging .log (f"⏱️ BWE Vocoder Execution Time: { time .time () - vocoder_start_time :.4f} seconds" )
17181723
1719- # Convert audio to numpy
1720- audio = np .array (audio )
1724+ # Convert audio to numpy
1725+ audio = np .array (audio )
1726+ else :
1727+ audio = None
17211728
17221729 return LTX2PipelineOutput (frames = video , audio = audio )
17231730
0 commit comments