88See more: https://github.com/VYNCX/VachanaTTS2
99"""
1010import tempfile
11+ import wave
12+ import numpy as np
13+ import os
1114
1215
1316class VachanaTTS :
17+ # Supported voice options
18+ SUPPORTED_VOICES = ["th_f_1" , "th_m_1" , "th_f_2" , "th_m_2" ]
19+
1420 def __init__ (self ) -> None :
1521 """
1622 Initialize VachanaTTS model.
@@ -35,6 +41,12 @@ def __call__(self, text: str, speaker_idx: str = "th_f_1", return_type: str = "f
3541 :param kwargs: Additional parameters (volume, speed, noise_scale, noise_w_scale)
3642 :return: File path if return_type is "file", otherwise audio waveform data
3743 """
44+ # Validate speaker_idx
45+ if speaker_idx not in self .SUPPORTED_VOICES :
46+ raise ValueError (
47+ f"Unsupported voice '{ speaker_idx } '. Supported voices are: { ', ' .join (self .SUPPORTED_VOICES )} "
48+ )
49+
3850 # Extract additional parameters with defaults
3951 volume = kwargs .get ('volume' , 1.0 )
4052 speed = kwargs .get ('speed' , 1.0 )
@@ -43,39 +55,43 @@ def __call__(self, text: str, speaker_idx: str = "th_f_1", return_type: str = "f
4355
4456 if return_type == "waveform" :
4557 # For waveform return, we need to generate to a temp file then read it
46- import wave
47- import numpy as np
48- with tempfile .NamedTemporaryFile (suffix = ".wav" , delete = False ) as fp :
49- temp_filename = fp .name
50-
51- # Generate the audio file
52- self .tts_func (
53- text ,
54- voice = speaker_idx ,
55- output = temp_filename ,
56- volume = volume ,
57- speed = speed ,
58- noise_scale = noise_scale ,
59- noise_w_scale = noise_w_scale
60- )
61-
62- # Read the waveform from the file
63- with wave .open (temp_filename , 'rb' ) as wav_file :
64- n_frames = wav_file .getnframes ()
65- audio_data = wav_file .readframes (n_frames )
66- # Convert bytes to numpy array
67- import struct
68- sample_width = wav_file .getsampwidth ()
69- if sample_width == 2 :
70- waveform = np .frombuffer (audio_data , dtype = np .int16 )
71- else :
72- waveform = np .frombuffer (audio_data , dtype = np .int8 )
73-
74- # Clean up temp file
75- import os
76- os .unlink (temp_filename )
77-
78- return waveform
58+ temp_filename = None
59+ try :
60+ with tempfile .NamedTemporaryFile (suffix = ".wav" , delete = False ) as fp :
61+ temp_filename = fp .name
62+
63+ # Generate the audio file
64+ self .tts_func (
65+ text ,
66+ voice = speaker_idx ,
67+ output = temp_filename ,
68+ volume = volume ,
69+ speed = speed ,
70+ noise_scale = noise_scale ,
71+ noise_w_scale = noise_w_scale
72+ )
73+
74+ # Read the waveform from the file
75+ with wave .open (temp_filename , 'rb' ) as wav_file :
76+ n_frames = wav_file .getnframes ()
77+ audio_data = wav_file .readframes (n_frames )
78+ sample_width = wav_file .getsampwidth ()
79+
80+ # Convert bytes to numpy array based on sample width
81+ if sample_width == 1 :
82+ waveform = np .frombuffer (audio_data , dtype = np .int8 )
83+ elif sample_width == 2 :
84+ waveform = np .frombuffer (audio_data , dtype = np .int16 )
85+ elif sample_width == 4 :
86+ waveform = np .frombuffer (audio_data , dtype = np .int32 )
87+ else :
88+ raise ValueError (f"Unsupported sample width: { sample_width } bytes" )
89+
90+ return waveform
91+ finally :
92+ # Clean up temp file
93+ if temp_filename and os .path .exists (temp_filename ):
94+ os .unlink (temp_filename )
7995 else :
8096 # File output
8197 if filename is None :
0 commit comments