Skip to content

Commit 4c47e5b

Browse files
committed
generate_audio flag added
1 parent c041251 commit 4c47e5b

3 files changed

Lines changed: 37 additions & 28 deletions

File tree

src/maxdiffusion/configs/ltx2_3_video.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ audio_stg_scale: 1.0
3636
modality_scale: 3.0
3737
audio_modality_scale: 3.0
3838
use_cross_timestep: true
39+
generate_audio: True
3940
spatio_temporal_guidance_blocks: [28]
4041
fps: 24
4142
pipeline_type: multi-scale

src/maxdiffusion/configs/ltx2_video.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ audio_stg_scale: 0.0
4242
modality_scale: 1.0
4343
audio_modality_scale: 1.0
4444
use_cross_timestep: false
45+
generate_audio: True
4546
spatio_temporal_guidance_blocks: []
4647
noise_scale: 1.0
4748
fps: 24

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 35 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -654,8 +654,9 @@ def _create_common_components(cls, config: HyperParameters, vae_only=False):
654654
components["tokenizer"] = cls.load_tokenizer(config)
655655
components["text_encoder"] = cls.load_text_encoder(config)
656656
components["connectors"] = cls.load_connectors(devices_array, mesh, rngs, config)
657-
components["audio_vae"] = cls.load_audio_vae(devices_array, mesh, rngs, config)
658-
components["vocoder"] = cls.load_vocoder(devices_array, mesh, rngs, config)
657+
if getattr(config, "generate_audio", True):
658+
components["audio_vae"] = cls.load_audio_vae(devices_array, mesh, rngs, config)
659+
components["vocoder"] = cls.load_vocoder(devices_array, mesh, rngs, config)
659660
components["scheduler"] = cls.load_scheduler(config)
660661

661662
return components
@@ -1325,15 +1326,18 @@ def __call__(
13251326
)
13261327
audio_num_frames = round(duration_s * audio_latents_per_second)
13271328

1328-
audio_latents = self.prepare_audio_latents(
1329-
batch_size=batch_size,
1330-
num_channels_latents=audio_channels,
1331-
audio_latent_length=audio_num_frames,
1332-
noise_scale=noise_scale,
1333-
dtype=dtype,
1334-
generator=key_audio,
1335-
latents=audio_latents,
1336-
)
1329+
if getattr(self.config, "generate_audio", True):
1330+
audio_latents = self.prepare_audio_latents(
1331+
batch_size=batch_size,
1332+
num_channels_latents=audio_channels,
1333+
audio_latent_length=audio_num_frames,
1334+
noise_scale=noise_scale,
1335+
dtype=dtype,
1336+
generator=key_audio,
1337+
latents=audio_latents,
1338+
)
1339+
else:
1340+
audio_latents = jnp.zeros((batch_size, audio_channels, audio_num_frames), dtype=dtype)
13371341

13381342
# 5. Prepare Timesteps
13391343
sigmas = jnp.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
@@ -1700,24 +1704,27 @@ def convert_to_vel(lat, x0, sig):
17001704
video = self.video_processor.postprocess_video(torch.from_numpy(video_np), output_type=output_type)
17011705

17021706
# Decode Audio
1703-
import time
1704-
audio_latents = audio_latents.astype(self.audio_vae.dtype)
1705-
generated_mel_spectrograms = self.audio_vae.decode(audio_latents, return_dict=False)[0]
1706-
1707-
# Audio VAE outputs (B, T, F, C), Vocoder expects (B, Channels, Time, MelBins)
1708-
generated_mel_spectrograms = generated_mel_spectrograms.transpose(0, 3, 1, 2)
1709-
1710-
vocoder_start_time = time.time()
1711-
# Cache the JITted function on the pipeline so it doesn't recompile on the 2nd run
1712-
if not hasattr(self, "_jitted_vocoder"):
1713-
self._jitted_vocoder = nnx.jit(lambda m, x: m(x))
1714-
1715-
audio = self._jitted_vocoder(self.vocoder, generated_mel_spectrograms)
1716-
jax.block_until_ready(audio)
1717-
max_logging.log(f"⏱️ BWE Vocoder Execution Time: {time.time() - vocoder_start_time:.4f} seconds")
1707+
if getattr(self.config, "generate_audio", True) and self.audio_vae is not None:
1708+
import time
1709+
audio_latents = audio_latents.astype(self.audio_vae.dtype)
1710+
generated_mel_spectrograms = self.audio_vae.decode(audio_latents, return_dict=False)[0]
1711+
1712+
# Audio VAE outputs (B, T, F, C), Vocoder expects (B, Channels, Time, MelBins)
1713+
generated_mel_spectrograms = generated_mel_spectrograms.transpose(0, 3, 1, 2)
1714+
1715+
vocoder_start_time = time.time()
1716+
# Cache the JITted function on the pipeline so it doesn't recompile on the 2nd run
1717+
if not hasattr(self, "_jitted_vocoder"):
1718+
self._jitted_vocoder = nnx.jit(lambda m, x: m(x))
1719+
1720+
audio = self._jitted_vocoder(self.vocoder, generated_mel_spectrograms)
1721+
jax.block_until_ready(audio)
1722+
max_logging.log(f"⏱️ BWE Vocoder Execution Time: {time.time() - vocoder_start_time:.4f} seconds")
17181723

1719-
# Convert audio to numpy
1720-
audio = np.array(audio)
1724+
# Convert audio to numpy
1725+
audio = np.array(audio)
1726+
else:
1727+
audio = None
17211728

17221729
return LTX2PipelineOutput(frames=video, audio=audio)
17231730

0 commit comments

Comments
 (0)