@@ -6205,7 +6205,7 @@ static void ggml_compute_forward_im2col_f16(
62056205 const ggml_tensor * src1 = dst->src [1 ];
62066206
62076207 GGML_ASSERT (src0->type == GGML_TYPE_F16 );
6208- GGML_ASSERT (src1->type == GGML_TYPE_F32 );
6208+ GGML_ASSERT (src1->type == GGML_TYPE_F16 || src1-> type == GGML_TYPE_F32 );
62096209 GGML_ASSERT ( dst->type == GGML_TYPE_F16 );
62106210
62116211 GGML_TENSOR_BINARY_OP_LOCALS ;
@@ -6236,7 +6236,7 @@ static void ggml_compute_forward_im2col_f16(
62366236 int ofs1 = is_2D ? nb12 : nb11;
62376237
62386238 GGML_ASSERT (nb00 == sizeof (ggml_fp16_t ));
6239- GGML_ASSERT (nb10 == sizeof ( float ));
6239+ GGML_ASSERT (nb10 == ggml_type_size (src1-> type ));
62406240
62416241 // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
62426242 {
@@ -6249,7 +6249,12 @@ static void ggml_compute_forward_im2col_f16(
62496249
62506250 // micro kernel
62516251 ggml_fp16_t * dst_data = wdata + (in*OH *OW + ioh*OW + iow)*(IC *KH *KW ); // [IC, KH, KW]
6252- const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW]
6252+ const float * const src_data_f32 = src1->type == GGML_TYPE_F32
6253+ ? (const float *)((const char *) src1->data + in*ofs0 + iic*ofs1)
6254+ : nullptr ; // [IH, IW]
6255+ const ggml_fp16_t * const src_data_f16 = src1->type == GGML_TYPE_F16
6256+ ? (const ggml_fp16_t *)((const char *) src1->data + in*ofs0 + iic*ofs1)
6257+ : nullptr ; // [IH, IW]
62536258
62546259 for (int64_t ikh = 0 ; ikh < KH ; ikh++) { // 1
62556260 for (int64_t ikw = 0 ; ikw < KW ; ikw++) {
@@ -6259,7 +6264,11 @@ static void ggml_compute_forward_im2col_f16(
62596264 if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW ) {
62606265 dst_data[iic*(KH *KW ) + ikh*KW + ikw] = 0 ;
62616266 } else {
6262- dst_data[iic*(KH *KW ) + ikh*KW + ikw] = GGML_CPU_FP32_TO_FP16 (src_data[iih*IW + iiw]);
6267+ if (src_data_f32 != nullptr ) {
6268+ dst_data[iic*(KH *KW ) + ikh*KW + ikw] = GGML_CPU_FP32_TO_FP16 (src_data_f32[iih*IW + iiw]);
6269+ } else {
6270+ dst_data[iic*(KH *KW ) + ikh*KW + ikw] = src_data_f16[iih*IW + iiw];
6271+ }
62636272 }
62646273 }
62656274 }
0 commit comments