@@ -1421,81 +1421,100 @@ template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat
14211421template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int >;
14221422template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short >;
14231423
1424- kernel void kernel_reglu_f32 (
1424+ template <typename T>
1425+ kernel void kernel_reglu (
14251426 constant ggml_metal_kargs_glu & args,
14261427 device const char * src0,
14271428 device const char * src1,
14281429 device char * dst,
14291430 uint tgpig[[threadgroup_position_in_grid]] ,
14301431 uint tpitg[[thread_position_in_threadgroup]] ,
14311432 uint ntg[[threads_per_threadgroup]] ) {
1432- device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01 ) + args.i00 ;
1433- device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11 ) + args.i10 ;
1434- device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1 );
1433+ device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01 ) + args.i00 ;
1434+ device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11 ) + args.i10 ;
1435+ device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1 );
14351436
14361437 for (int i0 = tpitg; i0 < args.ne0 ; i0 += ntg) {
14371438 const float x0 = src0_row[i0];
14381439 const float x1 = src1_row[i0];
14391440
1440- dst_row[i0] = x0*x1*(x0 > 0 .0f );
1441+ dst_row[i0] = (T)( x0*x1*(x0 > 0 .0f ) );
14411442 }
14421443}
14431444
1444- kernel void kernel_geglu_f32 (
1445+ typedef decltype (kernel_reglu<float >) kernel_reglu_t;
1446+
1447+ template [[host_name("kernel_reglu_f32")]] kernel kernel_reglu_t kernel_reglu<float >;
1448+ template [[host_name("kernel_reglu_f16")]] kernel kernel_reglu_t kernel_reglu<half>;
1449+
1450+ template <typename T>
1451+ kernel void kernel_geglu (
14451452 constant ggml_metal_kargs_glu & args,
14461453 device const char * src0,
14471454 device const char * src1,
14481455 device char * dst,
14491456 uint tgpig[[threadgroup_position_in_grid]] ,
14501457 uint tpitg[[thread_position_in_threadgroup]] ,
14511458 uint ntg[[threads_per_threadgroup]] ) {
1452- device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01 ) + args.i00 ;
1453- device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11 ) + args.i10 ;
1454- device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1 );
1459+ device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01 ) + args.i00 ;
1460+ device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11 ) + args.i10 ;
1461+ device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1 );
14551462
14561463 for (int i0 = tpitg; i0 < args.ne0 ; i0 += ntg) {
14571464 const float x0 = src0_row[i0];
14581465 const float x1 = src1_row[i0];
14591466
14601467 const float gelu = 0 .5f *x0*(1 .0f + precise::tanh (SQRT_2_OVER_PI *x0*(1 .0f + GELU_COEF_A *x0*x0)));
14611468
1462- dst_row[i0] = gelu*x1;
1469+ dst_row[i0] = (T)( gelu*x1) ;
14631470 }
14641471}
14651472
1466- kernel void kernel_swiglu_f32 (
1473+ typedef decltype (kernel_geglu<float >) kernel_geglu_t;
1474+
1475+ template [[host_name("kernel_geglu_f32")]] kernel kernel_geglu_t kernel_geglu<float >;
1476+ template [[host_name("kernel_geglu_f16")]] kernel kernel_geglu_t kernel_geglu<half>;
1477+
1478+ template <typename T>
1479+ kernel void kernel_swiglu (
14671480 constant ggml_metal_kargs_glu & args,
14681481 device const char * src0,
14691482 device const char * src1,
14701483 device char * dst,
14711484 uint tgpig[[threadgroup_position_in_grid]] ,
14721485 uint tpitg[[thread_position_in_threadgroup]] ,
14731486 uint ntg[[threads_per_threadgroup]] ) {
1474- device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01 ) + args.i00 ;
1475- device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11 ) + args.i10 ;
1476- device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1 );
1487+ device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01 ) + args.i00 ;
1488+ device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11 ) + args.i10 ;
1489+ device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1 );
14771490
14781491 for (int i0 = tpitg; i0 < args.ne0 ; i0 += ntg) {
14791492 const float x0 = src0_row[i0];
14801493 const float x1 = src1_row[i0];
14811494
14821495 const float silu = x0 / (1 .0f + exp (-x0));
14831496
1484- dst_row[i0] = silu*x1;
1497+ dst_row[i0] = (T)( silu*x1) ;
14851498 }
14861499}
14871500
1488- kernel void kernel_swiglu_oai_f32 (
1501+ typedef decltype (kernel_swiglu<float >) kernel_swiglu_t;
1502+
1503+ template [[host_name("kernel_swiglu_f32")]] kernel kernel_swiglu_t kernel_swiglu<float >;
1504+ template [[host_name("kernel_swiglu_f16")]] kernel kernel_swiglu_t kernel_swiglu<half>;
1505+
1506+ template <typename T>
1507+ kernel void kernel_swiglu_oai (
14891508 constant ggml_metal_kargs_glu & args,
14901509 device const char * src0,
14911510 device const char * src1,
14921511 device char * dst,
14931512 uint tgpig[[threadgroup_position_in_grid]] ,
14941513 uint tpitg[[thread_position_in_threadgroup]] ,
14951514 uint ntg[[threads_per_threadgroup]] ) {
1496- device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01 ) + args.i00 ;
1497- device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11 ) + args.i10 ;
1498- device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1 );
1515+ device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01 ) + args.i00 ;
1516+ device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11 ) + args.i10 ;
1517+ device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1 );
14991518
15001519 for (int i0 = tpitg; i0 < args.ne0 ; i0 += ntg) {
15011520 float x0 = src0_row[i0];
@@ -1507,54 +1526,71 @@ kernel void kernel_swiglu_oai_f32(
15071526 float out_glu = x0 / (1 .0f + exp (-x0 * args.alpha ));
15081527 out_glu = out_glu * (1 .0f + x1);
15091528
1510- dst_row[i0] = out_glu;
1529+ dst_row[i0] = (T) out_glu;
15111530 }
15121531}
15131532
1514- kernel void kernel_geglu_erf_f32 (
1533+ typedef decltype (kernel_swiglu_oai<float >) kernel_swiglu_oai_t;
1534+
1535+ template [[host_name("kernel_swiglu_oai_f32")]] kernel kernel_swiglu_oai_t kernel_swiglu_oai<float >;
1536+ template [[host_name("kernel_swiglu_oai_f16")]] kernel kernel_swiglu_oai_t kernel_swiglu_oai<half>;
1537+
1538+ template <typename T>
1539+ kernel void kernel_geglu_erf (
15151540 constant ggml_metal_kargs_glu & args,
15161541 device const char * src0,
15171542 device const char * src1,
15181543 device char * dst,
15191544 uint tgpig[[threadgroup_position_in_grid]] ,
15201545 uint tpitg[[thread_position_in_threadgroup]] ,
15211546 uint ntg[[threads_per_threadgroup]] ) {
1522- device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01 ) + args.i00 ;
1523- device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11 ) + args.i10 ;
1524- device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1 );
1547+ device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01 ) + args.i00 ;
1548+ device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11 ) + args.i10 ;
1549+ device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1 );
15251550
15261551 for (int i0 = tpitg; i0 < args.ne0 ; i0 += ntg) {
15271552 const float x0 = src0_row[i0];
15281553 const float x1 = src1_row[i0];
15291554
15301555 const float gelu_erf = 0 .5f *x0*(1 .0f +erf_approx<float >(x0*SQRT_2_INV ));
15311556
1532- dst_row[i0] = gelu_erf*x1;
1557+ dst_row[i0] = (T)( gelu_erf*x1) ;
15331558 }
15341559}
15351560
1536- kernel void kernel_geglu_quick_f32 (
1561+ typedef decltype (kernel_geglu_erf<float >) kernel_geglu_erf_t;
1562+
1563+ template [[host_name("kernel_geglu_erf_f32")]] kernel kernel_geglu_erf_t kernel_geglu_erf<float >;
1564+ template [[host_name("kernel_geglu_erf_f16")]] kernel kernel_geglu_erf_t kernel_geglu_erf<half>;
1565+
1566+ template <typename T>
1567+ kernel void kernel_geglu_quick (
15371568 constant ggml_metal_kargs_glu & args,
15381569 device const char * src0,
15391570 device const char * src1,
15401571 device char * dst,
15411572 uint tgpig[[threadgroup_position_in_grid]] ,
15421573 uint tpitg[[thread_position_in_threadgroup]] ,
15431574 uint ntg[[threads_per_threadgroup]] ) {
1544- device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01 ) + args.i00 ;
1545- device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11 ) + args.i10 ;
1546- device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1 );
1575+ device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01 ) + args.i00 ;
1576+ device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11 ) + args.i10 ;
1577+ device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1 );
15471578
15481579 for (int i0 = tpitg; i0 < args.ne0 ; i0 += ntg) {
15491580 const float x0 = src0_row[i0];
15501581 const float x1 = src1_row[i0];
15511582
15521583 const float gelu_quick = x0*(1 .0f /(1 .0f +exp (GELU_QUICK_COEF *x0)));
15531584
1554- dst_row[i0] = gelu_quick*x1;
1585+ dst_row[i0] = (T)( gelu_quick*x1) ;
15551586 }
15561587}
15571588
1589+ typedef decltype (kernel_geglu_quick<float >) kernel_geglu_quick_t;
1590+
1591+ template [[host_name("kernel_geglu_quick_f32")]] kernel kernel_geglu_quick_t kernel_geglu_quick<float >;
1592+ template [[host_name("kernel_geglu_quick_f16")]] kernel kernel_geglu_quick_t kernel_geglu_quick<half>;
1593+
15581594kernel void kernel_op_sum_f32 (
15591595 constant ggml_metal_kargs_sum & args,
15601596 device const float * src0,
0 commit comments