Skip to content

Commit 940ba47

Browse files
authored
Segmented mm nax kernel (#3419)
1 parent 8e649be commit 940ba47

10 files changed

Lines changed: 279 additions & 25 deletions

File tree

mlx/backend/metal/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ if(MLX_METAL_JIT)
8989
make_jit_source(steel/gemm/kernels/steel_gemm_fused_nax)
9090
make_jit_source(steel/gemm/kernels/steel_gemm_gather_nax)
9191
make_jit_source(steel/gemm/kernels/steel_gemm_splitk_nax)
92+
make_jit_source(steel/gemm/kernels/steel_gemm_segmented_nax)
9293

9394
make_jit_source(quantized_nax kernels/quantized_utils.h)
9495
make_jit_source(fp_quantized_nax kernels/quantized_utils.h kernels/fp8.h

mlx/backend/metal/jit/includes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ const char* gemm_nax();
5050
const char* steel_gemm_fused_nax();
5151
const char* steel_gemm_gather_nax();
5252
const char* steel_gemm_splitk_nax();
53+
const char* steel_gemm_segmented_nax();
5354

5455
const char* quantized_nax();
5556
const char* fp_quantized_nax();

mlx/backend/metal/jit_kernels.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1012,6 +1012,40 @@ MTL::ComputePipelineState* get_steel_gemm_splitk_nax_kernel(
10121012
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
10131013
}
10141014

1015+
MTL::ComputePipelineState* get_steel_gemm_segmented_nax_kernel(
1016+
metal::Device& d,
1017+
const std::string& kernel_name,
1018+
const std::string& hash_name,
1019+
const metal::MTLFCList& func_consts,
1020+
const array& out,
1021+
bool transpose_a,
1022+
bool transpose_b,
1023+
int bm,
1024+
int bn,
1025+
int bk,
1026+
int wm,
1027+
int wn) {
1028+
const auto& lib_name = kernel_name;
1029+
auto lib = d.get_library(lib_name, [&]() {
1030+
std::ostringstream kernel_source;
1031+
kernel_source << metal::utils() << metal::gemm_nax()
1032+
<< metal::steel_gemm_segmented_nax()
1033+
<< get_template_definition(
1034+
lib_name,
1035+
"segmented_mm_nax",
1036+
get_type_string(out.dtype()),
1037+
bm,
1038+
bn,
1039+
bk,
1040+
wm,
1041+
wn,
1042+
transpose_a,
1043+
transpose_b);
1044+
return kernel_source.str();
1045+
});
1046+
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
1047+
}
1048+
10151049
MTL::ComputePipelineState* get_qmm_nax_kernel(
10161050
metal::Device& d,
10171051
const std::string& kernel_name,

mlx/backend/metal/kernels.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,20 @@ MTL::ComputePipelineState* get_steel_gemm_splitk_nax_kernel(
311311
int wm,
312312
int wn);
313313

314+
MTL::ComputePipelineState* get_steel_gemm_segmented_nax_kernel(
315+
metal::Device& d,
316+
const std::string& kernel_name,
317+
const std::string& hash_name,
318+
const metal::MTLFCList& func_consts,
319+
const array& out,
320+
bool transpose_a,
321+
bool transpose_b,
322+
int bm,
323+
int bn,
324+
int bk,
325+
int wm,
326+
int wn);
327+
314328
MTL::ComputePipelineState* get_qmm_nax_kernel(
315329
metal::Device& d,
316330
const std::string& kernel_name,

mlx/backend/metal/kernels/CMakeLists.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,8 @@ set(STEEL_NAX_HEADERS
109109
steel/utils/integral_constant.h
110110
steel/gemm/kernels/steel_gemm_fused_nax.h
111111
steel/gemm/kernels/steel_gemm_gather_nax.h
112-
steel/gemm/kernels/steel_gemm_splitk_nax.h)
112+
steel/gemm/kernels/steel_gemm_splitk_nax.h
113+
steel/gemm/kernels/steel_gemm_segmented_nax.h)
113114

114115
set(STEEL_NAX_ATTN_HEADERS
115116
steel/defines.h
@@ -160,6 +161,8 @@ if(NOT MLX_METAL_JIT)
160161
build_kernel(steel/gemm/kernels/steel_gemm_fused_nax ${STEEL_NAX_HEADERS})
161162
build_kernel(steel/gemm/kernels/steel_gemm_gather_nax ${STEEL_NAX_HEADERS})
162163
build_kernel(steel/gemm/kernels/steel_gemm_splitk_nax ${STEEL_NAX_HEADERS})
164+
build_kernel(steel/gemm/kernels/steel_gemm_segmented_nax
165+
${STEEL_NAX_HEADERS})
163166

164167
build_kernel(quantized_nax quantized_nax.h ${STEEL_NAX_HEADERS})
165168
build_kernel(fp_quantized_nax fp4.h fp8.h fp_quantized_nax.h

mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,4 @@
3535

3636
instantiate_gather_mm_shapes_helper(float16, half, float16, half);
3737
instantiate_gather_mm_shapes_helper(bfloat16, bfloat, bfloat16, bfloat);
38+
instantiate_gather_mm_shapes_helper(float32, float, float32, float);
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
// Copyright © 2026 Apple Inc.
2+
3+
using namespace mlx::steel;
4+
5+
constant bool segments_contiguous [[function_constant(199)]];
6+
constant bool align_M [[function_constant(200)]];
7+
constant bool align_N [[function_constant(201)]];
8+
9+
template <
10+
typename T,
11+
int BM,
12+
int BN,
13+
int BK,
14+
int WM,
15+
int WN,
16+
bool transpose_a,
17+
bool transpose_b,
18+
typename AccumType = float>
19+
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]]
20+
void segmented_mm_nax(
21+
const device T* A [[buffer(0)]],
22+
const device T* B [[buffer(1)]],
23+
const device uint32_t* segments [[buffer(2)]],
24+
device T* C [[buffer(3)]],
25+
const constant GEMMParams* params [[buffer(4)]],
26+
uint simd_group_id [[simdgroup_index_in_threadgroup]],
27+
uint3 tid [[threadgroup_position_in_grid]]) {
28+
const int tid_m = (BK > 64) ? tid.y : tid.z;
29+
const int tid_n = (BK > 64) ? tid.x : tid.y;
30+
const int tid_s = (BK > 64) ? tid.z : tid.x;
31+
32+
const int c_row = tid_m * BM;
33+
const int c_col = tid_n * BN;
34+
const size_t c_row_long = size_t(c_row);
35+
const size_t c_col_long = size_t(c_col);
36+
37+
if (params->tiles_n <= static_cast<int>(tid_n) ||
38+
params->tiles_m <= static_cast<int>(tid_m)) {
39+
return;
40+
}
41+
42+
A += transpose_a ? c_row_long : c_row_long * params->lda;
43+
B += transpose_b ? c_col_long * params->ldb : c_col_long;
44+
C += c_row_long * params->ldd + c_col_long;
45+
46+
uint32_t k_start, k_end;
47+
if (segments_contiguous) {
48+
k_start = segments[2 * tid_s];
49+
k_end = segments[2 * tid_s + 1];
50+
} else {
51+
k_start = segments[tid_s];
52+
k_end = segments[tid_s + 1];
53+
}
54+
A += transpose_a ? k_start * params->lda : k_start;
55+
B += transpose_b ? k_start : k_start * params->ldb;
56+
C += tid_s * params->batch_stride_d;
57+
58+
constexpr short SM = BM / WM;
59+
constexpr short SN = BN / WN;
60+
constexpr short SK = 32;
61+
62+
constexpr short TM = SM / 16;
63+
constexpr short TN = SN / 16;
64+
65+
const short tm = SM * (simd_group_id / WN);
66+
const short tn = SN * (simd_group_id % WN);
67+
68+
const int sgp_sm_int =
69+
align_M ? int(SM) : min(int(SM), params->M - (c_row + tm));
70+
const short sgp_sm = short(sgp_sm_int);
71+
const bool is_unaligned_sm = align_M ? false : (sgp_sm != SM);
72+
73+
const int sgp_sn_int =
74+
align_N ? int(SN) : min(int(SN), params->N - (c_col + tn));
75+
const short sgp_sn = short(sgp_sn_int);
76+
const bool is_unaligned_sn = align_N ? false : (sgp_sn != SN);
77+
78+
A += transpose_a ? tm : (tm * params->lda);
79+
B += transpose_b ? (tn * params->ldb) : tn;
80+
C += tm * params->ldd + tn;
81+
82+
NAXTile<AccumType, TM, TN> Dtile;
83+
Dtile.clear();
84+
85+
const int segment_k_size = k_end - k_start;
86+
const int segment_k_iters = segment_k_size / BK;
87+
const bool segment_k_aligned = (segment_k_size % BK) == 0;
88+
89+
dispatch_bool(segment_k_aligned, [&](auto kAlignedK) {
90+
dispatch_bool(align_M || !is_unaligned_sm, [&](auto kAlignedM) {
91+
dispatch_bool(align_N || !is_unaligned_sn, [&](auto kAlignedN) {
92+
Dtile = gemm_loop<
93+
T,
94+
SM,
95+
SN,
96+
SK,
97+
BK,
98+
transpose_a,
99+
transpose_b,
100+
kAlignedM.value,
101+
kAlignedN.value,
102+
kAlignedK.value,
103+
AccumType>(
104+
A,
105+
B,
106+
params->lda,
107+
params->ldb,
108+
segment_k_size,
109+
segment_k_iters,
110+
sgp_sm,
111+
sgp_sn);
112+
});
113+
});
114+
});
115+
116+
dispatch_bool(align_M || !is_unaligned_sm, [&](auto kAlignedM) {
117+
dispatch_bool(align_N || !is_unaligned_sn, [&](auto kAlignedN) {
118+
if constexpr (kAlignedM && kAlignedN) {
119+
Dtile.store(C, int(params->ldd));
120+
} else {
121+
Dtile.store_safe(C, int(params->ldd), short2(sgp_sn, sgp_sm));
122+
}
123+
});
124+
});
125+
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
// Copyright © 2026 Apple Inc.
2+
3+
#include <metal_stdlib>
4+
5+
#include "mlx/backend/metal/kernels/utils.h"
6+
7+
#include "mlx/backend/metal/kernels/steel/gemm/gemm_nax.h"
8+
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented_nax.h"
9+
10+
// clang-format off
11+
#define instantiate_segmented_mm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
12+
instantiate_kernel( \
13+
"steel_segmented_mm_nax_" #tname "_" #iname "_" #oname \
14+
"_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn, \
15+
segmented_mm_nax, itype, bm, bn, bk, wm, wn, trans_a, trans_b, float)
16+
17+
#define instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
18+
instantiate_segmented_mm(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
19+
instantiate_segmented_mm(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
20+
instantiate_segmented_mm(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
21+
instantiate_segmented_mm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
22+
23+
#define instantiate_segmented_mm_shapes_helper(iname, itype, oname, otype) \
24+
instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 64, 64, 256, 2, 2) \
25+
instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 64, 64, 128, 2, 2) \
26+
instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 64, 64, 64, 2, 2)
27+
28+
instantiate_segmented_mm_shapes_helper(float16, half, float16, half);
29+
instantiate_segmented_mm_shapes_helper(bfloat16, bfloat, bfloat16, bfloat);
30+
instantiate_segmented_mm_shapes_helper(float32, float, float32, float);
31+
// clang-format on

mlx/backend/metal/matmul.cpp

Lines changed: 52 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2478,15 +2478,36 @@ void segmented_mm(
24782478
char devc = d.get_architecture().back();
24792479
GEMM_TPARAM_MACRO(devc)
24802480

2481+
bool use_nax = metal::is_nax_available() &&
2482+
(env::enable_tf32() || out.dtype() != float32);
2483+
24812484
const bool align_M = (M % bm) == 0;
24822485
const bool align_N = (N % bn) == 0;
24832486

2484-
// Define the kernel name
2487+
metal::MTLFCList func_consts = {
2488+
{&segments_contiguous, MTL::DataType::DataTypeBool, 199},
2489+
{&align_M, MTL::DataType::DataTypeBool, 200},
2490+
{&align_N, MTL::DataType::DataTypeBool, 201},
2491+
};
2492+
24852493
std::string base_name;
24862494
base_name.reserve(128);
2495+
base_name += "steel_segmented_mm_";
2496+
2497+
// Use NAX kernel if available
2498+
if (use_nax) {
2499+
int average_k = K / batch_size_out;
2500+
bm = 64;
2501+
bn = 64;
2502+
bk = (average_k >= 256) ? 256 : (average_k >= 128) ? 128 : 64;
2503+
wm = 2;
2504+
wn = 2;
2505+
2506+
base_name += "nax_";
2507+
}
2508+
24872509
concatenate(
24882510
base_name,
2489-
"steel_segmented_mm_",
24902511
transpose_a ? 't' : 'n',
24912512
transpose_b ? 't' : 'n',
24922513
"_",
@@ -2504,13 +2525,6 @@ void segmented_mm(
25042525
"_wn",
25052526
wn);
25062527

2507-
metal::MTLFCList func_consts = {
2508-
{&segments_contiguous, MTL::DataType::DataTypeBool, 199},
2509-
{&align_M, MTL::DataType::DataTypeBool, 200},
2510-
{&align_N, MTL::DataType::DataTypeBool, 201},
2511-
};
2512-
2513-
// And the kernel hash that includes the function constants
25142528
std::string hash_name;
25152529
hash_name.reserve(128);
25162530
concatenate(
@@ -2524,19 +2538,32 @@ void segmented_mm(
25242538
align_N ? 't' : 'n');
25252539

25262540
// Get and set the kernel
2527-
auto kernel = get_steel_gemm_segmented_kernel(
2528-
d,
2529-
base_name,
2530-
hash_name,
2531-
func_consts,
2532-
out,
2533-
transpose_a,
2534-
transpose_b,
2535-
bm,
2536-
bn,
2537-
bk,
2538-
wm,
2539-
wn);
2541+
auto kernel = (use_nax) ? get_steel_gemm_segmented_nax_kernel(
2542+
d,
2543+
base_name,
2544+
hash_name,
2545+
func_consts,
2546+
out,
2547+
transpose_a,
2548+
transpose_b,
2549+
bm,
2550+
bn,
2551+
bk,
2552+
wm,
2553+
wn)
2554+
: get_steel_gemm_segmented_kernel(
2555+
d,
2556+
base_name,
2557+
hash_name,
2558+
func_consts,
2559+
out,
2560+
transpose_a,
2561+
transpose_b,
2562+
bm,
2563+
bn,
2564+
bk,
2565+
wm,
2566+
wn);
25402567
compute_encoder.set_compute_pipeline_state(kernel);
25412568

25422569
// Prepare the matmul params
@@ -2557,8 +2584,9 @@ void segmented_mm(
25572584

25582585
// Prepare the grid
25592586
MTL::Size group_dims = MTL::Size(32, wn, wm);
2560-
MTL::Size grid_dims =
2561-
MTL::Size(params.tiles_n, params.tiles_m, batch_size_out);
2587+
MTL::Size grid_dims = (use_nax && bk == 64)
2588+
? MTL::Size(batch_size_out, params.tiles_n, params.tiles_m)
2589+
: MTL::Size(params.tiles_n, params.tiles_m, batch_size_out);
25622590

25632591
// Launch kernel
25642592
compute_encoder.set_input_array(a, 0);

mlx/backend/metal/nojit_kernels.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,22 @@ MTL::ComputePipelineState* get_steel_gemm_splitk_nax_kernel(
368368
return d.get_kernel(kernel_name, hash_name, func_consts);
369369
}
370370

371+
MTL::ComputePipelineState* get_steel_gemm_segmented_nax_kernel(
372+
metal::Device& d,
373+
const std::string& kernel_name,
374+
const std::string& hash_name,
375+
const metal::MTLFCList& func_consts,
376+
const array&,
377+
bool,
378+
bool,
379+
int,
380+
int,
381+
int,
382+
int,
383+
int) {
384+
return d.get_kernel(kernel_name, hash_name, func_consts);
385+
}
386+
371387
MTL::ComputePipelineState* get_qmm_nax_kernel(
372388
metal::Device& d,
373389
const std::string& kernel_name,

0 commit comments

Comments
 (0)