@@ -1786,20 +1786,21 @@ static struct ggml_cgraph * parakeet_build_graph_conv(parakeet_context & pctx, p
17861786 ggml_set_name (cur, " pre_conv_6_relu" );
17871787 ggml_set_output (cur);
17881788
1789+ // [freq, time, chan]
1790+ cur = ggml_permute (ctx0, cur, 0 , 2 , 1 , 3 );
1791+ // [freq, chan, time]
1792+ cur = ggml_cont (ctx0, cur);
1793+
17891794 const int n_freq = cur->ne [0 ]; // 16
1790- const int n_frames = cur->ne [1 ]; // 188
1791- const int n_chan = cur->ne [2 ]; // 256
1795+ const int n_chan = cur->ne [1 ]; // 256
1796+ const int n_frames = cur->ne [2 ]; // time
17921797
1793- // {n_chan, n_frames, n_freq, batch} -> {d_feat (n_freq * n_chan), n_frames}
1794- // {256, 188, 16, 1} -> {4096, 188, 1, 1}
1798+ // [freq, time, chan, batch] -> [(freq * chan), time]
17951799 cur = ggml_reshape_2d (ctx0, cur, n_freq * n_chan, n_frames);
17961800
1797- // {4096, 188, 1, 1} -> {1024, 188, 1, 1}
17981801 cur = ggml_mul_mat (ctx0, model.enc_pre_out_w , cur);
17991802 cur = ggml_add (ctx0, cur, model.enc_pre_out_b );
18001803
1801- // {embd, n_time}
1802- // {1024, 188, 1, 1}
18031804 ggml_set_name (cur, " embd_conv" );
18041805 ggml_set_output (cur);
18051806 pstate.embd_conv = cur;
@@ -2123,6 +2124,8 @@ static bool parakeet_encode_internal(
21232124
21242125 parakeet_print_tensor_gf (" pre_conv_6_relu" , sched, gf, 10 );
21252126
2127+ parakeet_print_tensor_gf (pctx.model .enc_pre_out_w , sched, gf, 10 );
2128+ parakeet_print_tensor_gf (pctx.model .enc_pre_out_b , sched, gf, 10 );
21262129 parakeet_print_tensor_gf (" embd_conv" , sched, gf, 10 );
21272130
21282131 // compare final output:
0 commit comments