Skip to content

Commit 4ac3e90

Browse files
committed
[Feat] support AutoPipelineForText2Audio
Signed-off-by: Lancer <maruixiang6688@gmail.com>
1 parent c8c8401 commit 4ac3e90

6 files changed

Lines changed: 293 additions & 2 deletions

File tree

docs/source/en/api/pipelines/auto_pipeline.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,10 @@ The `AutoPipeline` is designed to make it easy to load a checkpoint for a task w
3737
- all
3838
- from_pretrained
3939
- from_pipe
40+
41+
## AutoPipelineForText2Audio
42+
43+
[[autodoc]] AutoPipelineForText2Audio
44+
- all
45+
- from_pretrained
46+
- from_pipe

docs/source/en/tutorials/autopipeline.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,6 @@ pipeline = AutoPipelineForImage2Image.from_pretrained(
6262
"ValueError: AutoPipeline can't find a pipeline linked to ShapEImg2ImgPipeline for None"
6363
```
6464

65-
There are three types of [AutoPipeline](../api/models/auto_model) classes, [`AutoPipelineForText2Image`], [`AutoPipelineForImage2Image`] and [`AutoPipelineForInpainting`]. Each of these classes have a predefined mapping, linking a pipeline to their task-specific subclass.
65+
There are four types of [AutoPipeline](../api/models/auto_model) classes, [`AutoPipelineForText2Image`], [`AutoPipelineForImage2Image`], [`AutoPipelineForInpainting`] and [`AutoPipelineForText2Audio`]. Each of these classes have a predefined mapping, linking a pipeline to their task-specific subclass.
6666

6767
When [`~AutoPipelineForText2Image.from_pretrained`] is called, it extracts the class name from the `model_index.json` file and selects the appropriate pipeline subclass for the task based on the mapping.

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,7 @@
337337
"AutoPipelineForImage2Image",
338338
"AutoPipelineForInpainting",
339339
"AutoPipelineForText2Image",
340+
"AutoPipelineForText2Audio",
340341
"ConsistencyModelPipeline",
341342
"DanceDiffusionPipeline",
342343
"DDIMPipeline",
@@ -1142,6 +1143,7 @@
11421143
AutoPipelineForImage2Image,
11431144
AutoPipelineForInpainting,
11441145
AutoPipelineForText2Image,
1146+
AutoPipelineForText2Audio,
11451147
BlipDiffusionControlNetPipeline,
11461148
BlipDiffusionPipeline,
11471149
CLIPImageProjection,

src/diffusers/pipelines/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
"AutoPipelineForImage2Image",
4646
"AutoPipelineForInpainting",
4747
"AutoPipelineForText2Image",
48+
"AutoPipelineForText2Audio",
4849
]
4950
_import_structure["consistency_models"] = ["ConsistencyModelPipeline"]
5051
_import_structure["ddim"] = ["DDIMPipeline"]
@@ -539,6 +540,7 @@
539540
AutoPipelineForImage2Image,
540541
AutoPipelineForInpainting,
541542
AutoPipelineForText2Image,
543+
AutoPipelineForText2Audio,
542544
)
543545
from .consistency_models import ConsistencyModelPipeline
544546
from .ddim import DDIMPipeline

src/diffusers/pipelines/auto_pipeline.py

Lines changed: 266 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from ..configuration_utils import ConfigMixin
2121
from ..models.controlnets import ControlNetUnionModel
2222
from ..utils import is_sentencepiece_available
23+
from .audioldm2 import AudioLDM2Pipeline
2324
from .aura_flow import AuraFlowPipeline
2425
from .chroma import ChromaPipeline
2526
from .cogview3 import CogView3PlusPipeline
@@ -75,6 +76,7 @@
7576
)
7677
from .kandinsky3 import Kandinsky3Img2ImgPipeline, Kandinsky3Pipeline
7778
from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline
79+
from .longcat_audio_dit import LongCatAudioDiTPipeline
7880
from .lumina import LuminaPipeline
7981
from .lumina2 import Lumina2Pipeline
8082
from .nucleusmoe_image import NucleusMoEImagePipeline
@@ -109,6 +111,7 @@
109111
QwenImagePipeline,
110112
)
111113
from .sana import SanaPipeline
114+
from .stable_audio import StableAudioPipeline
112115
from .stable_cascade import StableCascadeCombinedPipeline, StableCascadeDecoderPipeline
113116
from .stable_diffusion import (
114117
StableDiffusionImg2ImgPipeline,
@@ -192,6 +195,14 @@
192195
]
193196
)
194197

198+
AUTO_TEXT2AUDIO_PIPELINES_MAPPING = OrderedDict(
199+
[
200+
("audioldm2", AudioLDM2Pipeline),
201+
("stable-audio", StableAudioPipeline),
202+
("longcat-audio-dit", LongCatAudioDiTPipeline),
203+
]
204+
)
205+
195206
AUTO_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict(
196207
[
197208
("stable-diffusion", StableDiffusionImg2ImgPipeline),
@@ -301,6 +312,7 @@
301312
AUTO_TEXT2VIDEO_PIPELINES_MAPPING,
302313
AUTO_IMAGE2VIDEO_PIPELINES_MAPPING,
303314
AUTO_VIDEO2VIDEO_PIPELINES_MAPPING,
315+
AUTO_TEXT2AUDIO_PIPELINES_MAPPING,
304316
_AUTO_TEXT2IMAGE_DECODER_PIPELINES_MAPPING,
305317
_AUTO_IMAGE2IMAGE_DECODER_PIPELINES_MAPPING,
306318
_AUTO_INPAINT_DECODER_PIPELINES_MAPPING,
@@ -847,7 +859,6 @@ def from_pipe(cls, pipeline, **kwargs):
847859

848860
original_config = dict(pipeline.config)
849861
original_cls_name = pipeline.__class__.__name__
850-
851862
# derive the pipeline class to instantiate
852863
image_2_image_cls = _get_task_class(AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, original_cls_name)
853864

@@ -1235,3 +1246,257 @@ def from_pipe(cls, pipeline, **kwargs):
12351246
model.register_to_config(**unused_original_config)
12361247

12371248
return model
1249+
1250+
class AutoPipelineForText2Audio(ConfigMixin):
1251+
r"""
1252+
1253+
[`AutoPipelineForText2Audio`] is a generic pipeline class that instantiates a text-to-audio pipeline class. The
1254+
specific underlying pipeline class is automatically selected from either the
1255+
[`~AutoPipelineForText2Audio.from_pretrained`] or [`~AutoPipelineForText2Audio.from_pipe`] methods.
1256+
1257+
This class cannot be instantiated using `__init__()` (throws an error).
1258+
1259+
Class attributes:
1260+
1261+
- **config_name** (`str`) -- The configuration filename that stores the class and module names of all the
1262+
diffusion pipeline's components.
1263+
1264+
"""
1265+
1266+
config_name = "model_index.json"
1267+
1268+
def __init__(self, *args, **kwargs):
1269+
raise EnvironmentError(
1270+
f"{self.__class__.__name__} is designed to be instantiated "
1271+
f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or "
1272+
f"`{self.__class__.__name__}.from_pipe(pipeline)` methods."
1273+
)
1274+
1275+
@classmethod
1276+
@validate_hf_hub_args
1277+
def from_pretrained(cls, pretrained_model_or_path, **kwargs):
1278+
r"""
1279+
Instantiates a text-to-audio Pytorch diffusion pipeline from pretrained pipeline weight.
1280+
1281+
The from_pretrained() method takes care of returning the correct pipeline class instance by:
1282+
1. Detect the pipeline class of the pretrained_model_or_path based on the _class_name property of its
1283+
config object
1284+
2. Find the text-to-audio pipeline linked to the pipeline class using pattern matching on pipeline class
1285+
name.
1286+
1287+
The pipeline is set in evaluation mode (`model.eval()`) by default.
1288+
1289+
Parameters:
1290+
pretrained_model_or_path (`str` or `os.PathLike`, *optional*):
1291+
Can be either:
1292+
1293+
- A string, the *repo id* (for example `stabilityai/stable-audio-open-1.0`) of a pretrained pipeline
1294+
hosted on the Hub.
1295+
- A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights
1296+
saved using
1297+
[`~DiffusionPipeline.save_pretrained`].
1298+
torch_dtype (`torch.dtype`, *optional*):
1299+
Override the default `torch.dtype` and load the model with another dtype.
1300+
force_download (`bool`, *optional*, defaults to `False`):
1301+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
1302+
cached versions if they exist.
1303+
cache_dir (`str | os.PathLike`, *optional*):
1304+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
1305+
is not used.
1306+
1307+
proxies (`dict[str, str]`, *optional*):
1308+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
1309+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
1310+
output_loading_info(`bool`, *optional*, defaults to `False`):
1311+
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
1312+
local_files_only (`bool`, *optional*, defaults to `False`):
1313+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
1314+
won't be downloaded from the Hub.
1315+
token (`str` or *bool*, *optional*):
1316+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
1317+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
1318+
revision (`str`, *optional*, defaults to `"main"`):
1319+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
1320+
allowed by Git.
1321+
custom_revision (`str`, *optional*, defaults to `"main"`):
1322+
The specific model version to use. It can be a branch name, a tag name, or a commit id similar to
1323+
`revision` when loading a custom pipeline from the Hub. It can be a 🤗 Diffusers version when loading a
1324+
custom pipeline from GitHub, otherwise it defaults to `"main"` when loading from the Hub.
1325+
mirror (`str`, *optional*):
1326+
Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
1327+
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
1328+
information.
1329+
device_map (`str` or `dict[str, int | str | torch.device]`, *optional*):
1330+
A map that specifies where each submodule should go. It doesn't need to be defined for each
1331+
parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
1332+
same device.
1333+
1334+
Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
1335+
more information about each option see [designing a device
1336+
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
1337+
max_memory (`Dict`, *optional*):
1338+
A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
1339+
each GPU and the available CPU RAM if unset.
1340+
offload_folder (`str` or `os.PathLike`, *optional*):
1341+
The path to offload weights if device_map contains the value `"disk"`.
1342+
offload_state_dict (`bool`, *optional*):
1343+
If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
1344+
the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
1345+
when there is some disk offload.
1346+
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
1347+
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
1348+
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
1349+
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
1350+
argument to `True` will raise an error.
1351+
use_safetensors (`bool`, *optional*, defaults to `None`):
1352+
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
1353+
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
1354+
weights. If set to `False`, safetensors weights are not loaded.
1355+
kwargs (remaining dictionary of keyword arguments, *optional*):
1356+
Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
1357+
class). The overwritten components are passed directly to the pipelines `__init__` method. See example
1358+
below for more information.
1359+
variant (`str`, *optional*):
1360+
Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
1361+
loading `from_flax`.
1362+
1363+
> [!TIP] > To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in
1364+
with `hf > auth login`.
1365+
1366+
Examples:
1367+
1368+
```py
1369+
>>> import torch
1370+
>>> import soundfile as sf
1371+
>>> from diffusers import AutoPipelineForText2Audio
1372+
1373+
>>> pipeline = AutoPipelineForText2Audio.from_pretrained(
1374+
... "stabilityai/stable-audio-open-1.0",
1375+
... torch_dtype=torch.float16
1376+
... )
1377+
>>> pipeline = pipeline.to("cuda")
1378+
1379+
>>> output = pipeline(
1380+
... "Generate a male voice reading a paragraph",
1381+
... num_inference_steps=200,
1382+
... audio_end_in_s=10.0,
1383+
... )
1384+
>>> audio = output.audios[0].T.float().cpu().numpy()
1385+
>>> sf.write("audio.wav", audio, pipeline.vae.sampling_rate)
1386+
```
1387+
"""
1388+
cache_dir = kwargs.pop("cache_dir", None)
1389+
force_download = kwargs.pop("force_download", False)
1390+
proxies = kwargs.pop("proxies", None)
1391+
token = kwargs.pop("token", None)
1392+
local_files_only = kwargs.pop("local_files_only", False)
1393+
revision = kwargs.pop("revision", None)
1394+
1395+
load_config_kwargs = {
1396+
"cache_dir": cache_dir,
1397+
"force_download": force_download,
1398+
"proxies": proxies,
1399+
"token": token,
1400+
"local_files_only": local_files_only,
1401+
"revision": revision,
1402+
}
1403+
1404+
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
1405+
orig_class_name = config["_class_name"]
1406+
1407+
text_2_audio_cls = _get_task_class(AUTO_TEXT2AUDIO_PIPELINES_MAPPING, orig_class_name)
1408+
1409+
kwargs = {**load_config_kwargs, **kwargs}
1410+
return text_2_audio_cls.from_pretrained(pretrained_model_or_path, **kwargs)
1411+
1412+
@classmethod
1413+
def from_pipe(cls, pipeline, **kwargs):
1414+
r"""
1415+
Instantiates a text-to-audio Pytorch diffusion pipeline from another instantiated diffusion pipeline class.
1416+
1417+
The from_pipe() method takes care of returning the correct pipeline class instance by finding the text-to-audio
1418+
pipeline linked to the pipeline class using pattern matching on pipeline class name.
1419+
1420+
All the modules the pipeline contains will be used to initialize the new pipeline without reallocating
1421+
additional memory.
1422+
1423+
The pipeline is set in evaluation mode (`model.eval()`) by default.
1424+
1425+
Parameters:
1426+
pipeline (`DiffusionPipeline`):
1427+
an instantiated `DiffusionPipeline` object
1428+
1429+
```py
1430+
>>> import torch
1431+
>>> import soundfile as sf
1432+
>>> from diffusers import AutoPipelineForText2Audio, StableAudioPipeline
1433+
1434+
>>> pipe = StableAudioPipeline.from_pretrained(
1435+
... "stabilityai/stable-audio-open-1.0",
1436+
... torch_dtype=torch.float16
1437+
... )
1438+
1439+
>>> pipe_audio = AutoPipelineForText2Audio.from_pipe(pipe)
1440+
>>> output = pipe_audio(
1441+
... "Generate a sound",
1442+
... num_inference_steps=200,
1443+
... audio_end_in_s=10.0,
1444+
... )
1445+
>>> audio = output.audios[0].T.float().cpu().numpy()
1446+
>>> sf.write("audio.wav", audio, pipe_audio.vae.sampling_rate)
1447+
```
1448+
"""
1449+
1450+
original_config = dict(pipeline.config)
1451+
original_cls_name = pipeline.__class__.__name__
1452+
1453+
text_2_audio_cls = _get_task_class(AUTO_TEXT2AUDIO_PIPELINES_MAPPING, original_cls_name)
1454+
1455+
expected_modules, optional_kwargs = text_2_audio_cls._get_signature_keys(text_2_audio_cls)
1456+
1457+
pretrained_model_name_or_path = original_config.pop("_name_or_path", None)
1458+
1459+
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
1460+
original_class_obj = {
1461+
k: pipeline.components[k]
1462+
for k, v in pipeline.components.items()
1463+
if k in expected_modules and k not in passed_class_obj
1464+
}
1465+
1466+
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
1467+
original_pipe_kwargs = {
1468+
k: original_config[k]
1469+
for k, v in original_config.items()
1470+
if k in optional_kwargs and k not in passed_pipe_kwargs
1471+
}
1472+
1473+
additional_pipe_kwargs = [
1474+
k[1:]
1475+
for k in original_config.keys()
1476+
if k.startswith("_") and k[1:] in optional_kwargs and k[1:] not in passed_pipe_kwargs
1477+
]
1478+
for k in additional_pipe_kwargs:
1479+
original_pipe_kwargs[k] = original_config.pop(f"_{k}")
1480+
1481+
text_2_audio_kwargs = {**passed_class_obj, **original_class_obj, **passed_pipe_kwargs, **original_pipe_kwargs}
1482+
1483+
unused_original_config = {
1484+
f"{'' if k.startswith('_') else '_'}{k}": original_config[k]
1485+
for k, v in original_config.items()
1486+
if k not in text_2_audio_kwargs
1487+
}
1488+
1489+
missing_modules = (
1490+
set(expected_modules) - set(text_2_audio_cls._optional_components) - set(text_2_audio_kwargs.keys())
1491+
)
1492+
1493+
if len(missing_modules) > 0:
1494+
raise ValueError(
1495+
f"Pipeline {text_2_audio_cls} expected {expected_modules}, but only {set(list(passed_class_obj.keys()) + list(original_class_obj.keys()))} were passed"
1496+
)
1497+
1498+
model = text_2_audio_cls(**text_2_audio_kwargs)
1499+
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
1500+
model.register_to_config(**unused_original_config)
1501+
1502+
return model

src/diffusers/utils/dummy_pt_objects.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2372,6 +2372,21 @@ def from_pretrained(cls, *args, **kwargs):
23722372
requires_backends(cls, ["torch"])
23732373

23742374

2375+
class AutoPipelineForText2Audio(metaclass=DummyObject):
2376+
_backends = ["torch"]
2377+
2378+
def __init__(self, *args, **kwargs):
2379+
requires_backends(self, ["torch"])
2380+
2381+
@classmethod
2382+
def from_config(cls, *args, **kwargs):
2383+
requires_backends(cls, ["torch"])
2384+
2385+
@classmethod
2386+
def from_pretrained(cls, *args, **kwargs):
2387+
requires_backends(cls, ["torch"])
2388+
2389+
23752390
class AutoPipelineForText2Image(metaclass=DummyObject):
23762391
_backends = ["torch"]
23772392

0 commit comments

Comments
 (0)