diff --git a/mlx/backend/metal/CMakeLists.txt b/mlx/backend/metal/CMakeLists.txt index 67c69579ad..bbb08137f6 100644 --- a/mlx/backend/metal/CMakeLists.txt +++ b/mlx/backend/metal/CMakeLists.txt @@ -89,6 +89,7 @@ if(MLX_METAL_JIT) make_jit_source(steel/gemm/kernels/steel_gemm_fused_nax) make_jit_source(steel/gemm/kernels/steel_gemm_gather_nax) make_jit_source(steel/gemm/kernels/steel_gemm_splitk_nax) + make_jit_source(steel/gemm/kernels/steel_gemm_segmented_nax) make_jit_source(quantized_nax kernels/quantized_utils.h) make_jit_source(fp_quantized_nax kernels/quantized_utils.h kernels/fp8.h diff --git a/mlx/backend/metal/jit/includes.h b/mlx/backend/metal/jit/includes.h index dcaf09a1e9..e22efa96d0 100644 --- a/mlx/backend/metal/jit/includes.h +++ b/mlx/backend/metal/jit/includes.h @@ -50,6 +50,7 @@ const char* gemm_nax(); const char* steel_gemm_fused_nax(); const char* steel_gemm_gather_nax(); const char* steel_gemm_splitk_nax(); +const char* steel_gemm_segmented_nax(); const char* quantized_nax(); const char* fp_quantized_nax(); diff --git a/mlx/backend/metal/jit_kernels.cpp b/mlx/backend/metal/jit_kernels.cpp index a0703cd875..9c47b53b40 100644 --- a/mlx/backend/metal/jit_kernels.cpp +++ b/mlx/backend/metal/jit_kernels.cpp @@ -1012,6 +1012,40 @@ MTL::ComputePipelineState* get_steel_gemm_splitk_nax_kernel( return d.get_kernel(kernel_name, lib, hash_name, func_consts); } +MTL::ComputePipelineState* get_steel_gemm_segmented_nax_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& hash_name, + const metal::MTLFCList& func_consts, + const array& out, + bool transpose_a, + bool transpose_b, + int bm, + int bn, + int bk, + int wm, + int wn) { + const auto& lib_name = kernel_name; + auto lib = d.get_library(lib_name, [&]() { + std::ostringstream kernel_source; + kernel_source << metal::utils() << metal::gemm_nax() + << metal::steel_gemm_segmented_nax() + << get_template_definition( + lib_name, + "segmented_mm_nax", + get_type_string(out.dtype()), + bm, + bn, + bk, + wm, + wn, + transpose_a, + transpose_b); + return kernel_source.str(); + }); + return d.get_kernel(kernel_name, lib, hash_name, func_consts); +} + MTL::ComputePipelineState* get_qmm_nax_kernel( metal::Device& d, const std::string& kernel_name, diff --git a/mlx/backend/metal/kernels.h b/mlx/backend/metal/kernels.h index 63fccc59ff..dc0dab970d 100644 --- a/mlx/backend/metal/kernels.h +++ b/mlx/backend/metal/kernels.h @@ -311,6 +311,20 @@ MTL::ComputePipelineState* get_steel_gemm_splitk_nax_kernel( int wm, int wn); +MTL::ComputePipelineState* get_steel_gemm_segmented_nax_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& hash_name, + const metal::MTLFCList& func_consts, + const array& out, + bool transpose_a, + bool transpose_b, + int bm, + int bn, + int bk, + int wm, + int wn); + MTL::ComputePipelineState* get_qmm_nax_kernel( metal::Device& d, const std::string& kernel_name, diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index 8c4246d16f..78076c20b8 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -109,7 +109,8 @@ set(STEEL_NAX_HEADERS steel/utils/integral_constant.h steel/gemm/kernels/steel_gemm_fused_nax.h steel/gemm/kernels/steel_gemm_gather_nax.h - steel/gemm/kernels/steel_gemm_splitk_nax.h) + steel/gemm/kernels/steel_gemm_splitk_nax.h + steel/gemm/kernels/steel_gemm_segmented_nax.h) set(STEEL_NAX_ATTN_HEADERS steel/defines.h @@ -160,6 +161,8 @@ if(NOT MLX_METAL_JIT) build_kernel(steel/gemm/kernels/steel_gemm_fused_nax ${STEEL_NAX_HEADERS}) build_kernel(steel/gemm/kernels/steel_gemm_gather_nax ${STEEL_NAX_HEADERS}) build_kernel(steel/gemm/kernels/steel_gemm_splitk_nax ${STEEL_NAX_HEADERS}) + build_kernel(steel/gemm/kernels/steel_gemm_segmented_nax + ${STEEL_NAX_HEADERS}) build_kernel(quantized_nax quantized_nax.h ${STEEL_NAX_HEADERS}) build_kernel(fp_quantized_nax fp4.h fp8.h fp_quantized_nax.h diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal index 690c7a3059..6141297bab 100644 --- a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal @@ -35,3 +35,4 @@ instantiate_gather_mm_shapes_helper(float16, half, float16, half); instantiate_gather_mm_shapes_helper(bfloat16, bfloat, bfloat16, bfloat); +instantiate_gather_mm_shapes_helper(float32, float, float32, float); diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented_nax.h b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented_nax.h new file mode 100644 index 0000000000..9cfece5c54 --- /dev/null +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented_nax.h @@ -0,0 +1,125 @@ +// Copyright © 2026 Apple Inc. + +using namespace mlx::steel; + +constant bool segments_contiguous [[function_constant(199)]]; +constant bool align_M [[function_constant(200)]]; +constant bool align_N [[function_constant(201)]]; + +template < + typename T, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + typename AccumType = float> +[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] +void segmented_mm_nax( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + const device uint32_t* segments [[buffer(2)]], + device T* C [[buffer(3)]], + const constant GEMMParams* params [[buffer(4)]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]]) { + const int tid_m = (BK > 64) ? tid.y : tid.z; + const int tid_n = (BK > 64) ? tid.x : tid.y; + const int tid_s = (BK > 64) ? tid.z : tid.x; + + const int c_row = tid_m * BM; + const int c_col = tid_n * BN; + const size_t c_row_long = size_t(c_row); + const size_t c_col_long = size_t(c_col); + + if (params->tiles_n <= static_cast(tid_n) || + params->tiles_m <= static_cast(tid_m)) { + return; + } + + A += transpose_a ? c_row_long : c_row_long * params->lda; + B += transpose_b ? c_col_long * params->ldb : c_col_long; + C += c_row_long * params->ldd + c_col_long; + + uint32_t k_start, k_end; + if (segments_contiguous) { + k_start = segments[2 * tid_s]; + k_end = segments[2 * tid_s + 1]; + } else { + k_start = segments[tid_s]; + k_end = segments[tid_s + 1]; + } + A += transpose_a ? k_start * params->lda : k_start; + B += transpose_b ? k_start : k_start * params->ldb; + C += tid_s * params->batch_stride_d; + + constexpr short SM = BM / WM; + constexpr short SN = BN / WN; + constexpr short SK = 32; + + constexpr short TM = SM / 16; + constexpr short TN = SN / 16; + + const short tm = SM * (simd_group_id / WN); + const short tn = SN * (simd_group_id % WN); + + const int sgp_sm_int = + align_M ? int(SM) : min(int(SM), params->M - (c_row + tm)); + const short sgp_sm = short(sgp_sm_int); + const bool is_unaligned_sm = align_M ? false : (sgp_sm != SM); + + const int sgp_sn_int = + align_N ? int(SN) : min(int(SN), params->N - (c_col + tn)); + const short sgp_sn = short(sgp_sn_int); + const bool is_unaligned_sn = align_N ? false : (sgp_sn != SN); + + A += transpose_a ? tm : (tm * params->lda); + B += transpose_b ? (tn * params->ldb) : tn; + C += tm * params->ldd + tn; + + NAXTile Dtile; + Dtile.clear(); + + const int segment_k_size = k_end - k_start; + const int segment_k_iters = segment_k_size / BK; + const bool segment_k_aligned = (segment_k_size % BK) == 0; + + dispatch_bool(segment_k_aligned, [&](auto kAlignedK) { + dispatch_bool(align_M || !is_unaligned_sm, [&](auto kAlignedM) { + dispatch_bool(align_N || !is_unaligned_sn, [&](auto kAlignedN) { + Dtile = gemm_loop< + T, + SM, + SN, + SK, + BK, + transpose_a, + transpose_b, + kAlignedM.value, + kAlignedN.value, + kAlignedK.value, + AccumType>( + A, + B, + params->lda, + params->ldb, + segment_k_size, + segment_k_iters, + sgp_sm, + sgp_sn); + }); + }); + }); + + dispatch_bool(align_M || !is_unaligned_sm, [&](auto kAlignedM) { + dispatch_bool(align_N || !is_unaligned_sn, [&](auto kAlignedN) { + if constexpr (kAlignedM && kAlignedN) { + Dtile.store(C, int(params->ldd)); + } else { + Dtile.store_safe(C, int(params->ldd), short2(sgp_sn, sgp_sm)); + } + }); + }); +} diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented_nax.metal b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented_nax.metal new file mode 100644 index 0000000000..78085cedea --- /dev/null +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented_nax.metal @@ -0,0 +1,31 @@ +// Copyright © 2026 Apple Inc. + +#include + +#include "mlx/backend/metal/kernels/utils.h" + +#include "mlx/backend/metal/kernels/steel/gemm/gemm_nax.h" +#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented_nax.h" + +// clang-format off +#define instantiate_segmented_mm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_kernel( \ + "steel_segmented_mm_nax_" #tname "_" #iname "_" #oname \ + "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn, \ + segmented_mm_nax, itype, bm, bn, bk, wm, wn, trans_a, trans_b, float) + +#define instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_segmented_mm(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_segmented_mm(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_segmented_mm(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_segmented_mm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) + +#define instantiate_segmented_mm_shapes_helper(iname, itype, oname, otype) \ + instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 64, 64, 256, 2, 2) \ + instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 64, 64, 128, 2, 2) \ + instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 64, 64, 64, 2, 2) + +instantiate_segmented_mm_shapes_helper(float16, half, float16, half); +instantiate_segmented_mm_shapes_helper(bfloat16, bfloat, bfloat16, bfloat); +instantiate_segmented_mm_shapes_helper(float32, float, float32, float); +// clang-format on diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index df0065be55..f65a9730c8 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -2478,15 +2478,36 @@ void segmented_mm( char devc = d.get_architecture().back(); GEMM_TPARAM_MACRO(devc) + bool use_nax = metal::is_nax_available() && + (env::enable_tf32() || out.dtype() != float32); + const bool align_M = (M % bm) == 0; const bool align_N = (N % bn) == 0; - // Define the kernel name + metal::MTLFCList func_consts = { + {&segments_contiguous, MTL::DataType::DataTypeBool, 199}, + {&align_M, MTL::DataType::DataTypeBool, 200}, + {&align_N, MTL::DataType::DataTypeBool, 201}, + }; + std::string base_name; base_name.reserve(128); + base_name += "steel_segmented_mm_"; + + // Use NAX kernel if available + if (use_nax) { + int average_k = K / batch_size_out; + bm = 64; + bn = 64; + bk = (average_k >= 256) ? 256 : (average_k >= 128) ? 128 : 64; + wm = 2; + wn = 2; + + base_name += "nax_"; + } + concatenate( base_name, - "steel_segmented_mm_", transpose_a ? 't' : 'n', transpose_b ? 't' : 'n', "_", @@ -2504,13 +2525,6 @@ void segmented_mm( "_wn", wn); - metal::MTLFCList func_consts = { - {&segments_contiguous, MTL::DataType::DataTypeBool, 199}, - {&align_M, MTL::DataType::DataTypeBool, 200}, - {&align_N, MTL::DataType::DataTypeBool, 201}, - }; - - // And the kernel hash that includes the function constants std::string hash_name; hash_name.reserve(128); concatenate( @@ -2524,19 +2538,32 @@ void segmented_mm( align_N ? 't' : 'n'); // Get and set the kernel - auto kernel = get_steel_gemm_segmented_kernel( - d, - base_name, - hash_name, - func_consts, - out, - transpose_a, - transpose_b, - bm, - bn, - bk, - wm, - wn); + auto kernel = (use_nax) ? get_steel_gemm_segmented_nax_kernel( + d, + base_name, + hash_name, + func_consts, + out, + transpose_a, + transpose_b, + bm, + bn, + bk, + wm, + wn) + : get_steel_gemm_segmented_kernel( + d, + base_name, + hash_name, + func_consts, + out, + transpose_a, + transpose_b, + bm, + bn, + bk, + wm, + wn); compute_encoder.set_compute_pipeline_state(kernel); // Prepare the matmul params @@ -2557,8 +2584,9 @@ void segmented_mm( // Prepare the grid MTL::Size group_dims = MTL::Size(32, wn, wm); - MTL::Size grid_dims = - MTL::Size(params.tiles_n, params.tiles_m, batch_size_out); + MTL::Size grid_dims = (use_nax && bk == 64) + ? MTL::Size(batch_size_out, params.tiles_n, params.tiles_m) + : MTL::Size(params.tiles_n, params.tiles_m, batch_size_out); // Launch kernel compute_encoder.set_input_array(a, 0); diff --git a/mlx/backend/metal/nojit_kernels.cpp b/mlx/backend/metal/nojit_kernels.cpp index a0b02084c2..2ed74f470a 100644 --- a/mlx/backend/metal/nojit_kernels.cpp +++ b/mlx/backend/metal/nojit_kernels.cpp @@ -368,6 +368,22 @@ MTL::ComputePipelineState* get_steel_gemm_splitk_nax_kernel( return d.get_kernel(kernel_name, hash_name, func_consts); } +MTL::ComputePipelineState* get_steel_gemm_segmented_nax_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& hash_name, + const metal::MTLFCList& func_consts, + const array&, + bool, + bool, + int, + int, + int, + int, + int) { + return d.get_kernel(kernel_name, hash_name, func_consts); +} + MTL::ComputePipelineState* get_qmm_nax_kernel( metal::Device& d, const std::string& kernel_name,