@@ -180,14 +180,43 @@ static void innerproduct_gemm_sse(const Mat& bottom_blob, Mat& top_blob, const M
180180#endif
181181 const float * m = bottom_blob .row (j );
182182
183- __m512 _sum = _mm512_setzero_ps ();
183+ __m512 _sum0 = _mm512_setzero_ps ();
184+ __m512 _sum1 = _mm512_setzero_ps ();
185+ __m512 _sum2 = _mm512_setzero_ps ();
186+ __m512 _sum3 = _mm512_setzero_ps ();
184187
185188 if (bias_data_ptr )
186189 {
187- _sum = _mm512_loadu_ps (bias_data_ptr + p * 16 );
190+ _sum0 = _mm512_loadu_ps (bias_data_ptr + p * 16 );
188191 }
189192
190193 int i = 0 ;
194+ for (; i + 3 < num_input ; i += 4 )
195+ {
196+ __m512 _val0 = _mm512_set1_ps (m [0 ]);
197+ __m512 _val1 = _mm512_set1_ps (m [1 ]);
198+ __m512 _val2 = _mm512_set1_ps (m [2 ]);
199+ __m512 _val3 = _mm512_set1_ps (m [3 ]);
200+ #if NCNN_IMPL_FP16S
201+ __m512 _w0 = _mm512_cvtph_ps (_mm256_lddqu_si256 ((const __m256i * )kptr ));
202+ __m512 _w1 = _mm512_cvtph_ps (_mm256_lddqu_si256 ((const __m256i * )(kptr + 16 )));
203+ __m512 _w2 = _mm512_cvtph_ps (_mm256_lddqu_si256 ((const __m256i * )(kptr + 32 )));
204+ __m512 _w3 = _mm512_cvtph_ps (_mm256_lddqu_si256 ((const __m256i * )(kptr + 48 )));
205+ #else
206+ __m512 _w0 = _mm512_loadu_ps (kptr );
207+ __m512 _w1 = _mm512_loadu_ps (kptr + 16 );
208+ __m512 _w2 = _mm512_loadu_ps (kptr + 32 );
209+ __m512 _w3 = _mm512_loadu_ps (kptr + 48 );
210+ #endif
211+
212+ _sum0 = _mm512_fmadd_ps (_val0 , _w0 , _sum0 );
213+ _sum1 = _mm512_fmadd_ps (_val1 , _w1 , _sum1 );
214+ _sum2 = _mm512_fmadd_ps (_val2 , _w2 , _sum2 );
215+ _sum3 = _mm512_fmadd_ps (_val3 , _w3 , _sum3 );
216+
217+ m += 4 ;
218+ kptr += 64 ;
219+ }
191220 for (; i < num_input ; i ++ )
192221 {
193222 __m512 _val = _mm512_set1_ps (m [0 ]);
@@ -197,15 +226,19 @@ static void innerproduct_gemm_sse(const Mat& bottom_blob, Mat& top_blob, const M
197226 __m512 _w = _mm512_loadu_ps (kptr );
198227#endif
199228
200- _sum = _mm512_fmadd_ps (_val , _w , _sum );
229+ _sum0 = _mm512_fmadd_ps (_val , _w , _sum0 );
201230
202231 m += 1 ;
203232 kptr += 16 ;
204233 }
205234
206- _sum = activation_avx512 (_sum , activation_type , activation_params );
235+ _sum0 = _mm512_add_ps (_sum0 , _sum1 );
236+ _sum2 = _mm512_add_ps (_sum2 , _sum3 );
237+ _sum0 = _mm512_add_ps (_sum0 , _sum2 );
207238
208- _mm512_storeu_ps (outptr , _sum );
239+ _sum0 = activation_avx512 (_sum0 , activation_type , activation_params );
240+
241+ _mm512_storeu_ps (outptr , _sum0 );
209242 outptr += 16 ;
210243 }
211244 }
@@ -695,6 +728,10 @@ static void innerproduct_gemm_sse(const Mat& bottom_blob, Mat& top_blob, const M
695728 __m256 _sum1 = _mm256_setzero_ps ();
696729 __m256 _sum2 = _mm256_setzero_ps ();
697730 __m256 _sum3 = _mm256_setzero_ps ();
731+ __m256 _sum4 = _mm256_setzero_ps ();
732+ __m256 _sum5 = _mm256_setzero_ps ();
733+ __m256 _sum6 = _mm256_setzero_ps ();
734+ __m256 _sum7 = _mm256_setzero_ps ();
698735
699736 if (bias_data_ptr )
700737 {
@@ -709,12 +746,10 @@ static void innerproduct_gemm_sse(const Mat& bottom_blob, Mat& top_blob, const M
709746 __m256 _val2 = _mm256_broadcast_ss (m + 2 );
710747 __m256 _val3 = _mm256_broadcast_ss (m + 3 );
711748#if NCNN_IMPL_FP16S
712- __m256i _w01 = _mm256_lddqu_si256 ((const __m256i * )kptr );
713- __m256i _w23 = _mm256_lddqu_si256 ((const __m256i * )(kptr + 16 ));
714- __m256 _w0 = _mm256_cvtph_ps (_mm256_extractf128_si256 (_w01 , 0 ));
715- __m256 _w1 = _mm256_cvtph_ps (_mm256_extractf128_si256 (_w01 , 1 ));
716- __m256 _w2 = _mm256_cvtph_ps (_mm256_extractf128_si256 (_w23 , 0 ));
717- __m256 _w3 = _mm256_cvtph_ps (_mm256_extractf128_si256 (_w23 , 1 ));
749+ __m256 _w0 = _mm256_cvtph_ps (_mm_lddqu_si128 ((const __m128i * )kptr ));
750+ __m256 _w1 = _mm256_cvtph_ps (_mm_lddqu_si128 ((const __m128i * )(kptr + 8 )));
751+ __m256 _w2 = _mm256_cvtph_ps (_mm_lddqu_si128 ((const __m128i * )(kptr + 16 )));
752+ __m256 _w3 = _mm256_cvtph_ps (_mm_lddqu_si128 ((const __m128i * )(kptr + 24 )));
718753#else
719754 __m256 _w0 = _mm256_loadu_ps (kptr );
720755 __m256 _w1 = _mm256_loadu_ps (kptr + 8 );
@@ -732,23 +767,21 @@ static void innerproduct_gemm_sse(const Mat& bottom_blob, Mat& top_blob, const M
732767 __m256 _val6 = _mm256_broadcast_ss (m + 6 );
733768 __m256 _val7 = _mm256_broadcast_ss (m + 7 );
734769#if NCNN_IMPL_FP16S
735- __m256i _w45 = _mm256_lddqu_si256 ((const __m256i * )(kptr + 32 ));
736- __m256i _w67 = _mm256_lddqu_si256 ((const __m256i * )(kptr + 48 ));
737- __m256 _w4 = _mm256_cvtph_ps (_mm256_extractf128_si256 (_w45 , 0 ));
738- __m256 _w5 = _mm256_cvtph_ps (_mm256_extractf128_si256 (_w45 , 1 ));
739- __m256 _w6 = _mm256_cvtph_ps (_mm256_extractf128_si256 (_w67 , 0 ));
740- __m256 _w7 = _mm256_cvtph_ps (_mm256_extractf128_si256 (_w67 , 1 ));
770+ __m256 _w4 = _mm256_cvtph_ps (_mm_lddqu_si128 ((const __m128i * )(kptr + 32 )));
771+ __m256 _w5 = _mm256_cvtph_ps (_mm_lddqu_si128 ((const __m128i * )(kptr + 40 )));
772+ __m256 _w6 = _mm256_cvtph_ps (_mm_lddqu_si128 ((const __m128i * )(kptr + 48 )));
773+ __m256 _w7 = _mm256_cvtph_ps (_mm_lddqu_si128 ((const __m128i * )(kptr + 56 )));
741774#else
742775 __m256 _w4 = _mm256_loadu_ps (kptr + 32 );
743776 __m256 _w5 = _mm256_loadu_ps (kptr + 40 );
744777 __m256 _w6 = _mm256_loadu_ps (kptr + 48 );
745778 __m256 _w7 = _mm256_loadu_ps (kptr + 56 );
746779#endif
747780
748- _sum0 = _mm256_comp_fmadd_ps (_val4 , _w4 , _sum0 );
749- _sum1 = _mm256_comp_fmadd_ps (_val5 , _w5 , _sum1 );
750- _sum2 = _mm256_comp_fmadd_ps (_val6 , _w6 , _sum2 );
751- _sum3 = _mm256_comp_fmadd_ps (_val7 , _w7 , _sum3 );
781+ _sum4 = _mm256_comp_fmadd_ps (_val4 , _w4 , _sum4 );
782+ _sum5 = _mm256_comp_fmadd_ps (_val5 , _w5 , _sum5 );
783+ _sum6 = _mm256_comp_fmadd_ps (_val6 , _w6 , _sum6 );
784+ _sum7 = _mm256_comp_fmadd_ps (_val7 , _w7 , _sum7 );
752785
753786 m += 8 ;
754787 kptr += 64 ;
@@ -760,12 +793,10 @@ static void innerproduct_gemm_sse(const Mat& bottom_blob, Mat& top_blob, const M
760793 __m256 _val2 = _mm256_broadcast_ss (m + 2 );
761794 __m256 _val3 = _mm256_broadcast_ss (m + 3 );
762795#if NCNN_IMPL_FP16S
763- __m256i _w01 = _mm256_lddqu_si256 ((const __m256i * )kptr );
764- __m256i _w23 = _mm256_lddqu_si256 ((const __m256i * )(kptr + 16 ));
765- __m256 _w0 = _mm256_cvtph_ps (_mm256_extractf128_si256 (_w01 , 0 ));
766- __m256 _w1 = _mm256_cvtph_ps (_mm256_extractf128_si256 (_w01 , 1 ));
767- __m256 _w2 = _mm256_cvtph_ps (_mm256_extractf128_si256 (_w23 , 0 ));
768- __m256 _w3 = _mm256_cvtph_ps (_mm256_extractf128_si256 (_w23 , 1 ));
796+ __m256 _w0 = _mm256_cvtph_ps (_mm_lddqu_si128 ((const __m128i * )kptr ));
797+ __m256 _w1 = _mm256_cvtph_ps (_mm_lddqu_si128 ((const __m128i * )(kptr + 8 )));
798+ __m256 _w2 = _mm256_cvtph_ps (_mm_lddqu_si128 ((const __m128i * )(kptr + 16 )));
799+ __m256 _w3 = _mm256_cvtph_ps (_mm_lddqu_si128 ((const __m128i * )(kptr + 24 )));
769800#else
770801 __m256 _w0 = _mm256_loadu_ps (kptr );
771802 __m256 _w1 = _mm256_loadu_ps (kptr + 8 );
@@ -797,7 +828,11 @@ static void innerproduct_gemm_sse(const Mat& bottom_blob, Mat& top_blob, const M
797828
798829 _sum0 = _mm256_add_ps (_sum0 , _sum1 );
799830 _sum2 = _mm256_add_ps (_sum2 , _sum3 );
831+ _sum4 = _mm256_add_ps (_sum4 , _sum5 );
832+ _sum6 = _mm256_add_ps (_sum6 , _sum7 );
800833 _sum0 = _mm256_add_ps (_sum0 , _sum2 );
834+ _sum4 = _mm256_add_ps (_sum4 , _sum6 );
835+ _sum0 = _mm256_add_ps (_sum0 , _sum4 );
801836
802837 _sum0 = activation_avx (_sum0 , activation_type , activation_params );
803838
@@ -1086,14 +1121,42 @@ static void innerproduct_gemm_sse(const Mat& bottom_blob, Mat& top_blob, const M
10861121#endif
10871122 const float * m = bottom_blob .row (j );
10881123
1089- __m128 _sum = _mm_setzero_ps ();
1124+ __m128 _sum0 = _mm_setzero_ps ();
1125+ __m128 _sum1 = _mm_setzero_ps ();
1126+ __m128 _sum2 = _mm_setzero_ps ();
1127+ __m128 _sum3 = _mm_setzero_ps ();
10901128
10911129 if (bias_data_ptr )
10921130 {
1093- _sum = _mm_loadu_ps (bias_data_ptr + p * 4 );
1131+ _sum0 = _mm_loadu_ps (bias_data_ptr + p * 4 );
10941132 }
10951133
10961134 int i = 0 ;
1135+ for (; i + 3 < num_input ; i += 4 )
1136+ {
1137+ __m128 _val0 = _mm_set1_ps (m [0 ]);
1138+ __m128 _val1 = _mm_set1_ps (m [1 ]);
1139+ __m128 _val2 = _mm_set1_ps (m [2 ]);
1140+ __m128 _val3 = _mm_set1_ps (m [3 ]);
1141+ #if NCNN_IMPL_FP16S
1142+ __m128 _w0 = _mm_cvtph_ps (_mm_loadl_epi64 ((const __m128i * )kptr ));
1143+ __m128 _w1 = _mm_cvtph_ps (_mm_loadl_epi64 ((const __m128i * )(kptr + 4 )));
1144+ __m128 _w2 = _mm_cvtph_ps (_mm_loadl_epi64 ((const __m128i * )(kptr + 8 )));
1145+ __m128 _w3 = _mm_cvtph_ps (_mm_loadl_epi64 ((const __m128i * )(kptr + 12 )));
1146+ #else
1147+ __m128 _w0 = _mm_loadu_ps (kptr );
1148+ __m128 _w1 = _mm_loadu_ps (kptr + 4 );
1149+ __m128 _w2 = _mm_loadu_ps (kptr + 8 );
1150+ __m128 _w3 = _mm_loadu_ps (kptr + 12 );
1151+ #endif
1152+ _sum0 = _mm_comp_fmadd_ps (_val0 , _w0 , _sum0 );
1153+ _sum1 = _mm_comp_fmadd_ps (_val1 , _w1 , _sum1 );
1154+ _sum2 = _mm_comp_fmadd_ps (_val2 , _w2 , _sum2 );
1155+ _sum3 = _mm_comp_fmadd_ps (_val3 , _w3 , _sum3 );
1156+
1157+ m += 4 ;
1158+ kptr += 16 ;
1159+ }
10971160 for (; i < num_input ; i ++ )
10981161 {
10991162 __m128 _val = _mm_set1_ps (m [0 ]);
@@ -1102,15 +1165,19 @@ static void innerproduct_gemm_sse(const Mat& bottom_blob, Mat& top_blob, const M
11021165#else
11031166 __m128 _w = _mm_loadu_ps (kptr );
11041167#endif
1105- _sum = _mm_comp_fmadd_ps (_val , _w , _sum );
1168+ _sum0 = _mm_comp_fmadd_ps (_val , _w , _sum0 );
11061169
11071170 m += 1 ;
11081171 kptr += 4 ;
11091172 }
11101173
1111- _sum = activation_sse (_sum , activation_type , activation_params );
1174+ _sum0 = _mm_add_ps (_sum0 , _sum1 );
1175+ _sum2 = _mm_add_ps (_sum2 , _sum3 );
1176+ _sum0 = _mm_add_ps (_sum0 , _sum2 );
11121177
1113- _mm_storeu_ps (outptr , _sum );
1178+ _sum0 = activation_sse (_sum0 , activation_type , activation_params );
1179+
1180+ _mm_storeu_ps (outptr , _sum0 );
11141181 outptr += 4 ;
11151182 }
11161183 }
0 commit comments