@@ -816,14 +816,10 @@ struct vk_device_struct {
816816 vk_pipeline pipeline_concat_i8, pipeline_concat_i16, pipeline_concat_i32, pipeline_concat_i64;
817817 vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bicubic_f32, pipeline_upscale_bilinear_antialias_f32;
818818 vk_pipeline pipeline_scale_f32;
819- vk_pipeline pipeline_sqr_f32;
820- vk_pipeline pipeline_sqrt_f32;
821- vk_pipeline pipeline_sin_f32;
822- vk_pipeline pipeline_cos_f32;
823819 vk_pipeline pipeline_log[2];
824820 vk_pipeline pipeline_tri[2];
825821 vk_pipeline pipeline_diag[2];
826- vk_pipeline pipeline_clamp_f32 ;
822+ vk_pipeline pipeline_clamp[2] ;
827823 vk_pipeline pipeline_pad_f32;
828824 vk_pipeline pipeline_roll_f32;
829825 vk_pipeline pipeline_repeat_i32, pipeline_repeat_back_f32;
@@ -855,6 +851,10 @@ struct vk_device_struct {
855851 vk_pipeline pipeline_gelu_quick[2];
856852 vk_pipeline pipeline_silu[2];
857853 vk_pipeline pipeline_relu[2];
854+ vk_pipeline pipeline_sqr[2];
855+ vk_pipeline pipeline_sqrt[2];
856+ vk_pipeline pipeline_sin[2];
857+ vk_pipeline pipeline_cos[2];
858858 vk_pipeline pipeline_xielu[2];
859859 vk_pipeline pipeline_neg[2];
860860 vk_pipeline pipeline_tanh[2];
@@ -886,7 +886,7 @@ struct vk_device_struct {
886886 vk_pipeline pipeline_geglu_erf[2];
887887 vk_pipeline pipeline_geglu_quick[2];
888888
889- vk_pipeline pipeline_leaky_relu_f32 ;
889+ vk_pipeline pipeline_leaky_relu[2] ;
890890 vk_pipeline pipeline_silu_back_f32;
891891 vk_pipeline pipeline_diag_mask_inf_f32;
892892 vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
@@ -4972,7 +4972,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
49724972 }
49734973 ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_nc_push_constants), {1, 1, 1}, {}, 1);
49744974
4975- ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants ), {1, 1, 1}, {}, 1);
4975+ ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_unary_push_constants ), {1, 1, 1}, {}, 1);
49764976 ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
49774977
49784978 ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true);
@@ -5092,11 +5092,6 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
50925092
50935093 ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
50945094
5095- ggml_vk_create_pipeline(device, device->pipeline_sqr_f32, "sqr_f32", sqr_f32_len, sqr_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
5096- ggml_vk_create_pipeline(device, device->pipeline_sqrt_f32, "sqrt_f32", sqrt_f32_len, sqrt_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
5097- ggml_vk_create_pipeline(device, device->pipeline_sin_f32, "sin_f32", sin_f32_len, sin_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
5098- ggml_vk_create_pipeline(device, device->pipeline_cos_f32, "cos_f32", cos_f32_len, cos_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
5099-
51005095 ggml_vk_create_pipeline(device, device->pipeline_log[0], "log_f32", log_f32_len, log_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
51015096 ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16", log_f16_len, log_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
51025097
@@ -5106,8 +5101,6 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
51065101 ggml_vk_create_pipeline(device, device->pipeline_diag[0], "diag_f32", diag_f32_len, diag_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
51075102 ggml_vk_create_pipeline(device, device->pipeline_diag[1], "diag_f16", diag_f16_len, diag_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
51085103
5109- ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
5110-
51115104 ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_pad_push_constants), {512, 1, 1}, {}, 1);
51125105
51135106 ggml_vk_create_pipeline(device, device->pipeline_roll_f32, "roll_f32", roll_f32_len, roll_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
@@ -5127,6 +5120,12 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
51275120 CREATE_UNARY(gelu_quick)
51285121 CREATE_UNARY(silu)
51295122 CREATE_UNARY(relu)
5123+ CREATE_UNARY(sqr)
5124+ CREATE_UNARY(sqrt)
5125+ CREATE_UNARY(sin)
5126+ CREATE_UNARY(cos)
5127+ CREATE_UNARY(clamp)
5128+ CREATE_UNARY(leaky_relu)
51305129 CREATE_UNARY(xielu)
51315130 CREATE_UNARY(neg)
51325131 CREATE_UNARY(tanh)
@@ -5166,7 +5165,6 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
51665165 CREATE_GLU(geglu_quick)
51675166#undef CREATE_GLU
51685167
5169- ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
51705168 ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
51715169
51725170 ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true);
@@ -10521,23 +10519,27 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
1052110519 }
1052210520 return nullptr;
1052310521 case GGML_OP_SQR:
10524- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
10525- return ctx->device->pipeline_sqr_f32;
10522+ if (src0->type == dst->type &&
10523+ (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
10524+ return ctx->device->pipeline_sqr[dst->type == GGML_TYPE_F16];
1052610525 }
1052710526 return nullptr;
1052810527 case GGML_OP_SQRT:
10529- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
10530- return ctx->device->pipeline_sqrt_f32;
10528+ if (src0->type == dst->type &&
10529+ (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
10530+ return ctx->device->pipeline_sqrt[dst->type == GGML_TYPE_F16];
1053110531 }
1053210532 return nullptr;
1053310533 case GGML_OP_SIN:
10534- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
10535- return ctx->device->pipeline_sin_f32;
10534+ if (src0->type == dst->type &&
10535+ (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
10536+ return ctx->device->pipeline_sin[dst->type == GGML_TYPE_F16];
1053610537 }
1053710538 return nullptr;
1053810539 case GGML_OP_COS:
10539- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
10540- return ctx->device->pipeline_cos_f32;
10540+ if (src0->type == dst->type &&
10541+ (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
10542+ return ctx->device->pipeline_cos[dst->type == GGML_TYPE_F16];
1054110543 }
1054210544 return nullptr;
1054310545 case GGML_OP_LOG:
@@ -10559,8 +10561,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
1055910561 }
1056010562 return nullptr;
1056110563 case GGML_OP_CLAMP:
10562- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
10563- return ctx->device->pipeline_clamp_f32;
10564+ if (src0->type == dst->type &&
10565+ (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
10566+ return ctx->device->pipeline_clamp[dst->type == GGML_TYPE_F16];
1056410567 }
1056510568 return nullptr;
1056610569 case GGML_OP_PAD:
@@ -10928,8 +10931,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
1092810931 }
1092910932 return nullptr;
1093010933 case GGML_OP_LEAKY_RELU:
10931- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
10932- return ctx->device->pipeline_leaky_relu_f32;
10934+ if (src0->type == dst->type &&
10935+ (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
10936+ return ctx->device->pipeline_leaky_relu[dst->type == GGML_TYPE_F16];
1093310937 }
1093410938 return nullptr;
1093510939 case GGML_OP_CONV_2D:
@@ -11431,6 +11435,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
1143111435 case GGML_OP_TRI:
1143211436 case GGML_OP_DIAG:
1143311437 case GGML_OP_CLAMP:
11438+ case GGML_OP_LEAKY_RELU:
1143411439 case GGML_OP_PAD:
1143511440 case GGML_OP_ROLL:
1143611441 case GGML_OP_REPEAT:
@@ -12297,8 +12302,10 @@ static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx,
1229712302
1229812303static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
1229912304 float * op_params = (float *)dst->op_params;
12305+ vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
12306+ p.param1 = op_params[0];
1230012307
12301- ggml_vk_op_f32<vk_op_push_constants> (ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f, 0.0f, 0.0f } );
12308+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_NORM, std::move(p) );
1230212309}
1230312310
1230412311static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
@@ -13399,7 +13406,10 @@ static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx
1339913406
1340013407static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
1340113408 const float * op_params = (const float *)dst->op_params;
13402- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f, 0.0f, 0.0f });
13409+ vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
13410+ p.param1 = op_params[0];
13411+
13412+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, std::move(p));
1340313413}
1340413414
1340513415#ifdef GGML_VULKAN_RUN_TESTS
@@ -17325,12 +17335,11 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1732517335 case GGML_OP_TRANSPOSE:
1732617336 case GGML_OP_RMS_NORM:
1732717337 return true;
17328- case GGML_OP_NORM:
1732917338 case GGML_OP_GROUP_NORM:
1733017339 return ggml_is_contiguous(op->src[0]);
17340+ case GGML_OP_NORM:
1733117341 case GGML_OP_L2_NORM:
17332- return ggml_is_contiguous_rows(op->src[0]) &&
17333- op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
17342+ return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
1733417343 case GGML_OP_ADD:
1733517344 case GGML_OP_SUB:
1733617345 case GGML_OP_MUL:
@@ -17349,8 +17358,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1734917358 case GGML_OP_SIN:
1735017359 case GGML_OP_COS:
1735117360 case GGML_OP_CLAMP:
17352- return op->src[0]->type == GGML_TYPE_F32;
1735317361 case GGML_OP_LEAKY_RELU:
17362+ return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
17363+ op->type == op->src[0]->type;
1735417364 case GGML_OP_OPT_STEP_ADAMW:
1735517365 case GGML_OP_OPT_STEP_SGD:
1735617366 return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
0 commit comments