Skip to content

Commit e4498c9

Browse files
committed
add Fun-CosyVoice3 eagle3 speculative decoding
1 parent d43bcbc commit e4498c9

36 files changed

Lines changed: 5189 additions & 31 deletions

angelslim/compressor/speculative/benchmark/pytorch/__init__.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from .benchmark_engine import BenchmarkConfig, BenchmarkEngine, BenchmarkMode
15+
from .benchmark_engine import (
16+
BenchmarkConfig,
17+
BenchmarkEngine,
18+
BenchmarkMode,
19+
TTSBenchmarkEngine,
20+
)
1621

17-
__all__ = ["BenchmarkEngine", "BenchmarkConfig", "BenchmarkMode"]
22+
__all__ = ["BenchmarkEngine", "TTSBenchmarkEngine", "BenchmarkConfig", "BenchmarkMode"]

angelslim/compressor/speculative/benchmark/pytorch/benchmark_engine.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@
2525
from angelslim.utils.lazy_imports import fastchat, ray
2626

2727
from .generate_baseline_answer import get_model_answers as get_baseline_answers
28+
from .generate_baseline_answer import get_tts_answers as get_tts_baseline_answers
29+
from .generate_baseline_answer import get_tts_audios as get_tts_baseline_audios
2830
from .generate_eagle_answer import get_model_answers as get_eagle_answers
31+
from .generate_eagle_answer import get_tts_answers as get_tts_eagle_answers
32+
from .generate_eagle_answer import get_tts_audios as get_tts_eagle_audios
2933

3034

3135
class BenchmarkMode(Enum):
@@ -77,6 +81,9 @@ class BenchmarkConfig:
7781
# Batch settings
7882
batch_size: int = 1
7983

84+
# TTS settings
85+
generate_audio: bool = False
86+
8087

8188
class BenchmarkEngine:
8289
"""Core benchmark engine for speculative decoding evaluation"""
@@ -343,6 +350,9 @@ def _create_args_namespace(self, mode: str) -> argparse.Namespace:
343350

344351
args.early_stop_method = self.config.early_stop_method
345352

353+
# TTS settings
354+
args.generate_audio = self.config.generate_audio
355+
346356
return args
347357

348358
def _get_question_file_path(self) -> str:
@@ -397,3 +407,119 @@ def get_performance_summary(self) -> str:
397407
summary.append(f"Analysis Report: {self.analysis_file}")
398408

399409
return "\n".join(summary)
410+
411+
412+
class TTSBenchmarkEngine(BenchmarkEngine):
413+
"""Core benchmark engine for speculative decoding evaluation"""
414+
415+
def _run_eagle_benchmark(self):
416+
"""Run Eagle speculative decoding benchmark"""
417+
args = self._create_args_namespace("eagle")
418+
419+
questions = fastchat.llm_judge.common.load_questions(
420+
self._get_question_file_path(),
421+
self.config.question_begin,
422+
self.config.question_end,
423+
)
424+
425+
use_ray = self.config.num_gpus_total // self.config.num_gpus_per_model > 1
426+
get_answers_func = (
427+
ray.remote(num_gpus=self.config.num_gpus_per_model)(
428+
get_tts_eagle_answers
429+
).remote
430+
if use_ray
431+
else get_tts_eagle_answers
432+
)
433+
434+
chunk_size = len(questions) // (
435+
self.config.num_gpus_total // self.config.num_gpus_per_model
436+
)
437+
ans_handles = [
438+
get_answers_func(
439+
f"{self.config.model_id}-temperature-{self.config.temperature}",
440+
questions[i : i + chunk_size],
441+
self.eagle_file,
442+
self.config.num_choices,
443+
self.config.temperature,
444+
args,
445+
)
446+
for i in range(0, len(questions), chunk_size)
447+
]
448+
449+
if use_ray:
450+
ray.get(ans_handles)
451+
452+
self._reorg_answer_file(self.eagle_file)
453+
self.results["eagle_file"] = self.eagle_file
454+
455+
if self.config.generate_audio:
456+
self._generate_audio("eagle")
457+
458+
def _run_baseline_benchmark(self):
459+
"""Run baseline benchmark"""
460+
args = self._create_args_namespace("baseline")
461+
462+
questions = fastchat.llm_judge.common.load_questions(
463+
self._get_question_file_path(),
464+
self.config.question_begin,
465+
self.config.question_end,
466+
)
467+
468+
use_ray = self.config.num_gpus_total // self.config.num_gpus_per_model > 1
469+
get_answers_func = (
470+
ray.remote(num_gpus=self.config.num_gpus_per_model)(
471+
get_tts_baseline_answers
472+
).remote
473+
if use_ray
474+
else get_tts_baseline_answers
475+
)
476+
477+
chunk_size = len(questions) // (
478+
self.config.num_gpus_total // self.config.num_gpus_per_model
479+
)
480+
ans_handles = [
481+
get_answers_func(
482+
f"{self.config.model_id}-temperature-{self.config.temperature}",
483+
questions[i : i + chunk_size],
484+
self.baseline_file,
485+
self.config.num_choices,
486+
self.config.temperature,
487+
args,
488+
)
489+
for i in range(0, len(questions), chunk_size)
490+
]
491+
492+
if use_ray:
493+
ray.get(ans_handles)
494+
495+
self._reorg_answer_file(self.baseline_file)
496+
self.results["baseline_file"] = self.baseline_file
497+
498+
if self.config.generate_audio:
499+
self._generate_audio("baseline")
500+
501+
def _calculate_metrics(self) -> Dict[str, Any]:
502+
"""Calculate acceptance length and speedup ratio"""
503+
metrics = {}
504+
505+
# Calculate acceptance length from Eagle results
506+
if os.path.exists(self.eagle_file):
507+
metrics["acceptance_length"] = self._calculate_acceptance_length(
508+
self.eagle_file
509+
)
510+
511+
return metrics
512+
513+
def _generate_audio(self, mode):
514+
args = self._create_args_namespace(mode)
515+
516+
answers = fastchat.llm_judge.common.load_questions(
517+
args.answer_file,
518+
self.config.question_begin,
519+
self.config.question_end,
520+
)
521+
522+
if mode == "baseline":
523+
get_tts_baseline_audios(answers, args.answer_file, args)
524+
else:
525+
get_tts_eagle_audios(answers, args.answer_file, args)

0 commit comments

Comments
 (0)