Skip to content

Commit f6da02c

Browse files
David366AIggerganov
authored andcommitted
ggml : extend im2col f16 (ggml/1434)
* examples/yolo: fix load_model memory leak * fix/issue-1433 ggml_compute_forward_im2col_f16 assert error * fix/issue-1433
1 parent dddca02 commit f6da02c

1 file changed

Lines changed: 13 additions & 4 deletions

File tree

ggml/src/ggml-cpu/ops.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6205,7 +6205,7 @@ static void ggml_compute_forward_im2col_f16(
62056205
const ggml_tensor * src1 = dst->src[1];
62066206

62076207
GGML_ASSERT(src0->type == GGML_TYPE_F16);
6208-
GGML_ASSERT(src1->type == GGML_TYPE_F32);
6208+
GGML_ASSERT(src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
62096209
GGML_ASSERT( dst->type == GGML_TYPE_F16);
62106210

62116211
GGML_TENSOR_BINARY_OP_LOCALS;
@@ -6236,7 +6236,7 @@ static void ggml_compute_forward_im2col_f16(
62366236
int ofs1 = is_2D ? nb12 : nb11;
62376237

62386238
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
6239-
GGML_ASSERT(nb10 == sizeof(float));
6239+
GGML_ASSERT(nb10 == ggml_type_size(src1->type));
62406240

62416241
// im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
62426242
{
@@ -6249,7 +6249,12 @@ static void ggml_compute_forward_im2col_f16(
62496249

62506250
// micro kernel
62516251
ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
6252-
const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW]
6252+
const float * const src_data_f32 = src1->type == GGML_TYPE_F32
6253+
? (const float *)((const char *) src1->data + in*ofs0 + iic*ofs1)
6254+
: nullptr; // [IH, IW]
6255+
const ggml_fp16_t * const src_data_f16 = src1->type == GGML_TYPE_F16
6256+
? (const ggml_fp16_t *)((const char *) src1->data + in*ofs0 + iic*ofs1)
6257+
: nullptr; // [IH, IW]
62536258

62546259
for (int64_t ikh = 0; ikh < KH; ikh++) { // 1
62556260
for (int64_t ikw = 0; ikw < KW; ikw++) {
@@ -6259,7 +6264,11 @@ static void ggml_compute_forward_im2col_f16(
62596264
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
62606265
dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;
62616266
} else {
6262-
dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_CPU_FP32_TO_FP16(src_data[iih*IW + iiw]);
6267+
if (src_data_f32 != nullptr) {
6268+
dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_CPU_FP32_TO_FP16(src_data_f32[iih*IW + iiw]);
6269+
} else {
6270+
dst_data[iic*(KH*KW) + ikh*KW + ikw] = src_data_f16[iih*IW + iiw];
6271+
}
62636272
}
62646273
}
62656274
}

0 commit comments

Comments
 (0)