Skip to content

Commit f2d8574

Browse files
committed
jit vocoder
1 parent de7f386 commit f2d8574

2 files changed

Lines changed: 7 additions & 20 deletions

File tree

src/maxdiffusion/models/ltx2/vocoder_ltx2.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -678,54 +678,35 @@ def __init__(
678678
window_type="hann",
679679
)
680680

681+
@nnx.jit
681682
def __call__(self, mel_spec: Array) -> Array:
682-
print(f"=== BWE Vocoder Debug ===")
683-
print(f"Input mel_spec - shape: {mel_spec.shape}, min: {mel_spec.min()}, max: {mel_spec.max()}")
684-
685683
x = self.vocoder(mel_spec)
686-
print(f"Base vocoder output (x) - shape: {x.shape}, min: {x.min()}, max: {x.max()}")
687-
688684
x = jnp.transpose(x, (0, 2, 1))
689685
batch_size, num_samples, num_channels = x.shape
690-
print(f"Transposed x - shape: {x.shape}")
691686

692687
remainder = num_samples % self.hop_length
693688
if remainder != 0:
694689
x = jnp.pad(x, ((0, 0), (0, self.hop_length - remainder), (0, 0)))
695-
print(f"Padded x - shape: {x.shape}")
696690

697691
x_flattened = x.transpose(0, 2, 1).reshape(-1, x.shape[1], 1)
698-
print(f"x_flattened - shape: {x_flattened.shape}")
699-
700692
log_mel, _, _, _ = self.mel_stft(x_flattened)
701-
print(f"MelSTFT output (log_mel) before reshape - shape: {log_mel.shape}, min: {log_mel.min()}, max: {log_mel.max()}")
702693

703694
log_mel = log_mel.reshape(batch_size, num_channels, -1, log_mel.shape[-1])
704-
print(f"Reshaped log_mel - shape: {log_mel.shape}")
705-
706695
residual = self.bwe_generator(log_mel, time_last=False)
707-
print(f"BWE generator output (residual) - shape: {residual.shape}, min: {residual.min()}, max: {residual.max()}")
708-
709696
skip = self.resampler(x)
710-
print(f"Resampler output (skip) - shape: {skip.shape}, min: {skip.min()}, max: {skip.max()}")
711697

712698
residual = jnp.transpose(residual, (0, 2, 1))
713699

714700
if residual.shape[1] < skip.shape[1]:
715701
residual = jnp.pad(residual, ((0, 0), (0, skip.shape[1] - residual.shape[1]), (0, 0)), mode='edge')
716702
elif residual.shape[1] > skip.shape[1]:
717703
residual = residual[:, :skip.shape[1], :]
718-
print(f"Matched residual - shape: {residual.shape}")
719704

720705
raw_waveform = residual + skip
721-
print(f"Raw waveform (residual + skip) - min: {raw_waveform.min()}, max: {raw_waveform.max()}")
722-
723706
waveform = jnp.clip(raw_waveform, -1, 1)
724707

725708
output_samples = num_samples * self.output_sampling_rate // self.input_sampling_rate
726709
waveform = waveform[:, :output_samples, :]
727710
waveform = jnp.transpose(waveform, (0, 2, 1))
728-
print(f"Final waveform - shape: {waveform.shape}, min: {waveform.min()}, max: {waveform.max()}")
729-
print(f"=========================")
730711

731712
return waveform

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1655,12 +1655,18 @@ def convert_to_vel(lat, x0, sig):
16551655
video = self.video_processor.postprocess_video(torch.from_numpy(video_np), output_type=output_type)
16561656

16571657
# Decode Audio
1658+
import time
16581659
audio_latents = audio_latents.astype(self.audio_vae.dtype)
16591660
generated_mel_spectrograms = self.audio_vae.decode(audio_latents, return_dict=False)[0]
16601661

16611662
# Audio VAE outputs (B, T, F, C), Vocoder expects (B, Channels, Time, MelBins)
16621663
generated_mel_spectrograms = generated_mel_spectrograms.transpose(0, 3, 1, 2)
1664+
1665+
vocoder_start_time = time.time()
16631666
audio = self.vocoder(generated_mel_spectrograms)
1667+
jax.block_until_ready(audio)
1668+
vocoder_execution_time = time.time() - vocoder_start_time
1669+
max_logging.log(f"BWE Vocoder Execution Time: {vocoder_execution_time:.4f} seconds")
16641670

16651671
# Convert audio to numpy
16661672
audio = np.array(audio)

0 commit comments

Comments
 (0)