@@ -2061,13 +2061,13 @@ class joint_matrix {
20612061// / \tparam [in] T The type of result variable
20622062// / \param [in] addr The address of the matrix in shared memory
20632063// / \param [in] m The local memory to store the matrix
2064- // / \param [in] item_ct1 The sycl::nd_item object
2064+ // / \param [in] item The sycl::nd_item index space class
20652065// / \param [in] trans Indicates whether the matrix to be loaded transposed
20662066// / \param [in] mat The matrix index to be loaded
2067- template <typename T>
2068- void ldmatrix (uintptr_t addr, T *m, const sycl::nd_item< 3 > &item_ct1 ,
2069- bool trans = false , unsigned mat = 0 ) {
2070- int lane = item_ct1. get_local_id ( 2 ) % 32 ;
2067+ template <typename T, typename ItemT >
2068+ void ldmatrix (uintptr_t addr, T *m, const ItemT &item, bool trans = false ,
2069+ unsigned mat = 0 ) {
2070+ int lane = item. get_sub_group (). get_local_linear_id () ;
20712071
20722072 int lane_group8_row = lane / 8 ;
20732073 int lane_group8_col = lane % 8 ;
@@ -2080,7 +2080,7 @@ void ldmatrix(uintptr_t addr, T *m, const sycl::nd_item<3> &item_ct1,
20802080
20812081 // Broadcast the address from the source lane
20822082 auto recv_addr_uintp = dpct::select_from_sub_group (
2083- item_ct1 .get_sub_group (), addr, mat * 8 + src_lane);
2083+ item .get_sub_group (), addr, mat * 8 + src_lane);
20842084
20852085 // Cast the received address from uintptr_t to the type of 'm'
20862086 auto recv_addr = reinterpret_cast <T *>(recv_addr_uintp);
@@ -2093,9 +2093,9 @@ void ldmatrix(uintptr_t addr, T *m, const sycl::nd_item<3> &item_ct1,
20932093
20942094 // Broadcast the address from the source lane
20952095 auto recv_addr_uintp_1 = dpct::select_from_sub_group (
2096- item_ct1 .get_sub_group (), addr, mat * 8 + src_lane);
2096+ item .get_sub_group (), addr, mat * 8 + src_lane);
20972097 auto recv_addr_uintp_2 = dpct::select_from_sub_group (
2098- item_ct1 .get_sub_group (), addr, mat * 8 + src_lane + 1 );
2098+ item .get_sub_group (), addr, mat * 8 + src_lane + 1 );
20992099
21002100 // Cast the received address from uintptr_t to 'half *'
21012101 auto recv_addr_1 = reinterpret_cast <sycl::half *>(recv_addr_uintp_1);
@@ -2118,15 +2118,15 @@ void ldmatrix(uintptr_t addr, T *m, const sycl::nd_item<3> &item_ct1,
21182118// / \param [in] addr The address of the matrix in shared memory
21192119// / \param [in] m1 The local memory to store data of 1st matrix
21202120// / \param [in] m2 The local memory to store data of 2nd matrix
2121- // / \param [in] item_ct1 The sycl::nd_item object
2121+ // / \param [in] item The sycl::nd_item index space class
21222122// / \param [in] trans Indicates whether the matrix to be loaded transposed
2123- template <typename T>
2124- void ldmatrix (uintptr_t addr, T *m1, T *m2, const sycl::nd_item< 3 > &item_ct1 ,
2123+ template <typename T, typename ItemT >
2124+ void ldmatrix (uintptr_t addr, T *m1, T *m2, const ItemT &item ,
21252125 bool trans = false ) {
21262126 // Load 1st matrix
2127- ldmatrix (addr, m1, item_ct1 , trans, 0 );
2127+ ldmatrix (addr, m1, item , trans, 0 );
21282128 // Load 2nd matrix
2129- ldmatrix (addr, m2, item_ct1 , trans, 1 );
2129+ ldmatrix (addr, m2, item , trans, 1 );
21302130}
21312131
21322132// / Loads 4 8x8 b16 matrix from shared memory to local memory (32-bits per wi)
@@ -2137,19 +2137,19 @@ void ldmatrix(uintptr_t addr, T *m1, T *m2, const sycl::nd_item<3> &item_ct1,
21372137// / \param [in] m2 The local memory to store data of 2nd matrix
21382138// / \param [in] m3 The local memory to store data of 3rd matrix
21392139// / \param [in] m4 The local memory to store data of 4th matrix
2140- // / \param [in] item_ct1 The sycl::nd_item object
2140+ // / \param [in] item The sycl::nd_item index space class
21412141// / \param [in] trans Indicates whether the matrix to be loaded transposed
2142- template <typename T>
2143- void ldmatrix (uintptr_t addr, T *m1, T *m2, T *m3, T *m4,
2144- const sycl::nd_item< 3 > &item_ct1, bool trans = false ) {
2142+ template <typename T, typename ItemT >
2143+ void ldmatrix (uintptr_t addr, T *m1, T *m2, T *m3, T *m4, const ItemT &item,
2144+ bool trans = false ) {
21452145 // Load 1st matrix
2146- ldmatrix (addr, m1, item_ct1 , trans, 0 );
2146+ ldmatrix (addr, m1, item , trans, 0 );
21472147 // Load 2nd matrix
2148- ldmatrix (addr, m2, item_ct1 , trans, 1 );
2148+ ldmatrix (addr, m2, item , trans, 1 );
21492149 // Load 3rd matrix
2150- ldmatrix (addr, m3, item_ct1 , trans, 2 );
2150+ ldmatrix (addr, m3, item , trans, 2 );
21512151 // Load 4th matrix
2152- ldmatrix (addr, m4, item_ct1 , trans, 3 );
2152+ ldmatrix (addr, m4, item , trans, 3 );
21532153}
21542154
21552155} // namespace matrix
0 commit comments