Skip to content

Commit 7a5f489

Browse files
committed
parakeet : fix output projection from pre-encode
1 parent 7de9732 commit 7a5f489

1 file changed

Lines changed: 10 additions & 7 deletions

File tree

src/parakeet.cpp

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1786,20 +1786,21 @@ static struct ggml_cgraph * parakeet_build_graph_conv(parakeet_context & pctx, p
17861786
ggml_set_name(cur, "pre_conv_6_relu");
17871787
ggml_set_output(cur);
17881788

1789+
// [freq, time, chan]
1790+
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
1791+
// [freq, chan, time]
1792+
cur = ggml_cont(ctx0, cur);
1793+
17891794
const int n_freq = cur->ne[0]; // 16
1790-
const int n_frames = cur->ne[1]; // 188
1791-
const int n_chan = cur->ne[2]; // 256
1795+
const int n_chan = cur->ne[1]; // 256
1796+
const int n_frames = cur->ne[2]; // time
17921797

1793-
// {n_chan, n_frames, n_freq, batch} -> {d_feat (n_freq * n_chan), n_frames}
1794-
// {256, 188, 16, 1} -> {4096, 188, 1, 1}
1798+
// [freq, time, chan, batch] -> [(freq * chan), time]
17951799
cur = ggml_reshape_2d(ctx0, cur, n_freq * n_chan, n_frames);
17961800

1797-
// {4096, 188, 1, 1} -> {1024, 188, 1, 1}
17981801
cur = ggml_mul_mat(ctx0, model.enc_pre_out_w, cur);
17991802
cur = ggml_add(ctx0, cur, model.enc_pre_out_b);
18001803

1801-
// {embd, n_time}
1802-
// {1024, 188, 1, 1}
18031804
ggml_set_name(cur, "embd_conv");
18041805
ggml_set_output(cur);
18051806
pstate.embd_conv = cur;
@@ -2123,6 +2124,8 @@ static bool parakeet_encode_internal(
21232124

21242125
parakeet_print_tensor_gf("pre_conv_6_relu", sched, gf, 10);
21252126

2127+
parakeet_print_tensor_gf(pctx.model.enc_pre_out_w, sched, gf, 10);
2128+
parakeet_print_tensor_gf(pctx.model.enc_pre_out_b, sched, gf, 10);
21262129
parakeet_print_tensor_gf("embd_conv", sched, gf, 10);
21272130

21282131
// compare final output:

0 commit comments

Comments
 (0)