Skip to content

Commit ffb1239

Browse files
committed
parakeet : add parakeet_stream_push API
This commit adds a new API to parakeet to support streaming audio input. The motivation for this came from trying to use parakeet.cpp with ffmpeg where the existing API did not work very well.
1 parent 7d65b61 commit ffb1239

4 files changed

Lines changed: 417 additions & 4 deletions

File tree

include/parakeet.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,29 @@ extern "C" {
314314
const float * samples,
315315
int n_samples);
316316

317+
// Initialize streaming state for a new stream.
318+
PARAKEET_API int parakeet_stream_init(
319+
struct parakeet_context * ctx,
320+
struct parakeet_state * state,
321+
struct parakeet_full_params params);
322+
323+
// Push audio samples in streaming mode. Internally this function will structure
324+
// the samples in a buffer where with a left context, a center chunk, and a
325+
// right context. The encoder will see the complete buffer which enables it
326+
// to get boundry context for the target/center audio chunk. This avoids hard
327+
// cut offs at the chunk boundaries. The joint network then only sees the
328+
// center chunk and this function internally handles the context windowing.
329+
PARAKEET_API int parakeet_stream_push(
330+
struct parakeet_context * ctx,
331+
struct parakeet_state * state,
332+
const float * samples,
333+
int n_samples);
334+
335+
// Flush the final partial chunk at end-of-stream.
336+
PARAKEET_API int parakeet_stream_flush(
337+
struct parakeet_context * ctx,
338+
struct parakeet_state * state);
339+
317340
// Number of generated text segments
318341
PARAKEET_API int parakeet_full_n_segments (struct parakeet_context * ctx);
319342
PARAKEET_API int parakeet_full_n_segments_from_state(struct parakeet_state * state);

src/parakeet.cpp

Lines changed: 278 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -580,6 +580,18 @@ struct tdt_stream_state {
580580
bool initialized; // whether prediction LSTM state has been initialized
581581
};
582582

583+
struct parakeet_stream {
584+
std::vector<float> buffer;
585+
int64_t n_samples_advanced = 0;
586+
587+
int n_left_ctx = 0;
588+
int n_chunk = 0;
589+
int n_right_ctx = 0;
590+
591+
parakeet_full_params params = {};
592+
bool initialized = false;
593+
};
594+
583595
struct parakeet_state {
584596
int64_t t_sample_us = 0;
585597
int64_t t_encode_us = 0;
@@ -638,6 +650,8 @@ struct parakeet_state {
638650
parakeet_lstm_state lstm_state;
639651

640652
struct tdt_stream_state tdt_stream_state = {0, 0, 0, false};
653+
654+
parakeet_stream stream;
641655
};
642656

643657
// FFT cache for mel spectrogram computation
@@ -2367,7 +2381,7 @@ static bool parakeet_decode(
23672381
// Start with the blank token (8192)
23682382
parakeet_token last_token = blank_id;
23692383

2370-
PARAKEET_LOG_INFO("parakeet_decode: starting decode with n_frames=%d\n", n_frames);
2384+
PARAKEET_LOG_DEBUG("parakeet_decode: starting decode with n_frames=%d\n", n_frames);
23712385

23722386
batch.n_tokens = 1;
23732387
batch.token[0] = last_token;
@@ -3609,6 +3623,259 @@ struct parakeet_full_params parakeet_full_default_params(enum parakeet_sampling_
36093623
return result;
36103624
}
36113625

3626+
static void parakeet_stream_reset_state(struct parakeet_state * state) {
3627+
if (state == nullptr) {
3628+
return;
3629+
}
3630+
3631+
if (state->lstm_state.buffer) {
3632+
ggml_backend_buffer_clear(state->lstm_state.buffer, 0);
3633+
}
3634+
3635+
state->decoded_tokens.clear();
3636+
state->decoded_token_data.clear();
3637+
state->result_all.clear();
3638+
3639+
state->tdt_stream_state.initialized = false;
3640+
state->tdt_stream_state.last_token = 0;
3641+
state->tdt_stream_state.time_step = 0;
3642+
state->tdt_stream_state.decoded_length = 0;
3643+
3644+
state->stream.buffer.clear();
3645+
state->stream.n_samples_advanced = 0;
3646+
state->stream.n_left_ctx = 0;
3647+
state->stream.n_chunk = 0;
3648+
state->stream.n_right_ctx = 0;
3649+
state->stream.params = {};
3650+
state->stream.initialized = false;
3651+
3652+
state->enc_out_buffer.clear();
3653+
state->enc_out_frames = 0;
3654+
state->n_frames = 0;
3655+
state->n_audio_ctx = 0;
3656+
}
3657+
3658+
static int parakeet_stream_process_window(
3659+
struct parakeet_context * ctx,
3660+
struct parakeet_state * state,
3661+
const float * samples,
3662+
int n_samples,
3663+
int n_chunk) {
3664+
const parakeet_stream & stream = state->stream;
3665+
const parakeet_full_params & params = stream.params;
3666+
const int d_enc = ctx->model.hparams.n_audio_state;
3667+
3668+
// process all the samples.
3669+
if (parakeet_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
3670+
return -2;
3671+
}
3672+
3673+
const int left_mel_frames = stream.n_left_ctx / PARAKEET_HOP_LENGTH;
3674+
const int chunk_mel_frames = n_chunk / PARAKEET_HOP_LENGTH;
3675+
3676+
state->n_audio_ctx = state->mel.n_len;
3677+
// process entire log mel spectrogram.
3678+
if (!parakeet_encode_internal(*ctx, *state, 0, params.n_threads,
3679+
params.abort_callback, params.abort_callback_user_data)) {
3680+
return -6;
3681+
}
3682+
3683+
const int left_enc_frames = left_mel_frames / ctx->model.hparams.subsampling_factor;
3684+
const int chunk_enc_frames = chunk_mel_frames / ctx->model.hparams.subsampling_factor;
3685+
3686+
if (chunk_enc_frames <= 0) {
3687+
return 0;
3688+
}
3689+
3690+
// Copy the center chunk so that it is the only part that the joint network sees.
3691+
state->enc_out_buffer.resize(chunk_enc_frames * d_enc);
3692+
ggml_backend_tensor_get(state->enc_out, state->enc_out_buffer.data(),
3693+
left_enc_frames * d_enc * sizeof(float),
3694+
chunk_enc_frames * d_enc * sizeof(float));
3695+
3696+
state->enc_out_frames = chunk_enc_frames;
3697+
state->n_frames = chunk_enc_frames;
3698+
3699+
const size_t tokens_before = state->decoded_tokens.size();
3700+
3701+
// Run the prediction and joint network on the center chunk.
3702+
if (!parakeet_decode_chunk(*ctx, *state, state->batch, chunk_enc_frames, params.n_threads, &params)) {
3703+
return -7;
3704+
}
3705+
3706+
const size_t tokens_after = state->decoded_tokens.size();
3707+
const size_t new_token_count = tokens_after - tokens_before;
3708+
3709+
if (new_token_count > 0) {
3710+
std::string text;
3711+
std::vector<parakeet_token_data> result_tokens;
3712+
const int64_t chunk_t0 = 100LL * stream.n_samples_advanced / PARAKEET_SAMPLE_RATE;
3713+
const int64_t chunk_t1 = 100LL * (stream.n_samples_advanced + n_chunk) / PARAKEET_SAMPLE_RATE;
3714+
const int frame_offset = chunk_t0 / ctx->model.hparams.subsampling_factor;
3715+
3716+
result_tokens.reserve(new_token_count);
3717+
3718+
for (size_t i = tokens_before; i < tokens_after; ++i) {
3719+
const auto token_id = state->decoded_tokens[i];
3720+
const char * token_str = parakeet_token_to_str(ctx, token_id);
3721+
if (token_str) {
3722+
const bool is_first_piece = (tokens_before == 0) && text.empty();
3723+
text += sentencepiece_piece_to_text(token_str, is_first_piece);
3724+
}
3725+
3726+
auto token_data = state->decoded_token_data[i];
3727+
token_data.frame_index += frame_offset;
3728+
token_data.t0 += chunk_t0;
3729+
token_data.t1 += chunk_t0;
3730+
result_tokens.push_back(token_data);
3731+
}
3732+
3733+
refine_timestamps_tdt(ctx->vocab, result_tokens);
3734+
3735+
if (!text.empty()) {
3736+
parakeet_segment segment;
3737+
segment.t0 = chunk_t0;
3738+
segment.t1 = chunk_t1;
3739+
segment.text = std::move(text);
3740+
segment.tokens = std::move(result_tokens);
3741+
3742+
state->result_all.push_back(std::move(segment));
3743+
3744+
if (params.new_segment_callback) {
3745+
params.new_segment_callback(ctx, state, 1, params.new_segment_callback_user_data);
3746+
}
3747+
}
3748+
}
3749+
3750+
return 0;
3751+
}
3752+
3753+
static int ms_to_n_samples(int ms) {
3754+
return ms * PARAKEET_SAMPLE_RATE / 1000;
3755+
}
3756+
3757+
int parakeet_stream_init(
3758+
struct parakeet_context * ctx,
3759+
struct parakeet_state * state,
3760+
struct parakeet_full_params params) {
3761+
if (ctx == nullptr || state == nullptr) {
3762+
return -1;
3763+
}
3764+
3765+
const int n_left_ctx = ms_to_n_samples(params.left_context_ms);
3766+
const int n_chunk = ms_to_n_samples(params.chunk_length_ms);
3767+
const int n_right_ctx = ms_to_n_samples(params.right_context_ms);
3768+
3769+
if (n_left_ctx < 0 || n_chunk <= 0 || n_right_ctx < 0) {
3770+
return -1;
3771+
}
3772+
3773+
parakeet_stream_reset_state(state);
3774+
3775+
state->stream.n_left_ctx = n_left_ctx;
3776+
state->stream.n_chunk = n_chunk;
3777+
state->stream.n_right_ctx = n_right_ctx;
3778+
state->stream.params = params;
3779+
state->stream.initialized = true;
3780+
3781+
if (n_left_ctx > 0) {
3782+
state->stream.buffer.assign(n_left_ctx, 0.0f);
3783+
}
3784+
3785+
return 0;
3786+
}
3787+
3788+
int parakeet_stream_push(
3789+
struct parakeet_context * ctx,
3790+
struct parakeet_state * state,
3791+
const float * samples,
3792+
int n_samples) {
3793+
if (ctx == nullptr || state == nullptr || samples == nullptr || n_samples <= 0) {
3794+
return -1;
3795+
}
3796+
3797+
if (!state->stream.initialized) {
3798+
return -1;
3799+
}
3800+
3801+
const int n_total_samples = state->stream.n_left_ctx + state->stream.n_chunk + state->stream.n_right_ctx;
3802+
3803+
// Insert the new chunk of samples as the new center and right context.
3804+
state->stream.buffer.insert(state->stream.buffer.end(), samples, samples + n_samples);
3805+
3806+
// As long as we have enough samples to form a complete window we process it.
3807+
while (state->stream.buffer.size() >= (size_t) n_total_samples) {
3808+
const int ret = parakeet_stream_process_window(
3809+
ctx,
3810+
state,
3811+
state->stream.buffer.data(),
3812+
n_total_samples,
3813+
state->stream.n_chunk);
3814+
if (ret != 0) {
3815+
return ret;
3816+
}
3817+
3818+
// TODO: std::vector::erase is O(n) and not optimal. We should probably
3819+
// use a ring buffer instead.
3820+
// Shift the center and right context to the start of the buffer. This
3821+
// allows the next call to have the current center chunk as its left
3822+
// context, and the right context will become part of the next target
3823+
// chunk together with the new samples which will make up the rest of
3824+
// the target chunk and the new right context.
3825+
state->stream.buffer.erase(state->stream.buffer.begin(), state->stream.buffer.begin() + state->stream.n_chunk);
3826+
3827+
state->stream.n_samples_advanced += state->stream.n_chunk;
3828+
}
3829+
3830+
return 0;
3831+
}
3832+
3833+
int parakeet_stream_flush(
3834+
struct parakeet_context * ctx,
3835+
struct parakeet_state * state) {
3836+
if (ctx == nullptr || state == nullptr) {
3837+
return -1;
3838+
}
3839+
3840+
if (!state->stream.initialized) {
3841+
return -1;
3842+
}
3843+
3844+
while (state->stream.buffer.size() > (size_t) state->stream.n_left_ctx) {
3845+
const int n_remaining_samples = (int) state->stream.buffer.size() - state->stream.n_left_ctx;
3846+
const int n_flush_chunk = std::min(state->stream.n_chunk, n_remaining_samples);
3847+
const int n_right_available = std::min(state->stream.n_right_ctx, n_remaining_samples - n_flush_chunk);
3848+
const int n_copied = state->stream.n_left_ctx + n_flush_chunk + n_right_available;
3849+
3850+
std::vector<float> flush_window(state->stream.n_left_ctx + n_flush_chunk + state->stream.n_right_ctx, 0.0f);
3851+
3852+
std::copy_n(state->stream.buffer.begin(), n_copied, flush_window.begin());
3853+
3854+
const int ret = parakeet_stream_process_window(
3855+
ctx,
3856+
state,
3857+
flush_window.data(),
3858+
(int) flush_window.size(),
3859+
n_flush_chunk);
3860+
if (ret != 0) {
3861+
return ret;
3862+
}
3863+
3864+
state->stream.buffer.erase(state->stream.buffer.begin(), state->stream.buffer.begin() + n_flush_chunk);
3865+
state->stream.n_samples_advanced += n_flush_chunk;
3866+
}
3867+
3868+
state->stream.buffer.clear();
3869+
state->stream.n_samples_advanced = 0;
3870+
state->stream.n_left_ctx = 0;
3871+
state->stream.n_chunk = 0;
3872+
state->stream.n_right_ctx = 0;
3873+
state->stream.params = {};
3874+
state->stream.initialized = false;
3875+
3876+
return 0;
3877+
}
3878+
36123879
int parakeet_full_with_state(
36133880
struct parakeet_context * ctx,
36143881
struct parakeet_state * state,
@@ -3729,7 +3996,8 @@ int parakeet_full_with_state(
37293996
const auto token_id = state->decoded_tokens[i];
37303997
const char * token_str = parakeet_token_to_str(ctx, token_id);
37313998
if (token_str) {
3732-
text += sentencepiece_piece_to_text(token_str, text.empty());
3999+
const bool is_first_piece = (tokens_before == 0) && text.empty();
4000+
text += sentencepiece_piece_to_text(token_str, is_first_piece);
37334001
}
37344002

37354003
auto token_data = state->decoded_token_data[i];
@@ -3787,6 +4055,11 @@ int parakeet_chunk(
37874055
ggml_backend_buffer_clear(state->lstm_state.buffer, 0);
37884056
state->decoded_tokens.clear();
37894057
state->decoded_token_data.clear();
4058+
4059+
state->tdt_stream_state.initialized = false;
4060+
state->tdt_stream_state.last_token = 0;
4061+
state->tdt_stream_state.time_step = 0;
4062+
state->tdt_stream_state.decoded_length = 0;
37904063
}
37914064

37924065
if (n_samples > 0) {
@@ -3800,7 +4073,7 @@ int parakeet_chunk(
38004073
const int total_len = parakeet_n_len_from_state(state);
38014074
const int model_max_ctx = parakeet_n_audio_ctx(ctx);
38024075
params.audio_ctx = std::min(total_len, model_max_ctx);
3803-
PARAKEET_LOG_INFO("Processing audio: total_frames=%d, chunk_size=%d\n", total_len, params.audio_ctx);
4076+
PARAKEET_LOG_DEBUG("Processing audio: total_frames=%d, chunk_size=%d\n", total_len, params.audio_ctx);
38044077
}
38054078
state->n_audio_ctx = params.audio_ctx;
38064079

@@ -3829,7 +4102,8 @@ int parakeet_chunk(
38294102
const auto token_id = state->decoded_tokens[i];
38304103
const char * token_str = parakeet_token_to_str(ctx, token_id);
38314104
if (token_str) {
3832-
text += sentencepiece_piece_to_text(token_str, text.empty());
4105+
const bool is_first_piece = (tokens_before == 0) && text.empty();
4106+
text += sentencepiece_piece_to_text(token_str, is_first_piece);
38334107
}
38344108

38354109
// Use the stored token data from parakeet_decode

tests/CMakeLists.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,15 @@ target_compile_definitions(${PARAKEET_TEST} PRIVATE
121121
SAMPLE_PATH="${PROJECT_SOURCE_DIR}/samples/jfk.wav")
122122
add_test(NAME ${PARAKEET_TEST} COMMAND ${PARAKEET_TEST})
123123

124+
set(PARAKEET_TEST test-parakeet-stream)
125+
add_executable(${PARAKEET_TEST} ${PARAKEET_TEST}.cpp)
126+
target_include_directories(${PARAKEET_TEST} PRIVATE ../include ../ggml/include ../examples)
127+
target_link_libraries(${PARAKEET_TEST} PRIVATE parakeet common)
128+
target_compile_definitions(${PARAKEET_TEST} PRIVATE
129+
PARAKEET_MODEL_PATH="${PROJECT_SOURCE_DIR}/models/ggml-parakeet-tdt-0.6b-v3.bin"
130+
SAMPLE_PATH="${PROJECT_SOURCE_DIR}/samples/gb1.wav")
131+
add_test(NAME ${PARAKEET_TEST} COMMAND ${PARAKEET_TEST})
132+
124133
set(PARAKEET_TEST test-parakeet-full)
125134
add_executable(${PARAKEET_TEST} ${PARAKEET_TEST}.cpp)
126135
target_include_directories(${PARAKEET_TEST} PRIVATE ../include ../ggml/include ../examples)

0 commit comments

Comments
 (0)