@@ -6959,6 +6959,277 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x4c8__avx_ld128(
69596959 } while (nc != 0);
69606960}
69616961
6962+ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x4c8__avx_ld128(
6963+ size_t mr,
6964+ size_t nc,
6965+ size_t kc,
6966+ size_t ks,
6967+ const int8_t** restrict a,
6968+ const void* restrict w,
6969+ float* restrict c,
6970+ size_t cm_stride,
6971+ size_t cn_stride,
6972+ size_t a_offset,
6973+ const int8_t* zero,
6974+ const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)],
6975+ const struct xnn_qd8_quantization_params quantization_params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
6976+ {
6977+ assert(mr != 0);
6978+ assert(mr <= 1);
6979+ assert(nc != 0);
6980+ assert(kc != 0);
6981+ assert(ks != 0);
6982+ assert(ks % (1 * sizeof(void*)) == 0);
6983+ assert(a_offset % sizeof(int8_t) == 0);
6984+ assert(a != NULL);
6985+ assert(w != NULL);
6986+ assert(c != NULL);
6987+
6988+ kc = round_up_po2(kc, 8 * sizeof(int8_t));
6989+ float* c0 = c;
6990+
6991+ const __m128i vinput_zero_point = _mm_castps_si128(_mm_broadcast_ss((const float*) &quantization_params->zero_point));
6992+ const __m128 vinput_scale = _mm_broadcast_ss(&quantization_params->inv_scale);
6993+ do {
6994+ const __m128i vksum = _mm_load_si128((const __m128i*) w);
6995+ const __m128i vzero = _mm_setzero_si128();
6996+ const __m128i vinit0 = _mm_mullo_epi32(vksum, vinput_zero_point);
6997+ __m128i vacc0x0 = _mm_blend_epi16(vinit0, vzero, 0xFC);
6998+ __m128i vacc0x1 = _mm_blend_epi16(vinit0, vzero, 0xF3);
6999+ __m128i vacc0x2 = _mm_blend_epi16(vinit0, vzero, 0xCF);
7000+ __m128i vacc0x3 = _mm_blend_epi16(vinit0, vzero, 0x3F);
7001+ w = (const int32_t*) w + 4;
7002+
7003+ size_t p = ks;
7004+ do {
7005+ const int8_t* restrict a0 = a[0];
7006+ if XNN_UNPREDICTABLE(a0 != zero) {
7007+ a0 = (const int8_t*) ((uintptr_t) a0 + a_offset);
7008+ }
7009+ a += 1;
7010+
7011+ size_t k = 0;
7012+ while (k < kc) {
7013+ const __m128i va0 = _mm_loadl_epi64((const __m128i*) a0);
7014+ const __m128i vxa0 = _mm_cvtepi8_epi16(va0);
7015+ a0 += 8;
7016+
7017+ const __m128i vb01 = _mm_load_si128((const __m128i*) w);
7018+ const __m128i vxb0 = _mm_cvtepi8_epi16(vb01);
7019+ const __m128i vxb1 = _mm_srai_epi16(_mm_unpackhi_epi8(vb01, vb01), 8);
7020+
7021+ vacc0x0 = _mm_add_epi32(vacc0x0, _mm_madd_epi16(vxa0, vxb0));
7022+ vacc0x1 = _mm_add_epi32(vacc0x1, _mm_madd_epi16(vxa0, vxb1));
7023+ const __m128i vb23 = _mm_load_si128((const __m128i*) ((const int8_t*) w + 16));
7024+ const __m128i vxb2 = _mm_cvtepi8_epi16(vb23);
7025+ const __m128i vxb3 = _mm_srai_epi16(_mm_unpackhi_epi8(vb23, vb23), 8);
7026+
7027+ vacc0x2 = _mm_add_epi32(vacc0x2, _mm_madd_epi16(vxa0, vxb2));
7028+ vacc0x3 = _mm_add_epi32(vacc0x3, _mm_madd_epi16(vxa0, vxb3));
7029+
7030+ w = (const void*) ((const int8_t*) w + 32);
7031+ k += 8 * sizeof(int8_t);
7032+ }
7033+ p -= 1 * sizeof(void*);
7034+ } while (p != 0);
7035+
7036+ const __m128i vacc0x01 = _mm_hadd_epi32(vacc0x0, vacc0x1);
7037+ const __m128i vacc0x23 = _mm_hadd_epi32(vacc0x2, vacc0x3);
7038+
7039+ __m128i vacc0x0123 = _mm_hadd_epi32(vacc0x01, vacc0x23);
7040+
7041+ __m128 vout0x0123 = _mm_cvtepi32_ps(vacc0x0123);
7042+
7043+ vout0x0123 = _mm_mul_ps(vout0x0123, vinput_scale);
7044+
7045+ const __m128 vfilter_output_scale0123 = _mm_load_ps((const float*) w);
7046+ vout0x0123 = _mm_mul_ps(vout0x0123, vfilter_output_scale0123);
7047+
7048+ const __m128 vbias0123 = _mm_load_ps((const float*) w + 4);
7049+ w = (const float*) w + 8;
7050+ vout0x0123 = _mm_add_ps(vout0x0123, vbias0123);
7051+
7052+ const __m128 vmin = _mm_load_ps(params->sse.min);
7053+ vout0x0123 = _mm_max_ps(vout0x0123, vmin);
7054+
7055+ const __m128 vmax = _mm_load_ps(params->sse.max);
7056+ vout0x0123 = _mm_min_ps(vout0x0123, vmax);
7057+
7058+ if XNN_LIKELY(nc >= 4) {
7059+ _mm_storeu_ps(c0, vout0x0123);
7060+
7061+ a = (const int8_t**restrict) ((uintptr_t) a - ks);
7062+
7063+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
7064+
7065+ nc -= 4;
7066+ } else {
7067+ if (nc & 2) {
7068+ _mm_storel_pi((__m64*) c0, vout0x0123);
7069+ vout0x0123 = _mm_unpackhi_ps(vout0x0123, vout0x0123);
7070+ c0 += 2;
7071+ }
7072+ if (nc & 1) {
7073+ _mm_store_ss(c0, vout0x0123);
7074+ }
7075+ nc = 0;
7076+ }
7077+ } while (nc != 0);
7078+ }
7079+
7080+ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_2x4c8__avx_ld128(
7081+ size_t mr,
7082+ size_t nc,
7083+ size_t kc,
7084+ size_t ks,
7085+ const int8_t** restrict a,
7086+ const void* restrict w,
7087+ float* restrict c,
7088+ size_t cm_stride,
7089+ size_t cn_stride,
7090+ size_t a_offset,
7091+ const int8_t* zero,
7092+ const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)],
7093+ const struct xnn_qd8_quantization_params quantization_params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
7094+ {
7095+ assert(mr != 0);
7096+ assert(mr <= 2);
7097+ assert(nc != 0);
7098+ assert(kc != 0);
7099+ assert(ks != 0);
7100+ assert(ks % (2 * sizeof(void*)) == 0);
7101+ assert(a_offset % sizeof(int8_t) == 0);
7102+ assert(a != NULL);
7103+ assert(w != NULL);
7104+ assert(c != NULL);
7105+
7106+ kc = round_up_po2(kc, 8 * sizeof(int8_t));
7107+ float* c0 = c;
7108+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
7109+ if XNN_UNPREDICTABLE(mr != 2) {
7110+ c1 = c0;
7111+ }
7112+
7113+ const __m128i vinput_zero_point = _mm_castps_si128(_mm_broadcast_ss((const float*) &quantization_params->zero_point));
7114+ const __m128 vinput_scale = _mm_broadcast_ss(&quantization_params->inv_scale);
7115+ do {
7116+ const __m128i vksum = _mm_load_si128((const __m128i*) w);
7117+ const __m128i vzero = _mm_setzero_si128();
7118+ const __m128i vinit0 = _mm_mullo_epi32(vksum, vinput_zero_point);
7119+ const __m128i vinit1 = _mm_mullo_epi32(vksum, vinput_zero_point);
7120+ __m128i vacc0x0 = _mm_blend_epi16(vinit0, vzero, 0xFC);
7121+ __m128i vacc0x1 = _mm_blend_epi16(vinit0, vzero, 0xF3);
7122+ __m128i vacc0x2 = _mm_blend_epi16(vinit0, vzero, 0xCF);
7123+ __m128i vacc0x3 = _mm_blend_epi16(vinit0, vzero, 0x3F);
7124+ __m128i vacc1x0 = _mm_blend_epi16(vinit1, vzero, 0xFC);
7125+ __m128i vacc1x1 = _mm_blend_epi16(vinit1, vzero, 0xF3);
7126+ __m128i vacc1x2 = _mm_blend_epi16(vinit1, vzero, 0xCF);
7127+ __m128i vacc1x3 = _mm_blend_epi16(vinit1, vzero, 0x3F);
7128+ w = (const int32_t*) w + 4;
7129+
7130+ size_t p = ks;
7131+ do {
7132+ const int8_t* restrict a0 = a[0];
7133+ if XNN_UNPREDICTABLE(a0 != zero) {
7134+ a0 = (const int8_t*) ((uintptr_t) a0 + a_offset);
7135+ }
7136+ const int8_t* restrict a1 = a[1];
7137+ if XNN_UNPREDICTABLE(a1 != zero) {
7138+ a1 = (const int8_t*) ((uintptr_t) a1 + a_offset);
7139+ }
7140+ a += 2;
7141+
7142+ size_t k = 0;
7143+ while (k < kc) {
7144+ const __m128i va0 = _mm_loadl_epi64((const __m128i*) a0);
7145+ const __m128i vxa0 = _mm_cvtepi8_epi16(va0);
7146+ a0 += 8;
7147+ const __m128i va1 = _mm_loadl_epi64((const __m128i*) a1);
7148+ const __m128i vxa1 = _mm_cvtepi8_epi16(va1);
7149+ a1 += 8;
7150+
7151+ const __m128i vb01 = _mm_load_si128((const __m128i*) w);
7152+ const __m128i vxb0 = _mm_cvtepi8_epi16(vb01);
7153+ const __m128i vxb1 = _mm_srai_epi16(_mm_unpackhi_epi8(vb01, vb01), 8);
7154+
7155+ vacc0x0 = _mm_add_epi32(vacc0x0, _mm_madd_epi16(vxa0, vxb0));
7156+ vacc0x1 = _mm_add_epi32(vacc0x1, _mm_madd_epi16(vxa0, vxb1));
7157+ vacc1x0 = _mm_add_epi32(vacc1x0, _mm_madd_epi16(vxa1, vxb0));
7158+ vacc1x1 = _mm_add_epi32(vacc1x1, _mm_madd_epi16(vxa1, vxb1));
7159+ const __m128i vb23 = _mm_load_si128((const __m128i*) ((const int8_t*) w + 16));
7160+ const __m128i vxb2 = _mm_cvtepi8_epi16(vb23);
7161+ const __m128i vxb3 = _mm_srai_epi16(_mm_unpackhi_epi8(vb23, vb23), 8);
7162+
7163+ vacc0x2 = _mm_add_epi32(vacc0x2, _mm_madd_epi16(vxa0, vxb2));
7164+ vacc0x3 = _mm_add_epi32(vacc0x3, _mm_madd_epi16(vxa0, vxb3));
7165+ vacc1x2 = _mm_add_epi32(vacc1x2, _mm_madd_epi16(vxa1, vxb2));
7166+ vacc1x3 = _mm_add_epi32(vacc1x3, _mm_madd_epi16(vxa1, vxb3));
7167+
7168+ w = (const void*) ((const int8_t*) w + 32);
7169+ k += 8 * sizeof(int8_t);
7170+ }
7171+ p -= 2 * sizeof(void*);
7172+ } while (p != 0);
7173+
7174+ const __m128i vacc0x01 = _mm_hadd_epi32(vacc0x0, vacc0x1);
7175+ const __m128i vacc0x23 = _mm_hadd_epi32(vacc0x2, vacc0x3);
7176+ const __m128i vacc1x01 = _mm_hadd_epi32(vacc1x0, vacc1x1);
7177+ const __m128i vacc1x23 = _mm_hadd_epi32(vacc1x2, vacc1x3);
7178+
7179+ __m128i vacc0x0123 = _mm_hadd_epi32(vacc0x01, vacc0x23);
7180+ __m128i vacc1x0123 = _mm_hadd_epi32(vacc1x01, vacc1x23);
7181+
7182+ __m128 vout0x0123 = _mm_cvtepi32_ps(vacc0x0123);
7183+ __m128 vout1x0123 = _mm_cvtepi32_ps(vacc1x0123);
7184+
7185+ vout0x0123 = _mm_mul_ps(vout0x0123, vinput_scale);
7186+ vout1x0123 = _mm_mul_ps(vout1x0123, vinput_scale);
7187+
7188+ const __m128 vfilter_output_scale0123 = _mm_load_ps((const float*) w);
7189+ vout0x0123 = _mm_mul_ps(vout0x0123, vfilter_output_scale0123);
7190+ vout1x0123 = _mm_mul_ps(vout1x0123, vfilter_output_scale0123);
7191+
7192+ const __m128 vbias0123 = _mm_load_ps((const float*) w + 4);
7193+ w = (const float*) w + 8;
7194+ vout0x0123 = _mm_add_ps(vout0x0123, vbias0123);
7195+ vout1x0123 = _mm_add_ps(vout1x0123, vbias0123);
7196+
7197+ const __m128 vmin = _mm_load_ps(params->sse.min);
7198+ vout0x0123 = _mm_max_ps(vout0x0123, vmin);
7199+ vout1x0123 = _mm_max_ps(vout1x0123, vmin);
7200+
7201+ const __m128 vmax = _mm_load_ps(params->sse.max);
7202+ vout0x0123 = _mm_min_ps(vout0x0123, vmax);
7203+ vout1x0123 = _mm_min_ps(vout1x0123, vmax);
7204+
7205+ if XNN_LIKELY(nc >= 4) {
7206+ _mm_storeu_ps(c1, vout1x0123);
7207+ _mm_storeu_ps(c0, vout0x0123);
7208+
7209+ a = (const int8_t**restrict) ((uintptr_t) a - ks);
7210+
7211+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
7212+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
7213+
7214+ nc -= 4;
7215+ } else {
7216+ if (nc & 2) {
7217+ _mm_storel_pi((__m64*) c1, vout1x0123);
7218+ vout1x0123 = _mm_unpackhi_ps(vout1x0123, vout1x0123);
7219+ c1 += 2;
7220+ _mm_storel_pi((__m64*) c0, vout0x0123);
7221+ vout0x0123 = _mm_unpackhi_ps(vout0x0123, vout0x0123);
7222+ c0 += 2;
7223+ }
7224+ if (nc & 1) {
7225+ _mm_store_ss(c1, vout1x0123);
7226+ _mm_store_ss(c0, vout0x0123);
7227+ }
7228+ nc = 0;
7229+ }
7230+ } while (nc != 0);
7231+ }
7232+
69627233void xnn_qs16_qs8_vcvt_ukernel__avx_u16(
69637234 size_t batch,
69647235 const int16_t* input,
0 commit comments