Skip to content

Commit d9ffee4

Browse files
harveyhappy-harvey
authored andcommitted
[feat] Add stable_audio T2A Generation
1 parent 31c0f1b commit d9ffee4

32 files changed

Lines changed: 1849 additions & 16 deletions

docs/inference/inference_quick_start.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ More inference example scripts can be found in `scripts/inference/`
6767

6868
Please see the [support matrix](support_matrix.md) for the list of supported models and their available optimizations.
6969

70+
For **text-to-audio** generation (Stable Audio), see [Stable Audio](stable_audio.md).
71+
7072
## Image-to-Video Generation
7173

7274
You can generate a video starting from an initial image:

docs/inference/stable_audio.md

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# Stable Audio
2+
3+
FastVideo supports text-to-audio (T2A) generation via the **Stable Audio Open** model from Stability AI. This page describes supported models, installation, usage, and known limitations.
4+
5+
## Supported Models and Weights
6+
7+
| Model | HuggingFace ID | Local Path |
8+
|-------|----------------|------------|
9+
| Stable Audio Open 1.0 | `stabilityai/stable-audio-open-1.0` | `official_weights/stable-audio-open-1.0` |
10+
11+
**Weight format**:
12+
13+
- **HuggingFace**: Pass the model ID (e.g. `stabilityai/stable-audio-open-1.0`) to `VideoGenerator.from_pretrained()`. FastVideo will download and cache the model on first use.
14+
- **Local**: Place a unified checkpoint (`model.safetensors` or `model.ckpt`) and `model_config.json` at the model root. Use the directory path as `model_path`.
15+
16+
## Installation and Dependencies
17+
18+
### Conflict with `stable-audio-tools` (Python 3.12)
19+
20+
The `stable-audio-tools` PyPI package has dependencies that **fail to build on Python 3.12** (e.g. PyWavelets). Do **not** run `pip install stable-audio-tools` directly.
21+
22+
Use the following two-step install:
23+
24+
```bash
25+
# 1. Install stable-audio-tools without its dependencies
26+
pip install stable-audio-tools --no-deps
27+
28+
# 2. Install compatible inference dependencies
29+
pip install .[stable-audio]
30+
# or: pip install k-diffusion v-diffusion-pytorch prefigure ema-pytorch local-attention alias-free-torch
31+
```
32+
33+
If FastVideo is already installed:
34+
35+
```bash
36+
pip install stable-audio-tools --no-deps
37+
pip install fastvideo[stable-audio]
38+
```
39+
40+
### Dependencies Installed by `[stable-audio]`
41+
42+
- `k-diffusion>=0.1.1`
43+
- `v-diffusion-pytorch>=0.0.2`
44+
- `prefigure>=0.0.9`
45+
- `ema-pytorch>=0.2.3`
46+
- `local-attention>=1.8.6`
47+
- `alias-free-torch>=0.0.6`
48+
49+
These versions are compatible with FastVideo. `stable-audio-tools` declares stricter pins; the `--no-deps` install avoids pulling in conflicting packages (PyWavelets, encodec, etc.) that are not required for inference.
50+
51+
## Running the Example
52+
53+
### Basic usage
54+
55+
```bash
56+
python examples/inference/basic/stable_audio_basic.py
57+
```
58+
59+
### With custom parameters
60+
61+
```bash
62+
python examples/inference/basic/stable_audio_basic.py \
63+
--prompt "A gentle rain on a wooden roof" \
64+
--duration 10 \
65+
--steps 250 \
66+
--output my_audio.wav
67+
```
68+
69+
### Main parameters
70+
71+
| Argument | Default | Description |
72+
|----------|---------|-------------|
73+
| `--model-path` | `stabilityai/stable-audio-open-1.0` | Model path or HuggingFace model ID |
74+
| `--prompt` | `A beautiful piano arpeggio` | Text description of the audio to generate |
75+
| `--duration` | `10.0` | Output duration in seconds |
76+
| `--output` | `outputs_audio/stable_audio_output.wav` | Output WAV file path |
77+
| `--steps` | `250` | Number of denoising steps (`num_inference_steps`) |
78+
| `--guidance-scale` | `6.0` | Classifier-free guidance scale |
79+
| `--seed` | `42` | Random seed |
80+
| `--no-cpu-offload` | (flag) | Disable CPU offload for higher GPU utilization (requires more VRAM) |
81+
82+
### Programmatic usage
83+
84+
```python
85+
from fastvideo import VideoGenerator
86+
87+
generator = VideoGenerator.from_pretrained(
88+
"stabilityai/stable-audio-open-1.0",
89+
num_gpus=1,
90+
)
91+
92+
result = generator.generate_audio(
93+
prompt="A beautiful piano arpeggio",
94+
duration_seconds=10.0,
95+
num_inference_steps=250,
96+
guidance_scale=6.0,
97+
seed=42,
98+
)
99+
100+
# result["audio"]: torch.Tensor (B, C, T)
101+
# result["sample_rate"]: 44100
102+
generator.shutdown()
103+
```
104+
105+
### Sampling parameters (`generate_audio` kwargs)
106+
107+
| Parameter | Default | Description |
108+
|-----------|---------|-------------|
109+
| `duration_seconds` | `10.0` | Output duration (seconds) |
110+
| `num_inference_steps` | `250` | Denoising steps |
111+
| `guidance_scale` | `6.0` | CFG scale |
112+
| `seed` | `42` | Random seed |
113+
| `seconds_start` | `0.0` | Conditioning start offset |
114+
| `seconds_total` | Same as `duration_seconds` | Conditioning total duration |
115+
116+
`sample_rate` is fixed at **44.1 kHz** and comes from the pipeline config.
117+
118+
## Known Limitations
119+
120+
| Item | Description |
121+
|------|-------------|
122+
| **T2A only** | Only text-to-audio is supported. Audio-to-audio, stem separation, and other stable-audio-tools features are not implemented. |
123+
| **Single model** | Only Stable Audio Open 1.0 is supported. |
124+
| **VRAM** | ~6–8 GB for typical generation (10 s, 250 steps). Use `--no-cpu-offload` for higher GPU utilization; this increases VRAM use. |
125+
| **Max duration** | ~47.5 s at 44.1 kHz (model `sample_size` limit). |
126+
| **Differences from official** | Uses FastVideo’s pipeline layout and executor; sampling logic matches stable-audio-tools (k-diffusion v-prediction, DPM++ 2M SDE). Minor numerical differences may occur due to implementation details. |

docs/inference/support_matrix.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ The `HuggingFace Model ID` can be directly pass to `from_pretrained()` methods a
6060
| Matrix Game 2.0 Base | `FastVideo/Matrix-Game-2.0-Base-Diffusers` | 352x640 ||||||
6161
| Matrix Game 2.0 GTA | `FastVideo/Matrix-Game-2.0-GTA-Diffusers` | 352x640 ||||||
6262
| Matrix Game 2.0 TempleRun | `FastVideo/Matrix-Game-2.0-TempleRun-Diffusers` | 352x640 ||||||
63+
| Stable Audio Open 1.0 (T2A) | `stabilityai/stable-audio-open-1.0` | 44.1 kHz stereo ||||||
6364

6465
**Note**: Wan2.2 TI2V 5B has some quality issues when performing I2V generation. We are working on fixing this issue.
6566

@@ -80,3 +81,6 @@ The `HuggingFace Model ID` can be directly pass to `from_pretrained()` methods a
8081
- Image-to-video game world models with keyboard/mouse control input
8182
- Three variants available: Base (universal), GTA, and TempleRun
8283
- Each variant has different keyboard dimensions for control inputs
84+
85+
### Stable Audio Open 1.0
86+
- Text-to-audio (T2A) only. See [Stable Audio](stable_audio.md) for installation and usage.
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
#!/usr/bin/env python3
2+
# SPDX-License-Identifier: Apache-2.0
3+
"""
4+
Minimal example: generate audio from a text prompt using Stable Audio Open.
5+
6+
Install (once): pip install stable-audio-tools --no-deps && pip install .[stable-audio]
7+
8+
Usage:
9+
python examples/inference/basic/stable_audio_basic.py
10+
python examples/inference/basic/stable_audio_basic.py --prompt "A gentle rain" --duration 8
11+
python examples/inference/basic/stable_audio_basic.py --no-cpu-offload # higher GPU utilization
12+
"""
13+
import argparse
14+
import os
15+
16+
import numpy as np
17+
import torch
18+
19+
from fastvideo import VideoGenerator
20+
21+
22+
def save_audio_wav(audio: torch.Tensor, sample_rate: int, path: str) -> None:
23+
"""Save audio tensor (B, C, T) to WAV file. Output is stereo interleaved."""
24+
import wave
25+
26+
if audio.ndim == 3:
27+
audio = audio[0]
28+
audio_np = audio.detach().cpu().float().numpy()
29+
audio_np = np.clip(audio_np, -1.0, 1.0)
30+
audio_int16 = (audio_np * 32767.0).astype(np.int16)
31+
if audio_int16.ndim == 1:
32+
audio_int16 = audio_int16[:, None]
33+
num_channels = audio_int16.shape[0]
34+
num_frames = audio_int16.shape[1]
35+
frames_bytes = audio_int16.T.tobytes()
36+
37+
os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
38+
with wave.open(path, "wb") as wav_file:
39+
wav_file.setnchannels(num_channels)
40+
wav_file.setsampwidth(2)
41+
wav_file.setframerate(sample_rate)
42+
wav_file.writeframes(frames_bytes)
43+
44+
45+
def main() -> None:
46+
parser = argparse.ArgumentParser(description="Stable Audio text-to-audio generation")
47+
parser.add_argument(
48+
"--model-path",
49+
type=str,
50+
default="stabilityai/stable-audio-open-1.0",
51+
help="Path to model or HuggingFace model ID (e.g. stabilityai/stable-audio-open-1.0)",
52+
)
53+
parser.add_argument(
54+
"--prompt",
55+
type=str,
56+
default="A beautiful piano arpeggio",
57+
help="Text description of the audio to generate",
58+
)
59+
parser.add_argument(
60+
"--duration",
61+
type=float,
62+
default=10.0,
63+
help="Duration in seconds (default: 10)",
64+
)
65+
parser.add_argument(
66+
"--output",
67+
type=str,
68+
default="outputs_audio/stable_audio_output.wav",
69+
help="Output WAV file path",
70+
)
71+
parser.add_argument(
72+
"--steps",
73+
type=int,
74+
default=250,
75+
help="Number of denoising steps (default: 250)",
76+
)
77+
parser.add_argument(
78+
"--guidance-scale",
79+
type=float,
80+
default=6.0,
81+
help="Classifier-free guidance scale (default: 6.0)",
82+
)
83+
parser.add_argument(
84+
"--seed",
85+
type=int,
86+
default=42,
87+
help="Random seed",
88+
)
89+
parser.add_argument(
90+
"--no-cpu-offload",
91+
action="store_true",
92+
help="Disable CPU offload for higher GPU utilization (requires more VRAM)",
93+
)
94+
args = parser.parse_args()
95+
96+
offload_kwargs = {}
97+
if args.no_cpu_offload:
98+
offload_kwargs = dict(
99+
dit_cpu_offload=False,
100+
text_encoder_cpu_offload=False,
101+
vae_cpu_offload=False,
102+
)
103+
104+
generator = VideoGenerator.from_pretrained(
105+
args.model_path,
106+
num_gpus=1,
107+
**offload_kwargs,
108+
)
109+
110+
result = generator.generate_audio(
111+
prompt=args.prompt,
112+
duration_seconds=args.duration,
113+
num_inference_steps=args.steps,
114+
guidance_scale=args.guidance_scale,
115+
seed=args.seed,
116+
)
117+
118+
generator.shutdown()
119+
120+
save_audio_wav(result["audio"], result["sample_rate"], args.output)
121+
print(f"Saved audio to {args.output}")
122+
print(f" Shape: {result['audio'].shape}, sample_rate: {result['sample_rate']} Hz")
123+
if result.get("generation_time"):
124+
print(f" Generation time: {result['generation_time']:.1f}s")
125+
126+
127+
if __name__ == "__main__":
128+
main()

fastvideo/configs/models/dits/__init__.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,26 @@
33
from fastvideo.configs.models.dits.hunyuanvideo import HunyuanVideoConfig
44
from fastvideo.configs.models.dits.hunyuanvideo15 import HunyuanVideo15Config
55
from fastvideo.configs.models.dits.lingbotworld import LingBotWorldVideoConfig
6+
from fastvideo.configs.models.dits.hyworld import HYWorldConfig
67
from fastvideo.configs.models.dits.longcat import LongCatVideoConfig
78
from fastvideo.configs.models.dits.ltx2 import LTX2VideoConfig
9+
from fastvideo.configs.models.dits.stable_audio import StableAudioDiTConfig
810
from fastvideo.configs.models.dits.stepvideo import StepVideoConfig
911
from fastvideo.configs.models.dits.wanvideo import WanVideoConfig
10-
from fastvideo.configs.models.dits.hyworld import HYWorldConfig
1112

1213
__all__ = [
1314
"HunyuanVideoConfig", "HunyuanVideo15Config", "WanVideoConfig",
1415
"StepVideoConfig", "CosmosVideoConfig", "Cosmos25VideoConfig",
1516
"LongCatVideoConfig", "LTX2VideoConfig", "HYWorldConfig",
1617
"LingBotWorldVideoConfig"
18+
"HunyuanVideoConfig",
19+
"HunyuanVideo15Config",
20+
"WanVideoConfig",
21+
"StepVideoConfig",
22+
"CosmosVideoConfig",
23+
"Cosmos25VideoConfig",
24+
"LongCatVideoConfig",
25+
"LTX2VideoConfig",
26+
"HYWorldConfig",
27+
"StableAudioDiTConfig",
1728
]
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""
3+
Stable Audio DiT config for FastVideo.
4+
"""
5+
from dataclasses import dataclass, field
6+
7+
from fastvideo.configs.models.dits.base import DiTArchConfig, DiTConfig
8+
9+
10+
@dataclass
11+
class StableAudioDiTArchConfig(DiTArchConfig):
12+
"""Arch config for Stable Audio DiT."""
13+
14+
# Iterator strips model.model. prefix; map inner keys to wrapper's model.*
15+
param_names_mapping: dict = field(default_factory=lambda: {
16+
r"^(.*)$": r"model.\1",
17+
})
18+
reverse_param_names_mapping: dict = field(default_factory=dict)
19+
lora_param_names_mapping: dict = field(default_factory=dict)
20+
_fsdp_shard_conditions: list = field(default_factory=list)
21+
22+
# HF config fields (from transformer/config.json)
23+
attention_head_dim: int = 64
24+
cross_attention_dim: int = 768
25+
cross_attention_input_dim: int = 768
26+
global_states_input_dim: int = 1536
27+
num_key_value_attention_heads: int = 12
28+
num_layers: int = 24
29+
sample_size: int = 1024
30+
time_proj_dim: int = 256
31+
32+
33+
@dataclass
34+
class StableAudioDiTConfig(DiTConfig):
35+
"""Config for Stable Audio DiffusionTransformer."""
36+
37+
arch_config: DiTArchConfig = field(default_factory=StableAudioDiTArchConfig)
38+
unified_checkpoint_path: str | None = None
39+
transformer_key_prefix: str = "model.model."
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""Stable Audio pipeline config."""
3+
from dataclasses import dataclass, field
4+
5+
from fastvideo.configs.models import DiTConfig
6+
from fastvideo.configs.models.dits.stable_audio import StableAudioDiTConfig
7+
from fastvideo.configs.pipelines.base import PipelineConfig
8+
9+
10+
@dataclass
11+
class StableAudioPipelineConfig(PipelineConfig):
12+
"""Config for Stable Audio text-to-audio pipeline.
13+
14+
Matches stable-audio-open-1.0: 44.1kHz, Oobleck VAE, T5+seconds conditioning.
15+
"""
16+
17+
dit_config: DiTConfig = field(default_factory=StableAudioDiTConfig)
18+
19+
# Audio-specific
20+
sample_rate: int = 44100
21+
sample_size: int = 2097152 # Max ~47.5s at 44.1kHz
22+
embedded_cfg_scale: float = 6.0

0 commit comments

Comments
 (0)