-
Notifications
You must be signed in to change notification settings - Fork 331
[feat] Add stable_audio T2A Generation #1080
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
happy-harvey
wants to merge
1
commit into
hao-ai-lab:main
Choose a base branch
from
happy-harvey:harvey/audio_dev
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,39 @@ | ||
| # Stable Audio | ||
|
|
||
| Text-to-audio with **Stable Audio Open 1.0** (`stabilityai/stable-audio-open-1.0`). Weights are loaded from HuggingFace or a local path (unified `model.safetensors` + `model_config.json`). | ||
|
|
||
| ## Setup | ||
|
|
||
| **1. HuggingFace login** (accept model terms on the Hub if required): | ||
|
|
||
| ```bash | ||
| hf auth login | ||
| ``` | ||
|
|
||
| **2. Install deps** (k-diffusion, alias-free-torch, einops-exts): | ||
|
|
||
| ```bash | ||
| pip install .[stable-audio] | ||
| ``` | ||
|
|
||
| ## Usage | ||
|
|
||
| **CLI:** | ||
|
|
||
| ```bash | ||
| python examples/inference/basic/stable_audio_basic.py | ||
| python examples/inference/basic/stable_audio_basic.py --prompt "A gentle rain" --duration 8 --output out.wav | ||
| ``` | ||
|
|
||
| **Python:** | ||
|
|
||
| ```python | ||
| from fastvideo import VideoGenerator | ||
|
|
||
| gen = VideoGenerator.from_pretrained("stabilityai/stable-audio-open-1.0", num_gpus=1) | ||
| out = gen.generate_audio(prompt="A beautiful piano arpeggio", duration_seconds=10.0) | ||
| # out["audio"]: (B, C, T), out["sample_rate"]: 44100 | ||
| gen.shutdown() | ||
| ``` | ||
|
|
||
| Optional: `--no-cpu-offload` for higher GPU use (more VRAM). Max duration ~47 s at 44.1 kHz. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,128 @@ | ||
| #!/usr/bin/env python3 | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| """ | ||
| Minimal example: generate audio from a text prompt using Stable Audio Open. | ||
|
|
||
| Requires: pip install .[stable-audio] (no stable-audio-tools clone needed). | ||
|
|
||
| Usage: | ||
| python examples/inference/basic/stable_audio_basic.py | ||
| python examples/inference/basic/stable_audio_basic.py --prompt "A gentle rain" --duration 8 | ||
| python examples/inference/basic/stable_audio_basic.py --no-cpu-offload # higher GPU utilization | ||
| """ | ||
| import argparse | ||
| import os | ||
|
|
||
| import numpy as np | ||
| import torch | ||
|
|
||
| from fastvideo import VideoGenerator | ||
|
|
||
|
|
||
| def save_audio_wav(audio: torch.Tensor, sample_rate: int, path: str) -> None: | ||
| """Save audio tensor (B, C, T) to WAV file. Output is stereo interleaved.""" | ||
| import wave | ||
|
|
||
| if audio.ndim == 3: | ||
| audio = audio[0] | ||
| audio_np = audio.detach().cpu().float().numpy() | ||
| audio_np = np.clip(audio_np, -1.0, 1.0) | ||
| audio_int16 = (audio_np * 32767.0).astype(np.int16) | ||
| if audio_int16.ndim == 1: | ||
| audio_int16 = audio_int16[:, None] | ||
| num_channels = audio_int16.shape[0] | ||
| num_frames = audio_int16.shape[1] | ||
| frames_bytes = audio_int16.T.tobytes() | ||
|
|
||
| os.makedirs(os.path.dirname(path) or ".", exist_ok=True) | ||
| with wave.open(path, "wb") as wav_file: | ||
| wav_file.setnchannels(num_channels) | ||
| wav_file.setsampwidth(2) | ||
| wav_file.setframerate(sample_rate) | ||
| wav_file.writeframes(frames_bytes) | ||
|
|
||
|
|
||
| def main() -> None: | ||
| parser = argparse.ArgumentParser(description="Stable Audio text-to-audio generation") | ||
| parser.add_argument( | ||
| "--model-path", | ||
| type=str, | ||
| default="stabilityai/stable-audio-open-1.0", | ||
| help="Path to model or HuggingFace model ID (e.g. stabilityai/stable-audio-open-1.0)", | ||
| ) | ||
| parser.add_argument( | ||
| "--prompt", | ||
| type=str, | ||
| default="A beautiful piano arpeggio", | ||
| help="Text description of the audio to generate", | ||
| ) | ||
| parser.add_argument( | ||
| "--duration", | ||
| type=float, | ||
| default=10.0, | ||
| help="Duration in seconds (default: 10)", | ||
| ) | ||
| parser.add_argument( | ||
| "--output", | ||
| type=str, | ||
| default="stable_audio_output.wav", | ||
| help="Output WAV file path", | ||
| ) | ||
| parser.add_argument( | ||
| "--steps", | ||
| type=int, | ||
| default=100, | ||
| help="Number of denoising steps (default: 100)", | ||
| ) | ||
| parser.add_argument( | ||
| "--guidance-scale", | ||
| type=float, | ||
| default=6.0, | ||
| help="Classifier-free guidance scale (default: 6.0)", | ||
| ) | ||
| parser.add_argument( | ||
| "--seed", | ||
| type=int, | ||
| default=42, | ||
| help="Random seed", | ||
| ) | ||
| parser.add_argument( | ||
| "--no-cpu-offload", | ||
| action="store_true", | ||
| help="Disable CPU offload for higher GPU utilization (requires more VRAM)", | ||
| ) | ||
| args = parser.parse_args() | ||
|
|
||
| offload_kwargs = {} | ||
| if args.no_cpu_offload: | ||
| offload_kwargs = dict( | ||
| dit_cpu_offload=False, | ||
| text_encoder_cpu_offload=False, | ||
| vae_cpu_offload=False, | ||
| ) | ||
|
|
||
| generator = VideoGenerator.from_pretrained( | ||
| args.model_path, | ||
| num_gpus=1, | ||
| **offload_kwargs, | ||
| ) | ||
|
|
||
| result = generator.generate_audio( | ||
| prompt=args.prompt, | ||
| duration_seconds=args.duration, | ||
| num_inference_steps=args.steps, | ||
| guidance_scale=args.guidance_scale, | ||
| seed=args.seed, | ||
| ) | ||
|
|
||
| generator.shutdown() | ||
|
|
||
| save_audio_wav(result["audio"], result["sample_rate"], args.output) | ||
| print(f"Saved audio to {args.output}") | ||
| print(f" Shape: {result['audio'].shape}, sample_rate: {result['sample_rate']} Hz") | ||
| if result.get("generation_time"): | ||
| print(f" Generation time: {result['generation_time']:.1f}s") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,39 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| """ | ||
| Stable Audio DiT config for FastVideo. | ||
| """ | ||
| from dataclasses import dataclass, field | ||
|
|
||
| from fastvideo.configs.models.dits.base import DiTArchConfig, DiTConfig | ||
|
|
||
|
|
||
| @dataclass | ||
| class StableAudioDiTArchConfig(DiTArchConfig): | ||
| """Arch config for Stable Audio DiT.""" | ||
|
|
||
| # Iterator strips model.model. prefix; map inner keys to wrapper's model.* | ||
| param_names_mapping: dict = field(default_factory=lambda: { | ||
| r"^(.*)$": r"model.\1", | ||
| }) | ||
| reverse_param_names_mapping: dict = field(default_factory=dict) | ||
| lora_param_names_mapping: dict = field(default_factory=dict) | ||
| _fsdp_shard_conditions: list = field(default_factory=list) | ||
|
|
||
| # HF config fields (from transformer/config.json) | ||
| attention_head_dim: int = 64 | ||
| cross_attention_dim: int = 768 | ||
| cross_attention_input_dim: int = 768 | ||
| global_states_input_dim: int = 1536 | ||
| num_key_value_attention_heads: int = 12 | ||
| num_layers: int = 24 | ||
| sample_size: int = 1024 | ||
| time_proj_dim: int = 256 | ||
|
|
||
|
|
||
| @dataclass | ||
| class StableAudioDiTConfig(DiTConfig): | ||
| """Config for Stable Audio DiffusionTransformer.""" | ||
|
|
||
| arch_config: DiTArchConfig = field(default_factory=StableAudioDiTArchConfig) | ||
| unified_checkpoint_path: str | None = None | ||
| transformer_key_prefix: str = "model.model." |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,22 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| """Stable Audio pipeline config.""" | ||
| from dataclasses import dataclass, field | ||
|
|
||
| from fastvideo.configs.models import DiTConfig | ||
| from fastvideo.configs.models.dits.stable_audio import StableAudioDiTConfig | ||
| from fastvideo.configs.pipelines.base import PipelineConfig | ||
|
|
||
|
|
||
| @dataclass | ||
| class StableAudioPipelineConfig(PipelineConfig): | ||
| """Config for Stable Audio text-to-audio pipeline. | ||
|
|
||
| Matches stable-audio-open-1.0: 44.1kHz, Oobleck VAE, T5+seconds conditioning. | ||
| """ | ||
|
|
||
| dit_config: DiTConfig = field(default_factory=StableAudioDiTConfig) | ||
|
|
||
| # Audio-specific | ||
| sample_rate: int = 44100 | ||
| sample_size: int = 2097152 # Max ~47.5s at 44.1kHz | ||
| embedded_cfg_scale: float = 6.0 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,42 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| """Sampling parameters for Stable Audio text-to-audio generation.""" | ||
| from dataclasses import dataclass | ||
|
|
||
| from fastvideo.configs.sample.base import SamplingParam | ||
|
|
||
|
|
||
| @dataclass | ||
| class StableAudioSamplingParam(SamplingParam): | ||
| """Default sampling parameters for Stable Audio Open text-to-audio. | ||
|
|
||
| Matches stable-audio-tools defaults: 44.1kHz, ~47.5s max duration, | ||
| 250 steps, cfg_scale=6. | ||
| """ | ||
|
|
||
| data_type: str = "audio" | ||
|
|
||
| # Audio-specific | ||
| sample_rate: int = 44100 | ||
| duration_seconds: float = 10.0 # Output duration in seconds | ||
| seconds_start: float = 0.0 # Conditioning: start offset (seconds) | ||
| seconds_total: float = 10.0 # Conditioning: total duration (seconds) | ||
|
|
||
| # Override video defaults for audio | ||
| num_frames: int = 1 | ||
| height: int = 1 | ||
| width: int = 1 | ||
| output_video_name: str | None = None # For audio, use output_audio_name | ||
| negative_prompt: str = "" | ||
|
|
||
| # Denoising (stable-audio-tools demo: 250 steps, cfg 3/6/9) | ||
| num_inference_steps: int = 250 | ||
| guidance_scale: float = 6.0 | ||
| seed: int = 42 | ||
|
|
||
| def __post_init__(self) -> None: | ||
| self.data_type = "audio" | ||
|
|
||
| @property | ||
| def sample_size(self) -> int: | ||
| """Audio samples (time dimension). For stereo, shape is (B, 2, sample_size).""" | ||
| return int(self.duration_seconds * self.sample_rate) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic for handling 1D (mono) audio arrays is incorrect. Reshaping a 1D array of shape
(T,)to(T, 1)using[:, None]results innum_channelsbeing incorrectly interpreted asTandnum_framesas1in the subsequent lines. To correctly handle mono audio, the array should be reshaped to(1, T). While this may not affect the current Stable Audio model which produces stereo output, fixing this will make thesave_audio_wavutility function more robust for general use.