@@ -2227,25 +2227,33 @@ void ldmatrix(uintptr_t addr, T *m1, T *m2, T *m3, T *m4, bool trans = false) {
22272227// / \tparam [in] MulType The type used to multiply A and B matrix elements as
22282228// / \tparam [in] ABType The type of the input matrix (A & B) elements
22292229// / \tparam [in] CDType The type of the output matrix (C & D) elements
2230- // / \param [out] d The elements of the output D matrix to store the result to
2231- // / \param [in] a The elements of the input A matrix to be multiplied with B
2230+ // / \param [out] d_mat The elements of the output D matrix to store the result
2231+ // / of A* B + C
2232+ // / \param [in] a_mat The elements of the input A matrix to be multiplied with B
22322233// / matrix elements
2233- // / \param [in] b The elements of the input B matrix to be multiplied with A
2234+ // / \param [in] b_mat The elements of the input B matrix to be multiplied with A
22342235// / matrix elements
2235- // / \param [in] c The elements of the input C matrix to be added with the result
2236- // / of A * B
2236+ // / \param [in] c_mat The elements of the input C matrix to be added with the
2237+ // / result of A * B
22372238template <int M, int N, int K, typename MulType, typename ABType,
22382239 typename CDType>
2239- void mma (CDType **d, ABType *a, ABType *b, CDType *c) {
2240+ void mma (void **d_mat, void *a_mat, void *b_mat, void *c_mat) {
2241+ auto d = reinterpret_cast <CDType **>(d_mat);
2242+ auto a = reinterpret_cast <ABType *>(a_mat);
2243+ auto b = reinterpret_cast <ABType *>(b_mat);
2244+ auto c = reinterpret_cast <CDType *>(c_mat);
2245+
22402246 auto sg = sycl::ext::oneapi::this_work_item::get_sub_group ();
22412247 int lane = sg.get_local_linear_id ();
22422248
2249+ static_assert (M == 16 && N == 8 && K == 16 ,
2250+ " Only m16n8k16 shape is supported!" );
2251+
22432252 short ROW_LOAD_OFFSET = 4 * (lane >> 2 );
22442253 short COL_LOAD_OFFSET = 8 * (lane % 4 );
22452254
22462255 if constexpr (M == 16 && N == 8 && K == 16 ) {
22472256 if constexpr (std::is_floating_point_v<CDType>) {
2248- // f32.f16.f16.f32
22492257 for (int i = 0 ; i < 4 ; i++) {
22502258 ABType recv_a[4 ], recv_b[4 ];
22512259
@@ -2278,7 +2286,6 @@ void mma(CDType **d, ABType *a, ABType *b, CDType *c) {
22782286 *d[2 ] = c[2 ];
22792287 *d[3 ] = c[3 ];
22802288 } else if constexpr (std::is_integral_v<MulType>) {
2281- // s32.s8.s8.s32
22822289 for (int i = 0 ; i < 4 ; i++) {
22832290 ABType recv_a[2 ], recv_b[2 ];
22842291
0 commit comments