|
@torch.no_grad() |
|
@replace_example_docstring(EXAMPLE_DOC_STRING) |
|
def __call__( |
|
self, |
|
prompt: str | list[str] = None, |
|
audio_end_in_s: float | None = None, |
|
audio_start_in_s: float | None = 0.0, |
|
num_inference_steps: int = 100, |
|
guidance_scale: float = 7.0, |
|
negative_prompt: str | list[str] | None = None, |
|
num_waveforms_per_prompt: int | None = 1, |
|
eta: float = 0.0, |
|
generator: torch.Generator | list[torch.Generator] | None = None, |
|
latents: torch.Tensor | None = None, |
|
initial_audio_waveforms: torch.Tensor | None = None, |
|
initial_audio_sampling_rate: torch.Tensor | None = None, |
|
prompt_embeds: torch.Tensor | None = None, |
|
negative_prompt_embeds: torch.Tensor | None = None, |
|
attention_mask: torch.LongTensor | None = None, |
|
negative_attention_mask: torch.LongTensor | None = None, |
|
return_dict: bool = True, |
|
callback: Callable[[int, int, torch.Tensor], None] | None = None, |
|
callback_steps: int | None = 1, |
|
output_type: str | None = "pt", |
|
): |
|
r""" |
|
The call function to the pipeline for generation. |
|
|
|
Args: |
|
prompt (`str` or `list[str]`, *optional*): |
|
The prompt or prompts to guide audio generation. If not defined, you need to pass `prompt_embeds`. |
|
audio_end_in_s (`float`, *optional*, defaults to 47.55): |
|
Audio end index in seconds. |
|
audio_start_in_s (`float`, *optional*, defaults to 0): |
|
Audio start index in seconds. |
|
num_inference_steps (`int`, *optional*, defaults to 100): |
|
The number of denoising steps. More denoising steps usually lead to a higher quality audio at the |
|
expense of slower inference. |
|
guidance_scale (`float`, *optional*, defaults to 7.0): |
|
A higher guidance scale value encourages the model to generate audio that is closely linked to the text |
|
`prompt` at the expense of lower sound quality. Guidance scale is enabled when `guidance_scale > 1`. |
|
negative_prompt (`str` or `list[str]`, *optional*): |
|
The prompt or prompts to guide what to not include in audio generation. If not defined, you need to |
|
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). |
|
num_waveforms_per_prompt (`int`, *optional*, defaults to 1): |
|
The number of waveforms to generate per prompt. |
|
eta (`float`, *optional*, defaults to 0.0): |
|
Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only |
|
applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. |
|
generator (`torch.Generator` or `list[torch.Generator]`, *optional*): |
|
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make |
|
generation deterministic. |
|
latents (`torch.Tensor`, *optional*): |
|
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for audio |
|
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents |
|
tensor is generated by sampling using the supplied random `generator`. |
|
initial_audio_waveforms (`torch.Tensor`, *optional*): |
|
Optional initial audio waveforms to use as the initial audio waveform for generation. Must be of shape |
|
`(batch_size, num_channels, audio_length)` or `(batch_size, audio_length)`, where `batch_size` |
|
corresponds to the number of prompts passed to the model. |
|
initial_audio_sampling_rate (`int`, *optional*): |
|
Sampling rate of the `initial_audio_waveforms`, if they are provided. Must be the same as the model. |
|
prompt_embeds (`torch.Tensor`, *optional*): |
|
Pre-computed text embeddings from the text encoder model. Can be used to easily tweak text inputs, |
|
*e.g.* prompt weighting. If not provided, text embeddings will be computed from `prompt` input |
|
argument. |
|
negative_prompt_embeds (`torch.Tensor`, *optional*): |
|
Pre-computed negative text embeddings from the text encoder model. Can be used to easily tweak text |
|
inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be computed from |
|
`negative_prompt` input argument. |
|
attention_mask (`torch.LongTensor`, *optional*): |
|
Pre-computed attention mask to be applied to the `prompt_embeds`. If not provided, attention mask will |
|
be computed from `prompt` input argument. |
|
negative_attention_mask (`torch.LongTensor`, *optional*): |
|
Pre-computed attention mask to be applied to the `negative_text_audio_duration_embeds`. |
|
return_dict (`bool`, *optional*, defaults to `True`): |
|
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a |
|
plain tuple. |
|
callback (`Callable`, *optional*): |
|
A function that calls every `callback_steps` steps during inference. The function is called with the |
|
following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. |
|
callback_steps (`int`, *optional*, defaults to 1): |
|
The frequency at which the `callback` function is called. If not specified, the callback is called at |
|
every step. |
|
output_type (`str`, *optional*, defaults to `"pt"`): |
|
The output format of the generated audio. Choose between `"np"` to return a NumPy `np.ndarray` or |
|
`"pt"` to return a PyTorch `torch.Tensor` object. Set to `"latent"` to return the latent diffusion |
|
model (LDM) output. |
|
|
|
Examples: |
|
|
|
Returns: |
|
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: |
|
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, |
|
otherwise a `tuple` is returned where the first element is a list with the generated audio. |
|
""" |
|
# 0. Convert audio input length from seconds to latent length |
|
downsample_ratio = self.vae.hop_length |
|
|
|
max_audio_length_in_s = self.transformer.config.sample_size * downsample_ratio / self.vae.config.sampling_rate |
|
if audio_end_in_s is None: |
|
audio_end_in_s = max_audio_length_in_s |
|
|
|
if audio_end_in_s - audio_start_in_s > max_audio_length_in_s: |
|
raise ValueError( |
|
f"The total audio length requested ({audio_end_in_s - audio_start_in_s}s) is longer than the model maximum possible length ({max_audio_length_in_s}). Make sure that 'audio_end_in_s-audio_start_in_s<={max_audio_length_in_s}'." |
|
) |
|
|
|
waveform_start = int(audio_start_in_s * self.vae.config.sampling_rate) |
|
waveform_end = int(audio_end_in_s * self.vae.config.sampling_rate) |
|
waveform_length = int(self.transformer.config.sample_size) |
|
|
|
# 1. Check inputs. Raise error if not correct |
|
self.check_inputs( |
|
prompt, |
|
audio_start_in_s, |
|
audio_end_in_s, |
|
callback_steps, |
|
negative_prompt, |
|
prompt_embeds, |
|
negative_prompt_embeds, |
|
attention_mask, |
|
negative_attention_mask, |
|
initial_audio_waveforms, |
|
initial_audio_sampling_rate, |
|
) |
|
|
|
# 2. Define call parameters |
|
if prompt is not None and isinstance(prompt, str): |
|
batch_size = 1 |
|
elif prompt is not None and isinstance(prompt, list): |
|
batch_size = len(prompt) |
|
else: |
|
batch_size = prompt_embeds.shape[0] |
|
|
|
device = self._execution_device |
|
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) |
|
# of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` |
|
# corresponds to doing no classifier free guidance. |
|
do_classifier_free_guidance = guidance_scale > 1.0 |
|
|
|
# 3. Encode input prompt |
|
prompt_embeds = self.encode_prompt( |
|
prompt, |
|
device, |
|
do_classifier_free_guidance, |
|
negative_prompt, |
|
prompt_embeds, |
|
negative_prompt_embeds, |
|
attention_mask, |
|
negative_attention_mask, |
|
) |
|
|
|
# Encode duration |
|
seconds_start_hidden_states, seconds_end_hidden_states = self.encode_duration( |
|
audio_start_in_s, |
|
audio_end_in_s, |
|
device, |
|
do_classifier_free_guidance and (negative_prompt is not None or negative_prompt_embeds is not None), |
|
batch_size, |
|
) |
|
|
|
# Create text_audio_duration_embeds and audio_duration_embeds |
|
text_audio_duration_embeds = torch.cat( |
|
[prompt_embeds, seconds_start_hidden_states, seconds_end_hidden_states], dim=1 |
|
) |
|
|
|
audio_duration_embeds = torch.cat([seconds_start_hidden_states, seconds_end_hidden_states], dim=2) |
|
|
|
# In case of classifier free guidance without negative prompt, we need to create unconditional embeddings and |
|
# to concatenate it to the embeddings |
|
if do_classifier_free_guidance and negative_prompt_embeds is None and negative_prompt is None: |
|
negative_text_audio_duration_embeds = torch.zeros_like( |
|
text_audio_duration_embeds, device=text_audio_duration_embeds.device |
|
) |
|
text_audio_duration_embeds = torch.cat( |
|
[negative_text_audio_duration_embeds, text_audio_duration_embeds], dim=0 |
|
) |
|
audio_duration_embeds = torch.cat([audio_duration_embeds, audio_duration_embeds], dim=0) |
|
|
|
bs_embed, seq_len, hidden_size = text_audio_duration_embeds.shape |
|
# duplicate audio_duration_embeds and text_audio_duration_embeds for each generation per prompt, using mps friendly method |
|
text_audio_duration_embeds = text_audio_duration_embeds.repeat(1, num_waveforms_per_prompt, 1) |
|
text_audio_duration_embeds = text_audio_duration_embeds.view( |
|
bs_embed * num_waveforms_per_prompt, seq_len, hidden_size |
|
) |
|
|
|
audio_duration_embeds = audio_duration_embeds.repeat(1, num_waveforms_per_prompt, 1) |
|
audio_duration_embeds = audio_duration_embeds.view( |
|
bs_embed * num_waveforms_per_prompt, -1, audio_duration_embeds.shape[-1] |
|
) |
|
|
|
# 4. Prepare timesteps |
|
self.scheduler.set_timesteps(num_inference_steps, device=device) |
|
timesteps = self.scheduler.timesteps |
|
|
|
# 5. Prepare latent variables |
|
num_channels_vae = self.transformer.config.in_channels |
|
latents = self.prepare_latents( |
|
batch_size * num_waveforms_per_prompt, |
|
num_channels_vae, |
|
waveform_length, |
|
text_audio_duration_embeds.dtype, |
|
device, |
|
generator, |
|
latents, |
|
initial_audio_waveforms, |
|
num_waveforms_per_prompt, |
|
audio_channels=self.vae.config.audio_channels, |
|
) |
|
|
|
# 6. Prepare extra step kwargs |
|
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) |
|
|
|
# 7. Prepare rotary positional embedding |
|
rotary_embedding = get_1d_rotary_pos_embed( |
|
self.rotary_embed_dim, |
|
latents.shape[2] + audio_duration_embeds.shape[1], |
|
use_real=True, |
|
repeat_interleave_real=False, |
|
) |
|
|
|
# 8. Denoising loop |
|
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order |
|
with self.progress_bar(total=num_inference_steps) as progress_bar: |
|
for i, t in enumerate(timesteps): |
|
# expand the latents if we are doing classifier free guidance |
|
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents |
|
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) |
|
|
|
# predict the noise residual |
|
noise_pred = self.transformer( |
|
latent_model_input, |
|
t.unsqueeze(0), |
|
encoder_hidden_states=text_audio_duration_embeds, |
|
global_hidden_states=audio_duration_embeds, |
|
rotary_embedding=rotary_embedding, |
|
return_dict=False, |
|
)[0] |
|
|
|
# perform guidance |
|
if do_classifier_free_guidance: |
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
|
|
|
# compute the previous noisy sample x_t -> x_t-1 |
|
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample |
|
|
|
# call the callback, if provided |
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): |
|
progress_bar.update() |
|
if callback is not None and i % callback_steps == 0: |
|
step_idx = i // getattr(self.scheduler, "order", 1) |
|
callback(step_idx, t, latents) |
|
|
|
if XLA_AVAILABLE: |
|
xm.mark_step() |
|
|
|
# 9. Post-processing |
|
if not output_type == "latent": |
|
audio = self.vae.decode(latents).sample |
|
else: |
|
return AudioPipelineOutput(audios=latents) |
|
|
|
audio = audio[:, :, waveform_start:waveform_end] |
|
|
|
if output_type == "np": |
|
audio = audio.cpu().float().numpy() |
|
|
|
self.maybe_free_model_hooks() |
|
|
|
if not return_dict: |
|
return (audio,) |
|
|
|
return AudioPipelineOutput(audios=audio) |
stable_audiomodel/pipeline reviewCommit tested:
0f1abc4ae8b0eb2a3b40e82a310507281144c423Review performed against the repository review rules.
Files/categories reviewed: target pipeline/model files, public lazy imports, top-level exports, config/loading/device-map behavior, dtype/device handling, offload-related tests, attention processor behavior, docs, examples, fast/nightly/slow test coverage.
Verification note: attempted
.venv\Scripts\python.exe -m pytest tests/pipelines/stable_audio/test_stable_audio.py -q, but local test collection fails before Stable Audio tests run because this Windows torch build lackstorch._C._distributed_c10dwhile importing FSDP. Narrow reproduction snippets below were checked with.venv.Duplicate-search status: searched GitHub Issues/PRs for
stable_audio,StableAudioPipeline,StableAudioDiTModel device_map,StableAudioAttnProcessor2_0 set_attention_backend, andinitial_audio_waveforms num_waveforms_per_prompt. Found related but not exact duplicates: #10861 for initial-audio scaling and #8989 for sequential offload testing. No exact duplicate found for the batch-order,_no_split_modules, attention-backend, dtype, or docs findings.Issue 1: Batched initial audio is paired with the wrong prompt when generating multiple waveforms
Affected code:
diffusers/src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py
Lines 485 to 487 in 0f1abc4
diffusers/src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py
Lines 670 to 680 in 0f1abc4
Problem:
text_audio_duration_embedsis expanded per prompt as[prompt0, prompt0, prompt1, prompt1], but encoded initial audio is expanded withencoded_audio.repeat((num_waveforms_per_prompt, 1, 1)), producing[audio0, audio1, audio0, audio1]. For batched audio-to-audio withnum_waveforms_per_prompt > 1, prompts and initial audio become misaligned.Impact:
Users requesting multiple variations per prompt with batched
initial_audio_waveformscondition some generations on another prompt's audio. Existing tests only assert output shape, so this does not get caught.Reproduction:
Relevant precedent:
repeat_interleave(..., dim=0)is the common pattern for per-prompt expansion, e.g.qwenimagemodular inputs.Suggested fix:
Issue 2:
StableAudioDiTModelcannot be loaded withdevice_map, despite docs usingdevice_map="balanced"Affected code:
diffusers/src/diffusers/models/transformers/stable_audio_transformer.py
Lines 206 to 208 in 0f1abc4
diffusers/docs/source/en/api/pipelines/stable_audio.md
Lines 66 to 72 in 0f1abc4
Problem:
StableAudioDiTModelsets_supports_gradient_checkpointing = Truebut does not define_no_split_modules. Diffusers model loading raises fordevice_map="balanced"/"auto"unless_no_split_modulesis implemented. The Stable Audio quantization docs currently showStableAudioPipeline.from_pretrained(..., device_map="balanced"), which is not supported by the transformer class.Impact:
The documented quantized loading path is broken for the Stable Audio transformer, and users cannot use Diffusers device-map placement for the model.
Reproduction:
Relevant precedent:
diffusers/src/diffusers/models/transformers/transformer_flux.py
Lines 565 to 566 in 0f1abc4
diffusers/src/diffusers/models/transformers/transformer_wan.py
Lines 546 to 548 in 0f1abc4
Suggested fix:
Issue 3: Stable Audio attention ignores
set_attention_backendAffected code:
diffusers/src/diffusers/models/transformers/stable_audio_transformer.py
Line 24 in 0f1abc4
diffusers/src/diffusers/models/transformers/stable_audio_transformer.py
Lines 105 to 121 in 0f1abc4
diffusers/src/diffusers/models/attention_processor.py
Lines 2991 to 3103 in 0f1abc4
Problem:
StableAudioAttnProcessor2_0lives in the shared attention processor file, has no_attention_backend/_parallel_configfields, and callsF.scaled_dot_product_attentiondirectly.ModelMixin.set_attention_backend()only updates processors with_attention_backend, so Stable Audio processors remain unchanged.Impact:
Users cannot select Flash/Sage/Flex/native backend behavior for Stable Audio even though
StableAudioDiTModelinheritsAttentionMixin. This also leaves Stable Audio outside the newer attention-dispatch and context-parallel patterns.Reproduction:
Relevant precedent:
diffusers/src/diffusers/models/transformers/transformer_flux.py
Lines 75 to 125 in 0f1abc4
Suggested fix:
Refactor the Stable Audio attention processor to the model-file attention pattern: define processor state fields, use
dispatch_attention_fn, and keep Q/K/V in the(batch, sequence, heads, head_dim)layout expected by the dispatcher. This is a moderate refactor because the current implementation uses(batch, heads, sequence, head_dim)around RoPE.Issue 4: User-provided latents keep their original dtype
Affected code:
diffusers/src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py
Lines 439 to 445 in 0f1abc4
Problem:
When
latentsis supplied,prepare_latents()only doeslatents.to(device). It does not cast to thedtypeselected for the pipeline call. With a half-precision Stable Audio transformer, float32 user latents are forwarded into half-precision Conv/Linear layers.Impact:
Mixed-precision calls can fail at runtime or run with an unintended latent dtype. This is especially relevant because the slow test path supplies precomputed latents.
Reproduction:
Relevant precedent:
Newer pipelines often recast latents to the active latent/model dtype before denoising, for example QwenImage and Flux paths recast around denoising.
Suggested fix:
Issue 5: Stable Audio has no model-level tests and no
@slowtestsAffected code:
diffusers/tests/pipelines/stable_audio/test_stable_audio.py
Lines 413 to 423 in 0f1abc4
diffusers/tests/pipelines/stable_audio/test_stable_audio.py
Lines 426 to 478 in 0f1abc4
Problem:
The family has fast pipeline tests and a
@nightlyintegration test, but notests/models/transformers/test_models_stable_audio*.pycoverage and no@slowStable Audio test. The pipeline also skips sequential offload tests and encode-prompt isolation. The sequential offload skip is already related to open issue #8989: #8989Impact:
Model serialization/loading, attention backend behavior,
_no_split_modules/device-map support, compile behavior, and model-level attention masks are not covered by the standard model test mixins. Missing slow coverage also means the non-nightly slow suite does not exercise the published checkpoint.Reproduction:
Relevant precedent:
diffusers/tests/models/transformers/test_models_transformer_longcat_audio_dit.py
Lines 84 to 101 in 0f1abc4
Suggested fix:
Add a
StableAudioDiTModelmodel tester usingModelTesterMixin,AttentionTesterMixin, and compile/memory coverage where supported. Add or mark a published-checkpoint pipeline test with@slowso Stable Audio is covered outside nightly-only CI. Keep #8989 referenced for sequential offload until that behavior is fixed or explicitly unsupported.Issue 6: Stable Audio docs claim waveform scoring that the pipeline does not implement
Affected code:
diffusers/docs/source/en/api/pipelines/stable_audio.md
Lines 33 to 37 in 0f1abc4
diffusers/src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py
Lines 490 to 764 in 0f1abc4
Problem:
The docs say
num_waveforms_per_prompt > 1performs automatic scoring and ranks outputs by prompt similarity.StableAudioPipelinehas no scoring component orscore_waveforms()path; it simply returns generated audio in batch order. This looks copied from AudioLDM2/MusicLDM behavior.Impact:
Users are told generated Stable Audio waveforms are ranked when they are not.
Reproduction:
Relevant precedent:
AudioLDM2 implements waveform scoring:
diffusers/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py
Lines 707 to 727 in 0f1abc4
Suggested fix:
Remove the scoring/ranking sentence from the Stable Audio docs, or implement an actual scoring component before documenting ranking behavior.