@@ -2223,13 +2223,12 @@ void ldmatrix(uintptr_t addr, T *m1, T *m2, T *m3, T *m4, bool trans = false) {
22232223// / \tparam [in] T The type of matrix elements
22242224// / \param [in] addr The address of the matrix in local memory
22252225// / \param [in] m The private memory containing data of matrix
2226- // / \param [in] item The sycl::nd_item index space class
22272226// / \param [in] trans Indicates whether the matrix to be stored transposed
22282227// / \param [in] mat The matrix index to be stored
2229- template <typename T, typename ItemT >
2230- void stmatrix (uintptr_t addr, T m, const ItemT &item, bool trans = false ,
2231- unsigned mat = 0 ) {
2232- int lane = item. get_sub_group () .get_local_linear_id ();
2228+ template <typename T>
2229+ void stmatrix (uintptr_t addr, T m, bool trans = false , unsigned mat = 0 ) {
2230+ auto sg = sycl::ext::oneapi::this_work_item::get_sub_group ();
2231+ int lane = sg .get_local_linear_id ();
22332232
22342233 int lane_group8_row = lane / 8 ;
22352234 int lane_group8_col = lane % 8 ;
@@ -2241,8 +2240,8 @@ void stmatrix(uintptr_t addr, T m, const ItemT &item, bool trans = false,
22412240 src_lane += 1 ;
22422241
22432242 // Broadcast the address from the source lane
2244- auto recv_addr_uintp = dpct::select_from_sub_group (
2245- item. get_sub_group () , addr, mat * 8 + src_lane);
2243+ auto recv_addr_uintp =
2244+ dpct::select_from_sub_group (sg , addr, mat * 8 + src_lane);
22462245
22472246 // Cast the received address from uintptr_t to the type of 'm'
22482247 auto recv_addr = reinterpret_cast <T *>(recv_addr_uintp);
@@ -2254,10 +2253,10 @@ void stmatrix(uintptr_t addr, T m, const ItemT &item, bool trans = false,
22542253 int src_lane = (lane % 4 ) * 2 ;
22552254
22562255 // Broadcast the address from the source lane
2257- auto recv_addr_uintp_1 = dpct::select_from_sub_group (
2258- item. get_sub_group () , addr, mat * 8 + src_lane);
2259- auto recv_addr_uintp_2 = dpct::select_from_sub_group (
2260- item. get_sub_group () , addr, mat * 8 + src_lane + 1 );
2256+ auto recv_addr_uintp_1 =
2257+ dpct::select_from_sub_group (sg , addr, mat * 8 + src_lane);
2258+ auto recv_addr_uintp_2 =
2259+ dpct::select_from_sub_group (sg , addr, mat * 8 + src_lane + 1 );
22612260
22622261 // Cast the received address from uintptr_t to 'half *'
22632262 auto recv_addr_1 = reinterpret_cast <sycl::half *>(recv_addr_uintp_1);
@@ -2279,15 +2278,13 @@ void stmatrix(uintptr_t addr, T m, const ItemT &item, bool trans = false,
22792278// / \param [in] addr The address of the matrix in local memory
22802279// / \param [in] m1 The private memory containing data of 1st matrix
22812280// / \param [in] m2 The private memory containing data of 2nd matrix
2282- // / \param [in] item The sycl::nd_item index space class
22832281// / \param [in] trans Indicates whether the matrix to be stored transposed
2284- template <typename T, typename ItemT>
2285- void stmatrix (uintptr_t addr, T m1, T m2, const ItemT &item,
2286- bool trans = false ) {
2282+ template <typename T>
2283+ void stmatrix (uintptr_t addr, T m1, T m2, bool trans = false ) {
22872284 // Store 1st matrix
2288- stmatrix (addr, m1, item, trans, 0 );
2285+ stmatrix (addr, m1, trans, 0 );
22892286 // Store 2nd matrix
2290- stmatrix (addr, m2, item, trans, 1 );
2287+ stmatrix (addr, m2, trans, 1 );
22912288}
22922289
22932290// / Stores 4 8x8 b16 matrix from private memory to local memory (32-bits per wi)
@@ -2298,19 +2295,17 @@ void stmatrix(uintptr_t addr, T m1, T m2, const ItemT &item,
22982295// / \param [in] m2 The private memory containing data of 2nd matrix
22992296// / \param [in] m3 The private memory containing data of 3rd matrix
23002297// / \param [in] m4 The private memory containing data of 4th matrix
2301- // / \param [in] item The sycl::nd_item index space class
23022298// / \param [in] trans Indicates whether the matrix to be stored transposed
2303- template <typename T, typename ItemT>
2304- void stmatrix (uintptr_t addr, T m1, T m2, T m3, T m4, const ItemT &item,
2305- bool trans = false ) {
2299+ template <typename T>
2300+ void stmatrix (uintptr_t addr, T m1, T m2, T m3, T m4, bool trans = false ) {
23062301 // Store 1st matrix
2307- stmatrix (addr, m1, item, trans, 0 );
2302+ stmatrix (addr, m1, trans, 0 );
23082303 // Store 2nd matrix
2309- stmatrix (addr, m2, item, trans, 1 );
2304+ stmatrix (addr, m2, trans, 1 );
23102305 // Store 3rd matrix
2311- stmatrix (addr, m3, item, trans, 2 );
2306+ stmatrix (addr, m3, trans, 2 );
23122307 // Store 4th matrix
2313- stmatrix (addr, m4, item, trans, 3 );
2308+ stmatrix (addr, m4, trans, 3 );
23142309}
23152310
23162311// / A helper struct that defines the pack type for the input matrix fragments
0 commit comments