Skip to content

Commit 8c78024

Browse files
committed
tune config to speedup
1 parent c9cfef1 commit 8c78024

4 files changed

Lines changed: 16 additions & 6 deletions

File tree

cosyvoice/cosyvoice/cli/frontend.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,16 @@ def __init__(self,
6464
if self.use_ttsfrd:
6565
self.frd = ttsfrd.TtsFrontendEngine()
6666
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
67-
assert self.frd.initialize('{}/../../pretrained_models/CosyVoice-ttsfrd/resource'.format(ROOT_DIR)) is True, \
68-
'failed to initialize ttsfrd resource'
67+
resource_paths = [
68+
'{}/../../../pretrained_models/CosyVoice-ttsfrd/resource'.format(ROOT_DIR),
69+
'{}/../../pretrained_models/CosyVoice-ttsfrd/resource'.format(ROOT_DIR)
70+
]
71+
initialized = False
72+
for path in resource_paths:
73+
if self.frd.initialize(path):
74+
initialized = True
75+
break
76+
assert initialized, 'failed to initialize ttsfrd resource'
6977
self.frd.set_lang_type('pinyinvg')
7078
else:
7179
self.zh_tn_model = ZhNormalizer(remove_erhua=False)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"32": {"1": {"BLOCK_SEQ": 16, "BLOCK_N": 16, "stage1_num_warps": 1, "stage1_num_stages": 4, "stage2_num_warps": 1, "stage2_num_stages": 1}, "2": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "stage1_num_warps": 1, "stage1_num_stages": 6, "stage2_num_warps": 1, "stage2_num_stages": 1}, "4": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "stage1_num_warps": 1, "stage1_num_stages": 6, "stage2_num_warps": 1, "stage2_num_stages": 1}, "8": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "stage1_num_warps": 4, "stage1_num_stages": 1, "stage2_num_warps": 1, "stage2_num_stages": 1}}, "64": {"1": {"BLOCK_SEQ": 16, "BLOCK_N": 16, "stage1_num_warps": 1, "stage1_num_stages": 6, "stage2_num_warps": 1, "stage2_num_stages": 1}, "2": {"BLOCK_SEQ": 32, "BLOCK_N": 32, "stage1_num_warps": 1, "stage1_num_stages": 8, "stage2_num_warps": 1, "stage2_num_stages": 1}, "4": {"BLOCK_SEQ": 256, "BLOCK_N": 64, "stage1_num_warps": 4, "stage1_num_stages": 8, "stage2_num_warps": 1, "stage2_num_stages": 1}, "8": {"BLOCK_SEQ": 256, "BLOCK_N": 64, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 1, "stage2_num_stages": 1}}, "128": {"1": {"BLOCK_SEQ": 16, "BLOCK_N": 16, "stage1_num_warps": 1, "stage1_num_stages": 8, "stage2_num_warps": 1, "stage2_num_stages": 1}, "2": {"BLOCK_SEQ": 32, "BLOCK_N": 32, "stage1_num_warps": 1, "stage1_num_stages": 1, "stage2_num_warps": 1, "stage2_num_stages": 1}, "4": {"BLOCK_SEQ": 32, "BLOCK_N": 32, "stage1_num_warps": 1, "stage1_num_stages": 8, "stage2_num_warps": 1, "stage2_num_stages": 1}, "8": {"BLOCK_SEQ": 64, "BLOCK_N": 64, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 1, "stage2_num_stages": 1}}, "256": {"1": {"BLOCK_SEQ": 16, "BLOCK_N": 16, "stage1_num_warps": 1, "stage1_num_stages": 6, "stage2_num_warps": 1, "stage2_num_stages": 1}, "2": {"BLOCK_SEQ": 32, "BLOCK_N": 32, "stage1_num_warps": 1, "stage1_num_stages": 8, "stage2_num_warps": 1, "stage2_num_stages": 1}, "4": {"BLOCK_SEQ": 32, "BLOCK_N": 32, "stage1_num_warps": 1, "stage1_num_stages": 8, "stage2_num_warps": 1, "stage2_num_stages": 1}, "8": {"BLOCK_SEQ": 64, "BLOCK_N": 32, "stage1_num_warps": 1, "stage1_num_stages": 7, "stage2_num_warps": 1, "stage2_num_stages": 1}}, "512": {"1": {"BLOCK_SEQ": 16, "BLOCK_N": 16, "stage1_num_warps": 1, "stage1_num_stages": 7, "stage2_num_warps": 1, "stage2_num_stages": 1}, "2": {"BLOCK_SEQ": 16, "BLOCK_N": 16, "stage1_num_warps": 1, "stage1_num_stages": 6, "stage2_num_warps": 1, "stage2_num_stages": 1}, "4": {"BLOCK_SEQ": 32, "BLOCK_N": 32, "stage1_num_warps": 1, "stage1_num_stages": 8, "stage2_num_warps": 1, "stage2_num_stages": 1}, "8": {"BLOCK_SEQ": 32, "BLOCK_N": 32, "stage1_num_warps": 1, "stage1_num_stages": 8, "stage2_num_warps": 1, "stage2_num_stages": 1}}}

light_tts/models/llama/triton_kernel/flash_decoding.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from frozendict import frozendict
55
from typing import Dict
66

7+
78
class LlamaFlashDecodingStage1KernelConfig(KernelConfigs):
89
kernel_name: str = "triton_flashdecoding"
910

@@ -44,7 +45,7 @@ def try_to_get_best_config(
4445
"stage2_num_stages": 2,
4546
}
4647
return config
47-
48+
4849
@classmethod
4950
def save_config(cls, *args, **kwargs) -> None:
5051
key_params = {
@@ -57,13 +58,13 @@ def save_config(cls, *args, **kwargs) -> None:
5758

5859
cls.store_config(key_params, kwargs["store_json_ans"])
5960

61+
6062
def token_decode_attention_flash_decoding(
6163
q, infer_state, q_head_num, head_dim, cache_k, cache_v, out=None, alloc_tensor_func=torch.empty
6264
):
6365
batch_size = infer_state.batch_size
64-
avg_seq_len_in_batch = infer_state.total_token_num // batch_size
6566
run_config = LlamaFlashDecodingStage1KernelConfig.try_to_get_best_config(
66-
batch_size, avg_seq_len_in_batch, head_dim, q_head_num, cache_k.shape[1], torch.float16
67+
batch_size, infer_state.max_len_in_batch, head_dim, q_head_num, cache_k.shape[1], torch.float16
6768
)
6869
BLOCK_SEQ = run_config["BLOCK_SEQ"]
6970

light_tts/server/api_cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def make_argument_parser() -> argparse.ArgumentParser:
116116
parser.add_argument(
117117
"--graph_max_len_in_batch",
118118
type=int,
119-
default=32768,
119+
default=2048,
120120
help="""Maximum sequence length that can be captured by the cuda graph for decodign stage.
121121
The default value is 8192. It will turn into eagar mode if encounters a larger value. """,
122122
)

0 commit comments

Comments
 (0)