diff --git a/docs/inference/inference_quick_start.md b/docs/inference/inference_quick_start.md index a55c85903c..1e0ea40e7a 100644 --- a/docs/inference/inference_quick_start.md +++ b/docs/inference/inference_quick_start.md @@ -67,6 +67,8 @@ More inference example scripts can be found in `scripts/inference/` Please see the [support matrix](support_matrix.md) for the list of supported models and their available optimizations. +For **text-to-audio** generation (Stable Audio), see [Stable Audio](stable_audio.md). + ## Image-to-Video Generation You can generate a video starting from an initial image: diff --git a/docs/inference/stable_audio.md b/docs/inference/stable_audio.md new file mode 100644 index 0000000000..61957ad958 --- /dev/null +++ b/docs/inference/stable_audio.md @@ -0,0 +1,39 @@ +# Stable Audio + +Text-to-audio with **Stable Audio Open 1.0** (`stabilityai/stable-audio-open-1.0`). Weights are loaded from HuggingFace or a local path (unified `model.safetensors` + `model_config.json`). + +## Setup + +**1. HuggingFace login** (accept model terms on the Hub if required): + +```bash +hf auth login +``` + +**2. Install deps** (k-diffusion, alias-free-torch, einops-exts): + +```bash +pip install .[stable-audio] +``` + +## Usage + +**CLI:** + +```bash +python examples/inference/basic/stable_audio_basic.py +python examples/inference/basic/stable_audio_basic.py --prompt "A gentle rain" --duration 8 --output out.wav +``` + +**Python:** + +```python +from fastvideo import VideoGenerator + +gen = VideoGenerator.from_pretrained("stabilityai/stable-audio-open-1.0", num_gpus=1) +out = gen.generate_audio(prompt="A beautiful piano arpeggio", duration_seconds=10.0) +# out["audio"]: (B, C, T), out["sample_rate"]: 44100 +gen.shutdown() +``` + +Optional: `--no-cpu-offload` for higher GPU use (more VRAM). Max duration ~47 s at 44.1 kHz. diff --git a/docs/inference/support_matrix.md b/docs/inference/support_matrix.md index afdec09cd1..3a19646243 100644 --- a/docs/inference/support_matrix.md +++ b/docs/inference/support_matrix.md @@ -60,6 +60,7 @@ The `HuggingFace Model ID` can be directly pass to `from_pretrained()` methods a | Matrix Game 2.0 Base | `FastVideo/Matrix-Game-2.0-Base-Diffusers` | 352x640 | ⭕ | ⭕ | ⭕ | ⭕ | ⭕ | | Matrix Game 2.0 GTA | `FastVideo/Matrix-Game-2.0-GTA-Diffusers` | 352x640 | ⭕ | ⭕ | ⭕ | ⭕ | ⭕ | | Matrix Game 2.0 TempleRun | `FastVideo/Matrix-Game-2.0-TempleRun-Diffusers` | 352x640 | ⭕ | ⭕ | ⭕ | ⭕ | ⭕ | +| Stable Audio Open 1.0 (T2A) | `stabilityai/stable-audio-open-1.0` | 44.1 kHz stereo | ⭕ | ⭕ | ⭕ | ⭕ | ⭕ | **Note**: Wan2.2 TI2V 5B has some quality issues when performing I2V generation. We are working on fixing this issue. @@ -80,3 +81,6 @@ The `HuggingFace Model ID` can be directly pass to `from_pretrained()` methods a - Image-to-video game world models with keyboard/mouse control input - Three variants available: Base (universal), GTA, and TempleRun - Each variant has different keyboard dimensions for control inputs + +### Stable Audio Open 1.0 +- Text-to-audio (T2A) only. See [Stable Audio](stable_audio.md) for installation and usage. diff --git a/examples/inference/basic/stable_audio_basic.py b/examples/inference/basic/stable_audio_basic.py new file mode 100644 index 0000000000..634e7d765a --- /dev/null +++ b/examples/inference/basic/stable_audio_basic.py @@ -0,0 +1,128 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +""" +Minimal example: generate audio from a text prompt using Stable Audio Open. + +Requires: pip install .[stable-audio] (no stable-audio-tools clone needed). + +Usage: + python examples/inference/basic/stable_audio_basic.py + python examples/inference/basic/stable_audio_basic.py --prompt "A gentle rain" --duration 8 + python examples/inference/basic/stable_audio_basic.py --no-cpu-offload # higher GPU utilization +""" +import argparse +import os + +import numpy as np +import torch + +from fastvideo import VideoGenerator + + +def save_audio_wav(audio: torch.Tensor, sample_rate: int, path: str) -> None: + """Save audio tensor (B, C, T) to WAV file. Output is stereo interleaved.""" + import wave + + if audio.ndim == 3: + audio = audio[0] + audio_np = audio.detach().cpu().float().numpy() + audio_np = np.clip(audio_np, -1.0, 1.0) + audio_int16 = (audio_np * 32767.0).astype(np.int16) + if audio_int16.ndim == 1: + audio_int16 = audio_int16[:, None] + num_channels = audio_int16.shape[0] + num_frames = audio_int16.shape[1] + frames_bytes = audio_int16.T.tobytes() + + os.makedirs(os.path.dirname(path) or ".", exist_ok=True) + with wave.open(path, "wb") as wav_file: + wav_file.setnchannels(num_channels) + wav_file.setsampwidth(2) + wav_file.setframerate(sample_rate) + wav_file.writeframes(frames_bytes) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Stable Audio text-to-audio generation") + parser.add_argument( + "--model-path", + type=str, + default="stabilityai/stable-audio-open-1.0", + help="Path to model or HuggingFace model ID (e.g. stabilityai/stable-audio-open-1.0)", + ) + parser.add_argument( + "--prompt", + type=str, + default="A beautiful piano arpeggio", + help="Text description of the audio to generate", + ) + parser.add_argument( + "--duration", + type=float, + default=10.0, + help="Duration in seconds (default: 10)", + ) + parser.add_argument( + "--output", + type=str, + default="stable_audio_output.wav", + help="Output WAV file path", + ) + parser.add_argument( + "--steps", + type=int, + default=100, + help="Number of denoising steps (default: 100)", + ) + parser.add_argument( + "--guidance-scale", + type=float, + default=6.0, + help="Classifier-free guidance scale (default: 6.0)", + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="Random seed", + ) + parser.add_argument( + "--no-cpu-offload", + action="store_true", + help="Disable CPU offload for higher GPU utilization (requires more VRAM)", + ) + args = parser.parse_args() + + offload_kwargs = {} + if args.no_cpu_offload: + offload_kwargs = dict( + dit_cpu_offload=False, + text_encoder_cpu_offload=False, + vae_cpu_offload=False, + ) + + generator = VideoGenerator.from_pretrained( + args.model_path, + num_gpus=1, + **offload_kwargs, + ) + + result = generator.generate_audio( + prompt=args.prompt, + duration_seconds=args.duration, + num_inference_steps=args.steps, + guidance_scale=args.guidance_scale, + seed=args.seed, + ) + + generator.shutdown() + + save_audio_wav(result["audio"], result["sample_rate"], args.output) + print(f"Saved audio to {args.output}") + print(f" Shape: {result['audio'].shape}, sample_rate: {result['sample_rate']} Hz") + if result.get("generation_time"): + print(f" Generation time: {result['generation_time']:.1f}s") + + +if __name__ == "__main__": + main() diff --git a/fastvideo/configs/models/dits/__init__.py b/fastvideo/configs/models/dits/__init__.py index dfe6f4aef9..80c91a55b3 100644 --- a/fastvideo/configs/models/dits/__init__.py +++ b/fastvideo/configs/models/dits/__init__.py @@ -3,15 +3,32 @@ from fastvideo.configs.models.dits.hunyuanvideo import HunyuanVideoConfig from fastvideo.configs.models.dits.hunyuanvideo15 import HunyuanVideo15Config from fastvideo.configs.models.dits.lingbotworld import LingBotWorldVideoConfig +from fastvideo.configs.models.dits.hyworld import HYWorldConfig from fastvideo.configs.models.dits.longcat import LongCatVideoConfig from fastvideo.configs.models.dits.ltx2 import LTX2VideoConfig +from fastvideo.configs.models.dits.stable_audio import StableAudioDiTConfig from fastvideo.configs.models.dits.stepvideo import StepVideoConfig from fastvideo.configs.models.dits.wanvideo import WanVideoConfig -from fastvideo.configs.models.dits.hyworld import HYWorldConfig __all__ = [ - "HunyuanVideoConfig", "HunyuanVideo15Config", "WanVideoConfig", - "StepVideoConfig", "CosmosVideoConfig", "Cosmos25VideoConfig", - "LongCatVideoConfig", "LTX2VideoConfig", "HYWorldConfig", - "LingBotWorldVideoConfig" + "HunyuanVideoConfig", + "HunyuanVideo15Config", + "WanVideoConfig", + "StepVideoConfig", + "CosmosVideoConfig", + "Cosmos25VideoConfig", + "LongCatVideoConfig", + "LTX2VideoConfig", + "HYWorldConfig", + "LingBotWorldVideoConfig", + "HunyuanVideoConfig", + "HunyuanVideo15Config", + "WanVideoConfig", + "StepVideoConfig", + "CosmosVideoConfig", + "Cosmos25VideoConfig", + "LongCatVideoConfig", + "LTX2VideoConfig", + "HYWorldConfig", + "StableAudioDiTConfig", ] diff --git a/fastvideo/configs/models/dits/stable_audio.py b/fastvideo/configs/models/dits/stable_audio.py new file mode 100644 index 0000000000..fc358d0275 --- /dev/null +++ b/fastvideo/configs/models/dits/stable_audio.py @@ -0,0 +1,39 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Stable Audio DiT config for FastVideo. +""" +from dataclasses import dataclass, field + +from fastvideo.configs.models.dits.base import DiTArchConfig, DiTConfig + + +@dataclass +class StableAudioDiTArchConfig(DiTArchConfig): + """Arch config for Stable Audio DiT.""" + + # Iterator strips model.model. prefix; map inner keys to wrapper's model.* + param_names_mapping: dict = field(default_factory=lambda: { + r"^(.*)$": r"model.\1", + }) + reverse_param_names_mapping: dict = field(default_factory=dict) + lora_param_names_mapping: dict = field(default_factory=dict) + _fsdp_shard_conditions: list = field(default_factory=list) + + # HF config fields (from transformer/config.json) + attention_head_dim: int = 64 + cross_attention_dim: int = 768 + cross_attention_input_dim: int = 768 + global_states_input_dim: int = 1536 + num_key_value_attention_heads: int = 12 + num_layers: int = 24 + sample_size: int = 1024 + time_proj_dim: int = 256 + + +@dataclass +class StableAudioDiTConfig(DiTConfig): + """Config for Stable Audio DiffusionTransformer.""" + + arch_config: DiTArchConfig = field(default_factory=StableAudioDiTArchConfig) + unified_checkpoint_path: str | None = None + transformer_key_prefix: str = "model.model." diff --git a/fastvideo/configs/pipelines/stable_audio.py b/fastvideo/configs/pipelines/stable_audio.py new file mode 100644 index 0000000000..0ffe119907 --- /dev/null +++ b/fastvideo/configs/pipelines/stable_audio.py @@ -0,0 +1,22 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Stable Audio pipeline config.""" +from dataclasses import dataclass, field + +from fastvideo.configs.models import DiTConfig +from fastvideo.configs.models.dits.stable_audio import StableAudioDiTConfig +from fastvideo.configs.pipelines.base import PipelineConfig + + +@dataclass +class StableAudioPipelineConfig(PipelineConfig): + """Config for Stable Audio text-to-audio pipeline. + + Matches stable-audio-open-1.0: 44.1kHz, Oobleck VAE, T5+seconds conditioning. + """ + + dit_config: DiTConfig = field(default_factory=StableAudioDiTConfig) + + # Audio-specific + sample_rate: int = 44100 + sample_size: int = 2097152 # Max ~47.5s at 44.1kHz + embedded_cfg_scale: float = 6.0 diff --git a/fastvideo/configs/sample/stable_audio.py b/fastvideo/configs/sample/stable_audio.py new file mode 100644 index 0000000000..33b13409ed --- /dev/null +++ b/fastvideo/configs/sample/stable_audio.py @@ -0,0 +1,42 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Sampling parameters for Stable Audio text-to-audio generation.""" +from dataclasses import dataclass + +from fastvideo.configs.sample.base import SamplingParam + + +@dataclass +class StableAudioSamplingParam(SamplingParam): + """Default sampling parameters for Stable Audio Open text-to-audio. + + Matches stable-audio-tools defaults: 44.1kHz, ~47.5s max duration, + 250 steps, cfg_scale=6. + """ + + data_type: str = "audio" + + # Audio-specific + sample_rate: int = 44100 + duration_seconds: float = 10.0 # Output duration in seconds + seconds_start: float = 0.0 # Conditioning: start offset (seconds) + seconds_total: float = 10.0 # Conditioning: total duration (seconds) + + # Override video defaults for audio + num_frames: int = 1 + height: int = 1 + width: int = 1 + output_video_name: str | None = None # For audio, use output_audio_name + negative_prompt: str = "" + + # Denoising (stable-audio-tools demo: 250 steps, cfg 3/6/9) + num_inference_steps: int = 250 + guidance_scale: float = 6.0 + seed: int = 42 + + def __post_init__(self) -> None: + self.data_type = "audio" + + @property + def sample_size(self) -> int: + """Audio samples (time dimension). For stereo, shape is (B, 2, sample_size).""" + return int(self.duration_seconds * self.sample_rate) diff --git a/fastvideo/entrypoints/video_generator.py b/fastvideo/entrypoints/video_generator.py index b6846e680a..63bba05e7c 100644 --- a/fastvideo/entrypoints/video_generator.py +++ b/fastvideo/entrypoints/video_generator.py @@ -230,6 +230,75 @@ def _is_image_workload(self) -> bool: return False return args.workload_type.value.endswith("2i") + def generate_audio( + self, + prompt: str, + sampling_param: SamplingParam | None = None, + **kwargs, + ) -> dict[str, Any]: + """ + Generate audio from a text prompt (Stable Audio pipeline). + + Args: + prompt: Text description of the audio to generate. + sampling_param: Sampling parameters. If None, loaded from model. + **kwargs: Override sampling params (duration_seconds, num_inference_steps, + guidance_scale, seed, etc.). + + Returns: + Dict with "audio" (torch.Tensor), "sample_rate" (int), "prompt", etc. + """ + if sampling_param is None: + sampling_param = SamplingParam.from_pretrained( + self.fastvideo_args.model_path) + sampling_param = deepcopy(sampling_param) + sampling_param.update(kwargs) + sampling_param.prompt = prompt.strip() + + pipeline_config = self.fastvideo_args.pipeline_config + sample_rate = getattr(pipeline_config, "sample_rate", 44100) + duration = sampling_param.duration_seconds or 10.0 + + batch = ForwardBatch( + **shallow_asdict(sampling_param), + eta=0.0, + n_tokens=1, + VSA_sparsity=self.fastvideo_args.VSA_sparsity, + ) + batch.sample_rate = sample_rate + batch.duration_seconds = duration + batch.seconds_total = duration + if batch.seconds_start is None: + batch.seconds_start = 0.0 + batch.do_classifier_free_guidance = (batch.guidance_scale or 1.0) > 1.0 + + result_container: dict[str, Any] = {} + + def execute_forward_thread(): + result_container["output_batch"] = self.executor.execute_forward( + batch, self.fastvideo_args) + + start_time = time.perf_counter() + thread = threading.Thread(target=execute_forward_thread) + thread.start() + thread.join() + gen_time = time.perf_counter() - start_time + + output_batch = result_container["output_batch"] + audio = output_batch.output.cpu() + if torch.is_tensor(audio) and audio.ndim == 3: + pass + elif torch.is_tensor(audio) and audio.ndim == 2: + audio = audio.unsqueeze(0) + + return { + "audio": audio, + "sample_rate": sample_rate, + "prompt": prompt, + "duration_seconds": duration, + "generation_time": gen_time, + } + def _prepare_output_path( self, output_path: str, diff --git a/fastvideo/models/dits/stable_audio.py b/fastvideo/models/dits/stable_audio.py new file mode 100644 index 0000000000..9dfb7b8d7d --- /dev/null +++ b/fastvideo/models/dits/stable_audio.py @@ -0,0 +1,102 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Stable Audio DiT model for FastVideo. + +Uses in-repo DiffusionTransformer (sat_dit) for loading from unified +model.safetensors checkpoint. No stable-audio-tools clone required. +""" +from __future__ import annotations + +import json +import os +from typing import Any + +import torch +import torch.nn as nn + +from fastvideo.configs.models.dits.stable_audio import ( + StableAudioDiTArchConfig, + StableAudioDiTConfig, +) +from fastvideo.logger import init_logger +from fastvideo.models.stable_audio.sat_dit import DiffusionTransformer + +logger = init_logger(__name__) + + +def _hf_config_to_dit_kwargs(config: dict) -> dict: + """Map HF transformer config to stable-audio-tools DiffusionTransformer kwargs.""" + embed_dim = config.get("global_states_input_dim", 1536) + return { + "io_channels": config.get("in_channels", 64), + "embed_dim": embed_dim, + "cond_token_dim": config.get("cross_attention_dim", 768), + "project_cond_tokens": False, + "global_cond_dim": embed_dim, + "project_global_cond": False, + "depth": config.get("num_layers", 24), + "num_heads": config.get("num_attention_heads", 24), + "patch_size": 1, + } + + +class StableAudioDiTModel(nn.Module): + """ + FastVideo wrapper for Stable Audio DiffusionTransformer. + + Loads from unified model.safetensors using param_names_mapping to strip + the model.model. prefix. Compatible with stable-audio-tools checkpoints. + """ + + param_names_mapping = StableAudioDiTConfig().arch_config.param_names_mapping + reverse_param_names_mapping = ( + StableAudioDiTConfig().arch_config.reverse_param_names_mapping + ) + lora_param_names_mapping = ( + StableAudioDiTConfig().arch_config.lora_param_names_mapping + ) + + def __init__(self, config: dict | StableAudioDiTConfig, hf_config: dict | None = None): + super().__init__() + if isinstance(config, StableAudioDiTConfig): + arch = config.arch_config + embed_dim = getattr(arch, "global_states_input_dim", None) or getattr( + arch, "hidden_size", 1536 + ) + kwargs = { + "io_channels": getattr(arch, "in_channels", 64), + "embed_dim": embed_dim, + "cond_token_dim": getattr(arch, "cross_attention_dim", 768), + "project_cond_tokens": False, + "global_cond_dim": embed_dim, + "project_global_cond": False, + "depth": getattr(arch, "num_layers", 24), + "num_heads": getattr(arch, "num_attention_heads", 24), + "patch_size": 1, + } + else: + kwargs = _hf_config_to_dit_kwargs(config) + self.model = DiffusionTransformer(**kwargs) + + def forward(self, *args, **kwargs): + return self.model(*args, **kwargs) + + @classmethod + def from_pretrained( + cls, + model_path: str, + config: dict | None = None, + **kwargs, + ) -> "StableAudioDiTModel": + """Create from HF-style model path.""" + config_path = os.path.join(model_path, "config.json") + if os.path.isfile(config_path): + with open(config_path) as f: + config = json.load(f) + config = config or {} + config.pop("_class_name", None) + config.pop("_diffusers_version", None) + return cls(config=config, **kwargs) + + +EntryClass = StableAudioDiTModel diff --git a/fastvideo/models/loader/component_loader.py b/fastvideo/models/loader/component_loader.py index 9afe4a1c03..1839fcd399 100644 --- a/fastvideo/models/loader/component_loader.py +++ b/fastvideo/models/loader/component_loader.py @@ -772,17 +772,29 @@ def load(self, model_path: str, fastvideo_args: FastVideoArgs): model_cls, _ = ModelRegistry.resolve_model_cls(cls_name) - # Find all safetensors files - safetensors_list = glob.glob( - os.path.join(str(model_path), "*.safetensors") - ) - if not safetensors_list: - raise ValueError(f"No safetensors files found in {model_path}") - - # Check if we should use custom initialization weights + # Check if we should use custom initialization weights or unified checkpoint custom_weights_path = getattr( fastvideo_args, "init_weights_from_safetensors", None ) + pipeline_dit_config = getattr( + fastvideo_args.pipeline_config, "dit_config", None + ) + unified_path = getattr(pipeline_dit_config, "unified_checkpoint_path", None) + if unified_path: + if not os.path.isabs(unified_path): + # Resolve relative to model root (parent of transformer/) + model_root = os.path.dirname(str(model_path)) + unified_path = os.path.join(model_root, unified_path) + if os.path.exists(unified_path): + custom_weights_path = unified_path + elif getattr(pipeline_dit_config, "transformer_key_prefix", None): + # Infer unified checkpoint: model.safetensors or model.ckpt in model root + model_root = os.path.dirname(str(model_path)) + for name in ("model.safetensors", "model.ckpt"): + inferred = os.path.join(model_root, name) + if os.path.exists(inferred): + custom_weights_path = inferred + break use_custom_weights = ( custom_weights_path and os.path.exists(custom_weights_path) @@ -802,10 +814,17 @@ def load(self, model_path: str, fastvideo_args: FastVideoArgs): os.path.join(str(custom_weights_path), "*.safetensors") ) else: - assert custom_weights_path.endswith(".safetensors"), ( - "Custom initialization weights must be a safetensors file" + assert custom_weights_path.endswith((".safetensors", ".ckpt")), ( + "Custom initialization weights must be a .safetensors or .ckpt file" ) safetensors_list = [custom_weights_path] + else: + safetensors_list = glob.glob( + os.path.join(str(model_path), "*.safetensors") + ) + + if not safetensors_list: + raise ValueError(f"No safetensors files found in {model_path}") logger.info( "Loading model from %s safetensors files: %s", @@ -829,6 +848,13 @@ def load(self, model_path: str, fastvideo_args: FastVideoArgs): or cls_name == "Cosmos25Transformer3DModel" or getattr(fastvideo_args.pipeline_config, "prefix", "") == "Cosmos25" ) + # For unified checkpoints (e.g. Stable Audio model.safetensors), filter + # keys by prefix when loading transformer weights + weight_key_prefix = getattr( + pipeline_dit_config, + "transformer_key_prefix", + None, + ) model = maybe_load_fsdp_model( model_cls=model_cls, init_params={"config": dit_config, "hf_config": hf_config}, @@ -848,6 +874,7 @@ def load(self, model_path: str, fastvideo_args: FastVideoArgs): training_mode=fastvideo_args.training_mode, enable_torch_compile=fastvideo_args.enable_torch_compile, torch_compile_kwargs=fastvideo_args.torch_compile_kwargs, + weight_key_prefix=weight_key_prefix, ) total_params = sum(p.numel() for p in model.parameters()) diff --git a/fastvideo/models/loader/fsdp_load.py b/fastvideo/models/loader/fsdp_load.py index 9ba60320aa..a618f40989 100644 --- a/fastvideo/models/loader/fsdp_load.py +++ b/fastvideo/models/loader/fsdp_load.py @@ -22,7 +22,10 @@ from fastvideo.logger import init_logger from fastvideo.models.loader.utils import (get_param_names_mapping, hf_to_custom_state_dict) -from fastvideo.models.loader.weight_utils import safetensors_weights_iterator +from fastvideo.models.loader.weight_utils import ( + safetensors_weights_iterator, + unified_checkpoint_weights_iterator, +) from fastvideo.utils import set_mixed_precision_policy, is_pin_memory_available logger = init_logger(__name__) @@ -75,6 +78,7 @@ def maybe_load_fsdp_model( pin_cpu_memory: bool = True, enable_torch_compile: bool = False, torch_compile_kwargs: dict[str, Any] | None = None, + weight_key_prefix: str | None = None, ) -> torch.nn.Module: """ Load the model with FSDP if is training, else load the model without FSDP. @@ -136,7 +140,15 @@ def maybe_load_fsdp_model( fsdp_shard_conditions=model._fsdp_shard_conditions, pin_cpu_memory=pin_cpu_memory) - weight_iterator = safetensors_weights_iterator(weight_dir_list) + # Support unified checkpoints (model.safetensors or model.ckpt) + if len(weight_dir_list) == 1 and weight_dir_list[0].endswith(".ckpt"): + weight_iterator = unified_checkpoint_weights_iterator( + weight_dir_list[0], to_cpu=not cpu_offload, key_prefix=weight_key_prefix + ) + else: + weight_iterator = safetensors_weights_iterator( + weight_dir_list, key_prefix=weight_key_prefix + ) param_names_mapping_fn = get_param_names_mapping(model.param_names_mapping) load_model_from_full_model_state_dict( model, diff --git a/fastvideo/models/loader/weight_utils.py b/fastvideo/models/loader/weight_utils.py index 3e6fa91ba8..2c20c59f0b 100644 --- a/fastvideo/models/loader/weight_utils.py +++ b/fastvideo/models/loader/weight_utils.py @@ -117,9 +117,41 @@ def filter_files_not_needed_for_inference( _BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501 +def unified_checkpoint_weights_iterator( + checkpoint_path: str, + to_cpu: bool = True, + key_prefix: str | None = None, +) -> Generator[tuple[str, torch.Tensor], None, None]: + """ + Iterate over weights in a unified checkpoint (model.safetensors or model.ckpt). + Supports Stable Audio and similar checkpoints. + """ + if checkpoint_path.endswith(".safetensors"): + yield from safetensors_weights_iterator( + [checkpoint_path], to_cpu=to_cpu, key_prefix=key_prefix + ) + else: + # .ckpt or .pt: state_dict may be nested under "state_dict" + state = torch.load( + checkpoint_path, map_location="cpu" if to_cpu else None, weights_only=True + ) + if isinstance(state, dict) and "state_dict" in state: + state = state["state_dict"] + for full_name, param in state.items(): + if key_prefix is not None: + if not full_name.startswith(key_prefix): + continue + yield_name = full_name[len(key_prefix):] + else: + yield_name = full_name + yield yield_name, param + del state + + def safetensors_weights_iterator( hf_weights_files: list[str], to_cpu: bool = True, + key_prefix: str | None = None, ) -> Generator[tuple[str, torch.Tensor], None, None]: """Iterate over the weights in the model safetensor files.""" enable_tqdm = not torch.distributed.is_initialized( @@ -132,9 +164,15 @@ def safetensors_weights_iterator( bar_format=_BAR_FORMAT, ): with safe_open(st_file, framework="pt", device=device) as f: - for name in f.keys(): # noqa: SIM118 - param = f.get_tensor(name) - yield name, param + for full_name in f.keys(): # noqa: SIM118 + if key_prefix is not None: + if not full_name.startswith(key_prefix): + continue + yield_name = full_name[len(key_prefix):] + else: + yield_name = full_name + param = f.get_tensor(full_name) + yield yield_name, param def pt_weights_iterator( diff --git a/fastvideo/models/registry.py b/fastvideo/models/registry.py index fc48ddc177..343ee58088 100644 --- a/fastvideo/models/registry.py +++ b/fastvideo/models/registry.py @@ -23,6 +23,7 @@ # huggingface class name: (component_name, fastvideo module name, fastvideo class name) _TEXT_TO_VIDEO_DIT_MODELS = { + "StableAudioDiTModel": ("dits", "stable_audio", "StableAudioDiTModel"), "HunyuanVideoTransformer3DModel": ("dits", "hunyuanvideo", "HunyuanVideoTransformer3DModel"), "HunyuanVideo15Transformer3DModel": diff --git a/fastvideo/models/stable_audio/__init__.py b/fastvideo/models/stable_audio/__init__.py new file mode 100644 index 0000000000..4ee51259a0 --- /dev/null +++ b/fastvideo/models/stable_audio/__init__.py @@ -0,0 +1,14 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Stable Audio: pretransform (VAE), conditioner, sampling, DiT. + +All components are in-repo; no stable-audio-tools clone required. +- Sampling: k-diffusion only. Conditioner: t5 + number (conditioners_inline). +- Pretransform: Oobleck VAE (sat_factory). DiT: DiffusionTransformer (sat_dit). +""" +from .pretransform import StableAudioPretransform +from .conditioner import StableAudioConditioner + +__all__ = [ + "StableAudioPretransform", + "StableAudioConditioner", +] diff --git a/fastvideo/models/stable_audio/conditioner.py b/fastvideo/models/stable_audio/conditioner.py new file mode 100644 index 0000000000..773bbc31d3 --- /dev/null +++ b/fastvideo/models/stable_audio/conditioner.py @@ -0,0 +1,92 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Stable Audio conditioner: T5 (prompt) + Number (seconds_start, seconds_total). + +Uses inlined conditioners (t5 + number only). Loads weights from unified +model.safetensors using conditioner.conditioners.* +""" +from __future__ import annotations + +import os +from typing import Any + +import torch +import torch.nn as nn + +from fastvideo.logger import init_logger +from fastvideo.models.stable_audio.conditioners_inline import ( + create_multi_conditioner_from_conditioning_config, +) + +logger = init_logger(__name__) + +CONDITIONER_KEY_PREFIX = "conditioner.conditioners." + + +def _create_multi_conditioner_from_config(model_config: dict) -> nn.Module: + """Create MultiConditioner from model_config (t5 + number only, no clone).""" + model_cfg = model_config.get("model", model_config) + conditioning_config = model_cfg.get("conditioning") + if conditioning_config is None: + raise ValueError("model_config must contain model.conditioning") + return create_multi_conditioner_from_conditioning_config(conditioning_config) + + +class StableAudioConditioner(nn.Module): + """ + FastVideo wrapper for Stable Audio conditioner. + + - prompt: T5-base (external, from HuggingFace) + - seconds_start, seconds_total: NumberEmbedder (loaded from checkpoint) + + Loads conditioner weights from unified model.safetensors. + """ + + def __init__( + self, + model_config: dict | str | None = None, + checkpoint_path: str | None = None, + ): + super().__init__() + if isinstance(model_config, str): + import json + with open(model_config) as f: + model_config = json.load(f) + if model_config is None: + model_config = {} + self._conditioner = _create_multi_conditioner_from_config(model_config) + + if checkpoint_path and os.path.exists(checkpoint_path): + self.load_from_unified_checkpoint(checkpoint_path) + + def load_from_unified_checkpoint(self, checkpoint_path: str) -> None: + """Load conditioner weights (seconds_start, seconds_total) from unified checkpoint.""" + from fastvideo.models.loader.weight_utils import unified_checkpoint_weights_iterator + + state_dict = {} + for key, tensor in unified_checkpoint_weights_iterator( + checkpoint_path, to_cpu=True, key_prefix=CONDITIONER_KEY_PREFIX + ): + # Checkpoint: conditioner.conditioners.seconds_start.embedder.xxx + # After strip: seconds_start.embedder.xxx + # MultiConditioner has self.conditioners = ModuleDict, so state_dict keys need "conditioners." prefix + state_dict["conditioners." + key] = tensor + + self._conditioner.load_state_dict(state_dict, strict=False) + loaded = len(state_dict) + logger.info("Loaded conditioner from %s (%d keys)", checkpoint_path, loaded) + + def forward( + self, + batch_metadata: list[dict[str, Any]], + device: torch.device | str, + ) -> dict[str, Any]: + """Run conditioner on batch metadata. Same interface as stable-audio-tools.""" + return self._conditioner(batch_metadata, device) + + def __call__( + self, + batch_metadata: list[dict[str, Any]], + device: torch.device | str, + ) -> dict[str, Any]: + return self.forward(batch_metadata, device) diff --git a/fastvideo/models/stable_audio/conditioners_inline.py b/fastvideo/models/stable_audio/conditioners_inline.py new file mode 100644 index 0000000000..05b7279d82 --- /dev/null +++ b/fastvideo/models/stable_audio/conditioners_inline.py @@ -0,0 +1,258 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Minimal Stable Audio conditioners inlined for t5 + number (seconds) only. + +No dependency on stable-audio-tools. Supports model_config conditioning with +type "t5" and "number" (stable-audio-open-1.0). +""" +from __future__ import annotations + +import logging +import typing as tp +import warnings + +import torch +from einops import rearrange +from torch import nn + +# NumberEmbedder dependency: LearnedPositionalEmbedding -> TimePositionalEmbedding + + +class LearnedPositionalEmbedding(nn.Module): + """Continuous time embedding (from stable-audio-tools adp).""" + + def __init__(self, dim: int) -> None: + super().__init__() + assert dim % 2 == 0 + self.weights = nn.Parameter(torch.randn(dim // 2)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = rearrange(x, "b -> b 1") + freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * 3.141592653589793 + fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) + return torch.cat((x, fouriered), dim=-1) + + +def _time_positional_embedding(dim: int, out_features: int) -> nn.Module: + return nn.Sequential( + LearnedPositionalEmbedding(dim), + nn.Linear(in_features=dim + 1, out_features=out_features), + ) + + +class NumberEmbedder(nn.Module): + """Embed floats for conditioning (e.g. seconds_start, seconds_total).""" + + def __init__(self, features: int, dim: int = 256) -> None: + super().__init__() + self.features = features + self.embedding = _time_positional_embedding(dim=dim, out_features=features) + + def forward(self, x: tp.Union[tp.List[float], torch.Tensor]) -> torch.Tensor: + if not torch.is_tensor(x): + x = torch.tensor(x, device=next(self.embedding.parameters()).device) + shape = x.shape + x = rearrange(x, "... -> (...)") + embedding = self.embedding(x) + return embedding.view(*shape, self.features) + + +class Conditioner(nn.Module): + """Base conditioner.""" + + def __init__( + self, dim: int, output_dim: int, project_out: bool = False + ) -> None: + super().__init__() + self.proj_out = ( + nn.Linear(dim, output_dim) + if (dim != output_dim or project_out) + else nn.Identity() + ) + + def forward(self, x: tp.Any) -> tp.Any: + raise NotImplementedError() + + +class NumberConditioner(Conditioner): + """Conditioner for float lists (e.g. seconds_start, seconds_total).""" + + def __init__( + self, + output_dim: int, + min_val: float = 0.0, + max_val: float = 1.0, + ) -> None: + super().__init__(output_dim, output_dim) + self.min_val = min_val + self.max_val = max_val + self.embedder = NumberEmbedder(features=output_dim) + + def forward( + self, floats: tp.List[float], device: tp.Optional[torch.device] = None + ) -> tp.Tuple[torch.Tensor, torch.Tensor]: + floats = [float(x) for x in floats] + t = torch.tensor(floats).to(device) + t = t.clamp(self.min_val, self.max_val) + normalized = (t - self.min_val) / (self.max_val - self.min_val) + embedder_dtype = next(self.embedder.parameters()).dtype + normalized = normalized.to(embedder_dtype) + embeds = self.embedder(normalized).unsqueeze(1) + ones = torch.ones(embeds.shape[0], 1, device=embeds.device) + return [embeds, ones] + + +class T5Conditioner(Conditioner): + """T5 text conditioner (stable-audio-open uses t5-base).""" + + T5_MODELS = [ + "t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b", + "google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large", + "google/flan-t5-xl", "google/flan-t5-xxl", + "google/t5-v1_1-xl", "google/t5-v1_1-xxl", + ] + T5_MODEL_DIMS = { + "t5-small": 512, "t5-base": 768, "t5-large": 1024, + "t5-3b": 1024, "t5-11b": 1024, + "google/t5-v1_1-xl": 2048, "google/t5-v1_1-xxl": 4096, + "google/flan-t5-small": 512, "google/flan-t5-base": 768, + "google/flan-t5-large": 1024, "google/flan-t5-3b": 1024, + "google/flan-t5-11b": 1024, "google/flan-t5-xl": 2048, + "google/flan-t5-xxl": 4096, + } + + def __init__( + self, + output_dim: int, + t5_model_name: str = "t5-base", + max_length: int = 128, + enable_grad: bool = False, + project_out: bool = False, + ) -> None: + assert t5_model_name in self.T5_MODELS + super().__init__( + self.T5_MODEL_DIMS[t5_model_name], output_dim, project_out=project_out + ) + from transformers import AutoTokenizer, T5EncoderModel + + self.max_length = max_length + self.enable_grad = enable_grad + prev = logging.getLogger().level + logging.getLogger().setLevel(logging.ERROR) + try: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + self.tokenizer = AutoTokenizer.from_pretrained(t5_model_name) + model = T5EncoderModel.from_pretrained(t5_model_name) + model = model.train(enable_grad).requires_grad_(enable_grad) + model = model.to(torch.float16) + finally: + logging.getLogger().setLevel(prev) + if self.enable_grad: + self.model = model + else: + self.__dict__["model"] = model + + def forward( + self, texts: tp.List[str], device: tp.Union[torch.device, str] + ) -> tp.Tuple[torch.Tensor, torch.Tensor]: + self.model.to(device) + self.proj_out.to(device) + encoded = self.tokenizer( + texts, + truncation=True, + max_length=self.max_length, + padding="max_length", + return_tensors="pt", + ) + input_ids = encoded["input_ids"].to(device) + attention_mask = encoded["attention_mask"].to(device).to(torch.bool) + self.model.eval() + with torch.cuda.amp.autocast(dtype=torch.float16), torch.set_grad_enabled( + self.enable_grad + ): + embeddings = self.model( + input_ids=input_ids, attention_mask=attention_mask + )["last_hidden_state"] + if not isinstance(self.proj_out, nn.Identity): + embeddings = embeddings.to( + next(self.proj_out.parameters()).dtype + ) + embeddings = self.proj_out(embeddings) + embeddings = embeddings * attention_mask.unsqueeze(-1).float() + return embeddings, attention_mask + + +class MultiConditioner(nn.Module): + """Applies multiple conditioners keyed by config.""" + + def __init__( + self, + conditioners: tp.Dict[str, Conditioner], + default_keys: tp.Dict[str, str] | None = None, + pre_encoded_keys: tp.List[str] | None = None, + ) -> None: + super().__init__() + self.conditioners = nn.ModuleDict(conditioners) + self.default_keys = default_keys or {} + self.pre_encoded_keys = list(pre_encoded_keys or []) + + def forward( + self, + batch_metadata: tp.List[tp.Dict[str, tp.Any]], + device: tp.Union[torch.device, str], + ) -> tp.Dict[str, tp.Any]: + output: tp.Dict[str, tp.Any] = {} + for key, conditioner in self.conditioners.items(): + condition_key = key + inputs = [] + for x in batch_metadata: + if condition_key not in x: + condition_key = self.default_keys.get(condition_key, key) + if condition_key not in x: + raise ValueError( + f"Conditioner key {condition_key} not in metadata" + ) + val = x[condition_key] + if isinstance(val, (list, tuple)) and len(val) == 1: + val = val[0] + inputs.append(val) + if key in self.pre_encoded_keys: + output[key] = [torch.stack(inputs).to(device), None] + else: + output[key] = conditioner(inputs, device) + return output + + +def create_multi_conditioner_from_conditioning_config( + config: tp.Dict[str, tp.Any], +) -> MultiConditioner: + """ + Build MultiConditioner from model_config conditioning section. + Only supports conditioner types: "t5", "number" (stable-audio-open-1.0). + """ + conditioners: tp.Dict[str, Conditioner] = {} + cond_dim = config["cond_dim"] + default_keys = config.get("default_keys", {}) + pre_encoded_keys = config.get("pre_encoded_keys", []) + + for info in config["configs"]: + cid = info["id"] + ctype = info["type"] + cfg = {"output_dim": cond_dim, **info.get("config", {})} + + if ctype == "t5": + conditioners[cid] = T5Conditioner(**cfg) + elif ctype == "number": + conditioners[cid] = NumberConditioner(**cfg) + else: + raise ValueError( + f"Only t5 and number conditioners are supported inline; " + f"got type={ctype}. Use stable-audio-tools clone for others." + ) + + return MultiConditioner( + conditioners, + default_keys=default_keys, + pre_encoded_keys=pre_encoded_keys, + ) diff --git a/fastvideo/models/stable_audio/pretransform.py b/fastvideo/models/stable_audio/pretransform.py new file mode 100644 index 0000000000..c47eebc6cb --- /dev/null +++ b/fastvideo/models/stable_audio/pretransform.py @@ -0,0 +1,97 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Stable Audio Oobleck VAE pretransform for FastVideo. + +Uses in-repo Oobleck VAE (sat_factory). Loads from unified model.safetensors +using pretransform.model.* prefix. No stable-audio-tools clone required. +""" +from __future__ import annotations + +import json +import os +from typing import Any + +import torch +import torch.nn as nn + +from fastvideo.logger import init_logger +from fastvideo.models.stable_audio.sat_factory import create_pretransform_from_config + +logger = init_logger(__name__) + +PRETRANSFORM_KEY_PREFIX = "pretransform.model." + + +def _create_pretransform_from_config(model_config: dict) -> nn.Module: + """Create pretransform from model_config (in-repo autoencoder path).""" + pretransform_cfg = model_config.get("model", model_config).get("pretransform") + if pretransform_cfg is None: + raise ValueError("model_config must contain model.pretransform") + sample_rate = model_config.get("sample_rate", 44100) + return create_pretransform_from_config(pretransform_cfg, sample_rate) + + +class StableAudioPretransform(nn.Module): + """ + FastVideo wrapper for Stable Audio Oobleck VAE pretransform. + + Loads from unified model.safetensors using pretransform.model.* prefix. + Uses in-repo Oobleck VAE (no clone). + """ + + def __init__( + self, + model_config: dict | str | None = None, + checkpoint_path: str | None = None, + ): + super().__init__() + if isinstance(model_config, str): + with open(model_config) as f: + model_config = json.load(f) + if model_config is None: + model_config = {} + self._pretransform = _create_pretransform_from_config(model_config) + + if checkpoint_path and os.path.exists(checkpoint_path): + self.load_from_unified_checkpoint(checkpoint_path) + + def load_from_unified_checkpoint(self, checkpoint_path: str) -> None: + """Load pretransform weights from unified model.safetensors or model.ckpt.""" + from fastvideo.models.loader.weight_utils import unified_checkpoint_weights_iterator + + state_dict = {} + for key, tensor in unified_checkpoint_weights_iterator( + checkpoint_path, to_cpu=True, key_prefix=PRETRANSFORM_KEY_PREFIX + ): + # Inner model in AutoencoderPretransform is self.model + state_dict[key] = tensor + + # AutoencoderPretransform.load_state_dict passes to self.model + self._pretransform.load_state_dict(state_dict, strict=True) + logger.info("Loaded pretransform from %s (%d keys)", checkpoint_path, len(state_dict)) + + @property + def downsampling_ratio(self) -> int: + return self._pretransform.downsampling_ratio + + @property + def encoded_channels(self) -> int: + return self._pretransform.encoded_channels + + @property + def io_channels(self) -> int: + return self._pretransform.io_channels + + @property + def scale(self) -> float: + return getattr(self._pretransform, "scale", 1.0) + + def encode(self, x: torch.Tensor, **kwargs) -> torch.Tensor: + return self._pretransform.encode(x, **kwargs) + + def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor: + return self._pretransform.decode(z, **kwargs) + + def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: + """Encode for compatibility.""" + return self.encode(x, **kwargs) diff --git a/fastvideo/models/stable_audio/sampling.py b/fastvideo/models/stable_audio/sampling.py new file mode 100644 index 0000000000..d140a69a71 --- /dev/null +++ b/fastvideo/models/stable_audio/sampling.py @@ -0,0 +1,74 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Stable Audio k-diffusion v-prediction sampling. + +Uses k-diffusion only (VDenoiser, get_sigmas_polyexponential, sample_dpmpp_2m_sde). +No dependency on stable-audio-tools. +""" +from __future__ import annotations + +from typing import Any + +import torch + + +# Default sampler config matching stable-audio-tools +DEFAULT_SAMPLER_TYPE = "dpmpp-2m-sde" +DEFAULT_SIGMA_MIN = 0.01 +DEFAULT_SIGMA_MAX = 100.0 +DEFAULT_RHO = 1.0 + + +def sample_stable_audio( + model_fn: torch.nn.Module, + noise: torch.Tensor, + *, + init_data: torch.Tensor | None = None, + steps: int = 100, + sampler_type: str = DEFAULT_SAMPLER_TYPE, + sigma_min: float = DEFAULT_SIGMA_MIN, + sigma_max: float = DEFAULT_SIGMA_MAX, + rho: float = DEFAULT_RHO, + device: str | torch.device = "cuda", + **extra_args: Any, +) -> torch.Tensor: + """ + Run k-diffusion sampling (v-prediction) compatible with Stable Audio. + + Uses get_sigmas_polyexponential and VDenoiser; default sampler is dpmpp-2m-sde. + Requires: pip install k-diffusion (or pip install .[stable-audio]). + """ + import k_diffusion as K + + device_str = str(device) + denoiser = K.external.VDenoiser(model_fn) + sigmas = K.sampling.get_sigmas_polyexponential( + steps, sigma_min, sigma_max, rho, device=device_str + ) + noise = noise.to(device_str) * sigmas[0] + if init_data is not None: + x = init_data.to(device_str) + noise + else: + x = noise + + if sampler_type == "dpmpp-2m-sde": + return K.sampling.sample_dpmpp_2m_sde( + denoiser, x, sigmas, disable=False, callback=None, extra_args=extra_args + ) + if sampler_type == "dpmpp-2m": + return K.sampling.sample_dpmpp_2m( + denoiser, x, sigmas, disable=False, callback=None, extra_args=extra_args + ) + if sampler_type == "dpmpp-3m-sde": + return K.sampling.sample_dpmpp_3m_sde( + denoiser, x, sigmas, disable=False, callback=None, extra_args=extra_args + ) + if sampler_type == "k-heun": + return K.sampling.sample_heun( + denoiser, x, sigmas, disable=False, callback=None, extra_args=extra_args + ) + if sampler_type == "k-lms": + return K.sampling.sample_lms( + denoiser, x, sigmas, disable=False, callback=None, extra_args=extra_args + ) + raise ValueError(f"Unsupported sampler_type: {sampler_type}") diff --git a/fastvideo/models/stable_audio/sat_autoencoders.py b/fastvideo/models/stable_audio/sat_autoencoders.py new file mode 100644 index 0000000000..c3da90c9a0 --- /dev/null +++ b/fastvideo/models/stable_audio/sat_autoencoders.py @@ -0,0 +1,552 @@ +# SPDX-License-Identifier: Apache-2.0 +# Oobleck VAE + AudioAutoencoder for Stable Audio pretransform (in-repo, no clone). +from __future__ import annotations + +import math +from typing import Any, Dict + +import torch +from torch import nn +from torch.nn.utils import weight_norm + +from fastvideo.models.stable_audio.sat_blocks import SnakeBeta + + +def WNConv1d(*args, **kwargs): + return weight_norm(nn.Conv1d(*args, **kwargs)) + + +def WNConvTranspose1d(*args, **kwargs): + return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) + + +def _checkpoint(function, *args, **kwargs): + kwargs.setdefault("use_reentrant", False) + return torch.utils.checkpoint.checkpoint(function, *args, **kwargs) + + +def get_activation( + activation: str, antialias: bool = False, channels: int | None = None +) -> nn.Module: + if activation == "elu": + act = nn.ELU() + elif activation == "snake": + act = SnakeBeta(channels) + elif activation == "none": + act = nn.Identity() + else: + raise ValueError(f"Unknown activation {activation}") + if antialias: + act = nn.Identity() # skip alias_free_torch for in-repo + return act + + +class ResidualUnit(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dilation: int, + use_snake: bool = False, + antialias_activation: bool = False, + ): + super().__init__() + self.dilation = dilation + padding = (dilation * (7 - 1)) // 2 + self.layers = nn.Sequential( + get_activation( + "snake" if use_snake else "elu", + antialias=antialias_activation, + channels=out_channels, + ), + WNConv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=7, + dilation=dilation, + padding=padding, + ), + get_activation( + "snake" if use_snake else "elu", + antialias=antialias_activation, + channels=out_channels, + ), + WNConv1d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=1, + ), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + res = x + if self.training: + x = _checkpoint(self.layers, x) + else: + x = self.layers(x) + return x + res + + +class EncoderBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + stride: int, + use_snake: bool = False, + antialias_activation: bool = False, + ): + super().__init__() + self.layers = nn.Sequential( + ResidualUnit( + in_channels, in_channels, 1, use_snake=use_snake + ), + ResidualUnit( + in_channels, in_channels, 3, use_snake=use_snake + ), + ResidualUnit( + in_channels, in_channels, 9, use_snake=use_snake + ), + get_activation( + "snake" if use_snake else "elu", + antialias=antialias_activation, + channels=in_channels, + ), + WNConv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=2 * stride, + stride=stride, + padding=math.ceil(stride / 2), + ), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.layers(x) + + +class DecoderBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + stride: int, + use_snake: bool = False, + antialias_activation: bool = False, + use_nearest_upsample: bool = False, + ): + super().__init__() + if use_nearest_upsample: + upsample_layer = nn.Sequential( + nn.Upsample(scale_factor=stride, mode="nearest"), + WNConv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=2 * stride, + stride=1, + bias=False, + padding="same", + ), + ) + else: + upsample_layer = WNConvTranspose1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=2 * stride, + stride=stride, + padding=math.ceil(stride / 2), + ) + self.layers = nn.Sequential( + get_activation( + "snake" if use_snake else "elu", + antialias=antialias_activation, + channels=in_channels, + ), + upsample_layer, + ResidualUnit(out_channels, out_channels, 1, use_snake=use_snake), + ResidualUnit(out_channels, out_channels, 3, use_snake=use_snake), + ResidualUnit(out_channels, out_channels, 9, use_snake=use_snake), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.layers(x) + + +class OobleckEncoder(nn.Module): + def __init__( + self, + in_channels: int = 2, + channels: int = 128, + latent_dim: int = 32, + c_mults: list = (1, 2, 4, 8), + strides: list = (2, 4, 8, 8), + use_snake: bool = False, + antialias_activation: bool = False, + ): + super().__init__() + self.in_channels = in_channels + c_mults = [1] + list(c_mults) + self.depth = len(c_mults) + layers = [ + WNConv1d( + in_channels=in_channels, + out_channels=c_mults[0] * channels, + kernel_size=7, + padding=3, + ) + ] + for i in range(self.depth - 1): + layers.append( + EncoderBlock( + c_mults[i] * channels, + c_mults[i + 1] * channels, + strides[i], + use_snake=use_snake, + ) + ) + layers += [ + get_activation( + "snake" if use_snake else "elu", + antialias=antialias_activation, + channels=c_mults[-1] * channels, + ), + WNConv1d( + in_channels=c_mults[-1] * channels, + out_channels=latent_dim, + kernel_size=3, + padding=1, + ), + ] + self.layers = nn.Sequential(*layers) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.layers(x) + + +class OobleckDecoder(nn.Module): + def __init__( + self, + out_channels: int = 2, + channels: int = 128, + latent_dim: int = 32, + c_mults: list = (1, 2, 4, 8), + strides: list = (2, 4, 8, 8), + use_snake: bool = False, + antialias_activation: bool = False, + use_nearest_upsample: bool = False, + final_tanh: bool = True, + ): + super().__init__() + self.out_channels = out_channels + c_mults = [1] + list(c_mults) + self.depth = len(c_mults) + layers = [ + WNConv1d( + in_channels=latent_dim, + out_channels=c_mults[-1] * channels, + kernel_size=7, + padding=3, + ), + ] + for i in range(self.depth - 1, 0, -1): + layers.append( + DecoderBlock( + c_mults[i] * channels, + c_mults[i - 1] * channels, + strides[i - 1], + use_snake=use_snake, + antialias_activation=antialias_activation, + use_nearest_upsample=use_nearest_upsample, + ) + ) + layers += [ + get_activation( + "snake" if use_snake else "elu", + antialias=antialias_activation, + channels=c_mults[0] * channels, + ), + WNConv1d( + in_channels=c_mults[0] * channels, + out_channels=out_channels, + kernel_size=7, + padding=3, + bias=False, + ), + nn.Tanh() if final_tanh else nn.Identity(), + ] + self.layers = nn.Sequential(*layers) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.layers(x) + + +class AudioAutoencoder(nn.Module): + def __init__( + self, + encoder: nn.Module, + decoder: nn.Module, + latent_dim: int, + downsampling_ratio: int, + sample_rate: int, + io_channels: int = 2, + bottleneck: Any = None, + pretransform: Any = None, + in_channels: int | None = None, + out_channels: int | None = None, + soft_clip: bool = False, + ): + super().__init__() + self.downsampling_ratio = downsampling_ratio + self.sample_rate = sample_rate + self.latent_dim = latent_dim + self.io_channels = io_channels + self.in_channels = io_channels + self.out_channels = io_channels + self.min_length = downsampling_ratio + if in_channels is not None: + self.in_channels = in_channels + if out_channels is not None: + self.out_channels = out_channels + self.bottleneck = bottleneck + self.encoder = encoder + self.decoder = decoder + self.pretransform = pretransform + self.soft_clip = soft_clip + self.is_discrete = ( + bottleneck is not None and getattr(bottleneck, "is_discrete", False) + ) + + def encode( + self, + audio: torch.Tensor, + skip_bottleneck: bool = False, + return_info: bool = False, + skip_pretransform: bool = False, + iterate_batch: bool = False, + **kwargs: Any, + ) -> torch.Tensor | tuple: + if self.pretransform is not None and not skip_pretransform: + with torch.no_grad(): + audio = self.pretransform.encode(audio) + if self.encoder is not None: + latents = self.encoder(audio) + else: + latents = audio + if self.bottleneck is not None and not skip_bottleneck: + latents = self.bottleneck.encode(latents, **kwargs) + if return_info: + return latents, {} + return latents + + def decode( + self, + latents: torch.Tensor, + skip_bottleneck: bool = False, + iterate_batch: bool = False, + **kwargs: Any, + ) -> torch.Tensor: + if self.bottleneck is not None and not skip_bottleneck: + latents = self.bottleneck.decode(latents) + decoded = self.decoder(latents, **kwargs) + if self.pretransform is not None: + with torch.no_grad(): + decoded = self.pretransform.decode(decoded) + if self.soft_clip: + decoded = torch.tanh(decoded) + return decoded + + def encode_audio( + self, + audio: torch.Tensor, + chunked: bool = False, + overlap: int = 32, + chunk_size: int = 128, + **kwargs: Any, + ) -> torch.Tensor: + if not chunked: + return self.encode(audio, **kwargs) + samples_per_latent = int(self.downsampling_ratio) + total_size = audio.shape[2] + batch_size = audio.shape[0] + chunk_size_samp = chunk_size * samples_per_latent + overlap_samp = overlap * samples_per_latent + hop_size = chunk_size_samp - overlap_samp + chunks_list = [] + i = 0 + for i in range(0, total_size - chunk_size_samp + 1, hop_size): + chunks_list.append(audio[:, :, i : i + chunk_size_samp]) + if i + chunk_size_samp < total_size: + chunks_list.append(audio[:, :, -chunk_size_samp:]) + chunks = torch.stack(chunks_list) + num_chunks = chunks.shape[0] + y_size = total_size // samples_per_latent + y_final = torch.zeros( + (batch_size, self.latent_dim, y_size), + dtype=chunks.dtype, + device=audio.device, + ) + for j in range(num_chunks): + y_chunk = self.encode(chunks[j]) + if j == num_chunks - 1: + t_end = y_size + t_start = t_end - y_chunk.shape[2] + else: + t_start = j * hop_size // samples_per_latent + t_end = t_start + chunk_size_samp // samples_per_latent + ol = overlap_samp // samples_per_latent // 2 + chunk_start = 0 + chunk_end = y_chunk.shape[2] + if j > 0: + t_start += ol + chunk_start += ol + if j < num_chunks - 1: + t_end -= ol + chunk_end -= ol + y_final[:, :, t_start:t_end] = y_chunk[:, :, chunk_start:chunk_end] + return y_final + + def decode_audio( + self, + latents: torch.Tensor, + chunked: bool = False, + overlap: int = 32, + chunk_size: int = 128, + **kwargs: Any, + ) -> torch.Tensor: + if not chunked: + return self.decode(latents, **kwargs) + hop_size = chunk_size - overlap + total_size = latents.shape[2] + batch_size = latents.shape[0] + chunks_list = [] + i = 0 + for i in range(0, total_size - chunk_size + 1, hop_size): + chunks_list.append(latents[:, :, i : i + chunk_size]) + if i + chunk_size < total_size: + chunks_list.append(latents[:, :, -chunk_size:]) + chunks = torch.stack(chunks_list) + num_chunks = chunks.shape[0] + samples_per_latent = int(self.downsampling_ratio) + y_size = total_size * samples_per_latent + y_final = torch.zeros( + (batch_size, self.out_channels, y_size), + dtype=chunks.dtype, + device=latents.device, + ) + for j in range(num_chunks): + y_chunk = self.decode(chunks[j]) + if j == num_chunks - 1: + t_end = y_size + t_start = t_end - y_chunk.shape[2] + else: + t_start = j * hop_size * samples_per_latent + t_end = t_start + chunk_size * samples_per_latent + ol = (overlap // 2) * samples_per_latent + chunk_start = 0 + chunk_end = y_chunk.shape[2] + if j > 0: + t_start += ol + chunk_start += ol + if j < num_chunks - 1: + t_end -= ol + chunk_end -= ol + y_final[:, :, t_start:t_end] = y_chunk[:, :, chunk_start:chunk_end] + return y_final + + +def create_encoder_from_config( + encoder_config: Dict[str, Any], + latent_dim_override: int | None = None, +) -> nn.Module: + encoder_type = encoder_config.get("type") + if encoder_type != "oobleck": + raise ValueError( + f"Only oobleck encoder is supported in-repo; got {encoder_type}" + ) + config = dict(encoder_config["config"]) + if latent_dim_override is not None: + config["latent_dim"] = latent_dim_override + encoder = OobleckEncoder(**config) + if not encoder_config.get("requires_grad", True): + for p in encoder.parameters(): + p.requires_grad = False + return encoder + + +def create_decoder_from_config( + decoder_config: Dict[str, Any], + latent_dim_override: int | None = None, +) -> nn.Module: + decoder_type = decoder_config.get("type") + if decoder_type != "oobleck": + raise ValueError( + f"Only oobleck decoder is supported in-repo; got {decoder_type}" + ) + config = dict(decoder_config["config"]) + if latent_dim_override is not None: + config["latent_dim"] = latent_dim_override + decoder = OobleckDecoder(**config) + if not decoder_config.get("requires_grad", True): + for p in decoder.parameters(): + p.requires_grad = False + return decoder + + +def _vae_sample(mean: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + stdev = torch.nn.functional.softplus(scale) + 1e-4 + return torch.randn_like(mean, device=mean.device, dtype=mean.dtype) * stdev + mean + + +class VAEBottleneck(nn.Module): + """Splits encoder output (2*latent_dim) into mean and scale, samples to latent_dim.""" + + def __init__(self) -> None: + super().__init__() + + def encode( + self, x: torch.Tensor, return_info: bool = False, **kwargs: Any + ) -> torch.Tensor | tuple: + mean, scale = x.chunk(2, dim=1) + out = _vae_sample(mean, scale) + if return_info: + return out, {} + return out + + def decode(self, x: torch.Tensor) -> torch.Tensor: + return x + + +def _create_bottleneck_from_config(bottleneck_config: Dict[str, Any] | None) -> nn.Module | None: + if bottleneck_config is None: + return None + bt_type = bottleneck_config.get("type") + if bt_type == "vae": + return VAEBottleneck() + raise ValueError(f"Only bottleneck type 'vae' is supported in-repo; got {bt_type}") + + +def create_autoencoder_from_config(config: Dict[str, Any]) -> AudioAutoencoder: + ae_config = config["model"] + latent_dim = ae_config["latent_dim"] + encoder = create_encoder_from_config(ae_config["encoder"]) + decoder = create_decoder_from_config(ae_config["decoder"]) + bottleneck = _create_bottleneck_from_config(ae_config.get("bottleneck")) + downsampling_ratio = ae_config["downsampling_ratio"] + io_channels = ae_config["io_channels"] + sample_rate = config["sample_rate"] + in_channels = ae_config.get("in_channels") + out_channels = ae_config.get("out_channels") + soft_clip = ae_config["decoder"].get("soft_clip", False) + return AudioAutoencoder( + encoder=encoder, + decoder=decoder, + io_channels=io_channels, + latent_dim=latent_dim, + downsampling_ratio=downsampling_ratio, + sample_rate=sample_rate, + bottleneck=bottleneck, + pretransform=None, + in_channels=in_channels, + out_channels=out_channels, + soft_clip=soft_clip, + ) diff --git a/fastvideo/models/stable_audio/sat_blocks.py b/fastvideo/models/stable_audio/sat_blocks.py new file mode 100644 index 0000000000..1ce8f6372f --- /dev/null +++ b/fastvideo/models/stable_audio/sat_blocks.py @@ -0,0 +1,54 @@ +# SPDX-License-Identifier: Apache-2.0 +# Minimal blocks for Stable Audio DiT + VAE (from stable-audio-tools models/blocks.py). +import math + +import torch +from torch import nn + + +def _snake_beta(x: torch.Tensor, alpha: torch.Tensor, beta: torch.Tensor) -> torch.Tensor: + return x + (1.0 / (beta + 1e-9)) * torch.pow(torch.sin(x * alpha), 2) + + +class SnakeBeta(nn.Module): + """From stable-audio-tools blocks (BigVGAN-style).""" + + def __init__( + self, + in_features: int, + alpha: float = 1.0, + alpha_trainable: bool = True, + alpha_logscale: bool = True, + ): + super().__init__() + self.in_features = in_features + self.alpha_logscale = alpha_logscale + if alpha_logscale: + self.alpha = nn.Parameter(torch.zeros(in_features) * alpha) + self.beta = nn.Parameter(torch.zeros(in_features) * alpha) + else: + self.alpha = nn.Parameter(torch.ones(in_features) * alpha) + self.beta = nn.Parameter(torch.ones(in_features) * alpha) + self.alpha.requires_grad = alpha_trainable + self.beta.requires_grad = alpha_trainable + + def forward(self, x: torch.Tensor) -> torch.Tensor: + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) + beta = self.beta.unsqueeze(0).unsqueeze(-1) + if self.alpha_logscale: + alpha = torch.exp(alpha) + beta = torch.exp(beta) + return _snake_beta(x, alpha, beta) + + +class FourierFeatures(nn.Module): + def __init__(self, in_features: int, out_features: int, std: float = 1.0): + super().__init__() + assert out_features % 2 == 0 + self.weight = nn.Parameter( + torch.randn([out_features // 2, in_features]) * std + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + f = 2 * math.pi * input @ self.weight.T + return torch.cat([f.cos(), f.sin()], dim=-1) diff --git a/fastvideo/models/stable_audio/sat_dit.py b/fastvideo/models/stable_audio/sat_dit.py new file mode 100644 index 0000000000..ba4913a6c9 --- /dev/null +++ b/fastvideo/models/stable_audio/sat_dit.py @@ -0,0 +1,433 @@ +# SPDX-License-Identifier: Apache-2.0 +# Vendored from stable-audio-tools for in-repo Stable Audio DiT. +import math +import typing as tp + +import torch +from einops import rearrange +from torch import nn +from torch.nn import functional as F + +from fastvideo.models.stable_audio.sat_blocks import FourierFeatures +from fastvideo.models.stable_audio.sat_transformer import ContinuousTransformer + +class DiffusionTransformer(nn.Module): + def __init__(self, + io_channels=32, + patch_size=1, + embed_dim=768, + cond_token_dim=0, + project_cond_tokens=True, + global_cond_dim=0, + project_global_cond=True, + input_concat_dim=0, + prepend_cond_dim=0, + depth=12, + num_heads=8, + transformer_type: tp.Literal["continuous_transformer"] = "continuous_transformer", + global_cond_type: tp.Literal["prepend", "adaLN"] = "prepend", + timestep_cond_type: tp.Literal["global", "input_concat"] = "global", + timestep_embed_dim=None, + diffusion_objective: tp.Literal["v", "rectified_flow", "rf_denoiser"] = "v", + **kwargs): + + super().__init__() + + self.cond_token_dim = cond_token_dim + + # Timestep embeddings + self.timestep_cond_type = timestep_cond_type + + timestep_features_dim = 256 + + self.timestep_features = FourierFeatures(1, timestep_features_dim) + + if timestep_cond_type == "global": + timestep_embed_dim = embed_dim + elif timestep_cond_type == "input_concat": + assert timestep_embed_dim is not None, "timestep_embed_dim must be specified if timestep_cond_type is input_concat" + input_concat_dim += timestep_embed_dim + + self.to_timestep_embed = nn.Sequential( + nn.Linear(timestep_features_dim, timestep_embed_dim, bias=True), + nn.SiLU(), + nn.Linear(timestep_embed_dim, timestep_embed_dim, bias=True), + ) + + self.diffusion_objective = diffusion_objective + + if cond_token_dim > 0: + # Conditioning tokens + + cond_embed_dim = cond_token_dim if not project_cond_tokens else embed_dim + self.to_cond_embed = nn.Sequential( + nn.Linear(cond_token_dim, cond_embed_dim, bias=False), + nn.SiLU(), + nn.Linear(cond_embed_dim, cond_embed_dim, bias=False) + ) + else: + cond_embed_dim = 0 + + if global_cond_dim > 0: + # Global conditioning + global_embed_dim = global_cond_dim if not project_global_cond else embed_dim + self.to_global_embed = nn.Sequential( + nn.Linear(global_cond_dim, global_embed_dim, bias=False), + nn.SiLU(), + nn.Linear(global_embed_dim, global_embed_dim, bias=False) + ) + + if prepend_cond_dim > 0: + # Prepend conditioning + self.to_prepend_embed = nn.Sequential( + nn.Linear(prepend_cond_dim, embed_dim, bias=False), + nn.SiLU(), + nn.Linear(embed_dim, embed_dim, bias=False) + ) + + self.input_concat_dim = input_concat_dim + + dim_in = io_channels + self.input_concat_dim + + self.patch_size = patch_size + + # Transformer + + self.transformer_type = transformer_type + + self.global_cond_type = global_cond_type + + if self.transformer_type == "continuous_transformer": + + global_dim = None + + if self.global_cond_type == "adaLN": + # The global conditioning is projected to the embed_dim already at this point + global_dim = embed_dim + + self.transformer = ContinuousTransformer( + dim=embed_dim, + depth=depth, + dim_heads=embed_dim // num_heads, + dim_in=dim_in * patch_size, + dim_out=io_channels * patch_size, + cross_attend = cond_token_dim > 0, + cond_token_dim = cond_embed_dim, + global_cond_dim=global_dim, + **kwargs + ) + else: + raise ValueError(f"Unknown transformer type: {self.transformer_type}") + + self.preprocess_conv = nn.Conv1d(dim_in, dim_in, 1, bias=False) + nn.init.zeros_(self.preprocess_conv.weight) + self.postprocess_conv = nn.Conv1d(io_channels, io_channels, 1, bias=False) + nn.init.zeros_(self.postprocess_conv.weight) + + def _forward( + self, + x, + t, + mask=None, + cross_attn_cond=None, + cross_attn_cond_mask=None, + input_concat_cond=None, + global_embed=None, + prepend_cond=None, + prepend_cond_mask=None, + return_info=False, + exit_layer_ix=None, + **kwargs): + + if cross_attn_cond is not None: + cross_attn_cond = self.to_cond_embed(cross_attn_cond) + + if global_embed is not None: + # Project the global conditioning to the embedding dimension + global_embed = self.to_global_embed(global_embed) + + prepend_inputs = None + prepend_mask = None + prepend_length = 0 + if prepend_cond is not None: + # Project the prepend conditioning to the embedding dimension + prepend_cond = self.to_prepend_embed(prepend_cond) + + prepend_inputs = prepend_cond + if prepend_cond_mask is not None: + prepend_mask = prepend_cond_mask + + prepend_length = prepend_cond.shape[1] + + if input_concat_cond is not None: + # Interpolate input_concat_cond to the same length as x + if input_concat_cond.shape[2] != x.shape[2]: + input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest') + + x = torch.cat([x, input_concat_cond], dim=1) + + # Get the batch of timestep embeddings + timestep_embed = self.to_timestep_embed(self.timestep_features(t[:, None])) # (b, embed_dim) + + # Timestep embedding is considered a global embedding. Add to the global conditioning if it exists + + if self.timestep_cond_type == "global": + if global_embed is not None: + global_embed = global_embed + timestep_embed + else: + global_embed = timestep_embed + elif self.timestep_cond_type == "input_concat": + x = torch.cat([x, timestep_embed.unsqueeze(1).expand(-1, -1, x.shape[2])], dim=1) + + # Add the global_embed to the prepend inputs if there is no global conditioning support in the transformer + if self.global_cond_type == "prepend" and global_embed is not None: + if prepend_inputs is None: + # Prepend inputs are just the global embed, and the mask is all ones + prepend_inputs = global_embed.unsqueeze(1) + prepend_mask = torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool) + else: + # Prepend inputs are the prepend conditioning + the global embed + prepend_inputs = torch.cat([prepend_inputs, global_embed.unsqueeze(1)], dim=1) + prepend_mask = torch.cat([prepend_mask, torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)], dim=1) + + prepend_length = prepend_inputs.shape[1] + + x = self.preprocess_conv(x) + x + + x = rearrange(x, "b c t -> b t c") + + extra_args = {} + + if self.global_cond_type == "adaLN": + extra_args["global_cond"] = global_embed + + if self.patch_size > 1: + x = rearrange(x, "b (t p) c -> b t (c p)", p=self.patch_size) + + if self.transformer_type == "continuous_transformer": + # Masks not currently implemented for continuous transformer + output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, return_info=return_info, exit_layer_ix=exit_layer_ix, **extra_args, **kwargs) + + if return_info: + output, info = output + + # Avoid postprocessing on early exit + if exit_layer_ix is not None: + if return_info: + return output, info + else: + return output + + output = rearrange(output, "b t c -> b c t")[:,:,prepend_length:] + + if self.patch_size > 1: + output = rearrange(output, "b (c p) t -> b c (t p)", p=self.patch_size) + + output = self.postprocess_conv(output) + output + + if return_info: + return output, info + + return output + + def forward( + self, + x, + t, + cross_attn_cond=None, + cross_attn_cond_mask=None, + negative_cross_attn_cond=None, + negative_cross_attn_mask=None, + input_concat_cond=None, + global_embed=None, + negative_global_embed=None, + prepend_cond=None, + prepend_cond_mask=None, + cfg_scale=1.0, + cfg_dropout_prob=0.0, + cfg_interval = (0, 1), + causal=False, + scale_phi=0.0, + mask=None, + return_info=False, + exit_layer_ix=None, + **kwargs): + + assert causal == False, "Causal mode is not supported for DiffusionTransformer" + + model_dtype = next(self.parameters()).dtype + + x = x.to(model_dtype) + + t = t.to(model_dtype) + + if cross_attn_cond is not None: + cross_attn_cond = cross_attn_cond.to(model_dtype) + + if negative_cross_attn_cond is not None: + negative_cross_attn_cond = negative_cross_attn_cond.to(model_dtype) + + if input_concat_cond is not None: + input_concat_cond = input_concat_cond.to(model_dtype) + + if global_embed is not None: + global_embed = global_embed.to(model_dtype) + + if negative_global_embed is not None: + negative_global_embed = negative_global_embed.to(model_dtype) + + if prepend_cond is not None: + prepend_cond = prepend_cond.to(model_dtype) + + if cross_attn_cond_mask is not None: + cross_attn_cond_mask = cross_attn_cond_mask.bool() + + cross_attn_cond_mask = None # Temporarily disabling conditioning masks due to kernel issue for flash attention + + if prepend_cond_mask is not None: + prepend_cond_mask = prepend_cond_mask.bool() + + # Early exit bypasses CFG processing + if exit_layer_ix is not None: + assert self.transformer_type == "continuous_transformer", "exit_layer_ix is only supported for continuous_transformer" + return self._forward( + x, + t, + cross_attn_cond=cross_attn_cond, + cross_attn_cond_mask=cross_attn_cond_mask, + input_concat_cond=input_concat_cond, + global_embed=global_embed, + prepend_cond=prepend_cond, + prepend_cond_mask=prepend_cond_mask, + mask=mask, + return_info=return_info, + exit_layer_ix=exit_layer_ix, + **kwargs + ) + + # CFG dropout + if cfg_dropout_prob > 0.0 and cfg_scale == 1.0: + if cross_attn_cond is not None: + null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device) + dropout_mask = torch.bernoulli(torch.full((cross_attn_cond.shape[0], 1, 1), cfg_dropout_prob, device=cross_attn_cond.device)).to(torch.bool) + cross_attn_cond = torch.where(dropout_mask, null_embed, cross_attn_cond) + + if prepend_cond is not None: + null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device) + dropout_mask = torch.bernoulli(torch.full((prepend_cond.shape[0], 1, 1), cfg_dropout_prob, device=prepend_cond.device)).to(torch.bool) + prepend_cond = torch.where(dropout_mask, null_embed, prepend_cond) + + if self.diffusion_objective == "v": + sigma = torch.sin(t * math.pi / 2) + alpha = torch.cos(t * math.pi / 2) + elif self.diffusion_objective in ["rectified_flow", "rf_denoiser"]: + sigma = t + + if cfg_scale != 1.0 and (cross_attn_cond is not None or prepend_cond is not None) and (cfg_interval[0] <= sigma[0] <= cfg_interval[1]): + + # Classifier-free guidance + # Concatenate conditioned and unconditioned inputs on the batch dimension + batch_inputs = torch.cat([x, x], dim=0) + batch_timestep = torch.cat([t, t], dim=0) + + if global_embed is not None: + batch_global_cond = torch.cat([global_embed, global_embed], dim=0) + else: + batch_global_cond = None + + if input_concat_cond is not None: + batch_input_concat_cond = torch.cat([input_concat_cond, input_concat_cond], dim=0) + else: + batch_input_concat_cond = None + + batch_cond = None + batch_cond_masks = None + + # Handle CFG for cross-attention conditioning + if cross_attn_cond is not None: + + null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device) + + # For negative cross-attention conditioning, replace the null embed with the negative cross-attention conditioning + if negative_cross_attn_cond is not None: + + # If there's a negative cross-attention mask, set the masked tokens to the null embed + if negative_cross_attn_mask is not None: + negative_cross_attn_mask = negative_cross_attn_mask.to(torch.bool).unsqueeze(2) + + negative_cross_attn_cond = torch.where(negative_cross_attn_mask, negative_cross_attn_cond, null_embed) + + batch_cond = torch.cat([cross_attn_cond, negative_cross_attn_cond], dim=0) + + else: + batch_cond = torch.cat([cross_attn_cond, null_embed], dim=0) + + if cross_attn_cond_mask is not None: + batch_cond_masks = torch.cat([cross_attn_cond_mask, cross_attn_cond_mask], dim=0) + + batch_prepend_cond = None + batch_prepend_cond_mask = None + + if prepend_cond is not None: + + null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device) + + batch_prepend_cond = torch.cat([prepend_cond, null_embed], dim=0) + + if prepend_cond_mask is not None: + batch_prepend_cond_mask = torch.cat([prepend_cond_mask, prepend_cond_mask], dim=0) + + + if mask is not None: + batch_masks = torch.cat([mask, mask], dim=0) + else: + batch_masks = None + + batch_output = self._forward( + batch_inputs, + batch_timestep, + cross_attn_cond=batch_cond, + cross_attn_cond_mask=batch_cond_masks, + mask = batch_masks, + input_concat_cond=batch_input_concat_cond, + global_embed = batch_global_cond, + prepend_cond = batch_prepend_cond, + prepend_cond_mask = batch_prepend_cond_mask, + return_info = return_info, + **kwargs) + + if return_info: + batch_output, info = batch_output + + cond_output, uncond_output = torch.chunk(batch_output, 2, dim=0) + + cfg_output = uncond_output + (cond_output - uncond_output) * cfg_scale + + # CFG Rescale + if scale_phi != 0.0: + cond_out_std = cond_output.std(dim=1, keepdim=True) + out_cfg_std = cfg_output.std(dim=1, keepdim=True) + output = scale_phi * (cfg_output * (cond_out_std/out_cfg_std)) + (1-scale_phi) * cfg_output + else: + output = cfg_output + + if return_info: + info["uncond_output"] = uncond_output + return output, info + + return output + + else: + return self._forward( + x, + t, + cross_attn_cond=cross_attn_cond, + cross_attn_cond_mask=cross_attn_cond_mask, + input_concat_cond=input_concat_cond, + global_embed=global_embed, + prepend_cond=prepend_cond, + prepend_cond_mask=prepend_cond_mask, + mask=mask, + return_info=return_info, + **kwargs + ) \ No newline at end of file diff --git a/fastvideo/models/stable_audio/sat_factory.py b/fastvideo/models/stable_audio/sat_factory.py new file mode 100644 index 0000000000..1161d4b895 --- /dev/null +++ b/fastvideo/models/stable_audio/sat_factory.py @@ -0,0 +1,38 @@ +# SPDX-License-Identifier: Apache-2.0 +# create_pretransform_from_config for autoencoder type (in-repo, no clone). +from __future__ import annotations + +from typing import Any, Dict + +from torch import nn + +from fastvideo.models.stable_audio.sat_autoencoders import create_autoencoder_from_config +from fastvideo.models.stable_audio.sat_pretransforms import AutoencoderPretransform + + +def create_pretransform_from_config( + pretransform_config: Dict[str, Any], sample_rate: int +) -> nn.Module: + pretransform_type = pretransform_config.get("type") + if pretransform_type != "autoencoder": + raise ValueError( + f"Only pretransform type 'autoencoder' is supported in-repo; " + f"got {pretransform_type}" + ) + cfg = pretransform_config["config"] + autoencoder_config = {"sample_rate": sample_rate, "model": cfg} + autoencoder = create_autoencoder_from_config(autoencoder_config) + scale = pretransform_config.get("scale", 1.0) + model_half = pretransform_config.get("model_half", False) + iterate_batch = pretransform_config.get("iterate_batch", False) + chunked = pretransform_config.get("chunked", False) + pretransform = AutoencoderPretransform( + autoencoder, + scale=scale, + model_half=model_half, + iterate_batch=iterate_batch, + chunked=chunked, + ) + pretransform.enable_grad = pretransform_config.get("enable_grad", False) + pretransform.eval().requires_grad_(pretransform.enable_grad) + return pretransform diff --git a/fastvideo/models/stable_audio/sat_pretransforms.py b/fastvideo/models/stable_audio/sat_pretransforms.py new file mode 100644 index 0000000000..8bef8803a8 --- /dev/null +++ b/fastvideo/models/stable_audio/sat_pretransforms.py @@ -0,0 +1,92 @@ +# SPDX-License-Identifier: Apache-2.0 +# AutoencoderPretransform for Stable Audio (in-repo, no clone). +from __future__ import annotations + +import torch +from torch import nn + + +class Pretransform(nn.Module): + def __init__( + self, + enable_grad: bool, + io_channels: int, + is_discrete: bool, + ): + super().__init__() + self.is_discrete = is_discrete + self.io_channels = io_channels + self.encoded_channels = None + self.downsampling_ratio = None + self.enable_grad = enable_grad + + def encode(self, x: torch.Tensor, **kwargs: object) -> torch.Tensor: + raise NotImplementedError + + def decode(self, z: torch.Tensor, **kwargs: object) -> torch.Tensor: + raise NotImplementedError + + +class AutoencoderPretransform(Pretransform): + def __init__( + self, + model: nn.Module, + scale: float = 1.0, + model_half: bool = False, + iterate_batch: bool = False, + chunked: bool = False, + ): + super().__init__( + enable_grad=False, + io_channels=model.io_channels, + is_discrete=( + getattr(model, "bottleneck", None) is not None + and getattr(model.bottleneck, "is_discrete", False) + ), + ) + self.model = model + self.model.requires_grad_(False).eval() + self.scale = scale + self.downsampling_ratio = model.downsampling_ratio + self.io_channels = model.io_channels + self.sample_rate = getattr(model, "sample_rate", 44100) + self.model_half = model_half + self.iterate_batch = iterate_batch + self.encoded_channels = model.latent_dim + self.chunked = chunked + self.num_quantizers = ( + getattr(model.bottleneck, "num_quantizers", None) + if getattr(model, "bottleneck", None) is not None + else None + ) + self.codebook_size = ( + getattr(model.bottleneck, "codebook_size", None) + if getattr(model, "bottleneck", None) is not None + else None + ) + if self.model_half: + self.model.half() + + def encode(self, x: torch.Tensor, **kwargs: object) -> torch.Tensor: + if self.model_half: + x = x.half() + encoded = self.model.encode_audio( + x, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs + ) + if self.model_half: + encoded = encoded.float() + return encoded / self.scale + + def decode(self, z: torch.Tensor, **kwargs: object) -> torch.Tensor: + z = z * self.scale + if self.model_half: + z = z.half() + decoded = self.model.decode_audio( + z, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs + ) + if self.model_half: + decoded = decoded.float() + return decoded + + def load_state_dict(self, state_dict: dict, strict: bool = True) -> None: + self.model.load_state_dict(state_dict, strict=strict) diff --git a/fastvideo/models/stable_audio/sat_transformer.py b/fastvideo/models/stable_audio/sat_transformer.py new file mode 100644 index 0000000000..7a4ec4eb79 --- /dev/null +++ b/fastvideo/models/stable_audio/sat_transformer.py @@ -0,0 +1,878 @@ +# SPDX-License-Identifier: Apache-2.0 +# Vendored from stable-audio-tools for in-repo Stable Audio DiT. +from functools import reduce + +from einops import rearrange +from einops.layers.torch import Rearrange +import torch +import torch.nn.functional as F +from torch import nn, einsum +from torch.amp import autocast +from typing import Callable, Literal + +try: + from torch.nn.attention.flex_attention import flex_attention +except ImportError: + flex_attention = None + +try: + from flash_attn import flash_attn_func +except ImportError: + flash_attn_func = None + +from fastvideo.models.stable_audio.sat_utils import compile + +if flex_attention is not None: + try: + torch._dynamo.config.cache_size_limit = 5000 + flex_attention_compiled = torch.compile( + flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs" + ) + except Exception: + flex_attention_compiled = flex_attention +else: + flex_attention_compiled = None + +def checkpoint(function, *args, **kwargs): + kwargs.setdefault("use_reentrant", False) + return torch.utils.checkpoint.checkpoint(function, *args, **kwargs) + + +# Copied and modified from https://github.com/lucidrains/x-transformers/blob/main/x_transformers/attend.py under MIT License +# License can be found in LICENSES/LICENSE_XTRANSFORMERS.txt + +def create_causal_mask(i, j, device): + return torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1) + +def or_reduce(masks): + head, *body = masks + for rest in body: + head = head | rest + return head + +# positional embeddings + +class AbsolutePositionalEmbedding(nn.Module): + def __init__(self, dim, max_seq_len): + super().__init__() + self.scale = dim ** -0.5 + self.max_seq_len = max_seq_len + self.emb = nn.Embedding(max_seq_len, dim) + + def forward(self, x, pos = None, seq_start_pos = None): + seq_len, device = x.shape[1], x.device + assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}' + + if pos is None: + pos = torch.arange(seq_len, device = device) + + if seq_start_pos is not None: + pos = (pos - seq_start_pos[..., None]).clamp(min = 0) + + pos_emb = self.emb(pos) + pos_emb = pos_emb * self.scale + return pos_emb + +class ScaledSinusoidalEmbedding(nn.Module): + def __init__(self, dim, theta = 10000): + super().__init__() + assert (dim % 2) == 0, 'dimension must be divisible by 2' + self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5) + + half_dim = dim // 2 + freq_seq = torch.arange(half_dim).float() / half_dim + inv_freq = theta ** -freq_seq + self.register_buffer('inv_freq', inv_freq, persistent = False) + + def forward(self, x, pos = None, seq_start_pos = None): + seq_len, device = x.shape[1], x.device + + if pos is None: + pos = torch.arange(seq_len, device = device) + + if seq_start_pos is not None: + pos = pos - seq_start_pos[..., None] + + emb = einsum('i, j -> i j', pos, self.inv_freq) + emb = torch.cat((emb.sin(), emb.cos()), dim = -1) + return emb * self.scale + +class RotaryEmbedding(nn.Module): + def __init__( + self, + dim, + use_xpos = False, + scale_base = 512, + interpolation_factor = 1., + base = 10000, + base_rescale_factor = 1. + ): + super().__init__() + # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning + # has some connection to NTK literature + # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ + base *= base_rescale_factor ** (dim / (dim - 2)) + + inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer('inv_freq', inv_freq) + + assert interpolation_factor >= 1. + self.interpolation_factor = interpolation_factor + + if not use_xpos: + self.register_buffer('scale', None) + return + + scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) + + self.scale_base = scale_base + self.register_buffer('scale', scale) + + def forward_from_seq_len(self, seq_len): + device = self.inv_freq.device + + t = torch.arange(seq_len, device = device) + return self.forward(t) + + @autocast("cuda", enabled=False) + def forward(self, t): + device = self.inv_freq.device + seq_len = t.shape[0] if t.dim() > 0 else t.numel() + + t = t.to(torch.float32) + + t = t / self.interpolation_factor + + freqs = torch.einsum("i , j -> i j", t, self.inv_freq) + freqs = torch.cat((freqs, freqs), dim=-1) + + if self.scale is None: + return freqs, 1.0 + + power = (torch.arange(seq_len, device=device) - (seq_len // 2)) / self.scale_base + scale = self.scale ** rearrange(power, 'n -> n 1') + scale = torch.cat((scale, scale), dim = -1) + + return freqs, scale + +def rotate_half(x): + x = rearrange(x, '... (j d) -> ... j d', j = 2) + x1, x2 = x.unbind(dim = -2) + return torch.cat((-x2, x1), dim = -1) + +@autocast("cuda", enabled = False) +def apply_rotary_pos_emb(t, freqs, scale = 1): + out_dtype = t.dtype + + # cast to float32 if necessary for numerical stability + dtype = reduce(torch.promote_types, (t.dtype, freqs.dtype, torch.float32)) + rot_dim, seq_len = freqs.shape[-1], t.shape[-2] + freqs, t = freqs.to(dtype), t.to(dtype) + freqs = freqs[-seq_len:, :] + + if t.ndim == 4 and freqs.ndim == 3: + freqs = rearrange(freqs, 'b n d -> b 1 n d') + + # partial rotary embeddings, Wang et al. GPT-J + t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:] + + t = (t * freqs.cos() * scale ) + (rotate_half(t) * freqs.sin() * scale) + + t, t_unrotated = t.to(out_dtype), t_unrotated.to(out_dtype) + + return torch.cat((t, t_unrotated), dim = -1) + +# norms +class DynamicTanh(nn.Module): + def __init__(self, dim, init_alpha=10.0): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1) * init_alpha) + self.gamma = nn.Parameter(torch.ones(dim)) + self.beta = nn.Parameter(torch.zeros(dim)) + + def forward(self, x): + x = F.tanh(self.alpha * x) + return self.gamma * x + self.beta + +class RunningInstanceNorm(nn.Module): + def __init__(self, dim, momentum = 0.99, eps = 1e-4, saturate = True, trainable_gain = True): + super().__init__() + self.register_buffer("running_mean", torch.zeros(1,1,dim)) + self.register_buffer("running_std", torch.ones(1,1,dim)) + self.saturate = saturate + self.eps = eps + self.momentum = momentum + self.dim = dim + self.trainable_gain = trainable_gain + if self.trainable_gain: + self.gain = nn.Parameter(torch.ones(1)) + + def _update_stats(self, x): + self.running_mean = self.running_mean * self.momentum + x.detach().mean(dim = [0,1]).view(1, 1, self.dim) * (1 - self.momentum) + self.running_std = (self.running_std * self.momentum + x.detach().std(dim = [0,1]).view(1, 1, self.dim) * (1 - self.momentum)).clip(min = self.eps) + + def forward(self, x): + if self.training: + self._update_stats(x) + x = (x - self.running_mean) / self.running_std + if self.saturate: + x = torch.asinh(x) + if self.trainable_gain: + x = x * self.gain + return x + +class LayerNorm(nn.Module): + def __init__(self, dim, bias=False, fix_scale=False, force_fp32=False, eps=1e-5): + """ + bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less + """ + super().__init__() + + if fix_scale: + self.register_buffer("gamma", torch.ones(dim)) + else: + self.gamma = nn.Parameter(torch.ones(dim)) + + if bias: + self.beta = nn.Parameter(torch.zeros(dim)) + else: + self.register_buffer("beta", torch.zeros(dim)) + + self.eps = eps + + self.force_fp32 = force_fp32 + + def forward(self, x): + if not self.force_fp32: + return F.layer_norm(x, x.shape[-1:], weight=self.gamma, bias=self.beta, eps=self.eps) + else: + output = F.layer_norm(x.float(), x.shape[-1:], weight=self.gamma.float(), bias=self.beta.float(), eps=self.eps) + return output.to(x.dtype) + +class LayerScale(nn.Module): + def __init__(self, dim, init_val = 1e-5): + super().__init__() + self.scale = nn.Parameter(torch.full([dim], init_val)) + def forward(self, x): + return x * self.scale + +# feedforward + +class GLU(nn.Module): + def __init__( + self, + dim_in, + dim_out, + activation: Callable, + use_conv = False, + conv_kernel_size = 3, + ): + super().__init__() + self.act = activation + self.proj = nn.Linear(dim_in, dim_out * 2) if not use_conv else nn.Conv1d(dim_in, dim_out * 2, conv_kernel_size, padding = (conv_kernel_size // 2)) + self.use_conv = use_conv + + def forward(self, x): + if self.use_conv: + x = rearrange(x, 'b n d -> b d n') + x = self.proj(x) + x = rearrange(x, 'b d n -> b n d') + else: + x = self.proj(x) + + x, gate = x.chunk(2, dim = -1) + return x * self.act(gate) + +class FeedForward(nn.Module): + def __init__( + self, + dim, + dim_out = None, + mult = 4, + no_bias = False, + glu = True, + use_conv = False, + conv_kernel_size = 3, + zero_init_output = True, + ): + super().__init__() + inner_dim = int(dim * mult) + + # Default to SwiGLU + + activation = nn.SiLU() + + dim_out = dim if dim_out is None else dim_out + + if glu: + linear_in = GLU(dim, inner_dim, activation) + else: + linear_in = nn.Sequential( + Rearrange('b n d -> b d n') if use_conv else nn.Identity(), + nn.Linear(dim, inner_dim, bias = not no_bias) if not use_conv else nn.Conv1d(dim, inner_dim, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias), + Rearrange('b n d -> b d n') if use_conv else nn.Identity(), + activation + ) + + linear_out = nn.Linear(inner_dim, dim_out, bias = not no_bias) if not use_conv else nn.Conv1d(inner_dim, dim_out, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias) + + # init last linear layer to 0 + if zero_init_output: + nn.init.zeros_(linear_out.weight) + if not no_bias: + nn.init.zeros_(linear_out.bias) + + + self.ff = nn.Sequential( + linear_in, + Rearrange('b d n -> b n d') if use_conv else nn.Identity(), + linear_out, + Rearrange('b n d -> b d n') if use_conv else nn.Identity(), + ) + + #@compile + def forward(self, x): + return self.ff(x) + +class Attention(nn.Module): + def __init__( + self, + dim, + dim_heads = 64, + dim_context = None, + causal = False, + zero_init_output=True, + qk_norm: Literal['l2', 'ln', 'dyt', 'none'] = 'none', + differential = False, + feat_scale = False + ): + super().__init__() + self.dim = dim + self.dim_heads = dim_heads + + self.differential = differential + + dim_kv = dim_context if dim_context is not None else dim + + self.num_heads = dim // dim_heads + self.kv_heads = dim_kv // dim_heads + + if dim_context is not None: + if differential: + self.to_q = nn.Linear(dim, dim * 2, bias=False) + self.to_kv = nn.Linear(dim_kv, dim_kv * 3, bias=False) + else: + self.to_q = nn.Linear(dim, dim, bias=False) + self.to_kv = nn.Linear(dim_kv, dim_kv * 2, bias=False) + else: + if differential: + self.to_qkv = nn.Linear(dim, dim * 5, bias=False) + else: + self.to_qkv = nn.Linear(dim, dim * 3, bias=False) + + self.to_out = nn.Linear(dim, dim, bias=False) + + if zero_init_output: + nn.init.zeros_(self.to_out.weight) + + if qk_norm not in ['l2', 'ln', 'dyt','none']: + raise ValueError(f'qk_norm must be one of ["l2", "ln", "none"], got {qk_norm}') + + self.qk_norm = qk_norm + + if self.qk_norm == "ln": + self.q_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6) + self.k_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6) + elif self.qk_norm == 'dyt': + self.q_norm = DynamicTanh(dim_heads) + self.k_norm = DynamicTanh(dim_heads) + + self.sdp_kwargs = dict( + enable_flash = True, + enable_math = True, + enable_mem_efficient = True + ) + + self.feat_scale = feat_scale + + if self.feat_scale: + self.lambda_dc = nn.Parameter(torch.zeros(dim)) + self.lambda_hf = nn.Parameter(torch.zeros(dim)) + + self.causal = causal + if causal: + print('Using `causal` argument disables FlexAttention. If you want to use them together, incorporate causal masking into `flex_attention_block_mask`.') + + @compile + def apply_qk_layernorm(self, q, k): + q_type = q.dtype + k_type = k.dtype + q = self.q_norm(q).to(q_type) + k = self.k_norm(k).to(k_type) + return q, k + + + def apply_attn(self, q, k, v, causal = None, flex_attention_block_mask = None, flex_attention_score_mod = None, flash_attn_sliding_window = None): + + if self.num_heads != self.kv_heads: + # Repeat interleave kv_heads to match q_heads for grouped query attention + heads_per_kv_head = self.num_heads // self.kv_heads + k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v)) + + flash_attn_available = flash_attn_func is not None + if flash_attn_sliding_window is not None and (not flash_attn_available): + print(f"Cannot use FlashAttention sliding window as FlashAttention is disabled or not available") + + if (flex_attention_block_mask is not None or flex_attention_score_mod is not None) and flash_attn_sliding_window is not None: + print(f"cannot use both FlashAttention and FlexAttention, favouring FlexAttention") + + if causal and (flex_attention_block_mask is not None or flex_attention_score_mod is not None): + print(f"Disabling FlexAttention because causal is set") + flex_attention_block_mask = None + flex_attention_score_mod = None + + if (flex_attention_compiled is not None + and (flex_attention_block_mask is not None + or flex_attention_score_mod is not None)): + out = flex_attention_compiled( + q, k, v, + block_mask=flex_attention_block_mask, + score_mod=flex_attention_score_mod, + ) + elif flash_attn_available: + fa_dtype_in = q.dtype + q, k, v = map(lambda t: rearrange(t, 'b h n d -> b n h d'), (q, k, v)) + + if fa_dtype_in != torch.float16 and fa_dtype_in != torch.bfloat16: + q, k, v = map(lambda t: t.to(torch.float16), (q, k, v)) + + out = flash_attn_func(q, k, v, causal = causal, window_size=flash_attn_sliding_window if (flash_attn_sliding_window is not None) else [-1,-1]) + + out = rearrange(out.to(fa_dtype_in), 'b n h d -> b h n d') + else: + out = F.scaled_dot_product_attention(q, k, v, is_causal = causal) + return out + + + #@compile + def forward( + self, + x, + context = None, + rotary_pos_emb = None, + causal = None, + flex_attention_block_mask = None, + flex_attention_score_mod = None, + flash_attn_sliding_window = None + ): + h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None + + kv_input = context if has_context else x + + if hasattr(self, 'to_q'): + # Use separate linear projections for q and k/v + if self.differential: + q, q_diff = self.to_q(x).chunk(2, dim=-1) + q, q_diff = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, q_diff)) + q = torch.stack([q, q_diff], dim = 1) + k, k_diff, v = self.to_kv(kv_input).chunk(3, dim=-1) + k, k_diff, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = kv_h), (k, k_diff, v)) + k = torch.stack([k, k_diff], dim = 1) + else: + q = self.to_q(x) + q = rearrange(q, 'b n (h d) -> b h n d', h = h) + k, v = self.to_kv(kv_input).chunk(2, dim=-1) + k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = kv_h), (k, v)) + else: + # Use fused linear projection + if self.differential: + q, k, v, q_diff, k_diff = self.to_qkv(x).chunk(5, dim=-1) + q, k, v, q_diff, k_diff = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v, q_diff, k_diff)) + q = torch.stack([q, q_diff], dim = 1) + k = torch.stack([k, k_diff], dim = 1) + else: + q, k, v = self.to_qkv(x).chunk(3, dim=-1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) + + # Normalize q and k for cosine sim attention + if self.qk_norm == "l2": + q = F.normalize(q, dim=-1) + k = F.normalize(k, dim=-1) + elif self.qk_norm != "none": + q, k = self.apply_qk_layernorm(q, k) + + if rotary_pos_emb is not None: + freqs, _ = rotary_pos_emb + q_dtype = q.dtype + k_dtype = k.dtype + q = q.to(torch.float32) + k = k.to(torch.float32) + freqs = freqs.to(torch.float32) + if q.shape[-2] >= k.shape[-2]: + ratio = q.shape[-2] / k.shape[-2] + q_freqs, k_freqs = freqs, ratio * freqs + else: + ratio = k.shape[-2] / q.shape[-2] + q_freqs, k_freqs = ratio * freqs, freqs + q = apply_rotary_pos_emb(q, q_freqs) + k = apply_rotary_pos_emb(k, k_freqs) + q = q.to(v.dtype) + k = k.to(v.dtype) + + n, device = q.shape[-2], q.device + + causal = self.causal if causal is None else causal + + if n == 1 and causal: + causal = False + + if self.differential: + q, q_diff = q.unbind(dim = 1) + k, k_diff = k.unbind(dim = 1) + out = self.apply_attn(q, k, v, causal = causal, flex_attention_block_mask = flex_attention_block_mask, flex_attention_score_mod = flex_attention_score_mod, flash_attn_sliding_window = flash_attn_sliding_window) + out_diff = self.apply_attn(q_diff, k_diff, v, causal = causal, flex_attention_block_mask = flex_attention_block_mask, flex_attention_score_mod = flex_attention_score_mod, flash_attn_sliding_window = flash_attn_sliding_window) + out = out - out_diff + else: + out = self.apply_attn(q, k, v, causal = causal, flex_attention_block_mask = flex_attention_block_mask, flex_attention_score_mod = flex_attention_score_mod, flash_attn_sliding_window = flash_attn_sliding_window) + + # merge heads + out = rearrange(out, ' b h n d -> b n (h d)') + + # Communicate between heads + + # with autocast(enabled = False): + # out_dtype = out.dtype + # out = out.to(torch.float32) + # out = self.to_out(out).to(out_dtype) + out = self.to_out(out) + + if self.feat_scale: + out_dc = out.mean(dim=-2, keepdim=True) + out_hf = out - out_dc + + # Selectively modulate DC and high frequency components + out = out + self.lambda_dc * out_dc + self.lambda_hf * out_hf + + return out + +class ConformerModule(nn.Module): + def __init__( + self, + dim, + norm_kwargs = {}, + ): + + super().__init__() + + self.dim = dim + + self.in_norm = LayerNorm(dim, **norm_kwargs) + self.pointwise_conv = nn.Conv1d(dim, dim, kernel_size=1, bias=False) + self.glu = GLU(dim, dim, nn.SiLU()) + self.depthwise_conv = nn.Conv1d(dim, dim, kernel_size=17, groups=dim, padding=8, bias=False) + self.mid_norm = LayerNorm(dim, **norm_kwargs) # This is a batch norm in the original but I don't like batch norm + self.swish = nn.SiLU() + self.pointwise_conv_2 = nn.Conv1d(dim, dim, kernel_size=1, bias=False) + + #@compile + def forward(self, x): + x = self.in_norm(x) + x = rearrange(x, 'b n d -> b d n') + x = self.pointwise_conv(x) + x = rearrange(x, 'b d n -> b n d') + x = self.glu(x) + x = rearrange(x, 'b n d -> b d n') + x = self.depthwise_conv(x) + x = rearrange(x, 'b d n -> b n d') + x = self.mid_norm(x) + x = self.swish(x) + x = rearrange(x, 'b n d -> b d n') + x = self.pointwise_conv_2(x) + x = rearrange(x, 'b d n -> b n d') + + return x + +class TransformerBlock(nn.Module): + def __init__( + self, + dim, + dim_heads = 64, + cross_attend = False, + dim_context = None, + global_cond_dim = None, + causal = False, + zero_init_branch_outputs = True, + conformer = False, + layer_ix = -1, + remove_norms = False, + add_rope = False, + layer_scale = False, + attn_kwargs = {}, + ff_kwargs = {}, + norm_kwargs = {} + ): + + super().__init__() + self.dim = dim + self.dim_heads = min(dim_heads,dim) + self.cross_attend = cross_attend + self.dim_context = dim_context + self.causal = causal + + if layer_scale and zero_init_branch_outputs: + print('zero_init_branch_outputs is redundant with layer_scale, setting zero_init_branch_outputs to False') + zero_init_branch_outputs = False + + self.pre_norm = LayerNorm(dim,**norm_kwargs) if not remove_norms else DynamicTanh(dim) + + self.add_rope = add_rope + + self.self_attn = Attention( + dim, + dim_heads = self.dim_heads, + causal = causal, + zero_init_output=zero_init_branch_outputs, + **attn_kwargs + ) + + self.self_attn_scale = LayerScale(dim) if layer_scale else nn.Identity() + + self.cross_attend = cross_attend + if cross_attend: + self.cross_attend_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else DynamicTanh(dim) + self.cross_attn = Attention( + dim, + dim_heads = self.dim_heads, + dim_context=dim_context, + causal = causal, + zero_init_output=zero_init_branch_outputs, + **attn_kwargs + ) + self.cross_attn_scale = LayerScale(dim) if layer_scale else nn.Identity() + + self.ff_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else DynamicTanh(dim) + self.ff = FeedForward(dim, zero_init_output=zero_init_branch_outputs, **ff_kwargs) + self.ff_scale = LayerScale(dim) if layer_scale else nn.Identity() + + self.layer_ix = layer_ix + + self.conformer = None + if conformer: + self.conformer = ConformerModule(dim, norm_kwargs=norm_kwargs) + self.conformer_scale = LayerScale(dim) if layer_scale else nn.Identity() + + self.global_cond_dim = global_cond_dim + + if global_cond_dim is not None: + self.to_scale_shift_gate = nn.Parameter(torch.randn(6*dim)/dim**0.5) + + self.rope = RotaryEmbedding(self.dim_heads // 2) if add_rope else None + + @compile + def forward( + self, + x, + context = None, + global_cond=None, + rotary_pos_emb = None, + self_attention_block_mask = None, + self_attention_score_mod = None, + cross_attention_block_mask = None, + cross_attention_score_mod = None, + self_attention_flash_sliding_window = None, + cross_attention_flash_sliding_window = None + ): + if rotary_pos_emb is None and self.add_rope: + rotary_pos_emb = self.rope.forward_from_seq_len(x.shape[-2]) + + if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None: + + scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = (self.to_scale_shift_gate + global_cond).unsqueeze(1).chunk(6, dim=-1) + + # self-attention with adaLN + residual = x + x = self.pre_norm(x) + x = x * (1 + scale_self) + shift_self + x = self.self_attn(x, rotary_pos_emb = rotary_pos_emb, flex_attention_block_mask = self_attention_block_mask, flex_attention_score_mod = self_attention_score_mod, flash_attn_sliding_window = self_attention_flash_sliding_window) + x = x * torch.sigmoid(1 - gate_self) + x = self.self_attn_scale(x) + x = x + residual + + if context is not None and self.cross_attend: + x = x + self.cross_attn_scale(self.cross_attn(self.cross_attend_norm(x), context = context, flex_attention_block_mask = cross_attention_block_mask, flex_attention_score_mod = cross_attention_score_mod, flash_attn_sliding_window = cross_attention_flash_sliding_window)) + + if self.conformer is not None: + x = x + self.conformer_scale(self.conformer(x)) + + # feedforward with adaLN + residual = x + x = self.ff_norm(x) + x = x * (1 + scale_ff) + shift_ff + x = self.ff(x) + x = x * torch.sigmoid(1 - gate_ff) + x = self.ff_scale(x) + x = x + residual + + else: + x = x + self.self_attn_scale(self.self_attn(self.pre_norm(x), rotary_pos_emb = rotary_pos_emb, flex_attention_block_mask = self_attention_block_mask, flex_attention_score_mod = self_attention_score_mod, flash_attn_sliding_window = self_attention_flash_sliding_window)) + + if context is not None and self.cross_attend: + x = x + self.cross_attn_scale(self.cross_attn(self.cross_attend_norm(x), context = context, flex_attention_block_mask = cross_attention_block_mask, flex_attention_score_mod = cross_attention_score_mod, flash_attn_sliding_window = cross_attention_flash_sliding_window)) + + if self.conformer is not None: + x = x + self.conformer_scale(self.conformer(x)) + + x = x + self.ff_scale(self.ff(self.ff_norm(x))) + return x + +class ContinuousTransformer(nn.Module): + def __init__( + self, + dim, + depth, + *, + dim_in = None, + dim_out = None, + dim_heads = 64, + cross_attend=False, + cond_token_dim=None, + final_cross_attn_ix=-1, + global_cond_dim=None, + causal=False, + rotary_pos_emb=True, + zero_init_branch_outputs=True, + conformer=False, + use_sinusoidal_emb=False, + use_abs_pos_emb=False, + abs_pos_emb_max_length=10000, + num_memory_tokens=0, + sliding_window=None, + **kwargs + ): + + super().__init__() + + self.dim = dim + self.depth = depth + self.causal = causal + self.layers = nn.ModuleList([]) + + self.project_in = nn.Linear(dim_in, dim, bias=False) if dim_in is not None else nn.Identity() + self.project_out = nn.Linear(dim, dim_out, bias=False) if dim_out is not None else nn.Identity() + + if rotary_pos_emb: + self.rotary_pos_emb = RotaryEmbedding(max(dim_heads // 2, 32)) + else: + self.rotary_pos_emb = None + + self.num_memory_tokens = num_memory_tokens + if num_memory_tokens > 0: + self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim)) + + self.use_sinusoidal_emb = use_sinusoidal_emb + if use_sinusoidal_emb: + self.pos_emb = ScaledSinusoidalEmbedding(dim) + + self.use_abs_pos_emb = use_abs_pos_emb + if use_abs_pos_emb: + self.pos_emb = AbsolutePositionalEmbedding(dim, abs_pos_emb_max_length + self.num_memory_tokens) + + self.global_cond_embedder = None + if global_cond_dim is not None: + self.global_cond_embedder = nn.Sequential( + nn.Linear(global_cond_dim, dim), + nn.SiLU(), + nn.Linear(dim, dim * 6) + ) + + self.final_cross_attn_ix = final_cross_attn_ix + + self.sliding_window = sliding_window + + for i in range(depth): + should_cross_attend = cross_attend and (self.final_cross_attn_ix == -1 or i <= (self.final_cross_attn_ix)) + self.layers.append( + TransformerBlock( + dim, + dim_heads = dim_heads, + cross_attend = should_cross_attend, + dim_context = cond_token_dim, + global_cond_dim = global_cond_dim, + causal = causal, + zero_init_branch_outputs = zero_init_branch_outputs, + conformer=conformer, + layer_ix=i, + **kwargs + ) + ) + + def forward( + self, + x, + prepend_embeds = None, + global_cond = None, + return_info = False, + use_checkpointing = True, + exit_layer_ix = None, + **kwargs + ): + batch, seq, device = *x.shape[:2], x.device + + model_dtype = next(self.parameters()).dtype + x = x.to(model_dtype) + + info = { + "hidden_states": [], + } + + x = self.project_in(x) + + if prepend_embeds is not None: + prepend_length, prepend_dim = prepend_embeds.shape[1:] + + assert prepend_dim == x.shape[-1], 'prepend dimension must match sequence dimension' + + x = torch.cat((prepend_embeds, x), dim = -2) + + if self.num_memory_tokens > 0: + memory_tokens = self.memory_tokens.expand(batch, -1, -1) + x = torch.cat((memory_tokens, x), dim=1) + + if self.rotary_pos_emb is not None: + rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1]) + else: + rotary_pos_emb = None + + if self.use_sinusoidal_emb or self.use_abs_pos_emb: + x = x + self.pos_emb(x) + + if global_cond is not None and self.global_cond_embedder is not None: + global_cond = self.global_cond_embedder(global_cond) + + # Iterate over the transformer layers + for layer_ix, layer in enumerate(self.layers): + + if use_checkpointing: + x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, self_attention_flash_sliding_window = self.sliding_window, **kwargs) + else: + x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, self_attention_flash_sliding_window = self.sliding_window, **kwargs) + + if return_info: + info["hidden_states"].append(x) + + if exit_layer_ix is not None and layer_ix == exit_layer_ix: + x = x[:, self.num_memory_tokens:, :] + + if return_info: + return x, info + + return x + + x = x[:, self.num_memory_tokens:, :] + + x = self.project_out(x) + + if return_info: + return x, info + + return x diff --git a/fastvideo/models/stable_audio/sat_utils.py b/fastvideo/models/stable_audio/sat_utils.py new file mode 100644 index 0000000000..f755c71b86 --- /dev/null +++ b/fastvideo/models/stable_audio/sat_utils.py @@ -0,0 +1,7 @@ +# SPDX-License-Identifier: Apache-2.0 +# Minimal utils for Stable Audio DiT (no torch.compile by default). + + +def compile(function, *args, **kwargs): + """No-op compile for stable-audio transformer; avoids torch.compile deps.""" + return function diff --git a/fastvideo/pipelines/basic/stable_audio/__init__.py b/fastvideo/pipelines/basic/stable_audio/__init__.py new file mode 100644 index 0000000000..b6cf2c5844 --- /dev/null +++ b/fastvideo/pipelines/basic/stable_audio/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Stable Audio pipeline for text-to-audio generation.""" diff --git a/fastvideo/pipelines/basic/stable_audio/stable_audio_pipeline.py b/fastvideo/pipelines/basic/stable_audio/stable_audio_pipeline.py new file mode 100644 index 0000000000..9084da62f9 --- /dev/null +++ b/fastvideo/pipelines/basic/stable_audio/stable_audio_pipeline.py @@ -0,0 +1,188 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Stable Audio pipeline for text-to-audio generation. + +Supports loading from: +- Unified format: model.safetensors or model.ckpt at model root + model_config.json +- HuggingFace: stabilityai/stable-audio-open-1.0 (uses unified when available) +""" +from __future__ import annotations + +import json +import os +from typing import Any + +from fastvideo.configs.pipelines.stable_audio import StableAudioPipelineConfig +from fastvideo.configs.sample.stable_audio import StableAudioSamplingParam +from fastvideo.fastvideo_args import FastVideoArgs +from fastvideo.logger import init_logger +from fastvideo.models.loader.component_loader import PipelineComponentLoader +from fastvideo.models.stable_audio.conditioner import StableAudioConditioner +from fastvideo.models.stable_audio.pretransform import StableAudioPretransform +from fastvideo.pipelines.composed_pipeline_base import ComposedPipelineBase +from fastvideo.pipelines.stages import ( + StableAudioConditioningStage, + StableAudioDecodingStage, + StableAudioDenoisingStage, + StableAudioInputValidationStage, + StableAudioLatentPreparationStage, +) + +logger = init_logger(__name__) + +UNIFIED_CHECKPOINT_NAMES = ("model.safetensors", "model.ckpt") + + +def _detect_unified_checkpoint(model_root: str) -> str | None: + """Return path to unified checkpoint if present, else None.""" + for name in UNIFIED_CHECKPOINT_NAMES: + path = os.path.join(model_root, name) + if os.path.isfile(path): + return path + return None + + +def _is_unified_format(model_path: str) -> bool: + """True if model root has unified checkpoint + model_config.json.""" + model_root = model_path.rstrip(os.sep) + config_path = os.path.join(model_root, "model_config.json") + if not os.path.isfile(config_path): + return False + return _detect_unified_checkpoint(model_root) is not None + + +class StableAudioPipeline(ComposedPipelineBase): + """ + Text-to-audio pipeline using Stable Audio Open. + + Loads from unified checkpoint (model.safetensors / model.ckpt + model_config.json) + or from HuggingFace-style diffusers layout when unified is not available. + """ + + pipeline_config_cls = StableAudioPipelineConfig + sampling_params_cls = StableAudioSamplingParam + + _required_config_modules = [ + "conditioner", + "pretransform", + "transformer", + ] + + def create_pipeline_stages(self, fastvideo_args: FastVideoArgs) -> None: + self.add_stage( + stage_name="input_validation_stage", + stage=StableAudioInputValidationStage(), + ) + self.add_stage( + stage_name="conditioning_stage", + stage=StableAudioConditioningStage( + conditioner=self.get_module("conditioner")), + ) + self.add_stage( + stage_name="latent_preparation_stage", + stage=StableAudioLatentPreparationStage( + pretransform=self.get_module("pretransform")), + ) + self.add_stage( + stage_name="denoising_stage", + stage=StableAudioDenoisingStage( + transformer=self.get_module("transformer"), ), + ) + self.add_stage( + stage_name="decoding_stage", + stage=StableAudioDecodingStage( + pretransform=self.get_module("pretransform"), ), + ) + + def load_modules( + self, + fastvideo_args: FastVideoArgs, + loaded_modules: dict[str, Any] | None = None, + ) -> dict[str, Any]: + model_path = self.model_path.rstrip(os.sep) + config_path = os.path.join(model_path, "model_config.json") + unified_ckpt = _detect_unified_checkpoint(model_path) + + if unified_ckpt is not None and os.path.isfile(config_path): + return self._load_unified(model_path, config_path, unified_ckpt, + fastvideo_args) + return self._load_diffusers(model_path, fastvideo_args, loaded_modules) + + def _load_unified( + self, + model_path: str, + config_path: str, + checkpoint_path: str, + fastvideo_args: FastVideoArgs, + ) -> dict[str, Any]: + """Load from unified checkpoint (model.safetensors / model.ckpt + model_config.json).""" + with open(config_path, encoding="utf-8") as f: + model_config = json.load(f) + + pretransform = StableAudioPretransform( + model_config=model_config, + checkpoint_path=checkpoint_path, + ) + conditioner = StableAudioConditioner( + model_config=model_config, + checkpoint_path=checkpoint_path, + ) + + transformer_path = os.path.join(model_path, "transformer") + transformer = PipelineComponentLoader.load_module( + module_name="transformer", + component_model_path=transformer_path, + transformers_or_diffusers="diffusers", + fastvideo_args=fastvideo_args, + ) + + sample_rate = getattr(fastvideo_args.pipeline_config, "sample_rate", + 44100) + logger.info( + "Loaded Stable Audio (unified) from %s, sample_rate=%d", + checkpoint_path, + sample_rate, + ) + + return { + "conditioner": conditioner, + "pretransform": pretransform, + "transformer": transformer, + } + + def _load_diffusers( + self, + model_path: str, + fastvideo_args: FastVideoArgs, + loaded_modules: dict[str, Any] | None, + ) -> dict[str, Any]: + """Load from HuggingFace diffusers layout (model_index.json + subdirs).""" + model_index = self._load_config(model_path) + logger.info("Loading Stable Audio (diffusers) from %s", model_index) + + model_index.pop("_class_name") + model_index.pop("_diffusers_version") + model_index.pop("_name_or_path", None) + model_index.pop("workload_type", None) + + config_path = os.path.join(model_path, "model_config.json") + if not os.path.isfile(config_path): + raise ValueError( + "Diffusers layout requires model_config.json for Stable Audio. " + "Ensure model_config.json exists, or use a unified checkpoint " + "(model.safetensors + model_config.json).") + with open(config_path, encoding="utf-8") as f: + json.load(f) # Validate model_config.json is valid + + unified_ckpt = _detect_unified_checkpoint(model_path) + if unified_ckpt is None: + raise ValueError( + "Stable Audio requires a unified checkpoint (model.safetensors or model.ckpt) " + "at the model root. The diffusers layout alone is not supported yet." + ) + + return self._load_unified(model_path, config_path, unified_ckpt, + fastvideo_args) + + +EntryClass = StableAudioPipeline diff --git a/fastvideo/pipelines/pipeline_batch_info.py b/fastvideo/pipelines/pipeline_batch_info.py index e928987987..46e0ba4e9f 100644 --- a/fastvideo/pipelines/pipeline_batch_info.py +++ b/fastvideo/pipelines/pipeline_batch_info.py @@ -138,6 +138,11 @@ class ForwardBatch: # Camera control inputs (LingBotWorld) c2ws_plucker_emb: torch.Tensor | None = None # Plucker embedding: [B, C, F_lat, H_lat, W_lat] + # Audio inputs (Stable Audio) + sample_rate: int | None = None + duration_seconds: float | None = None + seconds_start: float | None = None + seconds_total: float | None = None # Latent dimensions height_latents: list[int] | int | None = None diff --git a/fastvideo/pipelines/stages/__init__.py b/fastvideo/pipelines/stages/__init__.py index f085d6749c..e100cba584 100644 --- a/fastvideo/pipelines/stages/__init__.py +++ b/fastvideo/pipelines/stages/__init__.py @@ -32,6 +32,16 @@ from fastvideo.pipelines.stages.ltx2_text_encoding import LTX2TextEncodingStage from fastvideo.pipelines.stages.matrixgame_denoising import ( MatrixGameCausalDenoisingStage) +from fastvideo.pipelines.stages.stable_audio_conditioning import ( + StableAudioConditioningStage) +from fastvideo.pipelines.stages.stable_audio_decoding import ( + StableAudioDecodingStage) +from fastvideo.pipelines.stages.stable_audio_denoising import ( + StableAudioDenoisingStage) +from fastvideo.pipelines.stages.stable_audio_input_validation import ( + StableAudioInputValidationStage) +from fastvideo.pipelines.stages.stable_audio_latent_preparation import ( + StableAudioLatentPreparationStage) from fastvideo.pipelines.stages.hyworld_denoising import HYWorldDenoisingStage from fastvideo.pipelines.stages.stepvideo_encoding import ( StepvideoPromptEncodingStage) @@ -88,4 +98,10 @@ "LongCatVideoVAEEncodingStage", "LongCatKVCacheInitStage", "LongCatVCDenoisingStage", + # Stable Audio stages + "StableAudioConditioningStage", + "StableAudioDecodingStage", + "StableAudioDenoisingStage", + "StableAudioInputValidationStage", + "StableAudioLatentPreparationStage", ] diff --git a/fastvideo/pipelines/stages/stable_audio_conditioning.py b/fastvideo/pipelines/stages/stable_audio_conditioning.py new file mode 100644 index 0000000000..179100ec0a --- /dev/null +++ b/fastvideo/pipelines/stages/stable_audio_conditioning.py @@ -0,0 +1,77 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Stable Audio conditioning stage: T5 (prompt) + NumberEmbedder (seconds). +""" +from __future__ import annotations + +import torch + +from fastvideo.distributed import get_local_torch_device +from fastvideo.fastvideo_args import FastVideoArgs +from fastvideo.pipelines.pipeline_batch_info import ForwardBatch +from fastvideo.pipelines.stages.base import PipelineStage + + +class StableAudioConditioningStage(PipelineStage): + """Run Stable Audio conditioner: prompt (T5) + seconds_start, seconds_total.""" + + def __init__(self, conditioner) -> None: + super().__init__() + self.conditioner = conditioner + + @torch.no_grad() + def forward( + self, + batch: ForwardBatch, + fastvideo_args: FastVideoArgs, + ) -> ForwardBatch: + if fastvideo_args.text_encoder_cpu_offload: + device = next(self.conditioner.parameters()).device + else: + device = get_local_torch_device() + self.conditioner = self.conditioner.to(device) + + prompts = batch.prompt + if isinstance(prompts, str): + prompts = [prompts] + batch_size = len(prompts) + + seconds_start = batch.seconds_start + seconds_total = batch.seconds_total + if seconds_start is None: + seconds_start = 0.0 + if seconds_total is None: + seconds_total = batch.duration_seconds or 10.0 + if isinstance(seconds_start, int | float): + seconds_start = [float(seconds_start)] * batch_size + if isinstance(seconds_total, int | float): + seconds_total = [float(seconds_total)] * batch_size + + metadata = [{ + "prompt": p, + "seconds_start": s0, + "seconds_total": st + } for p, s0, st in zip( + prompts, seconds_start, seconds_total, strict=False)] + + conditioning = self.conditioner(metadata, device) + + batch.extra["stable_audio_conditioning"] = conditioning + batch.is_prompt_processed = True + + cross_attn = conditioning.get("prompt", (None, None))[0] + if cross_attn is not None: + batch.prompt_embeds = [cross_attn] + mask = conditioning["prompt"][1] + batch.prompt_attention_mask = [mask] if mask is not None else None + + global_conds = [] + for key in ["seconds_start", "seconds_total"]: + t, _ = conditioning.get(key, (None, None)) + if t is not None: + global_conds.append(t.squeeze(1)) + if global_conds: + batch.extra["stable_audio_global_cond"] = torch.cat(global_conds, + dim=-1) + + return batch diff --git a/fastvideo/pipelines/stages/stable_audio_decoding.py b/fastvideo/pipelines/stages/stable_audio_decoding.py new file mode 100644 index 0000000000..09a05c3603 --- /dev/null +++ b/fastvideo/pipelines/stages/stable_audio_decoding.py @@ -0,0 +1,45 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Stable Audio decoding stage: decode latents to audio via Oobleck VAE. +""" +from __future__ import annotations + +import torch + +from fastvideo.distributed import get_local_torch_device +from fastvideo.fastvideo_args import FastVideoArgs +from fastvideo.pipelines.pipeline_batch_info import ForwardBatch +from fastvideo.pipelines.stages.base import PipelineStage +from fastvideo.utils import PRECISION_TO_TYPE + + +class StableAudioDecodingStage(PipelineStage): + """Decode Stable Audio latents to waveform using pretransform (Oobleck VAE).""" + + def __init__(self, pretransform) -> None: + super().__init__() + self.pretransform = pretransform + + @torch.no_grad() + def forward( + self, + batch: ForwardBatch, + fastvideo_args: FastVideoArgs, + ) -> ForwardBatch: + latents = batch.latents + if latents is None: + raise ValueError("Latents must be provided before decoding.") + + device = get_local_torch_device() + self.pretransform = self.pretransform.to(device) + latents = latents.to(device) + + vae_dtype = PRECISION_TO_TYPE.get( + getattr(fastvideo_args.pipeline_config, "vae_precision", "fp32"), + torch.float32, + ) + latents = latents.to(vae_dtype) + + audio = self.pretransform.decode(latents) + batch.output = audio + return batch diff --git a/fastvideo/pipelines/stages/stable_audio_denoising.py b/fastvideo/pipelines/stages/stable_audio_denoising.py new file mode 100644 index 0000000000..3a6bca6d9e --- /dev/null +++ b/fastvideo/pipelines/stages/stable_audio_denoising.py @@ -0,0 +1,78 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Stable Audio denoising stage using k-diffusion v-prediction sampling. +""" +from __future__ import annotations + +from typing import Any + +import torch + +from fastvideo.fastvideo_args import FastVideoArgs +from fastvideo.logger import init_logger +from fastvideo.pipelines.pipeline_batch_info import ForwardBatch +from fastvideo.pipelines.stages.base import PipelineStage + +logger = init_logger(__name__) + + +class StableAudioDenoisingStage(PipelineStage): + """Run k-diffusion v-prediction sampling for Stable Audio.""" + + def __init__(self, transformer) -> None: + super().__init__() + self.transformer = transformer + + @torch.no_grad() + def forward( + self, + batch: ForwardBatch, + fastvideo_args: FastVideoArgs, + ) -> ForwardBatch: + latents = batch.latents + if latents is None: + raise ValueError("Latents must be provided before denoising.") + + conditioning = batch.extra.get("stable_audio_conditioning") + if conditioning is None: + raise ValueError( + "Conditioning must be set by StableAudioConditioningStage.") + + cfg_scale = batch.guidance_scale or 6.0 + steps = batch.num_inference_steps or 250 + + cross_attn = conditioning.get("prompt", (None, None)) + cross_attn_cond, cross_attn_mask = cross_attn[0], cross_attn[1] + global_cond = batch.extra.get("stable_audio_global_cond") + + cond_inputs = { + "cross_attn_cond": cross_attn_cond, + "cross_attn_cond_mask": cross_attn_mask, + "global_embed": global_cond, + } + + device = latents.device + + def _to_device(x: Any) -> Any: + if x is None: + return x + if torch.is_tensor(x): + return x.to(device) + return x + + cond_inputs = {k: _to_device(v) for k, v in cond_inputs.items()} + + from fastvideo.models.stable_audio.sampling import sample_stable_audio + + sampled = sample_stable_audio( + self.transformer.model, + batch.latents, + steps=steps, + device=device, + cfg_scale=cfg_scale, + **cond_inputs, + ) + + batch.latents = sampled + logger.info("[StableAudio] Denoising done: steps=%d", steps) + return batch diff --git a/fastvideo/pipelines/stages/stable_audio_input_validation.py b/fastvideo/pipelines/stages/stable_audio_input_validation.py new file mode 100644 index 0000000000..853e422c4f --- /dev/null +++ b/fastvideo/pipelines/stages/stable_audio_input_validation.py @@ -0,0 +1,50 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Stable Audio input validation: audio-specific checks. +""" +from __future__ import annotations + +import torch + +from fastvideo.fastvideo_args import FastVideoArgs +from fastvideo.pipelines.pipeline_batch_info import ForwardBatch +from fastvideo.pipelines.stages.input_validation import InputValidationStage + + +class StableAudioInputValidationStage(InputValidationStage): + """Input validation for Stable Audio: uses duration instead of height/width.""" + + def _generate_seeds(self, batch: ForwardBatch, + fastvideo_args: FastVideoArgs): + seed = batch.seed + num_videos_per_prompt = batch.num_videos_per_prompt + assert seed is not None + batch.seeds = [seed + i for i in range(num_videos_per_prompt)] + batch.generator = [ + torch.Generator("cpu").manual_seed(s) for s in batch.seeds + ] + + def forward( + self, + batch: ForwardBatch, + fastvideo_args: FastVideoArgs, + ) -> ForwardBatch: + self._generate_seeds(batch, fastvideo_args) + + if batch.prompt is None and batch.prompt_embeds is None: + raise ValueError( + "Either `prompt` or `prompt_embeds` must be provided") + + if batch.num_inference_steps <= 0: + raise ValueError( + f"num_inference_steps must be positive, got {batch.num_inference_steps}" + ) + + if batch.do_classifier_free_guidance and batch.guidance_scale <= 0: + raise ValueError( + f"guidance_scale must be positive when using CFG, got {batch.guidance_scale}" + ) + + duration = batch.duration_seconds or 10.0 + batch.extra["stable_audio_duration"] = duration + return batch diff --git a/fastvideo/pipelines/stages/stable_audio_latent_preparation.py b/fastvideo/pipelines/stages/stable_audio_latent_preparation.py new file mode 100644 index 0000000000..c6373bc2af --- /dev/null +++ b/fastvideo/pipelines/stages/stable_audio_latent_preparation.py @@ -0,0 +1,60 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Stable Audio latent preparation: sample initial noise for k-diffusion. +""" +from __future__ import annotations + +import torch + +from fastvideo.distributed import get_local_torch_device +from fastvideo.fastvideo_args import FastVideoArgs +from fastvideo.pipelines.pipeline_batch_info import ForwardBatch +from fastvideo.pipelines.stages.base import PipelineStage + + +class StableAudioLatentPreparationStage(PipelineStage): + """Prepare initial noise latents for Stable Audio denoising.""" + + def __init__(self, pretransform) -> None: + super().__init__() + self.pretransform = pretransform + + def forward( + self, + batch: ForwardBatch, + fastvideo_args: FastVideoArgs, + ) -> ForwardBatch: + sample_rate = getattr(fastvideo_args.pipeline_config, "sample_rate", + 44100) or 44100 + duration = batch.duration_seconds or 10.0 + sample_size = int(duration * sample_rate) + latent_size = sample_size // self.pretransform.downsampling_ratio + latent_channels = self.pretransform.encoded_channels + + batch_size = 1 + if batch.prompt is not None: + batch_size = len(batch.prompt) if isinstance(batch.prompt, + list) else 1 + batch_size *= batch.num_videos_per_prompt + + device = get_local_torch_device() + dtype = next(self.pretransform.parameters()).dtype + + seed = batch.seeds[0] if batch.seeds else (batch.seed or 0) + generator = batch.generator + if isinstance(generator, list): + generator = generator[0] if generator else None + if generator is None or str(generator.device) != str(device): + generator = torch.Generator(device).manual_seed(seed) + + latents = torch.randn( + (batch_size, latent_channels, latent_size), + generator=generator, + device=device, + dtype=dtype, + ) + + batch.latents = latents + batch.raw_latent_shape = latents.shape + batch.extra["stable_audio_sample_size"] = sample_size + return batch diff --git a/fastvideo/registry.py b/fastvideo/registry.py index c520be15f7..cceac51696 100644 --- a/fastvideo/registry.py +++ b/fastvideo/registry.py @@ -25,6 +25,7 @@ from fastvideo.configs.pipelines.lingbotworld import LingBotWorldI2V480PConfig from fastvideo.configs.pipelines.longcat import LongCatT2V480PConfig from fastvideo.configs.pipelines.ltx2 import LTX2T2VConfig +from fastvideo.configs.pipelines.stable_audio import StableAudioPipelineConfig from fastvideo.configs.pipelines.stepvideo import StepVideoT2VConfig from fastvideo.configs.pipelines.turbodiffusion import ( TurboDiffusionI2V_A14B_Config, @@ -62,6 +63,7 @@ from fastvideo.configs.sample.lingbotworld import LingBotWorld_SamplingParam from fastvideo.configs.sample.ltx2 import (LTX2BaseSamplingParam, LTX2DistilledSamplingParam) +from fastvideo.configs.sample.stable_audio import StableAudioSamplingParam from fastvideo.configs.sample.stepvideo import StepVideoT2VSamplingParam from fastvideo.configs.sample.turbodiffusion import ( TurboDiffusionI2V_A14B_SamplingParam, @@ -245,6 +247,20 @@ def _get_config_info( def _register_configs() -> None: + # Stable Audio + register_configs( + sampling_param_cls=StableAudioSamplingParam, + pipeline_config_cls=StableAudioPipelineConfig, + hf_model_paths=[ + "stabilityai/stable-audio-open-1.0", + "stable-audio-open-1.0", + ], + model_detectors=[ + lambda path: "stable-audio" in path.lower(), + lambda path: "stableaudio" in path.lower(), + ], + ) + # LTX-2 (base) register_configs( sampling_param_cls=LTX2BaseSamplingParam, diff --git a/mkdocs.yml b/mkdocs.yml index 7624f75d63..5e6376f45e 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -128,6 +128,7 @@ nav: - Optimizations: inference/optimizations.md - ComfyUI: inference/comfyui.md - Support Matrix: inference/support_matrix.md + - Stable Audio: inference/stable_audio.md - CLI: inference/cli.md - Add Pipeline: inference/add_pipeline.md - Examples: diff --git a/pyproject.toml b/pyproject.toml index 3e82dd0613..1e88df4eec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -125,6 +125,12 @@ rocm = [ "amdsmi", ] +stable-audio = [ + "k-diffusion>=0.1.1", + "alias-free-torch>=0.0.6", + "einops-exts>=0.0.4", +] + [project.scripts] fastvideo = "fastvideo.entrypoints.cli.main:main" diff --git a/tests/local_tests/pipelines/test_stable_audio_pipeline_smoke.py b/tests/local_tests/pipelines/test_stable_audio_pipeline_smoke.py new file mode 100644 index 0000000000..0ebc455437 --- /dev/null +++ b/tests/local_tests/pipelines/test_stable_audio_pipeline_smoke.py @@ -0,0 +1,67 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Smoke test for Stable Audio pipeline. + +Loads the pipeline, runs one generate_audio call, and asserts: +- Output audio shape, dtype, sample_rate match config +- No exceptions +""" +import pytest +import torch + +from fastvideo import VideoGenerator + + +def _model_path() -> str: + return "stabilityai/stable-audio-open-1.0" + + +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="Stable Audio smoke test requires CUDA.", +) +def test_stable_audio_pipeline_smoke() -> None: + model_path = _model_path() + sample_rate = 44100 + duration_seconds = 5.0 + expected_sample_size = int(duration_seconds * sample_rate) + expected_channels = 2 + downsampling_ratio = 2048 + + generator = VideoGenerator.from_pretrained( + model_path, + num_gpus=1, + use_fsdp_inference=False, + dit_cpu_offload=False, + vae_cpu_offload=False, + ) + + result = generator.generate_audio( + prompt="A short piano note.", + duration_seconds=duration_seconds, + num_inference_steps=10, + seed=42, + ) + + generator.shutdown() + + audio = result["audio"] + assert isinstance(audio, torch.Tensor), "Expected audio to be a tensor" + assert audio.ndim == 3, f"Expected (B, C, T), got {audio.shape}" + b, c, t = audio.shape + assert b >= 1, "Expected at least one batch" + assert c == expected_channels, f"Expected stereo ({expected_channels}), got {c}" + assert abs(t - expected_sample_size) <= downsampling_ratio, ( + f"Expected ~{expected_sample_size} samples ({duration_seconds}s @ {sample_rate}Hz), got {t}" + ) + + assert result["sample_rate"] == sample_rate, ( + f"Expected sample_rate {sample_rate}, got {result['sample_rate']}" + ) + + assert audio.dtype in (torch.float32, torch.float16, torch.bfloat16), ( + f"Expected float dtype, got {audio.dtype}" + ) + + assert "generation_time" in result + assert result["prompt"] == "A short piano note." diff --git a/tests/local_tests/stable_audio/test_parity.py b/tests/local_tests/stable_audio/test_parity.py new file mode 100644 index 0000000000..b4155e1da1 --- /dev/null +++ b/tests/local_tests/stable_audio/test_parity.py @@ -0,0 +1,318 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +""" +Parity tests for Stable Audio components (transformer, VAE, conditioner, scheduler). + +Verifies that FastVideo's Stable Audio implementations match stable-audio-tools +outputs when loading from unified model.safetensors. + +Run from project root: + python tests/local_tests/stable_audio/test_parity.py + python tests/local_tests/stable_audio/test_parity.py --test transformer # run specific test + +Uses official_weights/stable-audio-open-1.0/ (model.safetensors + model_config.json). +If missing, downloads from HuggingFace stabilityai/stable-audio-open-1.0 (set HF_TOKEN +for gated access). +""" +import argparse +import os +import sys + +import torch + +REPO_ROOT = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.dirname(__file__))) +) +sys.path.insert(0, REPO_ROOT) +SAT_PATH = os.path.join(REPO_ROOT, "stable-audio-tools") +if os.path.isdir(SAT_PATH): + sys.path.insert(0, SAT_PATH) + +HF_STABLE_AUDIO_ID = "stabilityai/stable-audio-open-1.0" +MODEL_ROOT = os.path.join(REPO_ROOT, "official_weights", "stable-audio-open-1.0") +CHECKPOINT_PATH = os.path.join(MODEL_ROOT, "model.safetensors") +CONFIG_PATH = os.path.join(MODEL_ROOT, "model_config.json") + + +def _ensure_model_downloaded() -> bool: + """If checkpoint missing, try downloading from HuggingFace. Returns True if ready.""" + if os.path.exists(CHECKPOINT_PATH) and os.path.exists(CONFIG_PATH): + return True + print(f" {CHECKPOINT_PATH} not found. Trying HF download ({HF_STABLE_AUDIO_ID})...") + try: + from huggingface_hub import snapshot_download + + os.makedirs(os.path.dirname(MODEL_ROOT), exist_ok=True) + token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") + snapshot_download( + repo_id=HF_STABLE_AUDIO_ID, + local_dir=MODEL_ROOT, + ignore_patterns=["*.onnx", "*.msgpack"], + token=token, + ) + except Exception as e: + print(f" Download failed: {e}") + return False + if not os.path.exists(CHECKPOINT_PATH): + print(f" SKIP: After download, {CHECKPOINT_PATH} still missing (repo layout may differ).") + return False + print(f" Downloaded to {MODEL_ROOT}") + return True + + +def _checkpoint_exists() -> bool: + if not _ensure_model_downloaded(): + print(f"SKIP: {CHECKPOINT_PATH} not found. Download model or set HF_TOKEN.") + return False + return True + + +# --- Transformer --- + + +def _load_fastvideo_transformer(): + from fastvideo.configs.models.dits.stable_audio import StableAudioDiTConfig + from fastvideo.models.dits.stable_audio import StableAudioDiTModel + from fastvideo.models.loader.utils import get_param_names_mapping + from fastvideo.models.loader.weight_utils import safetensors_weights_iterator + + config = StableAudioDiTConfig() + config.arch_config.in_channels = 64 + config.arch_config.global_states_input_dim = 1536 + config.arch_config.cross_attention_dim = 768 + config.arch_config.num_layers = 24 + config.arch_config.num_attention_heads = 24 + + model = StableAudioDiTModel(config=config) + mapping_fn = get_param_names_mapping(config.arch_config.param_names_mapping) + weight_iter = safetensors_weights_iterator( + [CHECKPOINT_PATH], to_cpu=True, key_prefix="model.model." + ) + state_dict = {mapping_fn(k)[0]: v for k, v in weight_iter} + missing, unexpected = model.load_state_dict(state_dict, strict=False) + assert len(missing) == 0, f"Should have no missing keys, got {len(missing)}" + return model + + +def _load_reference_transformer(): + import json + from stable_audio_tools.models.factory import create_model_from_config + from stable_audio_tools.models.utils import load_ckpt_state_dict + + with open(CONFIG_PATH) as f: + config = json.load(f) + full = create_model_from_config(config) + full.load_state_dict(load_ckpt_state_dict(CHECKPOINT_PATH), strict=False) + return full.model + + +def test_transformer(): + if not _checkpoint_exists(): + return + torch.manual_seed(42) + device = "cuda" if torch.cuda.is_available() else "cpu" + + fv_model = _load_fastvideo_transformer() + fv_inner = fv_model.model.to(device).eval() + B, C, T = 1, 64, 64 + x = torch.randn(B, C, T, device=device, dtype=torch.float32) + t = torch.rand(B, device=device, dtype=torch.float32) * 100 + + with torch.no_grad(): + out_fv = fv_inner(x, t) + print(f" Transformer: FastVideo output shape {out_fv.shape}") + + try: + sat_model = _load_reference_transformer().to(device).eval() + with torch.no_grad(): + out_sat = sat_model(x, t) + max_diff = (out_fv - out_sat).abs().max().item() + torch.testing.assert_close(out_fv, out_sat, atol=1e-5, rtol=1e-4) + print(f" Transformer: max_diff={max_diff:.6f} PASS") + except ImportError: + print(" Transformer: PASS (no stable-audio-tools comparison)") + + +# --- VAE / Pretransform --- + + +def _load_reference_pretransform(): + import json + from stable_audio_tools.models.factory import create_model_from_config + from stable_audio_tools.models.utils import load_ckpt_state_dict + + with open(CONFIG_PATH) as f: + config = json.load(f) + full = create_model_from_config(config) + full.load_state_dict(load_ckpt_state_dict(CHECKPOINT_PATH), strict=False) + return full.pretransform + + +def _load_fastvideo_pretransform(): + from fastvideo.models.stable_audio import StableAudioPretransform + + return StableAudioPretransform( + model_config=CONFIG_PATH, + checkpoint_path=CHECKPOINT_PATH, + ) + + +def test_vae(): + if not _checkpoint_exists(): + return + torch.manual_seed(42) + device = "cuda" if torch.cuda.is_available() else "cpu" + x = torch.randn(1, 2, 4096, device=device, dtype=torch.float32) + + ref = _load_reference_pretransform().to(device).eval() + fv = _load_fastvideo_pretransform().to(device).eval() + + with torch.no_grad(): + z_ref = ref.encode(x) + x_recon_ref = ref.decode(z_ref) + z_fv = fv.encode(x) + x_recon_fv = fv.decode(z_fv) + + torch.testing.assert_close(z_fv, z_ref, atol=0.02, rtol=0.02) + torch.testing.assert_close(x_recon_fv, x_recon_ref, atol=0.5, rtol=0.05) + print(f" VAE: z {z_fv.shape} recon {x_recon_fv.shape} PASS") + + +# --- Conditioner --- + + +def _load_reference_conditioner(): + import json + from stable_audio_tools.models.factory import create_model_from_config + from stable_audio_tools.models.utils import load_ckpt_state_dict + + with open(CONFIG_PATH) as f: + config = json.load(f) + full = create_model_from_config(config) + full.load_state_dict(load_ckpt_state_dict(CHECKPOINT_PATH), strict=False) + return full.conditioner + + +def _load_fastvideo_conditioner(): + from fastvideo.models.stable_audio import StableAudioConditioner + + return StableAudioConditioner( + model_config=CONFIG_PATH, + checkpoint_path=CHECKPOINT_PATH, + ) + + +def test_conditioner(): + if not _checkpoint_exists(): + return + device = "cuda" if torch.cuda.is_available() else "cpu" + metadata = [{"prompt": "Amen break 174 BPM", "seconds_start": 0, "seconds_total": 12}] + + ref = _load_reference_conditioner().to(device).eval() + fv = _load_fastvideo_conditioner().to(device).eval() + + with torch.no_grad(): + cond_ref = ref(metadata, device) + cond_fv = fv(metadata, device) + + for key in cond_ref: + r0, r1 = cond_ref[key] + f0, f1 = cond_fv[key] + if r0 is not None: + torch.testing.assert_close(r0, f0, atol=1e-4, rtol=1e-3) + if r1 is not None: + torch.testing.assert_close(r1, f1, atol=1e-5, rtol=1e-4) + print(" Conditioner: PASS") + + +# --- Scheduler / Sampling --- + + +def _load_reference_model(): + import json + from stable_audio_tools.models.factory import create_model_from_config + from stable_audio_tools.models.utils import load_ckpt_state_dict + + with open(CONFIG_PATH) as f: + config = json.load(f) + model = create_model_from_config(config) + model.load_state_dict(load_ckpt_state_dict(CHECKPOINT_PATH), strict=False) + return model + + +def test_scheduler(): + if not _checkpoint_exists(): + return + device = "cuda" if torch.cuda.is_available() else "cpu" + seed = 42 + steps = 10 + + model = _load_reference_model().to(device).eval() + metadata = [{"prompt": "Amen break", "seconds_start": 0, "seconds_total": 12}] + + with torch.no_grad(): + cond_tensors = model.conditioner(metadata, device) + cond_inputs = model.get_conditioning_inputs(cond_tensors) + + sample_size, latent_channels = 1024, 64 + noise_shape = (1, latent_channels, sample_size) + cond_inputs_cuda = {k: v.to(device) if v is not None else v for k, v in cond_inputs.items()} + + def model_fn(x, sigma, **kwargs): + return model.model(x, sigma, **cond_inputs_cuda) + + from stable_audio_tools.inference.sampling import sample_k + from fastvideo.models.stable_audio.sampling import sample_stable_audio + + torch.manual_seed(seed) + noise_ref = torch.randn(noise_shape, device=device, dtype=torch.float32) + sampled_ref = sample_k( + model_fn, noise_ref, init_data=None, steps=steps, + sampler_type="dpmpp-2m-sde", sigma_min=0.01, sigma_max=100, rho=1.0, + device=device, cfg_scale=6.0, batch_cfg=True, rescale_cfg=True, + ) + + torch.manual_seed(seed) + noise_fv = torch.randn(noise_shape, device=device, dtype=torch.float32) + sampled_fv = sample_stable_audio( + model_fn, noise_fv, steps=steps, device=device, + cfg_scale=6.0, batch_cfg=True, rescale_cfg=True, + ) + + torch.testing.assert_close(sampled_ref, sampled_fv, atol=1e-3, rtol=1e-2) + print(" Scheduler/sampling: PASS") + + +# --- Main --- + +TESTS = { + "transformer": test_transformer, + "vae": test_vae, + "conditioner": test_conditioner, + "scheduler": test_scheduler, +} + + +def main(): + parser = argparse.ArgumentParser(description="Stable Audio parity tests") + parser.add_argument( + "--test", + choices=list(TESTS) + ["all"], + default="all", + help="Which test to run (default: all)", + ) + args = parser.parse_args() + + if args.test == "all": + for name, fn in TESTS.items(): + print(f"\n[{name}]") + fn() + else: + print(f"\n[{args.test}]") + TESTS[args.test]() + + print("\nAll selected parity tests passed.") + + +if __name__ == "__main__": + main()