Skip to content

Commit 7d794e1

Browse files
Updated type info of args
1 parent d683438 commit 7d794e1

4 files changed

Lines changed: 22 additions & 20 deletions

File tree

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
// Option: --use-experimental-features=matrix
22
#include <mma.h>
33

4-
__global__ void test(float val) {
4+
template <typename T> __global__ void test(T val) {
55
// Start
66
nvcuda::wmma::fragment<nvcuda::wmma::accumulator, 16, 16, 16, float> acc_frag;
7-
nvcuda::wmma::fill_fragment(acc_frag, val /*float*/);
7+
nvcuda::wmma::fill_fragment(acc_frag, val /*const T&*/);
88
// End
99
}
Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
// Option: --use-experimental-features=matrix
22
#include <mma.h>
33

4-
__global__ void test(half *a, int row, int col, int lda) {
4+
template <typename T>
5+
__global__ void test(const T *a, int row, int col, unsigned lda) {
56
// Start
67
nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, 16, 16, 16, half,
78
nvcuda::wmma::row_major>
89
a_frag;
9-
nvcuda::wmma::load_matrix_sync(a_frag, a + col + row * lda /*void **/,
10-
lda /*int*/);
10+
nvcuda::wmma::load_matrix_sync(a_frag, a + col + row * lda /*const T **/,
11+
lda /*unsigned*/);
1112
// End
1213
}
Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
// Option: --use-experimental-features=matrix
22
#include <mma.h>
33

4-
__global__ void test(float *c, int row, int col, int ldc) {
4+
template <typename T>
5+
__global__ void test(const T *c, int row, int col, unsigned ldc) {
56
// Start
67
nvcuda::wmma::fragment<nvcuda::wmma::accumulator, 16, 16, 16, float> acc_frag;
78
nvcuda::wmma::store_matrix_sync(
8-
c + col + row * ldc /*void **/, acc_frag, ldc /*int*/,
9-
nvcuda::wmma::mem_col_major /*memory order*/);
9+
c + col + row * ldc /*const T **/, acc_frag, ldc /*unsigned*/,
10+
nvcuda::wmma::mem_col_major /*nvcuda::wmma::layout_t*/);
1011
nvcuda::wmma::store_matrix_sync(
11-
c + row + col * ldc /*void **/, acc_frag, ldc /*int*/,
12-
nvcuda::wmma::mem_row_major /*memory order*/);
12+
c + row + col * ldc /*const T **/, acc_frag, ldc /*unsigned*/,
13+
nvcuda::wmma::mem_row_major /*nvcuda::wmma::layout_t*/);
1314
// End
1415
}

clang/test/dpct/query_api_mapping/Runtime/test_wmma.cu

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
// RUN: dpct --cuda-include-path="%cuda-path/include" --query-api-mapping=nvcuda::wmma::fill_fragment | FileCheck %s -check-prefix=NVCUDA_WMMA_FILL_FRAGMENT
55
// NVCUDA_WMMA_FILL_FRAGMENT: CUDA API:
66
// NVCUDA_WMMA_FILL_FRAGMENT-NEXT: nvcuda::wmma::fragment<nvcuda::wmma::accumulator, 16, 16, 16, float> acc_frag;
7-
// NVCUDA_WMMA_FILL_FRAGMENT-NEXT: nvcuda::wmma::fill_fragment(acc_frag, val /*float*/);
7+
// NVCUDA_WMMA_FILL_FRAGMENT-NEXT: nvcuda::wmma::fill_fragment(acc_frag, val /*const T&*/);
88
// NVCUDA_WMMA_FILL_FRAGMENT-NEXT: Is migrated to (with the option --use-experimental-features=matrix):
99
// NVCUDA_WMMA_FILL_FRAGMENT-NEXT: dpct::experimental::matrix::joint_matrix<dpct::experimental::matrix::accumulator, 16, 16, 16, float> acc_frag;
1010
// NVCUDA_WMMA_FILL_FRAGMENT-NEXT: sycl::ext::oneapi::experimental::matrix::joint_matrix_fill(sycl::ext::oneapi::this_work_item::get_sub_group(), acc_frag.get(), val);
@@ -14,26 +14,26 @@
1414
// NVCUDA_WMMA_LOAD_MATRIX_SYNC-NEXT: nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, 16, 16, 16, half,
1515
// NVCUDA_WMMA_LOAD_MATRIX_SYNC-NEXT: nvcuda::wmma::row_major>
1616
// NVCUDA_WMMA_LOAD_MATRIX_SYNC-NEXT: a_frag;
17-
// NVCUDA_WMMA_LOAD_MATRIX_SYNC-NEXT: nvcuda::wmma::load_matrix_sync(a_frag, a + col + row * lda /*void **/,
18-
// NVCUDA_WMMA_LOAD_MATRIX_SYNC-NEXT: lda /*int*/);
17+
// NVCUDA_WMMA_LOAD_MATRIX_SYNC-NEXT: nvcuda::wmma::load_matrix_sync(a_frag, a + col + row * lda /*const T **/,
18+
// NVCUDA_WMMA_LOAD_MATRIX_SYNC-NEXT: lda /*unsigned*/);
1919
// NVCUDA_WMMA_LOAD_MATRIX_SYNC-NEXT: Is migrated to (with the option --use-experimental-features=matrix):
2020
// NVCUDA_WMMA_LOAD_MATRIX_SYNC-NEXT: dpct::experimental::matrix::joint_matrix<dpct::experimental::matrix::a, 16, 16, 16, sycl::half, dpct::experimental::matrix::row_major>
2121
// NVCUDA_WMMA_LOAD_MATRIX_SYNC-NEXT: a_frag;
22-
// NVCUDA_WMMA_LOAD_MATRIX_SYNC-NEXT: sycl::ext::oneapi::experimental::matrix::joint_matrix_load(sycl::ext::oneapi::this_work_item::get_sub_group(), a_frag.get(), sycl::address_space_cast<sycl::access::address_space::generic_space, sycl::access::decorated::no, const sycl::half>(a + col + row * lda), lda);
22+
// NVCUDA_WMMA_LOAD_MATRIX_SYNC-NEXT: sycl::ext::oneapi::experimental::matrix::joint_matrix_load(sycl::ext::oneapi::this_work_item::get_sub_group(), a_frag.get(), sycl::address_space_cast<sycl::access::address_space::generic_space, sycl::access::decorated::no, typename std::remove_pointer<decltype(a + col + row * lda)>::type>(a + col + row * lda), lda);
2323

2424
// RUN: dpct --cuda-include-path="%cuda-path/include" --query-api-mapping=nvcuda::wmma::store_matrix_sync | FileCheck %s -check-prefix=NVCUDA_WMMA_STORE_MATRIX_SYNC
2525
// NVCUDA_WMMA_STORE_MATRIX_SYNC: CUDA API:
2626
// NVCUDA_WMMA_STORE_MATRIX_SYNC-NEXT: nvcuda::wmma::fragment<nvcuda::wmma::accumulator, 16, 16, 16, float> acc_frag;
2727
// NVCUDA_WMMA_STORE_MATRIX_SYNC-NEXT: nvcuda::wmma::store_matrix_sync(
28-
// NVCUDA_WMMA_STORE_MATRIX_SYNC-NEXT: c + col + row * ldc /*void **/, acc_frag, ldc /*int*/,
29-
// NVCUDA_WMMA_STORE_MATRIX_SYNC-NEXT: nvcuda::wmma::mem_col_major /*memory order*/);
28+
// NVCUDA_WMMA_STORE_MATRIX_SYNC-NEXT: c + col + row * ldc /*const T **/, acc_frag, ldc /*unsigned*/,
29+
// NVCUDA_WMMA_STORE_MATRIX_SYNC-NEXT: nvcuda::wmma::mem_col_major /*nvcuda::wmma::layout_t*/);
3030
// NVCUDA_WMMA_STORE_MATRIX_SYNC-NEXT: nvcuda::wmma::store_matrix_sync(
31-
// NVCUDA_WMMA_STORE_MATRIX_SYNC-NEXT: c + row + col * ldc /*void **/, acc_frag, ldc /*int*/,
32-
// NVCUDA_WMMA_STORE_MATRIX_SYNC-NEXT: nvcuda::wmma::mem_row_major /*memory order*/);
31+
// NVCUDA_WMMA_STORE_MATRIX_SYNC-NEXT: c + row + col * ldc /*const T **/, acc_frag, ldc /*unsigned*/,
32+
// NVCUDA_WMMA_STORE_MATRIX_SYNC-NEXT: nvcuda::wmma::mem_row_major /*nvcuda::wmma::layout_t*/);
3333
// NVCUDA_WMMA_STORE_MATRIX_SYNC-NEXT: Is migrated to (with the option --use-experimental-features=matrix):
3434
// NVCUDA_WMMA_STORE_MATRIX_SYNC-NEXT: dpct::experimental::matrix::joint_matrix<dpct::experimental::matrix::accumulator, 16, 16, 16, float> acc_frag;
35-
// NVCUDA_WMMA_STORE_MATRIX_SYNC-NEXT: sycl::ext::oneapi::experimental::matrix::joint_matrix_store(sycl::ext::oneapi::this_work_item::get_sub_group(), acc_frag.get(), sycl::address_space_cast<sycl::access::address_space::generic_space, sycl::access::decorated::no, float>(c + col + row * ldc), ldc, sycl::ext::oneapi::experimental::matrix::layout::col_major);
36-
// NVCUDA_WMMA_STORE_MATRIX_SYNC-NEXT: sycl::ext::oneapi::experimental::matrix::joint_matrix_store(sycl::ext::oneapi::this_work_item::get_sub_group(), acc_frag.get(), sycl::address_space_cast<sycl::access::address_space::generic_space, sycl::access::decorated::no, float>(c + row + col * ldc), ldc, sycl::ext::oneapi::experimental::matrix::layout::row_major);
35+
// NVCUDA_WMMA_STORE_MATRIX_SYNC-NEXT: sycl::ext::oneapi::experimental::matrix::joint_matrix_store(sycl::ext::oneapi::this_work_item::get_sub_group(), acc_frag.get(), sycl::address_space_cast<sycl::access::address_space::generic_space, sycl::access::decorated::no, typename std::remove_pointer<decltype(c + col + row * ldc)>::type>(c + col + row * ldc), ldc, sycl::ext::oneapi::experimental::matrix::layout::col_major);
36+
// NVCUDA_WMMA_STORE_MATRIX_SYNC-NEXT: sycl::ext::oneapi::experimental::matrix::joint_matrix_store(sycl::ext::oneapi::this_work_item::get_sub_group(), acc_frag.get(), sycl::address_space_cast<sycl::access::address_space::generic_space, sycl::access::decorated::no, typename std::remove_pointer<decltype(c + row + col * ldc)>::type>(c + row + col * ldc), ldc, sycl::ext::oneapi::experimental::matrix::layout::row_major);
3737

3838
// RUN: dpct --cuda-include-path="%cuda-path/include" --query-api-mapping=nvcuda::wmma::mma_sync | FileCheck %s -check-prefix=NVCUDA_WMMA_MMA_SYNC
3939
// NVCUDA_WMMA_MMA_SYNC: CUDA API:

0 commit comments

Comments
 (0)