From 4119a12f59f6268aa9c982abe78c69b763a3acba Mon Sep 17 00:00:00 2001 From: cyber <19499442+cyberofficial@users.noreply.github.com> Date: Sun, 16 Jun 2024 01:15:41 -0400 Subject: [PATCH] [update] Auto trim the audio file * Auto trims the audio file based on the seconds totals * Seconds total limited to 1 second minimum and 47 maximum * FFMPEG Python added as a requirement in setup.py * Import os added to imports to handle file removal of output file if older one exists. (Prevents future conflicts) --- setup.py | 1 + stable_audio_tools/interface/gradio.py | 23 ++++++++++++++++++++--- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 7e7470d3..ebea2f28 100644 --- a/setup.py +++ b/setup.py @@ -16,6 +16,7 @@ 'einops-exts==0.0.4', 'ema-pytorch==0.2.3', 'encodec==0.1.1', + 'ffmpeg-python==0.2.0', 'gradio>=3.42.0', 'huggingface_hub', 'importlib-resources==5.12.0', diff --git a/stable_audio_tools/interface/gradio.py b/stable_audio_tools/interface/gradio.py index b46c8d43..65c85c3f 100644 --- a/stable_audio_tools/interface/gradio.py +++ b/stable_audio_tools/interface/gradio.py @@ -1,4 +1,7 @@ import gc +import os +import ffmpeg + import numpy as np import gradio as gr import json @@ -185,16 +188,30 @@ def progress_callback(callback_info): scale_phi = cfg_rescale ) - # Convert to WAV file + # Convert to WAV file (temporary file) audio = rearrange(audio, "b d n -> d (b n)") audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu() - torchaudio.save("output.wav", audio, sample_rate) + torchaudio.save("temp_output.wav", audio, sample_rate) + + # Trim audio using ffmpeg + trim_audio("temp_output.wav", "output.wav", seconds_total) # Let's look at a nice spectrogram too audio_spectrogram = audio_spectrogram_image(audio, sample_rate=sample_rate) return ("output.wav", [audio_spectrogram, *preview_images]) +def trim_audio(input_file, output_file, duration_seconds): + stream = ffmpeg.input(input_file) + audio_stream = stream.audio + trimmed = audio_stream.filter('atrim', end=duration_seconds) + output = ffmpeg.output(trimmed, output_file) + if os.path.exists(output_file): + os.remove(output_file) + ffmpeg.run(output) + os.remove(input_file) # removes the temp file + return + def generate_uncond( steps=250, seed=-1, @@ -399,7 +416,7 @@ def create_sampling_ui(model_config, inpainting=False): with gr.Row(visible = has_seconds_start or has_seconds_total): # Timing controls seconds_start_slider = gr.Slider(minimum=0, maximum=512, step=1, value=0, label="Seconds start", visible=has_seconds_start) - seconds_total_slider = gr.Slider(minimum=0, maximum=512, step=1, value=sample_size//sample_rate, label="Seconds total", visible=has_seconds_total) + seconds_total_slider = gr.Slider(minimum=1, maximum=47, step=1, value=sample_size//sample_rate, label="Seconds total", visible=has_seconds_total) with gr.Row(): # Steps slider