@@ -580,6 +580,18 @@ struct tdt_stream_state {
580580 bool initialized; // whether prediction LSTM state has been initialized
581581};
582582
583+ struct parakeet_stream {
584+ std::vector<float > buffer;
585+ int64_t n_samples_advanced = 0 ;
586+
587+ int n_left_ctx = 0 ;
588+ int n_chunk = 0 ;
589+ int n_right_ctx = 0 ;
590+
591+ parakeet_full_params params = {};
592+ bool initialized = false ;
593+ };
594+
583595struct parakeet_state {
584596 int64_t t_sample_us = 0 ;
585597 int64_t t_encode_us = 0 ;
@@ -638,6 +650,8 @@ struct parakeet_state {
638650 parakeet_lstm_state lstm_state;
639651
640652 struct tdt_stream_state tdt_stream_state = {0 , 0 , 0 , false };
653+
654+ parakeet_stream stream;
641655};
642656
643657// FFT cache for mel spectrogram computation
@@ -2367,7 +2381,7 @@ static bool parakeet_decode(
23672381 // Start with the blank token (8192)
23682382 parakeet_token last_token = blank_id;
23692383
2370- PARAKEET_LOG_INFO (" parakeet_decode: starting decode with n_frames=%d\n " , n_frames);
2384+ PARAKEET_LOG_DEBUG (" parakeet_decode: starting decode with n_frames=%d\n " , n_frames);
23712385
23722386 batch.n_tokens = 1 ;
23732387 batch.token [0 ] = last_token;
@@ -3609,6 +3623,259 @@ struct parakeet_full_params parakeet_full_default_params(enum parakeet_sampling_
36093623 return result;
36103624}
36113625
3626+ static void parakeet_stream_reset_state (struct parakeet_state * state) {
3627+ if (state == nullptr ) {
3628+ return ;
3629+ }
3630+
3631+ if (state->lstm_state .buffer ) {
3632+ ggml_backend_buffer_clear (state->lstm_state .buffer , 0 );
3633+ }
3634+
3635+ state->decoded_tokens .clear ();
3636+ state->decoded_token_data .clear ();
3637+ state->result_all .clear ();
3638+
3639+ state->tdt_stream_state .initialized = false ;
3640+ state->tdt_stream_state .last_token = 0 ;
3641+ state->tdt_stream_state .time_step = 0 ;
3642+ state->tdt_stream_state .decoded_length = 0 ;
3643+
3644+ state->stream .buffer .clear ();
3645+ state->stream .n_samples_advanced = 0 ;
3646+ state->stream .n_left_ctx = 0 ;
3647+ state->stream .n_chunk = 0 ;
3648+ state->stream .n_right_ctx = 0 ;
3649+ state->stream .params = {};
3650+ state->stream .initialized = false ;
3651+
3652+ state->enc_out_buffer .clear ();
3653+ state->enc_out_frames = 0 ;
3654+ state->n_frames = 0 ;
3655+ state->n_audio_ctx = 0 ;
3656+ }
3657+
3658+ static int parakeet_stream_process_window (
3659+ struct parakeet_context * ctx,
3660+ struct parakeet_state * state,
3661+ const float * samples,
3662+ int n_samples,
3663+ int n_chunk) {
3664+ const parakeet_stream & stream = state->stream ;
3665+ const parakeet_full_params & params = stream.params ;
3666+ const int d_enc = ctx->model .hparams .n_audio_state ;
3667+
3668+ // process all the samples.
3669+ if (parakeet_pcm_to_mel_with_state (ctx, state, samples, n_samples, params.n_threads ) != 0 ) {
3670+ return -2 ;
3671+ }
3672+
3673+ const int left_mel_frames = stream.n_left_ctx / PARAKEET_HOP_LENGTH;
3674+ const int chunk_mel_frames = n_chunk / PARAKEET_HOP_LENGTH;
3675+
3676+ state->n_audio_ctx = state->mel .n_len ;
3677+ // process entire log mel spectrogram.
3678+ if (!parakeet_encode_internal (*ctx, *state, 0 , params.n_threads ,
3679+ params.abort_callback , params.abort_callback_user_data )) {
3680+ return -6 ;
3681+ }
3682+
3683+ const int left_enc_frames = left_mel_frames / ctx->model .hparams .subsampling_factor ;
3684+ const int chunk_enc_frames = chunk_mel_frames / ctx->model .hparams .subsampling_factor ;
3685+
3686+ if (chunk_enc_frames <= 0 ) {
3687+ return 0 ;
3688+ }
3689+
3690+ // Copy the center chunk so that it is the only part that the joint network sees.
3691+ state->enc_out_buffer .resize (chunk_enc_frames * d_enc);
3692+ ggml_backend_tensor_get (state->enc_out , state->enc_out_buffer .data (),
3693+ left_enc_frames * d_enc * sizeof (float ),
3694+ chunk_enc_frames * d_enc * sizeof (float ));
3695+
3696+ state->enc_out_frames = chunk_enc_frames;
3697+ state->n_frames = chunk_enc_frames;
3698+
3699+ const size_t tokens_before = state->decoded_tokens .size ();
3700+
3701+ // Run the prediction and joint network on the center chunk.
3702+ if (!parakeet_decode_chunk (*ctx, *state, state->batch , chunk_enc_frames, params.n_threads , ¶ms)) {
3703+ return -7 ;
3704+ }
3705+
3706+ const size_t tokens_after = state->decoded_tokens .size ();
3707+ const size_t new_token_count = tokens_after - tokens_before;
3708+
3709+ if (new_token_count > 0 ) {
3710+ std::string text;
3711+ std::vector<parakeet_token_data> result_tokens;
3712+ const int64_t chunk_t0 = 100LL * stream.n_samples_advanced / PARAKEET_SAMPLE_RATE;
3713+ const int64_t chunk_t1 = 100LL * (stream.n_samples_advanced + n_chunk) / PARAKEET_SAMPLE_RATE;
3714+ const int frame_offset = chunk_t0 / ctx->model .hparams .subsampling_factor ;
3715+
3716+ result_tokens.reserve (new_token_count);
3717+
3718+ for (size_t i = tokens_before; i < tokens_after; ++i) {
3719+ const auto token_id = state->decoded_tokens [i];
3720+ const char * token_str = parakeet_token_to_str (ctx, token_id);
3721+ if (token_str) {
3722+ const bool is_first_piece = (tokens_before == 0 ) && text.empty ();
3723+ text += sentencepiece_piece_to_text (token_str, is_first_piece);
3724+ }
3725+
3726+ auto token_data = state->decoded_token_data [i];
3727+ token_data.frame_index += frame_offset;
3728+ token_data.t0 += chunk_t0;
3729+ token_data.t1 += chunk_t0;
3730+ result_tokens.push_back (token_data);
3731+ }
3732+
3733+ refine_timestamps_tdt (ctx->vocab , result_tokens);
3734+
3735+ if (!text.empty ()) {
3736+ parakeet_segment segment;
3737+ segment.t0 = chunk_t0;
3738+ segment.t1 = chunk_t1;
3739+ segment.text = std::move (text);
3740+ segment.tokens = std::move (result_tokens);
3741+
3742+ state->result_all .push_back (std::move (segment));
3743+
3744+ if (params.new_segment_callback ) {
3745+ params.new_segment_callback (ctx, state, 1 , params.new_segment_callback_user_data );
3746+ }
3747+ }
3748+ }
3749+
3750+ return 0 ;
3751+ }
3752+
3753+ static int ms_to_n_samples (int ms) {
3754+ return ms * PARAKEET_SAMPLE_RATE / 1000 ;
3755+ }
3756+
3757+ int parakeet_stream_init (
3758+ struct parakeet_context * ctx,
3759+ struct parakeet_state * state,
3760+ struct parakeet_full_params params) {
3761+ if (ctx == nullptr || state == nullptr ) {
3762+ return -1 ;
3763+ }
3764+
3765+ const int n_left_ctx = ms_to_n_samples (params.left_context_ms );
3766+ const int n_chunk = ms_to_n_samples (params.chunk_length_ms );
3767+ const int n_right_ctx = ms_to_n_samples (params.right_context_ms );
3768+
3769+ if (n_left_ctx < 0 || n_chunk <= 0 || n_right_ctx < 0 ) {
3770+ return -1 ;
3771+ }
3772+
3773+ parakeet_stream_reset_state (state);
3774+
3775+ state->stream .n_left_ctx = n_left_ctx;
3776+ state->stream .n_chunk = n_chunk;
3777+ state->stream .n_right_ctx = n_right_ctx;
3778+ state->stream .params = params;
3779+ state->stream .initialized = true ;
3780+
3781+ if (n_left_ctx > 0 ) {
3782+ state->stream .buffer .assign (n_left_ctx, 0 .0f );
3783+ }
3784+
3785+ return 0 ;
3786+ }
3787+
3788+ int parakeet_stream_push (
3789+ struct parakeet_context * ctx,
3790+ struct parakeet_state * state,
3791+ const float * samples,
3792+ int n_samples) {
3793+ if (ctx == nullptr || state == nullptr || samples == nullptr || n_samples <= 0 ) {
3794+ return -1 ;
3795+ }
3796+
3797+ if (!state->stream .initialized ) {
3798+ return -1 ;
3799+ }
3800+
3801+ const int n_total_samples = state->stream .n_left_ctx + state->stream .n_chunk + state->stream .n_right_ctx ;
3802+
3803+ // Insert the new chunk of samples as the new center and right context.
3804+ state->stream .buffer .insert (state->stream .buffer .end (), samples, samples + n_samples);
3805+
3806+ // As long as we have enough samples to form a complete window we process it.
3807+ while (state->stream .buffer .size () >= (size_t ) n_total_samples) {
3808+ const int ret = parakeet_stream_process_window (
3809+ ctx,
3810+ state,
3811+ state->stream .buffer .data (),
3812+ n_total_samples,
3813+ state->stream .n_chunk );
3814+ if (ret != 0 ) {
3815+ return ret;
3816+ }
3817+
3818+ // TODO: std::vector::erase is O(n) and not optimal. We should probably
3819+ // use a ring buffer instead.
3820+ // Shift the center and right context to the start of the buffer. This
3821+ // allows the next call to have the current center chunk as its left
3822+ // context, and the right context will become part of the next target
3823+ // chunk together with the new samples which will make up the rest of
3824+ // the target chunk and the new right context.
3825+ state->stream .buffer .erase (state->stream .buffer .begin (), state->stream .buffer .begin () + state->stream .n_chunk );
3826+
3827+ state->stream .n_samples_advanced += state->stream .n_chunk ;
3828+ }
3829+
3830+ return 0 ;
3831+ }
3832+
3833+ int parakeet_stream_flush (
3834+ struct parakeet_context * ctx,
3835+ struct parakeet_state * state) {
3836+ if (ctx == nullptr || state == nullptr ) {
3837+ return -1 ;
3838+ }
3839+
3840+ if (!state->stream .initialized ) {
3841+ return -1 ;
3842+ }
3843+
3844+ while (state->stream .buffer .size () > (size_t ) state->stream .n_left_ctx ) {
3845+ const int n_remaining_samples = (int ) state->stream .buffer .size () - state->stream .n_left_ctx ;
3846+ const int n_flush_chunk = std::min (state->stream .n_chunk , n_remaining_samples);
3847+ const int n_right_available = std::min (state->stream .n_right_ctx , n_remaining_samples - n_flush_chunk);
3848+ const int n_copied = state->stream .n_left_ctx + n_flush_chunk + n_right_available;
3849+
3850+ std::vector<float > flush_window (state->stream .n_left_ctx + n_flush_chunk + state->stream .n_right_ctx , 0 .0f );
3851+
3852+ std::copy_n (state->stream .buffer .begin (), n_copied, flush_window.begin ());
3853+
3854+ const int ret = parakeet_stream_process_window (
3855+ ctx,
3856+ state,
3857+ flush_window.data (),
3858+ (int ) flush_window.size (),
3859+ n_flush_chunk);
3860+ if (ret != 0 ) {
3861+ return ret;
3862+ }
3863+
3864+ state->stream .buffer .erase (state->stream .buffer .begin (), state->stream .buffer .begin () + n_flush_chunk);
3865+ state->stream .n_samples_advanced += n_flush_chunk;
3866+ }
3867+
3868+ state->stream .buffer .clear ();
3869+ state->stream .n_samples_advanced = 0 ;
3870+ state->stream .n_left_ctx = 0 ;
3871+ state->stream .n_chunk = 0 ;
3872+ state->stream .n_right_ctx = 0 ;
3873+ state->stream .params = {};
3874+ state->stream .initialized = false ;
3875+
3876+ return 0 ;
3877+ }
3878+
36123879int parakeet_full_with_state (
36133880 struct parakeet_context * ctx,
36143881 struct parakeet_state * state,
@@ -3729,7 +3996,8 @@ int parakeet_full_with_state(
37293996 const auto token_id = state->decoded_tokens [i];
37303997 const char * token_str = parakeet_token_to_str (ctx, token_id);
37313998 if (token_str) {
3732- text += sentencepiece_piece_to_text (token_str, text.empty ());
3999+ const bool is_first_piece = (tokens_before == 0 ) && text.empty ();
4000+ text += sentencepiece_piece_to_text (token_str, is_first_piece);
37334001 }
37344002
37354003 auto token_data = state->decoded_token_data [i];
@@ -3787,6 +4055,11 @@ int parakeet_chunk(
37874055 ggml_backend_buffer_clear (state->lstm_state .buffer , 0 );
37884056 state->decoded_tokens .clear ();
37894057 state->decoded_token_data .clear ();
4058+
4059+ state->tdt_stream_state .initialized = false ;
4060+ state->tdt_stream_state .last_token = 0 ;
4061+ state->tdt_stream_state .time_step = 0 ;
4062+ state->tdt_stream_state .decoded_length = 0 ;
37904063 }
37914064
37924065 if (n_samples > 0 ) {
@@ -3800,7 +4073,7 @@ int parakeet_chunk(
38004073 const int total_len = parakeet_n_len_from_state (state);
38014074 const int model_max_ctx = parakeet_n_audio_ctx (ctx);
38024075 params.audio_ctx = std::min (total_len, model_max_ctx);
3803- PARAKEET_LOG_INFO (" Processing audio: total_frames=%d, chunk_size=%d\n " , total_len, params.audio_ctx );
4076+ PARAKEET_LOG_DEBUG (" Processing audio: total_frames=%d, chunk_size=%d\n " , total_len, params.audio_ctx );
38044077 }
38054078 state->n_audio_ctx = params.audio_ctx ;
38064079
@@ -3829,7 +4102,8 @@ int parakeet_chunk(
38294102 const auto token_id = state->decoded_tokens [i];
38304103 const char * token_str = parakeet_token_to_str (ctx, token_id);
38314104 if (token_str) {
3832- text += sentencepiece_piece_to_text (token_str, text.empty ());
4105+ const bool is_first_piece = (tokens_before == 0 ) && text.empty ();
4106+ text += sentencepiece_piece_to_text (token_str, is_first_piece);
38334107 }
38344108
38354109 // Use the stored token data from parakeet_decode
0 commit comments