Skip to content

Commit 5169e63

Browse files
committed
Support switch from compile mode to non-compile mode, only omni full duplex compile by default
1 parent 40e7da2 commit 5169e63

6 files changed

Lines changed: 397 additions & 11 deletions

File tree

MiniCPMO45/modeling_minicpmo_unified.py

Lines changed: 56 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -443,23 +443,23 @@ def apply_torch_compile(
443443
compiled_modules: list = []
444444
skipped_modules: list = []
445445

446-
if hasattr(self, "vpm") and "vpm" not in skip:
447-
self.vpm = torch.compile(self.vpm, **compile_kwargs)
448-
compiled_modules.append("vpm")
449-
elif "vpm" in skip:
450-
skipped_modules.append("vpm")
446+
# if hasattr(self, "vpm") and "vpm" not in skip:
447+
# self.vpm = torch.compile(self.vpm, **compile_kwargs)
448+
# compiled_modules.append("vpm")
449+
# elif "vpm" in skip:
450+
# skipped_modules.append("vpm")
451451

452452
if hasattr(self, "llm") and "llm.model" not in skip:
453453
self.llm.model = torch.compile(self.llm.model, **compile_kwargs)
454454
compiled_modules.append("llm.model")
455455
elif "llm.model" in skip:
456456
skipped_modules.append("llm.model")
457457

458-
if hasattr(self, "resampler") and "resampler" not in skip:
459-
self.resampler = torch.compile(self.resampler, **compile_kwargs)
460-
compiled_modules.append("resampler")
461-
elif "resampler" in skip:
462-
skipped_modules.append("resampler")
458+
# if hasattr(self, "resampler") and "resampler" not in skip:
459+
# self.resampler = torch.compile(self.resampler, **compile_kwargs)
460+
# compiled_modules.append("resampler")
461+
# elif "resampler" in skip:
462+
# skipped_modules.append("resampler")
463463

464464
if hasattr(self, "tts") and hasattr(self.tts, "model") and "tts.model" not in skip:
465465
self.tts.model = torch.compile(self.tts.model, **compile_kwargs)
@@ -472,6 +472,7 @@ def apply_torch_compile(
472472

473473
elapsed = _time.time() - t0
474474
self._compiled = True
475+
self._compile_active = True
475476
logger.info(
476477
f"[torch.compile] Wrapping done ({elapsed:.2f}s), "
477478
f"compiled: {compiled_modules}"
@@ -480,6 +481,51 @@ def apply_torch_compile(
480481
)
481482
return self
482483

484+
def set_compile_enabled(self, enabled: bool) -> None:
485+
"""Switch between compiled and eager execution for all compiled sub-modules.
486+
487+
Only effective after apply_torch_compile() has been called.
488+
Compiled and eager modules share the same weights (zero copy),
489+
so switching is instant and costs no extra memory.
490+
"""
491+
if not getattr(self, "_compiled", False):
492+
return
493+
if enabled == getattr(self, "_compile_active", True):
494+
return
495+
496+
swapped: list = []
497+
498+
if hasattr(self, "llm"):
499+
cur = self.llm.model
500+
if enabled:
501+
compiled = getattr(cur, "_compiled_ref", None)
502+
if compiled is not None:
503+
self.llm.model = compiled
504+
swapped.append("llm.model")
505+
else:
506+
orig = getattr(cur, "_orig_mod", None)
507+
if orig is not None:
508+
orig._compiled_ref = cur
509+
self.llm.model = orig
510+
swapped.append("llm.model")
511+
512+
if hasattr(self, "tts") and hasattr(self.tts, "model"):
513+
cur = self.tts.model
514+
if enabled:
515+
compiled = getattr(cur, "_compiled_ref", None)
516+
if compiled is not None:
517+
self.tts.model = compiled
518+
swapped.append("tts.model")
519+
else:
520+
orig = getattr(cur, "_orig_mod", None)
521+
if orig is not None:
522+
orig._compiled_ref = cur
523+
self.tts.model = orig
524+
swapped.append("tts.model")
525+
526+
self._compile_active = enabled
527+
logger.info(f"[torch.compile] {'enabled' if enabled else 'disabled'} → swapped {swapped}")
528+
483529
def warmup_compile(
484530
self,
485531
warmup_video_path: Optional[str] = None,

TODO.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
[x] compile coverage, test?
2+
[] add tts case in audio chat
3+
[] add custom voice in audio chat
4+
[] calibration dataset for quantization

core/processors/unified.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1432,6 +1432,11 @@ def _release_resources(self) -> None:
14321432

14331433
# ==================== Mode Switching ====================
14341434

1435+
def _sync_compile_state(self, want_compiled: bool) -> None:
1436+
"""Enable/disable torch.compile based on target mode."""
1437+
if self.compile and self.model is not None:
1438+
self.model.set_compile_enabled(want_compiled)
1439+
14351440
def set_chat_mode(self) -> ChatView:
14361441
"""Switch to Chat mode.
14371442
@@ -1442,6 +1447,7 @@ def set_chat_mode(self) -> ChatView:
14421447

14431448
if self._current_mode != ProcessorMode.CHAT:
14441449
start = time.time()
1450+
self._sync_compile_state(False)
14451451
self.model.set_mode(ModelProcessorMode.CHAT)
14461452
self._current_mode = ProcessorMode.CHAT
14471453
logger.info(f"Switched to CHAT mode in {(time.time()-start)*1000:.1f}ms")
@@ -1458,6 +1464,7 @@ def set_half_duplex_mode(self) -> HalfDuplexView:
14581464

14591465
if self._current_mode != ProcessorMode.HALF_DUPLEX:
14601466
start = time.time()
1467+
self._sync_compile_state(False)
14611468
self.model.set_mode(ModelProcessorMode.STREAMING)
14621469
self._current_mode = ProcessorMode.HALF_DUPLEX
14631470
logger.info(f"Switched to HALF_DUPLEX mode in {(time.time()-start)*1000:.1f}ms")
@@ -1474,6 +1481,7 @@ def set_duplex_mode(self) -> DuplexView:
14741481

14751482
if self._current_mode != ProcessorMode.DUPLEX:
14761483
start = time.time()
1484+
self._sync_compile_state(True)
14771485
self.model.set_mode(ModelProcessorMode.DUPLEX)
14781486
self._current_mode = ProcessorMode.DUPLEX
14791487
logger.info(f"Switched to DUPLEX mode in {(time.time()-start)*1000:.1f}ms")

static/audio-duplex/audio-duplex-app.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ import { initRefAudio } from '../duplex/ui/ref-audio-init.js';
3838
const SAMPLE_RATE_IN = 16000;
3939
const SAMPLE_RATE_OUT = 24000;
4040
const CHUNK_MS = 1000;
41-
const FILE_MAX_DURATION = 120; // 2 minutes
41+
const FILE_MAX_DURATION = 300; // 5 minutes
4242

4343
let currentMode = 'live';
4444
let session = null;

tests/test_compile_bench.py

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
#!/usr/bin/env python3
2+
"""Benchmark compiled vs eager duplex inference on the same video.
3+
4+
Loads the model once (with compile=True), then runs the same omni full-duplex
5+
session twice: once with compiled modules, once with eager modules.
6+
Prints a side-by-side timing comparison at the end.
7+
8+
Usage:
9+
CUDA_VISIBLE_DEVICES=0 TORCHINDUCTOR_CACHE_DIR=./torch_compile_cache \
10+
PYTHONPATH=. .venv/base/bin/python test_compile_bench.py
11+
"""
12+
13+
import os
14+
import sys
15+
import time
16+
import logging
17+
import torch
18+
from config import get_config
19+
20+
logging.basicConfig(
21+
level=logging.INFO,
22+
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
23+
)
24+
logger = logging.getLogger("compile_bench")
25+
26+
VIDEO_PATH = os.path.join(
27+
os.path.dirname(os.path.abspath(__file__)),
28+
"assets", "samples", "compile.mp4",
29+
)
30+
MAX_CHUNKS = 8
31+
32+
33+
def module_type_label(mod) -> str:
34+
cls = type(mod).__name__
35+
if cls == "OptimizedModule":
36+
return f"OptimizedModule (compiled)"
37+
return f"{cls} (eager)"
38+
39+
40+
def print_header(label: str, model):
41+
active = getattr(model, "_compile_active", "N/A")
42+
llm_label = module_type_label(model.llm.model)
43+
tts_label = module_type_label(model.tts.model) if hasattr(model.tts, "model") else "N/A"
44+
print(f"\n{'='*70}")
45+
print(f" {label}")
46+
print(f" _compile_active = {active}")
47+
print(f" llm.model = {llm_label}")
48+
print(f" tts.model = {tts_label}")
49+
print(f"{'='*70}")
50+
51+
52+
def run_bench(model, label: str) -> dict:
53+
print_header(label, model)
54+
t0 = time.time()
55+
result = model.benchmark(
56+
video_paths=[VIDEO_PATH],
57+
max_chunks_per_video=MAX_CHUNKS,
58+
)
59+
elapsed = time.time() - t0
60+
print(f" [{label}] done in {elapsed:.1f}s, "
61+
f"units={result.get('num_units', 0)}, "
62+
f"listen={result.get('listen_count', 0)}, "
63+
f"speak={result.get('speak_count', 0)}")
64+
return result
65+
66+
67+
def format_stats(stats: dict, key_path: str) -> str:
68+
keys = key_path.split(".")
69+
d = stats
70+
for k in keys:
71+
d = d.get(k, {})
72+
if not d:
73+
return "N/A"
74+
return f"avg={d.get('avg', 0):.0f}ms min={d.get('min', 0):.0f}ms max={d.get('max', 0):.0f}ms"
75+
76+
77+
def print_comparison(compiled_result: dict, eager_result: dict):
78+
print("\n")
79+
print("=" * 70)
80+
print(" Compiled vs Eager 对比")
81+
print("=" * 70)
82+
83+
rows = [
84+
("总用时", "total_time", "s", True),
85+
]
86+
87+
# top-level
88+
for label, key, unit, is_time in rows:
89+
cv = compiled_result.get(key, 0)
90+
ev = eager_result.get(key, 0)
91+
if is_time:
92+
diff_pct = ((ev - cv) / cv * 100) if cv > 0 else 0
93+
print(f" {label:20s} compiled={cv:.1f}{unit} eager={ev:.1f}{unit} "
94+
f"差异={diff_pct:+.1f}%")
95+
else:
96+
print(f" {label:20s} compiled={cv} eager={ev}")
97+
98+
print(f" {'单位数':20s} compiled={compiled_result.get('num_units',0)} "
99+
f"eager={eager_result.get('num_units',0)}")
100+
101+
# per-decision-type stats
102+
for decision in ("listen", "speak"):
103+
cs = compiled_result.get(f"{decision}_stats", {})
104+
es = eager_result.get(f"{decision}_stats", {})
105+
cc = cs.get("count", 0)
106+
ec = es.get("count", 0)
107+
if cc == 0 and ec == 0:
108+
continue
109+
110+
print(f"\n ── {decision.upper()} (compiled n={cc}, eager n={ec}) ──")
111+
112+
metric_paths = [
113+
("prefill total", "prefill.total"),
114+
(" vision_process", "prefill.vision_process"),
115+
(" vision_embed", "prefill.vision_embed"),
116+
(" vision_feed", "prefill.vision_feed"),
117+
(" audio_process", "prefill.audio_process"),
118+
(" audio_embed", "prefill.audio_embed"),
119+
(" audio_feed", "prefill.audio_feed"),
120+
("generate total", "generate.total"),
121+
(" llm", "generate.llm"),
122+
(" tts_prep", "generate.tts_prep"),
123+
(" tts", "generate.tts"),
124+
(" token2wav", "generate.token2wav"),
125+
("unit_total", "unit_total"),
126+
]
127+
128+
for metric_label, path in metric_paths:
129+
keys = path.split(".")
130+
cd = cs
131+
for k in keys:
132+
cd = cd.get(k, {}) if isinstance(cd, dict) else {}
133+
ed = es
134+
for k in keys:
135+
ed = ed.get(k, {}) if isinstance(ed, dict) else {}
136+
137+
c_avg = cd.get("avg", 0) if isinstance(cd, dict) else 0
138+
e_avg = ed.get("avg", 0) if isinstance(ed, dict) else 0
139+
140+
if c_avg == 0 and e_avg == 0:
141+
continue
142+
143+
diff_pct = ((e_avg - c_avg) / c_avg * 100) if c_avg > 0 else 0
144+
arrow = "↑ slower" if diff_pct > 2 else ("↓ faster" if diff_pct < -2 else "≈")
145+
print(f" {metric_label:18s} compiled={c_avg:6.0f}ms eager={e_avg:6.0f}ms "
146+
f"{diff_pct:+6.1f}% {arrow}")
147+
148+
print("=" * 70)
149+
150+
151+
def main():
152+
cfg = get_config()
153+
154+
print("=" * 70)
155+
print(" Compiled vs Eager Duplex Benchmark")
156+
print("=" * 70)
157+
print(f" Model: {cfg.model.model_path}")
158+
print(f" Video: {VIDEO_PATH}")
159+
print(f" Max chunks: {MAX_CHUNKS}")
160+
print()
161+
162+
from core.processors.unified import UnifiedProcessor
163+
164+
logger.info("加载模型 (compile=True)...")
165+
t0 = time.time()
166+
processor = UnifiedProcessor(
167+
model_path=cfg.model.model_path,
168+
pt_path=cfg.model.pt_path,
169+
ref_audio_path=cfg.ref_audio_path,
170+
compile=True,
171+
chat_vocoder=cfg.chat_vocoder,
172+
attn_implementation=cfg.attn_implementation,
173+
)
174+
logger.info(f"模型加载完成 ({time.time() - t0:.1f}s)")
175+
176+
model = processor.model
177+
178+
# ── Round 1: Compiled ──
179+
model.set_compile_enabled(True)
180+
compiled_result = run_bench(model, "COMPILED")
181+
182+
# ── Reset state between runs ──
183+
torch.cuda.empty_cache()
184+
185+
# ── Round 2: Eager ──
186+
model.set_compile_enabled(False)
187+
eager_result = run_bench(model, "EAGER")
188+
189+
# ── Comparison ──
190+
print_comparison(compiled_result, eager_result)
191+
192+
193+
if __name__ == "__main__":
194+
main()

0 commit comments

Comments
 (0)