Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 47 additions & 17 deletions cosyvoice/cli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import torch
import numpy as np
import threading
import time
from torch.nn import functional as F
from contextlib import nullcontext
import uuid
Expand Down Expand Up @@ -57,6 +56,7 @@ def __init__(self,
# dict used to store session related variable
self.tts_speech_token_dict = {}
self.llm_end_dict = {}
self.token_condition_dict = {}
self.mel_overlap_dict = {}
self.flow_cache_dict = {}
self.hift_cache_dict = {}
Expand Down Expand Up @@ -125,12 +125,18 @@ def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uui
continue
else:
cur_silent_token_num = 0
self.tts_speech_token_dict[uuid].append(i)
self.llm_end_dict[uuid] = True
with self.lock:
self.tts_speech_token_dict[uuid].append(i)
self.token_condition_dict[uuid].notify()
with self.lock:
self.llm_end_dict[uuid] = True
self.token_condition_dict[uuid].notify()

def vc_job(self, source_speech_token, uuid):
self.tts_speech_token_dict[uuid] = source_speech_token.flatten().tolist()
self.llm_end_dict[uuid] = True
with self.lock:
self.tts_speech_token_dict[uuid] = source_speech_token.flatten().tolist()
self.llm_end_dict[uuid] = True
self.token_condition_dict[uuid].notify()

def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
with torch.cuda.amp.autocast(self.fp16):
Expand Down Expand Up @@ -181,6 +187,7 @@ def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.ze
this_uuid = str(uuid.uuid1())
with self.lock:
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
self.token_condition_dict[this_uuid] = threading.Condition(self.lock)
self.hift_cache_dict[this_uuid] = None
self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
Expand All @@ -192,10 +199,18 @@ def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.ze
if stream is True:
token_hop_len = self.token_min_hop_len
while True:
time.sleep(0.1)
if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
.unsqueeze(dim=0)
with self.lock:
while len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len and \
self.llm_end_dict[this_uuid] is False:
self.token_condition_dict[this_uuid].wait()
if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
this_tts_speech_token_slice = self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]
elif self.llm_end_dict[this_uuid] is True:
break
else:
continue
this_tts_speech_token = torch.tensor(this_tts_speech_token_slice).unsqueeze(dim=0)
if this_tts_speech_token.shape[1] != 0:
this_tts_speech = self.token2wav(token=this_tts_speech_token,
prompt_token=flow_prompt_speech_token,
prompt_feat=prompt_speech_feat,
Expand All @@ -207,8 +222,6 @@ def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.ze
self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
# increase token_hop_len for better speech quality
token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
break
p.join()
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
Expand All @@ -234,6 +247,7 @@ def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.ze
with self.lock:
self.tts_speech_token_dict.pop(this_uuid)
self.llm_end_dict.pop(this_uuid)
self.token_condition_dict.pop(this_uuid)
self.mel_overlap_dict.pop(this_uuid)
self.hift_cache_dict.pop(this_uuid)
self.flow_cache_dict.pop(this_uuid)
Expand Down Expand Up @@ -271,6 +285,7 @@ def __init__(self,
# dict used to store session related variable
self.tts_speech_token_dict = {}
self.llm_end_dict = {}
self.token_condition_dict = {}
self.hift_cache_dict = {}
self.silent_tokens = []

Expand All @@ -287,6 +302,10 @@ def load_vllm(self, model_dir):
gpu_memory_utilization=0.2)
self.llm.vllm = LLMEngine.from_engine_args(engine_args)
self.llm.lock = threading.Lock()
self.llm.vllm_step_condition = threading.Condition(self.llm.lock)
self.llm.vllm_step_thread = None
self.llm.vllm_background_error = None
self.llm._ensure_vllm_runtime()
del self.llm.llm.model.model.layers

def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
Expand Down Expand Up @@ -334,6 +353,7 @@ def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.ze
this_uuid = str(uuid.uuid1())
with self.lock:
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
self.token_condition_dict[this_uuid] = threading.Condition(self.lock)
self.hift_cache_dict[this_uuid] = None
if source_speech_token.shape[1] == 0:
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
Expand All @@ -344,10 +364,19 @@ def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.ze
token_offset = 0
prompt_token_pad = int(np.ceil(flow_prompt_speech_token.shape[1] / self.token_hop_len) * self.token_hop_len - flow_prompt_speech_token.shape[1])
while True:
time.sleep(0.1)
this_token_hop_len = self.token_hop_len + prompt_token_pad if token_offset == 0 else self.token_hop_len
if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= this_token_hop_len + self.flow.pre_lookahead_len:
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + this_token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0)
required_token_len = token_offset + this_token_hop_len + self.flow.pre_lookahead_len
with self.lock:
while len(self.tts_speech_token_dict[this_uuid]) < required_token_len and self.llm_end_dict[this_uuid] is False:
self.token_condition_dict[this_uuid].wait()
if len(self.tts_speech_token_dict[this_uuid]) >= required_token_len:
this_tts_speech_token_slice = self.tts_speech_token_dict[this_uuid][:required_token_len]
elif self.llm_end_dict[this_uuid] is True:
break
else:
continue
this_tts_speech_token = torch.tensor(this_tts_speech_token_slice).unsqueeze(dim=0)
if this_tts_speech_token.shape[1] != 0:
this_tts_speech = self.token2wav(token=this_tts_speech_token,
prompt_token=flow_prompt_speech_token,
prompt_feat=prompt_speech_feat,
Expand All @@ -359,8 +388,6 @@ def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.ze
token_offset += this_token_hop_len
self.token_hop_len = min(self.token_max_hop_len, self.token_hop_len * self.stream_scale_factor)
yield {'tts_speech': this_tts_speech.cpu()}
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < this_token_hop_len + self.flow.pre_lookahead_len:
break
p.join()
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
Expand All @@ -370,6 +397,7 @@ def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.ze
embedding=flow_embedding,
token_offset=token_offset,
uuid=this_uuid,
stream=stream,
finalize=True)
yield {'tts_speech': this_tts_speech.cpu()}
else:
Expand All @@ -388,9 +416,10 @@ def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.ze
with self.lock:
self.tts_speech_token_dict.pop(this_uuid)
self.llm_end_dict.pop(this_uuid)
self.token_condition_dict.pop(this_uuid)
self.hift_cache_dict.pop(this_uuid)
if torch.cuda.is_available():
torch.cuda.empty_cache()
# torch.cuda.empty_cache()
torch.cuda.current_stream().synchronize()


Expand Down Expand Up @@ -418,6 +447,7 @@ def __init__(self,
# dict used to store session related variable
self.tts_speech_token_dict = {}
self.llm_end_dict = {}
self.token_condition_dict = {}
self.hift_cache_dict = {}
# FSQ silent and breath token
self.silent_tokens = [1, 2, 28, 29, 55, 248, 494, 2241, 2242, 2322, 2323]
Expand Down
13 changes: 9 additions & 4 deletions cosyvoice/flow/flow_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import nullcontext

import torch
import torch.nn.functional as F
from matcha.models.components.flow_matching import BASECFM
Expand Down Expand Up @@ -128,9 +130,11 @@ def forward_estimator(self, x, mask, mu, t, spks, cond, streaming=False):
return self.estimator(x, mask, mu, t, spks, cond, streaming=streaming)
else:
[estimator, stream], trt_engine = self.estimator.acquire_estimator()
# NOTE need to synchronize when switching stream
torch.cuda.current_stream().synchronize()
with stream:
stream_context = stream if stream is not None else nullcontext()
if stream is not None:
# NOTE only synchronize when switching to a dedicated TRT stream.
torch.cuda.current_stream().synchronize()
with stream_context:
estimator.set_input_shape('x', (2, 80, x.size(2)))
estimator.set_input_shape('mask', (2, 1, x.size(2)))
estimator.set_input_shape('mu', (2, 80, x.size(2)))
Expand All @@ -148,7 +152,8 @@ def forward_estimator(self, x, mask, mu, t, spks, cond, streaming=False):
estimator.set_tensor_address(trt_engine.get_tensor_name(i), j)
# run trt engine
assert estimator.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
torch.cuda.current_stream().synchronize()
if stream is not None:
torch.cuda.current_stream().synchronize()
self.estimator.release_estimator(estimator, stream)
return x

Expand Down
Loading