|
9 | 9 | #ifndef __DPCT_MATH_HPP__ |
10 | 10 | #define __DPCT_MATH_HPP__ |
11 | 11 |
|
12 | | -#include <limits> |
13 | 12 | #include <climits> |
| 13 | +#include <limits> |
14 | 14 | #include <sycl/sycl.hpp> |
15 | 15 | #include <type_traits> |
16 | 16 |
|
@@ -425,7 +425,7 @@ max(T1 a, T2 b) { |
425 | 425 | return sycl::fmax(static_cast<common_t>(a), static_cast<common_t>(b)); |
426 | 426 | } |
427 | 427 |
|
428 | | -// pow functions overload. |
| 428 | +// pow functions overstore. |
429 | 429 | inline float pow(const float a, const int b) { return sycl::pown(a, b); } |
430 | 430 | inline double pow(const double a, const int b) { return sycl::pown(a, b); } |
431 | 431 | inline float pow(const float a, const float b) { return sycl::pow(a, b); } |
@@ -2055,6 +2055,102 @@ class joint_matrix { |
2055 | 2055 | matrix_accessor x; |
2056 | 2056 | const size_t num_elements; |
2057 | 2057 | }; |
| 2058 | + |
| 2059 | +/// Stores 1 8x8 b16 matrix from local memory to shared memory (32-bits per wi) |
| 2060 | +/// Requires the sub-group size of kernel calling this function to be 32 |
| 2061 | +/// \tparam [in] T The type of matrix elements |
| 2062 | +/// \param [in] addr The address of the matrix in shared memory |
| 2063 | +/// \param [in] m The local memory containing data of matrix |
| 2064 | +/// \param [in] item The sycl::nd_item index space class |
| 2065 | +/// \param [in] trans Indicates whether the matrix to be stored transposed |
| 2066 | +/// \param [in] mat The matrix index to be stored |
| 2067 | +template <typename T, typename ItemT> |
| 2068 | +void stmatrix(uintptr_t addr, T m, const ItemT &item, bool trans = false, |
| 2069 | + unsigned mat = 0) { |
| 2070 | + int lane = item.get_sub_group().get_local_linear_id(); |
| 2071 | + |
| 2072 | + int lane_group8_row = lane / 8; |
| 2073 | + int lane_group8_col = lane % 8; |
| 2074 | + |
| 2075 | + if (!trans) { |
| 2076 | + // calculate the source lane |
| 2077 | + int src_lane = 2 * lane_group8_row; |
| 2078 | + if (lane_group8_col >= 4) |
| 2079 | + src_lane += 1; |
| 2080 | + |
| 2081 | + // Broadcast the address from the source lane |
| 2082 | + auto recv_addr_uintp = dpct::select_from_sub_group( |
| 2083 | + item.get_sub_group(), addr, mat * 8 + src_lane); |
| 2084 | + |
| 2085 | + // Cast the received address from uintptr_t to the type of 'm' |
| 2086 | + auto recv_addr = reinterpret_cast<T *>(recv_addr_uintp); |
| 2087 | + |
| 2088 | + // Non-transposed store |
| 2089 | + recv_addr[lane_group8_col % 4] = m; |
| 2090 | + } else { |
| 2091 | + // calculate the source lane |
| 2092 | + int src_lane = (lane % 4) * 2; |
| 2093 | + |
| 2094 | + // Broadcast the address from the source lane |
| 2095 | + auto recv_addr_uintp_1 = dpct::select_from_sub_group( |
| 2096 | + item.get_sub_group(), addr, mat * 8 + src_lane); |
| 2097 | + auto recv_addr_uintp_2 = dpct::select_from_sub_group( |
| 2098 | + item.get_sub_group(), addr, mat * 8 + src_lane + 1); |
| 2099 | + |
| 2100 | + // Cast the received address from uintptr_t to 'half *' |
| 2101 | + auto recv_addr_1 = reinterpret_cast<sycl::half *>(recv_addr_uintp_1); |
| 2102 | + auto recv_addr_2 = reinterpret_cast<sycl::half *>(recv_addr_uintp_2); |
| 2103 | + |
| 2104 | + // Split the 32-bit value of 'm' into two 16-bits |
| 2105 | + sycl::half *val = reinterpret_cast<sycl::half *>(&m); |
| 2106 | + |
| 2107 | + // Transposed store |
| 2108 | + int index = lane / 4; |
| 2109 | + recv_addr_1[index] = val[0]; |
| 2110 | + recv_addr_2[index] = val[1]; |
| 2111 | + } |
| 2112 | +} |
| 2113 | + |
| 2114 | +/// Stores 2 8x8 b16 matrix from local memory to shared memory (32-bits per wi) |
| 2115 | +/// Requires the sub-group size of kernel calling this function to be 32 |
| 2116 | +/// \tparam [in] T The type of matrix elements |
| 2117 | +/// \param [in] addr The address of the matrix in shared memory |
| 2118 | +/// \param [in] m1 The local memory containing data of 1st matrix |
| 2119 | +/// \param [in] m2 The local memory containing data of 2nd matrix |
| 2120 | +/// \param [in] item The sycl::nd_item index space class |
| 2121 | +/// \param [in] trans Indicates whether the matrix to be stored transposed |
| 2122 | +template <typename T, typename ItemT> |
| 2123 | +void stmatrix(uintptr_t addr, T m1, T m2, const ItemT &item, |
| 2124 | + bool trans = false) { |
| 2125 | + // Store 1st matrix |
| 2126 | + stmatrix(addr, m1, item, trans, 0); |
| 2127 | + // Store 2nd matrix |
| 2128 | + stmatrix(addr, m2, item, trans, 1); |
| 2129 | +} |
| 2130 | + |
| 2131 | +/// Stores 4 8x8 b16 matrix from local memory to shared memory (32-bits per wi) |
| 2132 | +/// Requires the sub-group size of kernel calling this function to be 32 |
| 2133 | +/// \tparam [in] T The type of matrix elements |
| 2134 | +/// \param [in] addr The address of the matrix in shared memory |
| 2135 | +/// \param [in] m1 The local memory containing data of 1st matrix |
| 2136 | +/// \param [in] m2 The local memory containing data of 2nd matrix |
| 2137 | +/// \param [in] m3 The local memory containing data of 3rd matrix |
| 2138 | +/// \param [in] m4 The local memory containing data of 4th matrix |
| 2139 | +/// \param [in] item The sycl::nd_item index space class |
| 2140 | +/// \param [in] trans Indicates whether the matrix to be stored transposed |
| 2141 | +template <typename T, typename ItemT> |
| 2142 | +void stmatrix(uintptr_t addr, T m1, T m2, T m3, T m4, const ItemT &item, |
| 2143 | + bool trans = false) { |
| 2144 | + // Store 1st matrix |
| 2145 | + stmatrix(addr, m1, item, trans, 0); |
| 2146 | + // Store 2nd matrix |
| 2147 | + stmatrix(addr, m2, item, trans, 1); |
| 2148 | + // Store 3rd matrix |
| 2149 | + stmatrix(addr, m3, item, trans, 2); |
| 2150 | + // Store 4th matrix |
| 2151 | + stmatrix(addr, m4, item, trans, 3); |
| 2152 | +} |
| 2153 | + |
2058 | 2154 | } // namespace matrix |
2059 | 2155 | } // namespace experimental |
2060 | 2156 |
|
|
0 commit comments