Skip to content

stable_audio model/pipeline review #13629

@hlky

Description

@hlky

stable_audio model/pipeline review

Commit tested: 0f1abc4ae8b0eb2a3b40e82a310507281144c423

Review performed against the repository review rules.

Files/categories reviewed: target pipeline/model files, public lazy imports, top-level exports, config/loading/device-map behavior, dtype/device handling, offload-related tests, attention processor behavior, docs, examples, fast/nightly/slow test coverage.

Verification note: attempted .venv\Scripts\python.exe -m pytest tests/pipelines/stable_audio/test_stable_audio.py -q, but local test collection fails before Stable Audio tests run because this Windows torch build lacks torch._C._distributed_c10d while importing FSDP. Narrow reproduction snippets below were checked with .venv.

Duplicate-search status: searched GitHub Issues/PRs for stable_audio, StableAudioPipeline, StableAudioDiTModel device_map, StableAudioAttnProcessor2_0 set_attention_backend, and initial_audio_waveforms num_waveforms_per_prompt. Found related but not exact duplicates: #10861 for initial-audio scaling and #8989 for sequential offload testing. No exact duplicate found for the batch-order, _no_split_modules, attention-backend, dtype, or docs findings.

Issue 1: Batched initial audio is paired with the wrong prompt when generating multiple waveforms

Affected code:

encoded_audio = self.vae.encode(audio).latent_dist.sample(generator)
encoded_audio = encoded_audio.repeat((num_waveforms_per_prompt, 1, 1))
latents = encoded_audio + latents

bs_embed, seq_len, hidden_size = text_audio_duration_embeds.shape
# duplicate audio_duration_embeds and text_audio_duration_embeds for each generation per prompt, using mps friendly method
text_audio_duration_embeds = text_audio_duration_embeds.repeat(1, num_waveforms_per_prompt, 1)
text_audio_duration_embeds = text_audio_duration_embeds.view(
bs_embed * num_waveforms_per_prompt, seq_len, hidden_size
)
audio_duration_embeds = audio_duration_embeds.repeat(1, num_waveforms_per_prompt, 1)
audio_duration_embeds = audio_duration_embeds.view(
bs_embed * num_waveforms_per_prompt, -1, audio_duration_embeds.shape[-1]
)

Problem:
text_audio_duration_embeds is expanded per prompt as [prompt0, prompt0, prompt1, prompt1], but encoded initial audio is expanded with encoded_audio.repeat((num_waveforms_per_prompt, 1, 1)), producing [audio0, audio1, audio0, audio1]. For batched audio-to-audio with num_waveforms_per_prompt > 1, prompts and initial audio become misaligned.

Impact:
Users requesting multiple variations per prompt with batched initial_audio_waveforms condition some generations on another prompt's audio. Existing tests only assert output shape, so this does not get caught.

Reproduction:

from types import SimpleNamespace
import torch
from diffusers import StableAudioPipeline

class DummyLatentDist:
    def __init__(self, sample):
        self._sample = sample
    def sample(self, generator=None):
        return self._sample

class DummyVAE:
    hop_length = 1
    def encode(self, audio):
        return SimpleNamespace(latent_dist=DummyLatentDist(audio[:, :1, :]))

pipe = StableAudioPipeline.__new__(StableAudioPipeline)
pipe.scheduler = SimpleNamespace(init_noise_sigma=0.0)
pipe.transformer = SimpleNamespace(config=SimpleNamespace(sample_size=2))
pipe.vae = DummyVAE()

initial_audio = torch.tensor([[[10.0, 10.0]], [[20.0, 20.0]]])
latents = StableAudioPipeline.prepare_latents(
    pipe, batch_size=4, num_channels_vae=1, sample_size=2,
    dtype=torch.float32, device=torch.device("cpu"),
    generator=torch.Generator().manual_seed(0),
    initial_audio_waveforms=initial_audio,
    num_waveforms_per_prompt=2,
    audio_channels=1,
)
print(latents[:, 0, 0].tolist())  # [10.0, 20.0, 10.0, 20.0], expected [10.0, 10.0, 20.0, 20.0]

Relevant precedent:
repeat_interleave(..., dim=0) is the common pattern for per-prompt expansion, e.g. qwenimage modular inputs.

Suggested fix:

encoded_audio = encoded_audio.repeat_interleave(num_waveforms_per_prompt, dim=0)

Issue 2: StableAudioDiTModel cannot be loaded with device_map, despite docs using device_map="balanced"

Affected code:

_supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["preprocess_conv", "postprocess_conv", "^proj_in$", "^proj_out$", "norm"]

pipeline = StableAudioPipeline.from_pretrained(
"stabilityai/stable-audio-open-1.0",
text_encoder=text_encoder_8bit,
transformer=transformer_8bit,
torch_dtype=torch.float16,
device_map="balanced",
)

Problem:
StableAudioDiTModel sets _supports_gradient_checkpointing = True but does not define _no_split_modules. Diffusers model loading raises for device_map="balanced"/"auto" unless _no_split_modules is implemented. The Stable Audio quantization docs currently show StableAudioPipeline.from_pretrained(..., device_map="balanced"), which is not supported by the transformer class.

Impact:
The documented quantized loading path is broken for the Stable Audio transformer, and users cannot use Diffusers device-map placement for the model.

Reproduction:

from diffusers import StableAudioDiTModel

model = StableAudioDiTModel(
    sample_size=4, in_channels=3, num_layers=1,
    attention_head_dim=4, num_attention_heads=2,
    num_key_value_attention_heads=2, out_channels=3,
    cross_attention_dim=4, time_proj_dim=8,
    global_states_input_dim=8, cross_attention_input_dim=4,
)

try:
    print(model._get_no_split_modules("balanced"))
except Exception as e:
    print(type(e).__name__, str(e).splitlines()[0])
# ValueError StableAudioDiTModel does not support `device_map='balanced'`.

Relevant precedent:

_supports_gradient_checkpointing = True
_no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]

_supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"]
_no_split_modules = ["WanTransformerBlock"]

Suggested fix:

class StableAudioDiTModel(ModelMixin, AttentionMixin, ConfigMixin):
    _supports_gradient_checkpointing = True
    _no_split_modules = ["StableAudioDiTBlock"]

Issue 3: Stable Audio attention ignores set_attention_backend

Affected code:

from ..attention_processor import Attention, StableAudioAttnProcessor2_0

processor=StableAudioAttnProcessor2_0(),
)
# 2. Cross-Attn
self.norm2 = nn.LayerNorm(dim, norm_eps, True)
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
kv_heads=num_key_value_attention_heads,
dropout=dropout,
bias=False,
upcast_attention=upcast_attention,
out_bias=False,
processor=StableAudioAttnProcessor2_0(),

class StableAudioAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
used in the Stable Audio model. It applies rotary embedding on query and key vector, and allows MHA, GQA or MQA.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"StableAudioAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def apply_partial_rotary_emb(
self,
x: torch.Tensor,
freqs_cis: tuple[torch.Tensor],
) -> torch.Tensor:
from .embeddings import apply_rotary_emb
rot_dim = freqs_cis[0].shape[-1]
x_to_rotate, x_unrotated = x[..., :rot_dim], x[..., rot_dim:]
x_rotated = apply_rotary_emb(x_to_rotate, freqs_cis, use_real=True, use_real_unbind_dim=-2)
out = torch.cat((x_rotated, x_unrotated), dim=-1)
return out
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
rotary_emb: torch.Tensor | None = None,
) -> torch.Tensor:
from .embeddings import apply_rotary_emb
residual = hidden_states
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
head_dim = query.shape[-1] // attn.heads
kv_heads = key.shape[-1] // head_dim
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2)
if kv_heads != attn.heads:
# if GQA or MQA, repeat the key/value heads to reach the number of query heads.
heads_per_kv_head = attn.heads // kv_heads
key = torch.repeat_interleave(key, heads_per_kv_head, dim=1, output_size=key.shape[1] * heads_per_kv_head)
value = torch.repeat_interleave(
value, heads_per_kv_head, dim=1, output_size=value.shape[1] * heads_per_kv_head
)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if rotary_emb is not None:
query_dtype = query.dtype
key_dtype = key.dtype
query = query.to(torch.float32)
key = key.to(torch.float32)
rot_dim = rotary_emb[0].shape[-1]
query_to_rotate, query_unrotated = query[..., :rot_dim], query[..., rot_dim:]
query_rotated = apply_rotary_emb(query_to_rotate, rotary_emb, use_real=True, use_real_unbind_dim=-2)
query = torch.cat((query_rotated, query_unrotated), dim=-1)
if not attn.is_cross_attention:
key_to_rotate, key_unrotated = key[..., :rot_dim], key[..., rot_dim:]
key_rotated = apply_rotary_emb(key_to_rotate, rotary_emb, use_real=True, use_real_unbind_dim=-2)
key = torch.cat((key_rotated, key_unrotated), dim=-1)
query = query.to(query_dtype)
key = key.to(key_dtype)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)

Problem:
StableAudioAttnProcessor2_0 lives in the shared attention processor file, has no _attention_backend / _parallel_config fields, and calls F.scaled_dot_product_attention directly. ModelMixin.set_attention_backend() only updates processors with _attention_backend, so Stable Audio processors remain unchanged.

Impact:
Users cannot select Flash/Sage/Flex/native backend behavior for Stable Audio even though StableAudioDiTModel inherits AttentionMixin. This also leaves Stable Audio outside the newer attention-dispatch and context-parallel patterns.

Reproduction:

from diffusers import StableAudioDiTModel

model = StableAudioDiTModel(
    sample_size=4, in_channels=3, num_layers=1,
    attention_head_dim=4, num_attention_heads=2,
    num_key_value_attention_heads=2, out_channels=3,
    cross_attention_dim=4, time_proj_dim=8,
    global_states_input_dim=8, cross_attention_input_dim=4,
)
model.set_attention_backend("native")
print([(type(p).__name__, hasattr(p, "_attention_backend")) for p in model.attn_processors.values()])
# [('StableAudioAttnProcessor2_0', False), ('StableAudioAttnProcessor2_0', False)]

Relevant precedent:

class FluxAttnProcessor:
_attention_backend = None
_parallel_config = None
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
def __call__(
self,
attn: "FluxAttention",
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
attention_mask: torch.Tensor | None = None,
image_rotary_emb: torch.Tensor | None = None,
) -> torch.Tensor:
query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
attn, hidden_states, encoder_hidden_states
)
query = query.unflatten(-1, (attn.heads, -1))
key = key.unflatten(-1, (attn.heads, -1))
value = value.unflatten(-1, (attn.heads, -1))
query = attn.norm_q(query)
key = attn.norm_k(key)
if attn.added_kv_proj_dim is not None:
encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
encoder_query = attn.norm_added_q(encoder_query)
encoder_key = attn.norm_added_k(encoder_key)
query = torch.cat([encoder_query, query], dim=1)
key = torch.cat([encoder_key, key], dim=1)
value = torch.cat([encoder_value, value], dim=1)
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
hidden_states = dispatch_attention_fn(
query,
key,
value,
attn_mask=attention_mask,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)

Suggested fix:
Refactor the Stable Audio attention processor to the model-file attention pattern: define processor state fields, use dispatch_attention_fn, and keep Q/K/V in the (batch, sequence, heads, head_dim) layout expected by the dispatcher. This is a moderate refactor because the current implementation uses (batch, heads, sequence, head_dim) around RoPE.

Issue 4: User-provided latents keep their original dtype

Affected code:

if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma

Problem:
When latents is supplied, prepare_latents() only does latents.to(device). It does not cast to the dtype selected for the pipeline call. With a half-precision Stable Audio transformer, float32 user latents are forwarded into half-precision Conv/Linear layers.

Impact:
Mixed-precision calls can fail at runtime or run with an unintended latent dtype. This is especially relevant because the slow test path supplies precomputed latents.

Reproduction:

from types import SimpleNamespace
import torch
from diffusers import StableAudioPipeline

pipe = StableAudioPipeline.__new__(StableAudioPipeline)
pipe.scheduler = SimpleNamespace(init_noise_sigma=1.0)

latents = torch.randn(1, 3, 4, dtype=torch.float32)
out = StableAudioPipeline.prepare_latents(
    pipe, batch_size=1, num_channels_vae=3, sample_size=4,
    dtype=torch.float16, device=torch.device("cpu"),
    generator=None, latents=latents,
)
print(out.dtype)  # torch.float32, expected torch.float16

Relevant precedent:
Newer pipelines often recast latents to the active latent/model dtype before denoising, for example QwenImage and Flux paths recast around denoising.

Suggested fix:

if latents is None:
    latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
    latents = latents.to(device=device, dtype=dtype)

Issue 5: Stable Audio has no model-level tests and no @slow tests

Affected code:

@unittest.skip("Not supported yet")
def test_sequential_cpu_offload_forward_pass(self):
pass
@unittest.skip("Not supported yet")
def test_sequential_offload_forward_pass_twice(self):
pass
@unittest.skip("Test not supported because `rotary_embed_dim` doesn't have any sensible default.")
def test_encode_prompt_works_in_isolation(self):
pass

@nightly
@require_torch_accelerator
class StableAudioPipelineIntegrationTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
latents = np.random.RandomState(seed).standard_normal((1, 64, 1024))
latents = torch.from_numpy(latents).to(device=device, dtype=dtype)
inputs = {
"prompt": "A hammer hitting a wooden surface",
"latents": latents,
"generator": generator,
"num_inference_steps": 3,
"audio_end_in_s": 30,
"guidance_scale": 2.5,
}
return inputs
def test_stable_audio(self):
stable_audio_pipe = StableAudioPipeline.from_pretrained("stabilityai/stable-audio-open-1.0")
stable_audio_pipe = stable_audio_pipe.to(torch_device)
stable_audio_pipe.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device)
inputs["num_inference_steps"] = 25
audio = stable_audio_pipe(**inputs).audios[0]
assert audio.ndim == 2
assert audio.shape == (2, int(inputs["audio_end_in_s"] * stable_audio_pipe.vae.sampling_rate))
# check the portion of the generated audio with the largest dynamic range (reduces flakiness)
audio_slice = audio[0, 447590:447600]
# fmt: off
expected_slices = Expectations(
{
("xpu", 3): np.array([-0.0285, 0.1083, 0.1863, 0.3165, 0.5312, 0.6971, 0.6958, 0.6177, 0.5598, 0.5048]),
("cuda", 7): np.array([-0.0278, 0.1096, 0.1877, 0.3178, 0.5329, 0.6990, 0.6972, 0.6186, 0.5608, 0.5060]),
("cuda", 8): np.array([-0.0285, 0.1082, 0.1862, 0.3163, 0.5306, 0.6964, 0.6953, 0.6172, 0.5593, 0.5044]),
}
)
# fmt: on
expected_slice = expected_slices.get_expectation()
max_diff = np.abs(expected_slice - audio_slice.detach().cpu().numpy()).max()
assert max_diff < 1.5e-3

Problem:
The family has fast pipeline tests and a @nightly integration test, but no tests/models/transformers/test_models_stable_audio*.py coverage and no @slow Stable Audio test. The pipeline also skips sequential offload tests and encode-prompt isolation. The sequential offload skip is already related to open issue #8989: #8989

Impact:
Model serialization/loading, attention backend behavior, _no_split_modules/device-map support, compile behavior, and model-level attention masks are not covered by the standard model test mixins. Missing slow coverage also means the non-nightly slow suite does not exercise the published checkpoint.

Reproduction:

from pathlib import Path

model_tests = list(Path("tests/models/transformers").glob("*stable*audio*.py"))
pipeline_test = Path("tests/pipelines/stable_audio/test_stable_audio.py").read_text()

print(model_tests)          # []
print("@slow" in pipeline_test)    # False
print("@nightly" in pipeline_test) # True

Relevant precedent:

class TestLongCatAudioDiTTransformer(LongCatAudioDiTTransformerTesterConfig, ModelTesterMixin):
pass
class TestLongCatAudioDiTTransformerMemory(LongCatAudioDiTTransformerTesterConfig, MemoryTesterMixin):
def test_layerwise_casting_memory(self):
pytest.skip(
"LongCatAudioDiTTransformer tiny test config does not provide stable layerwise casting peak memory "
"coverage."
)
class TestLongCatAudioDiTTransformerCompile(LongCatAudioDiTTransformerTesterConfig, TorchCompileTesterMixin):
pass
class TestLongCatAudioDiTTransformerAttention(LongCatAudioDiTTransformerTesterConfig, AttentionTesterMixin):
pass

Suggested fix:
Add a StableAudioDiTModel model tester using ModelTesterMixin, AttentionTesterMixin, and compile/memory coverage where supported. Add or mark a published-checkpoint pipeline test with @slow so Stable Audio is covered outside nightly-only CI. Keep #8989 referenced for sequential offload until that behavior is fixed or explicitly unsupported.

Issue 6: Stable Audio docs claim waveform scoring that the pipeline does not implement

Affected code:

During inference:
* The _quality_ of the generated audio sample can be controlled by the `num_inference_steps` argument; higher steps give higher quality audio at the expense of slower inference.
* Multiple waveforms can be generated in one go: set `num_waveforms_per_prompt` to a value greater than 1 to enable. Automatic scoring will be performed between the generated waveforms and prompt text, and the audios ranked from best to worst accordingly.

@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: str | list[str] = None,
audio_end_in_s: float | None = None,
audio_start_in_s: float | None = 0.0,
num_inference_steps: int = 100,
guidance_scale: float = 7.0,
negative_prompt: str | list[str] | None = None,
num_waveforms_per_prompt: int | None = 1,
eta: float = 0.0,
generator: torch.Generator | list[torch.Generator] | None = None,
latents: torch.Tensor | None = None,
initial_audio_waveforms: torch.Tensor | None = None,
initial_audio_sampling_rate: torch.Tensor | None = None,
prompt_embeds: torch.Tensor | None = None,
negative_prompt_embeds: torch.Tensor | None = None,
attention_mask: torch.LongTensor | None = None,
negative_attention_mask: torch.LongTensor | None = None,
return_dict: bool = True,
callback: Callable[[int, int, torch.Tensor], None] | None = None,
callback_steps: int | None = 1,
output_type: str | None = "pt",
):
r"""
The call function to the pipeline for generation.
Args:
prompt (`str` or `list[str]`, *optional*):
The prompt or prompts to guide audio generation. If not defined, you need to pass `prompt_embeds`.
audio_end_in_s (`float`, *optional*, defaults to 47.55):
Audio end index in seconds.
audio_start_in_s (`float`, *optional*, defaults to 0):
Audio start index in seconds.
num_inference_steps (`int`, *optional*, defaults to 100):
The number of denoising steps. More denoising steps usually lead to a higher quality audio at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.0):
A higher guidance scale value encourages the model to generate audio that is closely linked to the text
`prompt` at the expense of lower sound quality. Guidance scale is enabled when `guidance_scale > 1`.
negative_prompt (`str` or `list[str]`, *optional*):
The prompt or prompts to guide what to not include in audio generation. If not defined, you need to
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
num_waveforms_per_prompt (`int`, *optional*, defaults to 1):
The number of waveforms to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `list[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for audio
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor is generated by sampling using the supplied random `generator`.
initial_audio_waveforms (`torch.Tensor`, *optional*):
Optional initial audio waveforms to use as the initial audio waveform for generation. Must be of shape
`(batch_size, num_channels, audio_length)` or `(batch_size, audio_length)`, where `batch_size`
corresponds to the number of prompts passed to the model.
initial_audio_sampling_rate (`int`, *optional*):
Sampling rate of the `initial_audio_waveforms`, if they are provided. Must be the same as the model.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-computed text embeddings from the text encoder model. Can be used to easily tweak text inputs,
*e.g.* prompt weighting. If not provided, text embeddings will be computed from `prompt` input
argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-computed negative text embeddings from the text encoder model. Can be used to easily tweak text
inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be computed from
`negative_prompt` input argument.
attention_mask (`torch.LongTensor`, *optional*):
Pre-computed attention mask to be applied to the `prompt_embeds`. If not provided, attention mask will
be computed from `prompt` input argument.
negative_attention_mask (`torch.LongTensor`, *optional*):
Pre-computed attention mask to be applied to the `negative_text_audio_duration_embeds`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that calls every `callback_steps` steps during inference. The function is called with the
following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function is called. If not specified, the callback is called at
every step.
output_type (`str`, *optional*, defaults to `"pt"`):
The output format of the generated audio. Choose between `"np"` to return a NumPy `np.ndarray` or
`"pt"` to return a PyTorch `torch.Tensor` object. Set to `"latent"` to return the latent diffusion
model (LDM) output.
Examples:
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
otherwise a `tuple` is returned where the first element is a list with the generated audio.
"""
# 0. Convert audio input length from seconds to latent length
downsample_ratio = self.vae.hop_length
max_audio_length_in_s = self.transformer.config.sample_size * downsample_ratio / self.vae.config.sampling_rate
if audio_end_in_s is None:
audio_end_in_s = max_audio_length_in_s
if audio_end_in_s - audio_start_in_s > max_audio_length_in_s:
raise ValueError(
f"The total audio length requested ({audio_end_in_s - audio_start_in_s}s) is longer than the model maximum possible length ({max_audio_length_in_s}). Make sure that 'audio_end_in_s-audio_start_in_s<={max_audio_length_in_s}'."
)
waveform_start = int(audio_start_in_s * self.vae.config.sampling_rate)
waveform_end = int(audio_end_in_s * self.vae.config.sampling_rate)
waveform_length = int(self.transformer.config.sample_size)
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
audio_start_in_s,
audio_end_in_s,
callback_steps,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
attention_mask,
negative_attention_mask,
initial_audio_waveforms,
initial_audio_sampling_rate,
)
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
prompt_embeds = self.encode_prompt(
prompt,
device,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
attention_mask,
negative_attention_mask,
)
# Encode duration
seconds_start_hidden_states, seconds_end_hidden_states = self.encode_duration(
audio_start_in_s,
audio_end_in_s,
device,
do_classifier_free_guidance and (negative_prompt is not None or negative_prompt_embeds is not None),
batch_size,
)
# Create text_audio_duration_embeds and audio_duration_embeds
text_audio_duration_embeds = torch.cat(
[prompt_embeds, seconds_start_hidden_states, seconds_end_hidden_states], dim=1
)
audio_duration_embeds = torch.cat([seconds_start_hidden_states, seconds_end_hidden_states], dim=2)
# In case of classifier free guidance without negative prompt, we need to create unconditional embeddings and
# to concatenate it to the embeddings
if do_classifier_free_guidance and negative_prompt_embeds is None and negative_prompt is None:
negative_text_audio_duration_embeds = torch.zeros_like(
text_audio_duration_embeds, device=text_audio_duration_embeds.device
)
text_audio_duration_embeds = torch.cat(
[negative_text_audio_duration_embeds, text_audio_duration_embeds], dim=0
)
audio_duration_embeds = torch.cat([audio_duration_embeds, audio_duration_embeds], dim=0)
bs_embed, seq_len, hidden_size = text_audio_duration_embeds.shape
# duplicate audio_duration_embeds and text_audio_duration_embeds for each generation per prompt, using mps friendly method
text_audio_duration_embeds = text_audio_duration_embeds.repeat(1, num_waveforms_per_prompt, 1)
text_audio_duration_embeds = text_audio_duration_embeds.view(
bs_embed * num_waveforms_per_prompt, seq_len, hidden_size
)
audio_duration_embeds = audio_duration_embeds.repeat(1, num_waveforms_per_prompt, 1)
audio_duration_embeds = audio_duration_embeds.view(
bs_embed * num_waveforms_per_prompt, -1, audio_duration_embeds.shape[-1]
)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_vae = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_waveforms_per_prompt,
num_channels_vae,
waveform_length,
text_audio_duration_embeds.dtype,
device,
generator,
latents,
initial_audio_waveforms,
num_waveforms_per_prompt,
audio_channels=self.vae.config.audio_channels,
)
# 6. Prepare extra step kwargs
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Prepare rotary positional embedding
rotary_embedding = get_1d_rotary_pos_embed(
self.rotary_embed_dim,
latents.shape[2] + audio_duration_embeds.shape[1],
use_real=True,
repeat_interleave_real=False,
)
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.transformer(
latent_model_input,
t.unsqueeze(0),
encoder_hidden_states=text_audio_duration_embeds,
global_hidden_states=audio_duration_embeds,
rotary_embedding=rotary_embedding,
return_dict=False,
)[0]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
# 9. Post-processing
if not output_type == "latent":
audio = self.vae.decode(latents).sample
else:
return AudioPipelineOutput(audios=latents)
audio = audio[:, :, waveform_start:waveform_end]
if output_type == "np":
audio = audio.cpu().float().numpy()
self.maybe_free_model_hooks()
if not return_dict:
return (audio,)
return AudioPipelineOutput(audios=audio)

Problem:
The docs say num_waveforms_per_prompt > 1 performs automatic scoring and ranks outputs by prompt similarity. StableAudioPipeline has no scoring component or score_waveforms() path; it simply returns generated audio in batch order. This looks copied from AudioLDM2/MusicLDM behavior.

Impact:
Users are told generated Stable Audio waveforms are ranked when they are not.

Reproduction:

import inspect
from diffusers import StableAudioPipeline

print(hasattr(StableAudioPipeline, "score_waveforms"))                  # False
print("score_waveforms" in inspect.getsource(StableAudioPipeline.__call__))  # False

Relevant precedent:
AudioLDM2 implements waveform scoring:

def score_waveforms(self, text, audio, num_waveforms_per_prompt, device, dtype):
if not is_librosa_available():
logger.info(
"Automatic scoring of the generated audio waveforms against the input prompt text requires the "
"`librosa` package to resample the generated waveforms. Returning the audios in the order they were "
"generated. To enable automatic scoring, install `librosa` with: `pip install librosa`."
)
return audio
inputs = self.tokenizer(text, return_tensors="pt", padding=True)
resampled_audio = librosa.resample(
audio.numpy(), orig_sr=self.vocoder.config.sampling_rate, target_sr=self.feature_extractor.sampling_rate
)
inputs["input_features"] = self.feature_extractor(
list(resampled_audio), return_tensors="pt", sampling_rate=self.feature_extractor.sampling_rate
).input_features.type(dtype)
inputs = inputs.to(device)
# compute the audio-text similarity score using the CLAP model
logits_per_text = self.text_encoder(**inputs).logits_per_text
# sort by the highest matching generations per prompt
indices = torch.argsort(logits_per_text, dim=1, descending=True)[:, :num_waveforms_per_prompt]

Suggested fix:
Remove the scoring/ranking sentence from the Stable Audio docs, or implement an actual scoring component before documenting ranking behavior.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions