Skip to content

Commit 01d26c2

Browse files
committed
easy test_zero_shot.py
1 parent cdece6b commit 01d26c2

1 file changed

Lines changed: 12 additions & 11 deletions

File tree

test/test_zero_shot.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,21 @@
1414
"--cosyvoice_version", type=int, choices=[2, 3], default=3, help="CosyVoice version: 2 or 3 (default: 3)"
1515
)
1616
parser.add_argument("--stream", action="store_true", default=False, help="是否使用流式推理 (default: True)")
17-
parser.add_argument("--num", type=int, default=5, help="测试数量 (default: 5)")
17+
parser.add_argument("--num", type=int, default=1, help="测试数量 (default: 5)")
1818
args = parser.parse_args()
1919

2020
url = f"http://0.0.0.0:{args.port}/inference_zero_shot"
21-
num = args.num
21+
2222
# 准备要发送的文本和音频文件
2323
path = "../cosyvoice/asset/zero_shot_prompt.wav"
24+
your_text = "收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。"
25+
# 根据 cosyvoice_version 设置 prompt_text
26+
if args.cosyvoice_version == 3:
27+
prompt_text = "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"
28+
else:
29+
prompt_text = "希望你以后能够做的比我还好呦。"
30+
31+
num = args.num
2432
stream = args.stream # 是否使用流式推理
2533
with open("test_texts.json", "r") as f:
2634
all_inputs = json.load(f)
@@ -31,16 +39,9 @@
3139
def get_file(index):
3240
files = {"prompt_wav": ("sample.wav", open(path, "rb"), "audio/wav")}
3341
# inputs = random.choice(all_inputs)
34-
inputs = all_inputs[0]
3542
# inputs = all_inputs[2]
3643

37-
# 根据 cosyvoice_version 设置 prompt_text
38-
if args.cosyvoice_version == 3:
39-
prompt_text = "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"
40-
else:
41-
prompt_text = "希望你以后能够做的比我还好呦。"
42-
43-
data = {"tts_text": inputs, "prompt_text": prompt_text, "stream": stream}
44+
data = {"tts_text": your_text, "prompt_text": prompt_text, "stream": stream}
4445
start_time = time.time()
4546

4647
response = requests.post(url, files=files, data=data, stream=True)
@@ -74,7 +75,7 @@ def get_file(index):
7475
output_wav = f"./outs/output{'_stream' if stream else ''}_{index}.wav"
7576
sf.write(output_wav, audio_np, samplerate=sample_rate, subtype="PCM_16")
7677
print(
77-
f"{inputs} saved as {output_wav}, time cost: {cost_time:.2f} s"
78+
f"{your_text} saved as {output_wav}, time cost: {cost_time:.2f} s"
7879
+ f", rtf: {cost_time / speech_len}, ttft: {ttft:.2f} s"
7980
)
8081
else:

0 commit comments

Comments
 (0)