[PTX][MMA] Added support for migrating m16n8k16#2821
Conversation
tomflinda
left a comment
There was a problem hiding this comment.
Pls address the comments and verify your update with e2e test cases
| auto rb = reinterpret_cast<MulType *>(recv_b); | ||
|
|
||
| for (int j = 0; j < 4; j++) { | ||
| c[0] += static_cast<CDType>(ra[j]) * static_cast<CDType>(rb[j]); |
There was a problem hiding this comment.
c matrix should not be updated.
There was a problem hiding this comment.
Changed the logic, to do
d = c;
d += a * b;
| for (int j = 0; j < 4; j++) { | ||
| c[0] += static_cast<CDType>(ra[j]) * static_cast<CDType>(rb[j]); | ||
| c[1] += static_cast<CDType>(ra[j]) * static_cast<CDType>(rb[j + 4]); | ||
| c[2] += static_cast<CDType>(ra[j + 4]) * static_cast<CDType>(rb[j]); | ||
| c[3] += | ||
| static_cast<CDType>(ra[j + 4]) * static_cast<CDType>(rb[j + 4]); | ||
| } |
There was a problem hiding this comment.
Pls add more comments to explain the code piece here and the reason offset 4 is used.
There was a problem hiding this comment.
Added comments to clarify the reason for using '4' offset and
that how this wouldn't overflow
| /// \tparam [in] M The rows of A, C & D matrix | ||
| /// \tparam [in] N The columns of B, C, D matrix | ||
| /// \tparam [in] K The columns & rows of A & B matrices respectively | ||
| /// \tparam [in] MulType The type used to multiply A and B matrix elements as |
There was a problem hiding this comment.
MulType is confusing to ABType; pls add more comments to explain it.
There was a problem hiding this comment.
Modified the comment to explain better
2dda4c7 to
9e2c234
Compare
| /// Multiplies 2 matrices (A & B) and adds the result to C matrix and | ||
| /// accumulates the result to a D matrix (MAD). Requires the sub-group size of |
There was a problem hiding this comment.
The functionality description for this helper function is not accurate; this helper function is called by one work item of a subgroup('the size of the subgroup is limited to 32'), the current work item i(i=0,1,..,31) only calculates the four elements of the result matrix D(e,g: D = A*B + C, where the shape of D=16x8, shape of A=16x16, shape of B=16x8, shape of C=16x8) for shape and type:m16n8k16 (f32.f16.f16.f32), pls update the description for this helper function.
There was a problem hiding this comment.
Added more description to the algo functionality
| // d2 += row8{ a0, a1, a8, a9 } * col0{ b0, b1, b8, b9 } | ||
| // d3 += row8{ a1, a1, a8, a9 } * col1{ b0, b1, b8, b9 } | ||
| for (int j = 0; j < 4; j++) { | ||
| *d[0] += static_cast<CDType>(ra[j]) * static_cast<CDType>(rb[j]); |
There was a problem hiding this comment.
d0~d3 is the four results of result D (D=AxB+C for m16n8k16 (f32.f16.f16.f32)), from the algorithm of matrix multiplication, d0(e.g., the position of d0 in matrix D where [i, j]) is the accumulation of dot multiplication of the whole i row of matrix A, and the whole j column of matrix B. In the subgroup level, for the current work item, pls explain how the whole i row of matrix A, and the whole j column of matrix B are loaded. For example, from the parameter of void *a_mat, void *b_mat, void *c_mat shown in the lit test mmu.cu:
asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
" { %0, %1, %2, %3 }, "
" { %4, %5, %6, %7 }, "
" { %8, %9 }, "
" { %0, %1, %2, %3 };"
: "+f"(fc[0]), "+f"(fc[1]), "+f"(fc[2]), "+f"(fc[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]),
"r"(b[0]), "r"(b[1]));
only 8 elements of matrix A and 4 elements of matrix B are passed into ASM instruction, while the result of this ASM is that the four elements of in result D are calculated, so for each one of the four elements, pls explain in the helper function, how the whole i row of matrix A, and the whole j column of matrix B are loaded.
There was a problem hiding this comment.
Changed the description to reflect Added more description to the algo functionality
9e2c234 to
7527809
Compare
| template <typename T> struct MMAType { | ||
| using PackType = uint32_t; | ||
| }; |
There was a problem hiding this comment.
If only uint32_t is enough, we can use uint32_t directly instead of introducing MMAType
There was a problem hiding this comment.
Some shapes involving f64 require a pack type of double. So, suggesting to keep this
| // Each work item Wi (i=0...31) gathers 2 row & 2 col matrix fragments | ||
| // of length k (8) from A & B matrices respectively into recv_a & recv_b | ||
| // across 4 iterations using 4 neighboring work items with below mapping |
There was a problem hiding this comment.
Could you refine this comment block? it is difficult for users to understand.
There was a problem hiding this comment.
Simplified it
| // logic: | ||
| // row0 = (lane >> 2) & row1 = (lane >> 2) + 8 | ||
| // col0 = (lane % 4) * 2 & col1 = (lane % 4) * 2 + 1 | ||
| for (int i = 0; i < 4; i++) { |
There was a problem hiding this comment.
Could explain the meaning of 4?
There was a problem hiding this comment.
Added comments to describe the distribution of rows & cols across 4 work items
tomflinda
left a comment
There was a problem hiding this comment.
Pls address the comment I left.
This PR adds support for below configs of m16n8k16