-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Expand file tree
/
Copy pathserver.py
More file actions
121 lines (102 loc) · 4.75 KB
/
server.py
File metadata and controls
121 lines (102 loc) · 4.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import io
import argparse
import logging
logging.getLogger('matplotlib').setLevel(logging.WARNING)
from fastapi import FastAPI, UploadFile, Form, File
from fastapi.responses import StreamingResponse, Response
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
import numpy as np
import torch
import torchaudio
CURR_DIR = os.path.dirname(os.path.abspath(__file__))
ROOT_DIR = f'{CURR_DIR}/../../..'
sys.path.append(f'{ROOT_DIR}')
sys.path.append(f'{ROOT_DIR}/third_party/Matcha-TTS')
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
from cosyvoice.utils.file_utils import load_wav
model_dir = f"{ROOT_DIR}/pretrained_models/CosyVoice2-0.5B"
cosyvoice = CosyVoice2(model_dir) if 'CosyVoice2' in model_dir else CosyVoice(model_dir)
app = FastAPI()
# set cross region allowance
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"])
# 非流式wav数据
def build_data(model_output):
tts_speeches = []
for i in model_output:
tts_speeches.append(i['tts_speech'])
output = torch.concat(tts_speeches, dim=1)
buffer = io.BytesIO()
torchaudio.save(buffer, output, 22050, format="wav")
buffer.seek(0)
return buffer.read(-1)
# 流式pcm数据
def generate_data(model_output):
for i in model_output:
tts_audio = (i['tts_speech'].numpy() * (2 ** 15)).astype(np.int16).tobytes()
yield tts_audio
@app.get("/inference_sft")
async def inference_sft(tts_text: str = Form(), spk_id: str = Form(), stream: bool = Form(default=False), format: str = Form(default="pcm")):
model_output = cosyvoice.inference_sft(tts_text, spk_id, stream=stream)
if format == "pcm":
return StreamingResponse(generate_data(model_output))
else:
return Response(build_data(model_output), media_type="audio/wav")
@app.get("/inference_zero_shot")
async def inference_zero_shot(tts_text: str = Form(), prompt_text: str = Form(), prompt_wav: UploadFile = File(), stream: bool = Form(default=False), format: str = Form(default="pcm")):
prompt_speech_16k = load_wav(prompt_wav.file, 16000)
model_output = cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k, stream=stream)
if format == "pcm":
return StreamingResponse(generate_data(model_output))
else:
return Response(build_data(model_output), media_type="audio/wav")
@app.get("/inference_cross_lingual")
async def inference_cross_lingual(tts_text: str = Form(), prompt_wav: UploadFile = File(), stream: bool = Form(default=False), format: str = Form(default="pcm")):
prompt_speech_16k = load_wav(prompt_wav.file, 16000)
model_output = cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k, stream=stream)
if format == "pcm":
return StreamingResponse(generate_data(model_output))
else:
return Response(build_data(model_output), media_type="audio/wav")
@app.get("/inference_instruct")
async def inference_instruct(tts_text: str = Form(), spk_id: str = Form(), instruct_text: str = Form(), stream: bool = Form(default=False), format: str = Form(default="pcm")):
model_output = cosyvoice.inference_instruct(tts_text, spk_id, instruct_text, stream=stream)
if format == "pcm":
return StreamingResponse(generate_data(model_output))
else:
return Response(build_data(model_output), media_type="audio/wav")
@app.get("/inference_instruct_v2")
async def inference_instruct_v2(tts_text: str = Form(), instruct_text: str = Form(), prompt_wav: UploadFile = File(), stream: bool = Form(default=False), format: str = Form(default="pcm")):
prompt_speech_16k = load_wav(prompt_wav.file, 16000)
model_output = cosyvoice.inference_instruct2(tts_text, instruct_text, prompt_speech_16k, stream=stream)
if format == "pcm":
return StreamingResponse(generate_data(model_output))
else:
return Response(build_data(model_output), media_type="audio/wav")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--port',
type=int,
default=50000)
args = parser.parse_args()
uvicorn.run(app, host="0.0.0.0", port=args.port)