Skip to content

Commit fcae601

Browse files
vulkan: add cpy bf16 -> f32 pipelines (ggml-org#22677)
1 parent 7ba22c6 commit fcae601

4 files changed

Lines changed: 23 additions & 5 deletions

File tree

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -759,8 +759,8 @@ struct vk_device_struct {
759759
vk_pipeline pipeline_pad_f32;
760760
vk_pipeline pipeline_roll_f32;
761761
vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32;
762-
vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16, pipeline_cpy_f32_i32, pipeline_cpy_i32_f32;
763-
vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16, pipeline_contig_cpy_f32_i32, pipeline_contig_cpy_i32_f32;
762+
vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16, pipeline_cpy_bf16_f32, pipeline_cpy_f32_i32, pipeline_cpy_i32_f32;
763+
vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16, pipeline_contig_cpy_bf16_f32, pipeline_contig_cpy_f32_i32, pipeline_contig_cpy_i32_f32;
764764
vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT];
765765
vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT];
766766
vk_pipeline pipeline_cpy_transpose_16, pipeline_cpy_transpose_32;
@@ -4572,6 +4572,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
45724572
ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
45734573
ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f32, "cpy_f16_f32", cpy_f16_f32_len, cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
45744574
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_bf16,"cpy_f32_bf16",cpy_f32_bf16_len,cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
4575+
ggml_vk_create_pipeline(device, device->pipeline_cpy_bf16_f32,"cpy_bf16_f32",cpy_bf16_f32_len,cpy_bf16_f32_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
45754576
ggml_vk_create_pipeline(device, device->pipeline_cpy_i32_f32, "cpy_i32_f32", cpy_i32_f32_len, cpy_i32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
45764577
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_i32, "cpy_f32_i32", cpy_f32_i32_len, cpy_f32_i32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
45774578

@@ -4580,6 +4581,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
45804581
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
45814582
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f32, "contig_cpy_f16_f32", contig_cpy_f16_f32_len, contig_cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
45824583
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_bf16,"contig_cpy_f32_bf16",contig_cpy_f32_bf16_len,contig_cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
4584+
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_bf16_f32,"contig_cpy_bf16_f32",contig_cpy_bf16_f32_len,contig_cpy_bf16_f32_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
45834585
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_i32_f32, "contig_cpy_i32_f32", contig_cpy_i32_f32_len, contig_cpy_i32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
45844586
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_i32, "contig_cpy_f32_i32", contig_cpy_f32_i32_len, contig_cpy_f32_i32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
45854587

@@ -7544,6 +7546,13 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const
75447546
return ctx->device->pipeline_cpy_f32_bf16;
75457547
}
75467548
}
7549+
if (src->type == GGML_TYPE_BF16 && to == GGML_TYPE_F32) {
7550+
if (contig) {
7551+
return ctx->device->pipeline_contig_cpy_bf16_f32;
7552+
} else {
7553+
return ctx->device->pipeline_cpy_bf16_f32;
7554+
}
7555+
}
75477556
if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_I32) {
75487557
if (contig) {
75497558
return ctx->device->pipeline_contig_cpy_f32_i32;
@@ -15974,6 +15983,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1597415983
if (src1_type == GGML_TYPE_F32) {
1597515984
switch (src0_type) {
1597615985
case GGML_TYPE_F16:
15986+
case GGML_TYPE_BF16:
1597715987
case GGML_TYPE_Q1_0:
1597815988
case GGML_TYPE_Q4_0:
1597915989
case GGML_TYPE_Q4_1:

ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ void main() {
1919
if (idx + (num_iter-1)*num_threads < p.ne) {
2020
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
2121

22-
#if defined(DATA_D_BF16)
22+
#if defined(DATA_A_BF16)
23+
data_d[get_doffset() + idx] = D_TYPE(bf16_to_fp32(uint32_t(data_a[get_aoffset() + idx])));
24+
#elif defined(DATA_D_BF16)
2325
float f = float(data_a[get_aoffset() + idx]);
2426
data_d[get_doffset() + idx] = D_TYPE(fp32_to_bf16(f));
2527
#elif !defined(OPTIMIZATION_ERROR_WORKAROUND)
@@ -35,7 +37,9 @@ void main() {
3537
continue;
3638
}
3739

38-
#if defined(DATA_D_BF16)
40+
#if defined(DATA_A_BF16)
41+
data_d[get_doffset() + idx] = D_TYPE(bf16_to_fp32(uint32_t(data_a[get_aoffset() + idx])));
42+
#elif defined(DATA_D_BF16)
3943
float f = float(data_a[get_aoffset() + idx]);
4044
data_d[get_doffset() + idx] = D_TYPE(fp32_to_bf16(f));
4145
#elif !defined(OPTIMIZATION_ERROR_WORKAROUND)

ggml/src/ggml-vulkan/vulkan-shaders/copy.comp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@ void main() {
1212
return;
1313
}
1414

15-
#if defined(DATA_D_BF16)
15+
#if defined(DATA_A_BF16)
16+
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(bf16_to_fp32(uint32_t(data_a[get_aoffset() + src0_idx(idx)])));
17+
#elif defined(DATA_D_BF16)
1618
float f = float(data_a[get_aoffset() + src0_idx(idx)]);
1719
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(fp32_to_bf16(f));
1820
#elif !defined(OPTIMIZATION_ERROR_WORKAROUND)

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -731,13 +731,15 @@ void process_shaders() {
731731
string_to_spv("cpy_f16_f16", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
732732
string_to_spv("cpy_f16_f32", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
733733
string_to_spv("cpy_f32_bf16","copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}});
734+
string_to_spv("cpy_bf16_f32","copy.comp", {{"A_TYPE", "uint16_t"}, {"D_TYPE", "float"}, {"DATA_A_BF16", "1"}});
734735
string_to_spv("contig_cpy_f32_f32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
735736
string_to_spv("contig_cpy_f32_i32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "int"}});
736737
string_to_spv("contig_cpy_i32_f32", "contig_copy.comp", {{"A_TYPE", "int"}, {"D_TYPE", "float"}});
737738
string_to_spv("contig_cpy_f32_f16", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
738739
string_to_spv("contig_cpy_f16_f16", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
739740
string_to_spv("contig_cpy_f16_f32", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
740741
string_to_spv("contig_cpy_f32_bf16","contig_copy.comp",{{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}});
742+
string_to_spv("contig_cpy_bf16_f32","contig_copy.comp",{{"A_TYPE", "uint16_t"}, {"D_TYPE", "float"}, {"DATA_A_BF16", "1"}});
741743
string_to_spv("cpy_f32_i32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "int"}});
742744
string_to_spv("cpy_i32_f32", "copy.comp", {{"A_TYPE", "int"}, {"D_TYPE", "float"}});
743745

0 commit comments

Comments
 (0)