-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Segmented mm nax kernel #3419
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Segmented mm nax kernel #3419
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,125 @@ | ||
| // Copyright © 2026 Apple Inc. | ||
|
|
||
| using namespace mlx::steel; | ||
|
|
||
| constant bool segments_contiguous [[function_constant(199)]]; | ||
| constant bool align_M [[function_constant(200)]]; | ||
| constant bool align_N [[function_constant(201)]]; | ||
|
|
||
| template < | ||
| typename T, | ||
| int BM, | ||
| int BN, | ||
| int BK, | ||
| int WM, | ||
| int WN, | ||
| bool transpose_a, | ||
| bool transpose_b, | ||
| typename AccumType = float> | ||
| [[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] | ||
| void segmented_mm_nax( | ||
| const device T* A [[buffer(0)]], | ||
| const device T* B [[buffer(1)]], | ||
| const device uint32_t* segments [[buffer(2)]], | ||
| device T* C [[buffer(3)]], | ||
| const constant GEMMParams* params [[buffer(4)]], | ||
| uint simd_group_id [[simdgroup_index_in_threadgroup]], | ||
| uint3 tid [[threadgroup_position_in_grid]]) { | ||
| const int tid_m = (BK > 64) ? tid.y : tid.z; | ||
| const int tid_n = (BK > 64) ? tid.x : tid.y; | ||
| const int tid_s = (BK > 64) ? tid.z : tid.x; | ||
|
|
||
| const int c_row = tid_m * BM; | ||
| const int c_col = tid_n * BN; | ||
| const size_t c_row_long = size_t(c_row); | ||
| const size_t c_col_long = size_t(c_col); | ||
|
|
||
| if (params->tiles_n <= static_cast<int>(tid_n) || | ||
| params->tiles_m <= static_cast<int>(tid_m)) { | ||
| return; | ||
| } | ||
|
|
||
| A += transpose_a ? c_row_long : c_row_long * params->lda; | ||
| B += transpose_b ? c_col_long * params->ldb : c_col_long; | ||
| C += c_row_long * params->ldd + c_col_long; | ||
|
|
||
| uint32_t k_start, k_end; | ||
| if (segments_contiguous) { | ||
| k_start = segments[2 * tid_s]; | ||
| k_end = segments[2 * tid_s + 1]; | ||
| } else { | ||
| k_start = segments[tid_s]; | ||
| k_end = segments[tid_s + 1]; | ||
| } | ||
| A += transpose_a ? k_start * params->lda : k_start; | ||
| B += transpose_b ? k_start : k_start * params->ldb; | ||
| C += tid_s * params->batch_stride_d; | ||
|
|
||
| constexpr short SM = BM / WM; | ||
| constexpr short SN = BN / WN; | ||
| constexpr short SK = 32; | ||
|
|
||
| constexpr short TM = SM / 16; | ||
| constexpr short TN = SN / 16; | ||
|
|
||
| const short tm = SM * (simd_group_id / WN); | ||
| const short tn = SN * (simd_group_id % WN); | ||
|
|
||
| const int sgp_sm_int = | ||
| align_M ? int(SM) : min(int(SM), params->M - (c_row + tm)); | ||
| const short sgp_sm = short(sgp_sm_int); | ||
| const bool is_unaligned_sm = align_M ? false : (sgp_sm != SM); | ||
|
|
||
| const int sgp_sn_int = | ||
| align_N ? int(SN) : min(int(SN), params->N - (c_col + tn)); | ||
| const short sgp_sn = short(sgp_sn_int); | ||
| const bool is_unaligned_sn = align_N ? false : (sgp_sn != SN); | ||
|
|
||
| A += transpose_a ? tm : (tm * params->lda); | ||
| B += transpose_b ? (tn * params->ldb) : tn; | ||
| C += tm * params->ldd + tn; | ||
|
|
||
| NAXTile<AccumType, TM, TN> Dtile; | ||
| Dtile.clear(); | ||
|
|
||
| const int segment_k_size = k_end - k_start; | ||
| const int segment_k_iters = segment_k_size / BK; | ||
| const bool segment_k_aligned = (segment_k_size % BK) == 0; | ||
|
|
||
| dispatch_bool(segment_k_aligned, [&](auto kAlignedK) { | ||
| dispatch_bool(align_M || !is_unaligned_sm, [&](auto kAlignedM) { | ||
| dispatch_bool(align_N || !is_unaligned_sn, [&](auto kAlignedN) { | ||
| Dtile = gemm_loop< | ||
| T, | ||
| SM, | ||
| SN, | ||
| SK, | ||
| BK, | ||
| transpose_a, | ||
| transpose_b, | ||
| kAlignedM.value, | ||
| kAlignedN.value, | ||
| kAlignedK.value, | ||
| AccumType>( | ||
| A, | ||
| B, | ||
| params->lda, | ||
| params->ldb, | ||
| segment_k_size, | ||
| segment_k_iters, | ||
| sgp_sm, | ||
| sgp_sn); | ||
| }); | ||
| }); | ||
| }); | ||
|
|
||
| dispatch_bool(align_M || !is_unaligned_sm, [&](auto kAlignedM) { | ||
| dispatch_bool(align_N || !is_unaligned_sn, [&](auto kAlignedN) { | ||
| if constexpr (kAlignedM && kAlignedN) { | ||
| Dtile.store(C, int(params->ldd)); | ||
| } else { | ||
| Dtile.store_safe(C, int(params->ldd), short2(sgp_sn, sgp_sm)); | ||
| } | ||
| }); | ||
| }); | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,31 @@ | ||
| // Copyright © 2026 Apple Inc. | ||
|
|
||
| #include <metal_stdlib> | ||
|
|
||
| #include "mlx/backend/metal/kernels/utils.h" | ||
|
|
||
| #include "mlx/backend/metal/kernels/steel/gemm/gemm_nax.h" | ||
| #include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented_nax.h" | ||
|
|
||
| // clang-format off | ||
| #define instantiate_segmented_mm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ | ||
| instantiate_kernel( \ | ||
| "steel_segmented_mm_nax_" #tname "_" #iname "_" #oname \ | ||
| "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn, \ | ||
| segmented_mm_nax, itype, bm, bn, bk, wm, wn, trans_a, trans_b, float) | ||
|
|
||
| #define instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \ | ||
| instantiate_segmented_mm(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ | ||
| instantiate_segmented_mm(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \ | ||
| instantiate_segmented_mm(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ | ||
| instantiate_segmented_mm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) | ||
|
|
||
| #define instantiate_segmented_mm_shapes_helper(iname, itype, oname, otype) \ | ||
| instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 64, 64, 256, 2, 2) \ | ||
| instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 64, 64, 128, 2, 2) \ | ||
| instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 64, 64, 64, 2, 2) | ||
|
|
||
| instantiate_segmented_mm_shapes_helper(float16, half, float16, half); | ||
| instantiate_segmented_mm_shapes_helper(bfloat16, bfloat, bfloat16, bfloat); | ||
| instantiate_segmented_mm_shapes_helper(float32, float, float32, float); | ||
| // clang-format on |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2478,15 +2478,36 @@ void segmented_mm( | |
| char devc = d.get_architecture().back(); | ||
| GEMM_TPARAM_MACRO(devc) | ||
|
|
||
| bool use_nax = metal::is_nax_available() && | ||
| (env::enable_tf32() || out.dtype() != float32); | ||
|
|
||
| const bool align_M = (M % bm) == 0; | ||
| const bool align_N = (N % bn) == 0; | ||
|
|
||
| // Define the kernel name | ||
| metal::MTLFCList func_consts = { | ||
| {&segments_contiguous, MTL::DataType::DataTypeBool, 199}, | ||
| {&align_M, MTL::DataType::DataTypeBool, 200}, | ||
| {&align_N, MTL::DataType::DataTypeBool, 201}, | ||
| }; | ||
|
|
||
| std::string base_name; | ||
| base_name.reserve(128); | ||
| base_name += "steel_segmented_mm_"; | ||
|
|
||
| // Use NAX kernel if available | ||
| if (use_nax) { | ||
| int average_k = K / batch_size_out; | ||
| bm = 64; | ||
| bn = 64; | ||
| bk = (average_k >= 256) ? 256 : (average_k >= 128) ? 128 : 64; | ||
| wm = 2; | ||
| wn = 2; | ||
|
|
||
| base_name += "nax_"; | ||
| } | ||
|
|
||
| concatenate( | ||
| base_name, | ||
| "steel_segmented_mm_", | ||
| transpose_a ? 't' : 'n', | ||
| transpose_b ? 't' : 'n', | ||
| "_", | ||
|
|
@@ -2504,13 +2525,6 @@ void segmented_mm( | |
| "_wn", | ||
| wn); | ||
|
|
||
| metal::MTLFCList func_consts = { | ||
| {&segments_contiguous, MTL::DataType::DataTypeBool, 199}, | ||
| {&align_M, MTL::DataType::DataTypeBool, 200}, | ||
| {&align_N, MTL::DataType::DataTypeBool, 201}, | ||
| }; | ||
|
|
||
| // And the kernel hash that includes the function constants | ||
| std::string hash_name; | ||
| hash_name.reserve(128); | ||
| concatenate( | ||
|
|
@@ -2524,19 +2538,32 @@ void segmented_mm( | |
| align_N ? 't' : 'n'); | ||
|
|
||
| // Get and set the kernel | ||
| auto kernel = get_steel_gemm_segmented_kernel( | ||
| d, | ||
| base_name, | ||
| hash_name, | ||
| func_consts, | ||
| out, | ||
| transpose_a, | ||
| transpose_b, | ||
| bm, | ||
| bn, | ||
| bk, | ||
| wm, | ||
| wn); | ||
| auto kernel = (use_nax) ? get_steel_gemm_segmented_nax_kernel( | ||
| d, | ||
| base_name, | ||
| hash_name, | ||
| func_consts, | ||
| out, | ||
| transpose_a, | ||
| transpose_b, | ||
| bm, | ||
| bn, | ||
| bk, | ||
| wm, | ||
| wn) | ||
| : get_steel_gemm_segmented_kernel( | ||
| d, | ||
| base_name, | ||
| hash_name, | ||
| func_consts, | ||
| out, | ||
| transpose_a, | ||
| transpose_b, | ||
| bm, | ||
| bn, | ||
| bk, | ||
| wm, | ||
| wn); | ||
| compute_encoder.set_compute_pipeline_state(kernel); | ||
|
|
||
| // Prepare the matmul params | ||
|
|
@@ -2557,8 +2584,9 @@ void segmented_mm( | |
|
|
||
| // Prepare the grid | ||
| MTL::Size group_dims = MTL::Size(32, wn, wm); | ||
| MTL::Size grid_dims = | ||
| MTL::Size(params.tiles_n, params.tiles_m, batch_size_out); | ||
| MTL::Size grid_dims = (use_nax && bk == 64) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am wondering should non nax kernel also swap dimensions when there are many small segments?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah it is possible. The idea was that we launch threadgroups in order x, y, z and since K is very small in these cases we want the different Ks to run one after the other since it is very likely that the 2nd matmul's K will be in the cache already. |
||
| ? MTL::Size(batch_size_out, params.tiles_n, params.tiles_m) | ||
| : MTL::Size(params.tiles_n, params.tiles_m, batch_size_out); | ||
|
|
||
| // Launch kernel | ||
| compute_encoder.set_input_array(a, 0); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just out of curiosity, why 2 x 2 simdgroups and 64 x 64 tiles?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well this can be tuned for sure. I haven't actually ran extensive testing and it seemed like a good default 🤷♂️