1+ # -*- coding: utf-8 -*-
2+ """
3+ Lunarlist TTS model (ONNX)
4+
5+ You can see more about lunarlist tts at `https://link.medium.com/OpPjQis6wBb <https://link.medium.com/OpPjQis6wBb>`_
6+
7+ ONNX port: `https://github.com/PyThaiNLP/thaitts-onnx <https://github.com/PyThaiNLP/thaitts-onnx>`_
8+ """
9+ import tempfile
10+ import numpy as np
11+ import onnxruntime as ort
12+ from huggingface_hub import hf_hub_download
13+
14+
15+ # from https://huggingface.co/lunarlist/tts-thai-last-step
16+ index_list = ['ก' , 'ข' , 'ค' , 'ฆ' , 'ง' , 'จ' , 'ฉ' , 'ช' , 'ซ' , 'ฌ' , 'ญ' , 'ฎ' , 'ฏ' , 'ฐ' , 'ฑ' , 'ฒ' , 'ณ' , 'ด' , 'ต' , 'ถ' , 'ท' , 'ธ' , 'น' , 'บ' , 'ป' , 'ผ' , 'ฝ' , 'พ' , 'ฟ' , 'ภ' , 'ม' , 'ย' , 'ร' , 'ฤ' , 'ล' , 'ว' , 'ศ' , 'ษ' , 'ส' , 'ห' , 'ฬ' , 'อ' , 'ฮ' , 'ะ' , 'ั' , 'า' , 'ำ' , 'ิ' , 'ี' , 'ึ' , 'ื' , 'ุ' , 'ู' , 'เ' , 'แ' , 'โ' , 'ใ' , 'ไ' , 'ๅ' , '็' , '่' , '้' , '๊' , '๋' , '์' , ' ' ]
17+ dict_idx = {k :i for i ,k in enumerate (index_list )}
18+
19+ def clean (text ):
20+ seq = np .array ([[66 ]+ [dict_idx [i ] for i in text if i ]+ [67 ]])
21+ _s = np .array ([len (seq [0 ])])
22+ return seq ,_s
23+
24+ n_mel_channels = 80
25+ n_frames_per_step = 1
26+ attention_rnn_dim = 1024
27+ decoder_rnn_dim = 1024
28+ encoder_embedding_dim = 512
29+
30+ def initialize_decoder_states (memory ):
31+ B = memory .shape [0 ]
32+ MAX_TIME = memory .shape [1 ]
33+
34+ attention_hidden = np .zeros ((B , attention_rnn_dim ), dtype = np .float32 )
35+ attention_cell = np .zeros ((B , attention_rnn_dim ), dtype = np .float32 )
36+
37+ decoder_hidden = np .zeros ((B , decoder_rnn_dim ), dtype = np .float32 )
38+ decoder_cell = np .zeros ((B , decoder_rnn_dim ), dtype = np .float32 )
39+
40+ attention_weights = np .zeros ((B , MAX_TIME ), dtype = np .float32 )
41+ attention_weights_cum = np .zeros ((B , MAX_TIME ), dtype = np .float32 )
42+ attention_context = np .zeros ((B , encoder_embedding_dim ), dtype = np .float32 )
43+
44+ return (
45+ attention_hidden ,
46+ attention_cell ,
47+ decoder_hidden ,
48+ decoder_cell ,
49+ attention_weights ,
50+ attention_weights_cum ,
51+ attention_context ,
52+ )
53+
54+
55+ def get_go_frame (memory ):
56+ B = memory .shape [0 ]
57+ decoder_input = np .zeros ((B , n_mel_channels * n_frames_per_step ), dtype = np .float32 )
58+ return decoder_input
59+
60+
61+ def sigmoid (x ):
62+ return np .exp (- np .logaddexp (0 , - x ))
63+
64+
65+ def parse_decoder_outputs (mel_outputs , gate_outputs , alignments ):
66+ # (T_out, B) -> (B, T_out)
67+ alignments = np .stack (alignments ).transpose ((1 , 0 , 2 , 3 ))
68+ # (T_out, B) -> (B, T_out)
69+ # Add a -1 to prevent squeezing the batch dimension in case
70+ # batch is 1
71+ gate_outputs = np .stack (gate_outputs ).squeeze (- 1 ).transpose ((1 , 0 , 2 ))
72+ # (T_out, B, n_mel_channels) -> (B, T_out, n_mel_channels)
73+ mel_outputs = np .stack (mel_outputs ).transpose ((1 , 0 , 2 , 3 ))
74+ # decouple frames per step
75+ mel_outputs = mel_outputs .reshape (mel_outputs .shape [0 ], - 1 , n_mel_channels )
76+ # (B, T_out, n_mel_channels) -> (B, n_mel_channels, T_out)
77+ mel_outputs = mel_outputs .transpose ((0 , 2 , 1 ))
78+
79+ return mel_outputs , gate_outputs , alignments
80+
81+
82+ # only numpy operations
83+ def inference (text , encoder , decoder_iter , postnet ):
84+ sequences , sequence_lengths = clean (text )
85+
86+ # print("Running Tacotron2 Encoder")
87+ inputs = {"seq" : sequences , "seq_len" : sequence_lengths }
88+ memory , processed_memory , _ = encoder .run (None , inputs )
89+
90+ # print("Running Tacotron2 Decoder")
91+ mel_lengths = np .zeros ([memory .shape [0 ]], dtype = np .int32 )
92+ not_finished = np .ones ([memory .shape [0 ]], dtype = np .int32 )
93+ mel_outputs , gate_outputs , alignments = [], [], []
94+ gate_threshold = 0.5
95+ max_decoder_steps = 5000
96+ first_iter = True
97+
98+ (
99+ attention_hidden ,
100+ attention_cell ,
101+ decoder_hidden ,
102+ decoder_cell ,
103+ attention_weights ,
104+ attention_weights_cum ,
105+ attention_context ,
106+ ) = initialize_decoder_states (memory )
107+
108+ decoder_input = get_go_frame (memory )
109+
110+ while True :
111+ inputs = {
112+ "decoder_input" : decoder_input ,
113+ "attention_hidden" : attention_hidden ,
114+ "attention_cell" : attention_cell ,
115+ "decoder_hidden" : decoder_hidden ,
116+ "decoder_cell" : decoder_cell ,
117+ "attention_weights" : attention_weights ,
118+ "attention_weights_cum" : attention_weights_cum ,
119+ "attention_context" : attention_context ,
120+ "memory" : memory ,
121+ "processed_memory" : processed_memory ,
122+ }
123+ (
124+ mel_output ,
125+ gate_output ,
126+ attention_hidden ,
127+ attention_cell ,
128+ decoder_hidden ,
129+ decoder_cell ,
130+ attention_weights ,
131+ attention_weights_cum ,
132+ attention_context ,
133+ ) = decoder_iter .run (None , inputs )
134+
135+ if first_iter :
136+ mel_outputs = [np .expand_dims (mel_output , 2 )]
137+ gate_outputs = [np .expand_dims (gate_output , 2 )]
138+ alignments = [np .expand_dims (attention_weights , 2 )]
139+ first_iter = False
140+ else :
141+ mel_outputs += [np .expand_dims (mel_output , 2 )]
142+ gate_outputs += [np .expand_dims (gate_output , 2 )]
143+ alignments += [np .expand_dims (attention_weights , 2 )]
144+
145+ dec = np .less (sigmoid (gate_output ), gate_threshold )
146+ dec = np .squeeze (dec , axis = 1 )
147+ not_finished = not_finished * dec
148+ mel_lengths += not_finished
149+
150+ if not_finished .sum () == 0 :
151+ # print("Stopping after ", len(mel_outputs), " decoder steps")
152+ break
153+ if len (mel_outputs ) == max_decoder_steps :
154+ # print("Warning! Reached max decoder steps")
155+ break
156+
157+ decoder_input = mel_output
158+
159+ mel_outputs , gate_outputs , alignments = parse_decoder_outputs (
160+ mel_outputs , gate_outputs , alignments
161+ )
162+
163+ # print("Running Tacotron2 PostNet")
164+ inputs = {"mel_spec" : mel_outputs }
165+ mel_outputs_postnet = postnet .run (None , inputs )
166+
167+ return mel_outputs_postnet
168+
169+ class LunarlistONNX :
170+ def __init__ (self ) -> None :
171+ self .encoder = ort .InferenceSession (hf_hub_download (repo_id = "pythainlp/thaitts-onnx" ,filename = "tacotron2encoder-th.onnx" ))
172+ self .decoder = ort .InferenceSession (hf_hub_download (repo_id = "pythainlp/thaitts-onnx" ,filename = "tacotron2decoder-th.onnx" ))
173+ self .postnet = ort .InferenceSession (hf_hub_download (repo_id = "pythainlp/thaitts-onnx" ,filename = "tacotron2postnet-th.onnx" ))
174+ self .hifi = ort .InferenceSession (hf_hub_download (repo_id = "pythainlp/thaitts-onnx" ,filename = "vocoder.onnx" ))
175+ def tts (self , text : str ):
176+ mel = inference (text , self .encoder , self .decoder , self .postnet )
177+ return self .hifi .run (None , {"spec" : mel [0 ]})
178+ def __call__ (self , text : str ,return_type : str = "file" , filename : str = None ):
179+ wavs = self .tts (text )
180+ if return_type == "waveform" :
181+ return wavs [0 ][0 , 0 , :]
182+ import soundfile as sf
183+ if filename != None :
184+ sf .write (filename , wavs [0 ][0 , 0 , :], 22050 )
185+ return filename
186+ else :
187+ with tempfile .NamedTemporaryFile (suffix = ".wav" , delete = False ) as fp :
188+ sf .write (fp .name , wavs [0 ][0 , 0 , :], 22050 )
189+ return fp .name
0 commit comments