Skip to content

Commit 127db93

Browse files
Changed the interface to accept void *
1 parent aefe236 commit 127db93

3 files changed

Lines changed: 20 additions & 14 deletions

File tree

clang/lib/DPCT/RulesAsm/AsmMigration.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1514,13 +1514,12 @@ class SYCLGen : public SYCLGenBase {
15141514
OS() << MapNames::getDpctNamespace() << "experimental::matrix::mma";
15151515
OS() << "<";
15161516
OS() << M << ", " << N << ", " << K << ", ";
1517-
OS() << ABType;
1517+
OS() << ABType << ", " << InMatrixType[0] << ", " << InMatrixType[2];
15181518
OS() << ">(";
15191519

1520-
OS() << "DMatrix_ct1";
1520+
OS() << "reinterpret_cast<void **>(DMatrix_ct1)";
15211521
for (int i = 0; i < 3; i++)
1522-
OS() << ", reinterpret_cast<" << InMatrixType[i] << " *>(&"
1523-
<< InMatrixName[i] << "Matrix_ct1)";
1522+
OS() << ", &" << InMatrixName[i] << "Matrix_ct1";
15241523
OS() << ")";
15251524
endstmt();
15261525
OS() << "}";

clang/runtime/dpct-rt/include/dpct/math.hpp

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2227,25 +2227,33 @@ void ldmatrix(uintptr_t addr, T *m1, T *m2, T *m3, T *m4, bool trans = false) {
22272227
/// \tparam [in] MulType The type used to multiply A and B matrix elements as
22282228
/// \tparam [in] ABType The type of the input matrix (A & B) elements
22292229
/// \tparam [in] CDType The type of the output matrix (C & D) elements
2230-
/// \param [out] d The elements of the output D matrix to store the result to
2231-
/// \param [in] a The elements of the input A matrix to be multiplied with B
2230+
/// \param [out] d_mat The elements of the output D matrix to store the result
2231+
/// of A* B + C
2232+
/// \param [in] a_mat The elements of the input A matrix to be multiplied with B
22322233
/// matrix elements
2233-
/// \param [in] b The elements of the input B matrix to be multiplied with A
2234+
/// \param [in] b_mat The elements of the input B matrix to be multiplied with A
22342235
/// matrix elements
2235-
/// \param [in] c The elements of the input C matrix to be added with the result
2236-
/// of A * B
2236+
/// \param [in] c_mat The elements of the input C matrix to be added with the
2237+
/// result of A * B
22372238
template <int M, int N, int K, typename MulType, typename ABType,
22382239
typename CDType>
2239-
void mma(CDType **d, ABType *a, ABType *b, CDType *c) {
2240+
void mma(void **d_mat, void *a_mat, void *b_mat, void *c_mat) {
2241+
auto d = reinterpret_cast<CDType **>(d_mat);
2242+
auto a = reinterpret_cast<ABType *>(a_mat);
2243+
auto b = reinterpret_cast<ABType *>(b_mat);
2244+
auto c = reinterpret_cast<CDType *>(c_mat);
2245+
22402246
auto sg = sycl::ext::oneapi::this_work_item::get_sub_group();
22412247
int lane = sg.get_local_linear_id();
22422248

2249+
static_assert(M == 16 && N == 8 && K == 16,
2250+
"Only m16n8k16 shape is supported!");
2251+
22432252
short ROW_LOAD_OFFSET = 4 * (lane >> 2);
22442253
short COL_LOAD_OFFSET = 8 * (lane % 4);
22452254

22462255
if constexpr (M == 16 && N == 8 && K == 16) {
22472256
if constexpr (std::is_floating_point_v<CDType>) {
2248-
// f32.f16.f16.f32
22492257
for (int i = 0; i < 4; i++) {
22502258
ABType recv_a[4], recv_b[4];
22512259

@@ -2278,7 +2286,6 @@ void mma(CDType **d, ABType *a, ABType *b, CDType *c) {
22782286
*d[2] = c[2];
22792287
*d[3] = c[3];
22802288
} else if constexpr (std::is_integral_v<MulType>) {
2281-
// s32.s8.s8.s32
22822289
for (int i = 0; i < 4; i++) {
22832290
ABType recv_a[2], recv_b[2];
22842291

clang/test/dpct/asm/mma.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ __global__ void mma_kernel_m16n8k16(int *a, int *b, int *c, float *fc, int *d) {
2727
// CHECK-NEXT: sycl::vec<int32_t, 4> AMatrix_ct1(a[0], a[1], a[2], a[3]);
2828
// CHECK-NEXT: sycl::vec<int32_t, 2> BMatrix_ct1(b[0], b[1]);
2929
// CHECK-NEXT: sycl::vec<float, 4> CMatrix_ct1(fc[0], fc[1], fc[2], fc[3]);
30-
// CHECK-NEXT: dpct::experimental::matrix::mma<16, 8, 16, sycl::half>(DMatrix_ct1, reinterpret_cast<int32_t *>(&AMatrix_ct1), reinterpret_cast<int32_t *>(&BMatrix_ct1), reinterpret_cast<float *>(&CMatrix_ct1));
30+
// CHECK-NEXT: dpct::experimental::matrix::mma<16, 8, 16, sycl::half, int32_t, float>(reinterpret_cast<void **>(DMatrix_ct1), &AMatrix_ct1, &BMatrix_ct1, &CMatrix_ct1);
3131
// CHECK-NEXT: }
3232
asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
3333
" { %0, %1, %2, %3 }, "
@@ -43,7 +43,7 @@ __global__ void mma_kernel_m16n8k16(int *a, int *b, int *c, float *fc, int *d) {
4343
// CHECK-NEXT: sycl::vec<int32_t, 2> AMatrix_ct1(a[0], a[1]);
4444
// CHECK-NEXT: sycl::vec<int32_t, 1> BMatrix_ct1(b[0]);
4545
// CHECK-NEXT: sycl::vec<int32_t, 4> CMatrix_ct1(c[0], c[1], c[2], c[3]);
46-
// CHECK-NEXT: dpct::experimental::matrix::mma<16, 8, 16, int8_t>(DMatrix_ct1, reinterpret_cast<int32_t *>(&AMatrix_ct1), reinterpret_cast<int32_t *>(&BMatrix_ct1), reinterpret_cast<int32_t *>(&CMatrix_ct1));
46+
// CHECK-NEXT: dpct::experimental::matrix::mma<16, 8, 16, int8_t, int32_t, int32_t>(reinterpret_cast<void **>(DMatrix_ct1), &AMatrix_ct1, &BMatrix_ct1, &CMatrix_ct1);
4747
// CHECK-NEXT: }
4848
asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 "
4949
" { %0, %1, %2, %3 }, "

0 commit comments

Comments
 (0)