Skip to content

Segmented mm nax kernel#3419

Merged
angeloskath merged 3 commits intomainfrom
segmented-mm-nax
Apr 17, 2026
Merged

Segmented mm nax kernel#3419
angeloskath merged 3 commits intomainfrom
segmented-mm-nax

Conversation

@angeloskath
Copy link
Copy Markdown
Member

As the title suggests it adds a NAX kernel for the segmented_mm. It also fixes #3362.

Micro benchmark numbers

float32
--------

| Case                | MLX before | MLX after | Speedup |
|---------------------|------------|-----------|---------|
| 2048x2048x4096x10   |      2.609 |     1.538 |   1.70x |
| 2048x2048x4096x50   |      3.163 |     2.266 |   1.40x |
| 2048x2048x4096x100  |      5.861 |     5.111 |   1.15x |
| 2048x2048x4096x200  |     10.697 |     9.710 |   1.10x |
| 2048x2048x8192x10   |      4.947 |     2.019 |   2.45x |
| 2048x2048x8192x50   |      5.526 |     3.233 |   1.71x |
| 2048x2048x8192x100  |      6.084 |     4.254 |   1.43x |
| 2048x2048x8192x200  |      9.129 |     7.942 |   1.15x |
| 4096x4096x16384x10  |     37.024 |    13.469 |   2.75x |
| 4096x4096x16384x50  |     37.826 |    20.329 |   1.86x |
| 4096x4096x16384x100 |     41.548 |    23.822 |   1.74x |
| 4096x4096x16384x200 |     49.209 |    31.521 |   1.56x |


float16
--------

| Case                | MLX before | MLX after | Speedup |
|---------------------|------------|-----------|---------|
| 2048x2048x4096x10   |      2.479 |     1.657 |   1.50x |
| 2048x2048x4096x50   |      2.620 |     1.312 |   2.00x |
| 2048x2048x4096x100  |      3.120 |     2.229 |   1.40x |
| 2048x2048x4096x200  |      5.638 |     4.832 |   1.17x |
| 2048x2048x8192x10   |      4.687 |     1.680 |   2.79x |
| 2048x2048x8192x50   |      4.776 |     1.855 |   2.57x |
| 2048x2048x8192x100  |      5.104 |     2.943 |   1.73x |
| 2048x2048x8192x200  |      6.075 |     4.283 |   1.42x |
| 4096x4096x16384x10  |     35.108 |     9.488 |   3.70x |
| 4096x4096x16384x50  |     36.647 |    11.823 |   3.10x |
| 4096x4096x16384x100 |     36.862 |    16.417 |   2.25x |
| 4096x4096x16384x200 |     39.391 |    15.950 |   2.47x |

@angeloskath angeloskath requested review from nastya236 and zcbenz April 16, 2026 09:53
Copy link
Copy Markdown
Collaborator

@nastya236 nastya236 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! I asked couple of questions as usual for my own understanding :)

// Use NAX kernel if available
if (use_nax) {
int average_k = K / batch_size_out;
bm = 64;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just out of curiosity, why 2 x 2 simdgroups and 64 x 64 tiles?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well this can be tuned for sure. I haven't actually ran extensive testing and it seemed like a good default 🤷‍♂️

MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims =
MTL::Size(params.tiles_n, params.tiles_m, batch_size_out);
MTL::Size grid_dims = (use_nax && bk == 64)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering should non nax kernel also swap dimensions when there are many small segments?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it is possible. The idea was that we launch threadgroups in order x, y, z and since K is very small in these cases we want the different Ks to run one after the other since it is very likely that the 2nd matmul's K will be in the cache already.

@angeloskath angeloskath merged commit 940ba47 into main Apr 17, 2026
16 checks passed
@angeloskath angeloskath deleted the segmented-mm-nax branch April 17, 2026 00:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] MoE LoRA training crashes on M5 Max — missing float32 NAX gather_mm kernel

2 participants