Skip to content

Commit 4681f2e

Browse files
authored
Optimize x86 fp16s innerproduct gemm to eliminate loop-carried stalls (#6682)
1 parent b9c6e63 commit 4681f2e

1 file changed

Lines changed: 99 additions & 32 deletions

File tree

src/layer/x86/innerproduct_gemm_fp.h

Lines changed: 99 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -180,14 +180,43 @@ static void innerproduct_gemm_sse(const Mat& bottom_blob, Mat& top_blob, const M
180180
#endif
181181
const float* m = bottom_blob.row(j);
182182

183-
__m512 _sum = _mm512_setzero_ps();
183+
__m512 _sum0 = _mm512_setzero_ps();
184+
__m512 _sum1 = _mm512_setzero_ps();
185+
__m512 _sum2 = _mm512_setzero_ps();
186+
__m512 _sum3 = _mm512_setzero_ps();
184187

185188
if (bias_data_ptr)
186189
{
187-
_sum = _mm512_loadu_ps(bias_data_ptr + p * 16);
190+
_sum0 = _mm512_loadu_ps(bias_data_ptr + p * 16);
188191
}
189192

190193
int i = 0;
194+
for (; i + 3 < num_input; i += 4)
195+
{
196+
__m512 _val0 = _mm512_set1_ps(m[0]);
197+
__m512 _val1 = _mm512_set1_ps(m[1]);
198+
__m512 _val2 = _mm512_set1_ps(m[2]);
199+
__m512 _val3 = _mm512_set1_ps(m[3]);
200+
#if NCNN_IMPL_FP16S
201+
__m512 _w0 = _mm512_cvtph_ps(_mm256_lddqu_si256((const __m256i*)kptr));
202+
__m512 _w1 = _mm512_cvtph_ps(_mm256_lddqu_si256((const __m256i*)(kptr + 16)));
203+
__m512 _w2 = _mm512_cvtph_ps(_mm256_lddqu_si256((const __m256i*)(kptr + 32)));
204+
__m512 _w3 = _mm512_cvtph_ps(_mm256_lddqu_si256((const __m256i*)(kptr + 48)));
205+
#else
206+
__m512 _w0 = _mm512_loadu_ps(kptr);
207+
__m512 _w1 = _mm512_loadu_ps(kptr + 16);
208+
__m512 _w2 = _mm512_loadu_ps(kptr + 32);
209+
__m512 _w3 = _mm512_loadu_ps(kptr + 48);
210+
#endif
211+
212+
_sum0 = _mm512_fmadd_ps(_val0, _w0, _sum0);
213+
_sum1 = _mm512_fmadd_ps(_val1, _w1, _sum1);
214+
_sum2 = _mm512_fmadd_ps(_val2, _w2, _sum2);
215+
_sum3 = _mm512_fmadd_ps(_val3, _w3, _sum3);
216+
217+
m += 4;
218+
kptr += 64;
219+
}
191220
for (; i < num_input; i++)
192221
{
193222
__m512 _val = _mm512_set1_ps(m[0]);
@@ -197,15 +226,19 @@ static void innerproduct_gemm_sse(const Mat& bottom_blob, Mat& top_blob, const M
197226
__m512 _w = _mm512_loadu_ps(kptr);
198227
#endif
199228

200-
_sum = _mm512_fmadd_ps(_val, _w, _sum);
229+
_sum0 = _mm512_fmadd_ps(_val, _w, _sum0);
201230

202231
m += 1;
203232
kptr += 16;
204233
}
205234

206-
_sum = activation_avx512(_sum, activation_type, activation_params);
235+
_sum0 = _mm512_add_ps(_sum0, _sum1);
236+
_sum2 = _mm512_add_ps(_sum2, _sum3);
237+
_sum0 = _mm512_add_ps(_sum0, _sum2);
207238

208-
_mm512_storeu_ps(outptr, _sum);
239+
_sum0 = activation_avx512(_sum0, activation_type, activation_params);
240+
241+
_mm512_storeu_ps(outptr, _sum0);
209242
outptr += 16;
210243
}
211244
}
@@ -695,6 +728,10 @@ static void innerproduct_gemm_sse(const Mat& bottom_blob, Mat& top_blob, const M
695728
__m256 _sum1 = _mm256_setzero_ps();
696729
__m256 _sum2 = _mm256_setzero_ps();
697730
__m256 _sum3 = _mm256_setzero_ps();
731+
__m256 _sum4 = _mm256_setzero_ps();
732+
__m256 _sum5 = _mm256_setzero_ps();
733+
__m256 _sum6 = _mm256_setzero_ps();
734+
__m256 _sum7 = _mm256_setzero_ps();
698735

699736
if (bias_data_ptr)
700737
{
@@ -709,12 +746,10 @@ static void innerproduct_gemm_sse(const Mat& bottom_blob, Mat& top_blob, const M
709746
__m256 _val2 = _mm256_broadcast_ss(m + 2);
710747
__m256 _val3 = _mm256_broadcast_ss(m + 3);
711748
#if NCNN_IMPL_FP16S
712-
__m256i _w01 = _mm256_lddqu_si256((const __m256i*)kptr);
713-
__m256i _w23 = _mm256_lddqu_si256((const __m256i*)(kptr + 16));
714-
__m256 _w0 = _mm256_cvtph_ps(_mm256_extractf128_si256(_w01, 0));
715-
__m256 _w1 = _mm256_cvtph_ps(_mm256_extractf128_si256(_w01, 1));
716-
__m256 _w2 = _mm256_cvtph_ps(_mm256_extractf128_si256(_w23, 0));
717-
__m256 _w3 = _mm256_cvtph_ps(_mm256_extractf128_si256(_w23, 1));
749+
__m256 _w0 = _mm256_cvtph_ps(_mm_lddqu_si128((const __m128i*)kptr));
750+
__m256 _w1 = _mm256_cvtph_ps(_mm_lddqu_si128((const __m128i*)(kptr + 8)));
751+
__m256 _w2 = _mm256_cvtph_ps(_mm_lddqu_si128((const __m128i*)(kptr + 16)));
752+
__m256 _w3 = _mm256_cvtph_ps(_mm_lddqu_si128((const __m128i*)(kptr + 24)));
718753
#else
719754
__m256 _w0 = _mm256_loadu_ps(kptr);
720755
__m256 _w1 = _mm256_loadu_ps(kptr + 8);
@@ -732,23 +767,21 @@ static void innerproduct_gemm_sse(const Mat& bottom_blob, Mat& top_blob, const M
732767
__m256 _val6 = _mm256_broadcast_ss(m + 6);
733768
__m256 _val7 = _mm256_broadcast_ss(m + 7);
734769
#if NCNN_IMPL_FP16S
735-
__m256i _w45 = _mm256_lddqu_si256((const __m256i*)(kptr + 32));
736-
__m256i _w67 = _mm256_lddqu_si256((const __m256i*)(kptr + 48));
737-
__m256 _w4 = _mm256_cvtph_ps(_mm256_extractf128_si256(_w45, 0));
738-
__m256 _w5 = _mm256_cvtph_ps(_mm256_extractf128_si256(_w45, 1));
739-
__m256 _w6 = _mm256_cvtph_ps(_mm256_extractf128_si256(_w67, 0));
740-
__m256 _w7 = _mm256_cvtph_ps(_mm256_extractf128_si256(_w67, 1));
770+
__m256 _w4 = _mm256_cvtph_ps(_mm_lddqu_si128((const __m128i*)(kptr + 32)));
771+
__m256 _w5 = _mm256_cvtph_ps(_mm_lddqu_si128((const __m128i*)(kptr + 40)));
772+
__m256 _w6 = _mm256_cvtph_ps(_mm_lddqu_si128((const __m128i*)(kptr + 48)));
773+
__m256 _w7 = _mm256_cvtph_ps(_mm_lddqu_si128((const __m128i*)(kptr + 56)));
741774
#else
742775
__m256 _w4 = _mm256_loadu_ps(kptr + 32);
743776
__m256 _w5 = _mm256_loadu_ps(kptr + 40);
744777
__m256 _w6 = _mm256_loadu_ps(kptr + 48);
745778
__m256 _w7 = _mm256_loadu_ps(kptr + 56);
746779
#endif
747780

748-
_sum0 = _mm256_comp_fmadd_ps(_val4, _w4, _sum0);
749-
_sum1 = _mm256_comp_fmadd_ps(_val5, _w5, _sum1);
750-
_sum2 = _mm256_comp_fmadd_ps(_val6, _w6, _sum2);
751-
_sum3 = _mm256_comp_fmadd_ps(_val7, _w7, _sum3);
781+
_sum4 = _mm256_comp_fmadd_ps(_val4, _w4, _sum4);
782+
_sum5 = _mm256_comp_fmadd_ps(_val5, _w5, _sum5);
783+
_sum6 = _mm256_comp_fmadd_ps(_val6, _w6, _sum6);
784+
_sum7 = _mm256_comp_fmadd_ps(_val7, _w7, _sum7);
752785

753786
m += 8;
754787
kptr += 64;
@@ -760,12 +793,10 @@ static void innerproduct_gemm_sse(const Mat& bottom_blob, Mat& top_blob, const M
760793
__m256 _val2 = _mm256_broadcast_ss(m + 2);
761794
__m256 _val3 = _mm256_broadcast_ss(m + 3);
762795
#if NCNN_IMPL_FP16S
763-
__m256i _w01 = _mm256_lddqu_si256((const __m256i*)kptr);
764-
__m256i _w23 = _mm256_lddqu_si256((const __m256i*)(kptr + 16));
765-
__m256 _w0 = _mm256_cvtph_ps(_mm256_extractf128_si256(_w01, 0));
766-
__m256 _w1 = _mm256_cvtph_ps(_mm256_extractf128_si256(_w01, 1));
767-
__m256 _w2 = _mm256_cvtph_ps(_mm256_extractf128_si256(_w23, 0));
768-
__m256 _w3 = _mm256_cvtph_ps(_mm256_extractf128_si256(_w23, 1));
796+
__m256 _w0 = _mm256_cvtph_ps(_mm_lddqu_si128((const __m128i*)kptr));
797+
__m256 _w1 = _mm256_cvtph_ps(_mm_lddqu_si128((const __m128i*)(kptr + 8)));
798+
__m256 _w2 = _mm256_cvtph_ps(_mm_lddqu_si128((const __m128i*)(kptr + 16)));
799+
__m256 _w3 = _mm256_cvtph_ps(_mm_lddqu_si128((const __m128i*)(kptr + 24)));
769800
#else
770801
__m256 _w0 = _mm256_loadu_ps(kptr);
771802
__m256 _w1 = _mm256_loadu_ps(kptr + 8);
@@ -797,7 +828,11 @@ static void innerproduct_gemm_sse(const Mat& bottom_blob, Mat& top_blob, const M
797828

798829
_sum0 = _mm256_add_ps(_sum0, _sum1);
799830
_sum2 = _mm256_add_ps(_sum2, _sum3);
831+
_sum4 = _mm256_add_ps(_sum4, _sum5);
832+
_sum6 = _mm256_add_ps(_sum6, _sum7);
800833
_sum0 = _mm256_add_ps(_sum0, _sum2);
834+
_sum4 = _mm256_add_ps(_sum4, _sum6);
835+
_sum0 = _mm256_add_ps(_sum0, _sum4);
801836

802837
_sum0 = activation_avx(_sum0, activation_type, activation_params);
803838

@@ -1086,14 +1121,42 @@ static void innerproduct_gemm_sse(const Mat& bottom_blob, Mat& top_blob, const M
10861121
#endif
10871122
const float* m = bottom_blob.row(j);
10881123

1089-
__m128 _sum = _mm_setzero_ps();
1124+
__m128 _sum0 = _mm_setzero_ps();
1125+
__m128 _sum1 = _mm_setzero_ps();
1126+
__m128 _sum2 = _mm_setzero_ps();
1127+
__m128 _sum3 = _mm_setzero_ps();
10901128

10911129
if (bias_data_ptr)
10921130
{
1093-
_sum = _mm_loadu_ps(bias_data_ptr + p * 4);
1131+
_sum0 = _mm_loadu_ps(bias_data_ptr + p * 4);
10941132
}
10951133

10961134
int i = 0;
1135+
for (; i + 3 < num_input; i += 4)
1136+
{
1137+
__m128 _val0 = _mm_set1_ps(m[0]);
1138+
__m128 _val1 = _mm_set1_ps(m[1]);
1139+
__m128 _val2 = _mm_set1_ps(m[2]);
1140+
__m128 _val3 = _mm_set1_ps(m[3]);
1141+
#if NCNN_IMPL_FP16S
1142+
__m128 _w0 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i*)kptr));
1143+
__m128 _w1 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i*)(kptr + 4)));
1144+
__m128 _w2 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i*)(kptr + 8)));
1145+
__m128 _w3 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i*)(kptr + 12)));
1146+
#else
1147+
__m128 _w0 = _mm_loadu_ps(kptr);
1148+
__m128 _w1 = _mm_loadu_ps(kptr + 4);
1149+
__m128 _w2 = _mm_loadu_ps(kptr + 8);
1150+
__m128 _w3 = _mm_loadu_ps(kptr + 12);
1151+
#endif
1152+
_sum0 = _mm_comp_fmadd_ps(_val0, _w0, _sum0);
1153+
_sum1 = _mm_comp_fmadd_ps(_val1, _w1, _sum1);
1154+
_sum2 = _mm_comp_fmadd_ps(_val2, _w2, _sum2);
1155+
_sum3 = _mm_comp_fmadd_ps(_val3, _w3, _sum3);
1156+
1157+
m += 4;
1158+
kptr += 16;
1159+
}
10971160
for (; i < num_input; i++)
10981161
{
10991162
__m128 _val = _mm_set1_ps(m[0]);
@@ -1102,15 +1165,19 @@ static void innerproduct_gemm_sse(const Mat& bottom_blob, Mat& top_blob, const M
11021165
#else
11031166
__m128 _w = _mm_loadu_ps(kptr);
11041167
#endif
1105-
_sum = _mm_comp_fmadd_ps(_val, _w, _sum);
1168+
_sum0 = _mm_comp_fmadd_ps(_val, _w, _sum0);
11061169

11071170
m += 1;
11081171
kptr += 4;
11091172
}
11101173

1111-
_sum = activation_sse(_sum, activation_type, activation_params);
1174+
_sum0 = _mm_add_ps(_sum0, _sum1);
1175+
_sum2 = _mm_add_ps(_sum2, _sum3);
1176+
_sum0 = _mm_add_ps(_sum0, _sum2);
11121177

1113-
_mm_storeu_ps(outptr, _sum);
1178+
_sum0 = activation_sse(_sum0, activation_type, activation_params);
1179+
1180+
_mm_storeu_ps(outptr, _sum0);
11141181
outptr += 4;
11151182
}
11161183
}

0 commit comments

Comments
 (0)