@@ -2218,27 +2218,33 @@ void ldmatrix(uintptr_t addr, T *m1, T *m2, T *m3, T *m4, bool trans = false) {
22182218 ldmatrix (addr, m4, trans, 3 );
22192219}
22202220
2221- // / Stores 1 8x8 b16 matrix from private memory to local memory per sub-group.
2221+ // / Collectively stores 1 8x8 b16 (128 bytes) matrix from private memory to
2222+ // / local memory per sub-group.
22222223// / Requires the sub-group size of kernel calling this function to be 32.
2223- // / Each of the first 8 work items contain the starting address of their
2224- // / respective matrix row.
2225- // / Each of the 32 work items store 32-bits (2 packed 16-bit data) for a total
2226- // / of 128 bytes.
2227- // / Row Major: Each row of the matrix is stored by a group of 4 work items
2228- // / r0: t0 t1 t2 t3
2229- // / r1: t4 t5 t6 t7
2224+ // / 'mat' specifies the matrix index to be stored. The first '(mat + 1) * 8'
2225+ // / work items of sub-group contain the starting address of their respective
2226+ // / matrix row in 'addr'.
2227+ // / After distributing addresses to other work items, each of the 32 work items
2228+ // / store 32-bits (2 packed 16-bit data) into 'm' for a total of 128 bytes.
2229+ // / 'trans' specifies to perform a transposed/non-transposed store by each work
2230+ // / item like below
2231+ // / Row Major: Each row of the matrix is stored by a group of 4 work items(wi)
2232+ // / row-0: wi0 wi0 wi1 wi1 ... wi3 wi3
2233+ // / row-1: wi4 wi4 wi5 wi5 ... wi7 wi7
22302234// / ...
2231- // / r7: t24 t25 t26 t27
2232- // / r7: t28 t29 t30 t31
2233- // / Col Major: Each col of the matrix is stored by a group of 4 work items
2234- // / r0: t0 t4 t8 ... t28
2235- // / r1: t0 t4 t8 ... t28
2235+ // / row-6: wi24 wi24 wi25 wi25 ... wi27 wi27
2236+ // / row-7: wi28 wi28 wi29 wi29 ... wi31 wi31
2237+ // / Col Major: Each col of the matrix is stored by a group of 4 work items(wi)
2238+ // / row-0: wi0 wi4 wi8 ... wi28
2239+ // / row-1: wi0 wi4 wi8 ... wi28
22362240// / ...
2237- // / r6: t3 t7 t11 ... t31
2238- // / r7: t3 t7 t11 ... t31
2239- // / \tparam [in] T The type of matrix elements
2240- // / \param [in] addr The address of the matrix in local memory
2241- // / \param [in] m The private memory containing data of matrix
2241+ // / row-6: wi3 wi7 wi11 ... wi31
2242+ // / row-7: wi3 wi7 wi11 ... wi31
2243+ // / \tparam [in] T Type of result variable (currently only supports 16-bit type)
2244+ // / \param [in] addr The starting address of corresponding matrix row for a work
2245+ // / item in local memory
2246+ // / \param [in] m The private memory to store the matrix. It points to 2 b16
2247+ // / type elements.
22422248// / \param [in] trans Indicates whether the matrix to be stored transposed
22432249// / \param [in] mat The matrix index to be stored
22442250template <typename T>
@@ -2288,16 +2294,35 @@ void stmatrix(uintptr_t addr, T m, bool trans = false, unsigned mat = 0) {
22882294 }
22892295}
22902296
2291- // / Stores 2 8x8 b16 matrix from private memory to local memory per sub-group.
2297+ // / Collectively stores 2 8x8 b16 (256 bytes) matrix from private memory to
2298+ // / local memory per sub-group.
22922299// / Requires the sub-group size of kernel calling this function to be 32.
2293- // / Each of the first 16 work items contain the starting address of their
2294- // / respective matrix row.
2295- // / Each of the 32 work items store 64-bits (32-bit per matrix) for a total
2296- // / of 256 bytes.
2297- // / \tparam [in] T The type of matrix elements
2298- // / \param [in] addr The address of the matrix in local memory
2299- // / \param [in] m1 The private memory containing data of 1st matrix
2300- // / \param [in] m2 The private memory containing data of 2nd matrix
2300+ // / The first 16 work items of sub-group contain the starting address of their
2301+ // / respective matrix row in 'addr'.
2302+ // / After distributing addresses to other work items, each of the 32 work items
2303+ // / store 64-bits (32-bits per matrix) into 'm1' & 'm2' for a total of 256
2304+ // / bytes.
2305+ // / 'trans' specifies to perform a transposed/non-transposed store by each work
2306+ // / item like below
2307+ // / Row Major: Each row of the matrices is stored by a group of 4 work items(wi)
2308+ // / row-0: wi0 wi0 wi1 wi1 ... wi3 wi3
2309+ // / row-1: wi4 wi4 wi5 wi5 ... wi7 wi7
2310+ // / ...
2311+ // / row-6: wi24 wi24 wi25 wi25 ... wi27 wi27
2312+ // / row-7: wi28 wi28 wi29 wi29 ... wi31 wi31
2313+ // / Col Major: Each col of the matrices is stored by a group of 4 work items(wi)
2314+ // / row-0: wi0 wi4 wi8 ... wi28
2315+ // / row-1: wi0 wi4 wi8 ... wi28
2316+ // / ...
2317+ // / row-6: wi3 wi7 wi11 ... wi31
2318+ // / row-7: wi3 wi7 wi11 ... wi31
2319+ // / \tparam [in] T Type of result variable (currently only supports 16-bit type)
2320+ // / \param [in] addr The starting address of corresponding matrix row for a work
2321+ // / item in local memory
2322+ // / \param [in] m1 The private memory to store the data of 1st matrix. It points
2323+ // / to 2 b16 type elements.
2324+ // / \param [in] m2 The private memory to store the data of 2nd matrix. It points
2325+ // / to 2 b16 type elements.
23012326// / \param [in] trans Indicates whether the matrix to be stored transposed
23022327template <typename T>
23032328void stmatrix (uintptr_t addr, T m1, T m2, bool trans = false ) {
@@ -2307,18 +2332,39 @@ void stmatrix(uintptr_t addr, T m1, T m2, bool trans = false) {
23072332 stmatrix (addr, m2, trans, 1 );
23082333}
23092334
2310- // / Stores 4 8x8 b16 matrix from private memory to local memory per sub-group.
2335+ // / Collectively stores 4 8x8 b16 (512 bytes) matrix from private memory to
2336+ // / local memory per sub-group.
23112337// / Requires the sub-group size of kernel calling this function to be 32.
2312- // / Each of the 32 work items contain the starting address of their
2313- // / respective matrix row.
2314- // / Each of the 32 work items store 128-bits (32-bit per matrix) for a total
2338+ // / Each work item of sub-group contains the starting address of their
2339+ // / respective matrix row in 'addr'.
2340+ // / After distributing addresses to other work items, each of the 32 work items
2341+ // / store 128-bits (32-bits per matrix) into 'm1', 'm2', 'm3' & 'm4' for a total
23152342// / of 512 bytes.
2316- // / \tparam [in] T The type of matrix elements
2317- // / \param [in] addr The address of the matrix in local memory
2318- // / \param [in] m1 The private memory containing data of 1st matrix
2319- // / \param [in] m2 The private memory containing data of 2nd matrix
2320- // / \param [in] m3 The private memory containing data of 3rd matrix
2321- // / \param [in] m4 The private memory containing data of 4th matrix
2343+ // / 'trans' specifies to perform a transposed/non-transposed store by each work
2344+ // / item like below
2345+ // / Row Major: Each row of the matrices is stored by a group of 4 work items(wi)
2346+ // / row-0: wi0 wi0 wi1 wi1 ... wi3 wi3
2347+ // / row-1: wi4 wi4 wi5 wi5 ... wi7 wi7
2348+ // / ...
2349+ // / row-6: wi24 wi24 wi25 wi25 ... wi27 wi27
2350+ // / row-7: wi28 wi28 wi29 wi29 ... wi31 wi31
2351+ // / Col Major: Each col of the matrices is stored by a group of 4 work items(wi)
2352+ // / row-0: wi0 wi4 wi8 ... wi28
2353+ // / row-1: wi0 wi4 wi8 ... wi28
2354+ // / ...
2355+ // / row-6: wi3 wi7 wi11 ... wi31
2356+ // / row-7: wi3 wi7 wi11 ... wi31
2357+ // / \tparam [in] T Type of result variable (currently only supports 16-bit type)
2358+ // / \param [in] addr The starting address of corresponding matrix row for a work
2359+ // / item in local memory
2360+ // / \param [in] m1 The private memory to store the data of 1st matrix. It points
2361+ // / to 2 b16 type elements.
2362+ // / \param [in] m2 The private memory to store the data of 2nd matrix. It points
2363+ // / to 2 b16 type elements.
2364+ // / \param [in] m3 The private memory to store the data of 3rd matrix. It points
2365+ // / to 2 b16 type elements.
2366+ // / \param [in] m4 The private memory to store the data of 4th matrix. It points
2367+ // / to 2 b16 type elements.
23222368// / \param [in] trans Indicates whether the matrix to be stored transposed
23232369template <typename T>
23242370void stmatrix (uintptr_t addr, T m1, T m2, T m3, T m4, bool trans = false ) {
0 commit comments