Skip to content

Commit d61134c

Browse files
committed
generate_audio flag added
1 parent 56e4606 commit d61134c

1 file changed

Lines changed: 13 additions & 11 deletions

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1640,20 +1640,22 @@ def convert_to_vel(lat, x0, sig):
16401640
# =======================================================================
16411641

16421642
# Denormalize and Unpack Audio (Order important: Denorm THEN Unpack)
1643-
audio_latents = self._denormalize_audio_latents(
1644-
audio_latents_jax, self.audio_vae.latents_mean.value, self.audio_vae.latents_std.value
1645-
)
1643+
audio_latents = None
1644+
if getattr(self.config, "generate_audio", True) and self.audio_vae is not None:
1645+
audio_latents = self._denormalize_audio_latents(
1646+
audio_latents_jax, self.audio_vae.latents_mean.value, self.audio_vae.latents_std.value
1647+
)
16461648

1647-
num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64
1648-
latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio
1649+
num_mel_bins = self.audio_vae.config.mel_bins
1650+
latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio
16491651

1650-
audio_latents = self._unpack_audio_latents(
1651-
audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins, num_channels=audio_channels
1652-
)
1652+
audio_latents = self._unpack_audio_latents(
1653+
audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins, num_channels=audio_channels
1654+
)
16531655

1654-
# Audio VAE expects channels last (B, T, F, C) but unpack returns (B, C, T, F)
1655-
if audio_latents.ndim == 4:
1656-
audio_latents = audio_latents.transpose(0, 2, 3, 1)
1656+
# Audio VAE expects channels last (B, T, F, C) but unpack returns (B, C, T, F)
1657+
if audio_latents.ndim == 4:
1658+
audio_latents = audio_latents.transpose(0, 2, 3, 1)
16571659

16581660
if output_type == "latent":
16591661
return LTX2PipelineOutput(frames=latents, audio=audio_latents)

0 commit comments

Comments
 (0)