@@ -3561,7 +3561,7 @@ struct test_relu_sqr : public test_case {
35613561// and dispatches a single fused kernel.
35623562struct test_snake_fuse : public test_case {
35633563 const ggml_type type;
3564- const std::array<int64_t , 2 > ne; // [T, C]
3564+ const std::array<int64_t , 4 > ne; // [T, C, D2, D3 ]
35653565
35663566 std::string op_desc (ggml_tensor * t) override {
35673567 GGML_UNUSED (t);
@@ -3586,11 +3586,11 @@ struct test_snake_fuse : public test_case {
35863586 }
35873587
35883588 test_snake_fuse (ggml_type type = GGML_TYPE_F32,
3589- std::array<int64_t , 2 > ne = {256 , 192 })
3589+ std::array<int64_t , 4 > ne = {256 , 192 , 1 , 1 })
35903590 : type(type), ne(ne) {}
35913591
35923592 ggml_tensor * build_graph (ggml_context * ctx) override {
3593- ggml_tensor * x = ggml_new_tensor_2d (ctx, type, ne[0 ], ne[1 ]);
3593+ ggml_tensor * x = ggml_new_tensor_4d (ctx, type, ne[0 ], ne[1 ], ne[ 2 ], ne[ 3 ]);
35943594 ggml_set_name (x, " x" );
35953595
35963596 ggml_tensor * a = ggml_new_tensor_2d (ctx, GGML_TYPE_F32, 1 , ne[1 ]);
@@ -7558,11 +7558,15 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
75587558
75597559 // SNAKE activation fusion: x + sin(a*x)^2 * inv_b
75607560 for (ggml_type type : { GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16 }) {
7561- test_cases.emplace_back (new test_snake_fuse (type, { 5 , 7 })); // primes sub-block
7562- test_cases.emplace_back (new test_snake_fuse (type, { 33 , 32 })); // boundary
7563- test_cases.emplace_back (new test_snake_fuse (type, {1025 , 13 })); // large prime, grid-stride
7564- test_cases.emplace_back (new test_snake_fuse (type, { 128 , 16 })); // power-of-two
7565- test_cases.emplace_back (new test_snake_fuse (type, { 256 , 192 })); // BigVGAN-ish
7561+ test_cases.emplace_back (new test_snake_fuse (type, { 5 , 7 , 1 , 1 })); // primes sub-block
7562+ test_cases.emplace_back (new test_snake_fuse (type, { 33 , 32 , 1 , 1 })); // boundary
7563+ test_cases.emplace_back (new test_snake_fuse (type, {1025 , 13 , 1 , 1 })); // large prime, grid-stride
7564+ test_cases.emplace_back (new test_snake_fuse (type, { 128 , 16 , 1 , 1 })); // power-of-two
7565+ test_cases.emplace_back (new test_snake_fuse (type, { 256 , 192 , 1 , 1 })); // BigVGAN-ish
7566+ // higher-rank shapes: matcher must reject fusion, fallback to naive chain
7567+ test_cases.emplace_back (new test_snake_fuse (type, { 64 , 32 , 2 , 1 })); // ne[2] > 1
7568+ test_cases.emplace_back (new test_snake_fuse (type, { 64 , 32 , 1 , 2 })); // ne[3] > 1
7569+ test_cases.emplace_back (new test_snake_fuse (type, { 64 , 32 , 2 , 3 })); // ne[2] > 1 and ne[3] > 1
75667570 }
75677571
75687572 // glu ops
@@ -9093,9 +9097,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
90939097 test_cases.emplace_back (new test_pad_reflect_1d (GGML_TYPE_F32, {3000 , 384 , 4 , 1 }));
90949098
90959099 // SNAKE activation fusion at BigVGAN scale (T=7680 = 24 kHz x 320 ms, C=192)
9096- test_cases.emplace_back (new test_snake_fuse (GGML_TYPE_F32, {7680 , 192 }));
9097- test_cases.emplace_back (new test_snake_fuse (GGML_TYPE_F16, {7680 , 192 }));
9098- test_cases.emplace_back (new test_snake_fuse (GGML_TYPE_BF16, {7680 , 192 }));
9100+ test_cases.emplace_back (new test_snake_fuse (GGML_TYPE_F32, {7680 , 192 , 1 , 1 }));
9101+ test_cases.emplace_back (new test_snake_fuse (GGML_TYPE_F16, {7680 , 192 , 1 , 1 }));
9102+ test_cases.emplace_back (new test_snake_fuse (GGML_TYPE_BF16, {7680 , 192 , 1 , 1 }));
90999103
91009104 test_cases.emplace_back (new test_mul_mat (GGML_TYPE_F16, GGML_TYPE_F32, 16416 , 1 , 128 , {8 , 1 }, {4 , 1 }, {0 , 2 , 1 , 3 }));
91019105 test_cases.emplace_back (new test_mul_mat (GGML_TYPE_F16, GGML_TYPE_F32, 128 , 1 , 16416 , {8 , 1 }, {4 , 1 }, {0 , 1 , 2 , 3 }, 2 *16416 ));
0 commit comments