Skip to content

Commit 66623a1

Browse files
jiqing-fengvasqu
andauthored
Fix speccht5_tts pipeline (huggingface#42830)
* Fix speccht5_tts pipeline Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * Update src/transformers/pipelines/text_to_audio.py Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com> --------- Signed-off-by: jiqing-feng <jiqing.feng@intel.com> Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com>
1 parent e17b1b8 commit 66623a1

2 files changed

Lines changed: 35 additions & 2 deletions

File tree

src/transformers/pipelines/text_to_audio.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,8 @@ def __init__(self, *args, vocoder=None, sampling_rate=None, **kwargs):
117117
else vocoder
118118
)
119119

120-
if self.model.config.model_type in ["musicgen"]:
121-
# MusicGen expect to use the tokenizer
120+
if self.model.config.model_type in ["musicgen", "speecht5"]:
121+
# MusicGen and SpeechT5 expect to use their tokenizer instead
122122
self.processor = None
123123

124124
self.sampling_rate = sampling_rate

tests/pipelines/test_pipelines_text_to_audio.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import unittest
1616

1717
import numpy as np
18+
import torch
1819

1920
from transformers import (
2021
MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING,
@@ -40,6 +41,38 @@ class TextToAudioPipelineTests(unittest.TestCase):
4041
model_mapping = MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING
4142
# for now only test text_to_waveform and not text_to_spectrogram
4243

44+
@require_torch
45+
def test_small_speecht5_pt(self):
46+
audio_generator = pipeline(task="text-to-audio", model="microsoft/speecht5_tts")
47+
num_channels = 1 # model generates mono audio
48+
forward_params = {
49+
"do_sample": True,
50+
"semantic_max_new_tokens": 5,
51+
"speaker_embeddings": torch.rand(1, 512) * 0.2 - 0.1,
52+
}
53+
54+
outputs = audio_generator("This is a test", forward_params=forward_params)
55+
self.assertEqual({"audio": ANY(np.ndarray), "sampling_rate": 16000}, outputs)
56+
self.assertEqual(len(outputs["audio"].shape), num_channels)
57+
58+
# test two examples side-by-side
59+
outputs = audio_generator(["This is a test", "This is a second test"], forward_params=forward_params)
60+
audio = [output["audio"] for output in outputs]
61+
self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio)
62+
63+
# test batching, this time with parameterization in the forward pass
64+
audio_generator = pipeline(task="text-to-audio", model="microsoft/speecht5_tts")
65+
forward_params = {
66+
"do_sample": False,
67+
"max_new_tokens": 5,
68+
"speaker_embeddings": torch.rand(1, 512) * 0.2 - 0.1,
69+
}
70+
outputs = audio_generator(
71+
["This is a test", "This is a second test"], forward_params=forward_params, batch_size=2
72+
)
73+
audio = [output["audio"] for output in outputs]
74+
self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio)
75+
4376
@require_torch
4477
def test_small_musicgen_pt(self):
4578
music_generator = pipeline(

0 commit comments

Comments
 (0)