Skip to content

Commit 7afdfc9

Browse files
authored
ggml-cpu: Enable FP16 MMA kernels on PPC (ggml-org#19060)
1 parent 94eeb59 commit 7afdfc9

1 file changed

Lines changed: 58 additions & 23 deletions

File tree

ggml/src/ggml-cpu/llamafile/sgemm.cpp

Lines changed: 58 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1797,10 +1797,27 @@ class tinyBLAS_Q0_AVX {
17971797
} \
17981798
} \
17991799

1800+
template<typename T>
1801+
struct mma_instr;
1802+
1803+
template<>
1804+
struct mma_instr<ggml_bf16_t> {
1805+
static inline void outer_product(acc_t *acc, vec_t a, vec_t b) {
1806+
__builtin_mma_xvbf16ger2pp(acc, a, b);
1807+
}
1808+
};
1809+
1810+
template<>
1811+
struct mma_instr<ggml_fp16_t> {
1812+
static inline void outer_product(acc_t *acc, vec_t a, vec_t b) {
1813+
__builtin_mma_xvf16ger2pp(acc, a, b);
1814+
}
1815+
};
1816+
18001817
template <typename TA, typename TB, typename TC>
1801-
class tinyBLAS_BF16_PPC {
1818+
class tinyBLAS_HP16_PPC {
18021819
public:
1803-
tinyBLAS_BF16_PPC(int64_t k,
1820+
tinyBLAS_HP16_PPC(int64_t k,
18041821
const TA *A, int64_t lda,
18051822
const TB *B, int64_t ldb,
18061823
TC *C, int64_t ldc,
@@ -2118,8 +2135,8 @@ class tinyBLAS_BF16_PPC {
21182135
packNormal((A+(ii*lda)+l), lda, 4, 8, (uint8_t*)vec_A);
21192136
packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B);
21202137
for (int x = 0; x < 4; x++) {
2121-
__builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
2122-
__builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x], vec_B[x+4]);
2138+
mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
2139+
mma_instr<TA>::outer_product(&acc_1, vec_A[x], vec_B[x+4]);
21232140
}
21242141
}
21252142
SAVE_ACC(&acc_0, ii, jj);
@@ -2135,8 +2152,8 @@ class tinyBLAS_BF16_PPC {
21352152
packNormal((A+(ii*lda)+l), lda, 8, 8, (uint8_t*)vec_A);
21362153
packNormal((B+(jj*ldb)+l), ldb, 8, 4, (uint8_t*)vec_B);
21372154
for (int x = 0; x < 4; x++) {
2138-
__builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
2139-
__builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x+4], vec_B[x]);
2155+
mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
2156+
mma_instr<TA>::outer_product(&acc_1, vec_A[x], vec_B[x+4]);
21402157
}
21412158
}
21422159
SAVE_ACC(&acc_0, ii, jj);
@@ -2155,10 +2172,10 @@ class tinyBLAS_BF16_PPC {
21552172
packNormal(A+(ii*lda)+l, lda, 8, 8, (uint8_t*)vec_A);
21562173
packNormal(B+(jj*ldb)+l, ldb, 8, 8, (uint8_t*)vec_B);
21572174
for (int x = 0; x < 4; x++) {
2158-
__builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
2159-
__builtin_mma_xvbf16ger2pp(&acc_1, (vec_t)vec_A[x], (vec_t)vec_B[x+4]);
2160-
__builtin_mma_xvbf16ger2pp(&acc_2, (vec_t)vec_A[x+4], (vec_t)vec_B[x]);
2161-
__builtin_mma_xvbf16ger2pp(&acc_3, (vec_t)vec_A[x+4], (vec_t)vec_B[x+4]);
2175+
mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
2176+
mma_instr<TA>::outer_product(&acc_1, vec_A[x], vec_B[x+4]);
2177+
mma_instr<TA>::outer_product(&acc_2, vec_A[x+4], vec_B[x]);
2178+
mma_instr<TA>::outer_product(&acc_3, vec_A[x+4], vec_B[x+4]);
21622179
}
21632180
}
21642181

@@ -2189,7 +2206,7 @@ class tinyBLAS_BF16_PPC {
21892206
packNormal(A+(ii*lda)+l, lda, RM, 4, (uint8_t*)vec_A);
21902207
packNormal(B+(jj*ldb)+l, ldb, RN, 4, (uint8_t*)vec_B);
21912208
for (int x = 0; x<2; x++) {
2192-
__builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
2209+
mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
21932210
}
21942211
}
21952212
__builtin_mma_disassemble_acc(vec_C, &acc_0);
@@ -2224,8 +2241,8 @@ class tinyBLAS_BF16_PPC {
22242241
packNormal(A+(ii*lda)+l, lda, RM, 8, (uint8_t*)vec_A);
22252242
packNormal(B+(jj*ldb)+l, ldb, RN, 8, (uint8_t*)vec_B);
22262243
for (int x = 0; x<4; x++) {
2227-
__builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
2228-
__builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x], vec_B[x+4]);
2244+
mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
2245+
mma_instr<TA>::outer_product(&acc_1, vec_A[x], vec_B[x+4]);
22292246
}
22302247
}
22312248
__builtin_mma_disassemble_acc(vec_C, &acc_0);
@@ -3418,16 +3435,19 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
34183435
return tb.matmul(m, n);
34193436
}
34203437
#elif defined(__MMA__)
3421-
if ((k % 8))
3422-
return false;
3423-
if(Btype == GGML_TYPE_BF16) {
3424-
tinyBLAS_BF16_PPC<ggml_bf16_t, ggml_bf16_t, float> tb{ k,
3425-
(const ggml_bf16_t *)A, lda,
3426-
(const ggml_bf16_t *)B, ldb,
3427-
(float *)C, ldc,
3428-
params->ith, params->nth};
3429-
tb.matmul(m, n);
3430-
return true;
3438+
if (k % 8) {
3439+
return false;
3440+
}
3441+
3442+
if (Btype == GGML_TYPE_BF16) {
3443+
tinyBLAS_HP16_PPC<ggml_bf16_t, ggml_bf16_t, float> tb{ k,
3444+
(const ggml_bf16_t *)A, lda,
3445+
(const ggml_bf16_t *)B, ldb,
3446+
(float *)C, ldc,
3447+
params->ith, params->nth };
3448+
3449+
tb.matmul(m, n);
3450+
return true;
34313451
}
34323452
#elif defined(__riscv_zvfbfwma)
34333453
#if LMUL == 1
@@ -3516,6 +3536,21 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
35163536
#endif
35173537
return tb.matmul(m, n);
35183538
}
3539+
#elif defined(__MMA__)
3540+
if (k % 8) {
3541+
return false;
3542+
}
3543+
3544+
if (Btype == GGML_TYPE_F16) {
3545+
tinyBLAS_HP16_PPC<ggml_fp16_t, ggml_fp16_t, float> tb{ k,
3546+
(const ggml_fp16_t *)A, lda,
3547+
(const ggml_fp16_t *)B, ldb,
3548+
(float *)C, ldc,
3549+
params->ith, params->nth };
3550+
3551+
tb.matmul(m, n);
3552+
return true;
3553+
}
35193554
#endif
35203555
return false;
35213556
}

0 commit comments

Comments
 (0)