diff --git a/examples/parakeet-cli/README.md b/examples/parakeet-cli/README.md index ccb8404f542..19a8e5a5381 100644 --- a/examples/parakeet-cli/README.md +++ b/examples/parakeet-cli/README.md @@ -28,6 +28,10 @@ options: -ng, --no-gpu [false ] disable GPU -dev N, --device N [0 ] GPU device to use -ps, --print-segments [false ] print segment information + --stream process audio in overlapping windows + -lc N, --left-context-ms N left context per stream window (ms) in multiple of 80ms (default: 10000) + -cs N, --chunk-ms N emitted audio per stream window (ms) in multiple of 80ms (default: 2000) + -rc N, --right-context-ms N right context per stream window (ms) in multiple of 80ms (default: 2000) ``` ### Example @@ -39,6 +43,13 @@ parakeet_decode: starting decode with n_frames=138 And so, my fellow Americans, ask not what your country can do for you, ask what you can do for your country. ``` +Streaming mode encodes overlapping `[left | chunk | right]` windows and emits only tokens that begin in the chunk. Defaults are `[10000 | 2000 | 2000]` (ms): +```console +$ ./build/bin/parakeet-cli -m models/parakeet-tdt-0.6b-v3-f16.bin -f samples/jfk.wav --stream --left-context-ms 10000 --chunk-ms 2000 --right-context-ms 2000 +``` + +This mode uses the existing encoder attention implementation. It does not reproduce NeMo configurable limited-right-context attention. + To print segment information: ```console $ ./build/bin/parakeet-cli -m models/parakeet-tdt-0.6b-v3-f16.bin -f samples/jfk.wav --print-segments diff --git a/examples/parakeet-cli/parakeet-cli.cpp b/examples/parakeet-cli/parakeet-cli.cpp index 03ddc7f8b8c..cafe4dcdaf5 100644 --- a/examples/parakeet-cli/parakeet-cli.cpp +++ b/examples/parakeet-cli/parakeet-cli.cpp @@ -18,6 +18,11 @@ struct parakeet_params { bool print_segments = false; bool output_txt = false; bool no_prints = false; + bool stream = false; + + int32_t left_context_ms = 10000; + int32_t chunk_ms = 2000; + int32_t right_context_ms = 2000; std::string model = "models/ggml-parakeet-tdt-0.6b-v3.bin"; std::string output_file = ""; @@ -63,6 +68,10 @@ static bool parakeet_params_parse(int argc, char ** argv, parakeet_params & para else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; } else if (arg == "-of" || arg == "--output-file") { params.output_file = ARGV_NEXT; } else if (arg == "-np" || arg == "--no-prints") { params.no_prints = true; } + else if (arg == "--stream") { params.stream = true; } + else if (arg == "-lc" || arg == "--left-context-ms") { params.left_context_ms = std::stoi(ARGV_NEXT); } + else if (arg == "-cs" || arg == "--chunk-ms") { params.chunk_ms = std::stoi(ARGV_NEXT); } + else if (arg == "-rc" || arg == "--right-context-ms") { params.right_context_ms = std::stoi(ARGV_NEXT); } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); parakeet_print_usage(argc, argv, params); @@ -89,6 +98,10 @@ static void parakeet_print_usage(int /*argc*/, char ** argv, const parakeet_para fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false"); fprintf(stderr, " -of, --output-file FILE [%-7s] output file path (without file extension)\n", ""); fprintf(stderr, " -np, --no-prints [%-7s] do not print anything other than the results\n", params.no_prints ? "true" : "false"); + fprintf(stderr, " --stream [%-7s] process audio in overlapping windows\n", params.stream ? "true" : "false"); + fprintf(stderr, " -lc N, --left-context-ms N [%-7d] left context per stream window (ms) in multiple of 80ms\n", params.left_context_ms); + fprintf(stderr, " -cs N, --chunk-ms N [%-7d] emitted audio per stream window (ms) in multiple of 80ms\n", params.chunk_ms); + fprintf(stderr, " -rc N, --right-context-ms N [%-7d] right context per stream window (ms) in multiple of 80ms\n", params.right_context_ms); fprintf(stderr, "\n"); } @@ -129,6 +142,11 @@ int main(int argc, char ** argv) { ctx_params.use_gpu = params.use_gpu; ctx_params.gpu_device = params.gpu_device; + struct parakeet_stream_params stream_params = parakeet_stream_default_params(); + stream_params.left_context_ms = params.left_context_ms; + stream_params.chunk_ms = params.chunk_ms; + stream_params.right_context_ms = params.right_context_ms; + if (!params.no_prints) { fprintf(stderr, "Loading Parakeet model from: %s\n", params.model.c_str()); } @@ -171,7 +189,9 @@ int main(int argc, char ** argv) { full_params.new_token_callback_user_data = &is_first; const int mel_frames = (int)(pcmf32.size() / PARAKEET_HOP_LENGTH); - int ret = parakeet_full(pctx, full_params, pcmf32.data(), pcmf32.size()); + const int ret = params.stream + ? parakeet_full_stream(pctx, full_params, stream_params, pcmf32.data(), pcmf32.size()) + : parakeet_full(pctx, full_params, pcmf32.data(), pcmf32.size()); if (ret != 0) { fprintf(stderr, "error: failed to process audio file '%s'\n", fname.c_str()); diff --git a/include/parakeet.h b/include/parakeet.h index d35aa870adb..377b994d0d7 100644 --- a/include/parakeet.h +++ b/include/parakeet.h @@ -265,12 +265,21 @@ extern "C" { void * abort_callback_user_data; }; + // Parameters for parakeet_full_stream(). All durations are positive milliseconds. + // Values must be multiples of the encoder frame duration (80 ms). + struct parakeet_stream_params { + int left_context_ms; + int chunk_ms; + int right_context_ms; + }; + // NOTE: this function allocates memory, and it is the responsibility of the caller to free the pointer - see parakeet_free_context_params() & parakeet_free_params() PARAKEET_API struct parakeet_context_params * parakeet_context_default_params_by_ref(void); PARAKEET_API struct parakeet_context_params parakeet_context_default_params (void); PARAKEET_API struct parakeet_full_params * parakeet_full_default_params_by_ref(enum parakeet_sampling_strategy strategy); PARAKEET_API struct parakeet_full_params parakeet_full_default_params (enum parakeet_sampling_strategy strategy); + PARAKEET_API struct parakeet_stream_params parakeet_stream_default_params (void); // Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text // Not thread safe for same context @@ -287,6 +296,24 @@ extern "C" { const float * samples, int n_samples); + // Nvidia Nemo example of parakeet streaming + // https://github.com/NVIDIA-NeMo/NeMo/blob/main/examples/asr/asr_chunked_inference/rnnt/speech_to_text_streaming_infer_rnnt.py + // Example of 10-2-3 window: encoder (full 15s) -> decoder (middle 2s) -> text (middle s) + PARAKEET_API int parakeet_full_stream( + struct parakeet_context * ctx, + struct parakeet_full_params params, + struct parakeet_stream_params stream_params, + const float * samples, + int n_samples); + + PARAKEET_API int parakeet_full_stream_with_state( + struct parakeet_context * ctx, + struct parakeet_state * state, + struct parakeet_full_params params, + struct parakeet_stream_params stream_params, + const float * samples, + int n_samples); + // Process a single chunk of audio data that fits within the model's audio context window. // This is more efficient than parakeet_full() for short audio clips. PARAKEET_API int parakeet_chunk( diff --git a/src/parakeet.cpp b/src/parakeet.cpp index b5da73e985c..a4fd7177d6d 100644 --- a/src/parakeet.cpp +++ b/src/parakeet.cpp @@ -2448,12 +2448,17 @@ static parakeet_token_data create_token_data( return token_data; } -static bool parakeet_decode( +static bool parakeet_decode_internal( parakeet_context & pctx, parakeet_state & pstate, parakeet_batch & batch, const int n_threads, - const parakeet_full_params * params = nullptr) { + const parakeet_full_params * params, + int frame_begin, + int frame_end, + int frame_offset, + int time_offset, + bool init_predictor_from_blank) { const auto & hparams = pctx.model.hparams; const auto & tdt_durations = pctx.model.tdt_durations; @@ -2463,32 +2468,41 @@ static bool parakeet_decode( const int n_vocab_logits = blank_id + 1; const int max_tokens_per_timestep = hparams.n_max_tokens; + if (frame_end < 0) { + frame_end = n_frames; + } + if (frame_begin < 0 || frame_begin > frame_end || frame_end > n_frames) { + PARAKEET_LOG_ERROR("%s: invalid decode range [%d, %d) for %d frames\n", __func__, frame_begin, frame_end, n_frames); + return false; + } + // time index into the encoder frame (current time frame) - int t = 0; + int t = frame_begin; // number of symbols emitted for the current time frame int tokens_emitted = 0; // Start with the blank token (8192) parakeet_token last_token = blank_id; - PARAKEET_LOG_DEBUG("parakeet_decode: starting decode with n_frames=%d\n", n_frames); - - batch.n_tokens = 1; - batch.token[0] = last_token; - batch.logits[0] = 1; - batch.i_time[0] = 0; + PARAKEET_LOG_DEBUG("parakeet_decode: starting decode in [%d, %d) of %d frames\n", frame_begin, frame_end, n_frames); - // run the prediction network for the initial blank token. This will - // initialize the LSTM state and produce an initial hidden state that can - // be used in the joint network below. - if (!parakeet_predict(pctx, pstate, batch, n_threads, - params ? params->abort_callback : nullptr, - params ? params->abort_callback_user_data : nullptr)) { - return false; + // Control whether to reuse the predictor state from the prior chunk (streaming) + // or start with blank + if (init_predictor_from_blank) { + batch.n_tokens = 1; + batch.token[0] = last_token; + batch.logits[0] = 1; + batch.i_time[0] = frame_begin; + // Initialize the predictor state from the blank token. + if (!parakeet_predict(pctx, pstate, batch, n_threads, + params ? params->abort_callback : nullptr, + params ? params->abort_callback_user_data : nullptr)) { + return false; + } } // process all time frames of the encoder output - while (t < n_frames) { + while (t < frame_end) { batch.n_tokens = 1; batch.i_time[0] = t; batch.logits[0] = 1; @@ -2552,6 +2566,9 @@ static bool parakeet_decode( parakeet_token_data token_data = create_token_data( pctx, pstate, best_token, best_duration_idx, duration, t, max_logit, n_vocab_logits); + token_data.frame_index += frame_offset; + token_data.t0 += time_offset; + token_data.t1 += time_offset; pstate.decoded_token_data.push_back(token_data); @@ -2589,6 +2606,34 @@ static bool parakeet_decode( return true; } +static bool parakeet_decode( + parakeet_context & pctx, + parakeet_state & pstate, + parakeet_batch & batch, + const int n_threads, + const parakeet_full_params * params = nullptr) { + return parakeet_decode_internal( + pctx, pstate, batch, n_threads, params, + 0, pstate.n_frames, 0, 0, true); +} + +static bool parakeet_decode_stream( + parakeet_context & pctx, + parakeet_state & pstate, + parakeet_batch & batch, + const int n_threads, + const parakeet_full_params * params, + int frame_begin, + int frame_end, + int frame_offset, + int time_offset, + bool init_predictor_from_blank) { + return parakeet_decode_internal( + pctx, pstate, batch, n_threads, params, + frame_begin, frame_end, frame_offset, time_offset, + init_predictor_from_blank); +} + // 500 -> 00:05.000 // 6000 -> 01:00.000 // naive Discrete Fourier Transform @@ -3504,9 +3549,21 @@ struct parakeet_full_params parakeet_full_default_params(enum parakeet_sampling_ return result; } -static void parakeet_reset_state(struct parakeet_state * state) { +struct parakeet_stream_params parakeet_stream_default_params(void) { + return { + /*.left_context_ms =*/ 10000, + /*.chunk_ms =*/ 2000, + /*.right_context_ms =*/ 2000, + }; +} + +static void parakeet_clear_decoded_output(struct parakeet_state * state) { state->decoded_tokens.clear(); state->decoded_token_data.clear(); +} + +static void parakeet_reset_state(struct parakeet_state * state) { + parakeet_clear_decoded_output(state); if (state->lstm_state.buffer) { ggml_backend_buffer_clear(state->lstm_state.buffer, 0); @@ -3634,6 +3691,155 @@ int parakeet_full( return parakeet_full_with_state(ctx, ctx->state, params, samples, n_samples); } +int parakeet_full_stream_with_state( + struct parakeet_context * ctx, + struct parakeet_state * state, + struct parakeet_full_params params, + struct parakeet_stream_params stream_params, + const float * samples, + int n_samples) { + + const int frame_stride_samples = PARAKEET_HOP_LENGTH * ctx->model.hparams.subsampling_factor; + const int frame_stride_ms = frame_stride_samples * 1000 / PARAKEET_SAMPLE_RATE; // 80ms + // Check if it is multiple of frame_stride_ms (80ms) + const auto is_valid_duration = [frame_stride_ms](int duration_ms) { + return duration_ms > 0 && duration_ms % frame_stride_ms == 0; + }; + + // Streaming slices the caller-provided PCM buffer with samples + buffer_start, + // non-null input buffer with at least one sample is required. + if (!samples || n_samples <= 0 || + !is_valid_duration(stream_params.left_context_ms) || + !is_valid_duration(stream_params.chunk_ms) || + !is_valid_duration(stream_params.right_context_ms) || + params.audio_ctx != 0) { + PARAKEET_LOG_ERROR("%s: invalid streaming parameters\n", __func__); + return -1; + } + + const int left_samples = stream_params.left_context_ms * PARAKEET_SAMPLE_RATE / 1000; + const int chunk_samples = stream_params.chunk_ms * PARAKEET_SAMPLE_RATE / 1000; + const int right_samples = stream_params.right_context_ms * PARAKEET_SAMPLE_RATE / 1000; + const int total_samples = left_samples + chunk_samples + right_samples; + // Calculation derived from PyTorch torch.stft docs : `T = 1 + L // hop_length` if center=True + const int max_window_mel_frames = 1 + total_samples / PARAKEET_HOP_LENGTH ; + const int model_audio_ctx = parakeet_n_audio_ctx(ctx); + + if (model_audio_ctx > 0 && max_window_mel_frames > model_audio_ctx) { + PARAKEET_LOG_ERROR("%s: streaming window (%d mel frames) exceeds model context (%d)\n", + __func__, max_window_mel_frames, model_audio_ctx); + return -1; + } + + state->result_all.clear(); + parakeet_reset_state(state); + + bool init_predictor_from_blank = true; + + if (params.progress_callback) { + params.progress_callback(ctx, state, 0, params.progress_callback_user_data); + } + + for (int chunk_start = 0; chunk_start < n_samples;) { + // [0---------------------------full audio--------------------------n_samples] + // buffer_start [----------encode-----------] buffer_end + // chunk_start [---decode---] chunk_end + const int chunk_end = std::min(n_samples, chunk_start + chunk_samples); + const int buffer_start = std::max(0, chunk_start - left_samples); + const int buffer_end = std::min(n_samples, chunk_end + right_samples); + const int buffer_samples = buffer_end - buffer_start; + + parakeet_clear_decoded_output(state); + + if (parakeet_pcm_to_mel_with_state(ctx, state, samples + buffer_start, buffer_samples, params.n_threads) != 0) { + PARAKEET_LOG_ERROR("%s: failed to compute mel spectrogram\n", __func__); + return -2; + } + + state->n_audio_ctx = state->mel.n_len; + if (!parakeet_ensure_encode_sched(*ctx, *state, state->n_audio_ctx)) { + PARAKEET_LOG_ERROR("%s: failed to allocate encoder graph for %d mel frames\n", + __func__, state->n_audio_ctx); + return -6; + } + + if (params.encoder_begin_callback && + !params.encoder_begin_callback(ctx, state, params.encoder_begin_callback_user_data)) { + PARAKEET_LOG_ERROR("%s: encoder_begin_callback returned false - aborting\n", __func__); + return -6; + } + + if (!parakeet_encode_internal(*ctx, *state, 0, params.n_threads, + params.abort_callback, params.abort_callback_user_data)) { + PARAKEET_LOG_ERROR("%s: failed to encode\n", __func__); + return -6; + } + + // Encoded full window : left context + chunk + right context. + // Only the middle chunk is decoded below. + const int decode_begin = (chunk_start - buffer_start) / frame_stride_samples; + const int decode_end = std::min(state->n_frames, + (chunk_end - buffer_start + frame_stride_samples - 1) / frame_stride_samples); + const int frame_offset = buffer_start / frame_stride_samples; + const int time_offset = buffer_start / PARAKEET_HOP_LENGTH; + + if (!parakeet_decode_stream(*ctx, *state, state->batch, params.n_threads, ¶ms, + decode_begin, decode_end, frame_offset, time_offset, + init_predictor_from_blank)) { + PARAKEET_LOG_ERROR("%s: failed to decode\n", __func__); + return -7; + } + + init_predictor_from_blank = false; + + if (!state->decoded_tokens.empty()) { + std::string text; + std::vector result_tokens; + result_tokens.reserve(state->decoded_tokens.size()); + + for (size_t i = 0; i < state->decoded_tokens.size(); i++) { + const char * tok_str = parakeet_token_to_str(ctx, state->decoded_tokens[i]); + if (tok_str) { + text += sentencepiece_piece_to_text(tok_str, text.empty()); + } + result_tokens.push_back(state->decoded_token_data[i]); + } + + refine_timestamps_tdt(ctx->vocab, result_tokens); + + if (!text.empty()) { + parakeet_segment segment; + segment.t0 = chunk_start / PARAKEET_HOP_LENGTH; + segment.t1 = (chunk_end + PARAKEET_HOP_LENGTH - 1) / PARAKEET_HOP_LENGTH; + segment.text = text; + segment.tokens = std::move(result_tokens); + state->result_all.push_back(std::move(segment)); + + if (params.new_segment_callback) { + params.new_segment_callback(ctx, state, 1, params.new_segment_callback_user_data); + } + } + } + + if (params.progress_callback) { + params.progress_callback(ctx, state, 100 * chunk_end / n_samples, + params.progress_callback_user_data); + } + chunk_start = chunk_end; + } + + return 0; +} + +int parakeet_full_stream( + struct parakeet_context * ctx, + struct parakeet_full_params params, + struct parakeet_stream_params stream_params, + const float * samples, + int n_samples) { + return parakeet_full_stream_with_state(ctx, ctx->state, params, stream_params, samples, n_samples); +} + int parakeet_chunk( struct parakeet_context * ctx, struct parakeet_state * state, diff --git a/tests/test-parakeet-full.cpp b/tests/test-parakeet-full.cpp index 22ac4c20e31..75e674fffcd 100644 --- a/tests/test-parakeet-full.cpp +++ b/tests/test-parakeet-full.cpp @@ -90,9 +90,52 @@ int main() { const std::string expected = read_expected_transcription(EXPECTED_TRANSCRIPTION_PATH); const bool transcript_matches = verify_transcription(expected, tstate.transcript); + const parakeet_stream_params stream_params = parakeet_stream_default_params(); + assert(stream_params.left_context_ms == 10000); + assert(stream_params.chunk_ms == 2000); + assert(stream_params.right_context_ms == 2000); + + stream_params.left_context_ms = 8000; + stream_params.chunk_ms = 1600; + stream_params.right_context_ms = 2400; + + test_state stream_tstate; + params.new_token_callback_user_data = &stream_tstate; + ret = parakeet_full_stream(pctx, params, stream_params, pcmf32.data(), pcmf32.size()); + assert(ret == 0); + const bool stream_transcript_matches = verify_transcription(expected, stream_tstate.transcript); + const int n_stream_segments = parakeet_full_n_segments(pctx); + assert(n_stream_segments >= 2); + int64_t previous_t1 = 0; + for (int i = 0; i < n_stream_segments; ++i) { + const int64_t t0 = parakeet_full_get_segment_t0(pctx, i); + const int64_t t1 = parakeet_full_get_segment_t1(pctx, i); + assert(t0 >= previous_t1); + assert(t1 > t0); + + const int n_tokens = parakeet_full_n_tokens(pctx, i); + for (int j = 0; j < n_tokens; ++j) { + const parakeet_token_data token = parakeet_full_get_token_data(pctx, i, j); + assert(token.t0 >= t0); + assert(token.t0 < t1); + } + previous_t1 = t1; + } + + parakeet_stream_params invalid_stream_params = stream_params; + test_state repeated_stream_tstate; + params.new_token_callback_user_data = &repeated_stream_tstate; + ret = parakeet_full_stream(pctx, params, stream_params, pcmf32.data(), pcmf32.size()); + assert(ret == 0); + const bool repeated_stream_transcript_matches = verify_transcription(expected, repeated_stream_tstate.transcript); + + invalid_stream_params.chunk_ms = 100; + ret = parakeet_full_stream(pctx, params, invalid_stream_params, pcmf32.data(), pcmf32.size()); + assert(ret == -1); + parakeet_free(pctx); - if (!transcript_matches) { + if (!transcript_matches || !stream_transcript_matches || !repeated_stream_transcript_matches) { return 1; }