Skip to content

Commit 3508e48

Browse files
committed
PyThaiTTS v0.3.0
- Add Lunarlist TTS model (ONNX) - Change default model to Lunarlist TTS model
1 parent c51b102 commit 3508e48

8 files changed

Lines changed: 520 additions & 21 deletions

File tree

notebook/cat.wav

48.5 KB
Binary file not shown.

notebook/use_lunarlist_model_onnx.ipynb

Lines changed: 307 additions & 0 deletions
Large diffs are not rendered by default.

pythaitts/__init__.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,16 @@
22
"""
33
PyThaiTTS
44
"""
5-
__version__ = "0.2.1"
5+
__version__ = "0.3.0"
66

77

88
class TTS:
9-
def __init__(self, pretrained="khanomtan", mode="last_checkpoint", version="1.0", device:str="cpu") -> None:
9+
def __init__(self, pretrained="lunarlist_onnx", mode="last_checkpoint", version="1.0", device:str="cpu") -> None:
1010
"""
11-
:param str pretrained: TTS pretrained (khanomtan, lunarlist)
12-
:param str mode: pretrained mode
11+
:param str pretrained: TTS pretrained (lunarlist_onnx, khanomtan, lunarlist)
12+
:param str mode: pretrained mode (lunarlist_onnx don't support)
1313
:param str version: model version (default is 1.0 or 1.1)
14+
:param str device: device for running model. (lunarlist_onnx support CPU only.)
1415
1516
**Options for mode**
1617
* *last_checkpoint* (default) - last checkpoint of model
@@ -21,6 +22,11 @@ def __init__(self, pretrained="khanomtan", mode="last_checkpoint", version="1.0"
2122
2223
For lunarlist tts model, you must to install nemo before use the model by pip install nemo_toolkit['tts'].
2324
You can see more about lunarlist tts at `https://link.medium.com/OpPjQis6wBb <https://link.medium.com/OpPjQis6wBb>`_
25+
26+
For lunarlist_onnx tts model, \
27+
You can see more about lunarlist tts at `https://github.com/PyThaiNLP/thaitts-onnx <https://github.com/PyThaiNLP/thaitts-onnx>`_
28+
29+
2430
2531
"""
2632
self.pretrained = pretrained
@@ -32,11 +38,14 @@ def load_pretrained(self,version):
3238
"""
3339
Load pretrained
3440
"""
35-
if self.pretrained == "khanomtan":
36-
from pythaitts.pretrained import KhanomTan
41+
if self.pretrained == "lunarlist_onnx":
42+
from pythaitts.pretrained.lunarlist_onnx import LunarlistONNX
43+
self.model = LunarlistONNX()
44+
elif self.pretrained == "khanomtan":
45+
from pythaitts.pretrained.khanomtan_tts import KhanomTan
3746
self.model = KhanomTan(mode=self.mode, version=version)
3847
elif self.pretrained == "lunarlist":
39-
from pythaitts.pretrained import LunarlistModel
48+
from pythaitts.pretrained.lunarlist_model import LunarlistModel
4049
self.model = LunarlistModel(mode=self.mode, device=self.device)
4150
else:
4251
raise NotImplemented(
@@ -53,7 +62,7 @@ def tts(self, text: str, speaker_idx: str = "Linda", language_idx: str = "th-th"
5362
:param str return_type: return type (default is file)
5463
:param str filename: path filename for save wav file if return_type is file.
5564
"""
56-
if self.pretrained == "lunarlist":
65+
if self.pretrained == "lunarlist" or self.pretrained == "lunarlist_onnx":
5766
return self.model(text=text,return_type=return_type,filename=filename)
5867
return self.model(
5968
text=text,

pythaitts/pretrained/__init__.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +0,0 @@
1-
# -*- coding: utf-8 -*-
2-
from pythaitts.pretrained.khanomtan_tts import KhanomTan
3-
from pythaitts.pretrained.lunarlist_model import LunarlistModel
4-
5-
__all__ = [
6-
"KhanomTan",
7-
"LunarlistModel"
8-
]

pythaitts/pretrained/lunarlist_model.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55
You can see more about lunarlist tts at `https://link.medium.com/OpPjQis6wBb <https://link.medium.com/OpPjQis6wBb>`_
66
"""
77
import tempfile
8-
import torch
8+
try:
9+
import torch
10+
except ImportError:
11+
raise ImportError("You must to install torch before use this model.")
912

1013

1114
class LunarlistModel:
Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Lunarlist TTS model (ONNX)
4+
5+
You can see more about lunarlist tts at `https://link.medium.com/OpPjQis6wBb <https://link.medium.com/OpPjQis6wBb>`_
6+
7+
ONNX port: `https://github.com/PyThaiNLP/thaitts-onnx <https://github.com/PyThaiNLP/thaitts-onnx>`_
8+
"""
9+
import tempfile
10+
import numpy as np
11+
import onnxruntime as ort
12+
from huggingface_hub import hf_hub_download
13+
14+
15+
# from https://huggingface.co/lunarlist/tts-thai-last-step
16+
index_list=['ก', 'ข', 'ค', 'ฆ', 'ง', 'จ', 'ฉ', 'ช', 'ซ', 'ฌ', 'ญ', 'ฎ', 'ฏ', 'ฐ', 'ฑ', 'ฒ', 'ณ', 'ด', 'ต', 'ถ', 'ท', 'ธ', 'น', 'บ', 'ป', 'ผ', 'ฝ', 'พ', 'ฟ', 'ภ', 'ม', 'ย', 'ร', 'ฤ', 'ล', 'ว', 'ศ', 'ษ', 'ส', 'ห', 'ฬ', 'อ', 'ฮ', 'ะ', 'ั', 'า', 'ำ', 'ิ', 'ี', 'ึ', 'ื', 'ุ', 'ู', 'เ', 'แ', 'โ', 'ใ', 'ไ', 'ๅ', '็', '่', '้', '๊', '๋', '์', ' ']
17+
dict_idx={k:i for i,k in enumerate(index_list)}
18+
19+
def clean(text):
20+
seq = np.array([[66]+[dict_idx[i] for i in text if i]+[67]])
21+
_s=np.array([len(seq[0])])
22+
return seq,_s
23+
24+
n_mel_channels = 80
25+
n_frames_per_step = 1
26+
attention_rnn_dim = 1024
27+
decoder_rnn_dim=1024
28+
encoder_embedding_dim=512
29+
30+
def initialize_decoder_states(memory):
31+
B = memory.shape[0]
32+
MAX_TIME = memory.shape[1]
33+
34+
attention_hidden = np.zeros((B, attention_rnn_dim), dtype=np.float32)
35+
attention_cell = np.zeros((B, attention_rnn_dim), dtype=np.float32)
36+
37+
decoder_hidden = np.zeros((B, decoder_rnn_dim), dtype=np.float32)
38+
decoder_cell = np.zeros((B, decoder_rnn_dim), dtype=np.float32)
39+
40+
attention_weights = np.zeros((B, MAX_TIME), dtype=np.float32)
41+
attention_weights_cum = np.zeros((B, MAX_TIME), dtype=np.float32)
42+
attention_context = np.zeros((B, encoder_embedding_dim), dtype=np.float32)
43+
44+
return (
45+
attention_hidden,
46+
attention_cell,
47+
decoder_hidden,
48+
decoder_cell,
49+
attention_weights,
50+
attention_weights_cum,
51+
attention_context,
52+
)
53+
54+
55+
def get_go_frame(memory):
56+
B = memory.shape[0]
57+
decoder_input = np.zeros((B, n_mel_channels*n_frames_per_step), dtype=np.float32)
58+
return decoder_input
59+
60+
61+
def sigmoid(x):
62+
return np.exp(-np.logaddexp(0, -x))
63+
64+
65+
def parse_decoder_outputs(mel_outputs, gate_outputs, alignments):
66+
# (T_out, B) -> (B, T_out)
67+
alignments = np.stack(alignments).transpose((1, 0, 2, 3))
68+
# (T_out, B) -> (B, T_out)
69+
# Add a -1 to prevent squeezing the batch dimension in case
70+
# batch is 1
71+
gate_outputs = np.stack(gate_outputs).squeeze(-1).transpose((1, 0, 2))
72+
# (T_out, B, n_mel_channels) -> (B, T_out, n_mel_channels)
73+
mel_outputs = np.stack(mel_outputs).transpose((1, 0, 2, 3))
74+
# decouple frames per step
75+
mel_outputs = mel_outputs.reshape(mel_outputs.shape[0], -1, n_mel_channels)
76+
# (B, T_out, n_mel_channels) -> (B, n_mel_channels, T_out)
77+
mel_outputs = mel_outputs.transpose((0, 2, 1))
78+
79+
return mel_outputs, gate_outputs, alignments
80+
81+
82+
# only numpy operations
83+
def inference(text, encoder, decoder_iter, postnet):
84+
sequences, sequence_lengths = clean(text)
85+
86+
# print("Running Tacotron2 Encoder")
87+
inputs = {"seq": sequences, "seq_len": sequence_lengths}
88+
memory, processed_memory, _ = encoder.run(None, inputs)
89+
90+
# print("Running Tacotron2 Decoder")
91+
mel_lengths = np.zeros([memory.shape[0]], dtype=np.int32)
92+
not_finished = np.ones([memory.shape[0]], dtype=np.int32)
93+
mel_outputs, gate_outputs, alignments = [], [], []
94+
gate_threshold = 0.5
95+
max_decoder_steps = 5000
96+
first_iter = True
97+
98+
(
99+
attention_hidden,
100+
attention_cell,
101+
decoder_hidden,
102+
decoder_cell,
103+
attention_weights,
104+
attention_weights_cum,
105+
attention_context,
106+
) = initialize_decoder_states(memory)
107+
108+
decoder_input = get_go_frame(memory)
109+
110+
while True:
111+
inputs = {
112+
"decoder_input": decoder_input,
113+
"attention_hidden": attention_hidden,
114+
"attention_cell": attention_cell,
115+
"decoder_hidden": decoder_hidden,
116+
"decoder_cell": decoder_cell,
117+
"attention_weights": attention_weights,
118+
"attention_weights_cum": attention_weights_cum,
119+
"attention_context": attention_context,
120+
"memory": memory,
121+
"processed_memory": processed_memory,
122+
}
123+
(
124+
mel_output,
125+
gate_output,
126+
attention_hidden,
127+
attention_cell,
128+
decoder_hidden,
129+
decoder_cell,
130+
attention_weights,
131+
attention_weights_cum,
132+
attention_context,
133+
) = decoder_iter.run(None, inputs)
134+
135+
if first_iter:
136+
mel_outputs = [np.expand_dims(mel_output, 2)]
137+
gate_outputs = [np.expand_dims(gate_output, 2)]
138+
alignments = [np.expand_dims(attention_weights, 2)]
139+
first_iter = False
140+
else:
141+
mel_outputs += [np.expand_dims(mel_output, 2)]
142+
gate_outputs += [np.expand_dims(gate_output, 2)]
143+
alignments += [np.expand_dims(attention_weights, 2)]
144+
145+
dec = np.less(sigmoid(gate_output), gate_threshold)
146+
dec = np.squeeze(dec, axis=1)
147+
not_finished = not_finished * dec
148+
mel_lengths += not_finished
149+
150+
if not_finished.sum() == 0:
151+
# print("Stopping after ", len(mel_outputs), " decoder steps")
152+
break
153+
if len(mel_outputs) == max_decoder_steps:
154+
# print("Warning! Reached max decoder steps")
155+
break
156+
157+
decoder_input = mel_output
158+
159+
mel_outputs, gate_outputs, alignments = parse_decoder_outputs(
160+
mel_outputs, gate_outputs, alignments
161+
)
162+
163+
# print("Running Tacotron2 PostNet")
164+
inputs = {"mel_spec": mel_outputs}
165+
mel_outputs_postnet = postnet.run(None, inputs)
166+
167+
return mel_outputs_postnet
168+
169+
class LunarlistONNX:
170+
def __init__(self) -> None:
171+
self.encoder = ort.InferenceSession(hf_hub_download(repo_id="pythainlp/thaitts-onnx",filename="tacotron2encoder-th.onnx"))
172+
self.decoder = ort.InferenceSession(hf_hub_download(repo_id="pythainlp/thaitts-onnx",filename="tacotron2decoder-th.onnx"))
173+
self.postnet = ort.InferenceSession(hf_hub_download(repo_id="pythainlp/thaitts-onnx",filename="tacotron2postnet-th.onnx"))
174+
self.hifi = ort.InferenceSession(hf_hub_download(repo_id="pythainlp/thaitts-onnx",filename="vocoder.onnx"))
175+
def tts(self, text: str):
176+
mel = inference(text, self.encoder, self.decoder, self.postnet)
177+
return self.hifi.run(None, {"spec": mel[0]})
178+
def __call__(self, text: str,return_type: str = "file", filename: str = None):
179+
wavs = self.tts(text)
180+
if return_type == "waveform":
181+
return wavs[0][0, 0, :]
182+
import soundfile as sf
183+
if filename != None:
184+
sf.write(filename, wavs[0][0, 0, :], 22050)
185+
return filename
186+
else:
187+
with tempfile.NamedTemporaryFile(suffix = ".wav", delete = False) as fp:
188+
sf.write(fp.name, wavs[0][0, 0, :], 22050)
189+
return fp.name

requirements.txt

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
TTS>=0.8.0
2-
pythainlp>=3.0.0
31
huggingface_hub
4-
torch
2+
numpy>=1.22
3+
onnxruntime

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
setup(
1111
name="PyThaiTTS",
12-
version="0.2.1",
12+
version="0.3.0",
1313
description="Open Source Thai Text-to-speech library in Python",
1414
long_description=readme,
1515
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)