@@ -2056,27 +2056,33 @@ class joint_matrix {
20562056 const size_t num_elements;
20572057};
20582058
2059- // / Stores 1 8x8 b16 matrix from private memory to local memory per sub-group.
2059+ // / Collectively stores 1 8x8 b16 (128 bytes) matrix from private memory to
2060+ // / local memory per sub-group.
20602061// / Requires the sub-group size of kernel calling this function to be 32.
2061- // / Each of the first 8 work items contain the starting address of their
2062- // / respective matrix row.
2063- // / Each of the 32 work items store 32-bits (2 packed 16-bit data) for a total
2064- // / of 128 bytes.
2065- // / Row Major: Each row of the matrix is stored by a group of 4 work items
2066- // / r0: t0 t1 t2 t3
2067- // / r1: t4 t5 t6 t7
2062+ // / 'mat' specifies the matrix index to be stored. The first '(mat + 1) * 8'
2063+ // / work items of sub-group contain the starting address of their respective
2064+ // / matrix row in 'addr'.
2065+ // / After distributing addresses to other work items, each of the 32 work items
2066+ // / store 32-bits (2 packed 16-bit data) into 'm' for a total of 128 bytes.
2067+ // / 'trans' specifies to perform a transposed/non-transposed store by each work
2068+ // / item like below
2069+ // / Row Major: Each row of the matrix is stored by a group of 4 work items(wi)
2070+ // / row-0: wi0 wi0 wi1 wi1 ... wi3 wi3
2071+ // / row-1: wi4 wi4 wi5 wi5 ... wi7 wi7
20682072// / ...
2069- // / r7: t24 t25 t26 t27
2070- // / r7: t28 t29 t30 t31
2071- // / Col Major: Each col of the matrix is stored by a group of 4 work items
2072- // / r0: t0 t4 t8 ... t28
2073- // / r1: t0 t4 t8 ... t28
2073+ // / row-6: wi24 wi24 wi25 wi25 ... wi27 wi27
2074+ // / row-7: wi28 wi28 wi29 wi29 ... wi31 wi31
2075+ // / Col Major: Each col of the matrix is stored by a group of 4 work items(wi)
2076+ // / row-0: wi0 wi4 wi8 ... wi28
2077+ // / row-1: wi0 wi4 wi8 ... wi28
20742078// / ...
2075- // / r6: t3 t7 t11 ... t31
2076- // / r7: t3 t7 t11 ... t31
2077- // / \tparam [in] T The type of matrix elements
2078- // / \param [in] addr The address of the matrix in local memory
2079- // / \param [in] m The private memory containing data of matrix
2079+ // / row-6: wi3 wi7 wi11 ... wi31
2080+ // / row-7: wi3 wi7 wi11 ... wi31
2081+ // / \tparam [in] T Type of result variable (currently only supports 16-bit type)
2082+ // / \param [in] addr The starting address of corresponding matrix row for a work
2083+ // / item in local memory
2084+ // / \param [in] m The private memory to store the matrix. It points to 2 b16
2085+ // / type elements.
20802086// / \param [in] trans Indicates whether the matrix to be stored transposed
20812087// / \param [in] mat The matrix index to be stored
20822088template <typename T>
@@ -2126,16 +2132,35 @@ void stmatrix(uintptr_t addr, T m, bool trans = false, unsigned mat = 0) {
21262132 }
21272133}
21282134
2129- // / Stores 2 8x8 b16 matrix from private memory to local memory per sub-group.
2135+ // / Collectively stores 2 8x8 b16 (256 bytes) matrix from private memory to
2136+ // / local memory per sub-group.
21302137// / Requires the sub-group size of kernel calling this function to be 32.
2131- // / Each of the first 16 work items contain the starting address of their
2132- // / respective matrix row.
2133- // / Each of the 32 work items store 64-bits (32-bit per matrix) for a total
2134- // / of 256 bytes.
2135- // / \tparam [in] T The type of matrix elements
2136- // / \param [in] addr The address of the matrix in local memory
2137- // / \param [in] m1 The private memory containing data of 1st matrix
2138- // / \param [in] m2 The private memory containing data of 2nd matrix
2138+ // / The first 16 work items of sub-group contain the starting address of their
2139+ // / respective matrix row in 'addr'.
2140+ // / After distributing addresses to other work items, each of the 32 work items
2141+ // / store 64-bits (32-bits per matrix) into 'm1' & 'm2' for a total of 256
2142+ // / bytes.
2143+ // / 'trans' specifies to perform a transposed/non-transposed store by each work
2144+ // / item like below
2145+ // / Row Major: Each row of the matrices is stored by a group of 4 work items(wi)
2146+ // / row-0: wi0 wi0 wi1 wi1 ... wi3 wi3
2147+ // / row-1: wi4 wi4 wi5 wi5 ... wi7 wi7
2148+ // / ...
2149+ // / row-6: wi24 wi24 wi25 wi25 ... wi27 wi27
2150+ // / row-7: wi28 wi28 wi29 wi29 ... wi31 wi31
2151+ // / Col Major: Each col of the matrices is stored by a group of 4 work items(wi)
2152+ // / row-0: wi0 wi4 wi8 ... wi28
2153+ // / row-1: wi0 wi4 wi8 ... wi28
2154+ // / ...
2155+ // / row-6: wi3 wi7 wi11 ... wi31
2156+ // / row-7: wi3 wi7 wi11 ... wi31
2157+ // / \tparam [in] T Type of result variable (currently only supports 16-bit type)
2158+ // / \param [in] addr The starting address of corresponding matrix row for a work
2159+ // / item in local memory
2160+ // / \param [in] m1 The private memory to store the data of 1st matrix. It points
2161+ // / to 2 b16 type elements.
2162+ // / \param [in] m2 The private memory to store the data of 2nd matrix. It points
2163+ // / to 2 b16 type elements.
21392164// / \param [in] trans Indicates whether the matrix to be stored transposed
21402165template <typename T>
21412166void stmatrix (uintptr_t addr, T m1, T m2, bool trans = false ) {
@@ -2145,18 +2170,39 @@ void stmatrix(uintptr_t addr, T m1, T m2, bool trans = false) {
21452170 stmatrix (addr, m2, trans, 1 );
21462171}
21472172
2148- // / Stores 4 8x8 b16 matrix from private memory to local memory per sub-group.
2173+ // / Collectively stores 4 8x8 b16 (512 bytes) matrix from private memory to
2174+ // / local memory per sub-group.
21492175// / Requires the sub-group size of kernel calling this function to be 32.
2150- // / Each of the 32 work items contain the starting address of their
2151- // / respective matrix row.
2152- // / Each of the 32 work items store 128-bits (32-bit per matrix) for a total
2176+ // / Each work item of sub-group contains the starting address of their
2177+ // / respective matrix row in 'addr'.
2178+ // / After distributing addresses to other work items, each of the 32 work items
2179+ // / store 128-bits (32-bits per matrix) into 'm1', 'm2', 'm3' & 'm4' for a total
21532180// / of 512 bytes.
2154- // / \tparam [in] T The type of matrix elements
2155- // / \param [in] addr The address of the matrix in local memory
2156- // / \param [in] m1 The private memory containing data of 1st matrix
2157- // / \param [in] m2 The private memory containing data of 2nd matrix
2158- // / \param [in] m3 The private memory containing data of 3rd matrix
2159- // / \param [in] m4 The private memory containing data of 4th matrix
2181+ // / 'trans' specifies to perform a transposed/non-transposed store by each work
2182+ // / item like below
2183+ // / Row Major: Each row of the matrices is stored by a group of 4 work items(wi)
2184+ // / row-0: wi0 wi0 wi1 wi1 ... wi3 wi3
2185+ // / row-1: wi4 wi4 wi5 wi5 ... wi7 wi7
2186+ // / ...
2187+ // / row-6: wi24 wi24 wi25 wi25 ... wi27 wi27
2188+ // / row-7: wi28 wi28 wi29 wi29 ... wi31 wi31
2189+ // / Col Major: Each col of the matrices is stored by a group of 4 work items(wi)
2190+ // / row-0: wi0 wi4 wi8 ... wi28
2191+ // / row-1: wi0 wi4 wi8 ... wi28
2192+ // / ...
2193+ // / row-6: wi3 wi7 wi11 ... wi31
2194+ // / row-7: wi3 wi7 wi11 ... wi31
2195+ // / \tparam [in] T Type of result variable (currently only supports 16-bit type)
2196+ // / \param [in] addr The starting address of corresponding matrix row for a work
2197+ // / item in local memory
2198+ // / \param [in] m1 The private memory to store the data of 1st matrix. It points
2199+ // / to 2 b16 type elements.
2200+ // / \param [in] m2 The private memory to store the data of 2nd matrix. It points
2201+ // / to 2 b16 type elements.
2202+ // / \param [in] m3 The private memory to store the data of 3rd matrix. It points
2203+ // / to 2 b16 type elements.
2204+ // / \param [in] m4 The private memory to store the data of 4th matrix. It points
2205+ // / to 2 b16 type elements.
21602206// / \param [in] trans Indicates whether the matrix to be stored transposed
21612207template <typename T>
21622208void stmatrix (uintptr_t addr, T m1, T m2, T m3, T m4, bool trans = false ) {
0 commit comments