@@ -1797,10 +1797,27 @@ class tinyBLAS_Q0_AVX {
17971797 } \
17981798 } \
17991799
1800+ template <typename T>
1801+ struct mma_instr ;
1802+
1803+ template <>
1804+ struct mma_instr <ggml_bf16_t > {
1805+ static inline void outer_product (acc_t *acc, vec_t a, vec_t b) {
1806+ __builtin_mma_xvbf16ger2pp (acc, a, b);
1807+ }
1808+ };
1809+
1810+ template <>
1811+ struct mma_instr <ggml_fp16_t > {
1812+ static inline void outer_product (acc_t *acc, vec_t a, vec_t b) {
1813+ __builtin_mma_xvf16ger2pp (acc, a, b);
1814+ }
1815+ };
1816+
18001817template <typename TA , typename TB , typename TC >
1801- class tinyBLAS_BF16_PPC {
1818+ class tinyBLAS_HP16_PPC {
18021819 public:
1803- tinyBLAS_BF16_PPC (int64_t k,
1820+ tinyBLAS_HP16_PPC (int64_t k,
18041821 const TA *A, int64_t lda,
18051822 const TB *B, int64_t ldb,
18061823 TC *C, int64_t ldc,
@@ -2118,8 +2135,8 @@ class tinyBLAS_BF16_PPC {
21182135 packNormal ((A+(ii*lda)+l), lda, 4 , 8 , (uint8_t *)vec_A);
21192136 packNormal ((B+(jj*ldb)+l), ldb, 8 , 8 , (uint8_t *)vec_B);
21202137 for (int x = 0 ; x < 4 ; x++) {
2121- __builtin_mma_xvbf16ger2pp (&acc_0, vec_A[x], vec_B[x]);
2122- __builtin_mma_xvbf16ger2pp (&acc_1, vec_A[x], vec_B[x+4 ]);
2138+ mma_instr< TA >:: outer_product (&acc_0, vec_A[x], vec_B[x]);
2139+ mma_instr< TA >:: outer_product (&acc_1, vec_A[x], vec_B[x+4 ]);
21232140 }
21242141 }
21252142 SAVE_ACC (&acc_0, ii, jj);
@@ -2135,8 +2152,8 @@ class tinyBLAS_BF16_PPC {
21352152 packNormal ((A+(ii*lda)+l), lda, 8 , 8 , (uint8_t *)vec_A);
21362153 packNormal ((B+(jj*ldb)+l), ldb, 8 , 4 , (uint8_t *)vec_B);
21372154 for (int x = 0 ; x < 4 ; x++) {
2138- __builtin_mma_xvbf16ger2pp (&acc_0, vec_A[x], vec_B[x]);
2139- __builtin_mma_xvbf16ger2pp (&acc_1, vec_A[x+ 4 ], vec_B[x]);
2155+ mma_instr< TA >:: outer_product (&acc_0, vec_A[x], vec_B[x]);
2156+ mma_instr< TA >:: outer_product (&acc_1, vec_A[x], vec_B[x+ 4 ]);
21402157 }
21412158 }
21422159 SAVE_ACC (&acc_0, ii, jj);
@@ -2155,10 +2172,10 @@ class tinyBLAS_BF16_PPC {
21552172 packNormal (A+(ii*lda)+l, lda, 8 , 8 , (uint8_t *)vec_A);
21562173 packNormal (B+(jj*ldb)+l, ldb, 8 , 8 , (uint8_t *)vec_B);
21572174 for (int x = 0 ; x < 4 ; x++) {
2158- __builtin_mma_xvbf16ger2pp (&acc_0, vec_A[x], vec_B[x]);
2159- __builtin_mma_xvbf16ger2pp (&acc_1, ( vec_t ) vec_A[x], ( vec_t ) vec_B[x+4 ]);
2160- __builtin_mma_xvbf16ger2pp (&acc_2, ( vec_t ) vec_A[x+4 ], ( vec_t ) vec_B[x]);
2161- __builtin_mma_xvbf16ger2pp (&acc_3, ( vec_t ) vec_A[x+4 ], ( vec_t ) vec_B[x+4 ]);
2175+ mma_instr< TA >:: outer_product (&acc_0, vec_A[x], vec_B[x]);
2176+ mma_instr< TA >:: outer_product (&acc_1, vec_A[x], vec_B[x+4 ]);
2177+ mma_instr< TA >:: outer_product (&acc_2, vec_A[x+4 ], vec_B[x]);
2178+ mma_instr< TA >:: outer_product (&acc_3, vec_A[x+4 ], vec_B[x+4 ]);
21622179 }
21632180 }
21642181
@@ -2189,7 +2206,7 @@ class tinyBLAS_BF16_PPC {
21892206 packNormal (A+(ii*lda)+l, lda, RM , 4 , (uint8_t *)vec_A);
21902207 packNormal (B+(jj*ldb)+l, ldb, RN , 4 , (uint8_t *)vec_B);
21912208 for (int x = 0 ; x<2 ; x++) {
2192- __builtin_mma_xvbf16ger2pp (&acc_0, vec_A[x], vec_B[x]);
2209+ mma_instr< TA >:: outer_product (&acc_0, vec_A[x], vec_B[x]);
21932210 }
21942211 }
21952212 __builtin_mma_disassemble_acc (vec_C, &acc_0);
@@ -2224,8 +2241,8 @@ class tinyBLAS_BF16_PPC {
22242241 packNormal (A+(ii*lda)+l, lda, RM , 8 , (uint8_t *)vec_A);
22252242 packNormal (B+(jj*ldb)+l, ldb, RN , 8 , (uint8_t *)vec_B);
22262243 for (int x = 0 ; x<4 ; x++) {
2227- __builtin_mma_xvbf16ger2pp (&acc_0, vec_A[x], vec_B[x]);
2228- __builtin_mma_xvbf16ger2pp (&acc_1, vec_A[x], vec_B[x+4 ]);
2244+ mma_instr< TA >:: outer_product (&acc_0, vec_A[x], vec_B[x]);
2245+ mma_instr< TA >:: outer_product (&acc_1, vec_A[x], vec_B[x+4 ]);
22292246 }
22302247 }
22312248 __builtin_mma_disassemble_acc (vec_C, &acc_0);
@@ -3418,16 +3435,19 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
34183435 return tb.matmul (m, n);
34193436 }
34203437#elif defined(__MMA__)
3421- if ((k % 8 ))
3422- return false ;
3423- if (Btype == GGML_TYPE_BF16 ) {
3424- tinyBLAS_BF16_PPC<ggml_bf16_t , ggml_bf16_t , float > tb{ k,
3425- (const ggml_bf16_t *)A, lda,
3426- (const ggml_bf16_t *)B, ldb,
3427- (float *)C, ldc,
3428- params->ith , params->nth };
3429- tb.matmul (m, n);
3430- return true ;
3438+ if (k % 8 ) {
3439+ return false ;
3440+ }
3441+
3442+ if (Btype == GGML_TYPE_BF16 ) {
3443+ tinyBLAS_HP16_PPC<ggml_bf16_t , ggml_bf16_t , float > tb{ k,
3444+ (const ggml_bf16_t *)A, lda,
3445+ (const ggml_bf16_t *)B, ldb,
3446+ (float *)C, ldc,
3447+ params->ith , params->nth };
3448+
3449+ tb.matmul (m, n);
3450+ return true ;
34313451 }
34323452#elif defined(__riscv_zvfbfwma)
34333453 #if LMUL == 1
@@ -3516,6 +3536,21 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
35163536 #endif
35173537 return tb.matmul (m, n);
35183538 }
3539+ #elif defined(__MMA__)
3540+ if (k % 8 ) {
3541+ return false ;
3542+ }
3543+
3544+ if (Btype == GGML_TYPE_F16 ) {
3545+ tinyBLAS_HP16_PPC<ggml_fp16_t , ggml_fp16_t , float > tb{ k,
3546+ (const ggml_fp16_t *)A, lda,
3547+ (const ggml_fp16_t *)B, ldb,
3548+ (float *)C, ldc,
3549+ params->ith , params->nth };
3550+
3551+ tb.matmul (m, n);
3552+ return true ;
3553+ }
35193554#endif
35203555 return false ;
35213556 }
0 commit comments