Skip to content

Commit 38ad257

Browse files
authored
[Metal][Performance]: Add split-K for quantized matmul (small M) (#3120)
1 parent 70a0da6 commit 38ad257

5 files changed

Lines changed: 275 additions & 16 deletions

File tree

mlx/backend/metal/kernels/fp_quantized.h

Lines changed: 61 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -642,6 +642,7 @@ METAL_FUNC void fp_qmm_t_impl(
642642
const constant int& K,
643643
const constant int& N,
644644
const constant int& M,
645+
const constant int& K_eff,
645646
uint3 tid [[threadgroup_position_in_grid]],
646647
uint lid [[thread_index_in_threadgroup]],
647648
uint simd_gid [[simdgroup_index_in_threadgroup]],
@@ -695,7 +696,7 @@ METAL_FUNC void fp_qmm_t_impl(
695696

696697
if (num_els < BM) {
697698
if (!aligned_N && num_outs < BN) {
698-
for (int k = 0; k < K; k += BK) {
699+
for (int k = 0; k < K_eff; k += BK) {
699700
threadgroup_barrier(mem_flags::mem_threadgroup);
700701
loader_x.load_safe(short2(BK, num_els));
701702
loader_w.load_safe(short2(BK, num_outs));
@@ -705,7 +706,7 @@ METAL_FUNC void fp_qmm_t_impl(
705706
loader_w.next();
706707
}
707708
} else {
708-
for (int k = 0; k < K; k += BK) {
709+
for (int k = 0; k < K_eff; k += BK) {
709710
threadgroup_barrier(mem_flags::mem_threadgroup);
710711
loader_x.load_safe(short2(BK, num_els));
711712
loader_w.load_unsafe();
@@ -717,7 +718,7 @@ METAL_FUNC void fp_qmm_t_impl(
717718
}
718719
} else {
719720
if (!aligned_N && num_outs < BN) {
720-
for (int k = 0; k < K; k += BK) {
721+
for (int k = 0; k < K_eff; k += BK) {
721722
threadgroup_barrier(mem_flags::mem_threadgroup);
722723
loader_x.load_unsafe();
723724
loader_w.load_safe(short2(BK, num_outs));
@@ -727,7 +728,7 @@ METAL_FUNC void fp_qmm_t_impl(
727728
loader_w.next();
728729
}
729730
} else {
730-
for (int k = 0; k < K; k += BK) {
731+
for (int k = 0; k < K_eff; k += BK) {
731732
threadgroup_barrier(mem_flags::mem_threadgroup);
732733
loader_x.load_unsafe();
733734
loader_w.load_unsafe();
@@ -1219,7 +1220,7 @@ template <
12191220
tid);
12201221
}
12211222
fp_qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
1222-
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
1223+
w, scales, x, y, Xs, Ws, K, N, M, K, tid, lid, simd_gid, simd_lid);
12231224
}
12241225

12251226
template <
@@ -1486,7 +1487,61 @@ template <
14861487
s_strides,
14871488
tid);
14881489
fp_qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
1489-
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
1490+
w, scales, x, y, Xs, Ws, K, N, M, K, tid, lid, simd_gid, simd_lid);
1491+
}
1492+
1493+
template <
1494+
typename T,
1495+
const int group_size,
1496+
const int bits,
1497+
const bool aligned_N,
1498+
const int BM = 32,
1499+
const int BK = 32,
1500+
const int BN = 32>
1501+
[[kernel]] void fp_qmm_t_splitk(
1502+
const device uint32_t* w [[buffer(0)]],
1503+
const device uint8_t* scales [[buffer(1)]],
1504+
const device T* x [[buffer(2)]],
1505+
device T* y [[buffer(3)]],
1506+
const constant int& K [[buffer(4)]],
1507+
const constant int& N [[buffer(5)]],
1508+
const constant int& M [[buffer(6)]],
1509+
const constant int& k_partition_size [[buffer(7)]],
1510+
const constant int& split_k_partition_stride [[buffer(8)]],
1511+
uint3 tid [[threadgroup_position_in_grid]],
1512+
uint lid [[thread_index_in_threadgroup]],
1513+
uint simd_gid [[simdgroup_index_in_threadgroup]],
1514+
uint simd_lid [[thread_index_in_simdgroup]]) {
1515+
(void)lid;
1516+
1517+
constexpr int BK_padded = (BK + 16 / sizeof(T));
1518+
constexpr int pack_factor = get_pack_factor<8, bits>();
1519+
constexpr int bytes_per_pack = get_bytes_per_pack();
1520+
threadgroup T Xs[BM * BK_padded];
1521+
threadgroup T Ws[BN * BK_padded];
1522+
const int k_start = tid.z * k_partition_size;
1523+
x += k_start;
1524+
1525+
auto wl = (const device uint8_t*)w;
1526+
wl += k_start * bytes_per_pack / pack_factor;
1527+
scales += k_start / group_size;
1528+
y += tid.z * static_cast<int64_t>(split_k_partition_stride);
1529+
1530+
fp_qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
1531+
(const device uint32_t*)wl,
1532+
scales,
1533+
x,
1534+
y,
1535+
Xs,
1536+
Ws,
1537+
K,
1538+
N,
1539+
M,
1540+
k_partition_size,
1541+
tid,
1542+
lid,
1543+
simd_gid,
1544+
simd_lid);
14901545
}
14911546

14921547
template <

mlx/backend/metal/kernels/fp_quantized.metal

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,9 @@
107107

108108
#define instantiate_quantized_all_splitk(type, mode, group_size, bits) \
109109
instantiate_quantized_split_k(mode, qvm_split_k, type, 8, group_size, bits) \
110-
instantiate_quantized_split_k(mode, qvm_split_k, type, 32, group_size, bits)
110+
instantiate_quantized_split_k(mode, qvm_split_k, type, 32, group_size, bits) \
111+
instantiate_quantized_aligned(mode, qmm_t_splitk, type, true, group_size, bits) \
112+
instantiate_quantized_aligned(mode, qmm_t_splitk, type, false, group_size, bits)
111113

112114
#define instantiate_quantized_all_rhs(type, mode, group_size, bits) \
113115
instantiate_gather_qmm_rhs(fp_gather_qmm_rhs, gather_qmm_rhs_nt, type, 16, 32, 32, 1, 2, true, mode, group_size, bits) \

mlx/backend/metal/kernels/quantized.h

Lines changed: 94 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1102,6 +1102,7 @@ METAL_FUNC void qmm_t_impl(
11021102
const constant int& K,
11031103
const constant int& N,
11041104
const constant int& M,
1105+
const constant int& K_eff,
11051106
uint3 tid [[threadgroup_position_in_grid]],
11061107
uint lid [[thread_index_in_threadgroup]],
11071108
uint simd_gid [[simdgroup_index_in_threadgroup]],
@@ -1156,7 +1157,7 @@ METAL_FUNC void qmm_t_impl(
11561157

11571158
if (num_els < BM) {
11581159
if (!aligned_N && num_outs < BN) {
1159-
for (int k = 0; k < K; k += BK) {
1160+
for (int k = 0; k < K_eff; k += BK) {
11601161
threadgroup_barrier(mem_flags::mem_threadgroup);
11611162
loader_x.load_safe(short2(BK, num_els));
11621163
loader_w.load_safe(short2(BK, num_outs));
@@ -1166,7 +1167,7 @@ METAL_FUNC void qmm_t_impl(
11661167
loader_w.next();
11671168
}
11681169
} else {
1169-
for (int k = 0; k < K; k += BK) {
1170+
for (int k = 0; k < K_eff; k += BK) {
11701171
threadgroup_barrier(mem_flags::mem_threadgroup);
11711172
loader_x.load_safe(short2(BK, num_els));
11721173
loader_w.load_unsafe();
@@ -1178,7 +1179,7 @@ METAL_FUNC void qmm_t_impl(
11781179
}
11791180
} else {
11801181
if (!aligned_N && num_outs < BN) {
1181-
for (int k = 0; k < K; k += BK) {
1182+
for (int k = 0; k < K_eff; k += BK) {
11821183
threadgroup_barrier(mem_flags::mem_threadgroup);
11831184
loader_x.load_unsafe();
11841185
loader_w.load_safe(short2(BK, num_outs));
@@ -1188,7 +1189,7 @@ METAL_FUNC void qmm_t_impl(
11881189
loader_w.next();
11891190
}
11901191
} else {
1191-
for (int k = 0; k < K; k += BK) {
1192+
for (int k = 0; k < K_eff; k += BK) {
11921193
threadgroup_barrier(mem_flags::mem_threadgroup);
11931194
loader_x.load_unsafe();
11941195
loader_w.load_unsafe();
@@ -1759,7 +1760,80 @@ template <
17591760
tid);
17601761
}
17611762
qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
1762-
w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
1763+
w,
1764+
scales,
1765+
biases,
1766+
x,
1767+
y,
1768+
Xs,
1769+
Ws,
1770+
K,
1771+
N,
1772+
M,
1773+
K,
1774+
tid,
1775+
lid,
1776+
simd_gid,
1777+
simd_lid);
1778+
}
1779+
1780+
template <
1781+
typename T,
1782+
const int group_size,
1783+
const int bits,
1784+
const bool aligned_N,
1785+
const int BM = 32,
1786+
const int BK = 32,
1787+
const int BN = 32>
1788+
[[kernel]] void affine_qmm_t_splitk(
1789+
const device uint32_t* w [[buffer(0)]],
1790+
const device T* scales [[buffer(1)]],
1791+
const device T* biases [[buffer(2)]],
1792+
const device T* x [[buffer(3)]],
1793+
device T* y [[buffer(4)]],
1794+
const constant int& K [[buffer(5)]],
1795+
const constant int& N [[buffer(6)]],
1796+
const constant int& M [[buffer(7)]],
1797+
const constant int& k_partition_size [[buffer(8)]],
1798+
const constant int& split_k_partition_stride [[buffer(9)]],
1799+
uint3 tid [[threadgroup_position_in_grid]],
1800+
uint lid [[thread_index_in_threadgroup]],
1801+
uint simd_gid [[simdgroup_index_in_threadgroup]],
1802+
uint simd_lid [[thread_index_in_simdgroup]]) {
1803+
(void)lid;
1804+
1805+
constexpr int BK_padded = (BK + 16 / sizeof(T));
1806+
constexpr int pack_factor = get_pack_factor<bits, 8>();
1807+
constexpr int bytes_per_pack = get_bytes_per_pack<bits>();
1808+
1809+
threadgroup T Xs[BM * BK_padded];
1810+
threadgroup T Ws[BN * BK_padded];
1811+
1812+
const int k_start = tid.z * k_partition_size;
1813+
x += k_start;
1814+
1815+
auto wl = (const device uint8_t*)w;
1816+
wl += k_start * bytes_per_pack / pack_factor;
1817+
scales += k_start / group_size;
1818+
biases += k_start / group_size;
1819+
y += tid.z * static_cast<int64_t>(split_k_partition_stride);
1820+
1821+
qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
1822+
(const device uint32_t*)wl,
1823+
scales,
1824+
biases,
1825+
x,
1826+
y,
1827+
Xs,
1828+
Ws,
1829+
K,
1830+
N,
1831+
M,
1832+
k_partition_size,
1833+
tid,
1834+
lid,
1835+
simd_gid,
1836+
simd_lid);
17631837
}
17641838

17651839
template <
@@ -2073,7 +2147,21 @@ template <
20732147
b_strides,
20742148
tid);
20752149
qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
2076-
w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
2150+
w,
2151+
scales,
2152+
biases,
2153+
x,
2154+
y,
2155+
Xs,
2156+
Ws,
2157+
K,
2158+
N,
2159+
M,
2160+
K,
2161+
tid,
2162+
lid,
2163+
simd_gid,
2164+
simd_lid);
20772165
}
20782166

20792167
template <

mlx/backend/metal/kernels/quantized.metal

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,20 @@
109109

110110
#define instantiate_quantized_all_splitk(type, group_size, bits) \
111111
instantiate_quantized_split_k(affine_qvm_split_k, type, group_size, bits, 8) \
112-
instantiate_quantized_split_k(affine_qvm_split_k, type, group_size, bits, 32)
112+
instantiate_quantized_split_k(affine_qvm_split_k, type, group_size, bits, 32) \
113+
114+
#define instantiate_quantized_splitk_qmm(name, type, group_size, bits, aligned) \
115+
instantiate_kernel( \
116+
#name "_" #type "_gs_" #group_size "_b_" #bits "_alN_" #aligned, \
117+
name, \
118+
type, \
119+
group_size, \
120+
bits, \
121+
aligned)
122+
123+
#define instantiate_quantized_all_splitk_qmm(type, group_size, bits) \
124+
instantiate_quantized_splitk_qmm(affine_qmm_t_splitk, type, group_size, bits, true) \
125+
instantiate_quantized_splitk_qmm(affine_qmm_t_splitk, type, group_size, bits, false)
113126

114127
#define instantiate_quantized_all_rhs(type, group_size, bits) \
115128
instantiate_gather_qmm_rhs(affine_gather_qmm_rhs, affine_gather_qmm_rhs_nt, type, group_size, bits, 16, 32, 32, 1, 2, true) \
@@ -121,6 +134,7 @@
121134
instantiate_quantized_all_aligned(type, group_size, bits) \
122135
instantiate_quantized_all_quad(type, group_size, bits) \
123136
instantiate_quantized_all_splitk(type, group_size, bits) \
137+
instantiate_quantized_all_splitk_qmm(type, group_size, bits) \
124138
instantiate_quantized_all_rhs(type, group_size, bits)
125139

126140
#define instantiate_quantized_types(group_size, bits) \

0 commit comments

Comments
 (0)