Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 43 additions & 17 deletions stable_audio_3/interface/diffusion_cond.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,31 @@
_LENGTH_EXTRACT_RE = re.compile(r' Length: (\d+) seconds\.?\s*$')


def parse_inpaint_regions(starts_csv, ends_csv):
"""Parse matching comma-separated inpaint start/end regions."""
def parse(value):
Comment on lines +25 to +27
values = [item.strip() for item in str(value or "").split(",") if item.strip()]
try:
return [float(item) for item in values]
except ValueError as exc:
raise gr.Error("Inpaint regions must be comma-separated numbers.") from exc
Comment on lines +27 to +32

starts = parse(starts_csv)
ends = parse(ends_csv)
if not starts and not ends:
return None, None
if not starts or not ends:
raise gr.Error("Inpaint starts and ends must both be provided.")
if len(starts) != len(ends):
raise gr.Error("Inpaint starts and ends must contain the same number of regions.")
if any(start < 0 or end <= start for start, end in zip(starts, ends)):
raise gr.Error("Each inpaint region must have a non-negative start before its end.")
return (
starts[0] if len(starts) == 1 else starts,
ends[0] if len(ends) == 1 else ends,
)


# when using a prompt in a filename
def condense_prompt(prompt):
pattern = r'[\\/:*?"<>|]'
Expand Down Expand Up @@ -132,6 +157,12 @@ def progress_callback(callback_info):
else:
inversion_params = None

mask_maskstart, mask_maskend = parse_inpaint_regions(
mask_maskstart, mask_maskend
)
if inpaint_audio is not None and mask_maskstart is None:
raise gr.Error("Add at least one inpaint start/end region.")

generate_args = {
"prompt": prompt,
"negative_prompt": negative_prompt,
Expand All @@ -140,7 +171,7 @@ def progress_callback(callback_info):
"cfg_scale": cfg_scale,
"cfg_interval": (cfg_interval_min, cfg_interval_max),
"lora_configs": lora_configs,
"batch_size": batch_size,
"batch_size": int(batch_size),
"sample_size": input_sample_size,
"seed": seed,
"sampler_type": sampler_type,
Expand Down Expand Up @@ -428,8 +459,14 @@ def update_dist_shift_state(shift_type, *params):
outputs=[dist_shift_state, logsnr_params_row, flux_params_row, full_params_row],
)

# Hidden state for batch_size (no UI control, but needed for function signature)
batch_size_state = gr.State(value=1)
with gr.Accordion("Batch", open=False):
batch_size_number = gr.Number(
label="Batch size",
value=1,
minimum=1,
precision=0,
info="Generate multiple variations in one run.",
)

with gr.Accordion("Output params", open=False):
# Output params
Expand Down Expand Up @@ -472,19 +509,8 @@ def init_audio_type_switch(choice):

with gr.Accordion("Inpainting", open=False):
inpaint_audio_input = gr.Audio(label="Inpaint audio", waveform_options=gr.WaveformOptions(show_recording_waveform=False))
mask_maskstart_slider = gr.Slider(minimum=0.0, maximum=sample_size//sample_rate, step=0.1, value=0, label="Mask Start (sec)")
mask_maskend_slider = gr.Slider(minimum=0.0, maximum=sample_size//sample_rate, step=0.1, value=sample_size//sample_rate, label="Mask End (sec)")

# Update inpainting slider ranges when seconds_total changes.
# Only seconds_total is an input — reading the mask sliders here would cause
# validation errors since their values may exceed the about-to-be-reduced maximum.
def update_inpaint_sliders(seconds_total):
max_val = max(seconds_total, 1)
return (
gr.update(maximum=max_val),
gr.update(maximum=max_val, value=max_val),
)
seconds_total_slider.change(update_inpaint_sliders, inputs=[seconds_total_slider], outputs=[mask_maskstart_slider, mask_maskend_slider])
mask_maskstart_slider = gr.Textbox(label="Mask starts (sec)", placeholder="4 or 4, 16", info="Comma-separate values to inpaint multiple regions.")
mask_maskend_slider = gr.Textbox(label="Mask ends (sec)", placeholder="8 or 8, 20")

inputs = [
prompt,
Expand Down Expand Up @@ -514,7 +540,7 @@ def update_inpaint_sliders(seconds_total):
inversion_gamma_slider,
inversion_unconditional_checkbox,
duration_padding_slider,
batch_size_state,
batch_size_number,
dist_shift_state,
] + lora_ui_inputs

Expand Down
27 changes: 27 additions & 0 deletions tests/test_gradio_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import pytest

from stable_audio_3.interface.diffusion_cond import parse_inpaint_regions


def test_parse_single_inpaint_region():
assert parse_inpaint_regions("4", "8") == (4.0, 8.0)


def test_parse_multiple_inpaint_regions():
assert parse_inpaint_regions("4, 16", "8, 20") == (
[4.0, 16.0],
[8.0, 20.0],
)


def test_parse_empty_inpaint_regions():
assert parse_inpaint_regions("", "") == (None, None)


@pytest.mark.parametrize(
("starts", "ends"),
[("4", ""), ("4, 16", "8"), ("8", "4"), ("-1", "4")],
)
def test_parse_invalid_inpaint_regions(starts, ends):
with pytest.raises(Exception):
parse_inpaint_regions(starts, ends)
Comment on lines +25 to +27
Loading