Skip to content

Commit 93cb2fe

Browse files
alankellyxnnpack-bot
authored andcommitted
Enable DQ IGEMM in config & 32 bit DQ I8MM
PiperOrigin-RevId: 573816978
1 parent ed93717 commit 93cb2fe

16 files changed

Lines changed: 7204 additions & 2 deletions

File tree

src/amalgam/gen/avx.c

Lines changed: 271 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6959,6 +6959,277 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x4c8__avx_ld128(
69596959
} while (nc != 0);
69606960
}
69616961

6962+
void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x4c8__avx_ld128(
6963+
size_t mr,
6964+
size_t nc,
6965+
size_t kc,
6966+
size_t ks,
6967+
const int8_t** restrict a,
6968+
const void* restrict w,
6969+
float* restrict c,
6970+
size_t cm_stride,
6971+
size_t cn_stride,
6972+
size_t a_offset,
6973+
const int8_t* zero,
6974+
const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)],
6975+
const struct xnn_qd8_quantization_params quantization_params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
6976+
{
6977+
assert(mr != 0);
6978+
assert(mr <= 1);
6979+
assert(nc != 0);
6980+
assert(kc != 0);
6981+
assert(ks != 0);
6982+
assert(ks % (1 * sizeof(void*)) == 0);
6983+
assert(a_offset % sizeof(int8_t) == 0);
6984+
assert(a != NULL);
6985+
assert(w != NULL);
6986+
assert(c != NULL);
6987+
6988+
kc = round_up_po2(kc, 8 * sizeof(int8_t));
6989+
float* c0 = c;
6990+
6991+
const __m128i vinput_zero_point = _mm_castps_si128(_mm_broadcast_ss((const float*) &quantization_params->zero_point));
6992+
const __m128 vinput_scale = _mm_broadcast_ss(&quantization_params->inv_scale);
6993+
do {
6994+
const __m128i vksum = _mm_load_si128((const __m128i*) w);
6995+
const __m128i vzero = _mm_setzero_si128();
6996+
const __m128i vinit0 = _mm_mullo_epi32(vksum, vinput_zero_point);
6997+
__m128i vacc0x0 = _mm_blend_epi16(vinit0, vzero, 0xFC);
6998+
__m128i vacc0x1 = _mm_blend_epi16(vinit0, vzero, 0xF3);
6999+
__m128i vacc0x2 = _mm_blend_epi16(vinit0, vzero, 0xCF);
7000+
__m128i vacc0x3 = _mm_blend_epi16(vinit0, vzero, 0x3F);
7001+
w = (const int32_t*) w + 4;
7002+
7003+
size_t p = ks;
7004+
do {
7005+
const int8_t* restrict a0 = a[0];
7006+
if XNN_UNPREDICTABLE(a0 != zero) {
7007+
a0 = (const int8_t*) ((uintptr_t) a0 + a_offset);
7008+
}
7009+
a += 1;
7010+
7011+
size_t k = 0;
7012+
while (k < kc) {
7013+
const __m128i va0 = _mm_loadl_epi64((const __m128i*) a0);
7014+
const __m128i vxa0 = _mm_cvtepi8_epi16(va0);
7015+
a0 += 8;
7016+
7017+
const __m128i vb01 = _mm_load_si128((const __m128i*) w);
7018+
const __m128i vxb0 = _mm_cvtepi8_epi16(vb01);
7019+
const __m128i vxb1 = _mm_srai_epi16(_mm_unpackhi_epi8(vb01, vb01), 8);
7020+
7021+
vacc0x0 = _mm_add_epi32(vacc0x0, _mm_madd_epi16(vxa0, vxb0));
7022+
vacc0x1 = _mm_add_epi32(vacc0x1, _mm_madd_epi16(vxa0, vxb1));
7023+
const __m128i vb23 = _mm_load_si128((const __m128i*) ((const int8_t*) w + 16));
7024+
const __m128i vxb2 = _mm_cvtepi8_epi16(vb23);
7025+
const __m128i vxb3 = _mm_srai_epi16(_mm_unpackhi_epi8(vb23, vb23), 8);
7026+
7027+
vacc0x2 = _mm_add_epi32(vacc0x2, _mm_madd_epi16(vxa0, vxb2));
7028+
vacc0x3 = _mm_add_epi32(vacc0x3, _mm_madd_epi16(vxa0, vxb3));
7029+
7030+
w = (const void*) ((const int8_t*) w + 32);
7031+
k += 8 * sizeof(int8_t);
7032+
}
7033+
p -= 1 * sizeof(void*);
7034+
} while (p != 0);
7035+
7036+
const __m128i vacc0x01 = _mm_hadd_epi32(vacc0x0, vacc0x1);
7037+
const __m128i vacc0x23 = _mm_hadd_epi32(vacc0x2, vacc0x3);
7038+
7039+
__m128i vacc0x0123 = _mm_hadd_epi32(vacc0x01, vacc0x23);
7040+
7041+
__m128 vout0x0123 = _mm_cvtepi32_ps(vacc0x0123);
7042+
7043+
vout0x0123 = _mm_mul_ps(vout0x0123, vinput_scale);
7044+
7045+
const __m128 vfilter_output_scale0123 = _mm_load_ps((const float*) w);
7046+
vout0x0123 = _mm_mul_ps(vout0x0123, vfilter_output_scale0123);
7047+
7048+
const __m128 vbias0123 = _mm_load_ps((const float*) w + 4);
7049+
w = (const float*) w + 8;
7050+
vout0x0123 = _mm_add_ps(vout0x0123, vbias0123);
7051+
7052+
const __m128 vmin = _mm_load_ps(params->sse.min);
7053+
vout0x0123 = _mm_max_ps(vout0x0123, vmin);
7054+
7055+
const __m128 vmax = _mm_load_ps(params->sse.max);
7056+
vout0x0123 = _mm_min_ps(vout0x0123, vmax);
7057+
7058+
if XNN_LIKELY(nc >= 4) {
7059+
_mm_storeu_ps(c0, vout0x0123);
7060+
7061+
a = (const int8_t**restrict) ((uintptr_t) a - ks);
7062+
7063+
c0 = (float*) ((uintptr_t) c0 + cn_stride);
7064+
7065+
nc -= 4;
7066+
} else {
7067+
if (nc & 2) {
7068+
_mm_storel_pi((__m64*) c0, vout0x0123);
7069+
vout0x0123 = _mm_unpackhi_ps(vout0x0123, vout0x0123);
7070+
c0 += 2;
7071+
}
7072+
if (nc & 1) {
7073+
_mm_store_ss(c0, vout0x0123);
7074+
}
7075+
nc = 0;
7076+
}
7077+
} while (nc != 0);
7078+
}
7079+
7080+
void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_2x4c8__avx_ld128(
7081+
size_t mr,
7082+
size_t nc,
7083+
size_t kc,
7084+
size_t ks,
7085+
const int8_t** restrict a,
7086+
const void* restrict w,
7087+
float* restrict c,
7088+
size_t cm_stride,
7089+
size_t cn_stride,
7090+
size_t a_offset,
7091+
const int8_t* zero,
7092+
const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)],
7093+
const struct xnn_qd8_quantization_params quantization_params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
7094+
{
7095+
assert(mr != 0);
7096+
assert(mr <= 2);
7097+
assert(nc != 0);
7098+
assert(kc != 0);
7099+
assert(ks != 0);
7100+
assert(ks % (2 * sizeof(void*)) == 0);
7101+
assert(a_offset % sizeof(int8_t) == 0);
7102+
assert(a != NULL);
7103+
assert(w != NULL);
7104+
assert(c != NULL);
7105+
7106+
kc = round_up_po2(kc, 8 * sizeof(int8_t));
7107+
float* c0 = c;
7108+
float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
7109+
if XNN_UNPREDICTABLE(mr != 2) {
7110+
c1 = c0;
7111+
}
7112+
7113+
const __m128i vinput_zero_point = _mm_castps_si128(_mm_broadcast_ss((const float*) &quantization_params->zero_point));
7114+
const __m128 vinput_scale = _mm_broadcast_ss(&quantization_params->inv_scale);
7115+
do {
7116+
const __m128i vksum = _mm_load_si128((const __m128i*) w);
7117+
const __m128i vzero = _mm_setzero_si128();
7118+
const __m128i vinit0 = _mm_mullo_epi32(vksum, vinput_zero_point);
7119+
const __m128i vinit1 = _mm_mullo_epi32(vksum, vinput_zero_point);
7120+
__m128i vacc0x0 = _mm_blend_epi16(vinit0, vzero, 0xFC);
7121+
__m128i vacc0x1 = _mm_blend_epi16(vinit0, vzero, 0xF3);
7122+
__m128i vacc0x2 = _mm_blend_epi16(vinit0, vzero, 0xCF);
7123+
__m128i vacc0x3 = _mm_blend_epi16(vinit0, vzero, 0x3F);
7124+
__m128i vacc1x0 = _mm_blend_epi16(vinit1, vzero, 0xFC);
7125+
__m128i vacc1x1 = _mm_blend_epi16(vinit1, vzero, 0xF3);
7126+
__m128i vacc1x2 = _mm_blend_epi16(vinit1, vzero, 0xCF);
7127+
__m128i vacc1x3 = _mm_blend_epi16(vinit1, vzero, 0x3F);
7128+
w = (const int32_t*) w + 4;
7129+
7130+
size_t p = ks;
7131+
do {
7132+
const int8_t* restrict a0 = a[0];
7133+
if XNN_UNPREDICTABLE(a0 != zero) {
7134+
a0 = (const int8_t*) ((uintptr_t) a0 + a_offset);
7135+
}
7136+
const int8_t* restrict a1 = a[1];
7137+
if XNN_UNPREDICTABLE(a1 != zero) {
7138+
a1 = (const int8_t*) ((uintptr_t) a1 + a_offset);
7139+
}
7140+
a += 2;
7141+
7142+
size_t k = 0;
7143+
while (k < kc) {
7144+
const __m128i va0 = _mm_loadl_epi64((const __m128i*) a0);
7145+
const __m128i vxa0 = _mm_cvtepi8_epi16(va0);
7146+
a0 += 8;
7147+
const __m128i va1 = _mm_loadl_epi64((const __m128i*) a1);
7148+
const __m128i vxa1 = _mm_cvtepi8_epi16(va1);
7149+
a1 += 8;
7150+
7151+
const __m128i vb01 = _mm_load_si128((const __m128i*) w);
7152+
const __m128i vxb0 = _mm_cvtepi8_epi16(vb01);
7153+
const __m128i vxb1 = _mm_srai_epi16(_mm_unpackhi_epi8(vb01, vb01), 8);
7154+
7155+
vacc0x0 = _mm_add_epi32(vacc0x0, _mm_madd_epi16(vxa0, vxb0));
7156+
vacc0x1 = _mm_add_epi32(vacc0x1, _mm_madd_epi16(vxa0, vxb1));
7157+
vacc1x0 = _mm_add_epi32(vacc1x0, _mm_madd_epi16(vxa1, vxb0));
7158+
vacc1x1 = _mm_add_epi32(vacc1x1, _mm_madd_epi16(vxa1, vxb1));
7159+
const __m128i vb23 = _mm_load_si128((const __m128i*) ((const int8_t*) w + 16));
7160+
const __m128i vxb2 = _mm_cvtepi8_epi16(vb23);
7161+
const __m128i vxb3 = _mm_srai_epi16(_mm_unpackhi_epi8(vb23, vb23), 8);
7162+
7163+
vacc0x2 = _mm_add_epi32(vacc0x2, _mm_madd_epi16(vxa0, vxb2));
7164+
vacc0x3 = _mm_add_epi32(vacc0x3, _mm_madd_epi16(vxa0, vxb3));
7165+
vacc1x2 = _mm_add_epi32(vacc1x2, _mm_madd_epi16(vxa1, vxb2));
7166+
vacc1x3 = _mm_add_epi32(vacc1x3, _mm_madd_epi16(vxa1, vxb3));
7167+
7168+
w = (const void*) ((const int8_t*) w + 32);
7169+
k += 8 * sizeof(int8_t);
7170+
}
7171+
p -= 2 * sizeof(void*);
7172+
} while (p != 0);
7173+
7174+
const __m128i vacc0x01 = _mm_hadd_epi32(vacc0x0, vacc0x1);
7175+
const __m128i vacc0x23 = _mm_hadd_epi32(vacc0x2, vacc0x3);
7176+
const __m128i vacc1x01 = _mm_hadd_epi32(vacc1x0, vacc1x1);
7177+
const __m128i vacc1x23 = _mm_hadd_epi32(vacc1x2, vacc1x3);
7178+
7179+
__m128i vacc0x0123 = _mm_hadd_epi32(vacc0x01, vacc0x23);
7180+
__m128i vacc1x0123 = _mm_hadd_epi32(vacc1x01, vacc1x23);
7181+
7182+
__m128 vout0x0123 = _mm_cvtepi32_ps(vacc0x0123);
7183+
__m128 vout1x0123 = _mm_cvtepi32_ps(vacc1x0123);
7184+
7185+
vout0x0123 = _mm_mul_ps(vout0x0123, vinput_scale);
7186+
vout1x0123 = _mm_mul_ps(vout1x0123, vinput_scale);
7187+
7188+
const __m128 vfilter_output_scale0123 = _mm_load_ps((const float*) w);
7189+
vout0x0123 = _mm_mul_ps(vout0x0123, vfilter_output_scale0123);
7190+
vout1x0123 = _mm_mul_ps(vout1x0123, vfilter_output_scale0123);
7191+
7192+
const __m128 vbias0123 = _mm_load_ps((const float*) w + 4);
7193+
w = (const float*) w + 8;
7194+
vout0x0123 = _mm_add_ps(vout0x0123, vbias0123);
7195+
vout1x0123 = _mm_add_ps(vout1x0123, vbias0123);
7196+
7197+
const __m128 vmin = _mm_load_ps(params->sse.min);
7198+
vout0x0123 = _mm_max_ps(vout0x0123, vmin);
7199+
vout1x0123 = _mm_max_ps(vout1x0123, vmin);
7200+
7201+
const __m128 vmax = _mm_load_ps(params->sse.max);
7202+
vout0x0123 = _mm_min_ps(vout0x0123, vmax);
7203+
vout1x0123 = _mm_min_ps(vout1x0123, vmax);
7204+
7205+
if XNN_LIKELY(nc >= 4) {
7206+
_mm_storeu_ps(c1, vout1x0123);
7207+
_mm_storeu_ps(c0, vout0x0123);
7208+
7209+
a = (const int8_t**restrict) ((uintptr_t) a - ks);
7210+
7211+
c0 = (float*) ((uintptr_t) c0 + cn_stride);
7212+
c1 = (float*) ((uintptr_t) c1 + cn_stride);
7213+
7214+
nc -= 4;
7215+
} else {
7216+
if (nc & 2) {
7217+
_mm_storel_pi((__m64*) c1, vout1x0123);
7218+
vout1x0123 = _mm_unpackhi_ps(vout1x0123, vout1x0123);
7219+
c1 += 2;
7220+
_mm_storel_pi((__m64*) c0, vout0x0123);
7221+
vout0x0123 = _mm_unpackhi_ps(vout0x0123, vout0x0123);
7222+
c0 += 2;
7223+
}
7224+
if (nc & 1) {
7225+
_mm_store_ss(c1, vout1x0123);
7226+
_mm_store_ss(c0, vout0x0123);
7227+
}
7228+
nc = 0;
7229+
}
7230+
} while (nc != 0);
7231+
}
7232+
69627233
void xnn_qs16_qs8_vcvt_ukernel__avx_u16(
69637234
size_t batch,
69647235
const int16_t* input,

0 commit comments

Comments
 (0)