@@ -2218,8 +2218,24 @@ void ldmatrix(uintptr_t addr, T *m1, T *m2, T *m3, T *m4, bool trans = false) {
22182218 ldmatrix (addr, m4, trans, 3 );
22192219}
22202220
2221- // / Stores 1 8x8 b16 matrix from private memory to local memory (32-bits per wi)
2222- // / Requires the sub-group size of kernel calling this function to be 32
2221+ // / Stores 1 8x8 b16 matrix from private memory to local memory per sub-group.
2222+ // / Requires the sub-group size of kernel calling this function to be 32.
2223+ // / Each of the first 8 work items contain the starting address of their
2224+ // / respective matrix row.
2225+ // / Each of the 32 work items store 32-bits (2 packed 16-bit data) for a total
2226+ // / of 128 bytes.
2227+ // / Row Major: Each row of the matrix is stored by a group of 4 work items
2228+ // / r0: t0 t1 t2 t3
2229+ // / r1: t4 t5 t6 t7
2230+ // / ...
2231+ // / r7: t24 t25 t26 t27
2232+ // / r7: t28 t29 t30 t31
2233+ // / Col Major: Each col of the matrix is stored by a group of 4 work items
2234+ // / r0: t0 t4 t8 ... t28
2235+ // / r1: t0 t4 t8 ... t28
2236+ // / ...
2237+ // / r6: t3 t7 t11 ... t31
2238+ // / r7: t3 t7 t11 ... t31
22232239// / \tparam [in] T The type of matrix elements
22242240// / \param [in] addr The address of the matrix in local memory
22252241// / \param [in] m The private memory containing data of matrix
@@ -2272,8 +2288,12 @@ void stmatrix(uintptr_t addr, T m, bool trans = false, unsigned mat = 0) {
22722288 }
22732289}
22742290
2275- // / Stores 2 8x8 b16 matrix from private memory to local memory (32-bits per wi)
2276- // / Requires the sub-group size of kernel calling this function to be 32
2291+ // / Stores 2 8x8 b16 matrix from private memory to local memory per sub-group.
2292+ // / Requires the sub-group size of kernel calling this function to be 32.
2293+ // / Each of the first 16 work items contain the starting address of their
2294+ // / respective matrix row.
2295+ // / Each of the 32 work items store 64-bits (32-bit per matrix) for a total
2296+ // / of 256 bytes.
22772297// / \tparam [in] T The type of matrix elements
22782298// / \param [in] addr The address of the matrix in local memory
22792299// / \param [in] m1 The private memory containing data of 1st matrix
@@ -2287,8 +2307,12 @@ void stmatrix(uintptr_t addr, T m1, T m2, bool trans = false) {
22872307 stmatrix (addr, m2, trans, 1 );
22882308}
22892309
2290- // / Stores 4 8x8 b16 matrix from private memory to local memory (32-bits per wi)
2291- // / Requires the sub-group size of kernel calling this function to be 32
2310+ // / Stores 4 8x8 b16 matrix from private memory to local memory per sub-group.
2311+ // / Requires the sub-group size of kernel calling this function to be 32.
2312+ // / Each of the 32 work items contain the starting address of their
2313+ // / respective matrix row.
2314+ // / Each of the 32 work items store 128-bits (32-bit per matrix) for a total
2315+ // / of 512 bytes.
22922316// / \tparam [in] T The type of matrix elements
22932317// / \param [in] addr The address of the matrix in local memory
22942318// / \param [in] m1 The private memory containing data of 1st matrix
0 commit comments