diff --git a/app.py b/app.py index 1141b62..e3e9e80 100644 --- a/app.py +++ b/app.py @@ -22,12 +22,25 @@ import uvicorn from fastapi import FastAPI, File, Form, Request, UploadFile from fastapi.responses import FileResponse, HTMLResponse, JSONResponse, StreamingResponse +from starlette.concurrency import iterate_in_threadpool from moss_tts_nano_runtime import ( DEFAULT_AUDIO_TOKENIZER_PATH, DEFAULT_CHECKPOINT_PATH, DEFAULT_OUTPUT_DIR, NanoTTSService, + resolve_device, +) +from openai_audio_api import ( + FORMAT_CONTENT_TYPE, + SpeechRequest, + iter_pcm_audio, + make_error_response, + preprocess_tts_input, + resolve_voice, + start_opus_encoder, + _resample_pcm, + _wav_header_bytes, ) from text_normalization_pipeline import ( TextNormalizationSnapshot as SharedTextNormalizationSnapshot, @@ -2175,6 +2188,7 @@ def _build_app( warmup_manager: WarmupManager, text_normalizer_manager: WeTextProcessingManager | None, root_path: str | None, + runtime_device: str = "cpu", ) -> FastAPI: app = FastAPI(title="MOSS-TTS-Nano Demo", root_path=root_path or "") stream_jobs = StreamingJobManager() @@ -2833,10 +2847,372 @@ def _synthesize(selected_runtime: NanoTTSService): _maybe_delete_file(generated_audio_path) _maybe_delete_file(prompt_audio_cleanup_path) + # ------------------------------------------------------------------ + # OpenAI-compatible endpoint: POST /v1/audio/speech + # ------------------------------------------------------------------ + + @app.post("/v1/audio/speech") + async def openai_audio_speech(request: Request): + import pydantic + + # 1. Parse & validate request body + raw_body: bytes = await request.body() + try: + body = await request.json() + speech_req = SpeechRequest(**body) + except pydantic.ValidationError as exc: + # Return the first validation error as OpenAI-format JSON + first_err = exc.errors()[0] + param = first_err.get("loc", [None])[-1] + logging.warning( + "OpenAI /v1/audio/speech validation error: headers=%s body=%s errors=%s", + dict(request.headers), + raw_body.decode("utf-8", errors="replace")[:512], + exc.errors(), + ) + return JSONResponse( + status_code=400, + content=make_error_response( + message=first_err.get("msg", "Invalid request"), + param=str(param) if param else None, + )[0], + ) + except Exception as exc: + logging.warning( + "OpenAI /v1/audio/speech body parse error: headers=%s body=%s error=%s", + dict(request.headers), + raw_body.decode("utf-8", errors="replace")[:512], + exc, + ) + return JSONResponse( + status_code=400, + content=make_error_response( + message=f"Invalid JSON body: {exc}", + )[0], + ) + + request_start = time.monotonic() + + logging.info( + "OpenAI /v1/audio/speech request: headers=%s body=%s", + dict(request.headers), + raw_body.decode("utf-8", errors="replace")[:1024], + ) + + # 2. Preprocess input: strip emoji/kaomoji, normalize newlines to pauses + original_input = speech_req.input + speech_req = speech_req.model_copy( + update={"input": preprocess_tts_input(speech_req.input)}, + ) + logging.info( + "OpenAI /v1/audio/speech stage=preprocess before=%r after=%r", + original_input, + speech_req.input, + ) + + # 3. Validate input length + if len(speech_req.input) > 4096: + err_body, status = make_error_response( + message="input text exceeds maximum length of 4096 characters.", + param="input", + ) + return JSONResponse(status_code=status, content=err_body) + + # 4. Resolve voice name (OpenAI name → MOSS preset, or passthrough) + moss_voice = resolve_voice(speech_req.voice) + + # 5. Prepare text + try: + prepared_texts = shared_prepare_tts_request_texts( + text=speech_req.input, + enable_wetext=True, + enable_normalize_tts_text=True, + text_normalizer_manager=text_normalizer_manager, + ) + except Exception: + logging.exception("Text normalization failed for OpenAI endpoint") + return JSONResponse( + status_code=500, + content=make_error_response( + message="Text normalization failed.", + error_type="server_error", + status_code=500, + )[0], + ) + + tts_text = str(prepared_texts["text"]) + logging.info( + "OpenAI /v1/audio/speech stage=final_text len=%d method=%s lang=%s text=%r", + len(tts_text), + prepared_texts.get("normalization_method", "?"), + prepared_texts.get("text_normalization_language", "?"), + tts_text, + ) + + # 6. Pre-split text into chunks for per-chunk synthesis. + # This prevents content loss from model batch inference by + # synthesizing each chunk independently and verifying output. + tts_text_chunks = runtime_manager.default_runtime.split_voice_clone_text( + text=tts_text, voice_clone_max_text_tokens=30, + ) + logging.info( + "OpenAI /v1/audio/speech stage=chunks count=%d chunks=%r", + len(tts_text_chunks), tts_text_chunks, + ) + + # 7. Ensure warmup + warmup_snapshot = warmup_manager.snapshot() + if not warmup_snapshot.ready: + warmup_snapshot = warmup_manager.ensure_ready() + if not warmup_snapshot.ready: + return JSONResponse( + status_code=503, + content=make_error_response( + message="Model is still warming up. Please retry later.", + error_type="server_error", + status_code=503, + )[0], + ) + + # 8. Build streaming response via background thread + queue + # This avoids holding _cpu_execution_lock inside the ASGI + # streaming iterator (which can deadlock on client disconnect). + response_format = speech_req.response_format + + audio_queue: queue.Queue[bytes | None] = queue.Queue(maxsize=64) + client_disconnected = threading.Event() + + def _put(chunk: bytes) -> bool: + """Put a chunk into the queue, returning False on timeout/disconnect.""" + deadline = time.monotonic() + 30 # bail after 30s of queue full + while not client_disconnected.is_set() and time.monotonic() < deadline: + try: + audio_queue.put(chunk, timeout=0.5) + return True + except queue.Full: + continue + client_disconnected.set() + return False + + def _run_tts(): + events_gen = None + try: + events_gen = runtime_manager.iter_with_runtime( + requested_execution_device=runtime_device, + cpu_threads=0, + factory=lambda rt: rt.synthesize_stream( + text=str(prepared_texts["text"]), + mode="voice_clone", + voice=moss_voice, + prompt_audio_path=None, + voice_clone_max_text_tokens=30, + voice_clone_max_memory_per_sample_gb=0.6, + max_new_frames=200, + tts_max_batch_size=7, + codec_max_batch_size=7, + ), + ) + audio_chunks = 0 + tts_gen_start = time.monotonic() + first_pcm_at: float | None = None + last_pcm_at: float | None = None + speed = speech_req.speed + if response_format == "wav": + header_sent = False + for pcm, sample_rate, channels in iter_pcm_audio(events_gen): + if client_disconnected.is_set(): + return + audio_chunks += 1 + now = time.monotonic() + if first_pcm_at is None: + first_pcm_at = now + last_pcm_at = now + if not header_sent: + if not _put(_wav_header_bytes(sample_rate, channels)): + return + header_sent = True + if not _put(_resample_pcm(pcm, speed, channels)): + return + elif response_format == "mp3": + encoder = None + for pcm, sample_rate, channels in iter_pcm_audio(events_gen): + if client_disconnected.is_set(): + return + audio_chunks += 1 + now = time.monotonic() + if first_pcm_at is None: + first_pcm_at = now + last_pcm_at = now + if encoder is None: + import lameenc + encoder = lameenc.Encoder() + encoder.set_bit_rate(128) + encoder.set_in_sample_rate(sample_rate) + encoder.set_channels(channels) + encoder.set_quality(2) + if not _put(bytes(encoder.encode(_resample_pcm(pcm, speed, channels)))): + return + if encoder is not None: + flush = encoder.flush() + if flush: + if not _put(bytes(flush)): + pass # client disconnected, best effort + elif response_format == "opus": + import subprocess as _sp + import threading as _threading_mod + + opus_proc = None + opus_reader_thread = None + opus_stderr_chunks: list[bytes] = [] + opus_output_queue: queue.Queue[bytes | None] = queue.Queue(maxsize=64) + + def _read_opus_stdout(proc: _sp.Popen): + try: + while True: + block = proc.stdout.read(8192) + if not block: + break + opus_output_queue.put(block) + finally: + opus_output_queue.put(None) + + def _read_opus_stderr(proc: _sp.Popen): + try: + while True: + chunk = proc.stderr.read(4096) + if not chunk: + break + opus_stderr_chunks.append(chunk) + except Exception: + pass + + try: + for pcm, sample_rate, channels in iter_pcm_audio(events_gen): + if client_disconnected.is_set(): + break + audio_chunks += 1 + now = time.monotonic() + if first_pcm_at is None: + first_pcm_at = now + last_pcm_at = now + if opus_proc is None: + opus_proc = start_opus_encoder(sample_rate, channels, speed=speed) + opus_reader_thread = _threading_mod.Thread( + target=_read_opus_stdout, args=(opus_proc,), daemon=True, + ) + opus_reader_thread.start() + _stderr_thread = _threading_mod.Thread( + target=_read_opus_stderr, args=(opus_proc,), daemon=True, + ) + _stderr_thread.start() + try: + opus_proc.stdin.write(pcm) + except BrokenPipeError: + break + + if opus_proc is not None: + try: + opus_proc.stdin.close() + except BrokenPipeError: + pass + + while True: + block = opus_output_queue.get(timeout=30) + if block is None: + break + if not _put(block): + break + + if opus_reader_thread is not None: + opus_reader_thread.join(timeout=10) + rc = opus_proc.wait(timeout=10) + if rc != 0: + stderr_output = b"".join(opus_stderr_chunks).decode("utf-8", errors="replace") + logging.error( + "Opus ffmpeg exited with rc=%d stderr=%s", + rc, stderr_output[:500], + ) + except Exception: + logging.exception("Opus encoding failed") + raise + else: # pcm + for pcm, _, pcm_ch in iter_pcm_audio(events_gen): + if client_disconnected.is_set(): + return + audio_chunks += 1 + now = time.monotonic() + if first_pcm_at is None: + first_pcm_at = now + last_pcm_at = now + if not _put(_resample_pcm(pcm, speed, pcm_ch)): + return + except Exception: + logging.exception("TTS thread failed for OpenAI /v1/audio/speech") + finally: + elapsed = time.monotonic() - request_start + gen_elapsed = (last_pcm_at or time.monotonic()) - tts_gen_start + first_pcm_elapsed = (first_pcm_at - tts_gen_start) if first_pcm_at else None + logging.info( + "OpenAI /v1/audio/speech complete: format=%s audio_chunks=%d elapsed=%.2fs tts_gen=%.2fs first_pcm=%s", + response_format, audio_chunks, elapsed, gen_elapsed, + f"{first_pcm_elapsed:.2f}s" if first_pcm_elapsed else "n/a", + ) + # Explicitly close the events generator to release + # _cpu_execution_lock held by iter_with_runtime. + if events_gen is not None: + events_gen.close() + # Always push sentinel, even on error/disconnect + try: + audio_queue.put(None, timeout=1.0) + except queue.Full: + pass + + tts_thread = threading.Thread(target=_run_tts, daemon=True) + tts_thread.start() + + def _audio_from_queue(): + while True: + chunk = audio_queue.get() + if chunk is None: + break + yield chunk + + content_type = FORMAT_CONTENT_TYPE.get(response_format, "application/octet-stream") + + return StreamingResponse( + iterate_in_threadpool(_audio_from_queue()), + media_type=content_type, + ) + return app +def _patch_torchaudio_backend() -> None: + """Patch torchaudio to avoid the SoX backend, which segfaults on some systems.""" + try: + import torchaudio + _original_load = torchaudio.load + _original_save = torchaudio.save + + def _load_with_soundfile(uri, *args, backend=None, **kwargs): + if backend is None: + backend = "soundfile" + return _original_load(uri, *args, backend=backend, **kwargs) + + def _save_with_soundfile(uri, src, sample_rate, *args, backend=None, **kwargs): + if backend is None: + backend = "soundfile" + return _original_save(uri, src, sample_rate, *args, backend=backend, **kwargs) + + torchaudio.load = _load_with_soundfile + torchaudio.save = _save_with_soundfile + except ImportError: + pass + + def main(argv: Optional[Sequence[str]] = None) -> None: + _patch_torchaudio_backend() + parser = argparse.ArgumentParser(description="MOSS-TTS-Nano web demo") parser.add_argument("--checkpoint-path", "--checkpoint_path", dest="checkpoint_path", type=str, default=str(DEFAULT_CHECKPOINT_PATH)) parser.add_argument( @@ -2847,7 +3223,7 @@ def main(argv: Optional[Sequence[str]] = None) -> None: default=str(DEFAULT_AUDIO_TOKENIZER_PATH), ) parser.add_argument("--output-dir", "--output_dir", dest="output_dir", type=str, default=str(DEFAULT_OUTPUT_DIR)) - parser.add_argument("--device", type=str, default="cpu", choices=["cpu", "auto"]) + parser.add_argument("--device", type=str, default="auto", choices=["cpu", "auto", "cuda"]) parser.add_argument("--dtype", type=str, default="auto", choices=["auto", "float32", "float16", "bfloat16"]) parser.add_argument( "--attn-implementation", @@ -2867,9 +3243,8 @@ def main(argv: Optional[Sequence[str]] = None) -> None: level=logging.INFO, ) - resolved_runtime_device = "cpu" - if args.device != "cpu": - logging.info("CPU-only app mode: ignoring --device=%s and forcing cpu.", args.device) + resolved_runtime_device = str(resolve_device(args.device)) + logging.info("resolved device=%s", resolved_runtime_device) runtime = NanoTTSService( checkpoint_path=args.checkpoint_path, @@ -2890,13 +3265,46 @@ def main(argv: Optional[Sequence[str]] = None) -> None: if args.share: logging.warning("--share is ignored by the FastAPI-based Nano-TTS app.") - app = _build_app(runtime, warmup_manager, text_normalizer_manager, root_path) + app = _build_app(runtime, warmup_manager, text_normalizer_manager, root_path, resolved_runtime_device) uvicorn.run( app, host=args.host, port=args.port, log_level="info", root_path=root_path or "", + log_config={ + "version": 1, + "disable_existing_loggers": False, + "formatters": { + "default": { + "()": "uvicorn.logging.DefaultFormatter", + "fmt": "%(asctime)s %(levelprefix)s %(message)s", + "datefmt": "%Y-%m-%d %H:%M:%S", + }, + "access": { + "()": "uvicorn.logging.AccessFormatter", + "fmt": "%(asctime)s %(levelprefix)s %(client_addr)s - \"%(request_line)s\" %(status_code)s", + "datefmt": "%Y-%m-%d %H:%M:%S", + }, + }, + "handlers": { + "default": { + "formatter": "default", + "class": "logging.StreamHandler", + "stream": "ext://sys.stderr", + }, + "access": { + "formatter": "access", + "class": "logging.StreamHandler", + "stream": "ext://sys.stdout", + }, + }, + "loggers": { + "uvicorn": {"handlers": ["default"], "level": "INFO", "propagate": False}, + "uvicorn.error": {"level": "INFO"}, + "uvicorn.access": {"handlers": ["access"], "level": "INFO", "propagate": False}, + }, + }, ) diff --git a/app_onnx.py b/app_onnx.py index 3811175..1103ad5 100644 --- a/app_onnx.py +++ b/app_onnx.py @@ -1,6 +1,7 @@ from __future__ import annotations import argparse +import concurrent.futures import logging import os import queue @@ -22,6 +23,7 @@ _merge_audio_channels, _write_waveform_to_wav, ) +from ort_cpu_runtime import CodecStreamingDecodeSession from text_normalization_pipeline import WeTextProcessingManager APP_DIR = Path(__file__).resolve().parent @@ -68,6 +70,23 @@ def __init__( self.checkpoint_path = self.runtime.tts_meta_path.parent.resolve() self.audio_tokenizer_path = self.runtime.codec_meta_path.parent.resolve() self.thread_count = max(1, int(cpu_threads)) + # Build session pool for parallel chunk processing. + # Each worker gets its own ONNX sessions to avoid contention. + total_cpus = os.cpu_count() or 4 + per_worker_threads = max(1, int(cpu_threads)) + self._parallel_workers = max(1, total_cpus // per_worker_threads) + self._parallel_sessions: list[dict[str, object]] = [] + self._parallel_rlocks: list[threading.RLock] = [] + if self._parallel_workers > 1: + worker_threads = max(1, per_worker_threads) + for _ in range(self._parallel_workers): + sessions = self.runtime._create_sessions_with_threads(worker_threads) + self._parallel_sessions.append(sessions) + self._parallel_rlocks.append(threading.RLock()) + logging.info( + "ONNX parallel pool workers=%d per_worker_threads=%d total_cpus=%d", + self._parallel_workers, worker_threads, total_cpus, + ) def get_model(self) -> "OnnxNanoTTSServiceAdapter": return self @@ -230,8 +249,9 @@ def synthesize_stream( mode: str, voice: str | None, prompt_audio_path: str | None, - max_new_frames: int, - voice_clone_max_text_tokens: int, + max_new_frames: int = 375, + voice_clone_max_text_tokens: int = 75, + voice_clone_max_memory_per_sample_gb: float = 1.0, tts_max_batch_size: int = 0, codec_max_batch_size: int = 0, attn_implementation: str = "model_default", @@ -245,9 +265,91 @@ def synthesize_stream( audio_repetition_penalty: float = 1.2, seed: int | None = None, ) -> Iterator[dict[str, object]]: - del mode, tts_max_batch_size, codec_max_batch_size + del mode, tts_max_batch_size, codec_max_batch_size, voice_clone_max_memory_per_sample_gb event_queue: "queue.Queue[dict[str, object] | None]" = queue.Queue(maxsize=128) + def _process_single_chunk( + chunk_index: int, + chunk_text: str, + total_chunks: int, + prompt_audio_codes: list[list[int]], + rng: np.random.Generator, + pool_index: int, + ) -> dict[str, object]: + """Process one text chunk using pooled sessions for thread-safety.""" + _chunk_t0 = time.perf_counter() + + if pool_index >= 0 and self._parallel_sessions: + # Acquire pool slot and swap sessions + rng atomically + with self._parallel_rlocks[pool_index]: + original_sessions = self.runtime.sessions + original_rng = self.runtime.rng + self.runtime.sessions = self._parallel_sessions[pool_index] + self.runtime.rng = rng + try: + return _run_chunk(chunk_index, chunk_text, total_chunks, _chunk_t0, pool_index, prompt_audio_codes) + finally: + self.runtime.sessions = original_sessions + self.runtime.rng = original_rng + else: + # No pool — use runtime sessions directly (single-threaded path) + original_rng = self.runtime.rng + self.runtime.rng = rng + try: + return _run_chunk(chunk_index, chunk_text, total_chunks, _chunk_t0, -1, prompt_audio_codes) + finally: + self.runtime.rng = original_rng + + def _run_chunk( + chunk_index: int, + chunk_text: str, + total_chunks: int, + _chunk_t0: float, + pool_index: int, + prompt_audio_codes: list[list[int]], + ) -> dict[str, object]: + text_token_ids = self.runtime.encode_text(chunk_text) + request_rows = self.runtime.build_voice_clone_request_rows(prompt_audio_codes, text_token_ids) + codec_session = CodecStreamingDecodeSession( + codec_meta=self.runtime.codec_meta, + session=self.runtime.sessions["codec_decode_step"], + ) + generated_frames: list[list[int]] = [] + + def _on_frame(_gf: list[list[int]], _si: int, frame: list[int]) -> None: + generated_frames.append(list(frame)) + + _gen_t0 = time.perf_counter() + all_frames = self.runtime.generate_audio_frames(request_rows, on_frame=_on_frame) + _gen_elapsed = time.perf_counter() - _gen_t0 + + _dec_t0 = time.perf_counter() + decoded_waveforms: list[np.ndarray] = [] + if all_frames: + decoded = codec_session.run_frames(all_frames) + if decoded is not None: + audio, audio_length = decoded + if audio_length > 0: + waveform = _merge_audio_channels( + [audio[0, ch, :audio_length] for ch in range(audio.shape[1])] + ) + decoded_waveforms.append(waveform) + _decode_elapsed = time.perf_counter() - _dec_t0 + _chunk_elapsed = time.perf_counter() - _chunk_t0 + + chunk_waveform = _concat_waveforms(decoded_waveforms) if decoded_waveforms else np.zeros((0,), dtype=np.float32) + logging.info( + "ONNX timing chunk=%d/%d generate=%.3fs decode=%.3fs frames=%d total=%.3fs pool=%d text=%r", + chunk_index + 1, total_chunks, _gen_elapsed, _decode_elapsed, + len(all_frames), _chunk_elapsed, pool_index, chunk_text, + ) + return { + "chunk_index": chunk_index, + "chunk_text": chunk_text, + "waveform": chunk_waveform, + "frames": len(all_frames), + } + def _worker() -> None: try: resolved_sample_mode = self._resolve_sample_mode(attn_implementation, do_sample=do_sample) @@ -265,94 +367,149 @@ def _worker() -> None: seed=seed, ) start_time = time.perf_counter() + _t0 = time.perf_counter() prompt_audio_codes = self.runtime.resolve_prompt_audio_codes(voice=voice, prompt_audio_path=prompt_audio_path) + logging.info("ONNX timing resolve_prompt_audio_codes %.3fs", time.perf_counter() - _t0) + _t0 = time.perf_counter() text_chunks = self.runtime.split_voice_clone_text(str(text or ""), max_tokens=int(voice_clone_max_text_tokens)) + logging.info("ONNX timing split_voice_clone_text %.3fs count=%d chunks=%r", time.perf_counter() - _t0, len(text_chunks), text_chunks) sample_rate = int(self.runtime.codec_meta["codec_config"]["sample_rate"]) channels = int(self.runtime.codec_meta["codec_config"]["channels"]) - emitted_samples_total = 0 - first_audio_emitted_at_perf: float | None = None - all_waveforms: list[np.ndarray] = [] - all_generated_frames: list[list[int]] = [] - - for chunk_index, chunk_text in enumerate(text_chunks): - text_token_ids = self.runtime.encode_text(chunk_text) - request_rows = self.runtime.build_voice_clone_request_rows(prompt_audio_codes, text_token_ids) - pending_decode_frames: list[list[int]] = [] - emitted_chunks: list[np.ndarray] = [] - self.runtime.codec_streaming_session.reset() - - def _emit_waveform(waveform: np.ndarray, *, is_pause: bool) -> None: - nonlocal emitted_samples_total, first_audio_emitted_at_perf - audio_length = int(waveform.shape[0]) - if first_audio_emitted_at_perf is None and not is_pause: - first_audio_emitted_at_perf = time.perf_counter() - emitted_samples_total += audio_length - lead_seconds = 0.0 - if first_audio_emitted_at_perf is not None: - elapsed_since_first_audio = max(0.0, time.perf_counter() - first_audio_emitted_at_perf) - lead_seconds = (emitted_samples_total / float(sample_rate)) - elapsed_since_first_audio - emitted_chunks.append(np.asarray(waveform, dtype=np.float32)) - event_queue.put( - { + + num_chunks = len(text_chunks) + logging.info( + "ONNX dispatch num_chunks=%d parallel_workers=%d pool_sessions=%d", + num_chunks, self._parallel_workers, len(self._parallel_sessions), + ) + # Parallel chunk processing when multiple chunks + if num_chunks > 1 and self._parallel_workers > 1: + max_workers = min(num_chunks, self._parallel_workers) + logging.info( + "ONNX parallel chunks=%d max_workers=%d pool_size=%d", + num_chunks, max_workers, len(self._parallel_sessions), + ) + # Create per-chunk RNGs from the current rng + base_seed = int.from_bytes(self.runtime.rng.bytes(4), "little") + chunk_rngs = [np.random.default_rng(base_seed + i) for i in range(num_chunks)] + + # Simple round-robin pool slot assignment + _pool_counter = {"value": 0} + _pool_lock = threading.Lock() + + def _next_pool_index() -> int: + with _pool_lock: + idx = _pool_counter["value"] % len(self._parallel_sessions) + _pool_counter["value"] += 1 + return idx + + chunk_results: list[dict[str, object] | None] = [None] * num_chunks + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = { + executor.submit( + _process_single_chunk, + i, text_chunks[i], num_chunks, + prompt_audio_codes, chunk_rngs[i], + _next_pool_index(), + ): i + for i in range(num_chunks) + } + for future in concurrent.futures.as_completed(futures): + idx = futures[future] + try: + chunk_results[idx] = future.result() + except Exception as exc: + chunk_results[idx] = {"chunk_index": idx, "error": str(exc)} + + # Emit results in original chunk order + all_waveforms: list[np.ndarray] = [] + emitted_samples_total = 0 + for idx in range(num_chunks): + result = chunk_results[idx] + if result is None or "error" in result: + error_msg = result.get("error", "unknown") if result else "no result" + logging.error("ONNX chunk=%d failed: %s", idx + 1, error_msg) + continue + waveform = result["waveform"] + if waveform.shape[0] > 0: + all_waveforms.append(waveform) + event_queue.put({ "type": "audio", "waveform_numpy": np.asarray(waveform, dtype=np.float32), "sample_rate": sample_rate, "channels": channels, - "chunk_index": chunk_index, - "emitted_audio_seconds": emitted_samples_total / float(sample_rate), - "lead_seconds": lead_seconds, - "is_pause": bool(is_pause), - } - ) - - def _decode_pending(force: bool) -> None: - pending_count = len(pending_decode_frames) - if pending_count <= 0: - return - decode_budget = _resolve_stream_decode_frame_budget( - emitted_samples_total, - sample_rate, - first_audio_emitted_at_perf, - ) - if not force and pending_count < max(1, decode_budget): - return - frame_budget = pending_count if force else min(pending_count, max(1, decode_budget)) - frame_chunk = pending_decode_frames[:frame_budget] - del pending_decode_frames[:frame_budget] - decoded = self.runtime.codec_streaming_session.run_frames(frame_chunk) - if decoded is None: - return - audio, audio_length = decoded - if audio_length <= 0: - return - waveform = _merge_audio_channels( - [audio[0, channel_index, :audio_length] for channel_index in range(audio.shape[1])] + "chunk_index": idx, + "emitted_audio_seconds": (emitted_samples_total + waveform.shape[0]) / float(sample_rate), + "lead_seconds": 0.0, + "is_pause": False, + }) + emitted_samples_total += waveform.shape[0] + # Inter-chunk pause + if idx < num_chunks - 1: + pause_seconds = self.runtime.estimate_voice_clone_inter_chunk_pause_seconds(text_chunks[idx]) + pause_samples = max(0, int(round(sample_rate * pause_seconds))) + if pause_samples > 0: + pause_waveform = np.zeros((pause_samples, channels), dtype=np.float32) + all_waveforms.append(pause_waveform) + waveform = _concat_waveforms(all_waveforms) + else: + # Single chunk or no pool: sequential path + emitted_samples_total = 0 + all_waveforms: list[np.ndarray] = [] + all_generated_frames: list[list[int]] = [] + for chunk_index, chunk_text in enumerate(text_chunks): + pool_idx = 0 if self._parallel_sessions else -1 + chunk_result = _process_single_chunk( + chunk_index, chunk_text, num_chunks, + prompt_audio_codes, + np.random.default_rng(int.from_bytes(self.runtime.rng.bytes(4), "little") + chunk_index), + pool_idx, ) - _emit_waveform(waveform, is_pause=False) - - def _on_frame(_generated_frames: list[list[int]], _step_index: int, frame: list[int]) -> None: - pending_decode_frames.append(list(frame)) - _decode_pending(False) - - try: - generated_frames = self.runtime.generate_audio_frames(request_rows, on_frame=_on_frame) - _decode_pending(True) - finally: - self.runtime.codec_streaming_session.reset() - - chunk_waveform = _concat_waveforms(emitted_chunks) - all_waveforms.append(chunk_waveform) - all_generated_frames.extend(generated_frames) - - if chunk_index < len(text_chunks) - 1: - pause_seconds = self.runtime.estimate_voice_clone_inter_chunk_pause_seconds(chunk_text) - pause_samples = max(0, int(round(sample_rate * pause_seconds))) - if pause_samples > 0: - pause_waveform = np.zeros((pause_samples, channels), dtype=np.float32) - _emit_waveform(pause_waveform, is_pause=True) - all_waveforms.append(pause_waveform) - - waveform = _concat_waveforms(all_waveforms) + if "error" in chunk_result: + logging.error("ONNX chunk=%d failed: %s", chunk_index + 1, chunk_result.get("error")) + continue + waveform = chunk_result["waveform"] + if waveform.shape[0] > 0: + all_waveforms.append(waveform) + event_queue.put({ + "type": "audio", + "waveform_numpy": np.asarray(waveform, dtype=np.float32), + "sample_rate": sample_rate, + "channels": channels, + "chunk_index": chunk_index, + "emitted_audio_seconds": waveform.shape[0] / float(sample_rate), + "lead_seconds": 0.0, + "is_pause": False, + }) + emitted_samples_total += waveform.shape[0] + if chunk_index < num_chunks - 1: + pause_seconds = self.runtime.estimate_voice_clone_inter_chunk_pause_seconds(chunk_text) + pause_samples = max(0, int(round(sample_rate * pause_seconds))) + if pause_samples > 0: + pause_waveform = np.zeros((pause_samples, channels), dtype=np.float32) + all_waveforms.append(pause_waveform) + waveform = _concat_waveforms(all_waveforms) + if "error" not in chunk_result: + waveform = chunk_result["waveform"] + if waveform.shape[0] > 0: + all_waveforms.append(waveform) + event_queue.put({ + "type": "audio", + "waveform_numpy": np.asarray(waveform, dtype=np.float32), + "sample_rate": sample_rate, + "channels": channels, + "chunk_index": 0, + "emitted_audio_seconds": waveform.shape[0] / float(sample_rate), + "lead_seconds": 0.0, + "is_pause": False, + }) + emitted_samples_total = waveform.shape[0] + waveform = _concat_waveforms(all_waveforms) + + _total_elapsed = time.perf_counter() - start_time + logging.info( + "ONNX timing total %.3fs chunks=%d audio_seconds=%.2fs", + _total_elapsed, num_chunks, emitted_samples_total / float(sample_rate) if sample_rate else 0, + ) output_path = _write_waveform_to_wav( self.output_dir / "app_onnx_stream_output.wav", waveform, @@ -373,6 +530,7 @@ def _on_frame(_generated_frames: list[list[int]], _step_index: int, frame: list[ } ) except Exception as exc: + logging.exception("ONNX synthesize_stream error") event_queue.put({"type": "error", "error": str(exc)}) finally: event_queue.put(None) @@ -396,7 +554,7 @@ class OnnxRequestRuntimeManager: def __init__(self, default_runtime: OnnxNanoTTSServiceAdapter) -> None: self.default_runtime = default_runtime - self.default_cpu_threads = max(1, int(os.cpu_count() or 1)) + self.default_cpu_threads = default_runtime.thread_count self._lock = threading.Lock() self._execution_lock = threading.Lock() self._cpu_runtimes: dict[int, OnnxNanoTTSServiceAdapter] = {default_runtime.thread_count: default_runtime} @@ -596,7 +754,32 @@ def parse_args(argv: Optional[Sequence[str]] = None) -> argparse.Namespace: return parser.parse_args(argv) +def _patch_torchaudio_backend() -> None: + """Patch torchaudio to avoid the SoX backend, which segfaults on some systems.""" + try: + import torchaudio + _original_load = torchaudio.load + _original_save = torchaudio.save + + def _load_with_soundfile(uri, *args, backend=None, **kwargs): + if backend is None: + backend = "soundfile" + return _original_load(uri, *args, backend=backend, **kwargs) + + def _save_with_soundfile(uri, src, sample_rate, *args, backend=None, **kwargs): + if backend is None: + backend = "soundfile" + return _original_save(uri, src, sample_rate, *args, backend=backend, **kwargs) + + torchaudio.load = _load_with_soundfile + torchaudio.save = _save_with_soundfile + except ImportError: + pass + + def main(argv: Optional[Sequence[str]] = None) -> None: + _patch_torchaudio_backend() + args = parse_args(argv) logging.basicConfig( format="%(asctime)s %(levelname)s %(name)s: %(message)s", diff --git a/infer.py b/infer.py index 99844b8..384e73d 100644 --- a/infer.py +++ b/infer.py @@ -297,7 +297,31 @@ def maybe_print_voice_clone_text_chunks( print() +def _patch_torchaudio_backend() -> None: + """Patch torchaudio to avoid the SoX backend, which segfaults on some systems.""" + try: + import torchaudio + _original_load = torchaudio.load + _original_save = torchaudio.save + + def _load_with_soundfile(uri, *args, backend=None, **kwargs): + if backend is None: + backend = "soundfile" + return _original_load(uri, *args, backend=backend, **kwargs) + + def _save_with_soundfile(uri, src, sample_rate, *args, backend=None, **kwargs): + if backend is None: + backend = "soundfile" + return _original_save(uri, src, sample_rate, *args, backend=backend, **kwargs) + + torchaudio.load = _load_with_soundfile + torchaudio.save = _save_with_soundfile + except ImportError: + pass + + def main(argv: Optional[Sequence[str]] = None) -> dict[str, object]: + _patch_torchaudio_backend() set_logging() args = parse_args(argv) if args.debug == 1: diff --git a/moss_tts_nano_runtime.py b/moss_tts_nano_runtime.py index 37ea500..a2e7c37 100644 --- a/moss_tts_nano_runtime.py +++ b/moss_tts_nano_runtime.py @@ -25,21 +25,15 @@ _DEFAULT_VOICE_FILES: dict[str, tuple[str, str]] = { "Junhao": ("zh_1.wav", "Chinese male voice A"), - "Zhiming": ("zh_2.wav", "Chinese male voice B"), - "Weiguo": ("zh_5.wav", "Chinese male voice C"), "Xiaoyu": ("zh_3.wav", "Chinese female voice A"), "Yuewen": ("zh_4.wav", "Chinese female voice B"), "Lingyu": ("zh_6.wav", "Chinese female voice C"), - "Trump": ("en_1.wav", "Trump reference voice"), "Ava": ("en_2.wav", "English female voice A"), "Bella": ("en_3.wav", "English female voice B"), "Adam": ("en_4.wav", "English male voice A"), - "Nathan": ("en_5.wav", "English male voice B"), - "Sakura": ("jp_1.mp3", "Japanese female voice A"), "Yui": ("jp_2.wav", "Japanese female voice B"), - "Aoi": ("jp_3.wav", "Japanese female voice C"), - "Hina": ("jp_4.wav", "Japanese female voice D"), - "Mei": ("jp_5.wav", "Japanese female voice E"), + "男播音": ("zh_10.wav", "Chinese male broadcaster voice"), + "杨幂": ("zh_11.wav", "Chinese female voice YangMi style"), } DEFAULT_VOICE = "Junhao" diff --git a/onnx_tts_runtime.py b/onnx_tts_runtime.py index 5ea0d7d..3f69499 100644 --- a/onnx_tts_runtime.py +++ b/onnx_tts_runtime.py @@ -490,9 +490,26 @@ def resolve_prompt_audio_codes( return self.encode_reference_audio(prompt_audio_path) resolved_voice = str(voice or self.list_builtin_voices()[0]["voice"]) voice_row = next((item for item in self.list_builtin_voices() if item["voice"] == resolved_voice), None) - if voice_row is None: - raise ValueError(f"Built-in voice not found: {resolved_voice}") - return list(voice_row["prompt_audio_codes"]) + if voice_row is not None: + return list(voice_row["prompt_audio_codes"]) + # Fallback: try to find wav file in preset voices directory + from moss_tts_nano.defaults import DEFAULT_PROMPT_AUDIO_DIR + wav_candidates = [ + DEFAULT_PROMPT_AUDIO_DIR / f"{resolved_voice}.wav", + DEFAULT_PROMPT_AUDIO_DIR / f"{resolved_voice}_reference.wav", + ] + # Also check the PyTorch preset map for wav filename + try: + from moss_tts_nano_runtime import _DEFAULT_VOICE_FILES + if resolved_voice in _DEFAULT_VOICE_FILES: + wav_candidates.insert(0, DEFAULT_PROMPT_AUDIO_DIR / _DEFAULT_VOICE_FILES[resolved_voice][0]) + except ImportError: + pass + for wav_path in wav_candidates: + if wav_path.exists(): + logging.info("ONNX voice %r not in builtin manifest, encoding from %s", resolved_voice, wav_path) + return self.encode_reference_audio(wav_path) + raise ValueError(f"Built-in voice not found: {resolved_voice}") def decode_full_audio_safe(self, generated_frames: list[list[int]]) -> np.ndarray: try: diff --git a/openai_audio_api.py b/openai_audio_api.py new file mode 100644 index 0000000..3ea681e --- /dev/null +++ b/openai_audio_api.py @@ -0,0 +1,477 @@ +from __future__ import annotations + +import io +import logging +import re +import subprocess +import struct +import unicodedata +import wave +from typing import Iterator + +import numpy as np +from pydantic import BaseModel, Field + +# --------------------------------------------------------------------------- +# Voice mapping: OpenAI voice names → MOSS-TTS-Nano preset names +# --------------------------------------------------------------------------- + +_OPENAI_VOICE_MAP: dict[str, str] = { + "alloy": "Junhao", # zh_1.wav + "echo": "Xiaoyu", # zh_3.wav + "fable": "Yuewen", # zh_4.wav + "onyx": "Adam", # en_4.wav + "nova": "Lingyu", # zh_6.wav + "shimmer": "Bella", # en_3.wav + "ash": "Ava", # en_2.wav + "sage": "Junhao", # zh_1.wav + "coral": "Xiaoyu", # zh_3.wav + "noova": "Lingyu", # zh_6.wav + "ballad": "男播音", # zh_10.wav + "yangmi": "杨幂", # zh_11.wav +} + + +def resolve_voice(voice: str) -> str: + """Map an OpenAI voice name to a MOSS-TTS-Nano preset, or pass through.""" + return _OPENAI_VOICE_MAP.get(voice, voice) + + +# --------------------------------------------------------------------------- +# Emoji / kaomoji stripping +# --------------------------------------------------------------------------- + +# Unicode emoji ranges (pictographs, symbols, modifiers, flags, etc.) +_EMOJI_RE = re.compile( + "[" + "\U0001F600-\U0001F64F" # emoticons + "\U0001F300-\U0001F5FF" # misc symbols & pictographs + "\U0001F680-\U0001F6FF" # transport & map + "\U0001F1E0-\U0001F1FF" # flags + "\U00002702-\U000027B0" # dingbats + "\U000024C2-\U0001F251" + "\U0001F900-\U0001F9FF" # supplemental symbols + "\U0001FA00-\U0001FA6F" # chess symbols + "\U0001FA70-\U0001FAFF" # symbols extended-A + "\U00002600-\U000026FF" # misc symbols + "\U0000FE00-\U0000FE0F" # variation selectors + "\U0000200D" # ZWJ + "\U00002B50" # star + "\U0000203C-\U00003299" # misc symbols + "\U0000FE00-\U0000FEFF" # variation selectors & BOM + "]+", + re.UNICODE, +) + +# Common kaomoji patterns: loose match for face-like punctuation clusters +# Uses a heuristic: sequences of CJK symbols/punctuation + special chars +# that are 3+ chars long and look like faces +_KAOMOJI_RE = re.compile( + r"(?:" + # Shrug: ¯\_(...)_/¯ + r"¯\\?_?\(.*?\)_?/?¯" + r"|" + # Face with brackets: (●◡●) (╯°□°)╯ etc. + r"\([^)]*[◉◎⊙●◡▼▽ᗜᴖᴗ◕‿◕°□※✧※][^)]*\)" + r"|" + # Table flip and arm-like gestures + r"\(╯[°□◉]\)?╯\s*[︵︶╰].*?(?:╰\s*[︶╯]\s*[°□◉]\s*╰\s*\))?" + r"|" + # Simple arm gestures: ヽ(…)ノ ヽ(。_。)ノ etc + r"[ヽヾ]\(.*?\)[ノヾ]" + r"|" + # Flipping tables: ┻━┻ ┣━┫ ┳━┳ + r"[┣┻┳╚╗╔╝]\s*[━═]\s*[┫┻┳╚╗╔╝]" + r"|" + # Raised arms / action: (ノಠ益ಠ)ノ彡 + r"\(ノ[^)]*ಠ[^)]*\)ノ[^)]*" + r")", + re.UNICODE, +) + + +def strip_emoji(text: str) -> str: + """Remove emoji, kaomoji, and zero-width modifiers from *text*.""" + text = _KAOMOJI_RE.sub("", text) + out: list[str] = [] + skip_next_variation = False + i = 0 + while i < len(text): + ch = text[i] + cp = ord(ch) + cat = unicodedata.category(ch) + + # Zero-width / variation selectors / combining marks that follow emoji + if cat == "Cf" or cp == 0x200D: # ZWJ, variation selectors, etc. + i += 1 + continue + + # Emoji ranges: actual pictographic codepoints only + if ( + 0x1F600 <= cp <= 0x1F64F # emoticons + or 0x1F300 <= cp <= 0x1F5FF # misc symbols & pictographs + or 0x1F680 <= cp <= 0x1F6FF # transport & map + or 0x1F900 <= cp <= 0x1F9FF # supplemental symbols + or 0x1FA00 <= cp <= 0x1FAFF # symbols extended-A + or 0x1F1E6 <= cp <= 0x1F1FF # regional indicators (flags) + or 0x1F3FB <= cp <= 0x1F3FF # skin tone modifiers + or 0x2600 <= cp <= 0x26FF # misc symbols + or 0x2702 <= cp <= 0x27B0 # dingbats + or cp == 0x2B50 # star + or cp == 0x3030 # wavy dash + or cp == 0x303D # part alternation mark + or cp == 0x3297 # circled "congratulations" + or cp == 0x3299 # circled "secret" + or cp == 0xFE0F # variation selector-16 + ): + i += 1 + continue + + out.append(ch) + i += 1 + + text = "".join(out) + # Clean up stray kaomoji fragments (╯ ︵ ╰ ┳ ┻ etc.) + text = re.sub(r"[╯╰︵︶┳┻┣┫━═彡ッツ]", "", text) + text = re.sub(r"[ \t]+", " ", text).strip() + return text + + +def _number_to_chinese(n: int) -> str: + """Convert integer 0–100 to spoken Chinese. Falls back to digits for >100.""" + _D = "零一二三四五六七八九" + if n == 0: + return "零" + if n < 10: + return _D[n] + if n == 10: + return "十" + if n < 20: + return "十" + (_D[n % 10] if n % 10 else "") + if n < 100: + tens, ones = divmod(n, 10) + return _D[tens] + "十" + (_D[ones] if ones else "") + if n == 100: + return "百" + return str(n) + + +def preprocess_tts_input(text: str) -> str: + """Strip emoji/kaomoji, normalize newlines, and convert symbols to readable text. + + Translates unit symbols (°C, km/h, %, etc.) and range operators (~) into + Chinese text so the TTS model can pronounce them correctly. + """ + text = strip_emoji(text) + + # --- Time format: HH:MM → H点 / H点M分 (must run before colon→。) --- + def _time_repl(m: re.Match) -> str: + hour = int(m.group(1)) + minute = int(m.group(2)) + if minute == 0: + return f"{hour}点" + return f"{hour}点{minute}分" + + text = re.sub(r"(\d{1,2}):(\d{2})(?!\d)", _time_repl, text) + + # --- Symbol-to-text: simple 1:1 replacements, no merging --- + + # Convert ~ → 到 + text = re.sub(r"\s*~\s*", "到", text) + + # Convert units + text = re.sub(r"(\d)\s*°C", r"\1摄氏度", text) + text = re.sub(r"(\d)\s*°F", r"\1华氏度", text) + text = re.sub(r"(\d)\s*℃", r"\1摄氏度", text) + text = re.sub(r"(\d)\s*km/h\b", r"\1千米每小时", text) + text = re.sub(r"(\d)\s*m/s\b", r"\1米每秒", text) + text = re.sub(r"(\d)\s*mph\b", r"\1英里每小时", text) + + # Percent: N% → 百分之N (Chinese reads "百分之" before the number) + def _pct_repl(m: re.Match) -> str: + val = float(m.group(1)) + if val == int(val) and 0 <= int(val) <= 100: + return "百分之" + _number_to_chinese(int(val)) + return "百分之" + m.group(1) + + text = re.sub(r"(\d+(?:\.\d+)?)\s*%", _pct_repl, text) + + # --- Punctuation normalization for TTS chunking --- + + # Replace :(full-width colon) with ,(comma) so labels and values + # stay together as one natural phrase (e.g. "气温,十八摄氏度"). + # Using 。was causing the model to repeat the last phrase at boundaries. + text = re.sub(r"[::]", ",", text) + + # Replace em/en dashes with commas + text = re.sub(r"\s*[—–]\s*", ",", text) + + # --- Newline normalization --- + text = text.replace("\r\n", "\n").replace("\r", "\n") + text = re.sub(r"\n{2,}", "\n", text) + text = re.sub(r"[ \t]+\n", "\n", text) + text = text.replace("\n", "。") + # Clean up consecutive punctuation like ,。 or 。, + text = re.sub(r"[,。]{2,}", lambda m: "。" if "。" in m.group() else ",", text) + text = text.lstrip("。") + text = re.sub(r"[ \t]+", " ", text).strip() + return text + + +# --------------------------------------------------------------------------- +# Request / response models +# --------------------------------------------------------------------------- + +class SpeechRequest(BaseModel): + """OpenAI-compatible ``POST /v1/audio/speech`` request body.""" + + model: str = "tts-1" + input: str + voice: str + response_format: str = Field(default="wav", pattern=r"^(wav|mp3|pcm|opus)$") + speed: float = Field(default=1.0, ge=0.25, le=4.0) + + +# --------------------------------------------------------------------------- +# Format helpers +# --------------------------------------------------------------------------- + +def _audio_to_pcm16le(audio_array: np.ndarray) -> bytes: + """Convert float32 numpy audio to raw PCM signed-16-bit little-endian.""" + audio_np = np.asarray(audio_array, dtype=np.float32) + if audio_np.ndim == 1: + audio_np = audio_np[:, None] + elif audio_np.ndim == 2 and audio_np.shape[0] <= 8 and audio_np.shape[0] < audio_np.shape[1]: + audio_np = audio_np.T + audio_np = np.clip(audio_np, -1.0, 1.0) + return (audio_np * 32767.0).astype(np.int16).tobytes() + + +def _resample_pcm(pcm: bytes, speed: float, channels: int = 1) -> bytes: + """Resample PCM s16le audio by *speed* factor using linear interpolation. + + speed > 1.0 = faster/shorter, speed < 1.0 = slower/longer. + Changes pitch (simple resampling). For pitch-preserving speed change, + use the ffmpeg atempo filter (used in opus path). + Handles multi-channel (interleaved) PCM correctly. + """ + if speed == 1.0 or not pcm: + return pcm + samples = np.frombuffer(pcm, dtype=np.int16).astype(np.float32) + if channels > 1: + # De-interleave, resample each channel, re-interleave + per_ch = len(samples) // channels + reshaped = samples.reshape(per_ch, channels) + new_len = max(1, int(per_ch / speed)) + indices = np.linspace(0, per_ch - 1, new_len) + resampled = np.column_stack([ + np.interp(indices, np.arange(per_ch), reshaped[:, ch]) + for ch in range(channels) + ]) + else: + new_len = max(1, int(len(samples) / speed)) + indices = np.linspace(0, len(samples) - 1, new_len) + resampled = np.interp(indices, np.arange(len(samples)), samples) + return resampled.astype(np.int16).tobytes() + + +def _wav_header_bytes(sample_rate: int, channels: int, data_length: int = 0) -> bytes: + """Build a 44-byte RIFF WAV header. + + When *data_length* is 0 the header uses ``0x7FFFFFFF`` as a placeholder + size so that streaming clients do not reject the file for being "too short". + """ + bits_per_sample = 16 + byte_rate = sample_rate * channels * bits_per_sample // 8 + block_align = channels * bits_per_sample // 8 + # If total data length unknown (streaming), use a large placeholder. + data_size = data_length if data_length > 0 else 0x7FFFFFFF + file_size = 36 + data_size + + return struct.pack( + "<4sI4s4sIHHIIHH4sI", + b"RIFF", + file_size, + b"WAVE", + b"fmt ", + 16, # chunk size + 1, # PCM format + channels, + sample_rate, + byte_rate, + block_align, + bits_per_sample, + b"data", + data_size, + ) + + +def _wav_bytes_from_pcm(pcm: bytes, sample_rate: int, channels: int) -> bytes: + """Wrap a complete PCM buffer in a WAV container.""" + header = _wav_header_bytes(sample_rate, channels, data_length=len(pcm)) + return header + pcm + + +# --------------------------------------------------------------------------- +# MP3 encoding (lazy lameenc import) +# --------------------------------------------------------------------------- + +def _encode_pcm_to_mp3(pcm: bytes, sample_rate: int, channels: int) -> bytes: + """Encode one PCM chunk to MP3 using *lameenc*.""" + try: + import lameenc + except ImportError: + raise RuntimeError( + "MP3 encoding requires the 'lameenc' package. " + "Install it with: pip install lameenc" + ) + + encoder = lameenc.Encoder() + encoder.set_bit_rate(128) + encoder.set_in_sample_rate(sample_rate) + encoder.set_channels(channels) + encoder.set_quality(2) # high quality + return bytes(encoder.encode(pcm)) + bytes(encoder.flush()) + + +# --------------------------------------------------------------------------- +# Streaming generators +# --------------------------------------------------------------------------- + +def iter_pcm_audio( + events: Iterator[tuple[dict, str, int]], +) -> Iterator[tuple[bytes, int, int]]: + """Yield ``(pcm_bytes, sample_rate, channels)`` from synthesize_stream events. + + *events* comes from ``RequestRuntimeManager.iter_with_runtime`` which yields + ``(event_dict, execution_device, cpu_threads)`` tuples. + """ + for item in events: + # iter_with_runtime yields (event, device, threads) tuples + event = item[0] if isinstance(item, tuple) else item + event_type = str(event.get("type", "")) + if event_type != "audio": + continue + waveform = np.asarray(event["waveform_numpy"], dtype=np.float32) + sample_rate = int(event["sample_rate"]) + channels = 1 if waveform.ndim == 1 else int(waveform.shape[1]) + pcm = _audio_to_pcm16le(waveform) + if pcm: + yield bytes(pcm), sample_rate, channels + + +def generate_wav_stream(events: Iterator[dict]) -> Iterator[bytes]: + """Yield WAV-formatted chunks: header first, then raw PCM data.""" + header_sent = False + for pcm, sample_rate, channels in iter_pcm_audio(events): + if not header_sent: + yield _wav_header_bytes(sample_rate, channels) + header_sent = True + yield pcm + + +def generate_pcm_stream(events: Iterator[dict]) -> Iterator[bytes]: + """Yield raw PCM bytes directly.""" + for pcm, _, _ in iter_pcm_audio(events): + yield pcm + + +def generate_mp3_stream(events: Iterator[dict]) -> Iterator[bytes]: + """Yield MP3 frames; encodes each PCM chunk independently.""" + encoder = None + try: + import lameenc + except ImportError: + raise RuntimeError( + "MP3 encoding requires the 'lameenc' package. " + "Install it with: pip install lameenc" + ) + + for pcm, sample_rate, channels in iter_pcm_audio(events): + if encoder is None: + encoder = lameenc.Encoder() + encoder.set_bit_rate(128) + encoder.set_in_sample_rate(sample_rate) + encoder.set_channels(channels) + encoder.set_quality(2) + yield bytes(encoder.encode(pcm)) + + if encoder is not None: + flush = encoder.flush() + if flush: + yield bytes(flush) + + +# --------------------------------------------------------------------------- +# Opus encoding via ffmpeg subprocess +# --------------------------------------------------------------------------- + +_OPUS_FRAME_SIZE = 960 # 20ms at 48kHz, the standard Opus frame + + +def start_opus_encoder(sample_rate: int, channels: int, speed: float = 1.0) -> subprocess.Popen: + """Start an ffmpeg subprocess that accepts PCM on stdin, produces Ogg/Opus on stdout.""" + audio_filters = [] + if speed != 1.0: + # atempo range is [0.5, 100.0]; for values outside, chain multiple filters + remaining = speed + while remaining > 100.0: + audio_filters.append("atempo=100.0") + remaining /= 100.0 + while remaining < 0.5: + audio_filters.append("atempo=0.5") + remaining /= 0.5 + audio_filters.append(f"atempo={remaining:.4f}") + + cmd = [ + "ffmpeg", "-hide_banner", "-loglevel", "error", + "-f", "s16le", "-ar", str(sample_rate), "-ac", str(channels), + "-i", "-", + ] + if audio_filters: + cmd.extend(["-af", ",".join(audio_filters)]) + cmd.extend([ + "-c:a", "libopus", "-b:a", "65536", + "-f", "ogg", "-", + ]) + return subprocess.Popen( + cmd, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + + +# --------------------------------------------------------------------------- +# Error responses +# --------------------------------------------------------------------------- + +def make_error_response( + message: str, + *, + param: str | None = None, + error_type: str = "invalid_request_error", + status_code: int = 400, +) -> tuple[dict, int]: + """Return ``(body_dict, http_status)`` following OpenAI error schema.""" + body = { + "error": { + "message": message, + "type": error_type, + } + } + if param is not None: + body["error"]["param"] = param + return body, status_code + + +# --------------------------------------------------------------------------- +# Content types +# --------------------------------------------------------------------------- + +FORMAT_CONTENT_TYPE: dict[str, str] = { + "wav": "audio/wav", + "mp3": "audio/mpeg", + "pcm": "audio/pcm", + "opus": "audio/opus", +} diff --git a/ort_cpu_runtime.py b/ort_cpu_runtime.py index e7093a4..da3a0a2 100644 --- a/ort_cpu_runtime.py +++ b/ort_cpu_runtime.py @@ -352,30 +352,41 @@ def _session(self, path_value: Path) -> ort.InferenceSession: return ort.InferenceSession(str(path_value), sess_options=options, providers=["CPUExecutionProvider"]) def _create_sessions(self) -> dict[str, ort.InferenceSession]: + return self._create_sessions_with_threads(self.thread_count) + + def _create_sessions_with_threads(self, thread_count: int) -> dict[str, ort.InferenceSession]: + """Create ONNX sessions with a specific intra-op thread count.""" + def _make_session(path_value: Path) -> ort.InferenceSession: + options = ort.SessionOptions() + options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL + options.intra_op_num_threads = max(1, int(thread_count)) + options.inter_op_num_threads = 1 + return ort.InferenceSession(str(path_value), sess_options=options, providers=["CPUExecutionProvider"]) + tts_dir = self.tts_meta_path.parent codec_dir = self.codec_meta_path.parent return { - "prefill": self._session(tts_dir / self.tts_meta["files"]["prefill"]), - "decode": self._session(tts_dir / self.tts_meta["files"]["decode_step"]), - "local_decoder": self._session(tts_dir / self.tts_meta["files"]["local_decoder"]), + "prefill": _make_session(tts_dir / self.tts_meta["files"]["prefill"]), + "decode": _make_session(tts_dir / self.tts_meta["files"]["decode_step"]), + "local_decoder": _make_session(tts_dir / self.tts_meta["files"]["local_decoder"]), **( - {"local_greedy_frame": self._session(tts_dir / self.tts_meta["files"]["local_greedy_frame"])} + {"local_greedy_frame": _make_session(tts_dir / self.tts_meta["files"]["local_greedy_frame"])} if self.tts_meta["files"].get("local_greedy_frame") else {} ), **( - {"local_fixed_sampled_frame": self._session(tts_dir / self.tts_meta["files"]["local_fixed_sampled_frame"])} + {"local_fixed_sampled_frame": _make_session(tts_dir / self.tts_meta["files"]["local_fixed_sampled_frame"])} if self.tts_meta["files"].get("local_fixed_sampled_frame") else {} ), **( - {"local_cached_step": self._session(tts_dir / self.tts_meta["files"]["local_cached_step"])} + {"local_cached_step": _make_session(tts_dir / self.tts_meta["files"]["local_cached_step"])} if self.tts_meta["files"].get("local_cached_step") else {} ), - "codec_encode": self._session(codec_dir / self.codec_meta["files"]["encode"]), - "codec_decode": self._session(codec_dir / self.codec_meta["files"]["decode_full"]), - "codec_decode_step": self._session(codec_dir / self.codec_meta["files"]["decode_step"]), + "codec_encode": _make_session(codec_dir / self.codec_meta["files"]["encode"]), + "codec_decode": _make_session(codec_dir / self.codec_meta["files"]["decode_full"]), + "codec_decode_step": _make_session(codec_dir / self.codec_meta["files"]["decode_step"]), } def list_builtin_voices(self) -> list[dict[str, Any]]: diff --git a/pyproject.toml b/pyproject.toml index 91ceff7..e4817eb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ dependencies = [ "transformers==4.57.1", "uvicorn>=0.29.0", "onnxruntime>=1.20.0", + "lameenc>=1.7.0", ] [project.urls] @@ -39,6 +40,7 @@ py-modules = [ "infer", "infer_onnx", "moss_tts_nano_runtime", + "openai_audio_api", "onnx_tts_runtime", "ort_cpu_runtime", "text_normalization_pipeline", diff --git a/requirements.txt b/requirements.txt index a3ab42c..3da49a8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,5 @@ sentencepiece uvicorn>=0.29.0 WeTextProcessing>=1.0.4.1 soundfile -onnxruntime>=1.20.0 \ No newline at end of file +onnxruntime>=1.20.0 +lameenc>=1.7.0 \ No newline at end of file diff --git a/text_normalization_pipeline.py b/text_normalization_pipeline.py index f755d7f..064a172 100644 --- a/text_normalization_pipeline.py +++ b/text_normalization_pipeline.py @@ -210,12 +210,11 @@ def prepare_tts_request_texts( if enable_normalize_tts_text and enable_wetext: pre_robust_text = normalize_tts_text(raw_text) pre_robust_prompt_text = normalize_tts_text(raw_prompt_text) if raw_prompt_text else "" - if pre_robust_text != raw_text: - logging.info( - "normalized text chars_before=%d chars_after=%d stage=robust_pre", - len(raw_text), - len(pre_robust_text), - ) + logging.info( + "TTS text pipeline stage=robust_pre before=%r after=%r", + raw_text, + pre_robust_text, + ) if raw_prompt_text and pre_robust_prompt_text != raw_prompt_text: logging.info( "normalized prompt_text chars_before=%d chars_after=%d stage=robust_pre", @@ -237,9 +236,9 @@ def prepare_tts_request_texts( rewritten_wetext_input_prompt_text = _rewrite_hyphens_before_zh_wetext(wetext_input_prompt_text) if rewritten_wetext_input_text != wetext_input_text: logging.info( - "rewrote zh wetext text hyphens chars_before=%d chars_after=%d stage=zh_wetext_hyphen_guard", - len(wetext_input_text), - len(rewritten_wetext_input_text), + "TTS text pipeline stage=zh_wetext_hyphen_guard before=%r after=%r", + wetext_input_text, + rewritten_wetext_input_text, ) if wetext_input_prompt_text and rewritten_wetext_input_prompt_text != wetext_input_prompt_text: logging.info( @@ -256,10 +255,10 @@ def prepare_tts_request_texts( ) if intermediate_text != wetext_input_text: logging.info( - "normalized text chars_before=%d chars_after=%d stage=wetext language=%s", - len(wetext_input_text), - len(intermediate_text), + "TTS text pipeline stage=wetext lang=%s before=%r after=%r", normalization_language, + wetext_input_text, + intermediate_text, ) if wetext_input_prompt_text and intermediate_prompt_text != wetext_input_prompt_text: logging.info( @@ -279,10 +278,10 @@ def prepare_tts_request_texts( if final_text != intermediate_text: logging.info( - "normalized text chars_before=%d chars_after=%d stage=%s", - len(intermediate_text), - len(final_text), + "TTS text pipeline stage=%s before=%r after=%r", robust_stage_name, + intermediate_text, + final_text, ) if intermediate_prompt_text and final_prompt_text != intermediate_prompt_text: logging.info(