@@ -1640,20 +1640,22 @@ def convert_to_vel(lat, x0, sig):
16401640 # =======================================================================
16411641
16421642 # Denormalize and Unpack Audio (Order important: Denorm THEN Unpack)
1643- audio_latents = self ._denormalize_audio_latents (
1644- audio_latents_jax , self .audio_vae .latents_mean .value , self .audio_vae .latents_std .value
1645- )
1643+ audio_latents = None
1644+ if getattr (self .config , "generate_audio" , True ) and self .audio_vae is not None :
1645+ audio_latents = self ._denormalize_audio_latents (
1646+ audio_latents_jax , self .audio_vae .latents_mean .value , self .audio_vae .latents_std .value
1647+ )
16461648
1647- num_mel_bins = self .audio_vae .config .mel_bins if getattr ( self , "audio_vae" , None ) is not None else 64
1648- latent_mel_bins = num_mel_bins // self .audio_vae_mel_compression_ratio
1649+ num_mel_bins = self .audio_vae .config .mel_bins
1650+ latent_mel_bins = num_mel_bins // self .audio_vae_mel_compression_ratio
16491651
1650- audio_latents = self ._unpack_audio_latents (
1651- audio_latents , audio_num_frames , num_mel_bins = latent_mel_bins , num_channels = audio_channels
1652- )
1652+ audio_latents = self ._unpack_audio_latents (
1653+ audio_latents , audio_num_frames , num_mel_bins = latent_mel_bins , num_channels = audio_channels
1654+ )
16531655
1654- # Audio VAE expects channels last (B, T, F, C) but unpack returns (B, C, T, F)
1655- if audio_latents .ndim == 4 :
1656- audio_latents = audio_latents .transpose (0 , 2 , 3 , 1 )
1656+ # Audio VAE expects channels last (B, T, F, C) but unpack returns (B, C, T, F)
1657+ if audio_latents .ndim == 4 :
1658+ audio_latents = audio_latents .transpose (0 , 2 , 3 , 1 )
16571659
16581660 if output_type == "latent" :
16591661 return LTX2PipelineOutput (frames = latents , audio = audio_latents )
0 commit comments