99import riva .client .proto .riva_tts_pb2_grpc as rtts_srv
1010from riva .client import Auth
1111from riva .client .proto .riva_audio_pb2 import AudioEncoding
12-
12+ import wave
1313
1414class SpeechSynthesisService :
1515 """
@@ -34,20 +34,27 @@ def synthesize(
3434 language_code : str = 'en-US' ,
3535 encoding : AudioEncoding = AudioEncoding .LINEAR_PCM ,
3636 sample_rate_hz : int = 44100 ,
37+ audio_prompt_file : Optional [str ] = None ,
38+ audio_prompt_encoding : AudioEncoding = AudioEncoding .LINEAR_PCM ,
39+ quality : int = 20 ,
3740 future : bool = False ,
3841 ) -> Union [rtts .SynthesizeSpeechResponse , _MultiThreadedRendezvous ]:
3942 """
4043 Synthesizes an entire audio for text :param:`text`.
4144
4245 Args:
43- text (:obj:`str`): an input text.
44- voice_name (:obj:`str`, `optional`): a name of the voice, e.g. ``"English-US-Female-1"``. You may find
46+ text (:obj:`str`): An input text.
47+ voice_name (:obj:`str`, `optional`): A name of the voice, e.g. ``"English-US-Female-1"``. You may find
4548 available voices in server logs or in server model directory. If this parameter is :obj:`None`, then
4649 a server will select the first available model with correct :param:`language_code` value.
4750 language_code (:obj:`str`): a language to use.
48- encoding (:obj:`AudioEncoding`): an output audio encoding, e.g. ``AudioEncoding.LINEAR_PCM``.
49- sample_rate_hz (:obj:`int`): number of frames per second in output audio.
50- future (:obj:`bool`, defaults to :obj:`False`): whether to return an async result instead of usual
51+ encoding (:obj:`AudioEncoding`): An output audio encoding, e.g. ``AudioEncoding.LINEAR_PCM``.
52+ sample_rate_hz (:obj:`int`): Number of frames per second in output audio.
53+ audio_prompt_file (:obj:`str`): An audio prompt file location for zero shot model.
54+ audio_prompt_encoding: (:obj:`AudioEncoding`): Encoding of audio prompt file, e.g. ``AudioEncoding.LINEAR_PCM``.
55+ quality: (:obj:`int`): This defines the number of times decoder is run. Higher number improves quality of generated
56+ audio but also takes longer to generate the audio. Ranges between 1-40.
57+ future (:obj:`bool`, defaults to :obj:`False`): Whether to return an async result instead of usual
5158 response. You can get a response by calling ``result()`` method of the future object.
5259
5360 Returns:
@@ -64,6 +71,16 @@ def synthesize(
6471 )
6572 if voice_name is not None :
6673 req .voice_name = voice_name
74+ if audio_prompt_file is not None :
75+ with wave .open (str (audio_prompt_file ), 'rb' ) as wf :
76+ rate = wf .getframerate ()
77+ req .zero_shot_data .sample_rate = rate
78+ with audio_prompt_file .open ('rb' ) as wav_f :
79+ audio_data = wav_f .read ()
80+ req .zero_shot_data .audio_prompt = audio_data
81+ req .zero_shot_data .encoding = audio_prompt_encoding
82+ req .zero_shot_data .quality = quality
83+
6784 func = self .stub .Synthesize .future if future else self .stub .Synthesize
6885 return func (req , metadata = self .auth .get_auth_metadata ())
6986
@@ -74,19 +91,26 @@ def synthesize_online(
7491 language_code : str = 'en-US' ,
7592 encoding : AudioEncoding = AudioEncoding .LINEAR_PCM ,
7693 sample_rate_hz : int = 44100 ,
94+ audio_prompt_file : Optional [str ] = None ,
95+ audio_prompt_encoding : AudioEncoding = AudioEncoding .LINEAR_PCM ,
96+ quality : int = 20 ,
7797 ) -> Generator [rtts .SynthesizeSpeechResponse , None , None ]:
7898 """
7999 Synthesizes and yields output audio chunks for text :param:`text` as the chunks
80100 becoming available.
81101
82102 Args:
83- text (:obj:`str`): an input text.
84- voice_name (:obj:`str`, `optional`): a name of the voice, e.g. ``"English-US-Female-1"``. You may find
103+ text (:obj:`str`): An input text.
104+ voice_name (:obj:`str`, `optional`): A name of the voice, e.g. ``"English-US-Female-1"``. You may find
85105 available voices in server logs or in server model directory. If this parameter is :obj:`None`, then
86106 a server will select the first available model with correct :param:`language_code` value.
87- language_code (:obj:`str`): a language to use.
88- encoding (:obj:`AudioEncoding`): an output audio encoding, e.g. ``AudioEncoding.LINEAR_PCM``.
89- sample_rate_hz (:obj:`int`): number of frames per second in output audio.
107+ language_code (:obj:`str`): A language to use.
108+ encoding (:obj:`AudioEncoding`): An output audio encoding, e.g. ``AudioEncoding.LINEAR_PCM``.
109+ sample_rate_hz (:obj:`int`): Number of frames per second in output audio.
110+ audio_prompt_file (:obj:`str`): An audio prompt file location for zero shot model.
111+ audio_prompt_encoding: (:obj:`AudioEncoding`): Encoding of audio prompt file, e.g. ``AudioEncoding.LINEAR_PCM``.
112+ quality: (:obj:`int`): This defines the number of times decoder is run. Higher number improves quality of generated
113+ audio but also takes longer to generate the audio. Ranges between 1-40.
90114
91115 Yields:
92116 :obj:`riva.client.proto.riva_tts_pb2.SynthesizeSpeechResponse`: a response with output. You may find
@@ -103,4 +127,15 @@ def synthesize_online(
103127 )
104128 if voice_name is not None :
105129 req .voice_name = voice_name
130+
131+ if audio_prompt_file is not None :
132+ with wave .open (str (audio_prompt_file ), 'rb' ) as wf :
133+ rate = wf .getframerate ()
134+ req .zero_shot_data .sample_rate = rate
135+ with audio_prompt_file .open ('rb' ) as wav_f :
136+ audio_data = wav_f .read ()
137+ req .zero_shot_data .audio_prompt = audio_data
138+ req .zero_shot_data .encoding = audio_prompt_encoding
139+ req .zero_shot_data .quality = quality
140+
106141 return self .stub .SynthesizeOnline (req , metadata = self .auth .get_auth_metadata ())
0 commit comments