diff --git a/setup.py b/setup.py index 7e7470d3..ebea2f28 100644 --- a/setup.py +++ b/setup.py @@ -16,6 +16,7 @@ 'einops-exts==0.0.4', 'ema-pytorch==0.2.3', 'encodec==0.1.1', + 'ffmpeg-python==0.2.0', 'gradio>=3.42.0', 'huggingface_hub', 'importlib-resources==5.12.0', diff --git a/stable_audio_tools/interface/gradio.py b/stable_audio_tools/interface/gradio.py index b46c8d43..65c85c3f 100644 --- a/stable_audio_tools/interface/gradio.py +++ b/stable_audio_tools/interface/gradio.py @@ -1,4 +1,7 @@ import gc +import os +import ffmpeg + import numpy as np import gradio as gr import json @@ -185,16 +188,30 @@ def progress_callback(callback_info): scale_phi = cfg_rescale ) - # Convert to WAV file + # Convert to WAV file (temporary file) audio = rearrange(audio, "b d n -> d (b n)") audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu() - torchaudio.save("output.wav", audio, sample_rate) + torchaudio.save("temp_output.wav", audio, sample_rate) + + # Trim audio using ffmpeg + trim_audio("temp_output.wav", "output.wav", seconds_total) # Let's look at a nice spectrogram too audio_spectrogram = audio_spectrogram_image(audio, sample_rate=sample_rate) return ("output.wav", [audio_spectrogram, *preview_images]) +def trim_audio(input_file, output_file, duration_seconds): + stream = ffmpeg.input(input_file) + audio_stream = stream.audio + trimmed = audio_stream.filter('atrim', end=duration_seconds) + output = ffmpeg.output(trimmed, output_file) + if os.path.exists(output_file): + os.remove(output_file) + ffmpeg.run(output) + os.remove(input_file) # removes the temp file + return + def generate_uncond( steps=250, seed=-1, @@ -399,7 +416,7 @@ def create_sampling_ui(model_config, inpainting=False): with gr.Row(visible = has_seconds_start or has_seconds_total): # Timing controls seconds_start_slider = gr.Slider(minimum=0, maximum=512, step=1, value=0, label="Seconds start", visible=has_seconds_start) - seconds_total_slider = gr.Slider(minimum=0, maximum=512, step=1, value=sample_size//sample_rate, label="Seconds total", visible=has_seconds_total) + seconds_total_slider = gr.Slider(minimum=1, maximum=47, step=1, value=sample_size//sample_rate, label="Seconds total", visible=has_seconds_total) with gr.Row(): # Steps slider