@@ -425,7 +425,7 @@ max(T1 a, T2 b) {
425425 return sycl::fmax (static_cast <common_t >(a), static_cast <common_t >(b));
426426}
427427
428- // pow functions overload .
428+ // pow functions overstore .
429429inline float pow (const float a, const int b) { return sycl::pown (a, b); }
430430inline double pow (const double a, const int b) { return sycl::pown (a, b); }
431431inline float pow (const float a, const float b) { return sycl::pow (a, b); }
@@ -2218,6 +2218,101 @@ void ldmatrix(uintptr_t addr, T *m1, T *m2, T *m3, T *m4, bool trans = false) {
22182218 ldmatrix (addr, m4, trans, 3 );
22192219}
22202220
2221+ // / Stores 1 8x8 b16 matrix from local memory to shared memory (32-bits per wi)
2222+ // / Requires the sub-group size of kernel calling this function to be 32
2223+ // / \tparam [in] T The type of matrix elements
2224+ // / \param [in] addr The address of the matrix in shared memory
2225+ // / \param [in] m The local memory containing data of matrix
2226+ // / \param [in] item The sycl::nd_item index space class
2227+ // / \param [in] trans Indicates whether the matrix to be stored transposed
2228+ // / \param [in] mat The matrix index to be stored
2229+ template <typename T, typename ItemT>
2230+ void stmatrix (uintptr_t addr, T m, const ItemT &item, bool trans = false ,
2231+ unsigned mat = 0 ) {
2232+ int lane = item.get_sub_group ().get_local_linear_id ();
2233+
2234+ int lane_group8_row = lane / 8 ;
2235+ int lane_group8_col = lane % 8 ;
2236+
2237+ if (!trans) {
2238+ // calculate the source lane
2239+ int src_lane = 2 * lane_group8_row;
2240+ if (lane_group8_col >= 4 )
2241+ src_lane += 1 ;
2242+
2243+ // Broadcast the address from the source lane
2244+ auto recv_addr_uintp = dpct::select_from_sub_group (
2245+ item.get_sub_group (), addr, mat * 8 + src_lane);
2246+
2247+ // Cast the received address from uintptr_t to the type of 'm'
2248+ auto recv_addr = reinterpret_cast <T *>(recv_addr_uintp);
2249+
2250+ // Non-transposed store
2251+ recv_addr[lane_group8_col % 4 ] = m;
2252+ } else {
2253+ // calculate the source lane
2254+ int src_lane = (lane % 4 ) * 2 ;
2255+
2256+ // Broadcast the address from the source lane
2257+ auto recv_addr_uintp_1 = dpct::select_from_sub_group (
2258+ item.get_sub_group (), addr, mat * 8 + src_lane);
2259+ auto recv_addr_uintp_2 = dpct::select_from_sub_group (
2260+ item.get_sub_group (), addr, mat * 8 + src_lane + 1 );
2261+
2262+ // Cast the received address from uintptr_t to 'half *'
2263+ auto recv_addr_1 = reinterpret_cast <sycl::half *>(recv_addr_uintp_1);
2264+ auto recv_addr_2 = reinterpret_cast <sycl::half *>(recv_addr_uintp_2);
2265+
2266+ // Split the 32-bit value of 'm' into two 16-bits
2267+ sycl::half *val = reinterpret_cast <sycl::half *>(&m);
2268+
2269+ // Transposed store
2270+ int index = lane / 4 ;
2271+ recv_addr_1[index] = val[0 ];
2272+ recv_addr_2[index] = val[1 ];
2273+ }
2274+ }
2275+
2276+ // / Stores 2 8x8 b16 matrix from local memory to shared memory (32-bits per wi)
2277+ // / Requires the sub-group size of kernel calling this function to be 32
2278+ // / \tparam [in] T The type of matrix elements
2279+ // / \param [in] addr The address of the matrix in shared memory
2280+ // / \param [in] m1 The local memory containing data of 1st matrix
2281+ // / \param [in] m2 The local memory containing data of 2nd matrix
2282+ // / \param [in] item The sycl::nd_item index space class
2283+ // / \param [in] trans Indicates whether the matrix to be stored transposed
2284+ template <typename T, typename ItemT>
2285+ void stmatrix (uintptr_t addr, T m1, T m2, const ItemT &item,
2286+ bool trans = false ) {
2287+ // Store 1st matrix
2288+ stmatrix (addr, m1, item, trans, 0 );
2289+ // Store 2nd matrix
2290+ stmatrix (addr, m2, item, trans, 1 );
2291+ }
2292+
2293+ // / Stores 4 8x8 b16 matrix from local memory to shared memory (32-bits per wi)
2294+ // / Requires the sub-group size of kernel calling this function to be 32
2295+ // / \tparam [in] T The type of matrix elements
2296+ // / \param [in] addr The address of the matrix in shared memory
2297+ // / \param [in] m1 The local memory containing data of 1st matrix
2298+ // / \param [in] m2 The local memory containing data of 2nd matrix
2299+ // / \param [in] m3 The local memory containing data of 3rd matrix
2300+ // / \param [in] m4 The local memory containing data of 4th matrix
2301+ // / \param [in] item The sycl::nd_item index space class
2302+ // / \param [in] trans Indicates whether the matrix to be stored transposed
2303+ template <typename T, typename ItemT>
2304+ void stmatrix (uintptr_t addr, T m1, T m2, T m3, T m4, const ItemT &item,
2305+ bool trans = false ) {
2306+ // Store 1st matrix
2307+ stmatrix (addr, m1, item, trans, 0 );
2308+ // Store 2nd matrix
2309+ stmatrix (addr, m2, item, trans, 1 );
2310+ // Store 3rd matrix
2311+ stmatrix (addr, m3, item, trans, 2 );
2312+ // Store 4th matrix
2313+ stmatrix (addr, m4, item, trans, 3 );
2314+ }
2315+
22212316// / A helper struct that defines the pack type for the input matrix fragments
22222317// / of mma() function based on the type of input matrix fragments.
22232318// / The MMAType struct is specialized for different types of input matrices.
0 commit comments