@@ -2061,13 +2061,12 @@ class joint_matrix {
20612061// / \tparam [in] T The type of matrix elements
20622062// / \param [in] addr The address of the matrix in local memory
20632063// / \param [in] m The private memory containing data of matrix
2064- // / \param [in] item The sycl::nd_item index space class
20652064// / \param [in] trans Indicates whether the matrix to be stored transposed
20662065// / \param [in] mat The matrix index to be stored
2067- template <typename T, typename ItemT >
2068- void stmatrix (uintptr_t addr, T m, const ItemT &item, bool trans = false ,
2069- unsigned mat = 0 ) {
2070- int lane = item. get_sub_group () .get_local_linear_id ();
2066+ template <typename T>
2067+ void stmatrix (uintptr_t addr, T m, bool trans = false , unsigned mat = 0 ) {
2068+ auto sg = sycl::ext::oneapi::this_work_item::get_sub_group ();
2069+ int lane = sg .get_local_linear_id ();
20712070
20722071 int lane_group8_row = lane / 8 ;
20732072 int lane_group8_col = lane % 8 ;
@@ -2079,8 +2078,8 @@ void stmatrix(uintptr_t addr, T m, const ItemT &item, bool trans = false,
20792078 src_lane += 1 ;
20802079
20812080 // Broadcast the address from the source lane
2082- auto recv_addr_uintp = dpct::select_from_sub_group (
2083- item. get_sub_group () , addr, mat * 8 + src_lane);
2081+ auto recv_addr_uintp =
2082+ dpct::select_from_sub_group (sg , addr, mat * 8 + src_lane);
20842083
20852084 // Cast the received address from uintptr_t to the type of 'm'
20862085 auto recv_addr = reinterpret_cast <T *>(recv_addr_uintp);
@@ -2092,10 +2091,10 @@ void stmatrix(uintptr_t addr, T m, const ItemT &item, bool trans = false,
20922091 int src_lane = (lane % 4 ) * 2 ;
20932092
20942093 // Broadcast the address from the source lane
2095- auto recv_addr_uintp_1 = dpct::select_from_sub_group (
2096- item. get_sub_group () , addr, mat * 8 + src_lane);
2097- auto recv_addr_uintp_2 = dpct::select_from_sub_group (
2098- item. get_sub_group () , addr, mat * 8 + src_lane + 1 );
2094+ auto recv_addr_uintp_1 =
2095+ dpct::select_from_sub_group (sg , addr, mat * 8 + src_lane);
2096+ auto recv_addr_uintp_2 =
2097+ dpct::select_from_sub_group (sg , addr, mat * 8 + src_lane + 1 );
20992098
21002099 // Cast the received address from uintptr_t to 'half *'
21012100 auto recv_addr_1 = reinterpret_cast <sycl::half *>(recv_addr_uintp_1);
@@ -2117,15 +2116,13 @@ void stmatrix(uintptr_t addr, T m, const ItemT &item, bool trans = false,
21172116// / \param [in] addr The address of the matrix in local memory
21182117// / \param [in] m1 The private memory containing data of 1st matrix
21192118// / \param [in] m2 The private memory containing data of 2nd matrix
2120- // / \param [in] item The sycl::nd_item index space class
21212119// / \param [in] trans Indicates whether the matrix to be stored transposed
2122- template <typename T, typename ItemT>
2123- void stmatrix (uintptr_t addr, T m1, T m2, const ItemT &item,
2124- bool trans = false ) {
2120+ template <typename T>
2121+ void stmatrix (uintptr_t addr, T m1, T m2, bool trans = false ) {
21252122 // Store 1st matrix
2126- stmatrix (addr, m1, item, trans, 0 );
2123+ stmatrix (addr, m1, trans, 0 );
21272124 // Store 2nd matrix
2128- stmatrix (addr, m2, item, trans, 1 );
2125+ stmatrix (addr, m2, trans, 1 );
21292126}
21302127
21312128// / Stores 4 8x8 b16 matrix from private memory to local memory (32-bits per wi)
@@ -2136,19 +2133,17 @@ void stmatrix(uintptr_t addr, T m1, T m2, const ItemT &item,
21362133// / \param [in] m2 The private memory containing data of 2nd matrix
21372134// / \param [in] m3 The private memory containing data of 3rd matrix
21382135// / \param [in] m4 The private memory containing data of 4th matrix
2139- // / \param [in] item The sycl::nd_item index space class
21402136// / \param [in] trans Indicates whether the matrix to be stored transposed
2141- template <typename T, typename ItemT>
2142- void stmatrix (uintptr_t addr, T m1, T m2, T m3, T m4, const ItemT &item,
2143- bool trans = false ) {
2137+ template <typename T>
2138+ void stmatrix (uintptr_t addr, T m1, T m2, T m3, T m4, bool trans = false ) {
21442139 // Store 1st matrix
2145- stmatrix (addr, m1, item, trans, 0 );
2140+ stmatrix (addr, m1, trans, 0 );
21462141 // Store 2nd matrix
2147- stmatrix (addr, m2, item, trans, 1 );
2142+ stmatrix (addr, m2, trans, 1 );
21482143 // Store 3rd matrix
2149- stmatrix (addr, m3, item, trans, 2 );
2144+ stmatrix (addr, m3, trans, 2 );
21502145 // Store 4th matrix
2151- stmatrix (addr, m4, item, trans, 3 );
2146+ stmatrix (addr, m4, trans, 3 );
21522147}
21532148
21542149} // namespace matrix
0 commit comments