|
| 1 | +#include "parakeet.h" |
| 2 | +#include "common-whisper.h" |
| 3 | + |
| 4 | +#include <cstdio> |
| 5 | +#include <string> |
| 6 | +#include <thread> |
| 7 | +#include <vector> |
| 8 | +#include <cstring> |
| 9 | + |
| 10 | +// command-line parameters |
| 11 | +struct parakeet_params { |
| 12 | + int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); |
| 13 | + int32_t chunk_length_ms = 10000; |
| 14 | + int32_t left_context_ms = 10000; |
| 15 | + int32_t right_context_ms = 4960; |
| 16 | + |
| 17 | + bool use_gpu = true; |
| 18 | + bool flash_attn = true; |
| 19 | + int32_t gpu_device = 0; |
| 20 | + |
| 21 | + bool print_segments = false; |
| 22 | + |
| 23 | + std::string model = "models/ggml-parakeet-tdt-0.6b-v3.bin"; |
| 24 | + std::vector<std::string> fname_inp = {}; |
| 25 | +}; |
| 26 | + |
| 27 | +static void parakeet_print_usage(int argc, char ** argv, const parakeet_params & params); |
| 28 | + |
| 29 | +static char * requires_value_error(const std::string & arg) { |
| 30 | + fprintf(stderr, "error: argument %s requires value\n", arg.c_str()); |
| 31 | + exit(1); |
| 32 | +} |
| 33 | + |
| 34 | +static bool parakeet_params_parse(int argc, char ** argv, parakeet_params & params) { |
| 35 | + if (const char * env_device = std::getenv("PARAKEET_ARG_DEVICE")) { |
| 36 | + params.gpu_device = std::stoi(env_device); |
| 37 | + } |
| 38 | + |
| 39 | + for (int i = 1; i < argc; i++) { |
| 40 | + std::string arg = argv[i]; |
| 41 | + |
| 42 | + if (arg == "-"){ |
| 43 | + params.fname_inp.push_back(arg); |
| 44 | + continue; |
| 45 | + } |
| 46 | + |
| 47 | + if (arg[0] != '-') { |
| 48 | + params.fname_inp.push_back(arg); |
| 49 | + continue; |
| 50 | + } |
| 51 | + |
| 52 | + if (arg == "-h" || arg == "--help") { |
| 53 | + parakeet_print_usage(argc, argv, params); |
| 54 | + exit(0); |
| 55 | + } |
| 56 | + #define ARGV_NEXT (((i + 1) < argc) ? argv[++i] : requires_value_error(arg)) |
| 57 | + else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(ARGV_NEXT); } |
| 58 | + else if (arg == "-cl" || arg == "--chunk-length") { params.chunk_length_ms = std::stoi(ARGV_NEXT); } |
| 59 | + else if (arg == "-lc" || arg == "--left-context") { params.left_context_ms = std::stoi(ARGV_NEXT); } |
| 60 | + else if (arg == "-rc" || arg == "--right-context") { params.right_context_ms = std::stoi(ARGV_NEXT); } |
| 61 | + else if (arg == "-m" || arg == "--model") { params.model = ARGV_NEXT; } |
| 62 | + else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(ARGV_NEXT); } |
| 63 | + else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } |
| 64 | + else if (arg == "-dev" || arg == "--device") { params.gpu_device = std::stoi(ARGV_NEXT); } |
| 65 | + else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } |
| 66 | + else if (arg == "-nfa" || arg == "--no-flash-attn") { params.flash_attn = false; } |
| 67 | + else if (arg == "-ps" || arg == "--print-segments") { params.print_segments = true; } |
| 68 | + else { |
| 69 | + fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); |
| 70 | + parakeet_print_usage(argc, argv, params); |
| 71 | + exit(1); |
| 72 | + } |
| 73 | + } |
| 74 | + |
| 75 | + return true; |
| 76 | +} |
| 77 | + |
| 78 | +static void parakeet_print_usage(int /*argc*/, char ** argv, const parakeet_params & params) { |
| 79 | + fprintf(stderr, "\n"); |
| 80 | + fprintf(stderr, "usage: %s [options] file0 file1 ...\n", argv[0]); |
| 81 | + fprintf(stderr, "supported audio formats: flac, mp3, ogg, wav\n"); |
| 82 | + fprintf(stderr, "\n"); |
| 83 | + fprintf(stderr, "options:\n"); |
| 84 | + fprintf(stderr, " -h, --help [default] show this help message and exit\n"); |
| 85 | + fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads); |
| 86 | + fprintf(stderr, " -cl N, --chunk-length N [%-7d] chunk length in milliseconds\n", params.chunk_length_ms); |
| 87 | + fprintf(stderr, " -lc N, --left-context N [%-7d] left context in milliseconds\n", params.left_context_ms); |
| 88 | + fprintf(stderr, " -rc N, --right-context N [%-7d] right context in milliseconds\n", params.right_context_ms); |
| 89 | + fprintf(stderr, " -m, --model FILE [%-7s] model path\n", params.model.c_str()); |
| 90 | + fprintf(stderr, " -f, --file FILE [%-7s] input audio file\n", ""); |
| 91 | + fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true"); |
| 92 | + fprintf(stderr, " -dev N, --device N [%-7d] GPU device to use\n", params.gpu_device); |
| 93 | + fprintf(stderr, " -fa, --flash-attn [%-7s] enable flash attention\n", params.flash_attn ? "true" : "false"); |
| 94 | + fprintf(stderr, " -nfa, --no-flash-attn [%-7s] disable flash attention\n", !params.flash_attn ? "true" : "false"); |
| 95 | + fprintf(stderr, " -ps, --print-segments [%-7s] print segment information\n", params.print_segments ? "true" : "false"); |
| 96 | + fprintf(stderr, "\n"); |
| 97 | +} |
| 98 | + |
| 99 | +void token_callback(parakeet_context * ctx, parakeet_state * state, const parakeet_token_data * token_data, void * user_data) { |
| 100 | + static bool is_first = true; |
| 101 | + |
| 102 | + const char * token_str = parakeet_token_to_str(ctx, token_data->id); |
| 103 | + char text_buf[256]; |
| 104 | + parakeet_token_to_text(token_str, is_first, text_buf, sizeof(text_buf)); |
| 105 | + printf("%s", text_buf); |
| 106 | + fflush(stdout); |
| 107 | + |
| 108 | + is_first = false; |
| 109 | +} |
| 110 | + |
| 111 | + |
| 112 | +int main(int argc, char ** argv) { |
| 113 | + parakeet_params params; |
| 114 | + |
| 115 | + if (parakeet_params_parse(argc, argv, params) == false) { |
| 116 | + return 1; |
| 117 | + } |
| 118 | + |
| 119 | + if (params.fname_inp.empty()) { |
| 120 | + fprintf(stderr, "error: no input files specified\n"); |
| 121 | + parakeet_print_usage(argc, argv, params); |
| 122 | + return 1; |
| 123 | + } |
| 124 | + |
| 125 | + // Process each input file |
| 126 | + for (const auto & fname : params.fname_inp) { |
| 127 | + fprintf(stderr, "\nProcessing file: %s\n", fname.c_str()); |
| 128 | + |
| 129 | + std::vector<float> pcmf32; |
| 130 | + std::vector<std::vector<float>> pcmf32s; |
| 131 | + if (!read_audio_data(fname.c_str(), pcmf32, pcmf32s, false)) { |
| 132 | + fprintf(stderr, "error: failed to read audio file '%s'\n", fname.c_str()); |
| 133 | + continue; |
| 134 | + } |
| 135 | + |
| 136 | + if (pcmf32.empty()) { |
| 137 | + fprintf(stderr, "error: no audio data in file '%s'\n", fname.c_str()); |
| 138 | + continue; |
| 139 | + } |
| 140 | + |
| 141 | + fprintf(stderr, "Loading Parakeet model from: %s\n", params.model.c_str()); |
| 142 | + |
| 143 | + struct parakeet_context_params ctx_params = parakeet_context_default_params(); |
| 144 | + ctx_params.use_gpu = params.use_gpu; |
| 145 | + ctx_params.flash_attn = params.flash_attn; |
| 146 | + ctx_params.gpu_device = params.gpu_device; |
| 147 | + |
| 148 | + struct parakeet_context * pctx = parakeet_init_from_file_with_params_no_state(params.model.c_str(), ctx_params); |
| 149 | + if (pctx == nullptr) { |
| 150 | + fprintf(stderr, "error: failed to load Parakeet model from '%s'\n", params.model.c_str()); |
| 151 | + return 1; |
| 152 | + } |
| 153 | + struct parakeet_state * state = parakeet_init_state(pctx); |
| 154 | + if (state == nullptr) { |
| 155 | + fprintf(stderr, "error: failed to initialize parakeet state\n"); |
| 156 | + parakeet_free(pctx); |
| 157 | + return 2; |
| 158 | + } |
| 159 | + |
| 160 | + fprintf(stderr, "Successfully loaded Parakeet model\n"); |
| 161 | + fprintf(stderr, "Processing audio (%zu samples, %.2f seconds)\n", |
| 162 | + pcmf32.size(), (float)pcmf32.size() / PARAKEET_SAMPLE_RATE); |
| 163 | + |
| 164 | + struct parakeet_full_params full_params = parakeet_full_default_params(PARAKEET_SAMPLING_GREEDY); |
| 165 | + full_params.n_threads = params.n_threads; |
| 166 | + full_params.chunk_length_ms = params.chunk_length_ms; |
| 167 | + full_params.left_context_ms = params.left_context_ms; |
| 168 | + full_params.right_context_ms = params.right_context_ms; |
| 169 | + full_params.new_token_callback = token_callback; |
| 170 | + full_params.new_token_callback_user_data = nullptr; |
| 171 | + |
| 172 | + const int mel_frames = pcmf32.size() / PARAKEET_HOP_LENGTH; |
| 173 | + const int model_max_ctx = parakeet_n_audio_ctx(pctx); |
| 174 | + const bool fits_single_chunk = mel_frames <= model_max_ctx; |
| 175 | + |
| 176 | + int ret; |
| 177 | + if (fits_single_chunk) { |
| 178 | + ret = parakeet_chunk(pctx, state, full_params, pcmf32.data(), pcmf32.size()); |
| 179 | + } else { |
| 180 | + ret = parakeet_full_with_state(pctx, state, full_params, pcmf32.data(), pcmf32.size()); |
| 181 | + } |
| 182 | + |
| 183 | + if (ret != 0) { |
| 184 | + fprintf(stderr, "error: failed to process audio file '%s'\n", fname.c_str()); |
| 185 | + parakeet_free_state(state); |
| 186 | + parakeet_free(pctx); |
| 187 | + continue; |
| 188 | + } |
| 189 | + |
| 190 | + printf("\n"); |
| 191 | + |
| 192 | + if (params.print_segments) { |
| 193 | + const int n_segments = parakeet_full_n_segments_from_state(state); |
| 194 | + fprintf(stderr, "\nSegments (%d):\n", n_segments); |
| 195 | + |
| 196 | + for (int i = 0; i < n_segments; i++) { |
| 197 | + const char * text = parakeet_full_get_segment_text_from_state(state, i); |
| 198 | + const int64_t t0 = parakeet_full_get_segment_t0_from_state(state, i); |
| 199 | + const int64_t t1 = parakeet_full_get_segment_t1_from_state(state, i); |
| 200 | + const int n_tokens = parakeet_full_n_tokens_from_state(state, i); |
| 201 | + |
| 202 | + fprintf(stderr, "Segment %d: [%lld -> %lld] \"%s\"\n", i, (long long)t0, (long long)t1, text); |
| 203 | + fprintf(stderr, "Tokens [%d]:\n", n_tokens); |
| 204 | + |
| 205 | + for (int j = 0; j < n_tokens; j++) { |
| 206 | + parakeet_token_data token_data = parakeet_full_get_token_data_from_state(state, i, j); |
| 207 | + const char * token_str = parakeet_token_to_str(pctx, token_data.id); |
| 208 | + |
| 209 | + fprintf(stderr, " [%2d] id=%5d frame=%3d dur_idx=%2d dur_val=%2d p=%.4f plog=%.4f t0=%4lld t1=%4lld word_start=%s \"%s\"\n", |
| 210 | + j, |
| 211 | + token_data.id, |
| 212 | + token_data.frame_index, |
| 213 | + token_data.duration_idx, |
| 214 | + token_data.duration_value, |
| 215 | + token_data.p, |
| 216 | + token_data.plog, |
| 217 | + (long long)token_data.t0, |
| 218 | + (long long)token_data.t1, |
| 219 | + token_data.is_word_start ? "true": "false", |
| 220 | + token_str); |
| 221 | + } |
| 222 | + } |
| 223 | + } |
| 224 | + |
| 225 | + parakeet_free_state(state); |
| 226 | + parakeet_free(pctx); |
| 227 | + } |
| 228 | + |
| 229 | + return 0; |
| 230 | +} |
0 commit comments