Skip to content

Commit b9b1bfb

Browse files
authored
Generate qmm implementaions with cmake (#3424)
1 parent 68cf2fd commit b9b1bfb

19 files changed

Lines changed: 176 additions & 263 deletions
Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,35 @@
11
target_sources(
22
mlx
3-
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/qmm.cu
4-
${CMAKE_CURRENT_SOURCE_DIR}/qmv.cu
5-
${CMAKE_CURRENT_SOURCE_DIR}/fp_qmv.cu
6-
${CMAKE_CURRENT_SOURCE_DIR}/qmm_impl_naive_m16_k.cu
7-
${CMAKE_CURRENT_SOURCE_DIR}/qmm_impl_naive_m16_n.cu
8-
${CMAKE_CURRENT_SOURCE_DIR}/qmm_impl_naive_m32_k.cu
9-
${CMAKE_CURRENT_SOURCE_DIR}/qmm_impl_naive_m32_n.cu
10-
${CMAKE_CURRENT_SOURCE_DIR}/qmm_impl_naive_m64_k.cu
11-
${CMAKE_CURRENT_SOURCE_DIR}/qmm_impl_naive_m64_n.cu
12-
${CMAKE_CURRENT_SOURCE_DIR}/qmm_impl_sm80_m16.cu
13-
${CMAKE_CURRENT_SOURCE_DIR}/qmm_impl_sm80_m32.cu
14-
${CMAKE_CURRENT_SOURCE_DIR}/qmm_impl_sm80_m64.cu
15-
${CMAKE_CURRENT_SOURCE_DIR}/qmm_impl_sm90_m128_n16_m1.cu
16-
${CMAKE_CURRENT_SOURCE_DIR}/qmm_impl_sm90_m128_n32_m1.cu
17-
${CMAKE_CURRENT_SOURCE_DIR}/qmm_impl_sm90_m128_n64_m2.cu
18-
${CMAKE_CURRENT_SOURCE_DIR}/qmm_impl_sm90_m128_n128_m2.cu
19-
${CMAKE_CURRENT_SOURCE_DIR}/qmm_impl_sm90_m128_n256_m2.cu)
3+
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/qmm.cu ${CMAKE_CURRENT_SOURCE_DIR}/qmv.cu
4+
${CMAKE_CURRENT_SOURCE_DIR}/fp_qmv.cu)
5+
6+
foreach(TileN 16 32 64 128 256)
7+
set(OUTPUT_FILE "qmm_sm90_impl_n${TileN}.cu")
8+
configure_file("${CMAKE_CURRENT_SOURCE_DIR}/qmm_sm90.cu"
9+
"${CMAKE_CURRENT_BINARY_DIR}/${OUTPUT_FILE}" @ONLY)
10+
target_sources(mlx PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/${OUTPUT_FILE})
11+
endforeach()
12+
13+
foreach(TileM 16 32 64)
14+
set(OUTPUT_FILE "qmm_sm80_impl_m${TileM}.cu")
15+
configure_file("${CMAKE_CURRENT_SOURCE_DIR}/qmm_sm80.cu"
16+
"${CMAKE_CURRENT_BINARY_DIR}/${OUTPUT_FILE}" @ONLY)
17+
target_sources(mlx PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/${OUTPUT_FILE})
18+
endforeach()
19+
20+
foreach(TileM 16 32 64)
21+
foreach(KMajor true false)
22+
foreach(HasKResidue true false)
23+
foreach(SM80 true false)
24+
if(${KMajor} AND ${HasKResidue})
25+
continue()
26+
endif()
27+
set(OUTPUT_FILE
28+
"qmm_naive_impl_m${TileM}_${KMajor}_${HasKResidue}_${SM80}.cu")
29+
configure_file("${CMAKE_CURRENT_SOURCE_DIR}/qmm_naive.cu"
30+
"${CMAKE_CURRENT_BINARY_DIR}/${OUTPUT_FILE}" @ONLY)
31+
target_sources(mlx PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/${OUTPUT_FILE})
32+
endforeach()
33+
endforeach()
34+
endforeach()
35+
endforeach()

mlx/backend/cuda/quantized/qmm/qmm.cu

Lines changed: 47 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ inline bool is_last_2_dims_row_contiguous(const array& x) {
1717
} // namespace
1818

1919
#if defined(MLX_CUDA_SM90A_ENABLED)
20-
// Defined in qmm_impl_sm90_xxx.cu files.
21-
template <typename TileShape, typename ClusterShape>
22-
void qmm_impl_sm90(
20+
// Defined in qmm_sm90.cu.
21+
template <int TileN>
22+
void qmm_sm90_impl(
2323
const array& x,
2424
const array& w,
2525
const array& scales,
@@ -83,34 +83,31 @@ void qmm_sm90(
8383
cu::CommandEncoder& encoder,
8484
Stream s) {
8585
#if defined(MLX_CUDA_SM90A_ENABLED)
86-
auto dispatch = [&]<int tile_m, int tile_n, int cluster_m>() {
87-
using cute::Int;
88-
using TileShapeMN = cute::Shape<Int<tile_m>, Int<tile_n>>;
89-
using ClusterShape = cute::Shape<Int<cluster_m>, Int<1>, Int<1>>;
90-
qmm_impl_sm90<TileShapeMN, ClusterShape>(
86+
auto dispatch = [&]<int TileN>() {
87+
qmm_sm90_impl<TileN>(
9188
x, w, scales, biases, out, bits, group_size, encoder, s);
9289
};
9390
int m = out.ndim() > 1 ? out.shape(-2) : 1;
9491
if (m <= 16) {
95-
dispatch.template operator()<128, 16, 1>();
92+
dispatch.template operator()<16>();
9693
} else if (m <= 32) {
97-
dispatch.template operator()<128, 32, 1>();
94+
dispatch.template operator()<32>();
9895
} else if (m <= 64) {
99-
dispatch.template operator()<128, 64, 2>();
96+
dispatch.template operator()<64>();
10097
} else if (m <= 128) {
101-
dispatch.template operator()<128, 128, 2>();
98+
dispatch.template operator()<128>();
10299
} else {
103-
dispatch.template operator()<128, 256, 2>();
100+
dispatch.template operator()<256>();
104101
}
105102
#else
106103
throw std::runtime_error(
107104
"[quantized_matmul] Hopper-only kernel is not available.");
108105
#endif // defined(MLX_CUDA_SM90A_ENABLED)
109106
}
110107

111-
// Defined in qmm_impl_sm80_xxx.cu files.
108+
// Defined in qmm_sm80.cu.
112109
template <int TileM>
113-
void qmm_impl_sm80(
110+
void qmm_sm80_impl(
114111
const array& x,
115112
const array& w,
116113
const array& scales,
@@ -174,7 +171,7 @@ void qmm_sm80(
174171
QuantizationMode mode,
175172
cu::CommandEncoder& encoder) {
176173
auto dispatch = [&]<int TileM>() {
177-
qmm_impl_sm80<TileM>(
174+
qmm_sm80_impl<TileM>(
178175
x,
179176
w,
180177
scales,
@@ -197,9 +194,9 @@ void qmm_sm80(
197194
}
198195
}
199196

200-
// Defined in qmm_impl_naive_xxx.cu files.
201-
template <int TileM, bool KMajor>
202-
void qmm_impl_naive(
197+
// Defined in qmm_naive.cu.
198+
template <int TileM, bool KMajor, bool HasKResidue, bool SM80>
199+
void qmm_naive_impl(
203200
const array& x,
204201
const array& w,
205202
const array& scales,
@@ -250,8 +247,8 @@ void qmm_naive(
250247
int group_size,
251248
QuantizationMode mode,
252249
cu::CommandEncoder& encoder) {
253-
auto dispatch = [&]<int TileM, bool KMajor>() {
254-
qmm_impl_naive<TileM, KMajor>(
250+
auto dispatch = [&]<int TileM, bool KMajor, bool HasKResidue, bool SM80>() {
251+
qmm_naive_impl<TileM, KMajor, HasKResidue, SM80>(
255252
x,
256253
w,
257254
scales,
@@ -264,15 +261,37 @@ void qmm_naive(
264261
mode,
265262
encoder);
266263
};
267-
dispatch_bool(transpose, [&](auto k_major) {
268-
int m = out.ndim() > 1 ? out.shape(-2) : 1;
269-
if (m <= 16) {
270-
dispatch.template operator()<16, k_major.value>();
271-
} else if (m <= 32) {
272-
dispatch.template operator()<32, k_major.value>();
264+
auto dispatch_k = [&](auto k_major, bool has_k_residue, auto&& f) {
265+
if constexpr (k_major.value) {
266+
if (has_k_residue) {
267+
throw std::invalid_argument(
268+
"[quantized_matmul] K must be multiples of group_size.");
269+
}
270+
f.template operator()<false>();
273271
} else {
274-
dispatch.template operator()<64, k_major.value>();
272+
dispatch_bool(has_k_residue, [&](auto has_k_residue) {
273+
f.template operator()<has_k_residue.value>();
274+
});
275275
}
276+
};
277+
int m = out.ndim() > 1 ? out.shape(-2) : 1;
278+
int k = x.shape(-1);
279+
bool has_k_residue = k % group_size != 0;
280+
bool sm80 = encoder.device().compute_capability_major() >= 8;
281+
dispatch_bool(transpose, [&](auto k_major) {
282+
dispatch_k(k_major, has_k_residue, [&]<bool HasKResidue>() {
283+
dispatch_bool(sm80, [&](auto sm80) {
284+
constexpr bool KMajor = k_major.value;
285+
constexpr bool SM80 = sm80.value;
286+
if (m <= 16) {
287+
dispatch.template operator()<16, KMajor, HasKResidue, SM80>();
288+
} else if (m <= 32) {
289+
dispatch.template operator()<32, KMajor, HasKResidue, SM80>();
290+
} else {
291+
dispatch.template operator()<64, KMajor, HasKResidue, SM80>();
292+
}
293+
});
294+
});
276295
});
277296
}
278297

mlx/backend/cuda/quantized/qmm/qmm_impl_naive_m16_k.cu

Lines changed: 0 additions & 5 deletions
This file was deleted.

mlx/backend/cuda/quantized/qmm/qmm_impl_naive_m16_n.cu

Lines changed: 0 additions & 5 deletions
This file was deleted.

mlx/backend/cuda/quantized/qmm/qmm_impl_naive_m32_k.cu

Lines changed: 0 additions & 5 deletions
This file was deleted.

mlx/backend/cuda/quantized/qmm/qmm_impl_naive_m32_n.cu

Lines changed: 0 additions & 5 deletions
This file was deleted.

mlx/backend/cuda/quantized/qmm/qmm_impl_naive_m64_k.cu

Lines changed: 0 additions & 5 deletions
This file was deleted.

mlx/backend/cuda/quantized/qmm/qmm_impl_naive_m64_n.cu

Lines changed: 0 additions & 5 deletions
This file was deleted.

mlx/backend/cuda/quantized/qmm/qmm_impl_sm80_m16.cu

Lines changed: 0 additions & 5 deletions
This file was deleted.

mlx/backend/cuda/quantized/qmm/qmm_impl_sm80_m32.cu

Lines changed: 0 additions & 5 deletions
This file was deleted.

0 commit comments

Comments
 (0)