Skip to content

Commit f65d69b

Browse files
authored
Refactor Synthesizer initialization parameters
1 parent d75cb36 commit f65d69b

File tree

1 file changed

+13
-19
lines changed

1 file changed

+13
-19
lines changed

pythaitts/pretrained/khanomtan_tts.py

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -47,28 +47,22 @@ def load_synthesizer(self, mode):
4747
if mode=="best_model":
4848
self.best_model_path = hf_hub_download(repo_id="wannaphong/khanomtan-tts-v{0}".format(self.version),filename=self.best_model_path_name,force_filename="best_model-v{0}.pth".format(self.version))
4949
self.synthesizer = Synthesizer(
50-
self.best_model_path,
51-
self.config_path,
52-
self.speakers_path,
53-
self.languages_path,
54-
None,
55-
None,
56-
self.speaker_encoder_model_path,
57-
self.speaker_encoder_config_path,
58-
False
50+
tts_checkpoint=self.best_model_path,
51+
tts_config_path=self.config_path,
52+
tts_speakers_file=self.speakers_path,
53+
tts_languages_file=self.languages_path,
54+
encoder_checkpoint=self.speaker_encoder_model_path,
55+
encoder_config=self.speaker_encoder_config_path
5956
)
6057
else:
6158
self.last_checkpoint_model_path = hf_hub_download(repo_id="wannaphong/khanomtan-tts-v{0}".format(self.version),filename=self.last_checkpoint_model_path_name,force_filename="last_checkpoint-v{0}.pth".format(self.version))
6259
self.synthesizer = Synthesizer(
63-
self.last_checkpoint_model_path,
64-
self.config_path,
65-
self.speakers_path,
66-
self.languages_path,
67-
None,
68-
None,
69-
self.speaker_encoder_model_path,
70-
self.speaker_encoder_config_path,
71-
False
60+
tts_checkpoint=self.last_checkpoint_model_path,
61+
tts_config_path=self.config_path,
62+
tts_speakers_file=self.speakers_path,
63+
tts_languages_file=self.languages_path,
64+
encoder_checkpoint=self.speaker_encoder_model_path,
65+
encoder_config=self.speaker_encoder_config_path
7266
)
7367
def __call__(self, text: str, speaker_idx: str, language_idx: str, return_type: str = "file", filename: str = None):
7468
wavs = self.synthesizer.tts(text, speaker_idx, language_idx)
@@ -80,4 +74,4 @@ def __call__(self, text: str, speaker_idx: str, language_idx: str, return_type:
8074
else:
8175
with tempfile.NamedTemporaryFile(suffix = ".wav", delete = False) as fp:
8276
self.synthesizer.save_wav(wavs, fp)
83-
return fp.name
77+
return fp.name

0 commit comments

Comments
 (0)