Skip to content

Commit 7de9732

Browse files
committed
parakeet : got up to pre_conv_6_relu working
1 parent 94c2a7a commit 7de9732

2 files changed

Lines changed: 118 additions & 43 deletions

File tree

src/parakeet.cpp

Lines changed: 105 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1462,24 +1462,34 @@ static bool parakeet_model_load(struct parakeet_model_loader * loader, parakeet_
14621462
// Encoder pre_encode
14631463
const int n_subsampling_channels = hparams.n_subsampling_channels;
14641464
model.enc_pre_out_w = create_tensor(PARAKEET_TENSOR_ENC_PRE_OUT_WEIGHT, ggml_new_tensor_2d(ctx, wtype, 4096, n_audio_state));
1465+
ggml_set_name(model.enc_pre_out_w, "enc_pre_out_w");
14651466
model.enc_pre_out_b = create_tensor(PARAKEET_TENSOR_ENC_PRE_OUT_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state));
1467+
ggml_set_name(model.enc_pre_out_b, "enc_pre_out_b");
14661468

14671469
model.enc_pre_conv_0_w = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_0_WEIGHT, ggml_new_tensor_4d(ctx, vtype, 3, 3, 1, n_subsampling_channels));
14681470
ggml_set_name(model.enc_pre_conv_0_w, "enc_pre_conv_0_w");
14691471
model.enc_pre_conv_0_b = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_0_BIAS, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, 1, n_subsampling_channels, 1));
14701472
ggml_set_name(model.enc_pre_conv_0_b, "enc_pre_conv_0_b");
14711473

14721474
model.enc_pre_conv_2_w = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_2_WEIGHT, ggml_new_tensor_4d(ctx, vtype, 3, 3, 1, n_subsampling_channels));
1475+
ggml_set_name(model.enc_pre_conv_2_w, "enc_pre_conv_2_w");
14731476
model.enc_pre_conv_2_b = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_2_BIAS, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, 1, n_subsampling_channels, 1));
1477+
ggml_set_name(model.enc_pre_conv_2_b, "enc_pre_conv_2_b");
14741478

14751479
model.enc_pre_conv_3_w = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_3_WEIGHT, ggml_new_tensor_4d(ctx, wtype, 1, 1, n_subsampling_channels, n_subsampling_channels));
1480+
ggml_set_name(model.enc_pre_conv_3_w, "enc_pre_conv_3_w");
14761481
model.enc_pre_conv_3_b = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_3_BIAS, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, 1, n_subsampling_channels, 1));
1482+
ggml_set_name(model.enc_pre_conv_3_b, "enc_pre_conv_3_b");
14771483

14781484
model.enc_pre_conv_5_w = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_5_WEIGHT, ggml_new_tensor_4d(ctx, vtype, 3, 3, 1, n_subsampling_channels));
1485+
ggml_set_name(model.enc_pre_conv_5_w, "enc_pre_conv_5_w");
14791486
model.enc_pre_conv_5_b = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_5_BIAS, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, 1, n_subsampling_channels, 1));
1487+
ggml_set_name(model.enc_pre_conv_5_b, "enc_pre_conv_5_b");
14801488

14811489
model.enc_pre_conv_6_w = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_6_WEIGHT, ggml_new_tensor_4d(ctx, wtype, 1, 1, n_subsampling_channels, n_subsampling_channels));
1490+
ggml_set_name(model.enc_pre_conv_6_w, "enc_pre_conv_6_w");
14821491
model.enc_pre_conv_6_b = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_6_BIAS, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, 1, n_subsampling_channels, 1));
1492+
ggml_set_name(model.enc_pre_conv_6_b, "enc_pre_conv_6_b");
14831493

14841494
// Encoder layers
14851495
for (int i = 0; i < n_audio_layer; ++i) {
@@ -1722,69 +1732,64 @@ static struct ggml_cgraph * parakeet_build_graph_conv(parakeet_context & pctx, p
17221732
struct ggml_context * ctx0 = ggml_init(params);
17231733
ggml_cgraph * gf = ggml_new_graph(ctx0);
17241734

1725-
// [n_time, n_mels] {1500, 128, 1, 1}
1726-
struct ggml_tensor * mel = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_time, n_mels, 1, 1);
1735+
// [freq, time]
1736+
struct ggml_tensor * mel = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_mels, n_time, 1, 1);
17271737
ggml_set_name(mel, "mel");
17281738
ggml_set_input(mel);
17291739
ggml_set_output(mel);
17301740

1731-
struct ggml_tensor * cur = mel;
1732-
//ggml_set_name(cur, "input_to_conv_0");
1733-
//ggml_set_output(cur);
1734-
1735-
// enc_pre_conv_0_w: {3, 3, 1, 256}
1736-
// {1500, 128, 1, 1} -> {750, 64, 256, 1}
1737-
cur = ggml_conv_2d(ctx0, model.enc_pre_conv_0_w, cur, 2, 2, 1, 1, 1, 1);
1738-
ggml_set_name(cur, "pre_conv_0");
1739-
ggml_set_output(cur);
1740-
1741+
// [freq, time, channels, batch]
1742+
struct ggml_tensor * cur = ggml_conv_2d(ctx0, model.enc_pre_conv_0_w, mel, 2, 2, 1, 1, 1, 1);
17411743
cur = ggml_add(ctx0, cur, model.enc_pre_conv_0_b);
1742-
ggml_set_name(cur, "pre_conv_0_bias");
1744+
ggml_set_name(cur, "pre_conv_0");
17431745
ggml_set_output(cur);
17441746

17451747
cur = ggml_relu(ctx0, cur);
17461748
ggml_set_name(cur, "pre_conv_0_relu");
17471749
ggml_set_output(cur);
17481750

1749-
// enc_pre_conv_2_w: {3, 3, 1, 256}
1750-
// {750, 64, 256, 1} -> {375, 32, 256, 1}
1751+
// enc_pre_conv_2_w: {3, 3, 1, 256} (depthwise)
1752+
// [freq, time, channels, batch]
17511753
cur = ggml_conv_2d_dw_direct(ctx0, model.enc_pre_conv_2_w, cur, 2, 2, 1, 1, 1, 1);
1752-
ggml_set_name(cur, "pre_conv_2");
17531754
cur = ggml_add(ctx0, cur, model.enc_pre_conv_2_b);
1755+
ggml_set_output(cur);
1756+
ggml_set_name(cur, "pre_conv_2");
17541757

1755-
// enc_pre_conv: {1, 1, 256, 256}
1756-
// {375, 32, 256, 1} -> {375, 32, 256, 1}
1758+
// enc_pre_conv_3_w: {1, 1, 256, 256} (pointwise)
1759+
// [freq, time, channels, batch]
17571760
cur = ggml_conv_2d(ctx0, model.enc_pre_conv_3_w, cur, 1, 1, 0, 0, 1, 1);
1758-
ggml_set_name(cur, "pre_conv_3");
17591761
cur = ggml_add(ctx0, cur, model.enc_pre_conv_3_b);
1762+
ggml_set_name(cur, "pre_conv_3");
1763+
ggml_set_output(cur);
17601764

17611765
cur = ggml_relu(ctx0, cur);
1766+
ggml_set_output(cur);
17621767
ggml_set_name(cur, "pre_conv_3_relu");
17631768

1764-
// enc_pre_conv_5_w: {3, 3, 1, 256}
1765-
// {375, 32, 256, 1} -> {188, 16, 256, 1}
1769+
// enc_pre_conv_5_w: {3, 3, 1, 256} (depthwise)
1770+
// [freq, time, channels, batch]
17661771
cur = ggml_conv_2d_dw_direct(ctx0, model.enc_pre_conv_5_w, cur, 2, 2, 1, 1, 1, 1);
1767-
ggml_set_name(cur, "pre_conv_5");
1772+
ggml_set_name(cur, "pre_conv_5_direct");
1773+
ggml_set_output(cur);
17681774
cur = ggml_add(ctx0, cur, model.enc_pre_conv_5_b);
1775+
ggml_set_name(cur, "pre_conv_5");
1776+
ggml_set_output(cur);
17691777

1770-
// enc_pre_conv_6_w: {1, 1, 256, 256}
1771-
// {188, 16, 256, 1} -> {188, 16, 256, 1}
1778+
// enc_pre_conv_6_w: {1, 1, 256, 256} (pointwise)
1779+
// [freq, time, channels, batch]
17721780
cur = ggml_conv_2d(ctx0, model.enc_pre_conv_6_w, cur, 1, 1, 0, 0, 1, 1);
1773-
ggml_set_name(cur, "pre_conv_6");
17741781
cur = ggml_add(ctx0, cur, model.enc_pre_conv_6_b);
1782+
ggml_set_output(cur);
1783+
ggml_set_name(cur, "pre_conv_6");
17751784

17761785
cur = ggml_relu(ctx0, cur);
17771786
ggml_set_name(cur, "pre_conv_6_relu");
1787+
ggml_set_output(cur);
17781788

1779-
const int n_frames = cur->ne[0]; // 188
1780-
const int n_freq = cur->ne[1]; // 16
1789+
const int n_freq = cur->ne[0]; // 16
1790+
const int n_frames = cur->ne[1]; // 188
17811791
const int n_chan = cur->ne[2]; // 256
17821792

1783-
// {n_frames, n_freq, n_chan, batch} -> {n_chan, n_frames, n_freq, batch}
1784-
// {188, 16, 256, 1} -> {256, 188, 16, 1}
1785-
cur = ggml_permute(ctx0, cur, 1, 2, 0, 3);
1786-
cur = ggml_cont(ctx0, cur);
1787-
17881793
// {n_chan, n_frames, n_freq, batch} -> {d_feat (n_freq * n_chan), n_frames}
17891794
// {256, 188, 16, 1} -> {4096, 188, 1, 1}
17901795
cur = ggml_reshape_2d(ctx0, cur, n_freq * n_chan, n_frames);
@@ -2077,7 +2082,7 @@ static bool parakeet_encode_internal(
20772082

20782083
for (int j = 0; j < mel_inp.n_mel; ++j) {
20792084
for (int i = i0; i < i1; ++i) {
2080-
dst[j*n_ctx + (i - i0)] = mel_inp.data[j*mel_inp.n_len + i];
2085+
dst[i * mel_inp.n_mel + j] = mel_inp.data[j * mel_inp.n_len + (i + i0)];
20812086
}
20822087
}
20832088

@@ -2090,11 +2095,38 @@ static bool parakeet_encode_internal(
20902095

20912096
// TODO: remove after debugging
20922097
{
2093-
//parakeet_print_tensor_gf(pctx.model.enc_pre_conv_0_w, sched, gf, 10);
2094-
// Comparing bias as that is what the pytorch conv layer does.
2095-
parakeet_print_tensor_gf(pctx.model.enc_pre_conv_0_b, sched, gf, 10);
20962098
parakeet_print_tensor_gf("mel", sched, gf, 10);
2097-
parakeet_print_tensor_gf("pre_conv_0_bias", sched, gf, 10);
2099+
// applied bias result:
2100+
//parakeet_print_tensor_gf(pctx.model.enc_pre_conv_0_w, sched, gf, 10);
2101+
//parakeet_print_tensor_gf(pctx.model.enc_pre_conv_0_b, sched, gf, 10);
2102+
parakeet_print_tensor_gf("pre_conv_0", sched, gf, 10);
2103+
2104+
parakeet_print_tensor_gf("pre_conv_0_relu", sched, gf, 10);
2105+
2106+
//parakeet_print_tensor_gf(pctx.model.enc_pre_conv_2_w, sched, gf, 10);
2107+
//parakeet_print_tensor_gf(pctx.model.enc_pre_conv_2_b, sched, gf, 10);
2108+
parakeet_print_tensor_gf("pre_conv_2", sched, gf, 10);
2109+
2110+
//parakeet_print_tensor_gf(pctx.model.enc_pre_conv_3_w, sched, gf, 10);
2111+
//parakeet_print_tensor_gf(pctx.model.enc_pre_conv_3_b, sched, gf, 10);
2112+
parakeet_print_tensor_gf("pre_conv_3", sched, gf, 10);
2113+
2114+
parakeet_print_tensor_gf("pre_conv_3_relu", sched, gf, 10);
2115+
2116+
//parakeet_print_tensor_gf(pctx.model.enc_pre_conv_5_w, sched, gf, 10);
2117+
//parakeet_print_tensor_gf(pctx.model.enc_pre_conv_5_b, sched, gf, 10);
2118+
parakeet_print_tensor_gf("pre_conv_5", sched, gf, 10);
2119+
2120+
//parakeet_print_tensor_gf(pctx.model.enc_pre_conv_6_w, sched, gf, 10);
2121+
//parakeet_print_tensor_gf(pctx.model.enc_pre_conv_6_b, sched, gf, 10);
2122+
parakeet_print_tensor_gf("pre_conv_6", sched, gf, 10);
2123+
2124+
parakeet_print_tensor_gf("pre_conv_6_relu", sched, gf, 10);
2125+
2126+
parakeet_print_tensor_gf("embd_conv", sched, gf, 10);
2127+
2128+
// compare final output:
2129+
//parakeet_print_tensor_gf(pstate.embd_conv, sched, gf, 10);
20982130
}
20992131

21002132
}
@@ -2634,13 +2666,32 @@ static bool log_mel_spectrogram(
26342666
{
26352667
std::vector<std::thread> workers(n_threads - 1);
26362668
for (int iw = 0; iw < n_threads - 1; ++iw) {
2637-
workers[iw] = std::thread(
2638-
log_mel_spectrogram_worker_thread, iw + 1, window_func, window_size, std::cref(samples_padded),
2639-
samples_padded.size(), frame_size, frame_step, n_threads,
2640-
std::cref(filters), std::ref(mel), std::cref(cache));
2669+
workers[iw] = std::thread(log_mel_spectrogram_worker_thread,
2670+
iw + 1, // thread index
2671+
window_func,
2672+
window_size,
2673+
std::cref(samples_padded),
2674+
samples_padded.size(),
2675+
frame_size,
2676+
frame_step,
2677+
n_threads,
2678+
std::cref(filters),
2679+
std::ref(mel),
2680+
std::cref(cache));
26412681
}
26422682

2643-
log_mel_spectrogram_worker_thread(0, window_func, window_size, samples_padded, samples_padded.size(), frame_size, frame_step, n_threads, filters, mel, cache);
2683+
log_mel_spectrogram_worker_thread(
2684+
0,
2685+
window_func,
2686+
window_size,
2687+
samples_padded,
2688+
samples_padded.size(),
2689+
frame_size,
2690+
frame_step,
2691+
n_threads,
2692+
filters,
2693+
mel,
2694+
cache);
26442695

26452696
for (int iw = 0; iw < n_threads - 1; ++iw) {
26462697
workers[iw].join();
@@ -3027,7 +3078,18 @@ void parakeet_free_params(struct parakeet_full_params * params) {
30273078
}
30283079

30293080
int parakeet_pcm_to_mel_with_state(struct parakeet_context * ctx, struct parakeet_state * state, const float * samples, int n_samples, int n_threads) {
3030-
if (!log_mel_spectrogram(*state, samples, n_samples, PARAKEET_SAMPLE_RATE, ctx->model.hparams.n_fft, PARAKEET_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, false, state->mel, ctx->mel_cache)) {
3081+
if (!log_mel_spectrogram(*state,
3082+
samples,
3083+
n_samples,
3084+
PARAKEET_SAMPLE_RATE,
3085+
ctx->model.hparams.n_fft,
3086+
PARAKEET_HOP_LENGTH,
3087+
ctx->model.filters.n_mel,
3088+
n_threads,
3089+
ctx->model.filters,
3090+
false, // debug
3091+
state->mel,
3092+
ctx->mel_cache)) {
30313093
PARAKEET_LOG_ERROR("%s: failed to compute mel spectrogram\n", __func__);
30323094
return -1;
30333095
}

test-parakeet.sh

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
#!/bin/sh
2+
3+
set -e
4+
5+
build_dir=build-debug
6+
cmd=test-parakeet
7+
8+
cmake --build $build_dir --target $cmd -j 12
9+
10+
#ctest -R ^$cmd$ --test-dir $build_dir --output-on-failure -VV
11+
echo "running ${build_dir}/$cmd with gdb"
12+
gdb --args ${build_dir}/bin/${cmd}
13+
#${build_dir}/bin/${cmd}

0 commit comments

Comments
 (0)