|
| 1 | +// Copyright © 2026 Apple Inc. |
| 2 | + |
| 3 | +using namespace mlx::steel; |
| 4 | + |
| 5 | +constant bool segments_contiguous [[function_constant(199)]]; |
| 6 | +constant bool align_M [[function_constant(200)]]; |
| 7 | +constant bool align_N [[function_constant(201)]]; |
| 8 | + |
| 9 | +template < |
| 10 | + typename T, |
| 11 | + int BM, |
| 12 | + int BN, |
| 13 | + int BK, |
| 14 | + int WM, |
| 15 | + int WN, |
| 16 | + bool transpose_a, |
| 17 | + bool transpose_b, |
| 18 | + typename AccumType = float> |
| 19 | +[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] |
| 20 | +void segmented_mm_nax( |
| 21 | + const device T* A [[buffer(0)]], |
| 22 | + const device T* B [[buffer(1)]], |
| 23 | + const device uint32_t* segments [[buffer(2)]], |
| 24 | + device T* C [[buffer(3)]], |
| 25 | + const constant GEMMParams* params [[buffer(4)]], |
| 26 | + uint simd_group_id [[simdgroup_index_in_threadgroup]], |
| 27 | + uint3 tid [[threadgroup_position_in_grid]]) { |
| 28 | + const int tid_m = (BK > 64) ? tid.y : tid.z; |
| 29 | + const int tid_n = (BK > 64) ? tid.x : tid.y; |
| 30 | + const int tid_s = (BK > 64) ? tid.z : tid.x; |
| 31 | + |
| 32 | + const int c_row = tid_m * BM; |
| 33 | + const int c_col = tid_n * BN; |
| 34 | + const size_t c_row_long = size_t(c_row); |
| 35 | + const size_t c_col_long = size_t(c_col); |
| 36 | + |
| 37 | + if (params->tiles_n <= static_cast<int>(tid_n) || |
| 38 | + params->tiles_m <= static_cast<int>(tid_m)) { |
| 39 | + return; |
| 40 | + } |
| 41 | + |
| 42 | + A += transpose_a ? c_row_long : c_row_long * params->lda; |
| 43 | + B += transpose_b ? c_col_long * params->ldb : c_col_long; |
| 44 | + C += c_row_long * params->ldd + c_col_long; |
| 45 | + |
| 46 | + uint32_t k_start, k_end; |
| 47 | + if (segments_contiguous) { |
| 48 | + k_start = segments[2 * tid_s]; |
| 49 | + k_end = segments[2 * tid_s + 1]; |
| 50 | + } else { |
| 51 | + k_start = segments[tid_s]; |
| 52 | + k_end = segments[tid_s + 1]; |
| 53 | + } |
| 54 | + A += transpose_a ? k_start * params->lda : k_start; |
| 55 | + B += transpose_b ? k_start : k_start * params->ldb; |
| 56 | + C += tid_s * params->batch_stride_d; |
| 57 | + |
| 58 | + constexpr short SM = BM / WM; |
| 59 | + constexpr short SN = BN / WN; |
| 60 | + constexpr short SK = 32; |
| 61 | + |
| 62 | + constexpr short TM = SM / 16; |
| 63 | + constexpr short TN = SN / 16; |
| 64 | + |
| 65 | + const short tm = SM * (simd_group_id / WN); |
| 66 | + const short tn = SN * (simd_group_id % WN); |
| 67 | + |
| 68 | + const int sgp_sm_int = |
| 69 | + align_M ? int(SM) : min(int(SM), params->M - (c_row + tm)); |
| 70 | + const short sgp_sm = short(sgp_sm_int); |
| 71 | + const bool is_unaligned_sm = align_M ? false : (sgp_sm != SM); |
| 72 | + |
| 73 | + const int sgp_sn_int = |
| 74 | + align_N ? int(SN) : min(int(SN), params->N - (c_col + tn)); |
| 75 | + const short sgp_sn = short(sgp_sn_int); |
| 76 | + const bool is_unaligned_sn = align_N ? false : (sgp_sn != SN); |
| 77 | + |
| 78 | + A += transpose_a ? tm : (tm * params->lda); |
| 79 | + B += transpose_b ? (tn * params->ldb) : tn; |
| 80 | + C += tm * params->ldd + tn; |
| 81 | + |
| 82 | + NAXTile<AccumType, TM, TN> Dtile; |
| 83 | + Dtile.clear(); |
| 84 | + |
| 85 | + const int segment_k_size = k_end - k_start; |
| 86 | + const int segment_k_iters = segment_k_size / BK; |
| 87 | + const bool segment_k_aligned = (segment_k_size % BK) == 0; |
| 88 | + |
| 89 | + dispatch_bool(segment_k_aligned, [&](auto kAlignedK) { |
| 90 | + dispatch_bool(align_M || !is_unaligned_sm, [&](auto kAlignedM) { |
| 91 | + dispatch_bool(align_N || !is_unaligned_sn, [&](auto kAlignedN) { |
| 92 | + Dtile = gemm_loop< |
| 93 | + T, |
| 94 | + SM, |
| 95 | + SN, |
| 96 | + SK, |
| 97 | + BK, |
| 98 | + transpose_a, |
| 99 | + transpose_b, |
| 100 | + kAlignedM.value, |
| 101 | + kAlignedN.value, |
| 102 | + kAlignedK.value, |
| 103 | + AccumType>( |
| 104 | + A, |
| 105 | + B, |
| 106 | + params->lda, |
| 107 | + params->ldb, |
| 108 | + segment_k_size, |
| 109 | + segment_k_iters, |
| 110 | + sgp_sm, |
| 111 | + sgp_sn); |
| 112 | + }); |
| 113 | + }); |
| 114 | + }); |
| 115 | + |
| 116 | + dispatch_bool(align_M || !is_unaligned_sm, [&](auto kAlignedM) { |
| 117 | + dispatch_bool(align_N || !is_unaligned_sn, [&](auto kAlignedN) { |
| 118 | + if constexpr (kAlignedM && kAlignedN) { |
| 119 | + Dtile.store(C, int(params->ldd)); |
| 120 | + } else { |
| 121 | + Dtile.store_safe(C, int(params->ldd), short2(sgp_sn, sgp_sm)); |
| 122 | + } |
| 123 | + }); |
| 124 | + }); |
| 125 | +} |
0 commit comments