Skip to content

Commit 4cf6d70

Browse files
authored
Support Fun-CosyVoice3 eagle3 speculative decoding (#198)
1 parent a62fc6c commit 4cf6d70

35 files changed

Lines changed: 4587 additions & 43 deletions

angelslim/compressor/speculative/benchmark/pytorch/__init__.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from .benchmark_engine import BenchmarkConfig, BenchmarkEngine, BenchmarkMode
15+
from .benchmark_engine import (
16+
BenchmarkConfig,
17+
BenchmarkEngine,
18+
BenchmarkMode,
19+
TTSBenchmarkEngine,
20+
)
1621

17-
__all__ = ["BenchmarkEngine", "BenchmarkConfig", "BenchmarkMode"]
22+
__all__ = ["BenchmarkEngine", "TTSBenchmarkEngine", "BenchmarkConfig", "BenchmarkMode"]

angelslim/compressor/speculative/benchmark/pytorch/benchmark_engine.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@
2525
from angelslim.utils.lazy_imports import fastchat, ray
2626

2727
from .generate_baseline_answer import get_model_answers as get_baseline_answers
28+
from .generate_baseline_answer import get_tts_answers as get_tts_baseline_answers
29+
from .generate_baseline_answer import get_tts_audios as get_tts_baseline_audios
2830
from .generate_eagle_answer import get_model_answers as get_eagle_answers
31+
from .generate_eagle_answer import get_tts_answers as get_tts_eagle_answers
32+
from .generate_eagle_answer import get_tts_audios as get_tts_eagle_audios
2933

3034

3135
class BenchmarkMode(Enum):
@@ -77,6 +81,10 @@ class BenchmarkConfig:
7781
# Batch settings
7882
batch_size: int = 1
7983

84+
# TTS settings
85+
is_tts: bool = False
86+
generate_audio: bool = False
87+
8088

8189
class BenchmarkEngine:
8290
"""Core benchmark engine for speculative decoding evaluation"""
@@ -343,6 +351,10 @@ def _create_args_namespace(self, mode: str) -> argparse.Namespace:
343351

344352
args.early_stop_method = self.config.early_stop_method
345353

354+
# TTS settings
355+
args.is_tts = self.config.is_tts
356+
args.generate_audio = self.config.generate_audio
357+
346358
return args
347359

348360
def _get_question_file_path(self) -> str:
@@ -397,3 +409,119 @@ def get_performance_summary(self) -> str:
397409
summary.append(f"Analysis Report: {self.analysis_file}")
398410

399411
return "\n".join(summary)
412+
413+
414+
class TTSBenchmarkEngine(BenchmarkEngine):
415+
"""Core benchmark engine for speculative decoding evaluation"""
416+
417+
def _run_eagle_benchmark(self):
418+
"""Run Eagle speculative decoding benchmark"""
419+
args = self._create_args_namespace("eagle")
420+
421+
questions = fastchat.llm_judge.common.load_questions(
422+
self._get_question_file_path(),
423+
self.config.question_begin,
424+
self.config.question_end,
425+
)
426+
427+
use_ray = self.config.num_gpus_total // self.config.num_gpus_per_model > 1
428+
get_answers_func = (
429+
ray.remote(num_gpus=self.config.num_gpus_per_model)(
430+
get_tts_eagle_answers
431+
).remote
432+
if use_ray
433+
else get_tts_eagle_answers
434+
)
435+
436+
chunk_size = len(questions) // (
437+
self.config.num_gpus_total // self.config.num_gpus_per_model
438+
)
439+
ans_handles = [
440+
get_answers_func(
441+
f"{self.config.model_id}-temperature-{self.config.temperature}",
442+
questions[i : i + chunk_size],
443+
self.eagle_file,
444+
self.config.num_choices,
445+
self.config.temperature,
446+
args,
447+
)
448+
for i in range(0, len(questions), chunk_size)
449+
]
450+
451+
if use_ray:
452+
ray.get(ans_handles)
453+
454+
self._reorg_answer_file(self.eagle_file)
455+
self.results["eagle_file"] = self.eagle_file
456+
457+
if self.config.generate_audio:
458+
self._generate_audio("eagle")
459+
460+
def _run_baseline_benchmark(self):
461+
"""Run baseline benchmark"""
462+
args = self._create_args_namespace("baseline")
463+
464+
questions = fastchat.llm_judge.common.load_questions(
465+
self._get_question_file_path(),
466+
self.config.question_begin,
467+
self.config.question_end,
468+
)
469+
470+
use_ray = self.config.num_gpus_total // self.config.num_gpus_per_model > 1
471+
get_answers_func = (
472+
ray.remote(num_gpus=self.config.num_gpus_per_model)(
473+
get_tts_baseline_answers
474+
).remote
475+
if use_ray
476+
else get_tts_baseline_answers
477+
)
478+
479+
chunk_size = len(questions) // (
480+
self.config.num_gpus_total // self.config.num_gpus_per_model
481+
)
482+
ans_handles = [
483+
get_answers_func(
484+
f"{self.config.model_id}-temperature-{self.config.temperature}",
485+
questions[i : i + chunk_size],
486+
self.baseline_file,
487+
self.config.num_choices,
488+
self.config.temperature,
489+
args,
490+
)
491+
for i in range(0, len(questions), chunk_size)
492+
]
493+
494+
if use_ray:
495+
ray.get(ans_handles)
496+
497+
self._reorg_answer_file(self.baseline_file)
498+
self.results["baseline_file"] = self.baseline_file
499+
500+
if self.config.generate_audio:
501+
self._generate_audio("baseline")
502+
503+
def _calculate_metrics(self) -> Dict[str, Any]:
504+
"""Calculate acceptance length and speedup ratio"""
505+
metrics = {}
506+
507+
# Calculate acceptance length from Eagle results
508+
if os.path.exists(self.eagle_file):
509+
metrics["acceptance_length"] = self._calculate_acceptance_length(
510+
self.eagle_file
511+
)
512+
513+
return metrics
514+
515+
def _generate_audio(self, mode):
516+
args = self._create_args_namespace(mode)
517+
518+
answers = fastchat.llm_judge.common.load_questions(
519+
args.answer_file,
520+
self.config.question_begin,
521+
self.config.question_end,
522+
)
523+
524+
if mode == "baseline":
525+
get_tts_baseline_audios(answers, args.answer_file, args)
526+
else:
527+
get_tts_eagle_audios(answers, args.answer_file, args)

0 commit comments

Comments
 (0)