diff --git a/angelslim/compressor/speculative/benchmark/pytorch/__init__.py b/angelslim/compressor/speculative/benchmark/pytorch/__init__.py index ac8a4b0f..0db6e4f5 100644 --- a/angelslim/compressor/speculative/benchmark/pytorch/__init__.py +++ b/angelslim/compressor/speculative/benchmark/pytorch/__init__.py @@ -12,6 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .benchmark_engine import BenchmarkConfig, BenchmarkEngine, BenchmarkMode +from .benchmark_engine import ( + BenchmarkConfig, + BenchmarkEngine, + BenchmarkMode, + TTSBenchmarkEngine, +) -__all__ = ["BenchmarkEngine", "BenchmarkConfig", "BenchmarkMode"] +__all__ = ["BenchmarkEngine", "TTSBenchmarkEngine", "BenchmarkConfig", "BenchmarkMode"] diff --git a/angelslim/compressor/speculative/benchmark/pytorch/benchmark_engine.py b/angelslim/compressor/speculative/benchmark/pytorch/benchmark_engine.py index 4444996a..5b35d45a 100644 --- a/angelslim/compressor/speculative/benchmark/pytorch/benchmark_engine.py +++ b/angelslim/compressor/speculative/benchmark/pytorch/benchmark_engine.py @@ -25,7 +25,11 @@ from angelslim.utils.lazy_imports import fastchat, ray from .generate_baseline_answer import get_model_answers as get_baseline_answers +from .generate_baseline_answer import get_tts_answers as get_tts_baseline_answers +from .generate_baseline_answer import get_tts_audios as get_tts_baseline_audios from .generate_eagle_answer import get_model_answers as get_eagle_answers +from .generate_eagle_answer import get_tts_answers as get_tts_eagle_answers +from .generate_eagle_answer import get_tts_audios as get_tts_eagle_audios class BenchmarkMode(Enum): @@ -77,6 +81,10 @@ class BenchmarkConfig: # Batch settings batch_size: int = 1 + # TTS settings + is_tts: bool = False + generate_audio: bool = False + class BenchmarkEngine: """Core benchmark engine for speculative decoding evaluation""" @@ -343,6 +351,10 @@ def _create_args_namespace(self, mode: str) -> argparse.Namespace: args.early_stop_method = self.config.early_stop_method + # TTS settings + args.is_tts = self.config.is_tts + args.generate_audio = self.config.generate_audio + return args def _get_question_file_path(self) -> str: @@ -397,3 +409,119 @@ def get_performance_summary(self) -> str: summary.append(f"Analysis Report: {self.analysis_file}") return "\n".join(summary) + + +class TTSBenchmarkEngine(BenchmarkEngine): + """Core benchmark engine for speculative decoding evaluation""" + + def _run_eagle_benchmark(self): + """Run Eagle speculative decoding benchmark""" + args = self._create_args_namespace("eagle") + + questions = fastchat.llm_judge.common.load_questions( + self._get_question_file_path(), + self.config.question_begin, + self.config.question_end, + ) + + use_ray = self.config.num_gpus_total // self.config.num_gpus_per_model > 1 + get_answers_func = ( + ray.remote(num_gpus=self.config.num_gpus_per_model)( + get_tts_eagle_answers + ).remote + if use_ray + else get_tts_eagle_answers + ) + + chunk_size = len(questions) // ( + self.config.num_gpus_total // self.config.num_gpus_per_model + ) + ans_handles = [ + get_answers_func( + f"{self.config.model_id}-temperature-{self.config.temperature}", + questions[i : i + chunk_size], + self.eagle_file, + self.config.num_choices, + self.config.temperature, + args, + ) + for i in range(0, len(questions), chunk_size) + ] + + if use_ray: + ray.get(ans_handles) + + self._reorg_answer_file(self.eagle_file) + self.results["eagle_file"] = self.eagle_file + + if self.config.generate_audio: + self._generate_audio("eagle") + + def _run_baseline_benchmark(self): + """Run baseline benchmark""" + args = self._create_args_namespace("baseline") + + questions = fastchat.llm_judge.common.load_questions( + self._get_question_file_path(), + self.config.question_begin, + self.config.question_end, + ) + + use_ray = self.config.num_gpus_total // self.config.num_gpus_per_model > 1 + get_answers_func = ( + ray.remote(num_gpus=self.config.num_gpus_per_model)( + get_tts_baseline_answers + ).remote + if use_ray + else get_tts_baseline_answers + ) + + chunk_size = len(questions) // ( + self.config.num_gpus_total // self.config.num_gpus_per_model + ) + ans_handles = [ + get_answers_func( + f"{self.config.model_id}-temperature-{self.config.temperature}", + questions[i : i + chunk_size], + self.baseline_file, + self.config.num_choices, + self.config.temperature, + args, + ) + for i in range(0, len(questions), chunk_size) + ] + + if use_ray: + ray.get(ans_handles) + + self._reorg_answer_file(self.baseline_file) + self.results["baseline_file"] = self.baseline_file + + if self.config.generate_audio: + self._generate_audio("baseline") + + def _calculate_metrics(self) -> Dict[str, Any]: + """Calculate acceptance length and speedup ratio""" + metrics = {} + + # Calculate acceptance length from Eagle results + if os.path.exists(self.eagle_file): + metrics["acceptance_length"] = self._calculate_acceptance_length( + self.eagle_file + ) + + return metrics + + def _generate_audio(self, mode): + args = self._create_args_namespace(mode) + + answers = fastchat.llm_judge.common.load_questions( + args.answer_file, + self.config.question_begin, + self.config.question_end, + ) + + if mode == "baseline": + get_tts_baseline_audios(answers, args.answer_file, args) + else: + get_tts_eagle_audios(answers, args.answer_file, args) diff --git a/angelslim/compressor/speculative/benchmark/pytorch/generate_baseline_answer.py b/angelslim/compressor/speculative/benchmark/pytorch/generate_baseline_answer.py index a0c72a50..dba559e4 100644 --- a/angelslim/compressor/speculative/benchmark/pytorch/generate_baseline_answer.py +++ b/angelslim/compressor/speculative/benchmark/pytorch/generate_baseline_answer.py @@ -17,15 +17,18 @@ import os import random import time -from typing import Any, Dict, List +from typing import Any, Dict, Generator, List import numpy as np import shortuuid import torch from tqdm import tqdm -from angelslim.compressor.speculative.inference.models import Eagle3Model -from angelslim.utils.lazy_imports import fastchat, ray +from angelslim.compressor.speculative.inference.models import ( + CosyVoice3Eagle3Model, + Eagle3Model, +) +from angelslim.utils.lazy_imports import fastchat, ray, torchaudio SYSTEM_PROMPT = { "role": "system", @@ -56,6 +59,7 @@ def __init__(self, args: argparse.Namespace): self.total_token = args.total_token self.depth = args.depth self.top_k = args.top_k + self.generate_audio = args.generate_audio def _get_question_file_path(self, args: argparse.Namespace) -> str: script_dir = os.path.dirname(__file__) @@ -99,6 +103,24 @@ def initialize_model(config: EvaluationConfig) -> Eagle3Model: return model +def initialize_cosycoice3_model(config: EvaluationConfig) -> CosyVoice3Eagle3Model: + """Initialize and return the Eagle3 model""" + model = CosyVoice3Eagle3Model.from_pretrained( + base_model_path=config.base_model_path, + eagle_model_path=config.eagle_model_path, + total_token=config.total_token, + depth=config.depth, + top_k=config.top_k, + device_map="auto", + torch_dtype="auto", + generate_audio=config.generate_audio, + ) + model.eval() + print(f"Model training state: {model.training}") + print(f'CUDA VISIBLE DEVICES: {os.environ.get("CUDA_VISIBLE_DEVICES")}') + return model + + def process_conversation_turn( model: Eagle3Model, tokenizer: Any, @@ -146,6 +168,107 @@ def process_conversation_turn( } +def process_tts_conversation_turn( + model: Eagle3Model, + model_id: str, + qs: str, + temperature: float, + path: str, + is_cosyvoice3: bool = False, + max_token_text_ratio: float = 20.0, + min_token_text_ratio: float = 2.0, +) -> Dict[str, Any]: + """Process a single question""" + if is_cosyvoice3: + prompt_text = model.base_model.frontend.text_normalize( + qs["prompt_text"], split=False, text_frontend=True + ) + prompt_wav = os.path.normpath(os.path.join(path, qs["prompt_wav"])) + for i in tqdm( + model.base_model.frontend.text_normalize( + qs["tts_text"], split=True, text_frontend=True + ) + ): + if (not isinstance(i, Generator)) and len(i) < 0.5 * len(prompt_text): + print( + "synthesis text {} too short than prompt text {}, this may lead to bad performance".format( # noqa: E501 + i, prompt_text + ) + ) + model_input = model.base_model.frontend.frontend_zero_shot( + i, prompt_text, prompt_wav, model.base_model.sample_rate, "" + ) + + torch.cuda.synchronize() + start_time = time.time() + + dtype = model_input["text"].dtype + device = model_input["text"].device + + input_ids = torch.concat( + [ + model.base_model.model.llm.sos_id.unsqueeze(dim=0) + .to(dtype) + .to(device), + model_input["prompt_text"], + model_input["text"], + model.base_model.model.llm.task_token.unsqueeze(dim=0) + .to(dtype) + .to(device), + model_input["llm_prompt_speech_token"], + ], + dim=1, + ) + + # concat llm input embedding + text = torch.concat( + [model_input["prompt_text"], model_input["text"]], dim=1 + ) + text_emb = model.base_model.model.llm.llm.model.model.embed_tokens(text) + sos_emb = model.base_model.model.llm.speech_embedding.weight[ + model.base_model.model.llm.sos + ].reshape(1, 1, -1) + task_id_emb = model.base_model.model.llm.speech_embedding.weight[ + model.base_model.model.llm.task_id + ].reshape(1, 1, -1) + if model_input["llm_prompt_speech_token_len"][0].item() != 0: + prompt_speech_token_emb = model.base_model.model.llm.speech_embedding( + model_input["llm_prompt_speech_token"] + ) + else: + prompt_speech_token_emb = torch.zeros( + 1, 0, model.base_model.model.llm.llm_input_size, dtype=text.dtype + ).to(device) + inputs_embeds = torch.concat( + [sos_emb, text_emb, task_id_emb, prompt_speech_token_emb], dim=1 + ) + min_len = int(model_input["text"].shape[1] * min_token_text_ratio) + max_decode_steps = int(model_input["text"].shape[1] * max_token_text_ratio) + + output_ids, new_token, idx = model.naive_generate( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + temperature=temperature, + log=True, + min_len=min_len, + max_decode_steps=max_decode_steps, + ) + + torch.cuda.synchronize() + total_time = time.time() - start_time + output_ids = output_ids[0][-new_token:] + + return { + "tts_text": qs["tts_text"], + "prompt_text": qs["prompt_text"], + "prompt_wav": prompt_wav, + "output_audio_tokens": output_ids, + "idx": int(idx), + "new_token": int(new_token), + "wall_time": total_time, + } + + def generate_answer_for_question( model: Eagle3Model, tokenizer: Any, @@ -183,6 +306,45 @@ def generate_answer_for_question( return choices +def generate_answer_for_question_tts( + model: Eagle3Model, + model_id: str, + question: Dict[str, Any], + num_choices: int, + temperature: float, + path: str, + is_cosyvoice3: bool = False, +) -> List[Dict[str, Any]]: + """Generate answers for a single question with multiple choices""" + choices = [] + for i in range(num_choices): + torch.manual_seed(i) + + result = process_tts_conversation_turn( + model, + model_id, + question, + temperature, + path, + is_cosyvoice3, + ) + + choices.append( + { + "index": i, + "tts_text": result["tts_text"], + "prompt_text": result["prompt_text"], + "prompt_wav": result["prompt_wav"], + "output_audio_tokens": result["output_audio_tokens"].tolist(), + "idxs": result["idx"], + "new_tokens": result["new_token"], + "wall_time": result["wall_time"], + } + ) + + return choices + + def warmup_model( model: Eagle3Model, tokenizer: Any, question: Dict[str, Any], temperature: float ) -> None: @@ -195,6 +357,23 @@ def warmup_model( print("Warmup done") +def warmup_tts_lm( + model: Eagle3Model, + model_id: str, + question: Dict[str, Any], + temperature: float, + path: str, + is_cosyvoice3: bool = False, +) -> None: + """Warm up the model before actual evaluation""" + for _ in range(3): + torch.manual_seed(0) + process_tts_conversation_turn( + model, model_id, question, temperature, path, is_cosyvoice3 + ) + print("Warmup done") + + @torch.inference_mode() def get_model_answers( model_id: str, @@ -230,6 +409,111 @@ def get_model_answers( fout.write(json.dumps(ans_json) + "\n") +@torch.inference_mode() +def get_tts_answers( + model_id: str, + questions: List[Dict[str, Any]], + answer_file: str, + num_choices: int, + temperature: float, + args: argparse.Namespace, +) -> None: + """Generate answers for a batch of questions""" + config = EvaluationConfig(args) + is_cosyvoice3 = False + if os.path.exists(os.path.join(args.base_model_path, "cosyvoice3.yaml")): + model = initialize_cosycoice3_model(config) + is_cosyvoice3 = True + + if questions: + current_file = os.path.abspath(__file__) + project_root = current_file.split("/AngelSlim/")[0] + "/AngelSlim" + warmup_tts_lm( + model, + model_id, + questions[0], + temperature, + os.path.join(project_root, "dataset", args.bench_name), + is_cosyvoice3, + ) + + os.makedirs(os.path.dirname(answer_file), exist_ok=True) + + i = 0 + for question in tqdm(questions): + choices = generate_answer_for_question_tts( + model, + model_id, + question, + num_choices, + temperature, + os.path.join(project_root, "dataset", args.bench_name), + is_cosyvoice3, + ) + + with open(os.path.expanduser(answer_file), "a") as fout: + ans_json = { + "question_id": i, + "answer_id": shortuuid.uuid(), + "model_id": model_id, + "choices": choices, + "tstamp": time.time(), + } + fout.write(json.dumps(ans_json) + "\n") + + i += 1 + + +def get_tts_audios( + answers: List[Dict[str, Any]], + answer_file: str, + args: argparse.Namespace, +) -> None: + """Generate audios for a batch of audio tokens""" + config = EvaluationConfig(args) + if os.path.exists(os.path.join(args.base_model_path, "cosyvoice3.yaml")): + model = initialize_cosycoice3_model(config) + + for answer in tqdm(answers): + prompt_text = model.base_model.frontend.text_normalize( + answer["choices"][0]["prompt_text"], split=False, text_frontend=True + ) + prompt_wav = answer["choices"][0]["prompt_wav"] + for i in tqdm( + model.base_model.frontend.text_normalize( + answer["choices"][0]["tts_text"], split=True, text_frontend=True + ) + ): + model_input = model.base_model.frontend.frontend_zero_shot( + i, prompt_text, prompt_wav, model.base_model.sample_rate, "" + ) + + tts_speech_token = answer["choices"][0]["output_audio_tokens"] + while tts_speech_token[-1] == model.base_model.model.llm.eos_token: + del tts_speech_token[-1] + this_tts_speech_token = torch.tensor(tts_speech_token).unsqueeze(dim=0) + this_tts_speech = model.base_model.model.token2wav( + token=this_tts_speech_token, + prompt_token=model_input["flow_prompt_speech_token"], + prompt_feat=model_input["prompt_speech_feat"], + embedding=model_input["flow_embedding"], + token_offset=0, + uuid="", + finalize=True, + speed=1.0, + ) + this_tts_speech = this_tts_speech.cpu() + directory = os.path.dirname(answer_file) + os.makedirs(f"{directory}/baseline", exist_ok=True) + torchaudio.save( + f"{directory}/baseline/eval_{answer['question_id']}.wav", + this_tts_speech, + model.base_model.sample_rate, + ) + else: + raise NotImplementedError("Model not supported") + + def run_evaluation(config: EvaluationConfig, args: argparse.Namespace) -> None: """Run the evaluation with optional distributed processing""" questions = fastchat.llm_judge.common.load_questions( diff --git a/angelslim/compressor/speculative/benchmark/pytorch/generate_eagle_answer.py b/angelslim/compressor/speculative/benchmark/pytorch/generate_eagle_answer.py index 9451b742..2b789b7a 100644 --- a/angelslim/compressor/speculative/benchmark/pytorch/generate_eagle_answer.py +++ b/angelslim/compressor/speculative/benchmark/pytorch/generate_eagle_answer.py @@ -17,15 +17,18 @@ import os import random import time -from typing import Any, Dict, List +from typing import Any, Dict, Generator, List import numpy as np import shortuuid import torch from tqdm import tqdm -from angelslim.compressor.speculative.inference.models import Eagle3Model -from angelslim.utils.lazy_imports import fastchat, ray +from angelslim.compressor.speculative.inference.models import ( + CosyVoice3Eagle3Model, + Eagle3Model, +) +from angelslim.utils.lazy_imports import fastchat, ray, torchaudio SYSTEM_PROMPT = { "role": "system", @@ -57,6 +60,7 @@ def __init__(self, args: argparse.Namespace): self.depth = args.depth self.top_k = args.top_k self.early_stop_method = args.early_stop_method + self.generate_audio = args.generate_audio def _get_question_file_path(self, args: argparse.Namespace) -> str: script_dir = os.path.dirname(__file__) @@ -99,6 +103,25 @@ def initialize_model(config: EvaluationConfig) -> Eagle3Model: return model +def initialize_cosycoice3_model(config: EvaluationConfig) -> CosyVoice3Eagle3Model: + """Initialize and return the Eagle3 model""" + model = CosyVoice3Eagle3Model.from_pretrained( + base_model_path=config.base_model_path, + eagle_model_path=config.eagle_model_path, + total_token=config.total_token, + depth=config.depth, + top_k=config.top_k, + device_map="auto", + torch_dtype="auto", + early_stop_method=config.early_stop_method, + generate_audio=config.generate_audio, + ) + model.eval() + print(f"Model training state: {model.training}") + print(f'CUDA VISIBLE DEVICES: {os.environ.get("CUDA_VISIBLE_DEVICES")}') + return model + + def process_conversation_turn( model: Eagle3Model, tokenizer: Any, @@ -147,6 +170,103 @@ def process_conversation_turn( } +def process_tts_conversation_turn( + model: Eagle3Model, + model_id: str, + qs: str, + temperature: float, + path: str, + is_cosyvoice3: bool = False, +) -> Dict[str, Any]: + """Process a single question""" + if is_cosyvoice3: + prompt_text = model.base_model.frontend.text_normalize( + qs["prompt_text"], split=False, text_frontend=True + ) + prompt_wav = os.path.normpath(os.path.join(path, qs["prompt_wav"])) + for i in tqdm( + model.base_model.frontend.text_normalize( + qs["tts_text"], split=True, text_frontend=True + ) + ): + if (not isinstance(i, Generator)) and len(i) < 0.5 * len(prompt_text): + print( + "synthesis text {} too short than prompt text {}, this may lead to bad performance".format( # noqa: E501 + i, prompt_text + ) + ) + model_input = model.base_model.frontend.frontend_zero_shot( + i, prompt_text, prompt_wav, model.base_model.sample_rate, "" + ) + + torch.cuda.synchronize() + start_time = time.time() + + dtype = model_input["text"].dtype + device = model_input["text"].device + + input_ids = torch.concat( + [ + model.base_model.model.llm.sos_id.unsqueeze(dim=0) + .to(dtype) + .to(device), + model_input["prompt_text"], + model_input["text"], + model.base_model.model.llm.task_token.unsqueeze(dim=0) + .to(dtype) + .to(device), + model_input["llm_prompt_speech_token"], + ], + dim=1, + ) + + # concat llm input embedding + text = torch.concat( + [model_input["prompt_text"], model_input["text"]], dim=1 + ) + text_emb = model.base_model.model.llm.llm.model.model.embed_tokens(text) + sos_emb = model.base_model.model.llm.speech_embedding.weight[ + model.base_model.model.llm.sos + ].reshape(1, 1, -1) + task_id_emb = model.base_model.model.llm.speech_embedding.weight[ + model.base_model.model.llm.task_id + ].reshape(1, 1, -1) + if model_input["llm_prompt_speech_token_len"][0].item() != 0: + prompt_speech_token_emb = model.base_model.model.llm.speech_embedding( + model_input["llm_prompt_speech_token"] + ) + else: + prompt_speech_token_emb = torch.zeros( + 1, 0, model.base_model.model.llm.llm_input_size, dtype=text.dtype + ).to(device) + inputs_embeds = torch.concat( + [sos_emb, text_emb, task_id_emb, prompt_speech_token_emb], dim=1 + ) + + output_ids, new_token, idx, accept_length_list = model.eagle_generate( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + temperature=temperature, + log=True, + is_cosyvoice3=True, + ) + + torch.cuda.synchronize() + total_time = time.time() - start_time + output_ids = output_ids[0][-new_token:] + + return { + "tts_text": qs["tts_text"], + "prompt_text": qs["prompt_text"], + "prompt_wav": prompt_wav, + "output_audio_tokens": output_ids, + "idx": int(idx), + "new_token": int(new_token), + "wall_time": total_time, + "accept_length_list": accept_length_list, + } + + def generate_answer_for_question( model: Eagle3Model, tokenizer: Any, @@ -187,6 +307,46 @@ def generate_answer_for_question( return choices +def generate_answer_for_question_tts( + model: Eagle3Model, + model_id: str, + question: Dict[str, Any], + num_choices: int, + temperature: float, + path: str, + is_cosyvoice3: bool = False, +) -> List[Dict[str, Any]]: + """Generate answers for a single question with multiple choices""" + choices = [] + for i in range(num_choices): + torch.manual_seed(i) + + result = process_tts_conversation_turn( + model, + model_id, + question, + temperature, + path, + is_cosyvoice3, + ) + + choices.append( + { + "index": i, + "tts_text": result["tts_text"], + "prompt_text": result["prompt_text"], + "prompt_wav": result["prompt_wav"], + "output_audio_tokens": result["output_audio_tokens"].tolist(), + "idxs": result["idx"], + "new_tokens": result["new_token"], + "wall_time": result["wall_time"], + "accept_length": result["accept_length_list"], + } + ) + + return choices + + def warmup_model( model: Eagle3Model, tokenizer: Any, question: Dict[str, Any], temperature: float ) -> None: @@ -199,6 +359,23 @@ def warmup_model( print("Warmup done") +def warmup_tts_lm( + model: Eagle3Model, + model_id: str, + question: Dict[str, Any], + temperature: float, + path: str, + is_cosyvoice3: bool = False, +) -> None: + """Warm up the model before actual evaluation""" + for _ in range(3): + torch.manual_seed(0) + process_tts_conversation_turn( + model, model_id, question, temperature, path, is_cosyvoice3 + ) + print("Warmup done") + + @torch.inference_mode() def get_model_answers( model_id: str, @@ -234,6 +411,111 @@ def get_model_answers( fout.write(json.dumps(ans_json) + "\n") +@torch.inference_mode() +def get_tts_answers( + model_id: str, + questions: List[Dict[str, Any]], + answer_file: str, + num_choices: int, + temperature: float, + args: argparse.Namespace, +) -> None: + """Generate answers for a batch of questions""" + config = EvaluationConfig(args) + is_cosyvoice3 = False + if os.path.exists(os.path.join(args.base_model_path, "cosyvoice3.yaml")): + model = initialize_cosycoice3_model(config) + is_cosyvoice3 = True + + if questions: + current_file = os.path.abspath(__file__) + project_root = current_file.split("/AngelSlim/")[0] + "/AngelSlim" + warmup_tts_lm( + model, + model_id, + questions[0], + temperature, + os.path.join(project_root, "dataset", args.bench_name), + is_cosyvoice3, + ) + + os.makedirs(os.path.dirname(answer_file), exist_ok=True) + + i = 0 + for question in tqdm(questions): + choices = generate_answer_for_question_tts( + model, + model_id, + question, + num_choices, + temperature, + os.path.join(project_root, "dataset", args.bench_name), + is_cosyvoice3, + ) + + with open(os.path.expanduser(answer_file), "a") as fout: + ans_json = { + "question_id": i, + "answer_id": shortuuid.uuid(), + "model_id": model_id, + "choices": choices, + "tstamp": time.time(), + } + fout.write(json.dumps(ans_json) + "\n") + + i += 1 + + +def get_tts_audios( + answers: List[Dict[str, Any]], + answer_file: str, + args: argparse.Namespace, +) -> None: + """Generate audios for a batch of audio tokens""" + config = EvaluationConfig(args) + if os.path.exists(os.path.join(args.base_model_path, "cosyvoice3.yaml")): + model = initialize_cosycoice3_model(config) + + for answer in tqdm(answers): + prompt_text = model.base_model.frontend.text_normalize( + answer["choices"][0]["prompt_text"], split=False, text_frontend=True + ) + prompt_wav = answer["choices"][0]["prompt_wav"] + for i in tqdm( + model.base_model.frontend.text_normalize( + answer["choices"][0]["tts_text"], split=True, text_frontend=True + ) + ): + model_input = model.base_model.frontend.frontend_zero_shot( + i, prompt_text, prompt_wav, model.base_model.sample_rate, "" + ) + + tts_speech_token = answer["choices"][0]["output_audio_tokens"] + while tts_speech_token[-1] == model.base_model.model.llm.eos_token: + del tts_speech_token[-1] + this_tts_speech_token = torch.tensor(tts_speech_token).unsqueeze(dim=0) + this_tts_speech = model.base_model.model.token2wav( + token=this_tts_speech_token, + prompt_token=model_input["flow_prompt_speech_token"], + prompt_feat=model_input["prompt_speech_feat"], + embedding=model_input["flow_embedding"], + token_offset=0, + uuid="", + finalize=True, + speed=1.0, + ) + this_tts_speech = this_tts_speech.cpu() + directory = os.path.dirname(answer_file) + os.makedirs(f"{directory}/eagle", exist_ok=True) + torchaudio.save( + f"{directory}/eagle/eval_{answer['question_id']}.wav", + this_tts_speech, + model.base_model.sample_rate, + ) + else: + raise NotImplementedError("Model not supported") + + def run_evaluation(config: EvaluationConfig, args: argparse.Namespace) -> None: """Run the evaluation with optional distributed processing""" questions = fastchat.llm_judge.common.load_questions( diff --git a/angelslim/compressor/speculative/inference/models/__init__.py b/angelslim/compressor/speculative/inference/models/__init__.py index a6696119..c03e73d7 100644 --- a/angelslim/compressor/speculative/inference/models/__init__.py +++ b/angelslim/compressor/speculative/inference/models/__init__.py @@ -12,6 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .eagle3 import Eagle3Model +from .eagle3 import CosyVoice3Eagle3Model, Eagle3Model -__all__ = ["Eagle3Model"] +__all__ = ["Eagle3Model", "CosyVoice3Eagle3Model"] diff --git a/angelslim/compressor/speculative/inference/models/eagle3/__init__.py b/angelslim/compressor/speculative/inference/models/eagle3/__init__.py index 291fcd7a..e16f74ab 100644 --- a/angelslim/compressor/speculative/inference/models/eagle3/__init__.py +++ b/angelslim/compressor/speculative/inference/models/eagle3/__init__.py @@ -12,6 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .eagle3_model import Eagle3Model +from .eagle3_model import Eagle3Model, CosyVoice3Eagle3Model # isort: skip -__all__ = ["Eagle3Model"] +__all__ = ["Eagle3Model", "CosyVoice3Eagle3Model"] diff --git a/angelslim/compressor/speculative/inference/models/eagle3/draft/__init__.py b/angelslim/compressor/speculative/inference/models/eagle3/draft/__init__.py index a4bc5eda..341b0b4e 100644 --- a/angelslim/compressor/speculative/inference/models/eagle3/draft/__init__.py +++ b/angelslim/compressor/speculative/inference/models/eagle3/draft/__init__.py @@ -12,6 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .llama3_eagle3 import Llama3Eagle3Drafter +# isort: off +from .llama3_eagle3 import ( + Llama3Eagle3Drafter, + CosyVoice3Llama3Eagle3Drafter, +) -__all__ = ["Llama3Eagle3Drafter"] +# isort: on + +__all__ = ["Llama3Eagle3Drafter", "CosyVoice3Llama3Eagle3Drafter"] diff --git a/angelslim/compressor/speculative/inference/models/eagle3/draft/base_model.py b/angelslim/compressor/speculative/inference/models/eagle3/draft/base_model.py index cd4108c8..775023c1 100644 --- a/angelslim/compressor/speculative/inference/models/eagle3/draft/base_model.py +++ b/angelslim/compressor/speculative/inference/models/eagle3/draft/base_model.py @@ -177,6 +177,7 @@ def topK_genrate( self, hidden_states: Tensor, input_ids: Tensor, + inputs_embeds: Optional[Tensor] = None, logits_processor: Optional[Any] = None, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: """ @@ -204,18 +205,24 @@ def topK_genrate( sample_token = input_ids[:, -1] input_ids = input_ids[:, 1:] self.initial_position_id = input_ids.shape[1] + if inputs_embeds is not None: + inputs_embeds = inputs_embeds[:, 1:] + assert input_ids.shape[1] == inputs_embeds.shape[1] self.reset() # Generate initial hidden states and tokens last_hidden, past_key_values, early_stop_signal = self._get_initial_hidden( - hidden_states, input_ids + hidden_states, input_ids, inputs_embeds ) self.stable_kv = past_key_values # Generate first level of tokens topk_index, scores = self._get_topk_tokens(last_hidden) - scores_list.append(scores[None]) + if len(scores.shape) == 1: + scores_list.append(scores[None]) + else: + scores_list.append(scores) parents_list.append(torch.zeros(1, dtype=torch.long, device=scores.device)) # Handle vocabulary mapping if needed @@ -272,7 +279,7 @@ def topK_genrate( ) def _get_initial_hidden( - self, hidden_states: Tensor, input_ids: Tensor + self, hidden_states: Tensor, input_ids: Tensor, inputs_embeds: Tensor = None ) -> Tuple[Tensor, Any]: """Get initial hidden states and past key values.""" if hasattr(self, "stable_kv") and self.stable_kv is not None: @@ -280,11 +287,19 @@ def _get_initial_hidden( outputs = self( hidden_states, input_ids=input_ids[:, kv_len:], + inputs_embeds=( + inputs_embeds[:, kv_len:] if inputs_embeds is not None else None + ), past_key_values=self.stable_kv, use_cache=True, ) else: - outputs = self(hidden_states, input_ids=input_ids, use_cache=True) + outputs = self( + hidden_states, + input_ids=input_ids, + inputs_embeds=inputs_embeds, + use_cache=True, + ) out_hidden, past_key_values, early_stop_signal = outputs return out_hidden[:, -1], past_key_values, early_stop_signal @@ -332,7 +347,10 @@ def _process_tree_level( # Get top-k tokens for this level topk_index, topk_p = self._get_topk_tokens(out_hidden[0]) - cu_scores = topk_p + scores[:, None] + if len(scores.shape) == 1: + cu_scores = topk_p + scores[:, None] + else: + cu_scores = topk_p + scores # Select best candidates topk_cs = torch.topk(cu_scores.view(-1), self.top_k, dim=-1) diff --git a/angelslim/compressor/speculative/inference/models/eagle3/draft/llama3_eagle3.py b/angelslim/compressor/speculative/inference/models/eagle3/draft/llama3_eagle3.py index a05fb5ec..ba819c66 100644 --- a/angelslim/compressor/speculative/inference/models/eagle3/draft/llama3_eagle3.py +++ b/angelslim/compressor/speculative/inference/models/eagle3/draft/llama3_eagle3.py @@ -13,11 +13,13 @@ # limitations under the License. import math +import os from typing import List, Optional, Tuple import torch import torch.nn.functional as F -from torch import nn +from huggingface_hub import snapshot_download +from torch import Tensor, nn from transformers.activations import ACT2FN from .base_model import BaseEagle3Drafter @@ -637,8 +639,9 @@ def forward( seq_length_with_past = seq_length past_key_values_length = 0 - with torch.no_grad(): - inputs_embeds = self.embed_tokens(input_ids) + if inputs_embeds is None: + with torch.no_grad(): + inputs_embeds = self.embed_tokens(input_ids) if past_key_values is not None: past_key_values_length = past_key_values[0][0].shape[2] @@ -672,7 +675,9 @@ def forward( past_key_values_length, ) - inputs_embeds = inputs_embeds.to(hidden_states.dtype) + dtype = self.fc.weight.dtype + inputs_embeds = inputs_embeds.to(dtype) + hidden_states = hidden_states.to(dtype) if hidden_states.shape[-1] != inputs_embeds.shape[-1]: hidden_states = self.fc(hidden_states) early_stop_signal: Optional[torch.Tensor] = None @@ -710,3 +715,25 @@ def forward( hidden_states = layer_outputs[0] return hidden_states, next_decoder_cache, early_stop_signal + + +class CosyVoice3Llama3Eagle3Drafter(Llama3Eagle3Drafter): + + def load_embed(self, path: str) -> None: + # Handle HuggingFace model identifier + if not os.path.exists(path): + path = snapshot_download(repo_id=path) + + # Try loading embedding weights + tensor = torch.load("{}/llm.pt".format(path)) + embed_tokens_weight = tensor["speech_embedding.weight"] + + with torch.no_grad(): + self.embed_tokens.weight.copy_(embed_tokens_weight) + + def _get_topk_tokens(self, hidden: Tensor) -> Tuple[Tensor, Tensor]: + """Get top-k tokens from hidden states.""" + logits = self.lm_head(self.norm(hidden)) + probs = self.logsoftmax(logits) + topk = torch.topk(probs, self.top_k, dim=-1) + return topk.indices, topk.values diff --git a/angelslim/compressor/speculative/inference/models/eagle3/eagle3_model.py b/angelslim/compressor/speculative/inference/models/eagle3/eagle3_model.py index 113ec450..c9283722 100644 --- a/angelslim/compressor/speculative/inference/models/eagle3/eagle3_model.py +++ b/angelslim/compressor/speculative/inference/models/eagle3/eagle3_model.py @@ -36,7 +36,8 @@ update_inference_inputs, ) from .configuration_eagle3_model import Eagle3Config -from .draft import Llama3Eagle3Drafter +from .draft import CosyVoice3Llama3Eagle3Drafter, Llama3Eagle3Drafter +from .target import CosyVoice3 as KVCosyVoice3 from .target import LlamaForCausalLM as KVLlamaForCausalLM from .target import Qwen3ForCausalLM as KVQwen3ForCausalLM @@ -72,20 +73,30 @@ class ModelLoader: SUPPORTED_ARCHITECTURES = { "LlamaForCausalLM": KVLlamaForCausalLM, "Qwen3ForCausalLM": KVQwen3ForCausalLM, + "CosyVoice3": KVCosyVoice3, } @classmethod def load_base_model(cls, base_model_path: str, **kwargs) -> nn.Module: """Load base model based on architecture""" - config = AutoConfig.from_pretrained(base_model_path) - if not getattr(config, "architectures", None): - raise ValueError("Base model config missing 'architectures' field") + try: + config = AutoConfig.from_pretrained(base_model_path) + if not getattr(config, "architectures", None): + raise ValueError("Base model config missing 'architectures' field") + arch = config.architectures[0] + except ValueError: + if os.path.exists(os.path.join(base_model_path, "cosyvoice3.yaml")): + arch = "CosyVoice3" + else: + raise ValueError - arch = config.architectures[0] if arch not in cls.SUPPORTED_ARCHITECTURES: raise NotImplementedError(f"Model {arch} not supported") model_class = cls.SUPPORTED_ARCHITECTURES[arch] + if arch == "CosyVoice3": + model = model_class(base_model_path, kwargs["generate_audio"]) + return model return model_class.from_pretrained(base_model_path, **kwargs) @classmethod @@ -225,6 +236,73 @@ def get_padding_token(self, device: torch.device) -> torch.Tensor: return self._padding_token +class CosyVoice3GenerationManager(GenerationManager): + + def prepare_generation( + self, model: "Eagle3Model", input_ids: torch.Tensor, config: GenerationConfig + ) -> GenerationState: + """Prepare all necessary components for generation""" + stop_token_id = model.base_model.model.llm.stop_token_ids + + logits_processor = ( + prepare_logits_processor( + temperature=config.temperature, top_p=config.top_p, top_k=config.top_k + ) + if config.temperature > 1e-5 + else None + ) + + input_ids = input_ids.clone() + model.eagle_layer.reset_kv() + + if hasattr(model, "past_key_values"): + past_key_values = model.past_key_values + model.current_length_data.zero_() + else: + past_key_values, past_key_values_data, current_length_data = ( + initialize_past_key_values( + model.base_model.model.llm.llm.model, max_length=config.max_length + ) + ) + model.past_key_values = past_key_values + model.past_key_values_data = past_key_values_data + model.current_length_data = current_length_data + + # reset_tree_mode(model) + model.base_model.model.llm.llm.model.model.tree_mask = None + model.base_model.model.llm.llm.model.model.tree_mode = None + + return GenerationState( + stop_token_id=stop_token_id, + logits_processor=logits_processor, + input_ids=input_ids, + past_key_values=past_key_values, + input_len=input_ids.shape[1], + ) + + def should_stop( + self, + input_ids: torch.Tensor, + input_len: int, + new_token: int, + config: GenerationConfig, + stop_token_id: Optional[int], + ) -> bool: + """Check if generation should stop""" + if stop_token_id is not None: + stop_tensor = torch.tensor(stop_token_id, device=input_ids.device) + if torch.any(torch.isin(input_ids[0, input_len:], stop_tensor)): + return True + + if new_token > config.max_new_tokens: + return True + + if input_ids.shape[1] > config.max_length: + return True + + return False + + class Eagle3Model(nn.Module): """ EAGLE3 Model for speculative decoding with improved structure and maintainability @@ -353,6 +431,7 @@ def forward( def eagle_generate( self, input_ids: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, temperature: float = 0.0, top_p: float = 0.0, top_k: float = 0.0, @@ -360,6 +439,7 @@ def eagle_generate( max_length: int = 2048, log: bool = False, is_llama3: bool = False, + is_cosyvoice3: bool = False, early_stop_smooth_type: str = "ewma", ) -> Union[torch.Tensor, Tuple[torch.Tensor, int, int, List[int]]]: """Generate text using EAGLE speculative decoding""" @@ -379,7 +459,11 @@ def eagle_generate( # Prefill phase draft_tokens, retrieve_indices, tree_mask, tree_position_ids, logits, _, _ = ( initialize_tree( - input_ids, self, state.past_key_values, state.logits_processor + input_ids, + inputs_embeds, + self, + state.past_key_values, + state.logits_processor, ) ) @@ -407,7 +491,10 @@ def eagle_generate( draft_tokens = draft_tokens.to(input_ids.device) tree_position_ids = tree_position_ids.to(input_ids.device) - self.base_model.model.tree_mask = tree_mask + if is_cosyvoice3: + self.base_model.model.llm.llm.model.model.tree_mask = tree_mask + else: + self.base_model.model.tree_mask = tree_mask # Target model forward pass logits, hidden_state_new, _ = tree_decoding( @@ -467,6 +554,7 @@ def eagle_generate( # Update inference inputs ( state.input_ids, + inputs_embeds, draft_tokens, retrieve_indices, tree_mask, @@ -475,6 +563,7 @@ def eagle_generate( early_stop_signal, ) = update_inference_inputs( input_ids=state.input_ids, + inputs_embeds=inputs_embeds, candidates=candidates, best_candidate=best_candidate, accept_length=accept_length, @@ -562,3 +651,179 @@ def naive_generate( break return (state.input_ids, state.new_token, step) if log else state.input_ids + + +class CosyVoice3Eagle3Model(Eagle3Model): + """ + CosyVoice3 EAGLE3 Model for speculative decoding + """ + + def __init__( + self, + base_model: nn.Module, + tokenizer: AutoTokenizer, + eagle_layer: nn.Module, + early_stop_method: Optional[str] = None, + ): + super().__init__(base_model, tokenizer, eagle_layer, early_stop_method) + self.generation_manager = CosyVoice3GenerationManager(tokenizer) + + @classmethod + def from_pretrained( + cls, + base_model_path: Optional[str] = None, + eagle_model_path: Optional[str] = None, + total_token: int = 60, + depth: int = 7, + top_k: int = 10, + threshold: float = 1.0, + enable_benchmark: bool = False, + early_stop_method: Optional[str] = None, + stop_think_token: str = "", + step_split_tokens: Optional[List[str]] = None, + **kwargs, + ) -> "CosyVoice3Eagle3Model": + """Create CosyVoice3Eagle3Model from pretrained components""" + # Load base model and tokenizer + if not step_split_tokens: + step_split_tokens = [ + "\n\n", + "\n\n\n", + ".\n\n", + ".\n\n\n", + " \n\n", + " \n\n\n", + ] + base_model = ModelLoader.load_base_model(base_model_path, **kwargs) + tokenizer = base_model.frontend.tokenizer + tokenizer.stop_think_id = tokenizer.encode( + stop_think_token, add_special_tokens=False + )[0] + tokenizer.step_split_ids = [] + for s in step_split_tokens: + t = tokenizer.encode(s, add_special_tokens=False) + if len(t) > 1: + continue + tokenizer.step_split_ids.append(t[0]) + # Load configuration + config_path = ModelLoader.ensure_config_path(eagle_model_path) + config = Eagle3Config.from_pretrained(config_path) + + # Initialize EAGLE layer + device = next(base_model.model.llm.parameters()).device + eagle_state_dict = ModelLoader.load_eagle_state_dict(eagle_model_path, device) + + # TODO: Implement factory pattern for different drafter types + eagle_layer = CosyVoice3Llama3Eagle3Drafter( + config, + total_tokens=total_token, + depth=depth, + top_k=top_k, + threshold=threshold, + path=base_model_path, + load_emb=True, + early_stop_method=early_stop_method, + ) + + # Clean up unused components + if config.vocab_size == config.draft_vocab_size: + del eagle_layer.d2t + del eagle_layer.t2d + + eagle_layer.load_state_dict(eagle_state_dict, strict=False) + eagle_layer.to(device=device, dtype=base_model.dtype) + eagle_layer.init_tree() + + # Auto-select optimal token count if needed + if total_token == -1 and enable_benchmark: + total_token = PerformanceBenchmark.auto_select_total_token( + base_model, config.vocab_size + ) + eagle_layer.total_tokens = total_token - 1 + + return cls(base_model, tokenizer, eagle_layer, early_stop_method) + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Any] = None, + output_orig: bool = False, + position_ids: Optional[torch.Tensor] = None, + ) -> Union[Tuple[Any, torch.Tensor], Tuple[Any, torch.Tensor, torch.Tensor]]: + """Forward pass through the model""" + outputs, orig = self.base_model.model.llm( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + ) + return outputs, orig, None + + @torch.no_grad() + def naive_generate( + self, + input_ids: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + temperature: float = 0.0, + top_p: float = 0.0, + top_k: float = 0.0, + max_new_tokens: int = 512, + max_length: int = 2048, + log: bool = False, + is_llama3: bool = False, + min_len: Optional[int] = None, + max_decode_steps: Optional[int] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, int, int]]: + """Generate text using naive (non-speculative) decoding""" + config = GenerationConfig( + temperature=temperature, + top_p=top_p, + top_k=top_k, + max_new_tokens=max_new_tokens, + max_length=max_length, + log=log, + is_llama3=is_llama3, + ) + + state = self.generation_manager.prepare_generation(self, input_ids, config) + + _, logits = self.base_model.model.llm( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + past_key_values=state.past_key_values, + ) + + out_tokens = [] + + for step in range(max_decode_steps): # noqa: B007 + input_id = self.base_model.model.llm.sampling_ids( + logits[:, -1].squeeze(dim=0), + out_tokens, + ignore_eos=True if step < min_len else False, + ) + out_tokens.append(input_id) + input_id = ( + torch.tensor(input_id, device=state.input_ids.device) + .unsqueeze(0) + .unsqueeze(0) + ) + + _, logits = self.base_model.model.llm( + input_id, past_key_values=state.past_key_values + ) + state.input_ids = torch.cat([state.input_ids, input_id], dim=-1) + state.new_token += 1 + + if self.generation_manager.should_stop( + state.input_ids, + state.input_len, + state.new_token, + config, + state.stop_token_id, + ): + break + + return (state.input_ids, state.new_token, step) if log else state.input_ids diff --git a/angelslim/compressor/speculative/inference/models/eagle3/target/__init__.py b/angelslim/compressor/speculative/inference/models/eagle3/target/__init__.py index 530d0478..733a55e5 100644 --- a/angelslim/compressor/speculative/inference/models/eagle3/target/__init__.py +++ b/angelslim/compressor/speculative/inference/models/eagle3/target/__init__.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .modeling_cosyvoice3_kv import CosyVoice3 from .modeling_llama_kv import LlamaForCausalLM from .modeling_qwen3_kv import Qwen3ForCausalLM -__all__ = ["LlamaForCausalLM", "Qwen3ForCausalLM"] +__all__ = ["LlamaForCausalLM", "Qwen3ForCausalLM", "CosyVoice3"] diff --git a/angelslim/compressor/speculative/inference/models/eagle3/target/modeling_cosyvoice3_kv.py b/angelslim/compressor/speculative/inference/models/eagle3/target/modeling_cosyvoice3_kv.py new file mode 100644 index 00000000..d0398f0c --- /dev/null +++ b/angelslim/compressor/speculative/inference/models/eagle3/target/modeling_cosyvoice3_kv.py @@ -0,0 +1,937 @@ +# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du) +# 2025 Alibaba Inc (authors: Xiang Lyu, Yabin Li, Qihua, Shengqiang Li) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from https://github.com/FunAudioLLM/CosyVoice for AngelSlim project + +import functools +import os +import re +from functools import partial +from typing import Any, Callable, Generator, List, Optional + +import numpy as np +import regex +import torch +from torch import nn +from torch.nn import functional as F +from transformers import AutoTokenizer +from transformers.configuration_utils import PretrainedConfig + +from .......utils.lazy_imports import ( + inflect, + librosa, + onnxruntime, + torchaudio, + wetext, + whisper, +) +from .modeling_qwen2_kv import Qwen2ForCausalLM + +IGNORE_ID = -1 +# cosyvoice3 fixed params +use_ttsfrd = False +sample_rate = 24000 +llm_input_size = 896 +llm_output_size = 896 +spk_embed_dim = 192 +token_frame_rate = 25 +token_mel_ratio = 2 +# stream related params +chunk_size = 25 # streaming inference chunk size, in token +num_decoding_left_chunks = ( + -1 +) # streaming inference flow decoder left chunk size, <0 means use all left chunks + + +# Repetition Aware Sampling in VALL-E 2 +def ras_sampling( + weighted_scores, + decoded_tokens, + sampling, + top_p=0.8, + top_k=25, + win_size=10, + tau_r=0.1, +): + top_ids = nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k) + rep_num = ( + (torch.tensor(decoded_tokens[-win_size:]).to(weighted_scores.device) == top_ids) + .sum() + .item() + ) + if rep_num >= win_size * tau_r: + top_ids = random_sampling(weighted_scores, decoded_tokens, sampling) + return top_ids + + +def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25): + prob, indices = [], [] + cum_prob = 0.0 + sorted_value, sorted_idx = weighted_scores.softmax(dim=0).sort( + descending=True, stable=True + ) + for i in range(len(sorted_idx)): + # sampling both top-p and numbers. + if cum_prob < top_p and len(prob) < top_k: + cum_prob += sorted_value[i] + prob.append(sorted_value[i]) + indices.append(sorted_idx[i]) + else: + break + prob = torch.tensor(prob).to(weighted_scores) + indices = torch.tensor(indices, dtype=torch.long).to(weighted_scores.device) + top_ids = indices[prob.multinomial(1, replacement=True)].item() + return top_ids + + +def random_sampling(weighted_scores, decoded_tokens, sampling): + top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True).item() + return top_ids + + +def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: + """Make mask tensor containing indices of padded part. + + See description of make_non_pad_mask. + + Args: + lengths (torch.Tensor): Batch of lengths (B,). + Returns: + torch.Tensor: Mask tensor containing indices of padded part. + + Examples: + >>> lengths = [5, 3, 2] + >>> make_pad_mask(lengths) + masks = [[0, 0, 0, 0 ,0], + [0, 0, 0, 1, 1], + [0, 0, 1, 1, 1]] + """ + batch_size = lengths.size(0) + max_len = max_len if max_len > 0 else lengths.max().item() + seq_range = torch.arange(0, max_len, dtype=torch.int64, device=lengths.device) + seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) + seq_length_expand = lengths.unsqueeze(-1) + mask = seq_range_expand >= seq_length_expand + return mask + + +def get_qwen_tokenizer( + token_path: str, skip_special_tokens: bool, version: str = "cosyvoice3" +): + if version == "cosyvoice3": + return CosyVoice3Tokenizer( + token_path=token_path, skip_special_tokens=skip_special_tokens + ) + else: + raise ValueError + + +mel_basis = {} +hann_window = {} + + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def spectral_normalize_torch(magnitudes): + output = dynamic_range_compression_torch(magnitudes) + return output + + +def mel_spectrogram( + y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False +): + if torch.min(y) < -1.0: + print("min value is ", torch.min(y)) + if torch.max(y) > 1.0: + print("max value is ", torch.max(y)) + + global mel_basis, hann_window # pylint: disable=global-statement + if f"{str(fmax)}_{str(y.device)}" not in mel_basis: + mel = librosa.filters.mel( + sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax + ) + mel_basis[str(fmax) + "_" + str(y.device)] = ( + torch.from_numpy(mel).float().to(y.device) + ) + hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), + (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), + mode="reflect", + ) + y = y.squeeze(1) + + spec = torch.view_as_real( + torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window[str(y.device)], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) + ) + + spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) + + spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec) + spec = spectral_normalize_torch(spec) + + return spec + + +chinese_char_pattern = re.compile(r"[\u4e00-\u9fff]+") + + +# whether contain chinese character +def contains_chinese(text): + return bool(chinese_char_pattern.search(text)) + + +# replace special symbol +def replace_corner_mark(text): + text = text.replace("²", "平方") + text = text.replace("³", "立方") + return text + + +# remove meaningless symbol +def remove_bracket(text): + text = text.replace("(", "").replace(")", "") + text = text.replace("【", "").replace("】", "") + text = text.replace("`", "").replace("`", "") + text = text.replace("——", " ") + return text + + +# spell Arabic numerals +def spell_out_number(text: str, inflect_parser): + new_text = [] + st = None + for i, c in enumerate(text): + if not c.isdigit(): + if st is not None: + num_str = inflect_parser.number_to_words(text[st:i]) + new_text.append(num_str) + st = None + new_text.append(c) + else: + if st is None: + st = i + if st is not None and st < len(text): + num_str = inflect_parser.number_to_words(text[st:]) + new_text.append(num_str) + return "".join(new_text) + + +def split_paragraph( + text: str, + tokenize, + lang="zh", + token_max_n=80, + token_min_n=60, + merge_len=20, + comma_split=False, +): + def calc_utt_length(_text: str): + if lang == "zh": + return len(_text) + else: + return len(tokenize(_text)) + + def should_merge(_text: str): + if lang == "zh": + return len(_text) < merge_len + else: + return len(tokenize(_text)) < merge_len + + if lang == "zh": + pounc = ["。", "?", "!", ";", ":", "、", ".", "?", "!", ";"] + else: + pounc = [".", "?", "!", ";", ":"] + if comma_split: + pounc.extend([",", ","]) + + if text[-1] not in pounc: + if lang == "zh": + text += "。" + else: + text += "." + + st = 0 + utts = [] + for i, c in enumerate(text): + if c in pounc: + if len(text[st:i]) > 0: + utts.append(text[st:i] + c) + if i + 1 < len(text) and text[i + 1] in ['"', "”"]: + tmp = utts.pop(-1) + utts.append(tmp + text[i + 1]) + st = i + 2 + else: + st = i + 1 + + final_utts = [] + cur_utt = "" + for utt in utts: + if ( + calc_utt_length(cur_utt + utt) > token_max_n + and calc_utt_length(cur_utt) > token_min_n + ): + final_utts.append(cur_utt) + cur_utt = "" + cur_utt = cur_utt + utt + if len(cur_utt) > 0: + if should_merge(cur_utt) and len(final_utts) != 0: + final_utts[-1] = final_utts[-1] + cur_utt + else: + final_utts.append(cur_utt) + + return final_utts + + +# remove blank between chinese character +def replace_blank(text: str): + out_str = [] + for i, c in enumerate(text): + if c == " ": + if (text[i + 1].isascii() and text[i + 1] != " ") and ( + text[i - 1].isascii() and text[i - 1] != " " + ): + out_str.append(c) + else: + out_str.append(c) + return "".join(out_str) + + +def is_only_punctuation(text): + # Regular expression: Match strings that consist only of punctuation marks or are empty. + punctuation_pattern = r"^[\p{P}\p{S}]*$" + return bool(regex.fullmatch(punctuation_pattern, text)) + + +def load_wav(wav, target_sr, min_sr=16000): + speech, sample_rate = torchaudio.load(wav, backend="soundfile") + speech = speech.mean(dim=0, keepdim=True) + if sample_rate != target_sr: + assert ( + sample_rate >= min_sr + ), "wav sample rate {} must be greater than {}".format(sample_rate, target_sr) + speech = torchaudio.transforms.Resample( + orig_freq=sample_rate, new_freq=target_sr + )(speech) + return speech + + +class Qwen2Encoder(torch.nn.Module): + def __init__(self, pretrain_path): + super().__init__() + self.model = Qwen2ForCausalLM.from_pretrained( + pretrain_path, attn_implementation="eager" + ) + + def forward_one_step( + self, + xs, + masks=None, + past_key_values=None, + position_ids=None, + output_hidden_states=False, + return_hidden_states=False, + ): + if masks is not None: + input_masks = masks[:, -1, :] + else: + input_masks = None + outs = self.model( + inputs_embeds=xs, + attention_mask=input_masks, + output_hidden_states=output_hidden_states, + past_key_values=past_key_values, + position_ids=position_ids, + ) + xs = outs.hidden_states[-1] + + if return_hidden_states: + return xs, outs["hidden_states"][:-1] + return xs + + +class CosyVoice3Tokenizer: + def __init__(self, token_path, skip_special_tokens=True): + # NOTE: non-chat model, all these special tokens keep randomly initialized. + # fmt: off + # flake8: noqa + special_tokens = { + 'eos_token': '<|endoftext|>', + 'pad_token': '<|endoftext|>', + 'additional_special_tokens': [ + '<|im_start|>', '<|im_end|>', '<|endofprompt|>', + '[breath]', '', '', '[noise]', + '[laughter]', '[cough]', '[clucking]', '[accent]', + '[quick_breath]', + "", "", + "[hissing]", "[sigh]", "[vocalized-noise]", + "[lipsmack]", "[mn]", "<|endofsystem|>", + "[AA]", "[AA0]", "[AA1]", "[AA2]", "[AE]", "[AE0]", "[AE1]", "[AE2]", "[AH]", "[AH0]", "[AH1]", "[AH2]", + "[AO]", "[AO0]", "[AO1]", "[AO2]", "[AW]", "[AW0]", "[AW1]", "[AW2]", "[AY]", "[AY0]", "[AY1]", "[AY2]", + "[B]", "[CH]", "[D]", "[DH]", "[EH]", "[EH0]", "[EH1]", "[EH2]", "[ER]", "[ER0]", "[ER1]", "[ER2]", "[EY]", + "[EY0]", "[EY1]", "[EY2]", "[F]", "[G]", "[HH]", "[IH]", "[IH0]", "[IH1]", "[IH2]", "[IY]", "[IY0]", "[IY1]", + "[IY2]", "[JH]", "[K]", "[L]", "[M]", "[N]", "[NG]", "[OW]", "[OW0]", "[OW1]", "[OW2]", "[OY]", "[OY0]", + "[OY1]", "[OY2]", "[P]", "[R]", "[S]", "[SH]", "[T]", "[TH]", "[UH]", "[UH0]", "[UH1]", "[UH2]", "[UW]", + "[UW0]", "[UW1]", "[UW2]", "[V]", "[W]", "[Y]", "[Z]", "[ZH]", + "[a]", "[ai]", "[an]", "[ang]", "[ao]", "[b]", "[c]", "[ch]", "[d]", "[e]", "[ei]", "[en]", "[eng]", "[f]", + "[g]", "[h]", "[i]", "[ian]", "[in]", "[ing]", "[iu]", "[ià]", "[iàn]", "[iàng]", "[iào]", "[iá]", "[ián]", + "[iáng]", "[iáo]", "[iè]", "[ié]", "[iòng]", "[ióng]", "[iù]", "[iú]", "[iā]", "[iān]", "[iāng]", "[iāo]", + "[iē]", "[iě]", "[iōng]", "[iū]", "[iǎ]", "[iǎn]", "[iǎng]", "[iǎo]", "[iǒng]", "[iǔ]", "[j]", "[k]", "[l]", + "[m]", "[n]", "[o]", "[ong]", "[ou]", "[p]", "[q]", "[r]", "[s]", "[sh]", "[t]", "[u]", "[uang]", "[ue]", + "[un]", "[uo]", "[uà]", "[uài]", "[uàn]", "[uàng]", "[uá]", "[uái]", "[uán]", "[uáng]", "[uè]", "[ué]", "[uì]", + "[uí]", "[uò]", "[uó]", "[uā]", "[uāi]", "[uān]", "[uāng]", "[uē]", "[uě]", "[uī]", "[uō]", "[uǎ]", "[uǎi]", + "[uǎn]", "[uǎng]", "[uǐ]", "[uǒ]", "[vè]", "[w]", "[x]", "[y]", "[z]", "[zh]", "[à]", "[ài]", "[àn]", "[àng]", + "[ào]", "[á]", "[ái]", "[án]", "[áng]", "[áo]", "[è]", "[èi]", "[èn]", "[èng]", "[èr]", "[é]", "[éi]", "[én]", + "[éng]", "[ér]", "[ì]", "[ìn]", "[ìng]", "[í]", "[ín]", "[íng]", "[ò]", "[òng]", "[òu]", "[ó]", "[óng]", "[óu]", + "[ù]", "[ùn]", "[ú]", "[ún]", "[ā]", "[āi]", "[ān]", "[āng]", "[āo]", "[ē]", "[ēi]", "[ēn]", "[ēng]", "[ě]", + "[ěi]", "[ěn]", "[ěng]", "[ěr]", "[ī]", "[īn]", "[īng]", "[ō]", "[ōng]", "[ōu]", "[ū]", "[ūn]", "[ǎ]", "[ǎi]", + "[ǎn]", "[ǎng]", "[ǎo]", "[ǐ]", "[ǐn]", "[ǐng]", "[ǒ]", "[ǒng]", "[ǒu]", "[ǔ]", "[ǔn]", "[ǘ]", "[ǚ]", "[ǜ]" + ] + } + # fmt: on + self.special_tokens = special_tokens + self.tokenizer = AutoTokenizer.from_pretrained(token_path) + self.tokenizer.add_special_tokens(special_tokens) + self.skip_special_tokens = skip_special_tokens + + def encode(self, text, **kwargs): + tokens = self.tokenizer([text], return_tensors="pt") + tokens = tokens["input_ids"][0].cpu().tolist() + return tokens + + def decode(self, tokens): + tokens = torch.tensor(tokens, dtype=torch.int64) + text = self.tokenizer.batch_decode( + [tokens], skip_special_tokens=self.skip_special_tokens + )[0] + return text + + +class CosyVoiceFrontEnd: + + def __init__( + self, + get_tokenizer: Callable, + feat_extractor: Callable, + campplus_model: str, + speech_tokenizer_model: str, + spk2info: str = "", + allowed_special: str = "all", + ): + self.tokenizer = get_tokenizer() + self.feat_extractor = feat_extractor + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + option = onnxruntime.SessionOptions() + option.graph_optimization_level = ( + onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL + ) + option.intra_op_num_threads = 1 + self.campplus_session = onnxruntime.InferenceSession( + campplus_model, sess_options=option, providers=["CPUExecutionProvider"] + ) + self.speech_tokenizer_session = onnxruntime.InferenceSession( + speech_tokenizer_model, + sess_options=option, + providers=[ + ( + "CUDAExecutionProvider" + if torch.cuda.is_available() + else "CPUExecutionProvider" + ) + ], + ) + if os.path.exists(spk2info): + self.spk2info = torch.load(spk2info, map_location=self.device) + else: + self.spk2info = {} + self.allowed_special = allowed_special + self.zh_tn_model = wetext.Normalizer(remove_erhua=False) + self.en_tn_model = wetext.Normalizer() + self.inflect_parser = inflect.engine() + + def text_normalize(self, text, split=True, text_frontend=True): + if isinstance(text, Generator): + print("get tts_text generator, will skip text_normalize!") + return [text] + # NOTE skip text_frontend when ssml symbol in text + if "<|" in text and "|>" in text: + text_frontend = False + if text_frontend is False or text == "": + return [text] if split is True else text + text = text.strip() + if contains_chinese(text): + text = self.zh_tn_model.normalize(text) + text = text.replace("\n", "") + text = replace_blank(text) + text = replace_corner_mark(text) + text = text.replace(".", "。") + text = text.replace(" - ", ",") + text = remove_bracket(text) + text = re.sub(r"[,,、]+$", "。", text) + texts = list( + split_paragraph( + text, + partial( + self.tokenizer.encode, allowed_special=self.allowed_special + ), + "zh", + token_max_n=80, + token_min_n=60, + merge_len=20, + comma_split=False, + ) + ) + else: + text = self.en_tn_model.normalize(text) + text = spell_out_number(text, self.inflect_parser) + texts = list( + split_paragraph( + text, + partial( + self.tokenizer.encode, allowed_special=self.allowed_special + ), + "en", + token_max_n=80, + token_min_n=60, + merge_len=20, + comma_split=False, + ) + ) + texts = [i for i in texts if not is_only_punctuation(i)] + return texts if split is True else text + + def _extract_text_token(self, text): + if isinstance(text, Generator): + print("get tts_text generator, will return _extract_text_token_generator!") + # NOTE add a dummy text_token_len for compatibility + return self._extract_text_token_generator(text), torch.tensor( + [0], dtype=torch.int32 + ).to(self.device) + else: + text_token = self.tokenizer.encode( + text, allowed_special=self.allowed_special + ) + text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device) + text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to( + self.device + ) + return text_token, text_token_len + + def _extract_text_token_generator(self, text_generator): + for text in text_generator: + text_token, _ = self._extract_text_token(text) + for i in range(text_token.shape[1]): + yield text_token[:, i : i + 1] + + def _extract_speech_token(self, prompt_wav): + speech = load_wav(prompt_wav, 16000) + assert ( + speech.shape[1] / 16000 <= 30 + ), "do not support extract speech token for audio longer than 30s" + feat = whisper.log_mel_spectrogram(speech, n_mels=128) + speech_token = ( + self.speech_tokenizer_session.run( + None, + { + self.speech_tokenizer_session.get_inputs()[0] + .name: feat.detach() + .cpu() + .numpy(), + self.speech_tokenizer_session.get_inputs()[1].name: np.array( + [feat.shape[2]], dtype=np.int32 + ), + }, + )[0] + .flatten() + .tolist() + ) + speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device) + speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to( + self.device + ) + return speech_token, speech_token_len + + def _extract_spk_embedding(self, prompt_wav): + speech = load_wav(prompt_wav, 16000) + feat = torchaudio.compliance.kaldi.fbank( + speech, num_mel_bins=80, dither=0, sample_frequency=16000 + ) + feat = feat - feat.mean(dim=0, keepdim=True) + embedding = ( + self.campplus_session.run( + None, + { + self.campplus_session.get_inputs()[0] + .name: feat.unsqueeze(dim=0) + .cpu() + .numpy() + }, + )[0] + .flatten() + .tolist() + ) + embedding = torch.tensor([embedding]).to(self.device) + return embedding + + def _extract_speech_feat(self, prompt_wav): + speech = load_wav(prompt_wav, sample_rate) + speech_feat = ( + self.feat_extractor(speech).squeeze(dim=0).transpose(0, 1).to(self.device) + ) + speech_feat = speech_feat.unsqueeze(dim=0) + speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to( + self.device + ) + return speech_feat, speech_feat_len + + def frontend_zero_shot( + self, tts_text, prompt_text, prompt_wav, resample_rate, zero_shot_spk_id + ): + tts_text_token, tts_text_token_len = self._extract_text_token(tts_text) + if zero_shot_spk_id == "": + prompt_text_token, prompt_text_token_len = self._extract_text_token( + prompt_text + ) + speech_feat, speech_feat_len = self._extract_speech_feat(prompt_wav) + speech_token, speech_token_len = self._extract_speech_token(prompt_wav) + if resample_rate == 24000: + # cosyvoice2, force speech_feat % speech_token = 2 + token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1]) + speech_feat, speech_feat_len[:] = ( + speech_feat[:, : 2 * token_len], + 2 * token_len, + ) + speech_token, speech_token_len[:] = ( + speech_token[:, :token_len], + token_len, + ) + embedding = self._extract_spk_embedding(prompt_wav) + model_input = { + "prompt_text": prompt_text_token, + "prompt_text_len": prompt_text_token_len, + "llm_prompt_speech_token": speech_token, + "llm_prompt_speech_token_len": speech_token_len, + "flow_prompt_speech_token": speech_token, + "flow_prompt_speech_token_len": speech_token_len, + "prompt_speech_feat": speech_feat, + "prompt_speech_feat_len": speech_feat_len, + "llm_embedding": embedding, + "flow_embedding": embedding, + } + else: + model_input = self.spk2info[zero_shot_spk_id] + model_input["text"] = tts_text_token + model_input["text_len"] = tts_text_token_len + return model_input + + +class CosyVoice3LM(torch.nn.Module): + def __init__( + self, + model_path, + llm_input_size: int, + llm_output_size: int, + speech_token_size: int, + ): + super().__init__() + self.llm_input_size = llm_input_size + self.llm_output_size = llm_output_size + self.speech_token_size = speech_token_size + # 2. build speech token language model related modules + self.sos = speech_token_size + 0 + self.sos_id = torch.tensor([self.sos]) + self.eos_token = speech_token_size + 1 + self.task_id = speech_token_size + 2 + self.task_token = torch.tensor([self.task_id]) + self.fill_token = speech_token_size + 3 + + self.llm = Qwen2Encoder(os.path.join(model_path, "CosyVoice-BlankEN")) + self.llm_decoder = nn.Linear( + llm_output_size, speech_token_size + 200, bias=False + ) + + # 3. [Optional] build speech token related modules + self.speech_embedding = torch.nn.Embedding( + speech_token_size + 200, llm_input_size + ) + + # 4. sampling method + self.sampling = functools.partial( + ras_sampling, top_p=0.8, top_k=25, win_size=10, tau_r=0.1 + ) + + self.stop_token_ids = [speech_token_size + i for i in range(200)] + + def sampling_ids( + self, + weighted_scores: torch.Tensor, + decoded_tokens: List, + sampling: int = 25, + ignore_eos: bool = True, + ): + num_trials, max_trials = 0, 100 + while True: + top_ids = self.sampling(weighted_scores, decoded_tokens, sampling) + if (not ignore_eos) or (top_ids < self.speech_token_size): + break + num_trials += 1 + if num_trials > max_trials: + raise RuntimeError( + "sampling reaches max_trials {} and still get eos when ignore_eos is True, check your input!".format( + max_trials + ) + ) + return top_ids + + @torch.inference_mode() + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + past_key_values=None, + ) -> List[int]: + if inputs_embeds is None: + inputs_embeds = self.speech_embedding.weight[ + input_ids.squeeze(0).tolist() + ].unsqueeze(0) + # prefill + y_pred, hidden_states = self.llm.forward_one_step( + inputs_embeds, + masks=attention_mask, + past_key_values=past_key_values, + position_ids=position_ids, + output_hidden_states=True, + return_hidden_states=True, + ) + logp = self.llm_decoder(y_pred).log_softmax(dim=-1) + + outputs = {"hidden_states": hidden_states} + + return outputs, logp + + +class CosyVoice3Model: + def __init__( + self, + llm: torch.nn.Module, + flow: Optional[torch.nn.Module], + hift: Optional[torch.nn.Module], + ): + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.llm = llm + self.flow = flow + self.hift = hift + + def load(self, llm_model, flow_model, hift_model): + self.llm.load_state_dict( + torch.load(llm_model, map_location=self.device), strict=True + ) + self.llm.to(self.device).eval() + if self.flow is not None: + self.flow.load_state_dict( + torch.load(flow_model, map_location=self.device), strict=True + ) + self.flow.to(self.device).eval() + if self.hift is not None: + # in case hift_model is a hifigan model + hift_state_dict = { + k.replace("generator.", ""): v + for k, v in torch.load(hift_model, map_location=self.device).items() + } + self.hift.load_state_dict(hift_state_dict, strict=True) + self.hift.to(self.device).eval() + + def token2wav( + self, + token, + prompt_token, + prompt_feat, + embedding, + token_offset, + uuid, + stream=False, + finalize=False, + speed=1.0, + ): + tts_mel, _ = self.flow.inference( + token=token.to(self.device, dtype=torch.int32), + token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device), + prompt_token=prompt_token.to(self.device), + prompt_token_len=torch.tensor( + [prompt_token.shape[1]], dtype=torch.int32 + ).to(self.device), + prompt_feat=prompt_feat.to(self.device), + prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to( + self.device + ), + embedding=embedding.to(self.device), + streaming=stream, + finalize=finalize, + ) + tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio :] + if speed != 1.0: + assert ( + token_offset == 0 and finalize is True + ), "speed change only support non-stream inference mode" + tts_mel = F.interpolate( + tts_mel, size=int(tts_mel.shape[2] / speed), mode="linear" + ) + tts_speech, _ = self.hift.inference(speech_feat=tts_mel, finalize=finalize) + return tts_speech + + +class CosyVoice3: + + def __init__(self, model_dir, generate_audio=False): + self.config = PretrainedConfig.from_pretrained( + os.path.join(model_dir, "CosyVoice-BlankEN") + ) + self.config.model_type = "cosyvoice3" + self.config.txt_tokenizer_path = os.path.join(model_dir, "CosyVoice-BlankEN") + self.dtype = self.config.torch_dtype + + self.model_dir = model_dir + self.frontend = CosyVoiceFrontEnd( + partial( + get_qwen_tokenizer, + token_path=self.config.txt_tokenizer_path, + skip_special_tokens=True, + ), + partial( + mel_spectrogram, + n_fft=1920, + num_mels=80, + sampling_rate=sample_rate, + hop_size=480, + win_size=1920, + fmin=0, + fmax=None, + center=False, + ), + os.path.join(model_dir, "campplus.onnx"), + os.path.join(model_dir, "speech_tokenizer_v3.onnx"), + os.path.join(model_dir, "spk2info.pt"), + allowed_special="all", + ) + self.sample_rate = sample_rate + llm = CosyVoice3LM( + model_dir, + llm_input_size=llm_input_size, + llm_output_size=llm_output_size, + speech_token_size=6561, + ) + + llm_path, flow_path, hift_path = os.path.join(model_dir, "llm.pt"), "", "" + flow, hift = None, None + self.generate_audio = generate_audio + if self.generate_audio: + from cosyvoice.flow.DiT.dit import DiT + from cosyvoice.flow.flow import CausalMaskedDiffWithDiT + from cosyvoice.flow.flow_matching import CausalConditionalCFM + from cosyvoice.hifigan.f0_predictor import CausalConvRNNF0Predictor + from cosyvoice.hifigan.generator import CausalHiFTGenerator + from cosyvoice.transformer.upsample_encoder import PreLookaheadLayer + from omegaconf import DictConfig + + pre_lookahead_layer = PreLookaheadLayer( + in_channels=80, channels=1024, pre_lookahead_len=3 + ) + config_dict = { + "sigma_min": 1e-06, + "solver": "euler", + "t_scheduler": "cosine", + "training_cfg_rate": 0.2, + "inference_cfg_rate": 0.7, + "reg_loss_type": "l1", + } + cfm_params = DictConfig(content=config_dict) + estimator = DiT( + dim=1024, + depth=22, + heads=16, + dim_head=64, + ff_mult=2, + mel_dim=80, + mu_dim=80, + spk_dim=80, + out_channels=80, + static_chunk_size=chunk_size * token_mel_ratio, + num_decoding_left_chunks=num_decoding_left_chunks, + ) + decoder = CausalConditionalCFM( + in_channels=240, + n_spks=1, + spk_emb_dim=80, + cfm_params=cfm_params, + estimator=estimator, + ) + flow = CausalMaskedDiffWithDiT( + input_size=80, + output_size=80, + spk_embed_dim=spk_embed_dim, + output_type="mel", + vocab_size=6561, + input_frame_rate=token_frame_rate, + only_mask_loss=True, + token_mel_ratio=token_mel_ratio, + pre_lookahead_len=3, + pre_lookahead_layer=pre_lookahead_layer, + decoder=decoder, + ) + f0_predictor = CausalConvRNNF0Predictor( + num_class=1, in_channels=80, cond_channels=512 + ) + hift = CausalHiFTGenerator( + in_channels=80, + base_channels=512, + nb_harmonics=8, + sampling_rate=sample_rate, + nsf_alpha=0.1, + nsf_sigma=0.003, + nsf_voiced_threshold=10, + upsample_rates=[8, 5, 3], + upsample_kernel_sizes=[16, 11, 7], + istft_params={"n_fft": 16, "hop_len": 4}, + resblock_kernel_sizes=[3, 7, 11], + resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], + source_resblock_kernel_sizes=[7, 7, 11], + source_resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], + lrelu_slope=0.1, + audio_limit=0.99, + conv_pre_look_right=4, + f0_predictor=f0_predictor, + ) + + flow_path, hift_path = os.path.join(model_dir, "flow.pt"), os.path.join( + model_dir, "hift.pt" + ) + + self.model = CosyVoice3Model(llm, flow, hift) + self.model.load(llm_path, flow_path, hift_path) diff --git a/angelslim/compressor/speculative/inference/models/eagle3/target/modeling_qwen2_kv.py b/angelslim/compressor/speculative/inference/models/eagle3/target/modeling_qwen2_kv.py new file mode 100644 index 00000000..93f6e850 --- /dev/null +++ b/angelslim/compressor/speculative/inference/models/eagle3/target/modeling_qwen2_kv.py @@ -0,0 +1,1255 @@ +# This file is adapted from the Hugging Face Transformers library: +# Source: https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/qwen2/modeling_qwen2.py # noqa: E501 +# Original Copyright: Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. # noqa: E501 +# +# Modifications made for AngelSlim project: +# - Modified KV cache mechanism for preallocated GPU memory optimization +# - Added support for speculative decoding in EAGLE3 target model +# - Customized attention mask handling with tree_mask support for tree-based inference +# - Modified forward pass to support custom cache position handling +# - Other modifications are denoted by the symbol: [MODIFIED] +# flake8: noqa: E501 +from functools import partial +from typing import Callable, Optional, Tuple, Union + +import torch +from torch import nn +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, SlidingWindowCache, StaticCache +from transformers.generation import GenerationMixin +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from transformers.models.qwen2.configuration_qwen2 import Qwen2Config +from transformers.processing_utils import Unpack +from transformers.utils import TransformersKwargs # [MODIFIED] +from transformers.utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + can_return_tuple, + logging, + replace_return_docstrings, +) +from transformers.utils.deprecation import deprecate_kwarg + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "meta-qwen2/Qwen2-2-7b-hf" +_CONFIG_FOR_DOC = "Qwen2Config" + + +class Qwen2MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( + query.dtype + ) + attn_weights = nn.functional.dropout( + attn_weights, p=dropout, training=module.training + ) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class Qwen2Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Qwen2Config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + self.num_key_value_groups = ( + config.num_attention_heads // config.num_key_value_heads + ) + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=True + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=False + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin + ) + + # [MODIFIED] + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + past_key, past_value = past_key_value[self.layer_idx] + key_states = past_key.cat(key_states) + value_states = past_value.cat(value_states) + + sliding_window = None + if ( + self.config.use_sliding_window + and getattr(self.config, "sliding_window", None) is not None + and self.layer_idx >= self.config.max_window_layers + ): + sliding_window = self.config.sliding_window + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get( + "output_attentions", False + ): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[ + self.config._attn_implementation + ] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=sliding_window, # main diff with Llama + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Qwen2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Qwen2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Qwen2DecoderLayer(nn.Module): + def __init__(self, config: Qwen2Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = Qwen2Attention(config=config, layer_idx=layer_idx) + self.mlp = Qwen2MLP(config) + self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen2RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + if config.sliding_window and config._attn_implementation != "flash_attention_2": + logger.warning_once( + f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " + "unexpected results may be encountered." + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[ + Tuple[torch.Tensor, torch.Tensor] + ] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[ + torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +class Qwen2RotaryEmbedding(nn.Module): + def __init__(self, config: Qwen2Config, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get( + "rope_type", config.rope_scaling.get("type") + ) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = ( + self.inv_freq[None, :, None] + .float() + .expand(position_ids.shape[0], -1, 1) + .to(x.device) + ) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = ( + x.device.type + if isinstance(x.device.type, str) and x.device.type != "mps" + else "cpu" + ) + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = ( + inv_freq_expanded.float() @ position_ids_expanded.float() + ).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +QWEN2_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Qwen2Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.", + QWEN2_START_DOCSTRING, +) +class Qwen2PreTrainedModel(PreTrainedModel): + config_class = Qwen2Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Qwen2DecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + _supports_attention_backend = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +QWEN2_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.", + QWEN2_START_DOCSTRING, +) +class Qwen2Model(Qwen2PreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`] + + Args: + config: Qwen2Config + """ + + def __init__(self, config: Qwen2Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding( + config.vocab_size, config.hidden_size, self.padding_idx + ) + self.layers = nn.ModuleList( + [ + Qwen2DecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen2RotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @can_return_tuple + @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> BaseModelOutputWithPast: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You must specify exactly one of input_ids or inputs_embeds" + ) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + # [MODIFIED] + if cache_position is None: + past_seen_tokens = ( + past_key_values[0][0].current_length.item() + if past_key_values is not None + else 0 + ) + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values, + output_attentions, + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for idx, decoder_layer in enumerate( + self.layers[: self.config.num_hidden_layers] + ): + if output_hidden_states: + # [MODIFIED] + if ( + idx == len(self.layers) - 3 + or idx == len(self.layers) // 2 + or idx == 2 + ): + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + partial(decoder_layer.__call__, **flash_attn_kwargs), + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and past_key_values is not None: + is_padding_right = ( + attention_mask[:, -1].sum().item() != input_tensor.size()[0] + ) + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Qwen2. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + # past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = ( + past_key_values[0][0].current_length.item() + if past_key_values is not None + else 0 + ) + + using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + # SlidingWindowCache or StaticCache + if using_sliding_window_cache or using_static_cache: + target_length = past_key_values.get_max_cache_shape() + # DynamicCache or no cache + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length # [MODIFIED] + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + config=self.config, + past_key_values=past_key_values, + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended( + causal_mask, min_dtype + ) + + return causal_mask + + # @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + self, + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + config: Qwen2Config, + past_key_values: Cache, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to place the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + config (`Qwen2Config`): + The model's configuration class + past_key_values (`Cache`): + The cache class that is being used currently to generate + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + if sequence_length == target_length: + causal_mask = torch.full( + (sequence_length, target_length), + fill_value=min_dtype, + dtype=dtype, + device=device, + ) + diagonal_attend_mask = torch.arange( + target_length, device=device + ) > cache_position.reshape(-1, 1) + if config.sliding_window is not None: + # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also + # the check is needed to verify is current checkpoint was trained with sliding window or not + if ( + not isinstance(past_key_values, SlidingWindowCache) + or sequence_length > target_length + ): + sliding_attend_mask = torch.arange( + target_length, device=device + ) <= (cache_position.reshape(-1, 1) - config.sliding_window) + diagonal_attend_mask.bitwise_or_(sliding_attend_mask) + causal_mask *= diagonal_attend_mask + causal_mask = causal_mask[None, None, :, :].expand( + batch_size, 1, -1, -1 + ) + if attention_mask is not None: + causal_mask = ( + causal_mask.clone() + ) # copy to contiguous memory for in-place edit + if attention_mask.shape[-1] > target_length: + attention_mask = attention_mask[:, :target_length] + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[ + :, None, None, : + ].to(causal_mask.device) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[ + :, :, :, :mask_length + ].masked_fill(padding_mask, min_dtype) + else: + # [MODIFIED] + causal_mask = torch.zeros( + (sequence_length, target_length), dtype=dtype, device=device + ) + causal_mask = causal_mask[None, None, :, :].expand( + batch_size, 1, -1, -1 + ) + + if hasattr(self, "tree_mask") and self.tree_mask is not None: + tree_mask = self.tree_mask + tree_len = tree_mask.size(-1) + causal_mask[:, :, -tree_len:, -tree_len:][ + tree_mask == 0 + ] = min_dtype + return causal_mask + + +class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = Qwen2Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @can_return_tuple + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) + @replace_return_docstrings( + output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], # [MODIFIED] + ) -> CausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, Qwen2ForCausalLM + + >>> model = Qwen2ForCausalLM.from_pretrained("meta-qwen2/Qwen2-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-qwen2/Qwen2-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = ( + slice(-logits_to_keep, None) + if isinstance(logits_to_keep, int) + else logits_to_keep + ) + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + The Qwen2 Model transformer with a sequence classification head on top (linear layer). + + [`Qwen2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + QWEN2_START_DOCSTRING, +) +class Qwen2ForSequenceClassification(Qwen2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = Qwen2Model(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @can_return_tuple + @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> SequenceClassifierOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + transformer_outputs: BaseModelOutputWithPast = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + hidden_states = transformer_outputs.last_hidden_state + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError( + "Cannot handle batch sizes > 1 if no padding token is defined." + ) + if self.config.pad_token_id is None: + last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to( + logits.device, torch.int32 + ) + token_indices = torch.arange( + input_ids.shape[-1], device=logits.device, dtype=torch.int32 + ) + last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) + else: + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[ + torch.arange(batch_size, device=logits.device), last_non_pad_token + ] + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + pooled_logits=pooled_logits, + config=self.config, + ) + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ + The Qwen2 Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + QWEN2_START_DOCSTRING, +) +class Qwen2ForTokenClassification(Qwen2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = Qwen2Model(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @can_return_tuple + @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> TokenClassifierOutput: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + outputs: BaseModelOutputWithPast = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + sequence_output = outputs.last_hidden_state + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.config) + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ +The Qwen2 Model transformer with a span classification head on top for extractive question-answering tasks like +SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + QWEN2_START_DOCSTRING, +) +class Qwen2ForQuestionAnswering(Qwen2PreTrainedModel): + base_model_prefix = "transformer" + + def __init__(self, config): + super().__init__(config) + self.transformer = Qwen2Model(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.transformer.embed_tokens + + def set_input_embeddings(self, value): + self.transformer.embed_tokens = value + + @can_return_tuple + @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + **kwargs, + ) -> QuestionAnsweringModelOutput: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + + outputs: BaseModelOutputWithPast = self.transformer( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + sequence_output = outputs.last_hidden_state + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + loss = None + if start_positions is not None and end_positions is not None: + loss = self.loss_function( + start_logits, end_logits, start_positions, end_positions, **kwargs + ) + + return QuestionAnsweringModelOutput( + loss=loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/angelslim/compressor/speculative/train/configs/cosyvoice3-llm-eagle3.json b/angelslim/compressor/speculative/train/configs/cosyvoice3-llm-eagle3.json new file mode 100644 index 00000000..e0cba072 --- /dev/null +++ b/angelslim/compressor/speculative/train/configs/cosyvoice3-llm-eagle3.json @@ -0,0 +1,28 @@ +{ + "architectures": [ + "CosyVoice3Eagle3LlamaForCausalLM" + ], + "attention_dropout": 0.0, + "bos_token_id": 151643, + "eos_token_id": 151645, + "hidden_act": "silu", + "hidden_size": 896, + "initializer_range": 0.02, + "intermediate_size": 4864, + "max_position_embeddings": 32768, + "max_window_layers": 24, + "model_type": "llama", + "num_attention_heads": 14, + "num_hidden_layers": 24, + "num_key_value_heads": 2, + "rms_norm_eps": 1e-06, + "rope_theta": 1000000.0, + "sliding_window": 32768, + "tie_word_embeddings": true, + "torch_dtype": "bfloat16", + "transformers_version": "4.40.1", + "use_cache": true, + "use_sliding_window": false, + "vocab_size": 6761, + "draft_vocab_size": 6761 +} \ No newline at end of file diff --git a/angelslim/compressor/speculative/train/data/data_utils.py b/angelslim/compressor/speculative/train/data/data_utils.py index 591a309d..41514a36 100644 --- a/angelslim/compressor/speculative/train/data/data_utils.py +++ b/angelslim/compressor/speculative/train/data/data_utils.py @@ -22,6 +22,9 @@ "convert_ultrachat_data", "DataCollatorWithPadding", "VLMDataCollatorWithPadding", + "VLMHunyuanDataCollatorWithPadding", + "AudioDataCollatorWithPadding", + "CosyVoice3DataCollatorWithPadding", ] @@ -409,3 +412,71 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: [(item["input_features"]) for item in features] ) return batch + + +class CosyVoice3DataCollatorWithPadding: + + def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: + max_length = max(item["text"].shape[-1] for item in features) + batch_text_tokens = torch.cat( + [ + paddingtensor2D(item["text"].unsqueeze(0), max_length) + for item in features + ] + ) + max_length = max(item["speech_token"].shape[-1] for item in features) + batch_speech_tokens = torch.cat( + [ + paddingtensor2D(item["speech_token"].unsqueeze(0), max_length) + for item in features + ] + ) + max_length = max(item["prompt_text"].shape[-1] for item in features) + batch_prompt_text = torch.cat( + [ + paddingtensor2D(item["prompt_text"].unsqueeze(0), max_length) + for item in features + ] + ) + max_length = max(item["prompt_speech_token"].shape[-1] for item in features) + batch_prompt_speech_tokens = torch.cat( + [ + paddingtensor2D(item["prompt_speech_token"].unsqueeze(0), max_length) + for item in features + ] + ) + batch_text_token_lens = torch.stack([item["text_len"] for item in features]) + batch_speech_token_lens = torch.stack( + [item["speech_token_len"] for item in features] + ) + batch_prompt_text_lens = torch.stack( + [item["prompt_text_len"] for item in features] + ) + batch_prompt_speech_token_lens = torch.stack( + [item["prompt_speech_token_len"] for item in features] + ) + + batch = { + "text": batch_text_tokens, + "text_len": batch_text_token_lens, + "speech_token": batch_speech_tokens, + "speech_token_len": batch_speech_token_lens, + "prompt_speech_token": batch_prompt_speech_tokens, + "prompt_speech_token_len": batch_prompt_speech_token_lens, + "prompt_text": batch_prompt_text, + "prompt_text_len": batch_prompt_text_lens, + "hidden_states": None, + "target_hiddens": None, + } + + # Check if both hidden_states and target_hiddens exist in all features + if all( + "hidden_states" in item and "target_hiddens" in item for item in features + ): + batch["hidden_states"] = torch.cat( + [paddingtensor(item["hidden_states"], max_length) for item in features] + ) + batch["target_hiddens"] = torch.cat( + [paddingtensor(item["target_hiddens"], max_length) for item in features] + ) + return batch diff --git a/angelslim/compressor/speculative/train/data/dataset.py b/angelslim/compressor/speculative/train/data/dataset.py index cd775fd6..12eeabac 100644 --- a/angelslim/compressor/speculative/train/data/dataset.py +++ b/angelslim/compressor/speculative/train/data/dataset.py @@ -84,6 +84,8 @@ def __init__( shuffle_seed=data_args.shuffle_seed, chat_template_type=chat_template_type, display=display, + target_model_name_or_path=data_args.target_model_name_or_path, + output_dir=data_args.output_dir, ) if data_args.training_mode == "offline": self.offline_dataset_builder = DatasetBuilderFactory.create( diff --git a/angelslim/compressor/speculative/train/data/dataset_builder/__init__.py b/angelslim/compressor/speculative/train/data/dataset_builder/__init__.py index 8b0501aa..3073d9ae 100644 --- a/angelslim/compressor/speculative/train/data/dataset_builder/__init__.py +++ b/angelslim/compressor/speculative/train/data/dataset_builder/__init__.py @@ -21,6 +21,7 @@ from .online_dataset_builder import ( OnlineAudioDatasetBuilder, OnlineLLMDatasetBuilder, + OnlineTTSDatasetBuilder, OnlineVLMDatasetBuilder, OnlineVLMHunyuanVLDatasetBuilder, ) @@ -28,6 +29,7 @@ __all__ = [ "OnlineLLMDatasetBuilder", "OnlineVLMDatasetBuilder", + "OnlineTTSDatasetBuilder", "OnlineVLMHunyuanVLDatasetBuilder", "OfflineLLMDatasetBuilder", "OfflineVLMDatasetBuilder", diff --git a/angelslim/compressor/speculative/train/data/dataset_builder/online_dataset_builder.py b/angelslim/compressor/speculative/train/data/dataset_builder/online_dataset_builder.py index 3a160161..3ac1a77c 100644 --- a/angelslim/compressor/speculative/train/data/dataset_builder/online_dataset_builder.py +++ b/angelslim/compressor/speculative/train/data/dataset_builder/online_dataset_builder.py @@ -12,21 +12,32 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json +import os +from functools import partial +from pathlib import Path from typing import Any, Dict, List, Optional, Union +import numpy as np import requests import torch from datasets import Features, Value, load_dataset +from huggingface_hub import snapshot_download from PIL import Image from torch.utils.data import Dataset +from tqdm import tqdm from transformers import AutoProcessor, AutoTokenizer from transformers.pipelines.audio_utils import ffmpeg_read from angelslim.utils import rank0_print +from ......utils.lazy_imports import onnxruntime, torchaudio, whisper +from ......utils.utils import decide_device_for_distributed +from ....inference.models.eagle3.target.modeling_cosyvoice3_kv import mel_spectrogram from ..chat_templates import ChatTemplateType from ..data_utils import ( AudioDataCollatorWithPadding, + CosyVoice3DataCollatorWithPadding, DataCollatorWithPadding, VLMDataCollatorWithPadding, VLMHunyuanDataCollatorWithPadding, @@ -810,3 +821,373 @@ def _process_single_conversation( except Exception as e: rank0_print(f"Error processing conversation: {e}") return None + + +@DatasetBuilderFactory.register("online", "TTS") +class OnlineTTSDatasetBuilder(OnlineDatasetBuilder): + def __init__( + self, + tokenizer: Union[AutoTokenizer, AutoProcessor], + max_length: int = 2048, + shuffle_seed: int = 42, + chat_template_type: ChatTemplateType = ChatTemplateType.QWEN3, + display: bool = False, + **kwargs: Any, + ): + super().__init__( + tokenizer, + max_length, + shuffle_seed, + chat_template_type, + display, + ) + self.world_size = int(os.getenv("WORLD_SIZE", 1)) + self.global_rank = int(os.getenv("RANK", -1)) + self.output_dir = kwargs["output_dir"] + self.device = decide_device_for_distributed() + + self.model_path = kwargs["target_model_name_or_path"] + if not os.path.exists(self.model_path): + self.model_path = snapshot_download(self.model_path) + + if os.path.exists(os.path.join(self.model_path, "cosyvoice3.yaml")): + self.model_name = "cosyvoice3" + onnx_path = os.path.join(self.model_path, "speech_tokenizer_v3.onnx") + self._init_audio_tokenizer_cosyvoice3(onnx_path) + self.feat_extractor = partial( + mel_spectrogram, + n_fft=1920, + num_mels=80, + sampling_rate=24000, + hop_size=480, + win_size=1920, + fmin=0, + fmax=None, + center=False, + ) + + def _init_audio_tokenizer_cosyvoice3(self, onnx_path) -> None: + option = onnxruntime.SessionOptions() + option.graph_optimization_level = ( + onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL + ) + option.intra_op_num_threads = 1 + providers = ["CUDAExecutionProvider"] + self.speech_tokenizer_session = onnxruntime.InferenceSession( + onnx_path, sess_options=option, providers=providers + ) + + def get_data_collator(self) -> Any: + if self.model_name == "cosyvoice3": + return CosyVoice3DataCollatorWithPadding() + + def read_jsonl_file(self, file_path: str) -> List[Dict[str, Any]]: + data = [] + try: + for file in file_path: + with open(file, "r", encoding="utf-8") as f: + for line in tqdm( + f, + desc=f"read data file {os.path.basename(file)}", + disable=self.global_rank > 0, + ): + try: + item = json.loads(line.strip()) + if isinstance(item, dict): + data.append(item) + except json.JSONDecodeError as e: + rank0_print( + f"JSON extract error: {e}, line: {line[:100]}..." + ) + continue + except Exception as e: + rank0_print(f"read data file {file_path} failed: {e}") + return data + + def build_dataset( + self, + datapath: str, + num_proc: int = 8, + shuffle: bool = True, + sample_num: Optional[int] = None, + ) -> Dataset: + try: + if not isinstance(datapath, list): + datapath = [datapath] + data_name = "_" + for path in datapath: + data_name += os.path.basename(path)[:-6] + os.makedirs(self.output_dir, exist_ok=True) + cache_path = os.path.join( + self.output_dir, f"processed{data_name}_merged_cache.jsonl" + ) + + if not os.path.exists(cache_path): + raw_data = self.read_jsonl_file(datapath) + chunk_size = len(raw_data) // self.world_size + start_idx = self.global_rank * chunk_size + end_idx = ( + start_idx + chunk_size + if self.global_rank < self.world_size - 1 + else len(raw_data) + ) + rank_data = raw_data[start_idx:end_idx] + processed_data = [] + count = 0 + for item in tqdm( + rank_data, + desc=f"Rank {self.global_rank} process data", + disable=self.global_rank > 0, + ): + if ( + sample_num is not None + and count == sample_num // self.world_size + ): + break + text = item.get("text", "") + audio_tokens = item.get("audio_tokens", None) + audio_path = item.get("audio_path", "") + instruct = item.get("instruct", "") + instruct_audio_path = item.get("instruct_audio_path", "") + + if self.model_name == "cosyvoice3": + processed = self._process_single_item_cosyvoice3( + text, + audio_tokens, + audio_path, + instruct, + instruct_audio_path, + ) + else: + raise NotImplementedError("This model is not implemented") + + processed_data.append(processed) + count += 1 + + # save for each rank + rank_file = os.path.join( + self.output_dir, + f"processed{data_name}_rank_{self.global_rank}.jsonl", + ) + with open(rank_file, "w", encoding="utf-8") as f: + for item in processed_data: + f.write(json.dumps(item, ensure_ascii=True) + "\n") + done_file = os.path.join( + self.output_dir, + f"processed{data_name}_rank_{self.global_rank}.done", + ) + Path(done_file).touch() + self._wait_for_all_ranks_done( + self.output_dir, data_name, self.world_size + ) + + # merge processed data on rank 0 + merge_done_file = os.path.join( + self.output_dir, f"processed{data_name}_merged_cache.done" + ) + if self.global_rank == 0: + all_processed_data = [] + for rank in range(self.world_size): + rank_data = [] + rank_tmp_file = os.path.join( + self.output_dir, f"processed{data_name}_rank_{rank}.jsonl" + ) + with open(rank_tmp_file, "r", encoding="utf-8") as f: + for line in f: + if line.strip(): + rank_data.append(json.loads(line.strip())) + all_processed_data.extend(rank_data) + + with open(cache_path, "w", encoding="utf-8") as f: + for item in all_processed_data: + f.write(json.dumps(item, ensure_ascii=True) + "\n") + + for rank in range(self.world_size): + rank_tmp_file = os.path.join( + self.output_dir, f"processed{data_name}_rank_{rank}.jsonl" + ) + rank_done_file = os.path.join( + self.output_dir, f"processed{data_name}_rank_{rank}.done" + ) + if os.path.exists(rank_tmp_file): + os.remove(rank_tmp_file) + if os.path.exists(rank_done_file): + os.remove(rank_done_file) + + with open(merge_done_file, "w") as f: + f.write("Merged done") + rank0_print("Rank 0: Created merge completion marker") + else: + merge_done = False + while not merge_done: + if os.path.exists(merge_done_file): + merge_done = True + break + + # Load dataset + processed_ds = load_dataset("json", data_files=cache_path) + + # Conditionally shuffle dataset + if shuffle: + processed_ds = processed_ds["train"].shuffle(seed=self.shuffle_seed) + else: + processed_ds = processed_ds["train"] + + # Filter out None results with multiprocessing support + processed_ds = processed_ds.filter( + lambda batch: [ids is not None for ids in batch["speech_token"]], + batched=True, + num_proc=num_proc, + desc="Filtering empty speech_token", + ) + + processed_ds.set_format(type="torch") + return processed_ds + + else: + # Load dataset + rank0_print(f"Loading cache data from {cache_path}") + ds = load_dataset("json", data_files=cache_path) + + # Conditionally shuffle dataset + if shuffle: + ds = ds["train"].shuffle(seed=self.shuffle_seed) + else: + ds = ds["train"] + + # Filter out None results with multiprocessing support + ds = ds.filter( + lambda batch: [ids is not None for ids in batch["speech_token"]], + batched=True, + num_proc=num_proc, + desc="Filtering empty speech_token", + ) + + ds.set_format(type="torch") + return ds + + except Exception as e: + raise RuntimeError(f"Dataset building failed for {datapath}") from e + + def _wait_for_all_ranks_done(self, output_dir, data_name, world_size): + all_done = False + while not all_done: + done_count = 0 + for rank in range(world_size): + done_file = os.path.join( + output_dir, f"processed{data_name}_rank_{rank}.done" + ) + if os.path.exists(done_file): + done_count += 1 + + if done_count == world_size: + all_done = True + break + + def _process_single_item_cosyvoice3( + self, + text: str, + audio_tokens: Optional[list], + audio_path: str, + instruct: Dict[str, Any], + instruct_audio_path: str, + ) -> Optional[Dict[str, Any]]: + text_token = self.tokenizer.encode(text) + instruct_token = self.tokenizer.encode(instruct) + prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat( + instruct_audio_path + ) + prompt_speech_token, prompt_speech_token_len = self._extract_speech_token( + instruct_audio_path + ) + + resample_rate = 24000 + if resample_rate == 24000: + token_len = min( + int(prompt_speech_feat.shape[1] / 2), prompt_speech_token.shape[1] + ) + prompt_speech_feat, prompt_speech_feat_len[:] = ( + prompt_speech_feat[:, : 2 * token_len], + 2 * token_len, + ) + prompt_speech_token, prompt_speech_token_len[:] = ( + prompt_speech_token[:, :token_len], + token_len, + ) + + if audio_tokens is not None: + return { + "text": text_token, + "text_len": len(text_token), + "speech_token": audio_tokens, + "speech_token_len": len(audio_tokens), + "prompt_speech_token": prompt_speech_token.squeeze(0).tolist(), + "prompt_speech_token_len": prompt_speech_token_len.item(), + "prompt_text": instruct_token, + "prompt_text_len": len(instruct_token), + } + + speech_token, speech_token_len = self._extract_speech_token(audio_path) + return { + "text": text_token, + "text_len": len(text_token), + "speech_token": speech_token.squeeze(0).tolist(), + "speech_token_len": speech_token_len.item(), + "prompt_speech_token": prompt_speech_token.squeeze(0).tolist(), + "prompt_speech_token_len": prompt_speech_token_len.item(), + "prompt_text": instruct_token, + "prompt_text_len": len(instruct_token), + } + + def _extract_speech_token(self, wav): + speech = self.load_wav(wav, 16000) + assert ( + speech.shape[1] / 16000 <= 30 + ), "do not support extract speech token for audio longer than 30s" + feat = whisper.log_mel_spectrogram(speech, n_mels=128) + speech_token = ( + self.speech_tokenizer_session.run( + None, + { + self.speech_tokenizer_session.get_inputs()[0] + .name: feat.detach() + .cpu() + .numpy(), + self.speech_tokenizer_session.get_inputs()[1].name: np.array( + [feat.shape[2]], dtype=np.int32 + ), + }, + )[0] + .flatten() + .tolist() + ) + speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device) + speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to( + self.device + ) + return speech_token, speech_token_len + + def _extract_speech_feat(self, wav): + speech = self.load_wav(wav, 24000) + speech_feat = ( + self.feat_extractor(speech).squeeze(dim=0).transpose(0, 1).to(self.device) + ) + speech_feat = speech_feat.unsqueeze(dim=0) + speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to( + self.device + ) + return speech_feat, speech_feat_len + + def load_wav(self, wav, target_sr, min_sr=16000): + speech, sample_rate = torchaudio.load(wav, backend="soundfile") + speech = speech.mean(dim=0, keepdim=True) + if sample_rate != target_sr: + assert ( + sample_rate >= min_sr + ), "wav sample rate {} must be greater than {}".format( + sample_rate, target_sr + ) + speech = torchaudio.transforms.Resample( + orig_freq=sample_rate, new_freq=target_sr + )(speech) + return speech diff --git a/angelslim/compressor/speculative/train/models/draft/__init__.py b/angelslim/compressor/speculative/train/models/draft/__init__.py index d69e8261..1b1eb4b9 100644 --- a/angelslim/compressor/speculative/train/models/draft/__init__.py +++ b/angelslim/compressor/speculative/train/models/draft/__init__.py @@ -13,6 +13,11 @@ # limitations under the License. from .draft_model_factory import DraftModelConfig, create_draft_model -from .llama_eagle3 import Eagle3LlamaForCausalLM +from .llama_eagle3 import CosyVoice3Eagle3LlamaForCausalLM, Eagle3LlamaForCausalLM -__all__ = ["create_draft_model", "DraftModelConfig", "Eagle3LlamaForCausalLM"] +__all__ = [ + "create_draft_model", + "DraftModelConfig", + "Eagle3LlamaForCausalLM", + "CosyVoice3Eagle3LlamaForCausalLM", +] diff --git a/angelslim/compressor/speculative/train/models/draft/llama_eagle3.py b/angelslim/compressor/speculative/train/models/draft/llama_eagle3.py index 971d8db4..03afb6d5 100644 --- a/angelslim/compressor/speculative/train/models/draft/llama_eagle3.py +++ b/angelslim/compressor/speculative/train/models/draft/llama_eagle3.py @@ -13,16 +13,20 @@ # limitations under the License. import math +import os +from collections import Counter from typing import List, Optional, Tuple import torch import torch.nn.functional as F import torch.utils.checkpoint +from huggingface_hub import snapshot_download from torch import nn from transformers import LlamaConfig from transformers.activations import ACT2FN from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +from ...data.data_utils import process_token_dict_to_mappings from ..model_utils import apply_rotary_pos_emb, apply_rotary_pos_emb_mrope, repeat_kv from .base_model import Eagle3BaseDraftModel from .draft_model_factory import DraftModelFactory @@ -698,3 +702,76 @@ def custom_forward(*inputs): logits = self.lm_head(hidden_states_out) logits = logits.float() return hidden_states, logits + + +@DraftModelFactory.register +class CosyVoice3Eagle3LlamaForCausalLM(Eagle3LlamaForCausalLM): + + def load_embed_weights(self, target_model_name_or_path, embed_weight_key): + """ + Load embedding weights from pretrained model. + + Args: + target_model_name_or_path: Local path or + HuggingFace model identifier (e.g., 'Qwen/Qwen2-7B') + embed_weight_key: Key for the embedding weights in the model file + """ + # Handle HuggingFace model identifier + if not os.path.exists(target_model_name_or_path): + target_model_name_or_path = snapshot_download( + repo_id=target_model_name_or_path + ) + + # Try loading embedding weights + tensor = torch.load("{}/llm.pt".format(target_model_name_or_path)) + speech_embedding_weight = tensor["speech_embedding.weight"] + + with torch.no_grad(): + self.embed_tokens.weight.copy_(speech_embedding_weight) + + def build_vocab_mapping(self, dataset, cache_path): + """ + Build vocab mapping from full vocabulary to draft vocabulary + based on token frequency. + + Args: + dataset: Preprocessed dataset containing 'input_ids' field + cache_path: Path to save/load the token mapping cache + num_processes: Number of processes for parallel processing + """ + if not os.path.exists(cache_path): + # we first count the frequency of effective tokens in the dataset + token_dict = Counter() + print(f"vocab len(dataset)={len(dataset)} type(dataset)={type(dataset)}") + # for item in tqdm(dataset, desc=f"Counting tokens for vocab mapping"): + + for _, item in enumerate(dataset): + input_ids = item["speech_token"] + unique_ids, counts = input_ids.unique(return_counts=True) + batch_token_dict = dict(zip(unique_ids.tolist(), counts.tolist())) + token_dict.update(batch_token_dict) + + # generate the d2t and t2d mapping + d2t, t2d = process_token_dict_to_mappings( + token_dict, + self.draft_vocab_size, + self.vocab_size, + ) + + vocab_mapping = { + "d2t": d2t, + "t2d": t2d, + } + + cache_parent_dir = os.path.dirname(cache_path) + os.makedirs(cache_parent_dir, exist_ok=True) + torch.save(vocab_mapping, cache_path) + print(f"Saved vocab mapping to: {cache_path}") + else: + # Load from cache + cache = torch.load(cache_path) + d2t = cache["d2t"] + t2d = cache["t2d"] + + self.t2d.copy_(t2d) + self.d2t.copy_(d2t) diff --git a/angelslim/compressor/speculative/train/models/target/cosyvoice3_llm.py b/angelslim/compressor/speculative/train/models/target/cosyvoice3_llm.py new file mode 100644 index 00000000..ab7ad341 --- /dev/null +++ b/angelslim/compressor/speculative/train/models/target/cosyvoice3_llm.py @@ -0,0 +1,283 @@ +# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du) +# 2025 Alibaba Inc (authors: Xiang Lyu, Yabin Li, Qihua, Shengqiang Li) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from https://github.com/FunAudioLLM/CosyVoice for AngelSlim project + +import os +from typing import Optional + +import torch +from torch import nn +from torch.nn.utils.rnn import pad_sequence, unpad_sequence +from transformers import AutoTokenizer + +from ....inference.models.eagle3.target.modeling_qwen2_kv import Qwen2ForCausalLM + +IGNORE_ID = -1 + + +def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: + """Make mask tensor containing indices of padded part. + + See description of make_non_pad_mask. + + Args: + lengths (torch.Tensor): Batch of lengths (B,). + Returns: + torch.Tensor: Mask tensor containing indices of padded part. + + Examples: + >>> lengths = [5, 3, 2] + >>> make_pad_mask(lengths) + masks = [[0, 0, 0, 0 ,0], + [0, 0, 0, 1, 1], + [0, 0, 1, 1, 1]] + """ + batch_size = lengths.size(0) + max_len = max_len if max_len > 0 else lengths.max().item() + seq_range = torch.arange(0, max_len, dtype=torch.int64, device=lengths.device) + seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) + seq_length_expand = lengths.unsqueeze(-1) + mask = seq_range_expand >= seq_length_expand + return mask + + +def get_qwen_tokenizer( + token_path: str, skip_special_tokens: bool, version: str = "cosyvoice3" +): + if version == "cosyvoice3": + return CosyVoice3Tokenizer( + token_path=token_path, skip_special_tokens=skip_special_tokens + ) + else: + raise ValueError + + +class Qwen2Encoder(torch.nn.Module): + def __init__(self, pretrain_path): + super().__init__() + self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path) + + def forward( + self, xs: torch.Tensor, xs_lens: torch.Tensor, output_hidden_states: bool + ): + T = xs.size(1) + masks = ~make_pad_mask(xs_lens, T) + outs = self.model( + inputs_embeds=xs, + attention_mask=masks, + output_hidden_states=output_hidden_states, + ) + + return outs, masks.unsqueeze(1) + + +class CosyVoice3Tokenizer: + def __init__(self, token_path, skip_special_tokens=True): + # NOTE: non-chat model, all these special tokens keep randomly initialized. + # fmt: off + # flake8: noqa + special_tokens = { + 'eos_token': '<|endoftext|>', + 'pad_token': '<|endoftext|>', + 'additional_special_tokens': [ + '<|im_start|>', '<|im_end|>', '<|endofprompt|>', + '[breath]', '', '', '[noise]', + '[laughter]', '[cough]', '[clucking]', '[accent]', + '[quick_breath]', + "", "", + "[hissing]", "[sigh]", "[vocalized-noise]", + "[lipsmack]", "[mn]", "<|endofsystem|>", + "[AA]", "[AA0]", "[AA1]", "[AA2]", "[AE]", "[AE0]", "[AE1]", "[AE2]", "[AH]", "[AH0]", "[AH1]", "[AH2]", + "[AO]", "[AO0]", "[AO1]", "[AO2]", "[AW]", "[AW0]", "[AW1]", "[AW2]", "[AY]", "[AY0]", "[AY1]", "[AY2]", + "[B]", "[CH]", "[D]", "[DH]", "[EH]", "[EH0]", "[EH1]", "[EH2]", "[ER]", "[ER0]", "[ER1]", "[ER2]", "[EY]", + "[EY0]", "[EY1]", "[EY2]", "[F]", "[G]", "[HH]", "[IH]", "[IH0]", "[IH1]", "[IH2]", "[IY]", "[IY0]", "[IY1]", + "[IY2]", "[JH]", "[K]", "[L]", "[M]", "[N]", "[NG]", "[OW]", "[OW0]", "[OW1]", "[OW2]", "[OY]", "[OY0]", + "[OY1]", "[OY2]", "[P]", "[R]", "[S]", "[SH]", "[T]", "[TH]", "[UH]", "[UH0]", "[UH1]", "[UH2]", "[UW]", + "[UW0]", "[UW1]", "[UW2]", "[V]", "[W]", "[Y]", "[Z]", "[ZH]", + "[a]", "[ai]", "[an]", "[ang]", "[ao]", "[b]", "[c]", "[ch]", "[d]", "[e]", "[ei]", "[en]", "[eng]", "[f]", + "[g]", "[h]", "[i]", "[ian]", "[in]", "[ing]", "[iu]", "[ià]", "[iàn]", "[iàng]", "[iào]", "[iá]", "[ián]", + "[iáng]", "[iáo]", "[iè]", "[ié]", "[iòng]", "[ióng]", "[iù]", "[iú]", "[iā]", "[iān]", "[iāng]", "[iāo]", + "[iē]", "[iě]", "[iōng]", "[iū]", "[iǎ]", "[iǎn]", "[iǎng]", "[iǎo]", "[iǒng]", "[iǔ]", "[j]", "[k]", "[l]", + "[m]", "[n]", "[o]", "[ong]", "[ou]", "[p]", "[q]", "[r]", "[s]", "[sh]", "[t]", "[u]", "[uang]", "[ue]", + "[un]", "[uo]", "[uà]", "[uài]", "[uàn]", "[uàng]", "[uá]", "[uái]", "[uán]", "[uáng]", "[uè]", "[ué]", "[uì]", + "[uí]", "[uò]", "[uó]", "[uā]", "[uāi]", "[uān]", "[uāng]", "[uē]", "[uě]", "[uī]", "[uō]", "[uǎ]", "[uǎi]", + "[uǎn]", "[uǎng]", "[uǐ]", "[uǒ]", "[vè]", "[w]", "[x]", "[y]", "[z]", "[zh]", "[à]", "[ài]", "[àn]", "[àng]", + "[ào]", "[á]", "[ái]", "[án]", "[áng]", "[áo]", "[è]", "[èi]", "[èn]", "[èng]", "[èr]", "[é]", "[éi]", "[én]", + "[éng]", "[ér]", "[ì]", "[ìn]", "[ìng]", "[í]", "[ín]", "[íng]", "[ò]", "[òng]", "[òu]", "[ó]", "[óng]", "[óu]", + "[ù]", "[ùn]", "[ú]", "[ún]", "[ā]", "[āi]", "[ān]", "[āng]", "[āo]", "[ē]", "[ēi]", "[ēn]", "[ēng]", "[ě]", + "[ěi]", "[ěn]", "[ěng]", "[ěr]", "[ī]", "[īn]", "[īng]", "[ō]", "[ōng]", "[ōu]", "[ū]", "[ūn]", "[ǎ]", "[ǎi]", + "[ǎn]", "[ǎng]", "[ǎo]", "[ǐ]", "[ǐn]", "[ǐng]", "[ǒ]", "[ǒng]", "[ǒu]", "[ǔ]", "[ǔn]", "[ǘ]", "[ǚ]", "[ǜ]" + ] + } + # fmt: on + self.special_tokens = special_tokens + self.tokenizer = AutoTokenizer.from_pretrained(token_path) + self.tokenizer.add_special_tokens(special_tokens) + self.skip_special_tokens = skip_special_tokens + + def encode(self, text, **kwargs): + tokens = self.tokenizer([text], return_tensors="pt") + tokens = tokens["input_ids"][0].cpu().tolist() + return tokens + + def decode(self, tokens): + tokens = torch.tensor(tokens, dtype=torch.int64) + text = self.tokenizer.batch_decode( + [tokens], skip_special_tokens=self.skip_special_tokens + )[0] + return text + + +class CosyVoice3LM(torch.nn.Module): + def __init__( + self, + model_path, + llm_input_size: int, + llm_output_size: int, + speech_token_size: int, + ): + super().__init__() + self.llm_input_size = llm_input_size + self.llm_output_size = llm_output_size + self.speech_token_size = speech_token_size + # build speech token language model related modules + self.sos = speech_token_size + 0 + self.eos_token = speech_token_size + 1 + self.task_id = speech_token_size + 2 + self.fill_token = speech_token_size + 3 + + self.llm = Qwen2Encoder(os.path.join(model_path, "CosyVoice-BlankEN")) + self.llm_decoder = nn.Linear( + llm_output_size, speech_token_size + 200, bias=False + ) + + # [Optional] build speech token related modules + self.speech_embedding = torch.nn.Embedding( + speech_token_size + 200, llm_input_size + ) + self.stop_token_ids = [speech_token_size + i for i in range(200)] + + # tokenizer + self.tokenizer = get_qwen_tokenizer( + os.path.join(model_path, "CosyVoice-BlankEN"), skip_special_tokens=True + ) + + def forward( + self, + text: torch.Tensor, + text_len: torch.Tensor, + speech_token: torch.Tensor, + speech_token_len: torch.Tensor, + prompt_text: torch.Tensor, + prompt_text_len: torch.Tensor, + prompt_speech_token: torch.Tensor, + prompt_speech_token_len: torch.Tensor, + hidden_states: Optional[torch.Tensor], + output_hidden_states: bool = False, + **kwargs, + ): + device = text.device + text_token = torch.concat([prompt_text, text], dim=1) + text_len += prompt_text_len + text_emb = self.llm.model.model.embed_tokens(text_token) + + # concat llm_input + sos_emb = self.speech_embedding.weight[self.sos].reshape(1, 1, -1) + task_id_emb = self.speech_embedding.weight[self.task_id].reshape(1, 1, -1) + if prompt_speech_token_len != 0: + prompt_speech_token_emb = self.speech_embedding(prompt_speech_token) + else: + prompt_speech_token_emb = torch.zeros( + 1, 0, self.llm_input_size, dtype=text_emb.dtype + ).to(device) + speech_token_emb = self.speech_embedding(speech_token) + + # prepare llm_input/target + lm_input, lm_input_len, loss_mask = self.prepare_lm_input_target( + sos_emb, + text_token, + text_emb, + text_len, + task_id_emb, + prompt_speech_token, + prompt_speech_token_emb, + prompt_speech_token_len, + speech_token, + speech_token_emb, + speech_token_len, + ) + + # run lm forward + outputs, lm_output_mask = self.llm( + lm_input, lm_input_len.to(device), output_hidden_states + ) + lm_output = outputs.hidden_states[-1] + logits = self.llm_decoder(lm_output) + hidden_states = torch.cat(outputs.hidden_states[:-1], dim=-1) + return hidden_states, logits, lm_input, loss_mask, lm_output_mask + + def prepare_lm_input_target( + self, + sos_emb, + text_token, + text_emb, + text_len, + task_id_emb, + prompt_speech_token, + prompt_speech_token_emb, + prompt_speech_token_len, + speech_token, + speech_token_emb, + speech_token_len, + ): + lm_target, lm_input = [], [] + text_token = unpad_sequence(text_token, text_len.cpu(), batch_first=True) + text_emb = unpad_sequence(text_emb, text_len.cpu(), batch_first=True) + prompt_speech_token = unpad_sequence( + prompt_speech_token, prompt_speech_token_len.cpu(), batch_first=True + ) + prompt_speech_token_emb = unpad_sequence( + prompt_speech_token_emb, prompt_speech_token_len.cpu(), batch_first=True + ) + speech_token = unpad_sequence( + speech_token, speech_token_len.cpu(), batch_first=True + ) + speech_token_emb = unpad_sequence( + speech_token_emb, speech_token_len.cpu(), batch_first=True + ) + for i in range(len(text_token)): + this_lm_target = torch.tensor( + [IGNORE_ID] * (1 + text_len[i] + prompt_speech_token_len[i]) + + speech_token[i].tolist() + + [self.eos_token] + ) + this_lm_input = torch.concat( + [ + sos_emb.squeeze(dim=0), + text_emb[i], + task_id_emb.squeeze(dim=0), + prompt_speech_token_emb[i], + speech_token_emb[i], + ], + dim=0, + ) + lm_input.append(this_lm_input) + lm_target.append(this_lm_target) + lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32) + lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID) + lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID) + loss_mask = torch.ones_like(lm_target, device=lm_target.device) + loss_mask = loss_mask.masked_fill(lm_target == IGNORE_ID, 0) + return lm_input, lm_input_len, loss_mask diff --git a/angelslim/compressor/speculative/train/models/target/target_model_wrapper.py b/angelslim/compressor/speculative/train/models/target/target_model_wrapper.py index c3087186..23b188ae 100644 --- a/angelslim/compressor/speculative/train/models/target/target_model_wrapper.py +++ b/angelslim/compressor/speculative/train/models/target/target_model_wrapper.py @@ -12,13 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os from abc import ABC, abstractmethod from typing import List, Optional, Tuple import torch +from huggingface_hub import snapshot_download from angelslim.utils import decide_device_for_distributed, print_with_rank +from .cosyvoice3_llm import CosyVoice3LM + class BaseBackend(ABC): """ @@ -722,6 +726,76 @@ def hook(module, args, kwargs): } +class TTSTransformersBackend(TransformersBackend): + """ + HuggingFace Transformers backend implementation. + + """ + + def load_model(self) -> None: + # Load and configure model + if not os.path.exists(self.model_path): + self.model_path = snapshot_download(self.model_path) + + # Determine device based on distributed environment + self.device = decide_device_for_distributed() + print_with_rank(f"Loading model to device: {self.device}") + + # Load model + if os.path.exists(os.path.join(self.model_path, "cosyvoice3.yaml")): + self.model_name = "cosyvoice3" + self._load_cosyvoice3() + else: + raise NotImplementedError("This model is not implemented") + + self._freeze_model_parameters() + self.model.eval() + + def _load_cosyvoice3(self) -> None: + """Load text tokenizer using HuggingFace Transformers.""" + + self.model = CosyVoice3LM( + self.model_path, + llm_input_size=896, + llm_output_size=896, + speech_token_size=6561, + ).to(self.device) + self.model.load_state_dict( + torch.load( + os.path.join(self.model_path, "llm.pt"), map_location=self.device + ), + strict=True, + ) + + # Load tokenizer + self.tokenizer = self.model.tokenizer + + def get_hidden_states_and_logits( + self, + input_ids: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Extract hidden states and logits using Transformers backend. + + Args: + input_ids: Input token IDs + attention_mask: Attention mask + **kwargs: May contain 'aux_hidden_states_layer_ids' to specify custom layers + + Returns: + Tuple of (concatenated_hidden_states, logits) + """ + if self.model_name == "cosyvoice3": + with torch.no_grad(): + outputs = self.model( + **input_ids, + output_hidden_states=True, + ) + return outputs + + class TargetModelWrapper: """ Unified wrapper for target models in Eagle3 training. @@ -749,6 +823,7 @@ class TargetModelWrapper: BACKENDS = { ("hf", "LLM"): TransformersBackend, ("hf", "VLM"): VLMTransformersBackend, + ("hf", "TTS"): TTSTransformersBackend, ("hf", "Audio"): AudioTransformersBackend, } diff --git a/angelslim/compressor/speculative/train/trainer/__init__.py b/angelslim/compressor/speculative/train/trainer/__init__.py index b7012159..ad1c6729 100644 --- a/angelslim/compressor/speculative/train/trainer/__init__.py +++ b/angelslim/compressor/speculative/train/trainer/__init__.py @@ -13,13 +13,18 @@ # limitations under the License. from .offline_eagle3_trainer import OfflineEagle3Trainer, OfflineVLMEagle3Trainer -from .online_eagle3_trainer import OnlineEagle3Trainer, OnlineVLMEagle3Trainer +from .online_eagle3_trainer import ( + OnlineEagle3Trainer, + OnlineTTSEagle3Trainer, + OnlineVLMEagle3Trainer, +) from .trainer_factory import Eagle3TrainerFactory __all__ = [ "Eagle3TrainerFactory", "OnlineEagle3Trainer", "OnlineVLMEagle3Trainer", + "OnlineTTSEagle3Trainer", "OfflineEagle3Trainer", "OfflineVLMEagle3Trainer", ] diff --git a/angelslim/compressor/speculative/train/trainer/online_eagle3_trainer.py b/angelslim/compressor/speculative/train/trainer/online_eagle3_trainer.py index e0a91d6d..11f314fe 100644 --- a/angelslim/compressor/speculative/train/trainer/online_eagle3_trainer.py +++ b/angelslim/compressor/speculative/train/trainer/online_eagle3_trainer.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict +from typing import Any, Dict, List, Optional, Tuple +import torch from torch import nn from ...utils import padding @@ -218,3 +219,227 @@ def prepare_data_for_draft_model(self, inputs): } ) return result_dict + + +@Eagle3TrainerFactory.register("online", "TTS") +class OnlineTTSEagle3Trainer(Eagle3Trainer): + """ + Online EAGLE3 Trainer for speculative decoding training. + + Implements training logic for EAGLE3 model using a draft model to predict + tokens based on hidden states from a target model. + """ + + def __init__( + self, + draft_model: nn.Module, + target_model: nn.Module, + length: int, + draft_model_config: Dict[str, Any], + **kwargs, + ): + """ + Initialize the OnlineEagle3Trainer. + Args: + draft_model: Draft model for token prediction + target_model: Target model for generating hidden states + length: Number of speculative decoding steps + draft_model_config: Configuration dictionary for draft model + **kwargs: Additional arguments passed to parent Trainer + """ + super().__init__(draft_model=draft_model, length=length, **kwargs) + self.target_model = target_model + + def prepare_data_for_draft_model(self, inputs): + if self.target_model.backend.model_name == "cosyvoice3": + data_for_draft_model = self._prepare_data_for_draft_model_cosyvoice3(inputs) + else: + raise NotImplementedError("This model is not implemented") + return data_for_draft_model + + def _prepare_data_for_draft_model_cosyvoice3(self, inputs): + ( + hidden_states, + target_logits, + inputs_embeds, + loss_mask, + attention_mask, + ) = self.target_model.get_hidden_states_and_logits(input_ids=inputs) + + device = inputs_embeds.device + dtype = self.draft_model.fc.weight.dtype + + target_logits = padding(target_logits, left=False).to(device) + inputs_embeds = padding(inputs_embeds, left=False) + loss_mask = loss_mask[..., None].to(device) + + return { + "hidden_states": hidden_states.to(dtype), + "target_logits": target_logits, + "inputs_embeds": inputs_embeds.to(dtype), + "loss_mask": loss_mask, + "position_ids": None, + "attention_mask": attention_mask.squeeze(0), + } + + def compute_loss( + self, + model: nn.Module, + inputs: Dict[str, torch.Tensor], + num_items_in_batch: Optional[int] = None, + return_outputs: bool = False, + ) -> Tuple[List[torch.Tensor], List, List[float]]: + """ + Compute the training loss for the model. + + Args: + model: The model for which to compute the loss + inputs: Input data dictionary with input_ids, attention_mask, + loss_mask, position_ids + num_items_in_batch: Number of items in batch (unused) + return_outputs: Whether to return model outputs (unused) + + Returns: + Tuple of (prediction_losses, value_losses, accuracies) for each step + """ + data_for_draft_model = self.prepare_data_for_draft_model(inputs) + + attention_mask = data_for_draft_model["attention_mask"] # Batch x Seq + position_ids = data_for_draft_model["position_ids"] # Batch x Seq + target_logits = data_for_draft_model["target_logits"] # Batch x Seq x Vocab + loss_mask = data_for_draft_model["loss_mask"] # Batch x Seq x 1 + hidden_states = data_for_draft_model["hidden_states"] # Batch x Seq x Hidden + inputs_embeds = data_for_draft_model["inputs_embeds"] + + hidden_states = self.down_project_hidden_states(hidden_states) + attention_mask, position_ids = self.prepare_attention_mask_and_position_ids( + hidden_states, attention_mask, position_ids + ) + loss = self.draft_model_training_time_test( + hidden_states, + attention_mask, + position_ids, + target_logits, + loss_mask, + inputs_embeds, + log_prefix="train", + ) + + return loss + + def prediction_step( + self, + model: nn.Module, + inputs: Dict[str, torch.Tensor], + prediction_loss_only: bool, + ignore_keys: Optional[List[str]] = None, + ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Perform an evaluation step on `model` using `inputs`. + """ + data_for_draft_model = self.prepare_data_for_draft_model(**inputs) + + attention_mask = data_for_draft_model["attention_mask"] # Batch x Seq + position_ids = data_for_draft_model["position_ids"] # Batch x Seq + target_logits = data_for_draft_model["target_logits"] # Batch x Seq x Vocab + loss_mask = data_for_draft_model["loss_mask"] # Batch x Seq x 1 + hidden_states = data_for_draft_model["hidden_states"] # Batch x Seq x Hidden + inputs_embeds = data_for_draft_model["inputs_embeds"] + + with torch.no_grad(): + hidden_states = self.down_project_hidden_states(hidden_states) + attention_mask, position_ids = self.prepare_attention_mask_and_position_ids( + hidden_states, attention_mask, position_ids + ) + loss = self.draft_model_training_time_test( + hidden_states, + attention_mask, + position_ids, + target_logits, + loss_mask, + inputs_embeds, + log_prefix="eval", + ) + return (loss, None, None) + + def draft_model_training_time_test( + self, + hidden_states, + attention_mask, + position_ids, + target_logits, + loss_mask, + inputs_embeds, + log_prefix="", + ): + # Step 6: Initialize containers for losses, accuracies and cache + plosses, acces = [], [] + cache_hidden = [[], []] + + # Step 7: Iterative speculative decoding training loop + for idx in range(self.length): + # Step 7.1: Get input embeddings with gradient tracking + if not inputs_embeds.requires_grad: + inputs_embeds.requires_grad = True + + # Step 7.2: Encode through draft model layers + hidden_states, cache_hidden = self.draft_model.encode_layers( + inputs_embeds=inputs_embeds, + hidden_states=hidden_states, + cache_hidden=cache_hidden, + attention_mask=attention_mask, + position_ids=position_ids, + use_cache=True, + ) + + # Step 7.3: Compute logits from hidden states + logits = self.draft_model.compute_logits(hidden_states) + + # Step 7.4: Compute target distribution and position mask + with torch.no_grad(): + target_max_token = target_logits.argmax(-1) + target_mask = self.draft_model.t2d[target_max_token][..., None].int() + position_mask = target_mask * loss_mask + + target_head = target_logits[..., self.draft_model.t2d].float() + target_p = nn.Softmax(dim=2)(target_head).detach() + + # Step 7.5: Compute loss + out_logp = nn.LogSoftmax(dim=2)(logits) + loss = -torch.sum(position_mask * target_p * out_logp, dim=2).mean() + + # Step 7.6: Compute accuracy + with torch.no_grad(): + correct = ( + logits.argmax(-1) == target_p.argmax(-1) + ) * position_mask.squeeze(-1) + accuracy = correct.sum().item() / (loss_mask.sum().item() + 1e-6) + + # Step 7.7: Store loss and accuracy + plosses.append(loss) + acces.append(accuracy) + + # Step 7.8: Update inputs for next iteration (skip on last step) + if idx < self.length - 1: + inputs_embeds = padding(inputs_embeds, left=False) + target_logits = padding(target_logits, left=False) + loss_mask = padding(loss_mask, left=False) + + # Step 8: Compute weighted loss + ploss_weight = [0.8**i for i in range(len(plosses))] + ploss = sum([ploss_weight[i] * plosses[i] for i in range(len(plosses))]) + + log = { + f"{log_prefix}/acc_{i}": round(float(acces[i]), 3) + for i in range(len(acces)) + } + log.update( + { + f"{log_prefix}/ploss_{i}": round(float(plosses[i].item()), 3) + for i in range(len(plosses)) + } + ) + self.log(log) + + # Step 9: Return loss + return ploss diff --git a/angelslim/compressor/speculative/utils/util.py b/angelslim/compressor/speculative/utils/util.py index db13187f..aac4e1e2 100644 --- a/angelslim/compressor/speculative/utils/util.py +++ b/angelslim/compressor/speculative/utils/util.py @@ -90,9 +90,9 @@ def prepare_logits_processor( return processor_list -def initialize_tree(input_ids, model, past_key_values, logits_processor): +def initialize_tree(input_ids, inputs_embeds, model, past_key_values, logits_processor): outputs, orig, hidden_states = model( - input_ids, past_key_values=past_key_values, output_orig=True + input_ids, inputs_embeds, past_key_values=past_key_values, output_orig=True ) if logits_processor is not None: @@ -104,6 +104,11 @@ def initialize_tree(input_ids, model, past_key_values, logits_processor): token = torch.argmax(orig[:, -1]) token = token[None, None] input_ids = torch.cat((input_ids, token.to(input_ids.device)), dim=1) + # add embedding + if inputs_embeds is not None: + add_inputs_embeds = torch.cat( + [inputs_embeds, model.eagle_layer.embed_tokens(token)], dim=1 + ) # Clone the output hidden states eagle_device = next(model.eagle_layer.parameters()).device @@ -113,7 +118,9 @@ def initialize_tree(input_ids, model, past_key_values, logits_processor): ] hidden_states = torch.cat(outputs["hidden_states"], dim=-1) draft_tokens, retrieve_indices, tree_mask, tree_position_ids, _ = ( - model.eagle_layer.topK_genrate(hidden_states, input_ids, logits_processor) + model.eagle_layer.topK_genrate( + hidden_states, input_ids, add_inputs_embeds, logits_processor + ) ) return ( draft_tokens, @@ -144,9 +151,8 @@ def tree_decoding( position_ids = tree_position_ids + input_ids.shape[1] if position_ids is not None and position_ids.dim() == 1: position_ids = position_ids.unsqueeze(0) - # import pdb; pdb.set_trace() outputs, tree_logits, hidden_state = model( - tree_candidates, + input_ids=tree_candidates, output_orig=True, past_key_values=past_key_values, position_ids=position_ids, @@ -258,6 +264,7 @@ def evaluate_posterior( @torch.no_grad() def update_inference_inputs( input_ids, + inputs_embeds, candidates, best_candidate, accept_length, @@ -270,6 +277,8 @@ def update_inference_inputs( hidden_state_new, sample_token, ): + if inputs_embeds is not None: + assert input_ids.shape[1] == inputs_embeds.shape[1] prev_input_len = input_ids.shape[1] # Map the best candidate indices to the original indices in the sequence select_indices = ( @@ -283,6 +292,13 @@ def update_inference_inputs( ], dim=-1, ) + + # add embedding + if inputs_embeds is not None: + add_inputs_embeds = model.eagle_layer.embed_tokens.weight[ + candidates[None, best_candidate, : accept_length + 1].squeeze(0).tolist() + ].unsqueeze(0) + inputs_embeds = torch.cat([inputs_embeds, add_inputs_embeds], dim=1) # Update the past key values based on the selected tokens # Source tensor that contains relevant past information based # on the selected candidate @@ -305,10 +321,17 @@ def update_inference_inputs( :, best_candidate, : accept_length + 1 ] + # add embedding + if inputs_embeds is not None: + add_inputs_embeds = model.eagle_layer.embed_tokens.weight[ + sample_token.squeeze(0).tolist() + ].unsqueeze(0) + draft_tokens, retrieve_indices, tree_mask, tree_position_ids, early_stop_signal = ( model.eagle_layer.topK_genrate( accept_hidden_state_new, input_ids=torch.cat((input_ids, sample_token.to(input_ids.device)), dim=1), + inputs_embeds=torch.cat([inputs_embeds, add_inputs_embeds], dim=1), logits_processor=logits_processor, ) ) @@ -317,6 +340,7 @@ def update_inference_inputs( return ( input_ids, + inputs_embeds, draft_tokens, retrieve_indices, tree_mask, diff --git a/angelslim/engine.py b/angelslim/engine.py index 3ec3e3af..7ddcddf9 100644 --- a/angelslim/engine.py +++ b/angelslim/engine.py @@ -449,6 +449,8 @@ def setup_benchmark( config_dict.update(kwargs) self.config = self.BenchmarkConfig(**config_dict) + if self.config.is_tts: + self.BenchmarkEngine = pytorch_benchmark.TTSBenchmarkEngine self.benchmark_engine = self.BenchmarkEngine(self.config) return self.config diff --git a/angelslim/utils/lazy_imports.py b/angelslim/utils/lazy_imports.py index 246050e8..f544ca86 100644 --- a/angelslim/utils/lazy_imports.py +++ b/angelslim/utils/lazy_imports.py @@ -208,6 +208,12 @@ def __getattr__(self, name: str) -> Any: # --- multimodal related lazy imports --- qwen_vl_utils = LazyModule("qwen_vl_utils", "multimodal") qwen_omni_utils = LazyModule("qwen_omni_utils", "multimodal") +torchaudio = LazyModule("torchaudio", "multimodal") +whisper = LazyModule("whisper", "multimodal") +onnxruntime = LazyModule("onnxruntime", "multimodal") +inflect = LazyModule("inflect", "multimodal") +librosa = LazyModule("librosa", "multimodal") +wetext = LazyModule("wetext", "multimodal") # --- HunyuanVL related lazy imports --- HunYuanVLForConditionalGeneration = LazyAttribute( diff --git a/dataset/tts_fake_data/question.jsonl b/dataset/tts_fake_data/question.jsonl new file mode 100644 index 00000000..db76edaa --- /dev/null +++ b/dataset/tts_fake_data/question.jsonl @@ -0,0 +1,100 @@ +{"tts_text": "\"He says we're to start to morrow at daybreak.\"", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "\"And all seems favourable for our attempt to morrow?\"", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "\"I wish we could charge them boldly, and send them flying over the plains.\"", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "\"I shouldn't be a bit surprised if we saw them over the way there-just one or two, scouting; and if we do I should be for a stand at arms all night, for it might mean an attack after dark.\"", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "\"Nothing, sir, but wait.\"", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "\"Exactly,\" replied the doctor.", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "\"Hurrah!\" cried Chris.", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "\"Yes; of course,\" said Chris, with a dubious look all the same.", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "Finally, just at dusk the animals can be driven in for food and water, and-\"", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "\"Serve 'em right if they did, sir.", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "\"Look here, lads.", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "\"Never to come back again,\" said Ned sharply.", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "The Indians can wait; we cannot, and they seem to know it.", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "\"They seem to me to be hatching up some dodge or another,\" replied Griggs.", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "GRIGGS IS STUBBORN.", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "\"But we shan't, my lad.\"", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "\"Look here; if you say that again we shall quarrel.\"", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "\"Why not?\"", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "CHAPTER FORTY NINE.", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "It takes a deal to starve a redskin.", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "We must get away from here to some good hunting ground.", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "\"Yes, yes, but you know what I mean.\"", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "\"No, sir.", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "\"Not a bit of it, sir.", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "\"Oh, I'm only a little stiff still.", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "Well, I do.", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "They'd have weeks of work before they could get their horses out but without horses they'd be out in a week.", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "Never been away at all, I believe.", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "\"But the enemy won't be standing still,\" continued Griggs.", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "\"But all the same we can be making our preparations.", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "We're not waiting for you now.", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "I've just been trying that place again.", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "\"I'm not afraid of them hitting me, my lad,\" said Griggs confidently. \"Being shot at by fellows with bows and arrows sounds bad enough, but there's not much risk here.\"", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "\"You mean the shutting up the enemy here to starve?\" said Bourne.", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "\"Canter?", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "How did your pony go this morning?\"", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "It's all settled, gentlemen.", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "\"Oh yes, I hear.", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "\"Oh, here you are, Griggs,\" cried the doctor.", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "\"Would be if we let them get the better of us, sir.\"", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "\"Oh!", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "\"Why?\"", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "I shall be all right.", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "We shall get strong more quickly journeying over the plains or climbing in and out among the mountains.", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "But there, I don't want to make speeches.", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "\"Very well; I can do that,\" said Ned haughtily.", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "I don't like to bother my father any more, but what does he say?\"", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "\"Well, they showed themselves to me; I didn't want them,\" said Griggs dryly.", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "\"Right; I do, neighbour, and it's very handsome of you to offer me the chance to back out.", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "\"Had a good turn at scouting?\"", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "\"Yes, sir.\"", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "\"It's all nonsense, Ned,\" cried Chris, \"for them to think they are staying on account of us.--Hullo, Griggs!", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "We went at a good swinging gallop.\"", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "\"Going to give up young Chris's plan?\" said Griggs slowly.", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "The plan is to get out of this valley ourselves, where we are regularly locked in, and to put the redskins in our place, locking them in.", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "\"Don't you?", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "Provisions can be packed in our wallets; in fact, everything held ready for a start.", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "\"They're an artful lot.", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "\"Did you canter this morning?\"", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "You propose offering yourself for a mark to the Indians' arrows, and-\"", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "\"Splendid.", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "Serve you right.", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "\"Well, you must talk it over with father,\" said Chris.", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "Don't you see that we're playing a very ticklish game?", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "I've no doubt about it now.", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "Griggs nodded his head.", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "A thrill of excitement ran through Chris, and his heart began to beat. Then he was listening, so to speak, with all his might.", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "\"That's right,\" cried Griggs.", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "The Indians have shifted their quarters, and they're in about as awkward a position as they could contrive for our purpose.\"", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "\"Yes, my lad; but I want them to be planted farther back still.", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "\"Because I've seen Indians again.\"", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "Nonsense!", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "\"Then what do you propose?\"", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "\"We've been patient enough.\"", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "\"Don't be so petty, Ned.", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "Here, I want for us to be off.", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "We start to night.\"", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "I'm going to take care they don't hit me.\"", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "Were you listening?\"", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "\"There,\" he said, \"I've made up my mind.", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "\"Not allowed to go off again?\"", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "I have felt something of the kind, but I am convinced now that it will not, and that we must chance something and make it.\"", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "\"Too many redskins about, as I told you.\"", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "\"Hear that, Griggs?\"", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "There's a bit I've been looking out quite a quarter of a mile farther off, and I'm going to propose it to the doctor as being safest.\"", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "\"Any time the doctor likes.\"", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "\"Not quite, my lads.", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "\"I don't know about that,\" said Chris anxiously.", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "\"Let's see; we're going to have another look at the place this afternoon, aren't we?\"", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "\"Because it is a very risky thing to do.", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "\"What for?\" said Griggs sharply.", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "We've got something else to think about besides teasing and bantering.\"", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "I should be running fast and dodging in and out among the rocks and trees.", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "Just halted a little on the bad leg; but it's better than it was yesterday.\"", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "\"And what about you?\"", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "\"But what about the arrows?\" said Ned.", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "Every one was busy, for the keeping watch regularly took up a good deal of time.", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "You're always seeing Indians again.", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "The days glided by, with the stiffness in Chris Lee's limbs growing less painful, and the pony recovering fast, for the clear mountain air seemed to act like a cure for wounds.", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} +{"tts_text": "\"There, that's enough,\" cried Chris.", "prompt_wav": "./zero_shot_prompt.wav", "prompt_text": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"} \ No newline at end of file diff --git a/dataset/tts_fake_data/train.jsonl b/dataset/tts_fake_data/train.jsonl new file mode 100644 index 00000000..c3ba25a1 --- /dev/null +++ b/dataset/tts_fake_data/train.jsonl @@ -0,0 +1,2 @@ +{"text": "\"But you will get no more money out of me, I promise you.\"", "audio_path": "./asset/0.wav", "instruct": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。", "instruct_audio_path": "./asset/zero_shot_prompt.wav"} +{"text": "His ways might be affected and effeminate and his conversational powers indifferent; but his bandaged wrist was a constant reminder to all the nieces that he possessed courage and ready wit, and it was but natural that he became more interesting to them because just now he was to an extent helpless, and his crippled hand had been acquired in their service.", "audio_path": "./asset/1.wav", "instruct": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。", "instruct_audio_path": "./asset/zero_shot_prompt.wav"} \ No newline at end of file diff --git a/dataset/tts_fake_data/train_regenerate.jsonl b/dataset/tts_fake_data/train_regenerate.jsonl new file mode 100644 index 00000000..26ed1b7b --- /dev/null +++ b/dataset/tts_fake_data/train_regenerate.jsonl @@ -0,0 +1,2 @@ +{"text": "\"But you will get no more money out of me, I promise you.\"", "audio_tokens": [29, 29, 248, 503, 6361, 5100, 1869, 600, 4592, 1424, 1901, 1921, 3943, 1514, 1712, 3946, 6535, 5817, 5646, 3456, 367, 170, 2138, 6554, 5526, 1140, 708, 4809, 348, 159, 1212, 5343, 5663, 1778, 323, 3719, 5180, 6065, 1052, 323, 1295, 4697, 2834, 5100, 321, 3872, 4115, 4537, 4455, 4547, 2642, 2912, 4930, 5656, 1329, 597, 4839, 5505, 1946, 1217, 1052, 80, 3875, 3872, 1685, 929, 140, 4591, 4512, 4509, 4456, 4559, 4586, 6283, 1879, 494, 521, 4433, 5165, 6038, 2591, 2825, 2740, 3471, 566, 6546, 6300, 3861, 4887, 4968, 3565, 1901, 3863, 6051, 6126, 5155, 5153, 5114, 2216, 2, 29], "instruct": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。", "instruct_audio_path": "./dataset/tts_fake_data/zero_shot_prompt.wav"} +{"text": "His ways might be affected and effeminate and his conversational powers indifferent; but his bandaged wrist was a constant reminder to all the nieces that he possessed courage and ready wit, and it was but natural that he became more interesting to them because just now he was to an extent helpless, and his crippled hand had been acquired in their service.", "audio_tokens": [29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 28, 136, 3133, 6050, 6544, 6546, 4860, 4164, 1739, 6373, 5823, 5815, 6301, 3809, 1631, 1631, 5589, 3402, 80, 1781, 2348, 6310, 1147, 494, 1223, 1928, 1901, 5553, 5172, 4916, 2675, 5094, 5086, 197, 2411, 411, 2865, 6021, 6012, 6033, 1383, 5274, 4536, 4826, 966, 663, 6312, 3190, 2729, 2702, 5796, 5086, 1285, 2024, 3855, 240, 6039, 6013, 2352, 654, 2652, 109, 4510, 4509, 4510, 4591, 2269, 4798, 3150, 726, 3352, 6059, 5616, 2275, 4598, 4440, 2507, 3561, 645, 2702, 4283, 5566, 4968, 6516, 6554, 1919, 5057, 5756, 483, 5425, 6140, 503, 4431, 4601, 2669, 2906, 6139, 3677, 5966, 3878, 4051, 3402, 4537, 6275, 483, 375, 1383, 5815, 3544, 4916, 5594, 5984, 3126, 3096, 105, 654, 2895, 111, 2323, 2322, 2322, 4510, 4590, 4509, 2244, 494, 503, 4669, 2121, 6268, 3861, 4131, 1318, 512, 4219, 6041, 5824, 5085, 645, 1392, 6058, 1158, 662, 5759, 664, 3491, 1475, 3950, 6550, 6040, 5274, 4725, 2673, 654, 2653, 254, 3917, 6234, 3402, 5721, 843, 4544, 4435, 4429, 4685, 2747, 2825, 645, 513, 4887, 2835, 4297, 5310, 102, 2918, 5984, 3724, 323, 4694, 2510, 5097, 6545, 240, 1384, 6019, 5921, 664, 2650, 1843, 1757, 4375, 2188, 4384, 4907, 4934, 4934, 5663, 6382, 6373, 1957, 1959, 6312, 159, 1455, 4358, 1442, 2162, 4914, 2727, 5997, 6003, 3810, 1542, 5589, 2457, 501, 4227, 5313, 1411, 2405, 2153, 2171, 494, 2492, 3876, 4131, 4131, 4851, 4825, 5313, 4914, 4860, 678, 2625, 2357, 4597, 4509, 5579, 5821, 6461, 5948, 6015, 6004, 6003, 663, 5024, 5513, 4536, 4556, 4609, 240, 1385, 1715, 1466, 5587, 5580, 1384, 1658, 1901, 1469, 1712, 6375, 6553, 5311, 4536, 411, 675, 4836, 139, 2322, 4509, 4510, 2323, 2322, 4591, 109, 4455, 4826, 5337, 1428, 6553, 2188, 1721, 6394, 2025, 2676, 503, 5639, 5073, 165, 402, 483, 4851, 2909, 2900, 2904, 655, 5054, 4976, 3896, 5948, 5125, 6094, 984, 3417, 3129, 681, 2405, 1424, 1903, 1232, 1919, 170, 2405, 5329, 6302, 3871, 323, 1052, 6149, 5663, 5660, 5273, 5266, 6544, 726, 5057, 5705, 6064, 5049, 2673, 654, 6274, 209, 156, 705, 5538, 1554, 1230, 5310, 4583, 2968, 80, 809, 56, 4428, 2322, 2322, 4509, 4590, 35, 494, 4327, 1685, 2384, 5322, 6315, 1944, 5589, 664, 2837, 6277, 5089, 4998, 4860, 567, 483, 2184, 5095, 2911, 2668, 2401, 4426, 5870, 3701, 31, 4509, 4591, 2404, 109, 2405, 2171, 1841, 1721, 6144, 6186, 4860, 654, 5077, 1843, 1514, 5748, 483, 6556, 170, 4489, 4887, 4860, 654, 6519, 2899, 2668, 915, 102, 663, 4836, 4590, 4591, 4831, 5634, 4159, 494, 601, 1719, 3906, 4560, 6003, 4752, 4887, 4887, 4887, 31, 4509, 2322, 2323, 2404, 2296, 4536, 4826, 3153, 699, 4078, 6048, 5589, 2194, 2357, 4433, 3194, 6550, 1232, 2681, 6382, 4185, 1245, 4518, 4591, 5824, 6554, 6068, 5310, 2371, 3021, 483, 2865, 4509, 4843, 4062, 503, 3410, 6309, 1456, 5340, 89, 4409, 2243, 1001, 6383, 2822, 2903, 5581, 6023, 6007, 5939, 6029, 664, 3081, 2351, 6310, 159, 78, 5553, 6074, 4085, 5022, 4941, 6522, 6559, 6551, 3770, 3404, 5310, 5283, 5295, 4887, 4887, 4887, 4887, 2406, 110, 29, 29, 2], "instruct": "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。", "instruct_audio_path": "./dataset/tts_fake_data/zero_shot_prompt.wav"} \ No newline at end of file diff --git a/dataset/tts_fake_data/zero_shot_prompt.wav b/dataset/tts_fake_data/zero_shot_prompt.wav new file mode 100644 index 00000000..a7b9d954 Binary files /dev/null and b/dataset/tts_fake_data/zero_shot_prompt.wav differ diff --git a/requirements/requirements_multimodal.txt b/requirements/requirements_multimodal.txt index 621f664f..6bf3bc5b 100644 --- a/requirements/requirements_multimodal.txt +++ b/requirements/requirements_multimodal.txt @@ -1,3 +1,9 @@ qwen_vl_utils==0.0.11 qwen_omni_utils -mistral_common \ No newline at end of file +torchaudio +openai-whisper +onnxruntime-gpu +inflect +wetext +librosa +mistral_common diff --git a/scripts/speculative/train_eagle3_tts_online.sh b/scripts/speculative/train_eagle3_tts_online.sh new file mode 100644 index 00000000..302ab470 --- /dev/null +++ b/scripts/speculative/train_eagle3_tts_online.sh @@ -0,0 +1,32 @@ +#!/bin/bash + +export CONFIG_DIR=angelslim/compressor/speculative/train/configs +export TARGET_MODEL_NAME_OR_PATH= +export DRAFT_MODEL_CONFIG_PATH=$CONFIG_DIR/cosyvoice3-llm-eagle3.json +export TRAIN_DATA_PATH= +export OUTPUT_DIR= +export RUN_NAME= +export MODEL_MAX_LENGTH= + +torchrun --nproc_per_node=8 tools/train_eagle3_online.py \ + --modal_type TTS \ + --target_model_name_or_path $TARGET_MODEL_NAME_OR_PATH \ + --draft_model_config_path $DRAFT_MODEL_CONFIG_PATH \ + --train_data_path $TRAIN_DATA_PATH \ + --output_dir $OUTPUT_DIR \ + --num_train_epochs 20 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --gradient_accumulation_steps 1 \ + --save_strategy "steps" \ + --save_steps 1000 \ + --learning_rate 1e-4 \ + --weight_decay 0.0 \ + --warmup_ratio 0.1 \ + --lr_scheduler_type "constant" \ + --logging_steps 20 \ + --model_max_length $MODEL_MAX_LENGTH \ + --training_time_test_length 4 \ + --deepspeed $CONFIG_DIR/deepspeed_zero3.json \ + --report_to wandb \ + --run_name $RUN_NAME \ \ No newline at end of file diff --git a/tools/spec_benchmark.py b/tools/spec_benchmark.py index f56da48e..af9f1b83 100644 --- a/tools/spec_benchmark.py +++ b/tools/spec_benchmark.py @@ -121,6 +121,12 @@ def parse_args() -> argparse.Namespace: default=1, help="Tensor parallel size for draft model (vllm only)", ) + parser.add_argument( + "--is-tts", action="store_true", help="whether or not TTS model" + ) + parser.add_argument( + "--generate-audio", action="store_true", help="whether or not generate audio" + ) return parser.parse_args() @@ -155,10 +161,12 @@ def main(): "top_p": args.top_p, "top_k": args.top_k, "depth": args.depth, + "is_tts": args.is_tts, + "generate_audio": args.generate_audio, } # Add backend-specific parameters - if args.deploy_backend == "pytorch": + if "pytorch" in args.deploy_backend: config_dict.update( { "total_token": args.total_token, diff --git a/tools/train_eagle3_online.py b/tools/train_eagle3_online.py index 02b1b2e2..f16d1612 100644 --- a/tools/train_eagle3_online.py +++ b/tools/train_eagle3_online.py @@ -40,7 +40,7 @@ def parse_args(): "--modal_type", type=str, default="LLM", - choices=["LLM", "VLM", "Audio"], + choices=["LLM", "VLM", "Audio", "TTS"], help="Modal type: LLM for language models, VLM for vision-language models", ) model_group.add_argument(