Skip to content

Commit 92080b4

Browse files
nuri-yoonryoo
andauthored
metal : add FLOOR, CEIL, ROUND, TRUNC unary ops (#20930)
Co-authored-by: nryoo <nryoo@nryooui-MacBookPro.local>
1 parent 342d612 commit 92080b4

4 files changed

Lines changed: 28 additions & 0 deletions

File tree

ggml/src/ggml-metal/ggml-metal-device.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,10 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary(ggml_metal
246246
case GGML_UNARY_OP_EXP: op_num = OP_UNARY_NUM_EXP; break;
247247
case GGML_UNARY_OP_SOFTPLUS: op_num = OP_UNARY_NUM_SOFTPLUS; break;
248248
case GGML_UNARY_OP_EXPM1: op_num = OP_UNARY_NUM_EXPM1; break;
249+
case GGML_UNARY_OP_FLOOR: op_num = OP_UNARY_NUM_FLOOR; break;
250+
case GGML_UNARY_OP_CEIL: op_num = OP_UNARY_NUM_CEIL; break;
251+
case GGML_UNARY_OP_ROUND: op_num = OP_UNARY_NUM_ROUND; break;
252+
case GGML_UNARY_OP_TRUNC: op_num = OP_UNARY_NUM_TRUNC; break;
249253
default: GGML_ABORT("fatal error");
250254
} break;
251255
default: GGML_ABORT("fatal error");

ggml/src/ggml-metal/ggml-metal-device.m

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1039,6 +1039,10 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
10391039
case GGML_UNARY_OP_EXP:
10401040
case GGML_UNARY_OP_SOFTPLUS:
10411041
case GGML_UNARY_OP_EXPM1:
1042+
case GGML_UNARY_OP_FLOOR:
1043+
case GGML_UNARY_OP_CEIL:
1044+
case GGML_UNARY_OP_ROUND:
1045+
case GGML_UNARY_OP_TRUNC:
10421046
return ggml_is_contiguous_rows(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16);
10431047
default:
10441048
return false;

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,10 @@
120120
#define OP_UNARY_NUM_EXP 114
121121
#define OP_UNARY_NUM_SOFTPLUS 115
122122
#define OP_UNARY_NUM_EXPM1 116
123+
#define OP_UNARY_NUM_FLOOR 117
124+
#define OP_UNARY_NUM_CEIL 118
125+
#define OP_UNARY_NUM_ROUND 119
126+
#define OP_UNARY_NUM_TRUNC 120
123127

124128
#define OP_SUM_ROWS_NUM_SUM_ROWS 10
125129
#define OP_SUM_ROWS_NUM_MEAN 11

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1094,6 +1094,22 @@ kernel void kernel_unary_impl(
10941094
// TODO: precise implementation
10951095
dst_ptr[i0] = (T) (exp(x) - 1);
10961096
}
1097+
1098+
if (FC_OP == OP_UNARY_NUM_FLOOR) {
1099+
dst_ptr[i0] = (T) floor(x);
1100+
}
1101+
1102+
if (FC_OP == OP_UNARY_NUM_CEIL) {
1103+
dst_ptr[i0] = (T) ceil(x);
1104+
}
1105+
1106+
if (FC_OP == OP_UNARY_NUM_ROUND) {
1107+
dst_ptr[i0] = (T) round(x);
1108+
}
1109+
1110+
if (FC_OP == OP_UNARY_NUM_TRUNC) {
1111+
dst_ptr[i0] = (T) trunc(x);
1112+
}
10971113
}
10981114

10991115
#undef FC_OP

0 commit comments

Comments
 (0)