diff --git a/examples/parakeet-cli/parakeet-cli.cpp b/examples/parakeet-cli/parakeet-cli.cpp index c8a77a42fbc..91c756b3132 100644 --- a/examples/parakeet-cli/parakeet-cli.cpp +++ b/examples/parakeet-cli/parakeet-cli.cpp @@ -35,6 +35,34 @@ static char * requires_value_error(const std::string & arg) { exit(1); } +static std::string trim(const std::string & line) { + const size_t first = line.find_first_not_of(" \t\r\n"); + if (first == std::string::npos) { + return ""; + } + + const size_t last = line.find_last_not_of(" \t\r\n"); + return line.substr(first, last - first + 1); +} + +static void parakeet_params_add_file_list(const std::string & fname, parakeet_params & params) { + std::ifstream fin(fname); + if (!fin.is_open()) { + fprintf(stderr, "error: failed to open file list '%s'\n", fname.c_str()); + exit(1); + } + + std::string line; + while (std::getline(fin, line)) { + line = trim(line); + if (line.empty() || line[0] == '#') { + continue; + } + + params.fname_inp.push_back(line); + } +} + static bool parakeet_params_parse(int argc, char ** argv, parakeet_params & params) { if (const char * env_device = std::getenv("PARAKEET_ARG_DEVICE")) { params.gpu_device = std::stoi(env_device); @@ -64,6 +92,7 @@ static bool parakeet_params_parse(int argc, char ** argv, parakeet_params & para else if (arg == "-rc" || arg == "--right-context") { params.right_context_ms = std::stoi(ARGV_NEXT); } else if (arg == "-m" || arg == "--model") { params.model = ARGV_NEXT; } else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(ARGV_NEXT); } + else if (arg == "-fl" || arg == "--file-list") { parakeet_params_add_file_list(ARGV_NEXT, params); } else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } else if (arg == "-dev" || arg == "--device") { params.gpu_device = std::stoi(ARGV_NEXT); } else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = false; } @@ -95,6 +124,7 @@ static void parakeet_print_usage(int /*argc*/, char ** argv, const parakeet_para fprintf(stderr, " -rc N, --right-context N [%-7d] right context in milliseconds\n", params.right_context_ms); fprintf(stderr, " -m, --model FILE [%-7s] model path\n", params.model.c_str()); fprintf(stderr, " -f, --file FILE [%-7s] input audio file\n", ""); + fprintf(stderr, " -fl, --file-list FILE [%-7s] text file containing one input audio path per line\n", ""); fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true"); fprintf(stderr, " -dev N, --device N [%-7d] GPU device to use\n", params.gpu_device); fprintf(stderr, " -fa, --flash-attn [%-7s] enable flash attention\n", params.flash_attn ? "true" : "false"); @@ -139,6 +169,32 @@ int main(int argc, char ** argv) { return 1; } + if (!params.output_file.empty() && params.fname_inp.size() > 1) { + fprintf(stderr, "error: --output-file cannot be used with multiple input files\n"); + return 1; + } + + struct parakeet_context_params ctx_params = parakeet_context_default_params(); + ctx_params.use_gpu = params.use_gpu; + ctx_params.flash_attn = params.flash_attn; + ctx_params.gpu_device = params.gpu_device; + + if (!params.no_prints) { + fprintf(stderr, "Loading Parakeet model from: %s\n", params.model.c_str()); + } + + struct parakeet_context * pctx = parakeet_init_from_file_with_params(params.model.c_str(), ctx_params); + if (pctx == nullptr) { + fprintf(stderr, "error: failed to load Parakeet model from '%s'\n", params.model.c_str()); + return 1; + } + + if (!params.no_prints) { + fprintf(stderr, "Successfully loaded Parakeet model\n"); + fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", + params.n_threads, (int32_t) std::thread::hardware_concurrency(), parakeet_print_system_info()); + } + // Process each input file for (const auto & fname : params.fname_inp) { if (!params.no_prints) { @@ -158,24 +214,6 @@ int main(int argc, char ** argv) { } if (!params.no_prints) { - fprintf(stderr, "Loading Parakeet model from: %s\n", params.model.c_str()); - } - - struct parakeet_context_params ctx_params = parakeet_context_default_params(); - ctx_params.use_gpu = params.use_gpu; - ctx_params.flash_attn = params.flash_attn; - ctx_params.gpu_device = params.gpu_device; - - struct parakeet_context * pctx = parakeet_init_from_file_with_params(params.model.c_str(), ctx_params); - if (pctx == nullptr) { - fprintf(stderr, "error: failed to load Parakeet model from '%s'\n", params.model.c_str()); - return 1; - } - - if (!params.no_prints) { - fprintf(stderr, "Successfully loaded Parakeet model\n"); - fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", - params.n_threads, (int32_t) std::thread::hardware_concurrency(), parakeet_print_system_info()); fprintf(stderr, "Processing audio (%zu samples, %.2f seconds)\n", pcmf32.size(), (float)pcmf32.size() / PARAKEET_SAMPLE_RATE); } @@ -197,7 +235,6 @@ int main(int argc, char ** argv) { if (ret != 0) { fprintf(stderr, "error: failed to process audio file '%s'\n", fname.c_str()); - parakeet_free(pctx); continue; } @@ -222,10 +259,6 @@ int main(int argc, char ** argv) { } } - if (!params.no_prints) { - parakeet_print_timings(pctx); - } - if (params.print_segments) { const int n_segments = parakeet_full_n_segments(pctx); fprintf(stderr, "\nSegments (%d):\n", n_segments); @@ -259,8 +292,12 @@ int main(int argc, char ** argv) { } } - parakeet_free(pctx); } + if (!params.no_prints) { + parakeet_print_timings(pctx); + } + parakeet_free(pctx); + return 0; } diff --git a/tests/librispeech-parakeet/.gitignore b/tests/librispeech-parakeet/.gitignore index 838bfeae9db..0f08cd7b36a 100644 --- a/tests/librispeech-parakeet/.gitignore +++ b/tests/librispeech-parakeet/.gitignore @@ -1,6 +1,7 @@ __pycache__ *.tar.gz *.txt +!requirements.txt eval.conf venv LibriSpeech diff --git a/tests/librispeech-parakeet/README.md b/tests/librispeech-parakeet/README.md index e09cba405ef..0aa06fa374b 100644 --- a/tests/librispeech-parakeet/README.md +++ b/tests/librispeech-parakeet/README.md @@ -17,13 +17,7 @@ performance of parakeet.cpp on LibriSpeech corpus. $ cmake --build build --config Release ``` -2. Download the audio files from LibriSpeech project. - - ``` - $ make get-audio - ``` - -3. Set up the environment to compute WER score. +2. Set up the environment to compute WER score. ``` $ pip install -r requirements.txt @@ -37,21 +31,49 @@ performance of parakeet.cpp on LibriSpeech corpus. $ pip install -r requirements.txt ``` -4. Run the benchmark test. +3. Run the benchmark test. + + ``` + $ python run_eval.py --download --cli ../../build/bin/parakeet-cli + ``` + + On Windows with a Visual Studio build, the CLI path usually includes the + configuration directory: ``` - $ make + > python run_eval.py --download --cli ..\..\build\bin\Release\parakeet-cli.exe ``` + For a CUDA build, pass the CUDA build's `parakeet-cli` explicitly: + + ``` + > python run_eval.py --download --force --cli ..\..\build-cuda\bin\Release\parakeet-cli.exe + ``` + + `parakeet-cli` loads the model once for all pending LibriSpeech files. + Without `--force`, existing `*.flac.txt` transcript files are reused. + +## Makefile runner + +The checked-in `Makefile`/`eval.mk` runner uses `make` plus commands such as +`wget`, `tar`, `mv`, and `rm`. It invokes the configured `parakeet-cli` once +per audio file and then runs `eval.py`. + ## How-to guides ### How to change the inference parameters -Create `eval.conf` and override variables. +With `run_eval.py`, pass Parakeet CLI options after `--`. + +``` +$ python run_eval.py --download --cli ../../build/bin/parakeet-cli -- --threads 8 --no-flash-attn +``` + +With the Makefile runner, create `eval.conf` and override variables. ``` PARAKEET_MODEL = parakeet-tdt-0.6b-v3 -PARAKEET_FLAGS = --no-prints --threads 8 --language en --output-txt +PARAKEET_FLAGS = --no-prints --threads 8 --output-txt ``` Check out `eval.mk` for more details. diff --git a/tests/librispeech-parakeet/requirements.txt b/tests/librispeech-parakeet/requirements.txt new file mode 100644 index 00000000000..0e69d036202 --- /dev/null +++ b/tests/librispeech-parakeet/requirements.txt @@ -0,0 +1,6 @@ +# This is the minimal set of dependencies we need to compute +# WER score. Read Section 3.2. of the original paper +# (https://arxiv.org/abs/2212.04356) for more contexts. +jiwer +regex +more-itertools diff --git a/tests/librispeech-parakeet/run_eval.py b/tests/librispeech-parakeet/run_eval.py new file mode 100644 index 00000000000..3e4dd6de910 --- /dev/null +++ b/tests/librispeech-parakeet/run_eval.py @@ -0,0 +1,270 @@ +#!/usr/bin/env python3 +import argparse +import subprocess +import sys +import tarfile +import urllib.request +from pathlib import Path +from typing import List + + +DEFAULT_URL = "https://www.openslr.org/resources/12/test-clean.tar.gz" + + +def eprint(*args): + print(*args, file=sys.stderr) + + +def format_bytes(value: int) -> str: + units = ["B", "KB", "MB", "GB"] + amount = float(value) + for unit in units: + if amount < 1024.0 or unit == units[-1]: + return f"{amount:.1f} {unit}" + amount /= 1024.0 + return f"{amount:.1f} GB" + + +def path_is_inside(path: Path, directory: Path) -> bool: + try: + path.relative_to(directory) + return True + except ValueError: + return False + + +def download_file(url: str, destination: Path) -> None: + if destination.exists(): + eprint(f"using existing archive: {destination}") + return + + eprint(f"downloading {url}") + with urllib.request.urlopen(url) as response, destination.open("wb") as output: + total_header = response.headers.get("Content-Length") + total = int(total_header) if total_header else None + downloaded = 0 + + while True: + chunk = response.read(1024 * 1024) + if not chunk: + break + + output.write(chunk) + downloaded += len(chunk) + + if total: + percent = downloaded * 100.0 / total + progress = f"{format_bytes(downloaded)} / {format_bytes(total)} ({percent:.1f}%)" + else: + progress = format_bytes(downloaded) + + print(f"\rdownloaded {progress}", end="", file=sys.stderr, flush=True) + + print(file=sys.stderr) + eprint(f"downloaded archive: {destination}") + + +def extract_archive(archive: Path, work_dir: Path) -> None: + extracted = work_dir / "LibriSpeech" / "test-clean" + if extracted.exists(): + eprint(f"using existing dataset: {extracted}") + return + + eprint(f"extracting {archive}") + root = work_dir.resolve() + with tarfile.open(archive, "r:gz") as tar: + for member in tar.getmembers(): + target = (work_dir / member.name).resolve() + if not path_is_inside(target, root): + raise RuntimeError(f"archive member escapes destination: {member.name}") + try: + tar.extractall(work_dir, filter="data") + except TypeError: + tar.extractall(work_dir) + + +def find_flac_files(dataset: Path) -> List[Path]: + files: List[Path] = [] + for speaker_dir in sorted(path for path in dataset.iterdir() if path.is_dir()): + for chapter_dir in sorted(path for path in speaker_dir.iterdir() if path.is_dir()): + files.extend(sorted(chapter_dir.glob("*.flac"))) + return files + + +def hypothesis_path(audio: Path) -> Path: + return Path(str(audio) + ".txt") + + +def run_parakeet_batch(args, files: List[Path]) -> int: + file_list = args.work_dir / ".parakeet-files.txt" + file_list.write_text("".join(f"{audio}\n" for audio in files), encoding="utf-8") + + command = [ + str(args.cli), + "--output-txt", + "--model", + str(args.model), + "--file-list", + str(file_list), + ] + + if args.no_prints: + command.append("--no-prints") + + if args.threads is not None: + command.extend(["--threads", str(args.threads)]) + + command.extend(args.extra_args) + + try: + return subprocess.run(command, cwd=args.repo_root).returncode + finally: + file_list.unlink(missing_ok=True) + + +def transcribe(args, files: List[Path]) -> None: + pending = [] + for audio in files: + hyp = hypothesis_path(audio) + if args.force or not hyp.exists(): + pending.append(audio) + + eprint(f"audio files: {len(files)}") + eprint(f"pending transcriptions: {len(pending)}") + + if not pending: + return + + eprint(f"run: {len(pending)} file(s)") + ret = run_parakeet_batch(args, pending) + if ret != 0: + raise RuntimeError(f"parakeet-cli failed with exit code {ret}") + + missing = [audio for audio in pending if not hypothesis_path(audio).exists()] + if missing: + first = missing[0] + raise RuntimeError(f"missing hypothesis after transcription: {first}") + + +def score(args) -> int: + command = [str(args.python), "eval.py"] + completed = subprocess.run( + command, + cwd=args.work_dir, + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + ) + if completed.stdout: + print(completed.stdout, end="") + + if completed.returncode != 0: + return completed.returncode + + if args.result_file: + args.result_file.write_text(completed.stdout, encoding="utf-8") + eprint(f"wrote result: {args.result_file}") + + return 0 + + +def parse_args() -> argparse.Namespace: + script_dir = Path(__file__).resolve().parent + repo_root = script_dir.parent.parent + + parser = argparse.ArgumentParser( + description="Run parakeet-cli on LibriSpeech test-clean and compute WER." + ) + parser.add_argument("--repo-root", type=Path, default=repo_root) + parser.add_argument("--work-dir", type=Path, default=script_dir) + parser.add_argument( + "--dataset", + type=Path, + default=script_dir / "LibriSpeech" / "test-clean", + help="Directory containing LibriSpeech .flac files.", + ) + parser.add_argument( + "--cli", + type=Path, + required=True, + help="Path to parakeet-cli.", + ) + parser.add_argument( + "--model", + type=Path, + default=repo_root / "models" / "ggml-parakeet-tdt-0.6b-v3.bin", + help="Path to converted Parakeet ggml model.", + ) + parser.add_argument("--python", default=sys.executable) + parser.add_argument("--threads", type=int, default=None) + parser.add_argument("--no-prints", action="store_true", help="Pass --no-prints to parakeet-cli.") + parser.add_argument("--force", action="store_true", help="Regenerate existing .flac.txt files.") + parser.add_argument("--download", action="store_true", help="Download and extract test-clean if needed.") + parser.add_argument("--url", default=DEFAULT_URL) + parser.add_argument("--result-file", type=Path, default=None) + parser.add_argument( + "extra_args", + nargs=argparse.REMAINDER, + help="Extra arguments passed to parakeet-cli after '--'.", + ) + + args = parser.parse_args() + + if args.extra_args and args.extra_args[0] == "--": + args.extra_args = args.extra_args[1:] + + args.repo_root = args.repo_root.resolve() + args.work_dir = args.work_dir.resolve() + args.dataset = args.dataset.resolve() + args.cli = args.cli.resolve() + args.model = args.model.resolve() + if args.result_file is not None: + args.result_file = args.result_file.resolve() + + return args + + +def main() -> int: + args = parse_args() + + if args.download: + args.work_dir.mkdir(parents=True, exist_ok=True) + archive = args.work_dir / Path(args.url).name + download_file(args.url, archive) + extract_archive(archive, args.work_dir) + + if not args.cli.exists(): + eprint(f"error: parakeet-cli not found: {args.cli}") + return 2 + + if not args.model.exists(): + eprint(f"error: model not found: {args.model}") + return 2 + + if not args.dataset.exists(): + eprint(f"error: dataset not found: {args.dataset}") + eprint("hint: pass --download or extract test-clean under tests/librispeech-parakeet/LibriSpeech") + return 2 + + eprint(f"repo root: {args.repo_root}") + eprint(f"work dir: {args.work_dir}") + eprint(f"dataset: {args.dataset}") + eprint(f"cli: {args.cli}") + eprint(f"model: {args.model}") + + files = find_flac_files(args.dataset) + if not files: + eprint(f"error: no .flac files found under {args.dataset}") + return 2 + + try: + transcribe(args, files) + except RuntimeError as exc: + eprint(f"error: {exc}") + return 1 + + return score(args) + + +if __name__ == "__main__": + raise SystemExit(main())