Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,32 @@
sudo yum install sox sox-devel
```

### macOS Apple Silicon (M1/M2/M3/M4)

For Apple Silicon Macs, use the dedicated setup script:

``` sh
git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git
cd CosyVoice
bash setup_macos.sh
```

Or manually:

``` sh
conda create -n cosyvoice -y python=3.10
conda activate cosyvoice
conda install -c conda-forge pynini==2.1.5 -y
pip install torch torchaudio
pip install -r requirements.txt
```

**Apple Silicon notes:**
- Inference runs on MPS (Metal Performance Shaders) — faster than CPU
- TensorRT and vLLM are not available (CUDA-only)
- Training with DeepSpeed/DDP is not supported
- For CUDA environments (Linux), use `pip install -r requirements-cuda.txt` instead

### Model download

We strongly recommend that you download our pretrained `Fun-CosyVoice3-0.5B` `CosyVoice2-0.5B` `CosyVoice-300M` `CosyVoice-300M-SFT` `CosyVoice-300M-Instruct` model and `CosyVoice-ttsfrd` resource.
Expand Down
32 changes: 23 additions & 9 deletions cosyvoice/cli/cosyvoice.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model, CosyVoice3Model
from cosyvoice.utils.file_utils import logging
from cosyvoice.utils.class_utils import get_model_type
from cosyvoice.utils.device import is_cuda, is_gpu_available


class CosyVoice:
Expand All @@ -44,9 +45,12 @@ def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_co
'{}/spk2info.pt'.format(model_dir),
configs['allowed_special'])
self.sample_rate = configs['sample_rate']
if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
load_jit, load_trt, fp16 = False, False, False
logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
if not is_cuda() and load_trt:
load_trt = False
logging.warning('TensorRT requires CUDA, disabling load_trt')
if not is_gpu_available() and (load_jit or fp16):
load_jit, fp16 = False, False
logging.warning('no GPU device, disabling load_jit/fp16')
self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'], fp16)
self.model.load('{}/llm.pt'.format(model_dir),
'{}/flow.pt'.format(model_dir),
Expand Down Expand Up @@ -156,9 +160,16 @@ def __init__(self, model_dir, load_jit=False, load_trt=False, load_vllm=False, f
'{}/spk2info.pt'.format(model_dir),
configs['allowed_special'])
self.sample_rate = configs['sample_rate']
if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or load_vllm is True or fp16 is True):
load_jit, load_trt, load_vllm, fp16 = False, False, False, False
logging.warning('no cuda device, set load_jit/load_trt/load_vllm/fp16 to False')
if not is_cuda():
if load_trt:
load_trt = False
logging.warning('TensorRT requires CUDA, disabling load_trt')
if load_vllm:
load_vllm = False
logging.warning('vLLM requires CUDA, disabling load_vllm')
if not is_gpu_available() and (load_jit or fp16):
load_jit, fp16 = False, False
logging.warning('no GPU device, disabling load_jit/fp16')
self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16)
self.model.load('{}/llm.pt'.format(model_dir),
'{}/flow.pt'.format(model_dir),
Expand Down Expand Up @@ -206,9 +217,12 @@ def __init__(self, model_dir, load_trt=False, load_vllm=False, fp16=False, trt_c
'{}/spk2info.pt'.format(model_dir),
configs['allowed_special'])
self.sample_rate = configs['sample_rate']
if torch.cuda.is_available() is False and (load_trt is True or fp16 is True):
load_trt, fp16 = False, False
logging.warning('no cuda device, set load_trt/fp16 to False')
if not is_cuda() and load_trt:
load_trt = False
logging.warning('TensorRT requires CUDA, disabling load_trt')
if not is_gpu_available() and fp16:
fp16 = False
logging.warning('no GPU device, disabling fp16')
self.model = CosyVoice3Model(configs['llm'], configs['flow'], configs['hift'], fp16)
self.model.load('{}/llm.pt'.format(model_dir),
'{}/flow.pt'.format(model_dir),
Expand Down
12 changes: 9 additions & 3 deletions cosyvoice/cli/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import inflect
from cosyvoice.utils.file_utils import logging, load_wav
from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph, is_only_punctuation
from cosyvoice.utils.device import get_device


class CosyVoiceFrontEnd:
Expand All @@ -38,14 +39,19 @@ def __init__(self,
allowed_special: str = 'all'):
self.tokenizer = get_tokenizer()
self.feat_extractor = feat_extractor
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.device = get_device()
option = onnxruntime.SessionOptions()
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
option.intra_op_num_threads = 1
self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
if torch.cuda.is_available():
tokenizer_providers = ["CUDAExecutionProvider"]
elif "CoreMLExecutionProvider" in onnxruntime.get_available_providers():
tokenizer_providers = ["CoreMLExecutionProvider"]
else:
tokenizer_providers = ["CPUExecutionProvider"]
self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option,
providers=["CUDAExecutionProvider" if torch.cuda.is_available() else
"CPUExecutionProvider"])
providers=tokenizer_providers)
if os.path.exists(spk2info):
self.spk2info = torch.load(spk2info, map_location=self.device, weights_only=True)
else:
Expand Down
29 changes: 13 additions & 16 deletions cosyvoice/cli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from cosyvoice.utils.common import fade_in_out
from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vllm
from cosyvoice.utils.common import TrtContextWrapper
from cosyvoice.utils.device import get_device, get_stream_context, get_autocast_context, empty_cache


class CosyVoiceModel:
Expand All @@ -33,7 +34,7 @@ def __init__(self,
flow: torch.nn.Module,
hift: torch.nn.Module,
fp16: bool = False):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.device = get_device()
self.llm = llm
self.flow = flow
self.hift = hift
Expand All @@ -52,7 +53,7 @@ def __init__(self,
# rtf and decoding related
self.stream_scale_factor = 1
assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
self.llm_context = get_stream_context(self.device)
self.lock = threading.Lock()
# dict used to store session related variable
self.tts_speech_token_dict = {}
Expand Down Expand Up @@ -100,7 +101,7 @@ def get_trt_kwargs(self):

def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
cur_silent_token_num, max_silent_token_num = 0, 5
with self.llm_context, torch.cuda.amp.autocast(self.fp16 is True and hasattr(self.llm, 'vllm') is False):
with self.llm_context, get_autocast_context(self.fp16 is True and hasattr(self.llm, 'vllm') is False, self.device):
if isinstance(text, Generator):
assert (self.__class__.__name__ != 'CosyVoiceModel') and not hasattr(self.llm, 'vllm'), 'streaming input text is only implemented for CosyVoice2/3 and do not support vllm!'
token_generator = self.llm.inference_bistream(text=text,
Expand Down Expand Up @@ -133,7 +134,7 @@ def vc_job(self, source_speech_token, uuid):
self.llm_end_dict[uuid] = True

def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
with torch.cuda.amp.autocast(self.fp16):
with get_autocast_context(self.fp16, self.device):
tts_mel, self.flow_cache_dict[uuid] = self.flow.inference(token=token.to(self.device, dtype=torch.int32),
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
prompt_token=prompt_token.to(self.device),
Expand Down Expand Up @@ -237,9 +238,7 @@ def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.ze
self.mel_overlap_dict.pop(this_uuid)
self.hift_cache_dict.pop(this_uuid)
self.flow_cache_dict.pop(this_uuid)
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.current_stream().synchronize()
empty_cache(self.device)


class CosyVoice2Model(CosyVoiceModel):
Expand All @@ -249,7 +248,7 @@ def __init__(self,
flow: torch.nn.Module,
hift: torch.nn.Module,
fp16: bool = False):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.device = get_device()
self.llm = llm
self.flow = flow
self.hift = hift
Expand All @@ -266,7 +265,7 @@ def __init__(self,
# speech fade in out
self.speech_window = np.hamming(2 * self.source_cache_len)
# rtf and decoding related
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
self.llm_context = get_stream_context(self.device)
self.lock = threading.Lock()
# dict used to store session related variable
self.tts_speech_token_dict = {}
Expand All @@ -290,7 +289,7 @@ def load_vllm(self, model_dir):
del self.llm.llm.model.model.layers

def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
with torch.cuda.amp.autocast(self.fp16):
with get_autocast_context(self.fp16, self.device):
tts_mel, _ = self.flow.inference(token=token.to(self.device, dtype=torch.int32),
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
prompt_token=prompt_token.to(self.device),
Expand Down Expand Up @@ -389,9 +388,7 @@ def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.ze
self.tts_speech_token_dict.pop(this_uuid)
self.llm_end_dict.pop(this_uuid)
self.hift_cache_dict.pop(this_uuid)
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.current_stream().synchronize()
empty_cache(self.device)


class CosyVoice3Model(CosyVoice2Model):
Expand All @@ -401,7 +398,7 @@ def __init__(self,
flow: torch.nn.Module,
hift: torch.nn.Module,
fp16: bool = False):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.device = get_device()
self.llm = llm
self.flow = flow
self.hift = hift
Expand All @@ -413,7 +410,7 @@ def __init__(self,
self.stream_scale_factor = 2
assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
# rtf and decoding related
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
self.llm_context = get_stream_context(self.device)
self.lock = threading.Lock()
# dict used to store session related variable
self.tts_speech_token_dict = {}
Expand All @@ -423,7 +420,7 @@ def __init__(self,
self.silent_tokens = [1, 2, 28, 29, 55, 248, 494, 2241, 2242, 2322, 2323]

def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
with torch.cuda.amp.autocast(self.fp16):
with get_autocast_context(self.fp16, self.device):
tts_mel, _ = self.flow.inference(token=token.to(self.device, dtype=torch.int32),
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
prompt_token=prompt_token.to(self.device),
Expand Down
3 changes: 2 additions & 1 deletion cosyvoice/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,8 @@ def set_all_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)


def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
Expand Down
80 changes: 80 additions & 0 deletions cosyvoice/utils/device.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unified device management for CUDA, MPS (Apple Silicon), and CPU backends."""

import random
from contextlib import nullcontext

import numpy as np
import torch


def get_device() -> torch.device:
"""Return the best available device: cuda > mps > cpu."""
if torch.cuda.is_available():
return torch.device('cuda')
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
return torch.device('mps')
return torch.device('cpu')


def is_cuda() -> bool:
return torch.cuda.is_available()


def is_mps() -> bool:
return hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()


def is_gpu_available() -> bool:
return is_cuda() or is_mps()


def get_stream_context(device: torch.device):
"""Return a CUDA stream context or nullcontext for non-CUDA devices."""
if device.type == 'cuda':
return torch.cuda.stream(torch.cuda.Stream(device))
return nullcontext()


def get_autocast_context(enabled: bool, device: torch.device):
"""Return the appropriate autocast context for the device."""
if not enabled:
return nullcontext()
if device.type == 'cuda':
return torch.cuda.amp.autocast(enabled=True)
if device.type == 'mps':
return torch.autocast(device_type='mps', dtype=torch.float16)
return nullcontext()


def empty_cache(device: torch.device):
"""Clear device cache and synchronize."""
if device.type == 'cuda':
torch.cuda.empty_cache()
torch.cuda.current_stream().synchronize()
elif device.type == 'mps':
if hasattr(torch.mps, 'empty_cache'):
torch.mps.empty_cache()
if hasattr(torch.mps, 'synchronize'):
torch.mps.synchronize()


def set_all_random_seed(seed: int):
"""Set random seed across all available backends."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
5 changes: 5 additions & 0 deletions requirements-cuda.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# CUDA-specific requirements (Linux with NVIDIA GPU)
# Install with: pip install -r requirements-cuda.txt
--extra-index-url https://download.pytorch.org/whl/cu121
--extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/
-r requirements.txt
6 changes: 2 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
--extra-index-url https://download.pytorch.org/whl/cu121
--extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/ # https://github.com/microsoft/onnxruntime/issues/21684
conformer==0.3.2
deepspeed==0.15.1; sys_platform == 'linux'
diffusers==0.29.0
Expand Down Expand Up @@ -33,8 +31,8 @@ tensorboard==2.14.0
tensorrt-cu12==10.13.3.9; sys_platform == 'linux'
tensorrt-cu12-bindings==10.13.3.9; sys_platform == 'linux'
tensorrt-cu12-libs==10.13.3.9; sys_platform == 'linux'
torch==2.3.1
torchaudio==2.3.1
torch>=2.3.1
torchaudio>=2.3.1
transformers==4.51.3
x-transformers==2.11.24
uvicorn==0.30.0
Expand Down
Loading