Skip to content

Commit fbbdd96

Browse files
authored
Merge pull request #13 from PyThaiNLP/copilot/add-support-for-vachanatts2
Add VachanaTTS2 model support
2 parents 15f5801 + 5be47c2 commit fbbdd96

5 files changed

Lines changed: 258 additions & 6 deletions

File tree

README.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,26 @@ file = tts.tts("ภาษาไทย ง่าย มาก มาก", filenam
2424
wave = tts.tts("ภาษาไทย ง่าย มาก มาก",return_type="waveform") # It will get waveform.
2525
```
2626

27+
### Using Different TTS Models
28+
29+
PyThaiTTS supports multiple TTS models. You can specify which model to use:
30+
31+
```python
32+
from pythaitts import TTS
33+
34+
# Use VachanaTTS (default voices: th_f_1, th_m_1, th_f_2, th_m_2)
35+
tts = TTS(pretrained="vachana")
36+
file = tts.tts("สวัสดีครับ", speaker_idx="th_f_1", filename="output.wav")
37+
38+
# Use Lunarlist ONNX (default)
39+
tts = TTS(pretrained="lunarlist_onnx")
40+
file = tts.tts("ภาษาไทย ง่าย มาก", filename="output.wav")
41+
42+
# Use KhanomTan
43+
tts = TTS(pretrained="khanomtan")
44+
file = tts.tts("ภาษาไทย", speaker_idx="Linda", filename="output.wav")
45+
```
46+
2747
### Text Preprocessing
2848

2949
PyThaiTTS includes automatic text preprocessing to improve TTS quality:

pythaitts/__init__.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@
1010
class TTS:
1111
def __init__(self, pretrained="lunarlist_onnx", mode="last_checkpoint", version="1.0", device:str="cpu") -> None:
1212
"""
13-
:param str pretrained: TTS pretrained (lunarlist_onnx, khanomtan, lunarlist)
14-
:param str mode: pretrained mode (lunarlist_onnx don't support)
13+
:param str pretrained: TTS pretrained (lunarlist_onnx, khanomtan, lunarlist, vachana)
14+
:param str mode: pretrained mode (lunarlist_onnx and vachana don't support)
1515
:param str version: model version (default is 1.0 or 1.1)
16-
:param str device: device for running model. (lunarlist_onnx support CPU only.)
16+
:param str device: device for running model. (lunarlist_onnx and vachana support CPU only.)
1717
1818
**Options for mode**
1919
* *last_checkpoint* (default) - last checkpoint of model
@@ -28,6 +28,8 @@ def __init__(self, pretrained="lunarlist_onnx", mode="last_checkpoint", version=
2828
For lunarlist_onnx tts model, \
2929
You can see more about lunarlist tts at `https://github.com/PyThaiNLP/thaitts-onnx <https://github.com/PyThaiNLP/thaitts-onnx>`_
3030
31+
For vachana tts model, \
32+
You can see more about vachana tts at `https://github.com/VYNCX/VachanaTTS2 <https://github.com/VYNCX/VachanaTTS2>`_
3133
3234
3335
"""
@@ -49,8 +51,11 @@ def load_pretrained(self,version):
4951
elif self.pretrained == "lunarlist":
5052
from pythaitts.pretrained.lunarlist_model import LunarlistModel
5153
self.model = LunarlistModel(mode=self.mode, device=self.device)
54+
elif self.pretrained == "vachana":
55+
from pythaitts.pretrained.vachana_tts import VachanaTTS
56+
self.model = VachanaTTS()
5257
else:
53-
raise NotImplemented(
58+
raise NotImplementedError(
5459
"PyThaiTTS doesn't support %s pretrained." % self.pretrained
5560
)
5661

@@ -59,7 +64,7 @@ def tts(self, text: str, speaker_idx: str = "Linda", language_idx: str = "th-th"
5964
speech synthesis
6065
6166
:param str text: text
62-
:param str speaker_idx: speaker (default is Linda)
67+
:param str speaker_idx: speaker (default is Linda for khanomtan, th_f_1 for vachana)
6368
:param str language_idx: language (default is th-th)
6469
:param str return_type: return type (default is file)
6570
:param str filename: path filename for save wav file if return_type is file.
@@ -72,6 +77,8 @@ def tts(self, text: str, speaker_idx: str = "Linda", language_idx: str = "th-th"
7277

7378
if self.pretrained == "lunarlist" or self.pretrained == "lunarlist_onnx":
7479
return self.model(text=text,return_type=return_type,filename=filename)
80+
elif self.pretrained == "vachana":
81+
return self.model(text=text,speaker_idx=speaker_idx,return_type=return_type,filename=filename)
7582
return self.model(
7683
text=text,
7784
speaker_idx=speaker_idx,
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
VachanaTTS2 model
4+
5+
VachanaTTS2 is a Thai text-to-speech model built on VITS architecture.
6+
It supports multiple Thai voices and is optimized for both CPU and GPU usage.
7+
8+
See more: https://github.com/VYNCX/VachanaTTS2
9+
"""
10+
import tempfile
11+
import wave
12+
import numpy as np
13+
import os
14+
15+
16+
class VachanaTTS:
17+
# Supported voice options
18+
SUPPORTED_VOICES = ["th_f_1", "th_m_1", "th_f_2", "th_m_2"]
19+
20+
def __init__(self) -> None:
21+
"""
22+
Initialize VachanaTTS model.
23+
The model will be automatically downloaded from HuggingFace on first use.
24+
"""
25+
try:
26+
from vachanatts import TTS as VachanaTTS_TTS
27+
self.tts_func = VachanaTTS_TTS
28+
except ImportError:
29+
raise ImportError(
30+
"vachanatts is not installed. Please install it with: pip install vachanatts"
31+
)
32+
33+
def __call__(self, text: str, speaker_idx: str = "th_f_1", return_type: str = "file", filename: str = None, **kwargs):
34+
"""
35+
Generate speech from text using VachanaTTS.
36+
37+
:param str text: Input text to synthesize
38+
:param str speaker_idx: Voice to use (th_f_1, th_m_1, th_f_2, th_m_2). Default is "th_f_1"
39+
:param str return_type: Return type ("file" or "waveform")
40+
:param str filename: Output filename for the generated audio
41+
:param kwargs: Additional parameters (volume, speed, noise_scale, noise_w_scale)
42+
:return: File path if return_type is "file", otherwise audio waveform data
43+
"""
44+
# Validate speaker_idx
45+
if speaker_idx not in self.SUPPORTED_VOICES:
46+
raise ValueError(
47+
f"Unsupported voice '{speaker_idx}'. Supported voices are: {', '.join(self.SUPPORTED_VOICES)}"
48+
)
49+
50+
# Extract additional parameters with defaults
51+
volume = kwargs.get('volume', 1.0)
52+
speed = kwargs.get('speed', 1.0)
53+
noise_scale = kwargs.get('noise_scale', 0.667)
54+
noise_w_scale = kwargs.get('noise_w_scale', 0.8)
55+
56+
if return_type == "waveform":
57+
# For waveform return, we need to generate to a temp file then read it
58+
temp_filename = None
59+
try:
60+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
61+
temp_filename = fp.name
62+
63+
# Generate the audio file
64+
self.tts_func(
65+
text,
66+
voice=speaker_idx,
67+
output=temp_filename,
68+
volume=volume,
69+
speed=speed,
70+
noise_scale=noise_scale,
71+
noise_w_scale=noise_w_scale
72+
)
73+
74+
# Read the waveform from the file
75+
with wave.open(temp_filename, 'rb') as wav_file:
76+
n_frames = wav_file.getnframes()
77+
audio_data = wav_file.readframes(n_frames)
78+
sample_width = wav_file.getsampwidth()
79+
80+
# Convert bytes to numpy array based on sample width
81+
if sample_width == 1:
82+
waveform = np.frombuffer(audio_data, dtype=np.int8)
83+
elif sample_width == 2:
84+
waveform = np.frombuffer(audio_data, dtype=np.int16)
85+
elif sample_width == 4:
86+
waveform = np.frombuffer(audio_data, dtype=np.int32)
87+
else:
88+
raise ValueError(f"Unsupported sample width: {sample_width} bytes")
89+
90+
return waveform
91+
finally:
92+
# Clean up temp file
93+
if temp_filename and os.path.exists(temp_filename):
94+
os.unlink(temp_filename)
95+
else:
96+
# File output
97+
if filename is None:
98+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
99+
filename = fp.name
100+
101+
self.tts_func(
102+
text,
103+
voice=speaker_idx,
104+
output=filename,
105+
volume=volume,
106+
speed=speed,
107+
noise_scale=noise_scale,
108+
noise_w_scale=noise_w_scale
109+
)
110+
111+
return filename

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
huggingface_hub
22
numpy>=1.22
3-
onnxruntime
3+
onnxruntime
4+
vachanatts

tests/test_vachana.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Unit tests for VachanaTTS integration
4+
"""
5+
import unittest
6+
from unittest.mock import Mock, patch, MagicMock
7+
import numpy as np
8+
from pythaitts import TTS
9+
10+
11+
class TestVachanaIntegration(unittest.TestCase):
12+
"""Test VachanaTTS integration"""
13+
14+
@patch('pythaitts.pretrained.vachana_tts.VachanaTTS')
15+
def test_vachana_model_initialization(self, mock_vachana):
16+
"""Test that VachanaTTS model can be initialized"""
17+
# Create TTS instance with vachana model
18+
tts = TTS(pretrained="vachana")
19+
20+
# Verify model is loaded
21+
self.assertIsNotNone(tts.model)
22+
self.assertEqual(tts.pretrained, "vachana")
23+
24+
@patch('pythaitts.pretrained.vachana_tts.VachanaTTS')
25+
def test_vachana_tts_call(self, mock_vachana_class):
26+
"""Test calling tts method with vachana model"""
27+
# Setup mock
28+
mock_instance = Mock()
29+
mock_instance.return_value = "/tmp/output.wav"
30+
mock_vachana_class.return_value = mock_instance
31+
32+
# Create TTS instance
33+
tts = TTS(pretrained="vachana")
34+
35+
# Call tts method
36+
result = tts.tts("สวัสดีครับ", speaker_idx="th_f_1", filename="/tmp/test.wav")
37+
38+
# Verify the model was called with correct parameters
39+
mock_instance.assert_called_once()
40+
call_args = mock_instance.call_args
41+
self.assertEqual(call_args.kwargs['text'], "สวัสดีครับ")
42+
self.assertEqual(call_args.kwargs['speaker_idx'], "th_f_1")
43+
self.assertEqual(call_args.kwargs['filename'], "/tmp/test.wav")
44+
self.assertEqual(call_args.kwargs['return_type'], "file")
45+
46+
@patch('pythaitts.pretrained.vachana_tts.VachanaTTS')
47+
def test_vachana_with_preprocessing(self, mock_vachana_class):
48+
"""Test that preprocessing works with vachana model"""
49+
# Setup mock
50+
mock_instance = Mock()
51+
mock_instance.return_value = "/tmp/output.wav"
52+
mock_vachana_class.return_value = mock_instance
53+
54+
# Create TTS instance
55+
tts = TTS(pretrained="vachana")
56+
57+
# Call tts method with text that needs preprocessing
58+
result = tts.tts("มี 5 คนๆ", speaker_idx="th_f_1", preprocess=True)
59+
60+
# Verify preprocessing was applied
61+
mock_instance.assert_called_once()
62+
call_args = mock_instance.call_args
63+
processed_text = call_args.kwargs['text']
64+
65+
# Text should have numbers converted and ๆ expanded
66+
self.assertNotIn("5", processed_text)
67+
self.assertNotIn("ๆ", processed_text)
68+
self.assertIn("ห้า", processed_text)
69+
self.assertIn("คนคน", processed_text)
70+
71+
@patch('pythaitts.pretrained.vachana_tts.VachanaTTS')
72+
def test_vachana_all_supported_voices(self, mock_vachana_class):
73+
"""Test that all supported voices work correctly"""
74+
# Setup mock
75+
mock_instance = Mock()
76+
mock_instance.return_value = "/tmp/output.wav"
77+
mock_vachana_class.return_value = mock_instance
78+
79+
# Create TTS instance
80+
tts = TTS(pretrained="vachana")
81+
82+
# Test all supported voices
83+
supported_voices = ["th_f_1", "th_m_1", "th_f_2", "th_m_2"]
84+
for voice in supported_voices:
85+
mock_instance.reset_mock()
86+
result = tts.tts("สวัสดี", speaker_idx=voice)
87+
88+
# Verify the voice was passed correctly
89+
call_args = mock_instance.call_args
90+
self.assertEqual(call_args.kwargs['speaker_idx'], voice)
91+
92+
@patch('pythaitts.pretrained.vachana_tts.VachanaTTS')
93+
def test_vachana_waveform_return(self, mock_vachana_class):
94+
"""Test waveform return type functionality"""
95+
# Setup mock
96+
mock_instance = Mock()
97+
mock_waveform = np.array([0.1, 0.2, 0.3, 0.4])
98+
mock_instance.return_value = mock_waveform
99+
mock_vachana_class.return_value = mock_instance
100+
101+
# Create TTS instance
102+
tts = TTS(pretrained="vachana")
103+
104+
# Call tts method with waveform return type
105+
result = tts.tts("สวัสดี", speaker_idx="th_f_1", return_type="waveform")
106+
107+
# Verify the return type was set correctly
108+
call_args = mock_instance.call_args
109+
self.assertEqual(call_args.kwargs['return_type'], "waveform")
110+
111+
112+
if __name__ == '__main__':
113+
unittest.main()

0 commit comments

Comments
 (0)