Skip to content

Commit d48a56e

Browse files
authored
ggml : add some lsx support (ggml-org#23798)
* loongarch : optimize LSX fp16 load/store with native intrinsics Use __lsx_vfcvtl_s_h and __lsx_vfcvt_h_s instead of scalar loops in __lsx_f16x4_load and __lsx_f16x4_store. * loongarch : add LSX implementation for q8_0 dot product * loongarch : add LSX implementation for q6_K dot product * loongarch : add LSX implementation for iq4_xs dot product * Improve reduce ops when sun int16 pairs to int32
1 parent 6e093b8 commit d48a56e

2 files changed

Lines changed: 154 additions & 16 deletions

File tree

ggml/src/ggml-cpu/arch/loongarch/quants.c

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -977,6 +977,35 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
977977
sumf = hsum_float_8(acc);
978978

979979
*s = sumf;
980+
981+
#elif defined(__loongarch_sx)
982+
983+
__m128 acc = (__m128)__lsx_vldi(0);
984+
985+
for (; ib < nb; ++ib) {
986+
const float d = GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d);
987+
const __m128i qx_0 = __lsx_vld((const __m128i *)x[ib].qs, 0);
988+
const __m128i qx_1 = __lsx_vld((const __m128i *)x[ib].qs + 1, 0);
989+
const __m128i qy_0 = __lsx_vld((const __m128i *)y[ib].qs, 0);
990+
const __m128i qy_1 = __lsx_vld((const __m128i *)y[ib].qs + 1, 0);
991+
992+
const __m128i p16_0 = lsx_maddubs_h(qx_0, qy_0);
993+
const __m128i p16_1 = lsx_maddubs_h(qx_1, qy_1);
994+
995+
// Sum int16 pairs → int32
996+
const __m128i s_0 = __lsx_vaddwev_w_h(p16_0, p16_1);
997+
const __m128i s_1 = __lsx_vaddwod_w_h(p16_0, p16_1);
998+
999+
const __m128 q = __lsx_vffint_s_w(__lsx_vadd_w(s_0, s_1));
1000+
acc = __lsx_vfmadd_s(__lsx_vreplfr2vr_s(d), q, acc);
1001+
}
1002+
1003+
__m128 res = lsx_hadd_s(acc, acc);
1004+
res = lsx_hadd_s(res, res);
1005+
sumf = ((v4f32)res)[0];
1006+
1007+
*s = sumf;
1008+
9801009
#else
9811010
UNUSED(nb);
9821011
UNUSED(ib);
@@ -1443,6 +1472,99 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
14431472

14441473
*s = hsum_float_8(acc);
14451474

1475+
#elif defined(__loongarch_sx)
1476+
1477+
const __m128i m32s = __lsx_vreplgr2vr_b(32);
1478+
1479+
__m128 acc_0 = (__m128)__lsx_vldi(0);
1480+
__m128 acc_1 = (__m128)__lsx_vldi(0);
1481+
1482+
for (int i = 0; i < nb; ++i) {
1483+
1484+
const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
1485+
1486+
const uint8_t * GGML_RESTRICT q4 = x[i].ql;
1487+
const uint8_t * GGML_RESTRICT qh = x[i].qh;
1488+
const int8_t * GGML_RESTRICT q8 = y[i].qs;
1489+
1490+
const __m128i scale_i8 = __lsx_vld(x[i].scales, 0);
1491+
const __m128i scales_lo = __lsx_vsllwil_h_b(scale_i8, 0);
1492+
const __m128i scales_hi = __lsx_vsllwil_h_b(__lsx_vbsrl_v(scale_i8, 8), 0);
1493+
1494+
__m128i sumi_0 = __lsx_vldi(0);
1495+
__m128i sumi_1 = __lsx_vldi(0);
1496+
1497+
for (int j = 0; j < QK_K/128; ++j) {
1498+
1499+
const __m128i q4bitsH_0 = __lsx_vld((const __m128i*)qh, 0); qh += 16;
1500+
const __m128i q4bitsH_1 = __lsx_vld((const __m128i*)qh, 0); qh += 16;
1501+
1502+
const __m128i q4h_0 = __lsx_vslli_b(__lsx_vandi_b(q4bitsH_0, 3), 4);
1503+
const __m128i q4h_1 = __lsx_vslli_b(__lsx_vandi_b(q4bitsH_1, 3), 4);
1504+
const __m128i q4h_2 = __lsx_vslli_b(__lsx_vandi_b(q4bitsH_0, 3 << 2), 2);
1505+
const __m128i q4h_3 = __lsx_vslli_b(__lsx_vandi_b(q4bitsH_1, 3 << 2), 2);
1506+
const __m128i q4h_4 = __lsx_vandi_b(q4bitsH_0, 3 << 4);
1507+
const __m128i q4h_5 = __lsx_vandi_b(q4bitsH_1, 3 << 4);
1508+
const __m128i q4h_6 = __lsx_vsrli_b(__lsx_vandi_b(q4bitsH_0, 3 << 6), 2);
1509+
const __m128i q4h_7 = __lsx_vsrli_b(__lsx_vandi_b(q4bitsH_1, 3 << 6), 2);
1510+
1511+
const __m128i q4bits1_0 = __lsx_vld((const __m128i*)q4, 0); q4 += 16;
1512+
const __m128i q4bits1_1 = __lsx_vld((const __m128i*)q4, 0); q4 += 16;
1513+
const __m128i q4bits2_0 = __lsx_vld((const __m128i*)q4, 0); q4 += 16;
1514+
const __m128i q4bits2_1 = __lsx_vld((const __m128i*)q4, 0); q4 += 16;
1515+
1516+
const __m128i q4_0 = __lsx_vor_v(__lsx_vandi_b(q4bits1_0, 0xf), q4h_0);
1517+
const __m128i q4_1 = __lsx_vor_v(__lsx_vandi_b(q4bits1_1, 0xf), q4h_1);
1518+
const __m128i q4_2 = __lsx_vor_v(__lsx_vandi_b(q4bits2_0, 0xf), q4h_2);
1519+
const __m128i q4_3 = __lsx_vor_v(__lsx_vandi_b(q4bits2_1, 0xf), q4h_3);
1520+
const __m128i q4_4 = __lsx_vor_v(__lsx_vsrli_b(q4bits1_0, 4), q4h_4);
1521+
const __m128i q4_5 = __lsx_vor_v(__lsx_vsrli_b(q4bits1_1, 4), q4h_5);
1522+
const __m128i q4_6 = __lsx_vor_v(__lsx_vsrli_b(q4bits2_0, 4), q4h_6);
1523+
const __m128i q4_7 = __lsx_vor_v(__lsx_vsrli_b(q4bits2_1, 4), q4h_7);
1524+
1525+
const __m128i q8_0 = __lsx_vld((const __m128i*)q8, 0); q8 += 16;
1526+
const __m128i q8_1 = __lsx_vld((const __m128i*)q8, 0); q8 += 16;
1527+
const __m128i q8_2 = __lsx_vld((const __m128i*)q8, 0); q8 += 16;
1528+
const __m128i q8_3 = __lsx_vld((const __m128i*)q8, 0); q8 += 16;
1529+
const __m128i q8_4 = __lsx_vld((const __m128i*)q8, 0); q8 += 16;
1530+
const __m128i q8_5 = __lsx_vld((const __m128i*)q8, 0); q8 += 16;
1531+
const __m128i q8_6 = __lsx_vld((const __m128i*)q8, 0); q8 += 16;
1532+
const __m128i q8_7 = __lsx_vld((const __m128i*)q8, 0); q8 += 16;
1533+
1534+
__m128i p16_0 = lsx_maddubs_h(__lsx_vsub_b(q4_0, m32s), q8_0);
1535+
__m128i p16_1 = lsx_maddubs_h(__lsx_vsub_b(q4_1, m32s), q8_1);
1536+
__m128i p16_2 = lsx_maddubs_h(__lsx_vsub_b(q4_2, m32s), q8_2);
1537+
__m128i p16_3 = lsx_maddubs_h(__lsx_vsub_b(q4_3, m32s), q8_3);
1538+
__m128i p16_4 = lsx_maddubs_h(__lsx_vsub_b(q4_4, m32s), q8_4);
1539+
__m128i p16_5 = lsx_maddubs_h(__lsx_vsub_b(q4_5, m32s), q8_5);
1540+
__m128i p16_6 = lsx_maddubs_h(__lsx_vsub_b(q4_6, m32s), q8_6);
1541+
__m128i p16_7 = lsx_maddubs_h(__lsx_vsub_b(q4_7, m32s), q8_7);
1542+
1543+
const __m128i sc_vec = j == 0 ? scales_lo : scales_hi;
1544+
1545+
p16_0 = lsx_madd_h(__lsx_vreplvei_h(sc_vec, 0), p16_0);
1546+
p16_1 = lsx_madd_h(__lsx_vreplvei_h(sc_vec, 1), p16_1);
1547+
p16_2 = lsx_madd_h(__lsx_vreplvei_h(sc_vec, 2), p16_2);
1548+
p16_3 = lsx_madd_h(__lsx_vreplvei_h(sc_vec, 3), p16_3);
1549+
p16_4 = lsx_madd_h(__lsx_vreplvei_h(sc_vec, 4), p16_4);
1550+
p16_5 = lsx_madd_h(__lsx_vreplvei_h(sc_vec, 5), p16_5);
1551+
p16_6 = lsx_madd_h(__lsx_vreplvei_h(sc_vec, 6), p16_6);
1552+
p16_7 = lsx_madd_h(__lsx_vreplvei_h(sc_vec, 7), p16_7);
1553+
1554+
sumi_0 = __lsx_vadd_w(sumi_0, __lsx_vadd_w(p16_0, p16_2));
1555+
sumi_1 = __lsx_vadd_w(sumi_1, __lsx_vadd_w(p16_1, p16_3));
1556+
sumi_0 = __lsx_vadd_w(sumi_0, __lsx_vadd_w(p16_4, p16_6));
1557+
sumi_1 = __lsx_vadd_w(sumi_1, __lsx_vadd_w(p16_5, p16_7));
1558+
}
1559+
1560+
__m128 p_0 = __lsx_vfmul_s(__lsx_vreplfr2vr_s(d), __lsx_vffint_s_w(sumi_0));
1561+
__m128 p_1 = __lsx_vfmul_s(__lsx_vreplfr2vr_s(d), __lsx_vffint_s_w(sumi_1));
1562+
acc_0 = __lsx_vfadd_s(p_0, acc_0);
1563+
acc_1 = __lsx_vfadd_s(p_1, acc_1);
1564+
}
1565+
1566+
*s = hsum_float_4x4(acc_0, acc_1, (__m128)__lsx_vldi(0), (__m128)__lsx_vldi(0));
1567+
14461568
#else
14471569
UNUSED(x);
14481570
UNUSED(y);
@@ -2149,6 +2271,35 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v
21492271

21502272
*s = hsum_float_8(accum);
21512273

2274+
#elif defined(__loongarch_sx)
2275+
2276+
const __m128i values128 = __lsx_vld((const __m128i*)kvalues_iq4nl, 0);
2277+
2278+
__m128 accum = (__m128)__lsx_vldi(0);
2279+
for (int ibl = 0; ibl < nb; ++ibl) {
2280+
const uint8_t * qs = x[ibl].qs;
2281+
const int8_t * q8 = y[ibl].qs;
2282+
uint16_t sh = x[ibl].scales_h;
2283+
__m128i sumi = __lsx_vldi(0);
2284+
for (int ib = 0; ib < QK_K/32; ++ib) {
2285+
const __m128i q4bits = __lsx_vld((const __m128i*)qs, 0); qs += 16;
2286+
const __m128i q8b_0 = __lsx_vld((const __m128i*)q8, 0); q8 += 16;
2287+
const __m128i q8b_1 = __lsx_vld((const __m128i*)q8, 0); q8 += 16;
2288+
const __m128i q4b_0 = __lsx_vshuf_b(values128, values128, __lsx_vandi_b(q4bits, 0xf));
2289+
const __m128i q4b_1 = __lsx_vshuf_b(values128, values128, __lsx_vsrli_b(q4bits, 4));
2290+
const __m128i p16_0 = lsx_maddubs_h(q4b_0, q8b_0);
2291+
const __m128i p16_1 = lsx_maddubs_h(q4b_1, q8b_1);
2292+
const int16_t ls = (((x[ibl].scales_l[ib/2] >> ((ib & 1) * 4)) & 0xf) | ((sh & 0x3) << 4)) - 32;
2293+
sh >>= 2;
2294+
sumi = __lsx_vadd_w(lsx_madd_h(p16_0, __lsx_vreplgr2vr_h(ls)), sumi);
2295+
sumi = __lsx_vadd_w(lsx_madd_h(p16_1, __lsx_vreplgr2vr_h(ls)), sumi);
2296+
}
2297+
const float ds = GGML_CPU_FP16_TO_FP32(x[ibl].d) * y[ibl].d;
2298+
accum = __lsx_vfadd_s(__lsx_vfmul_s(__lsx_vreplfr2vr_s(ds), __lsx_vffint_s_w(sumi)), accum);
2299+
}
2300+
2301+
*s = ((v4f32)lsx_hadd_s(lsx_hadd_s(accum, accum), lsx_hadd_s(accum, accum)))[0];
2302+
21522303
#else
21532304
UNUSED(x);
21542305
UNUSED(y);

ggml/src/ggml-cpu/simd-mappings.h

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1125,25 +1125,12 @@ static inline void __lasx_f32cx8_store(ggml_fp16_t * x, __m256 y) {
11251125
#define GGML_F16_EPR 4
11261126

11271127
static inline __m128 __lsx_f16x4_load(const ggml_fp16_t * x) {
1128-
float tmp[4];
1129-
1130-
tmp[0] = GGML_CPU_FP16_TO_FP32(x[0]);
1131-
tmp[1] = GGML_CPU_FP16_TO_FP32(x[1]);
1132-
tmp[2] = GGML_CPU_FP16_TO_FP32(x[2]);
1133-
tmp[3] = GGML_CPU_FP16_TO_FP32(x[3]);
1134-
1135-
return (__m128)__lsx_vld(tmp, 0);
1128+
return __lsx_vfcvtl_s_h(__lsx_vld((const void *)x, 0));
11361129
}
11371130

11381131
static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) {
1139-
float arr[4];
1140-
1141-
__lsx_vst(y, arr, 0);
1142-
1143-
x[0] = GGML_CPU_FP32_TO_FP16(arr[0]);
1144-
x[1] = GGML_CPU_FP32_TO_FP16(arr[1]);
1145-
x[2] = GGML_CPU_FP32_TO_FP16(arr[2]);
1146-
x[3] = GGML_CPU_FP32_TO_FP16(arr[3]);
1132+
__m128i a = __lsx_vfcvt_h_s(y, y);
1133+
memcpy(x, &a, sizeof(ggml_fp16_t) * 4);
11471134
}
11481135

11491136
#define GGML_F32Cx4 __m128

0 commit comments

Comments
 (0)