Skip to content

Commit e936660

Browse files
Ggml/cuda snake fusion hardening (ggml-org#22912)
* cuda: tighten snake fusion type checks for all operands (defensive, sync vulkan) * cuda: reject snake fusion when ne[2] or ne[3] > 1 (mirror vulkan PR review) * cuda: merge type_ok and types_ok into a single types_ok (address am17an review) * cuda: filter ADD/SUB/MUL/DIV in supports_op to F32/F16 bin_bcast only dispatches F32/F16 type triplets, mirror the vulkan filter so unsupported types fall back through cpy instead of aborting. * test-backend-ops: extend snake_fuse to rank-4 with ne[2]/ne[3] > 1 cases
1 parent ef22b3e commit e936660

2 files changed

Lines changed: 39 additions & 17 deletions

File tree

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3929,10 +3929,25 @@ static int ggml_cuda_try_fuse(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph
39293929
// closure check: the trailing add must read the same x as the leading mul
39303930
const ggml_tensor * x_in_add = (add->src[0] == mul1) ? add->src[1] : add->src[0];
39313931

3932-
const bool type_ok = (x->type == GGML_TYPE_F32 || x->type == GGML_TYPE_F16 || x->type == GGML_TYPE_BF16);
3932+
// Kernel iterates over total = T * C, so x and add must be 2D and
3933+
// a / inv_b must collapse to [1, C, 1, 1]. Higher dims are not handled.
3934+
const bool dim_ok = (x->ne[2] == 1 && x->ne[3] == 1) &&
3935+
(add->ne[2] == 1 && add->ne[3] == 1) &&
3936+
(a->ne[2] == 1 && a->ne[3] == 1);
39333937
const bool shape_ok = ggml_are_same_shape(a, inv_b) && a->ne[0] == 1 && a->ne[1] == x->ne[1];
39343938

3935-
if (type_ok && shape_ok && x_in_add == x && add->type == x->type) {
3939+
// x must be in the supported whitelist and every operand / intermediate
3940+
// result must share x's type, since launch_snake casts a / inv_b as
3941+
// float and templates the kernel on a single T. Mixed precision chains
3942+
// fall back to the naive path.
3943+
const ggml_tensor * sin1 = cgraph->nodes[i + 1];
3944+
const bool types_ok = (x->type == GGML_TYPE_F32 || x->type == GGML_TYPE_F16 || x->type == GGML_TYPE_BF16) &&
3945+
(a->type == x->type) && (inv_b->type == x->type) &&
3946+
(mul0->type == x->type) && (sin1->type == x->type) &&
3947+
(sqr->type == x->type) && (mul1->type == x->type) &&
3948+
(add->type == x->type);
3949+
3950+
if (types_ok && shape_ok && dim_ok && x_in_add == x) {
39363951
ggml_cuda_op_snake_fused(*cuda_ctx, x, a, inv_b, add);
39373952
return 4;
39383953
}
@@ -5291,12 +5306,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
52915306
case GGML_OP_VIEW:
52925307
case GGML_OP_PERMUTE:
52935308
case GGML_OP_TRANSPOSE:
5294-
case GGML_OP_ADD:
52955309
case GGML_OP_ADD_ID:
52965310
case GGML_OP_ADD1:
5297-
case GGML_OP_SUB:
5298-
case GGML_OP_MUL:
5299-
case GGML_OP_DIV:
53005311
case GGML_OP_SCALE:
53015312
case GGML_OP_SQR:
53025313
case GGML_OP_SQRT:
@@ -5305,6 +5316,13 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
53055316
case GGML_OP_CLAMP:
53065317
case GGML_OP_LOG:
53075318
return true;
5319+
case GGML_OP_ADD:
5320+
case GGML_OP_SUB:
5321+
case GGML_OP_MUL:
5322+
case GGML_OP_DIV:
5323+
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
5324+
(op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16) &&
5325+
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
53085326
case GGML_OP_SSM_SCAN: {
53095327
if (op->src[3]->ne[0] == 1) {
53105328
// Mamba2

tests/test-backend-ops.cpp

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3561,7 +3561,7 @@ struct test_relu_sqr : public test_case {
35613561
// and dispatches a single fused kernel.
35623562
struct test_snake_fuse : public test_case {
35633563
const ggml_type type;
3564-
const std::array<int64_t, 2> ne; // [T, C]
3564+
const std::array<int64_t, 4> ne; // [T, C, D2, D3]
35653565

35663566
std::string op_desc(ggml_tensor * t) override {
35673567
GGML_UNUSED(t);
@@ -3586,11 +3586,11 @@ struct test_snake_fuse : public test_case {
35863586
}
35873587

35883588
test_snake_fuse(ggml_type type = GGML_TYPE_F32,
3589-
std::array<int64_t, 2> ne = {256, 192})
3589+
std::array<int64_t, 4> ne = {256, 192, 1, 1})
35903590
: type(type), ne(ne) {}
35913591

35923592
ggml_tensor * build_graph(ggml_context * ctx) override {
3593-
ggml_tensor * x = ggml_new_tensor_2d(ctx, type, ne[0], ne[1]);
3593+
ggml_tensor * x = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);
35943594
ggml_set_name(x, "x");
35953595

35963596
ggml_tensor * a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, ne[1]);
@@ -7558,11 +7558,15 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
75587558

75597559
// SNAKE activation fusion: x + sin(a*x)^2 * inv_b
75607560
for (ggml_type type : { GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16 }) {
7561-
test_cases.emplace_back(new test_snake_fuse(type, { 5, 7})); // primes sub-block
7562-
test_cases.emplace_back(new test_snake_fuse(type, { 33, 32})); // boundary
7563-
test_cases.emplace_back(new test_snake_fuse(type, {1025, 13})); // large prime, grid-stride
7564-
test_cases.emplace_back(new test_snake_fuse(type, { 128, 16})); // power-of-two
7565-
test_cases.emplace_back(new test_snake_fuse(type, { 256, 192})); // BigVGAN-ish
7561+
test_cases.emplace_back(new test_snake_fuse(type, { 5, 7, 1, 1})); // primes sub-block
7562+
test_cases.emplace_back(new test_snake_fuse(type, { 33, 32, 1, 1})); // boundary
7563+
test_cases.emplace_back(new test_snake_fuse(type, {1025, 13, 1, 1})); // large prime, grid-stride
7564+
test_cases.emplace_back(new test_snake_fuse(type, { 128, 16, 1, 1})); // power-of-two
7565+
test_cases.emplace_back(new test_snake_fuse(type, { 256, 192, 1, 1})); // BigVGAN-ish
7566+
// higher-rank shapes: matcher must reject fusion, fallback to naive chain
7567+
test_cases.emplace_back(new test_snake_fuse(type, { 64, 32, 2, 1})); // ne[2] > 1
7568+
test_cases.emplace_back(new test_snake_fuse(type, { 64, 32, 1, 2})); // ne[3] > 1
7569+
test_cases.emplace_back(new test_snake_fuse(type, { 64, 32, 2, 3})); // ne[2] > 1 and ne[3] > 1
75667570
}
75677571

75687572
// glu ops
@@ -9093,9 +9097,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
90939097
test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 384, 4, 1}));
90949098

90959099
// SNAKE activation fusion at BigVGAN scale (T=7680 = 24 kHz x 320 ms, C=192)
9096-
test_cases.emplace_back(new test_snake_fuse(GGML_TYPE_F32, {7680, 192}));
9097-
test_cases.emplace_back(new test_snake_fuse(GGML_TYPE_F16, {7680, 192}));
9098-
test_cases.emplace_back(new test_snake_fuse(GGML_TYPE_BF16, {7680, 192}));
9100+
test_cases.emplace_back(new test_snake_fuse(GGML_TYPE_F32, {7680, 192, 1, 1}));
9101+
test_cases.emplace_back(new test_snake_fuse(GGML_TYPE_F16, {7680, 192, 1, 1}));
9102+
test_cases.emplace_back(new test_snake_fuse(GGML_TYPE_BF16, {7680, 192, 1, 1}));
90999103

91009104
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16416, 1, 128, {8, 1}, {4, 1}, {0, 2, 1, 3}));
91019105
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 1, 16416, {8, 1}, {4, 1}, {0, 1, 2, 3}, 2*16416));

0 commit comments

Comments
 (0)