|
| 1 | +#!/usr/bin/env python3 |
| 2 | +"""Benchmark compiled vs eager duplex inference on the same video. |
| 3 | +
|
| 4 | +Loads the model once (with compile=True), then runs the same omni full-duplex |
| 5 | +session twice: once with compiled modules, once with eager modules. |
| 6 | +Prints a side-by-side timing comparison at the end. |
| 7 | +
|
| 8 | +Usage: |
| 9 | + CUDA_VISIBLE_DEVICES=0 TORCHINDUCTOR_CACHE_DIR=./torch_compile_cache \ |
| 10 | + PYTHONPATH=. .venv/base/bin/python test_compile_bench.py |
| 11 | +""" |
| 12 | + |
| 13 | +import os |
| 14 | +import sys |
| 15 | +import time |
| 16 | +import logging |
| 17 | +import torch |
| 18 | +from config import get_config |
| 19 | + |
| 20 | +logging.basicConfig( |
| 21 | + level=logging.INFO, |
| 22 | + format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", |
| 23 | +) |
| 24 | +logger = logging.getLogger("compile_bench") |
| 25 | + |
| 26 | +VIDEO_PATH = os.path.join( |
| 27 | + os.path.dirname(os.path.abspath(__file__)), |
| 28 | + "assets", "samples", "compile.mp4", |
| 29 | +) |
| 30 | +MAX_CHUNKS = 8 |
| 31 | + |
| 32 | + |
| 33 | +def module_type_label(mod) -> str: |
| 34 | + cls = type(mod).__name__ |
| 35 | + if cls == "OptimizedModule": |
| 36 | + return f"OptimizedModule (compiled)" |
| 37 | + return f"{cls} (eager)" |
| 38 | + |
| 39 | + |
| 40 | +def print_header(label: str, model): |
| 41 | + active = getattr(model, "_compile_active", "N/A") |
| 42 | + llm_label = module_type_label(model.llm.model) |
| 43 | + tts_label = module_type_label(model.tts.model) if hasattr(model.tts, "model") else "N/A" |
| 44 | + print(f"\n{'='*70}") |
| 45 | + print(f" {label}") |
| 46 | + print(f" _compile_active = {active}") |
| 47 | + print(f" llm.model = {llm_label}") |
| 48 | + print(f" tts.model = {tts_label}") |
| 49 | + print(f"{'='*70}") |
| 50 | + |
| 51 | + |
| 52 | +def run_bench(model, label: str) -> dict: |
| 53 | + print_header(label, model) |
| 54 | + t0 = time.time() |
| 55 | + result = model.benchmark( |
| 56 | + video_paths=[VIDEO_PATH], |
| 57 | + max_chunks_per_video=MAX_CHUNKS, |
| 58 | + ) |
| 59 | + elapsed = time.time() - t0 |
| 60 | + print(f" [{label}] done in {elapsed:.1f}s, " |
| 61 | + f"units={result.get('num_units', 0)}, " |
| 62 | + f"listen={result.get('listen_count', 0)}, " |
| 63 | + f"speak={result.get('speak_count', 0)}") |
| 64 | + return result |
| 65 | + |
| 66 | + |
| 67 | +def format_stats(stats: dict, key_path: str) -> str: |
| 68 | + keys = key_path.split(".") |
| 69 | + d = stats |
| 70 | + for k in keys: |
| 71 | + d = d.get(k, {}) |
| 72 | + if not d: |
| 73 | + return "N/A" |
| 74 | + return f"avg={d.get('avg', 0):.0f}ms min={d.get('min', 0):.0f}ms max={d.get('max', 0):.0f}ms" |
| 75 | + |
| 76 | + |
| 77 | +def print_comparison(compiled_result: dict, eager_result: dict): |
| 78 | + print("\n") |
| 79 | + print("=" * 70) |
| 80 | + print(" Compiled vs Eager 对比") |
| 81 | + print("=" * 70) |
| 82 | + |
| 83 | + rows = [ |
| 84 | + ("总用时", "total_time", "s", True), |
| 85 | + ] |
| 86 | + |
| 87 | + # top-level |
| 88 | + for label, key, unit, is_time in rows: |
| 89 | + cv = compiled_result.get(key, 0) |
| 90 | + ev = eager_result.get(key, 0) |
| 91 | + if is_time: |
| 92 | + diff_pct = ((ev - cv) / cv * 100) if cv > 0 else 0 |
| 93 | + print(f" {label:20s} compiled={cv:.1f}{unit} eager={ev:.1f}{unit} " |
| 94 | + f"差异={diff_pct:+.1f}%") |
| 95 | + else: |
| 96 | + print(f" {label:20s} compiled={cv} eager={ev}") |
| 97 | + |
| 98 | + print(f" {'单位数':20s} compiled={compiled_result.get('num_units',0)} " |
| 99 | + f"eager={eager_result.get('num_units',0)}") |
| 100 | + |
| 101 | + # per-decision-type stats |
| 102 | + for decision in ("listen", "speak"): |
| 103 | + cs = compiled_result.get(f"{decision}_stats", {}) |
| 104 | + es = eager_result.get(f"{decision}_stats", {}) |
| 105 | + cc = cs.get("count", 0) |
| 106 | + ec = es.get("count", 0) |
| 107 | + if cc == 0 and ec == 0: |
| 108 | + continue |
| 109 | + |
| 110 | + print(f"\n ── {decision.upper()} (compiled n={cc}, eager n={ec}) ──") |
| 111 | + |
| 112 | + metric_paths = [ |
| 113 | + ("prefill total", "prefill.total"), |
| 114 | + (" vision_process", "prefill.vision_process"), |
| 115 | + (" vision_embed", "prefill.vision_embed"), |
| 116 | + (" vision_feed", "prefill.vision_feed"), |
| 117 | + (" audio_process", "prefill.audio_process"), |
| 118 | + (" audio_embed", "prefill.audio_embed"), |
| 119 | + (" audio_feed", "prefill.audio_feed"), |
| 120 | + ("generate total", "generate.total"), |
| 121 | + (" llm", "generate.llm"), |
| 122 | + (" tts_prep", "generate.tts_prep"), |
| 123 | + (" tts", "generate.tts"), |
| 124 | + (" token2wav", "generate.token2wav"), |
| 125 | + ("unit_total", "unit_total"), |
| 126 | + ] |
| 127 | + |
| 128 | + for metric_label, path in metric_paths: |
| 129 | + keys = path.split(".") |
| 130 | + cd = cs |
| 131 | + for k in keys: |
| 132 | + cd = cd.get(k, {}) if isinstance(cd, dict) else {} |
| 133 | + ed = es |
| 134 | + for k in keys: |
| 135 | + ed = ed.get(k, {}) if isinstance(ed, dict) else {} |
| 136 | + |
| 137 | + c_avg = cd.get("avg", 0) if isinstance(cd, dict) else 0 |
| 138 | + e_avg = ed.get("avg", 0) if isinstance(ed, dict) else 0 |
| 139 | + |
| 140 | + if c_avg == 0 and e_avg == 0: |
| 141 | + continue |
| 142 | + |
| 143 | + diff_pct = ((e_avg - c_avg) / c_avg * 100) if c_avg > 0 else 0 |
| 144 | + arrow = "↑ slower" if diff_pct > 2 else ("↓ faster" if diff_pct < -2 else "≈") |
| 145 | + print(f" {metric_label:18s} compiled={c_avg:6.0f}ms eager={e_avg:6.0f}ms " |
| 146 | + f"{diff_pct:+6.1f}% {arrow}") |
| 147 | + |
| 148 | + print("=" * 70) |
| 149 | + |
| 150 | + |
| 151 | +def main(): |
| 152 | + cfg = get_config() |
| 153 | + |
| 154 | + print("=" * 70) |
| 155 | + print(" Compiled vs Eager Duplex Benchmark") |
| 156 | + print("=" * 70) |
| 157 | + print(f" Model: {cfg.model.model_path}") |
| 158 | + print(f" Video: {VIDEO_PATH}") |
| 159 | + print(f" Max chunks: {MAX_CHUNKS}") |
| 160 | + print() |
| 161 | + |
| 162 | + from core.processors.unified import UnifiedProcessor |
| 163 | + |
| 164 | + logger.info("加载模型 (compile=True)...") |
| 165 | + t0 = time.time() |
| 166 | + processor = UnifiedProcessor( |
| 167 | + model_path=cfg.model.model_path, |
| 168 | + pt_path=cfg.model.pt_path, |
| 169 | + ref_audio_path=cfg.ref_audio_path, |
| 170 | + compile=True, |
| 171 | + chat_vocoder=cfg.chat_vocoder, |
| 172 | + attn_implementation=cfg.attn_implementation, |
| 173 | + ) |
| 174 | + logger.info(f"模型加载完成 ({time.time() - t0:.1f}s)") |
| 175 | + |
| 176 | + model = processor.model |
| 177 | + |
| 178 | + # ── Round 1: Compiled ── |
| 179 | + model.set_compile_enabled(True) |
| 180 | + compiled_result = run_bench(model, "COMPILED") |
| 181 | + |
| 182 | + # ── Reset state between runs ── |
| 183 | + torch.cuda.empty_cache() |
| 184 | + |
| 185 | + # ── Round 2: Eager ── |
| 186 | + model.set_compile_enabled(False) |
| 187 | + eager_result = run_bench(model, "EAGER") |
| 188 | + |
| 189 | + # ── Comparison ── |
| 190 | + print_comparison(compiled_result, eager_result) |
| 191 | + |
| 192 | + |
| 193 | +if __name__ == "__main__": |
| 194 | + main() |
0 commit comments