Skip to content

Commit 5be47c2

Browse files
Copilotwannaphong
andcommitted
Address code review feedback: improve error handling, validation, and test coverage
Co-authored-by: wannaphong <8536487+wannaphong@users.noreply.github.com>
1 parent 336041d commit 5be47c2

3 files changed

Lines changed: 91 additions & 34 deletions

File tree

pythaitts/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def load_pretrained(self,version):
5555
from pythaitts.pretrained.vachana_tts import VachanaTTS
5656
self.model = VachanaTTS()
5757
else:
58-
raise NotImplemented(
58+
raise NotImplementedError(
5959
"PyThaiTTS doesn't support %s pretrained." % self.pretrained
6060
)
6161

pythaitts/pretrained/vachana_tts.py

Lines changed: 49 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,15 @@
88
See more: https://github.com/VYNCX/VachanaTTS2
99
"""
1010
import tempfile
11+
import wave
12+
import numpy as np
13+
import os
1114

1215

1316
class VachanaTTS:
17+
# Supported voice options
18+
SUPPORTED_VOICES = ["th_f_1", "th_m_1", "th_f_2", "th_m_2"]
19+
1420
def __init__(self) -> None:
1521
"""
1622
Initialize VachanaTTS model.
@@ -35,6 +41,12 @@ def __call__(self, text: str, speaker_idx: str = "th_f_1", return_type: str = "f
3541
:param kwargs: Additional parameters (volume, speed, noise_scale, noise_w_scale)
3642
:return: File path if return_type is "file", otherwise audio waveform data
3743
"""
44+
# Validate speaker_idx
45+
if speaker_idx not in self.SUPPORTED_VOICES:
46+
raise ValueError(
47+
f"Unsupported voice '{speaker_idx}'. Supported voices are: {', '.join(self.SUPPORTED_VOICES)}"
48+
)
49+
3850
# Extract additional parameters with defaults
3951
volume = kwargs.get('volume', 1.0)
4052
speed = kwargs.get('speed', 1.0)
@@ -43,39 +55,43 @@ def __call__(self, text: str, speaker_idx: str = "th_f_1", return_type: str = "f
4355

4456
if return_type == "waveform":
4557
# For waveform return, we need to generate to a temp file then read it
46-
import wave
47-
import numpy as np
48-
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
49-
temp_filename = fp.name
50-
51-
# Generate the audio file
52-
self.tts_func(
53-
text,
54-
voice=speaker_idx,
55-
output=temp_filename,
56-
volume=volume,
57-
speed=speed,
58-
noise_scale=noise_scale,
59-
noise_w_scale=noise_w_scale
60-
)
61-
62-
# Read the waveform from the file
63-
with wave.open(temp_filename, 'rb') as wav_file:
64-
n_frames = wav_file.getnframes()
65-
audio_data = wav_file.readframes(n_frames)
66-
# Convert bytes to numpy array
67-
import struct
68-
sample_width = wav_file.getsampwidth()
69-
if sample_width == 2:
70-
waveform = np.frombuffer(audio_data, dtype=np.int16)
71-
else:
72-
waveform = np.frombuffer(audio_data, dtype=np.int8)
73-
74-
# Clean up temp file
75-
import os
76-
os.unlink(temp_filename)
77-
78-
return waveform
58+
temp_filename = None
59+
try:
60+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
61+
temp_filename = fp.name
62+
63+
# Generate the audio file
64+
self.tts_func(
65+
text,
66+
voice=speaker_idx,
67+
output=temp_filename,
68+
volume=volume,
69+
speed=speed,
70+
noise_scale=noise_scale,
71+
noise_w_scale=noise_w_scale
72+
)
73+
74+
# Read the waveform from the file
75+
with wave.open(temp_filename, 'rb') as wav_file:
76+
n_frames = wav_file.getnframes()
77+
audio_data = wav_file.readframes(n_frames)
78+
sample_width = wav_file.getsampwidth()
79+
80+
# Convert bytes to numpy array based on sample width
81+
if sample_width == 1:
82+
waveform = np.frombuffer(audio_data, dtype=np.int8)
83+
elif sample_width == 2:
84+
waveform = np.frombuffer(audio_data, dtype=np.int16)
85+
elif sample_width == 4:
86+
waveform = np.frombuffer(audio_data, dtype=np.int32)
87+
else:
88+
raise ValueError(f"Unsupported sample width: {sample_width} bytes")
89+
90+
return waveform
91+
finally:
92+
# Clean up temp file
93+
if temp_filename and os.path.exists(temp_filename):
94+
os.unlink(temp_filename)
7995
else:
8096
# File output
8197
if filename is None:

tests/test_vachana.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"""
55
import unittest
66
from unittest.mock import Mock, patch, MagicMock
7+
import numpy as np
78
from pythaitts import TTS
89

910

@@ -67,6 +68,46 @@ def test_vachana_with_preprocessing(self, mock_vachana_class):
6768
self.assertIn("ห้า", processed_text)
6869
self.assertIn("คนคน", processed_text)
6970

71+
@patch('pythaitts.pretrained.vachana_tts.VachanaTTS')
72+
def test_vachana_all_supported_voices(self, mock_vachana_class):
73+
"""Test that all supported voices work correctly"""
74+
# Setup mock
75+
mock_instance = Mock()
76+
mock_instance.return_value = "/tmp/output.wav"
77+
mock_vachana_class.return_value = mock_instance
78+
79+
# Create TTS instance
80+
tts = TTS(pretrained="vachana")
81+
82+
# Test all supported voices
83+
supported_voices = ["th_f_1", "th_m_1", "th_f_2", "th_m_2"]
84+
for voice in supported_voices:
85+
mock_instance.reset_mock()
86+
result = tts.tts("สวัสดี", speaker_idx=voice)
87+
88+
# Verify the voice was passed correctly
89+
call_args = mock_instance.call_args
90+
self.assertEqual(call_args.kwargs['speaker_idx'], voice)
91+
92+
@patch('pythaitts.pretrained.vachana_tts.VachanaTTS')
93+
def test_vachana_waveform_return(self, mock_vachana_class):
94+
"""Test waveform return type functionality"""
95+
# Setup mock
96+
mock_instance = Mock()
97+
mock_waveform = np.array([0.1, 0.2, 0.3, 0.4])
98+
mock_instance.return_value = mock_waveform
99+
mock_vachana_class.return_value = mock_instance
100+
101+
# Create TTS instance
102+
tts = TTS(pretrained="vachana")
103+
104+
# Call tts method with waveform return type
105+
result = tts.tts("สวัสดี", speaker_idx="th_f_1", return_type="waveform")
106+
107+
# Verify the return type was set correctly
108+
call_args = mock_instance.call_args
109+
self.assertEqual(call_args.kwargs['return_type'], "waveform")
110+
70111

71112
if __name__ == '__main__':
72113
unittest.main()

0 commit comments

Comments
 (0)