Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/inference/inference_quick_start.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ More inference example scripts can be found in `scripts/inference/`

Please see the [support matrix](support_matrix.md) for the list of supported models and their available optimizations.

For **text-to-audio** generation (Stable Audio), see [Stable Audio](stable_audio.md).

## Image-to-Video Generation

You can generate a video starting from an initial image:
Expand Down
39 changes: 39 additions & 0 deletions docs/inference/stable_audio.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Stable Audio

Text-to-audio with **Stable Audio Open 1.0** (`stabilityai/stable-audio-open-1.0`). Weights are loaded from HuggingFace or a local path (unified `model.safetensors` + `model_config.json`).

## Setup

**1. HuggingFace login** (accept model terms on the Hub if required):

```bash
hf auth login
```

**2. Install deps** (k-diffusion, alias-free-torch, einops-exts):

```bash
pip install .[stable-audio]
```

## Usage

**CLI:**

```bash
python examples/inference/basic/stable_audio_basic.py
python examples/inference/basic/stable_audio_basic.py --prompt "A gentle rain" --duration 8 --output out.wav
```

**Python:**

```python
from fastvideo import VideoGenerator

gen = VideoGenerator.from_pretrained("stabilityai/stable-audio-open-1.0", num_gpus=1)
out = gen.generate_audio(prompt="A beautiful piano arpeggio", duration_seconds=10.0)
# out["audio"]: (B, C, T), out["sample_rate"]: 44100
gen.shutdown()
```

Optional: `--no-cpu-offload` for higher GPU use (more VRAM). Max duration ~47 s at 44.1 kHz.
4 changes: 4 additions & 0 deletions docs/inference/support_matrix.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ The `HuggingFace Model ID` can be directly pass to `from_pretrained()` methods a
| Matrix Game 2.0 Base | `FastVideo/Matrix-Game-2.0-Base-Diffusers` | 352x640 | ⭕ | ⭕ | ⭕ | ⭕ | ⭕ |
| Matrix Game 2.0 GTA | `FastVideo/Matrix-Game-2.0-GTA-Diffusers` | 352x640 | ⭕ | ⭕ | ⭕ | ⭕ | ⭕ |
| Matrix Game 2.0 TempleRun | `FastVideo/Matrix-Game-2.0-TempleRun-Diffusers` | 352x640 | ⭕ | ⭕ | ⭕ | ⭕ | ⭕ |
| Stable Audio Open 1.0 (T2A) | `stabilityai/stable-audio-open-1.0` | 44.1 kHz stereo | ⭕ | ⭕ | ⭕ | ⭕ | ⭕ |

**Note**: Wan2.2 TI2V 5B has some quality issues when performing I2V generation. We are working on fixing this issue.

Expand All @@ -80,3 +81,6 @@ The `HuggingFace Model ID` can be directly pass to `from_pretrained()` methods a
- Image-to-video game world models with keyboard/mouse control input
- Three variants available: Base (universal), GTA, and TempleRun
- Each variant has different keyboard dimensions for control inputs

### Stable Audio Open 1.0
- Text-to-audio (T2A) only. See [Stable Audio](stable_audio.md) for installation and usage.
128 changes: 128 additions & 0 deletions examples/inference/basic/stable_audio_basic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
#!/usr/bin/env python3
# SPDX-License-Identifier: Apache-2.0
"""
Minimal example: generate audio from a text prompt using Stable Audio Open.

Requires: pip install .[stable-audio] (no stable-audio-tools clone needed).

Usage:
python examples/inference/basic/stable_audio_basic.py
python examples/inference/basic/stable_audio_basic.py --prompt "A gentle rain" --duration 8
python examples/inference/basic/stable_audio_basic.py --no-cpu-offload # higher GPU utilization
"""
import argparse
import os

import numpy as np
import torch

from fastvideo import VideoGenerator


def save_audio_wav(audio: torch.Tensor, sample_rate: int, path: str) -> None:
"""Save audio tensor (B, C, T) to WAV file. Output is stereo interleaved."""
import wave

if audio.ndim == 3:
audio = audio[0]
audio_np = audio.detach().cpu().float().numpy()
audio_np = np.clip(audio_np, -1.0, 1.0)
audio_int16 = (audio_np * 32767.0).astype(np.int16)
if audio_int16.ndim == 1:
audio_int16 = audio_int16[:, None]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for handling 1D (mono) audio arrays is incorrect. Reshaping a 1D array of shape (T,) to (T, 1) using [:, None] results in num_channels being incorrectly interpreted as T and num_frames as 1 in the subsequent lines. To correctly handle mono audio, the array should be reshaped to (1, T). While this may not affect the current Stable Audio model which produces stereo output, fixing this will make the save_audio_wav utility function more robust for general use.

Suggested change
audio_int16 = audio_int16[:, None]
audio_int16 = audio_int16[np.newaxis, :]

num_channels = audio_int16.shape[0]
num_frames = audio_int16.shape[1]
frames_bytes = audio_int16.T.tobytes()

os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
with wave.open(path, "wb") as wav_file:
wav_file.setnchannels(num_channels)
wav_file.setsampwidth(2)
wav_file.setframerate(sample_rate)
wav_file.writeframes(frames_bytes)


def main() -> None:
parser = argparse.ArgumentParser(description="Stable Audio text-to-audio generation")
parser.add_argument(
"--model-path",
type=str,
default="stabilityai/stable-audio-open-1.0",
help="Path to model or HuggingFace model ID (e.g. stabilityai/stable-audio-open-1.0)",
)
parser.add_argument(
"--prompt",
type=str,
default="A beautiful piano arpeggio",
help="Text description of the audio to generate",
)
parser.add_argument(
"--duration",
type=float,
default=10.0,
help="Duration in seconds (default: 10)",
)
parser.add_argument(
"--output",
type=str,
default="stable_audio_output.wav",
help="Output WAV file path",
)
parser.add_argument(
"--steps",
type=int,
default=100,
help="Number of denoising steps (default: 100)",
)
parser.add_argument(
"--guidance-scale",
type=float,
default=6.0,
help="Classifier-free guidance scale (default: 6.0)",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="Random seed",
)
parser.add_argument(
"--no-cpu-offload",
action="store_true",
help="Disable CPU offload for higher GPU utilization (requires more VRAM)",
)
args = parser.parse_args()

offload_kwargs = {}
if args.no_cpu_offload:
offload_kwargs = dict(
dit_cpu_offload=False,
text_encoder_cpu_offload=False,
vae_cpu_offload=False,
)

generator = VideoGenerator.from_pretrained(
args.model_path,
num_gpus=1,
**offload_kwargs,
)

result = generator.generate_audio(
prompt=args.prompt,
duration_seconds=args.duration,
num_inference_steps=args.steps,
guidance_scale=args.guidance_scale,
seed=args.seed,
)

generator.shutdown()

save_audio_wav(result["audio"], result["sample_rate"], args.output)
print(f"Saved audio to {args.output}")
print(f" Shape: {result['audio'].shape}, sample_rate: {result['sample_rate']} Hz")
if result.get("generation_time"):
print(f" Generation time: {result['generation_time']:.1f}s")


if __name__ == "__main__":
main()
27 changes: 22 additions & 5 deletions fastvideo/configs/models/dits/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,32 @@
from fastvideo.configs.models.dits.hunyuanvideo import HunyuanVideoConfig
from fastvideo.configs.models.dits.hunyuanvideo15 import HunyuanVideo15Config
from fastvideo.configs.models.dits.lingbotworld import LingBotWorldVideoConfig
from fastvideo.configs.models.dits.hyworld import HYWorldConfig
from fastvideo.configs.models.dits.longcat import LongCatVideoConfig
from fastvideo.configs.models.dits.ltx2 import LTX2VideoConfig
from fastvideo.configs.models.dits.stable_audio import StableAudioDiTConfig
from fastvideo.configs.models.dits.stepvideo import StepVideoConfig
from fastvideo.configs.models.dits.wanvideo import WanVideoConfig
from fastvideo.configs.models.dits.hyworld import HYWorldConfig

__all__ = [
"HunyuanVideoConfig", "HunyuanVideo15Config", "WanVideoConfig",
"StepVideoConfig", "CosmosVideoConfig", "Cosmos25VideoConfig",
"LongCatVideoConfig", "LTX2VideoConfig", "HYWorldConfig",
"LingBotWorldVideoConfig"
"HunyuanVideoConfig",
"HunyuanVideo15Config",
"WanVideoConfig",
"StepVideoConfig",
"CosmosVideoConfig",
"Cosmos25VideoConfig",
"LongCatVideoConfig",
"LTX2VideoConfig",
"HYWorldConfig",
"LingBotWorldVideoConfig",
"HunyuanVideoConfig",
"HunyuanVideo15Config",
"WanVideoConfig",
"StepVideoConfig",
"CosmosVideoConfig",
"Cosmos25VideoConfig",
"LongCatVideoConfig",
"LTX2VideoConfig",
"HYWorldConfig",
"StableAudioDiTConfig",
]
39 changes: 39 additions & 0 deletions fastvideo/configs/models/dits/stable_audio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# SPDX-License-Identifier: Apache-2.0
"""
Stable Audio DiT config for FastVideo.
"""
from dataclasses import dataclass, field

from fastvideo.configs.models.dits.base import DiTArchConfig, DiTConfig


@dataclass
class StableAudioDiTArchConfig(DiTArchConfig):
"""Arch config for Stable Audio DiT."""

# Iterator strips model.model. prefix; map inner keys to wrapper's model.*
param_names_mapping: dict = field(default_factory=lambda: {
r"^(.*)$": r"model.\1",
})
reverse_param_names_mapping: dict = field(default_factory=dict)
lora_param_names_mapping: dict = field(default_factory=dict)
_fsdp_shard_conditions: list = field(default_factory=list)

# HF config fields (from transformer/config.json)
attention_head_dim: int = 64
cross_attention_dim: int = 768
cross_attention_input_dim: int = 768
global_states_input_dim: int = 1536
num_key_value_attention_heads: int = 12
num_layers: int = 24
sample_size: int = 1024
time_proj_dim: int = 256


@dataclass
class StableAudioDiTConfig(DiTConfig):
"""Config for Stable Audio DiffusionTransformer."""

arch_config: DiTArchConfig = field(default_factory=StableAudioDiTArchConfig)
unified_checkpoint_path: str | None = None
transformer_key_prefix: str = "model.model."
22 changes: 22 additions & 0 deletions fastvideo/configs/pipelines/stable_audio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# SPDX-License-Identifier: Apache-2.0
"""Stable Audio pipeline config."""
from dataclasses import dataclass, field

from fastvideo.configs.models import DiTConfig
from fastvideo.configs.models.dits.stable_audio import StableAudioDiTConfig
from fastvideo.configs.pipelines.base import PipelineConfig


@dataclass
class StableAudioPipelineConfig(PipelineConfig):
"""Config for Stable Audio text-to-audio pipeline.

Matches stable-audio-open-1.0: 44.1kHz, Oobleck VAE, T5+seconds conditioning.
"""

dit_config: DiTConfig = field(default_factory=StableAudioDiTConfig)

# Audio-specific
sample_rate: int = 44100
sample_size: int = 2097152 # Max ~47.5s at 44.1kHz
embedded_cfg_scale: float = 6.0
42 changes: 42 additions & 0 deletions fastvideo/configs/sample/stable_audio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# SPDX-License-Identifier: Apache-2.0
"""Sampling parameters for Stable Audio text-to-audio generation."""
from dataclasses import dataclass

from fastvideo.configs.sample.base import SamplingParam


@dataclass
class StableAudioSamplingParam(SamplingParam):
"""Default sampling parameters for Stable Audio Open text-to-audio.

Matches stable-audio-tools defaults: 44.1kHz, ~47.5s max duration,
250 steps, cfg_scale=6.
"""

data_type: str = "audio"

# Audio-specific
sample_rate: int = 44100
duration_seconds: float = 10.0 # Output duration in seconds
seconds_start: float = 0.0 # Conditioning: start offset (seconds)
seconds_total: float = 10.0 # Conditioning: total duration (seconds)

# Override video defaults for audio
num_frames: int = 1
height: int = 1
width: int = 1
output_video_name: str | None = None # For audio, use output_audio_name
negative_prompt: str = ""

# Denoising (stable-audio-tools demo: 250 steps, cfg 3/6/9)
num_inference_steps: int = 250
guidance_scale: float = 6.0
seed: int = 42

def __post_init__(self) -> None:
self.data_type = "audio"

@property
def sample_size(self) -> int:
"""Audio samples (time dimension). For stereo, shape is (B, 2, sample_size)."""
return int(self.duration_seconds * self.sample_rate)
Loading