@@ -2056,8 +2056,24 @@ class joint_matrix {
20562056 const size_t num_elements;
20572057};
20582058
2059- // / Stores 1 8x8 b16 matrix from private memory to local memory (32-bits per wi)
2060- // / Requires the sub-group size of kernel calling this function to be 32
2059+ // / Stores 1 8x8 b16 matrix from private memory to local memory per sub-group.
2060+ // / Requires the sub-group size of kernel calling this function to be 32.
2061+ // / Each of the first 8 work items contain the starting address of their
2062+ // / respective matrix row.
2063+ // / Each of the 32 work items store 32-bits (2 packed 16-bit data) for a total
2064+ // / of 128 bytes.
2065+ // / Row Major: Each row of the matrix is stored by a group of 4 work items
2066+ // / r0: t0 t1 t2 t3
2067+ // / r1: t4 t5 t6 t7
2068+ // / ...
2069+ // / r7: t24 t25 t26 t27
2070+ // / r7: t28 t29 t30 t31
2071+ // / Col Major: Each col of the matrix is stored by a group of 4 work items
2072+ // / r0: t0 t4 t8 ... t28
2073+ // / r1: t0 t4 t8 ... t28
2074+ // / ...
2075+ // / r6: t3 t7 t11 ... t31
2076+ // / r7: t3 t7 t11 ... t31
20612077// / \tparam [in] T The type of matrix elements
20622078// / \param [in] addr The address of the matrix in local memory
20632079// / \param [in] m The private memory containing data of matrix
@@ -2110,8 +2126,12 @@ void stmatrix(uintptr_t addr, T m, bool trans = false, unsigned mat = 0) {
21102126 }
21112127}
21122128
2113- // / Stores 2 8x8 b16 matrix from private memory to local memory (32-bits per wi)
2114- // / Requires the sub-group size of kernel calling this function to be 32
2129+ // / Stores 2 8x8 b16 matrix from private memory to local memory per sub-group.
2130+ // / Requires the sub-group size of kernel calling this function to be 32.
2131+ // / Each of the first 16 work items contain the starting address of their
2132+ // / respective matrix row.
2133+ // / Each of the 32 work items store 64-bits (32-bit per matrix) for a total
2134+ // / of 256 bytes.
21152135// / \tparam [in] T The type of matrix elements
21162136// / \param [in] addr The address of the matrix in local memory
21172137// / \param [in] m1 The private memory containing data of 1st matrix
@@ -2125,8 +2145,12 @@ void stmatrix(uintptr_t addr, T m1, T m2, bool trans = false) {
21252145 stmatrix (addr, m2, trans, 1 );
21262146}
21272147
2128- // / Stores 4 8x8 b16 matrix from private memory to local memory (32-bits per wi)
2129- // / Requires the sub-group size of kernel calling this function to be 32
2148+ // / Stores 4 8x8 b16 matrix from private memory to local memory per sub-group.
2149+ // / Requires the sub-group size of kernel calling this function to be 32.
2150+ // / Each of the 32 work items contain the starting address of their
2151+ // / respective matrix row.
2152+ // / Each of the 32 work items store 128-bits (32-bit per matrix) for a total
2153+ // / of 512 bytes.
21302154// / \tparam [in] T The type of matrix elements
21312155// / \param [in] addr The address of the matrix in local memory
21322156// / \param [in] m1 The private memory containing data of 1st matrix
0 commit comments