Skip to content

Commit 95b8b8e

Browse files
metal: template GLU kernels to support f16/f32 (ggml-org#23882)
Drops the hardcoded f32 GLU kernels in favor of a single template. We now load/store in the native tensor type (half or float) to save memory bandwidth, but keep the actual ALU compute in float to avoid exploding math in geglu/swiglu. Also opened up the dispatch gate to allow f16 inputs.
1 parent 55ac090 commit 95b8b8e

2 files changed

Lines changed: 67 additions & 31 deletions

File tree

ggml/src/ggml-metal/ggml-metal-device.m

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1107,7 +1107,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
11071107
case GGML_GLU_OP_SWIGLU_OAI:
11081108
case GGML_GLU_OP_GEGLU_ERF:
11091109
case GGML_GLU_OP_GEGLU_QUICK:
1110-
return ggml_is_contiguous_1(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
1110+
return ggml_is_contiguous_1(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16);
11111111
default:
11121112
return false;
11131113
}

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 66 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1421,81 +1421,100 @@ template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat
14211421
template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>;
14221422
template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>;
14231423

1424-
kernel void kernel_reglu_f32(
1424+
template<typename T>
1425+
kernel void kernel_reglu(
14251426
constant ggml_metal_kargs_glu & args,
14261427
device const char * src0,
14271428
device const char * src1,
14281429
device char * dst,
14291430
uint tgpig[[threadgroup_position_in_grid]],
14301431
uint tpitg[[thread_position_in_threadgroup]],
14311432
uint ntg[[threads_per_threadgroup]]) {
1432-
device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1433-
device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1434-
device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
1433+
device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1434+
device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1435+
device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1);
14351436

14361437
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
14371438
const float x0 = src0_row[i0];
14381439
const float x1 = src1_row[i0];
14391440

1440-
dst_row[i0] = x0*x1*(x0 > 0.0f);
1441+
dst_row[i0] = (T)(x0*x1*(x0 > 0.0f));
14411442
}
14421443
}
14431444

1444-
kernel void kernel_geglu_f32(
1445+
typedef decltype(kernel_reglu<float>) kernel_reglu_t;
1446+
1447+
template [[host_name("kernel_reglu_f32")]] kernel kernel_reglu_t kernel_reglu<float>;
1448+
template [[host_name("kernel_reglu_f16")]] kernel kernel_reglu_t kernel_reglu<half>;
1449+
1450+
template<typename T>
1451+
kernel void kernel_geglu(
14451452
constant ggml_metal_kargs_glu & args,
14461453
device const char * src0,
14471454
device const char * src1,
14481455
device char * dst,
14491456
uint tgpig[[threadgroup_position_in_grid]],
14501457
uint tpitg[[thread_position_in_threadgroup]],
14511458
uint ntg[[threads_per_threadgroup]]) {
1452-
device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1453-
device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1454-
device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
1459+
device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1460+
device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1461+
device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1);
14551462

14561463
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
14571464
const float x0 = src0_row[i0];
14581465
const float x1 = src1_row[i0];
14591466

14601467
const float gelu = 0.5f*x0*(1.0f + precise::tanh(SQRT_2_OVER_PI*x0*(1.0f + GELU_COEF_A*x0*x0)));
14611468

1462-
dst_row[i0] = gelu*x1;
1469+
dst_row[i0] = (T)(gelu*x1);
14631470
}
14641471
}
14651472

1466-
kernel void kernel_swiglu_f32(
1473+
typedef decltype(kernel_geglu<float>) kernel_geglu_t;
1474+
1475+
template [[host_name("kernel_geglu_f32")]] kernel kernel_geglu_t kernel_geglu<float>;
1476+
template [[host_name("kernel_geglu_f16")]] kernel kernel_geglu_t kernel_geglu<half>;
1477+
1478+
template<typename T>
1479+
kernel void kernel_swiglu(
14671480
constant ggml_metal_kargs_glu & args,
14681481
device const char * src0,
14691482
device const char * src1,
14701483
device char * dst,
14711484
uint tgpig[[threadgroup_position_in_grid]],
14721485
uint tpitg[[thread_position_in_threadgroup]],
14731486
uint ntg[[threads_per_threadgroup]]) {
1474-
device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1475-
device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1476-
device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
1487+
device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1488+
device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1489+
device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1);
14771490

14781491
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
14791492
const float x0 = src0_row[i0];
14801493
const float x1 = src1_row[i0];
14811494

14821495
const float silu = x0 / (1.0f + exp(-x0));
14831496

1484-
dst_row[i0] = silu*x1;
1497+
dst_row[i0] = (T)(silu*x1);
14851498
}
14861499
}
14871500

1488-
kernel void kernel_swiglu_oai_f32(
1501+
typedef decltype(kernel_swiglu<float>) kernel_swiglu_t;
1502+
1503+
template [[host_name("kernel_swiglu_f32")]] kernel kernel_swiglu_t kernel_swiglu<float>;
1504+
template [[host_name("kernel_swiglu_f16")]] kernel kernel_swiglu_t kernel_swiglu<half>;
1505+
1506+
template<typename T>
1507+
kernel void kernel_swiglu_oai(
14891508
constant ggml_metal_kargs_glu & args,
14901509
device const char * src0,
14911510
device const char * src1,
14921511
device char * dst,
14931512
uint tgpig[[threadgroup_position_in_grid]],
14941513
uint tpitg[[thread_position_in_threadgroup]],
14951514
uint ntg[[threads_per_threadgroup]]) {
1496-
device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1497-
device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1498-
device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
1515+
device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1516+
device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1517+
device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1);
14991518

15001519
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
15011520
float x0 = src0_row[i0];
@@ -1507,54 +1526,71 @@ kernel void kernel_swiglu_oai_f32(
15071526
float out_glu = x0 / (1.0f + exp(-x0 * args.alpha));
15081527
out_glu = out_glu * (1.0f + x1);
15091528

1510-
dst_row[i0] = out_glu;
1529+
dst_row[i0] = (T)out_glu;
15111530
}
15121531
}
15131532

1514-
kernel void kernel_geglu_erf_f32(
1533+
typedef decltype(kernel_swiglu_oai<float>) kernel_swiglu_oai_t;
1534+
1535+
template [[host_name("kernel_swiglu_oai_f32")]] kernel kernel_swiglu_oai_t kernel_swiglu_oai<float>;
1536+
template [[host_name("kernel_swiglu_oai_f16")]] kernel kernel_swiglu_oai_t kernel_swiglu_oai<half>;
1537+
1538+
template<typename T>
1539+
kernel void kernel_geglu_erf(
15151540
constant ggml_metal_kargs_glu & args,
15161541
device const char * src0,
15171542
device const char * src1,
15181543
device char * dst,
15191544
uint tgpig[[threadgroup_position_in_grid]],
15201545
uint tpitg[[thread_position_in_threadgroup]],
15211546
uint ntg[[threads_per_threadgroup]]) {
1522-
device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1523-
device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1524-
device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
1547+
device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1548+
device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1549+
device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1);
15251550

15261551
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
15271552
const float x0 = src0_row[i0];
15281553
const float x1 = src1_row[i0];
15291554

15301555
const float gelu_erf = 0.5f*x0*(1.0f+erf_approx<float>(x0*SQRT_2_INV));
15311556

1532-
dst_row[i0] = gelu_erf*x1;
1557+
dst_row[i0] = (T)(gelu_erf*x1);
15331558
}
15341559
}
15351560

1536-
kernel void kernel_geglu_quick_f32(
1561+
typedef decltype(kernel_geglu_erf<float>) kernel_geglu_erf_t;
1562+
1563+
template [[host_name("kernel_geglu_erf_f32")]] kernel kernel_geglu_erf_t kernel_geglu_erf<float>;
1564+
template [[host_name("kernel_geglu_erf_f16")]] kernel kernel_geglu_erf_t kernel_geglu_erf<half>;
1565+
1566+
template<typename T>
1567+
kernel void kernel_geglu_quick(
15371568
constant ggml_metal_kargs_glu & args,
15381569
device const char * src0,
15391570
device const char * src1,
15401571
device char * dst,
15411572
uint tgpig[[threadgroup_position_in_grid]],
15421573
uint tpitg[[thread_position_in_threadgroup]],
15431574
uint ntg[[threads_per_threadgroup]]) {
1544-
device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1545-
device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1546-
device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
1575+
device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1576+
device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1577+
device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1);
15471578

15481579
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
15491580
const float x0 = src0_row[i0];
15501581
const float x1 = src1_row[i0];
15511582

15521583
const float gelu_quick = x0*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x0)));
15531584

1554-
dst_row[i0] = gelu_quick*x1;
1585+
dst_row[i0] = (T)(gelu_quick*x1);
15551586
}
15561587
}
15571588

1589+
typedef decltype(kernel_geglu_quick<float>) kernel_geglu_quick_t;
1590+
1591+
template [[host_name("kernel_geglu_quick_f32")]] kernel kernel_geglu_quick_t kernel_geglu_quick<float>;
1592+
template [[host_name("kernel_geglu_quick_f16")]] kernel kernel_geglu_quick_t kernel_geglu_quick<half>;
1593+
15581594
kernel void kernel_op_sum_f32(
15591595
constant ggml_metal_kargs_sum & args,
15601596
device const float * src0,

0 commit comments

Comments
 (0)