Conversation
nastya236
left a comment
There was a problem hiding this comment.
Looks great! I asked couple of questions as usual for my own understanding :)
| // Use NAX kernel if available | ||
| if (use_nax) { | ||
| int average_k = K / batch_size_out; | ||
| bm = 64; |
There was a problem hiding this comment.
Just out of curiosity, why 2 x 2 simdgroups and 64 x 64 tiles?
There was a problem hiding this comment.
Well this can be tuned for sure. I haven't actually ran extensive testing and it seemed like a good default 🤷♂️
| MTL::Size group_dims = MTL::Size(32, wn, wm); | ||
| MTL::Size grid_dims = | ||
| MTL::Size(params.tiles_n, params.tiles_m, batch_size_out); | ||
| MTL::Size grid_dims = (use_nax && bk == 64) |
There was a problem hiding this comment.
I am wondering should non nax kernel also swap dimensions when there are many small segments?
There was a problem hiding this comment.
Yeah it is possible. The idea was that we launch threadgroups in order x, y, z and since K is very small in these cases we want the different Ks to run one after the other since it is very likely that the 2nd matmul's K will be in the cache already.
As the title suggests it adds a NAX kernel for the
segmented_mm. It also fixes #3362.Micro benchmark numbers