Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlx/backend/metal/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ if(MLX_METAL_JIT)
make_jit_source(steel/gemm/kernels/steel_gemm_fused_nax)
make_jit_source(steel/gemm/kernels/steel_gemm_gather_nax)
make_jit_source(steel/gemm/kernels/steel_gemm_splitk_nax)
make_jit_source(steel/gemm/kernels/steel_gemm_segmented_nax)

make_jit_source(quantized_nax kernels/quantized_utils.h)
make_jit_source(fp_quantized_nax kernels/quantized_utils.h kernels/fp8.h
Expand Down
1 change: 1 addition & 0 deletions mlx/backend/metal/jit/includes.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ const char* gemm_nax();
const char* steel_gemm_fused_nax();
const char* steel_gemm_gather_nax();
const char* steel_gemm_splitk_nax();
const char* steel_gemm_segmented_nax();

const char* quantized_nax();
const char* fp_quantized_nax();
Expand Down
34 changes: 34 additions & 0 deletions mlx/backend/metal/jit_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1012,6 +1012,40 @@ MTL::ComputePipelineState* get_steel_gemm_splitk_nax_kernel(
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
}

MTL::ComputePipelineState* get_steel_gemm_segmented_nax_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const array& out,
bool transpose_a,
bool transpose_b,
int bm,
int bn,
int bk,
int wm,
int wn) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::gemm_nax()
<< metal::steel_gemm_segmented_nax()
<< get_template_definition(
lib_name,
"segmented_mm_nax",
get_type_string(out.dtype()),
bm,
bn,
bk,
wm,
wn,
transpose_a,
transpose_b);
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
}

MTL::ComputePipelineState* get_qmm_nax_kernel(
metal::Device& d,
const std::string& kernel_name,
Expand Down
14 changes: 14 additions & 0 deletions mlx/backend/metal/kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,20 @@ MTL::ComputePipelineState* get_steel_gemm_splitk_nax_kernel(
int wm,
int wn);

MTL::ComputePipelineState* get_steel_gemm_segmented_nax_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const array& out,
bool transpose_a,
bool transpose_b,
int bm,
int bn,
int bk,
int wm,
int wn);

MTL::ComputePipelineState* get_qmm_nax_kernel(
metal::Device& d,
const std::string& kernel_name,
Expand Down
5 changes: 4 additions & 1 deletion mlx/backend/metal/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ set(STEEL_NAX_HEADERS
steel/utils/integral_constant.h
steel/gemm/kernels/steel_gemm_fused_nax.h
steel/gemm/kernels/steel_gemm_gather_nax.h
steel/gemm/kernels/steel_gemm_splitk_nax.h)
steel/gemm/kernels/steel_gemm_splitk_nax.h
steel/gemm/kernels/steel_gemm_segmented_nax.h)

set(STEEL_NAX_ATTN_HEADERS
steel/defines.h
Expand Down Expand Up @@ -160,6 +161,8 @@ if(NOT MLX_METAL_JIT)
build_kernel(steel/gemm/kernels/steel_gemm_fused_nax ${STEEL_NAX_HEADERS})
build_kernel(steel/gemm/kernels/steel_gemm_gather_nax ${STEEL_NAX_HEADERS})
build_kernel(steel/gemm/kernels/steel_gemm_splitk_nax ${STEEL_NAX_HEADERS})
build_kernel(steel/gemm/kernels/steel_gemm_segmented_nax
${STEEL_NAX_HEADERS})

build_kernel(quantized_nax quantized_nax.h ${STEEL_NAX_HEADERS})
build_kernel(fp_quantized_nax fp4.h fp8.h fp_quantized_nax.h
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,4 @@

instantiate_gather_mm_shapes_helper(float16, half, float16, half);
instantiate_gather_mm_shapes_helper(bfloat16, bfloat, bfloat16, bfloat);
instantiate_gather_mm_shapes_helper(float32, float, float32, float);
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
// Copyright © 2026 Apple Inc.

using namespace mlx::steel;

constant bool segments_contiguous [[function_constant(199)]];
constant bool align_M [[function_constant(200)]];
constant bool align_N [[function_constant(201)]];

template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
typename AccumType = float>
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]]
void segmented_mm_nax(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
const device uint32_t* segments [[buffer(2)]],
device T* C [[buffer(3)]],
const constant GEMMParams* params [[buffer(4)]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]]) {
const int tid_m = (BK > 64) ? tid.y : tid.z;
const int tid_n = (BK > 64) ? tid.x : tid.y;
const int tid_s = (BK > 64) ? tid.z : tid.x;

const int c_row = tid_m * BM;
const int c_col = tid_n * BN;
const size_t c_row_long = size_t(c_row);
const size_t c_col_long = size_t(c_col);

if (params->tiles_n <= static_cast<int>(tid_n) ||
params->tiles_m <= static_cast<int>(tid_m)) {
return;
}

A += transpose_a ? c_row_long : c_row_long * params->lda;
B += transpose_b ? c_col_long * params->ldb : c_col_long;
C += c_row_long * params->ldd + c_col_long;

uint32_t k_start, k_end;
if (segments_contiguous) {
k_start = segments[2 * tid_s];
k_end = segments[2 * tid_s + 1];
} else {
k_start = segments[tid_s];
k_end = segments[tid_s + 1];
}
A += transpose_a ? k_start * params->lda : k_start;
B += transpose_b ? k_start : k_start * params->ldb;
C += tid_s * params->batch_stride_d;

constexpr short SM = BM / WM;
constexpr short SN = BN / WN;
constexpr short SK = 32;

constexpr short TM = SM / 16;
constexpr short TN = SN / 16;

const short tm = SM * (simd_group_id / WN);
const short tn = SN * (simd_group_id % WN);

const int sgp_sm_int =
align_M ? int(SM) : min(int(SM), params->M - (c_row + tm));
const short sgp_sm = short(sgp_sm_int);
const bool is_unaligned_sm = align_M ? false : (sgp_sm != SM);

const int sgp_sn_int =
align_N ? int(SN) : min(int(SN), params->N - (c_col + tn));
const short sgp_sn = short(sgp_sn_int);
const bool is_unaligned_sn = align_N ? false : (sgp_sn != SN);

A += transpose_a ? tm : (tm * params->lda);
B += transpose_b ? (tn * params->ldb) : tn;
C += tm * params->ldd + tn;

NAXTile<AccumType, TM, TN> Dtile;
Dtile.clear();

const int segment_k_size = k_end - k_start;
const int segment_k_iters = segment_k_size / BK;
const bool segment_k_aligned = (segment_k_size % BK) == 0;

dispatch_bool(segment_k_aligned, [&](auto kAlignedK) {
dispatch_bool(align_M || !is_unaligned_sm, [&](auto kAlignedM) {
dispatch_bool(align_N || !is_unaligned_sn, [&](auto kAlignedN) {
Dtile = gemm_loop<
T,
SM,
SN,
SK,
BK,
transpose_a,
transpose_b,
kAlignedM.value,
kAlignedN.value,
kAlignedK.value,
AccumType>(
A,
B,
params->lda,
params->ldb,
segment_k_size,
segment_k_iters,
sgp_sm,
sgp_sn);
});
});
});

dispatch_bool(align_M || !is_unaligned_sm, [&](auto kAlignedM) {
dispatch_bool(align_N || !is_unaligned_sn, [&](auto kAlignedN) {
if constexpr (kAlignedM && kAlignedN) {
Dtile.store(C, int(params->ldd));
} else {
Dtile.store_safe(C, int(params->ldd), short2(sgp_sn, sgp_sm));
}
});
});
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright © 2026 Apple Inc.

#include <metal_stdlib>

#include "mlx/backend/metal/kernels/utils.h"

#include "mlx/backend/metal/kernels/steel/gemm/gemm_nax.h"
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented_nax.h"

// clang-format off
#define instantiate_segmented_mm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_kernel( \
"steel_segmented_mm_nax_" #tname "_" #iname "_" #oname \
"_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn, \
segmented_mm_nax, itype, bm, bn, bk, wm, wn, trans_a, trans_b, float)

#define instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_segmented_mm(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_segmented_mm(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_segmented_mm(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_segmented_mm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)

#define instantiate_segmented_mm_shapes_helper(iname, itype, oname, otype) \
instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 64, 64, 256, 2, 2) \
instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 64, 64, 128, 2, 2) \
instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 64, 64, 64, 2, 2)

instantiate_segmented_mm_shapes_helper(float16, half, float16, half);
instantiate_segmented_mm_shapes_helper(bfloat16, bfloat, bfloat16, bfloat);
instantiate_segmented_mm_shapes_helper(float32, float, float32, float);
// clang-format on
76 changes: 52 additions & 24 deletions mlx/backend/metal/matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2478,15 +2478,36 @@ void segmented_mm(
char devc = d.get_architecture().back();
GEMM_TPARAM_MACRO(devc)

bool use_nax = metal::is_nax_available() &&
(env::enable_tf32() || out.dtype() != float32);

const bool align_M = (M % bm) == 0;
const bool align_N = (N % bn) == 0;

// Define the kernel name
metal::MTLFCList func_consts = {
{&segments_contiguous, MTL::DataType::DataTypeBool, 199},
{&align_M, MTL::DataType::DataTypeBool, 200},
{&align_N, MTL::DataType::DataTypeBool, 201},
};

std::string base_name;
base_name.reserve(128);
base_name += "steel_segmented_mm_";

// Use NAX kernel if available
if (use_nax) {
int average_k = K / batch_size_out;
bm = 64;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just out of curiosity, why 2 x 2 simdgroups and 64 x 64 tiles?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well this can be tuned for sure. I haven't actually ran extensive testing and it seemed like a good default 🤷‍♂️

bn = 64;
bk = (average_k >= 256) ? 256 : (average_k >= 128) ? 128 : 64;
wm = 2;
wn = 2;

base_name += "nax_";
}

concatenate(
base_name,
"steel_segmented_mm_",
transpose_a ? 't' : 'n',
transpose_b ? 't' : 'n',
"_",
Expand All @@ -2504,13 +2525,6 @@ void segmented_mm(
"_wn",
wn);

metal::MTLFCList func_consts = {
{&segments_contiguous, MTL::DataType::DataTypeBool, 199},
{&align_M, MTL::DataType::DataTypeBool, 200},
{&align_N, MTL::DataType::DataTypeBool, 201},
};

// And the kernel hash that includes the function constants
std::string hash_name;
hash_name.reserve(128);
concatenate(
Expand All @@ -2524,19 +2538,32 @@ void segmented_mm(
align_N ? 't' : 'n');

// Get and set the kernel
auto kernel = get_steel_gemm_segmented_kernel(
d,
base_name,
hash_name,
func_consts,
out,
transpose_a,
transpose_b,
bm,
bn,
bk,
wm,
wn);
auto kernel = (use_nax) ? get_steel_gemm_segmented_nax_kernel(
d,
base_name,
hash_name,
func_consts,
out,
transpose_a,
transpose_b,
bm,
bn,
bk,
wm,
wn)
: get_steel_gemm_segmented_kernel(
d,
base_name,
hash_name,
func_consts,
out,
transpose_a,
transpose_b,
bm,
bn,
bk,
wm,
wn);
compute_encoder.set_compute_pipeline_state(kernel);

// Prepare the matmul params
Expand All @@ -2557,8 +2584,9 @@ void segmented_mm(

// Prepare the grid
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims =
MTL::Size(params.tiles_n, params.tiles_m, batch_size_out);
MTL::Size grid_dims = (use_nax && bk == 64)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering should non nax kernel also swap dimensions when there are many small segments?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it is possible. The idea was that we launch threadgroups in order x, y, z and since K is very small in these cases we want the different Ks to run one after the other since it is very likely that the 2nd matmul's K will be in the cache already.

? MTL::Size(batch_size_out, params.tiles_n, params.tiles_m)
: MTL::Size(params.tiles_n, params.tiles_m, batch_size_out);

// Launch kernel
compute_encoder.set_input_array(a, 0);
Expand Down
16 changes: 16 additions & 0 deletions mlx/backend/metal/nojit_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,22 @@ MTL::ComputePipelineState* get_steel_gemm_splitk_nax_kernel(
return d.get_kernel(kernel_name, hash_name, func_consts);
}

MTL::ComputePipelineState* get_steel_gemm_segmented_nax_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const array&,
bool,
bool,
int,
int,
int,
int,
int) {
return d.get_kernel(kernel_name, hash_name, func_consts);
}

MTL::ComputePipelineState* get_qmm_nax_kernel(
metal::Device& d,
const std::string& kernel_name,
Expand Down
Loading