Skip to content

Commit aa0f189

Browse files
authored
metal : add XIELU unary op (ggml-org#20802)
1 parent be76dd0 commit aa0f189

6 files changed

Lines changed: 22 additions & 0 deletions

File tree

ggml/src/ggml-metal/ggml-metal-device.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary(ggml_metal
250250
case GGML_UNARY_OP_CEIL: op_num = OP_UNARY_NUM_CEIL; break;
251251
case GGML_UNARY_OP_ROUND: op_num = OP_UNARY_NUM_ROUND; break;
252252
case GGML_UNARY_OP_TRUNC: op_num = OP_UNARY_NUM_TRUNC; break;
253+
case GGML_UNARY_OP_XIELU: op_num = OP_UNARY_NUM_XIELU; break;
253254
default: GGML_ABORT("fatal error");
254255
} break;
255256
default: GGML_ABORT("fatal error");

ggml/src/ggml-metal/ggml-metal-device.m

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1043,6 +1043,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
10431043
case GGML_UNARY_OP_CEIL:
10441044
case GGML_UNARY_OP_ROUND:
10451045
case GGML_UNARY_OP_TRUNC:
1046+
case GGML_UNARY_OP_XIELU:
10461047
return ggml_is_contiguous_rows(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16);
10471048
default:
10481049
return false;

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@
127127
#define OP_UNARY_NUM_CEIL 118
128128
#define OP_UNARY_NUM_ROUND 119
129129
#define OP_UNARY_NUM_TRUNC 120
130+
#define OP_UNARY_NUM_XIELU 121
130131

131132
#define OP_SUM_ROWS_NUM_SUM_ROWS 10
132133
#define OP_SUM_ROWS_NUM_MEAN 11

ggml/src/ggml-metal/ggml-metal-ops.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -787,6 +787,13 @@ int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
787787
args.max = ggml_get_op_params_f32(op, 1);
788788
}
789789

790+
if (op->op == GGML_OP_UNARY && ggml_get_unary_op(op) == GGML_UNARY_OP_XIELU) {
791+
args.slope = ggml_get_op_params_f32(op, 1); // alpha_n
792+
args.scale = ggml_get_op_params_f32(op, 2); // alpha_p
793+
args.bias = ggml_get_op_params_f32(op, 3); // beta
794+
args.val = ggml_get_op_params_f32(op, 4); // eps
795+
}
796+
790797
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
791798

792799
if (pipeline.c4) {

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1177,6 +1177,15 @@ kernel void kernel_unary_impl(
11771177
if (FC_OP == OP_UNARY_NUM_TRUNC) {
11781178
dst_ptr[i0] = (T) trunc(x);
11791179
}
1180+
1181+
if (FC_OP == OP_UNARY_NUM_XIELU) {
1182+
const TC xi = x;
1183+
const TC gate = TC(xi > TC(0.0f));
1184+
const TC clamped = fmin(xi, TC(args.val));
1185+
const TC y_pos = TC(args.scale) * xi * xi + TC(args.bias) * xi;
1186+
const TC y_neg = (exp(clamped) - TC(1.0f) - xi) * TC(args.slope) + TC(args.bias) * xi;
1187+
dst_ptr[i0] = (T) (gate * y_pos + (TC(1.0f) - gate) * y_neg);
1188+
}
11801189
}
11811190

11821191
#undef FC_OP

tests/test-backend-ops.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8506,6 +8506,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
85068506
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 20481, 4, 1, 1 }));
85078507

85088508
test_cases.emplace_back(new test_xielu());
8509+
test_cases.emplace_back(new test_xielu(GGML_TYPE_F16));
8510+
test_cases.emplace_back(new test_xielu(GGML_TYPE_F32, { 512, 16, 1, 1 }));
8511+
test_cases.emplace_back(new test_xielu(GGML_TYPE_F16, { 512, 16, 1, 1 }));
85098512

85108513
test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER));
85118514
test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER_DIAG));

0 commit comments

Comments
 (0)