Skip to content

Commit b003a47

Browse files
RuixiangMadg845github-actions[bot]
authored
[Feat] support AutoPipelineForText2Audio (#13511)
* [Feat] support AutoPipelineForText2Audio Signed-off-by: Lancer <maruixiang6688@gmail.com> * upd Signed-off-by: Lancer <maruixiang6688@gmail.com> * Apply style fixes --------- Signed-off-by: Lancer <maruixiang6688@gmail.com> Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent c943837 commit b003a47

6 files changed

Lines changed: 297 additions & 2 deletions

File tree

docs/source/en/api/pipelines/auto_pipeline.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,10 @@ The `AutoPipeline` is designed to make it easy to load a checkpoint for a task w
3737
- all
3838
- from_pretrained
3939
- from_pipe
40+
41+
## AutoPipelineForText2Audio
42+
43+
[[autodoc]] AutoPipelineForText2Audio
44+
- all
45+
- from_pretrained
46+
- from_pipe

docs/source/en/tutorials/autopipeline.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,13 @@ pipeline = AutoPipelineForImage2Image.from_pretrained(
6262
"ValueError: AutoPipeline can't find a pipeline linked to ShapEImg2ImgPipeline for None"
6363
```
6464

65-
There are three types of [AutoPipeline](../api/models/auto_model) classes, [`AutoPipelineForText2Image`], [`AutoPipelineForImage2Image`] and [`AutoPipelineForInpainting`]. Each of these classes have a predefined mapping, linking a pipeline to their task-specific subclass.
65+
There are four types of [AutoPipeline](../api/models/auto_model) classes:
66+
67+
- [`AutoPipelineForText2Image`]
68+
- [`AutoPipelineForImage2Image`]
69+
- [`AutoPipelineForInpainting`]
70+
- [`AutoPipelineForText2Audio`]
71+
72+
Each of these classes have a predefined mapping, linking a pipeline to their task-specific subclass.
6673

6774
When [`~AutoPipelineForText2Image.from_pretrained`] is called, it extracts the class name from the `model_index.json` file and selects the appropriate pipeline subclass for the task based on the mapping.

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,7 @@
345345
"AudioPipelineOutput",
346346
"AutoPipelineForImage2Image",
347347
"AutoPipelineForInpainting",
348+
"AutoPipelineForText2Audio",
348349
"AutoPipelineForText2Image",
349350
"ConsistencyModelPipeline",
350351
"DanceDiffusionPipeline",
@@ -1179,6 +1180,7 @@
11791180
AudioPipelineOutput,
11801181
AutoPipelineForImage2Image,
11811182
AutoPipelineForInpainting,
1183+
AutoPipelineForText2Audio,
11821184
AutoPipelineForText2Image,
11831185
BlipDiffusionControlNetPipeline,
11841186
BlipDiffusionPipeline,

src/diffusers/pipelines/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
"AutoPipelineForImage2Image",
4646
"AutoPipelineForInpainting",
4747
"AutoPipelineForText2Image",
48+
"AutoPipelineForText2Audio",
4849
]
4950
_import_structure["consistency_models"] = ["ConsistencyModelPipeline"]
5051
_import_structure["ddim"] = ["DDIMPipeline"]
@@ -557,6 +558,7 @@
557558
from .auto_pipeline import (
558559
AutoPipelineForImage2Image,
559560
AutoPipelineForInpainting,
561+
AutoPipelineForText2Audio,
560562
AutoPipelineForText2Image,
561563
)
562564
from .consistency_models import ConsistencyModelPipeline

src/diffusers/pipelines/auto_pipeline.py

Lines changed: 263 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from ..models.controlnets import ControlNetUnionModel
2222
from ..utils import is_sentencepiece_available
2323
from .anyflow import AnyFlowFARPipeline, AnyFlowPipeline
24+
from .audioldm2 import AudioLDM2Pipeline
2425
from .aura_flow import AuraFlowPipeline
2526
from .chroma import ChromaPipeline
2627
from .cogview3 import CogView3PlusPipeline
@@ -76,6 +77,7 @@
7677
)
7778
from .kandinsky3 import Kandinsky3Img2ImgPipeline, Kandinsky3Pipeline
7879
from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline
80+
from .longcat_audio_dit import LongCatAudioDiTPipeline
7981
from .lumina import LuminaPipeline
8082
from .lumina2 import Lumina2Pipeline
8183
from .nucleusmoe_image import NucleusMoEImagePipeline
@@ -110,6 +112,7 @@
110112
QwenImagePipeline,
111113
)
112114
from .sana import SanaPipeline
115+
from .stable_audio import StableAudioPipeline
113116
from .stable_cascade import StableCascadeCombinedPipeline, StableCascadeDecoderPipeline
114117
from .stable_diffusion import (
115118
StableDiffusionImg2ImgPipeline,
@@ -193,6 +196,14 @@
193196
]
194197
)
195198

199+
AUTO_TEXT2AUDIO_PIPELINES_MAPPING = OrderedDict(
200+
[
201+
("audioldm2", AudioLDM2Pipeline),
202+
("stable-audio", StableAudioPipeline),
203+
("longcat-audio-dit", LongCatAudioDiTPipeline),
204+
]
205+
)
206+
196207
AUTO_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict(
197208
[
198209
("stable-diffusion", StableDiffusionImg2ImgPipeline),
@@ -305,6 +316,7 @@
305316
AUTO_TEXT2VIDEO_PIPELINES_MAPPING,
306317
AUTO_IMAGE2VIDEO_PIPELINES_MAPPING,
307318
AUTO_VIDEO2VIDEO_PIPELINES_MAPPING,
319+
AUTO_TEXT2AUDIO_PIPELINES_MAPPING,
308320
_AUTO_TEXT2IMAGE_DECODER_PIPELINES_MAPPING,
309321
_AUTO_IMAGE2IMAGE_DECODER_PIPELINES_MAPPING,
310322
_AUTO_INPAINT_DECODER_PIPELINES_MAPPING,
@@ -851,7 +863,6 @@ def from_pipe(cls, pipeline, **kwargs):
851863

852864
original_config = dict(pipeline.config)
853865
original_cls_name = pipeline.__class__.__name__
854-
855866
# derive the pipeline class to instantiate
856867
image_2_image_cls = _get_task_class(AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, original_cls_name)
857868

@@ -1239,3 +1250,254 @@ def from_pipe(cls, pipeline, **kwargs):
12391250
model.register_to_config(**unused_original_config)
12401251

12411252
return model
1253+
1254+
1255+
class AutoPipelineForText2Audio(ConfigMixin):
1256+
r"""
1257+
1258+
[`AutoPipelineForText2Audio`] is a generic pipeline class that instantiates a text-to-audio pipeline class. The
1259+
specific underlying pipeline class is automatically selected from either the
1260+
[`~AutoPipelineForText2Audio.from_pretrained`] or [`~AutoPipelineForText2Audio.from_pipe`] methods.
1261+
1262+
This class cannot be instantiated using `__init__()` (throws an error).
1263+
1264+
Class attributes:
1265+
1266+
- **config_name** (`str`) -- The configuration filename that stores the class and module names of all the
1267+
diffusion pipeline's components.
1268+
1269+
"""
1270+
1271+
config_name = "model_index.json"
1272+
1273+
def __init__(self, *args, **kwargs):
1274+
raise EnvironmentError(
1275+
f"{self.__class__.__name__} is designed to be instantiated "
1276+
f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or "
1277+
f"`{self.__class__.__name__}.from_pipe(pipeline)` methods."
1278+
)
1279+
1280+
@classmethod
1281+
@validate_hf_hub_args
1282+
def from_pretrained(cls, pretrained_model_or_path, **kwargs):
1283+
r"""
1284+
Instantiates a text-to-audio Pytorch diffusion pipeline from pretrained pipeline weight.
1285+
1286+
The from_pretrained() method takes care of returning the correct pipeline class instance by:
1287+
1. Detect the pipeline class of the pretrained_model_or_path based on the _class_name property of its
1288+
config object
1289+
2. Find the text-to-audio pipeline linked to the pipeline class using pattern matching on pipeline class
1290+
name.
1291+
1292+
The pipeline is set in evaluation mode (`model.eval()`) by default.
1293+
1294+
Parameters:
1295+
pretrained_model_or_path (`str` or `os.PathLike`, *optional*):
1296+
Can be either:
1297+
1298+
- A string, the *repo id* (for example `stabilityai/stable-audio-open-1.0`) of a pretrained
1299+
pipeline hosted on the Hub.
1300+
- A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights
1301+
saved using
1302+
[`~DiffusionPipeline.save_pretrained`].
1303+
torch_dtype (`torch.dtype`, *optional*):
1304+
Override the default `torch.dtype` and load the model with another dtype.
1305+
force_download (`bool`, *optional*, defaults to `False`):
1306+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
1307+
cached versions if they exist.
1308+
cache_dir (`str | os.PathLike`, *optional*):
1309+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
1310+
is not used.
1311+
1312+
proxies (`dict[str, str]`, *optional*):
1313+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
1314+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
1315+
output_loading_info(`bool`, *optional*, defaults to `False`):
1316+
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
1317+
local_files_only (`bool`, *optional*, defaults to `False`):
1318+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
1319+
won't be downloaded from the Hub.
1320+
token (`str` or *bool*, *optional*):
1321+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
1322+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
1323+
revision (`str`, *optional*, defaults to `"main"`):
1324+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
1325+
allowed by Git.
1326+
custom_revision (`str`, *optional*, defaults to `"main"`):
1327+
The specific model version to use. It can be a branch name, a tag name, or a commit id similar to
1328+
`revision` when loading a custom pipeline from the Hub. It can be a 🤗 Diffusers version when loading a
1329+
custom pipeline from GitHub, otherwise it defaults to `"main"` when loading from the Hub.
1330+
mirror (`str`, *optional*):
1331+
Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
1332+
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
1333+
information.
1334+
device_map (`str` or `dict[str, int | str | torch.device]`, *optional*):
1335+
A map that specifies where each submodule should go. It doesn't need to be defined for each
1336+
parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
1337+
same device.
1338+
1339+
Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
1340+
more information about each option see [designing a device
1341+
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
1342+
max_memory (`Dict`, *optional*):
1343+
A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
1344+
each GPU and the available CPU RAM if unset.
1345+
offload_folder (`str` or `os.PathLike`, *optional*):
1346+
The path to offload weights if device_map contains the value `"disk"`.
1347+
offload_state_dict (`bool`, *optional*):
1348+
If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
1349+
the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
1350+
when there is some disk offload.
1351+
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
1352+
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
1353+
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
1354+
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
1355+
argument to `True` will raise an error.
1356+
use_safetensors (`bool`, *optional*, defaults to `None`):
1357+
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
1358+
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
1359+
weights. If set to `False`, safetensors weights are not loaded.
1360+
kwargs (remaining dictionary of keyword arguments, *optional*):
1361+
Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
1362+
class). The overwritten components are passed directly to the pipelines `__init__` method. See example
1363+
below for more information.
1364+
variant (`str`, *optional*):
1365+
Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
1366+
loading `from_flax`.
1367+
1368+
> [!TIP] > To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in
1369+
with `hf > auth login`.
1370+
1371+
Examples:
1372+
1373+
```py
1374+
>>> import torch
1375+
>>> import soundfile as sf
1376+
>>> from diffusers import AutoPipelineForText2Audio
1377+
1378+
>>> pipeline = AutoPipelineForText2Audio.from_pretrained(
1379+
... "stabilityai/stable-audio-open-1.0", torch_dtype=torch.float16
1380+
... )
1381+
>>> pipeline = pipeline.to("cuda")
1382+
1383+
>>> output = pipeline(
1384+
... "Generate a male voice reading a paragraph",
1385+
... num_inference_steps=200,
1386+
... audio_end_in_s=10.0,
1387+
... )
1388+
>>> audio = output.audios[0].T.float().cpu().numpy()
1389+
>>> sf.write("audio.wav", audio, pipeline.vae.sampling_rate)
1390+
```
1391+
"""
1392+
cache_dir = kwargs.pop("cache_dir", None)
1393+
force_download = kwargs.pop("force_download", False)
1394+
proxies = kwargs.pop("proxies", None)
1395+
token = kwargs.pop("token", None)
1396+
local_files_only = kwargs.pop("local_files_only", False)
1397+
revision = kwargs.pop("revision", None)
1398+
1399+
load_config_kwargs = {
1400+
"cache_dir": cache_dir,
1401+
"force_download": force_download,
1402+
"proxies": proxies,
1403+
"token": token,
1404+
"local_files_only": local_files_only,
1405+
"revision": revision,
1406+
}
1407+
1408+
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
1409+
orig_class_name = config["_class_name"]
1410+
1411+
text_2_audio_cls = _get_task_class(AUTO_TEXT2AUDIO_PIPELINES_MAPPING, orig_class_name)
1412+
1413+
kwargs = {**load_config_kwargs, **kwargs}
1414+
return text_2_audio_cls.from_pretrained(pretrained_model_or_path, **kwargs)
1415+
1416+
@classmethod
1417+
def from_pipe(cls, pipeline, **kwargs):
1418+
r"""
1419+
Instantiates a text-to-audio Pytorch diffusion pipeline from another instantiated diffusion pipeline class.
1420+
1421+
The from_pipe() method takes care of returning the correct pipeline class instance by finding the text-to-audio
1422+
pipeline linked to the pipeline class using pattern matching on pipeline class name.
1423+
1424+
All the modules the pipeline contains will be used to initialize the new pipeline without reallocating
1425+
additional memory.
1426+
1427+
The pipeline is set in evaluation mode (`model.eval()`) by default.
1428+
1429+
Parameters:
1430+
pipeline (`DiffusionPipeline`):
1431+
an instantiated `DiffusionPipeline` object
1432+
1433+
```py
1434+
>>> import torch
1435+
>>> import soundfile as sf
1436+
>>> from diffusers import AutoPipelineForText2Audio, StableAudioPipeline
1437+
1438+
>>> pipe = StableAudioPipeline.from_pretrained("stabilityai/stable-audio-open-1.0", torch_dtype=torch.float16)
1439+
1440+
>>> pipe_audio = AutoPipelineForText2Audio.from_pipe(pipe)
1441+
>>> output = pipe_audio(
1442+
... "Generate a sound",
1443+
... num_inference_steps=200,
1444+
... audio_end_in_s=10.0,
1445+
... )
1446+
>>> audio = output.audios[0].T.float().cpu().numpy()
1447+
>>> sf.write("audio.wav", audio, pipe_audio.vae.sampling_rate)
1448+
```
1449+
"""
1450+
1451+
original_config = dict(pipeline.config)
1452+
original_cls_name = pipeline.__class__.__name__
1453+
1454+
text_2_audio_cls = _get_task_class(AUTO_TEXT2AUDIO_PIPELINES_MAPPING, original_cls_name)
1455+
1456+
expected_modules, optional_kwargs = text_2_audio_cls._get_signature_keys(text_2_audio_cls)
1457+
1458+
pretrained_model_name_or_path = original_config.pop("_name_or_path", None)
1459+
1460+
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
1461+
original_class_obj = {
1462+
k: pipeline.components[k]
1463+
for k, v in pipeline.components.items()
1464+
if k in expected_modules and k not in passed_class_obj
1465+
}
1466+
1467+
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
1468+
original_pipe_kwargs = {
1469+
k: original_config[k]
1470+
for k, v in original_config.items()
1471+
if k in optional_kwargs and k not in passed_pipe_kwargs
1472+
}
1473+
1474+
additional_pipe_kwargs = [
1475+
k[1:]
1476+
for k in original_config.keys()
1477+
if k.startswith("_") and k[1:] in optional_kwargs and k[1:] not in passed_pipe_kwargs
1478+
]
1479+
for k in additional_pipe_kwargs:
1480+
original_pipe_kwargs[k] = original_config.pop(f"_{k}")
1481+
1482+
text_2_audio_kwargs = {**passed_class_obj, **original_class_obj, **passed_pipe_kwargs, **original_pipe_kwargs}
1483+
1484+
unused_original_config = {
1485+
f"{'' if k.startswith('_') else '_'}{k}": original_config[k]
1486+
for k, v in original_config.items()
1487+
if k not in text_2_audio_kwargs
1488+
}
1489+
1490+
missing_modules = (
1491+
set(expected_modules) - set(text_2_audio_cls._optional_components) - set(text_2_audio_kwargs.keys())
1492+
)
1493+
1494+
if len(missing_modules) > 0:
1495+
raise ValueError(
1496+
f"Pipeline {text_2_audio_cls} expected {expected_modules}, but only {set(list(passed_class_obj.keys()) + list(original_class_obj.keys()))} were passed"
1497+
)
1498+
1499+
model = text_2_audio_cls(**text_2_audio_kwargs)
1500+
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
1501+
model.register_to_config(**unused_original_config)
1502+
1503+
return model

src/diffusers/utils/dummy_pt_objects.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2492,6 +2492,21 @@ def from_pretrained(cls, *args, **kwargs):
24922492
requires_backends(cls, ["torch"])
24932493

24942494

2495+
class AutoPipelineForText2Audio(metaclass=DummyObject):
2496+
_backends = ["torch"]
2497+
2498+
def __init__(self, *args, **kwargs):
2499+
requires_backends(self, ["torch"])
2500+
2501+
@classmethod
2502+
def from_config(cls, *args, **kwargs):
2503+
requires_backends(cls, ["torch"])
2504+
2505+
@classmethod
2506+
def from_pretrained(cls, *args, **kwargs):
2507+
requires_backends(cls, ["torch"])
2508+
2509+
24952510
class AutoPipelineForText2Image(metaclass=DummyObject):
24962511
_backends = ["torch"]
24972512

0 commit comments

Comments
 (0)