Skip to content

Commit 0c8107c

Browse files
authored
[CUDA] Heuristics for Hopper QMM (ml-explore#3173)
1 parent 5c4abd2 commit 0c8107c

13 files changed

Lines changed: 189 additions & 40 deletions

.github/actions/build-linux/action.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ runs:
2121
run: |
2222
if ${{ startsWith(inputs.toolkit, 'cuda') && runner.arch == 'arm64' }} ; then
2323
# There is no GPU in arm64 runner, use a common arch.
24-
CMAKE_ARGS="$CMAKE_ARGS -DMLX_CUDA_ARCHITECTURES=90a"
24+
CMAKE_ARGS="$CMAKE_ARGS -DMLX_CUDA_ARCHITECTURES=80"
2525
# Can not build tests and stubs when the built executables can not run.
2626
CMAKE_ARGS="$CMAKE_ARGS -DMLX_BUILD_TESTS=OFF -DMLX_BUILD_PYTHON_STUBS=OFF"
2727
fi

mlx/backend/cuda/CMakeLists.txt

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ target_sources(
5656
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
5757
${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu
5858
${CMAKE_CURRENT_SOURCE_DIR}/quantized/fp_quantize.cu
59-
${CMAKE_CURRENT_SOURCE_DIR}/quantized/qmm_sm90.cu
6059
${CMAKE_CURRENT_SOURCE_DIR}/quantized/qmv.cu
6160
${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cpp
6261
${CMAKE_CURRENT_SOURCE_DIR}/quantized/qqmm.cpp
@@ -65,6 +64,7 @@ target_sources(
6564
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
6665

6766
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/binary)
67+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/quantized/qmm)
6868
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/unary)
6969

7070
# fp4 is not available on < 12.8
@@ -145,12 +145,11 @@ if(NOT DEFINED MLX_CUDA_ARCHITECTURES)
145145
COMMAND __nvcc_device_query
146146
OUTPUT_VARIABLE MLX_CUDA_ARCHITECTURES
147147
OUTPUT_STRIP_TRAILING_WHITESPACE)
148-
set(UPGRADABLE_ARCHITECTURES "90;100;121")
149148
if(MLX_CUDA_ARCHITECTURES STREQUAL "")
150149
message(
151150
FATAL_ERROR
152151
"Can not get native CUDA arch, must set MLX_CUDA_ARCHITECTURES")
153-
elseif(MLX_CUDA_ARCHITECTURES IN_LIST UPGRADABLE_ARCHITECTURES)
152+
elseif(MLX_CUDA_ARCHITECTURES GREATER_EQUAL 90)
154153
# Use arch-specific compute capability whenever possible.
155154
set(MLX_CUDA_ARCHITECTURES "${MLX_CUDA_ARCHITECTURES}a")
156155
endif()
@@ -159,6 +158,11 @@ message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}")
159158
set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES
160159
"${MLX_CUDA_ARCHITECTURES}")
161160

161+
if(("90a" IN_LIST MLX_CUDA_ARCHITECTURES) OR ("90a-real" IN_LIST
162+
MLX_CUDA_ARCHITECTURES))
163+
target_compile_definitions(mlx PRIVATE MLX_CUDA_SM90A_ENABLED)
164+
endif()
165+
162166
# Search CUDA libs from installed python packages.
163167
if(WIN32)
164168
# Resolve paths of unfound DLL at runtime.

mlx/backend/cuda/device.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -281,8 +281,8 @@ void CommandEncoder::add_kernel_node_raw(
281281
config.blockDim = block_dim;
282282
config.dynamicSmemBytes = smem_bytes;
283283
config.stream = stream();
284+
cudaLaunchAttribute attr = {};
284285
if (use_cluster) {
285-
cudaLaunchAttribute attr;
286286
attr.id = cudaLaunchAttributeClusterDimension;
287287
attr.val.clusterDim.x = cluster_dim.x;
288288
attr.val.clusterDim.y = cluster_dim.y;
@@ -332,16 +332,16 @@ void CommandEncoder::add_kernel_node_raw(
332332
config.blockDimZ = block_dim.z;
333333
config.sharedMemBytes = smem_bytes;
334334
config.hStream = stream();
335+
CUlaunchAttribute attr = {};
335336
if (use_cluster) {
336-
CUlaunchAttribute attr = {};
337337
attr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
338338
attr.value.clusterDim.x = cluster_dim.x;
339339
attr.value.clusterDim.y = cluster_dim.y;
340340
attr.value.clusterDim.z = cluster_dim.z;
341341
config.attrs = &attr;
342342
config.numAttrs = 1;
343-
CHECK_CUDA_ERROR(cuLaunchKernelEx(&config, func, params, nullptr));
344343
}
344+
CHECK_CUDA_ERROR(cuLaunchKernelEx(&config, func, params, nullptr));
345345
return;
346346
}
347347

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
target_sources(
2+
mlx
3+
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/qmm.cpp
4+
${CMAKE_CURRENT_SOURCE_DIR}/qmm_impl_sm90_m128_n16_m1.cu
5+
${CMAKE_CURRENT_SOURCE_DIR}/qmm_impl_sm90_m128_n32_m1.cu
6+
${CMAKE_CURRENT_SOURCE_DIR}/qmm_impl_sm90_m128_n64_m2.cu
7+
${CMAKE_CURRENT_SOURCE_DIR}/qmm_impl_sm90_m128_n128_m2.cu
8+
${CMAKE_CURRENT_SOURCE_DIR}/qmm_impl_sm90_m128_n256_m2.cu)
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
// Copyright © 2026 Apple Inc.
2+
3+
#include "mlx/backend/cuda/quantized/qmm/qmm.h"
4+
5+
#include <cute/tensor.hpp>
6+
7+
namespace mlx::core {
8+
9+
#if defined(MLX_CUDA_SM90A_ENABLED)
10+
// Defined in qmm_impl_sm90_xxx.cu files.
11+
template <typename TileShape, typename ClusterShape>
12+
void qmm_impl_sm90(
13+
const array& x,
14+
const array& w,
15+
const array& scales,
16+
const array& biases,
17+
array& out,
18+
int bits,
19+
int group_size,
20+
cu::CommandEncoder& encoder,
21+
Stream s);
22+
#endif // defined(MLX_CUDA_SM90A_ENABLED)
23+
24+
void qmm_sm90(
25+
const array& x,
26+
const array& w,
27+
const array& scales,
28+
const array& biases,
29+
array& out,
30+
int bits,
31+
int group_size,
32+
cu::CommandEncoder& encoder,
33+
Stream s) {
34+
#if defined(MLX_CUDA_SM90A_ENABLED)
35+
auto dispatch = [&]<int tile_m, int tile_n, int cluster_m>() {
36+
using cute::Int;
37+
using TileShapeMN = cute::Shape<Int<tile_m>, Int<tile_n>>;
38+
using ClusterShape = cute::Shape<Int<cluster_m>, Int<1>, Int<1>>;
39+
qmm_impl_sm90<TileShapeMN, ClusterShape>(
40+
x, w, scales, biases, out, bits, group_size, encoder, s);
41+
};
42+
int m = out.shape(-2);
43+
if (m <= 16) {
44+
dispatch.template operator()<128, 16, 1>();
45+
} else if (m <= 32) {
46+
dispatch.template operator()<128, 32, 1>();
47+
} else if (m <= 64) {
48+
dispatch.template operator()<128, 64, 2>();
49+
} else if (m <= 128) {
50+
dispatch.template operator()<128, 128, 2>();
51+
} else {
52+
dispatch.template operator()<128, 256, 2>();
53+
}
54+
#else
55+
throw std::runtime_error(
56+
"[quantized_matmul] Hopper-only kernel is not available.");
57+
#endif // defined(MLX_CUDA_SM90A_ENABLED)
58+
}
59+
60+
} // namespace mlx::core
Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
#pragma once
44

55
#include "mlx/backend/cuda/device.h"
6-
#include "mlx/primitives.h"
76

87
#include <optional>
98

@@ -13,11 +12,10 @@ void qmm_sm90(
1312
const array& x,
1413
const array& w,
1514
const array& scales,
16-
const std::optional<array>& biases,
15+
const array& biases,
1716
array& out,
1817
int bits,
1918
int group_size,
20-
QuantizationMode mode,
2119
cu::CommandEncoder& encoder,
2220
Stream s);
2321

mlx/backend/cuda/quantized/qmm_sm90.cu renamed to mlx/backend/cuda/quantized/qmm/qmm_impl_sm90.cuh

Lines changed: 54 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
// Copyright © 2026 Apple Inc.
22

33
#include "mlx/backend/cuda/cutlass_utils.cuh"
4-
#include "mlx/backend/cuda/quantized/qmm.h"
54
#include "mlx/backend/cuda/quantized/quantized_utils.h"
65
#include "mlx/backend/gpu/copy.h"
76
#include "mlx/dtype_utils.h"
@@ -13,10 +12,20 @@
1312
#include <cutlass/gemm/device/gemm_universal_adapter.h>
1413
#include <cutlass/gemm/kernel/gemm_universal.hpp>
1514

15+
#if defined(MLX_CUDA_SM90A_ENABLED)
16+
1617
// We can't put kernel code in mlx::core due to name conflicts of "Shape".
1718
namespace cutlass_gemm {
1819

19-
template <typename GroupSize, typename Element, typename Quant, typename F>
20+
using namespace cute;
21+
22+
template <
23+
typename TileShapeMN = Shape<_128, _16>,
24+
typename ClusterShape = Shape<_1, _1, _1>,
25+
typename Element,
26+
typename Quant,
27+
typename GroupSize,
28+
typename F>
2029
void qmm_sm90(
2130
const Element* A,
2231
const Quant* B,
@@ -29,9 +38,6 @@ void qmm_sm90(
2938
int64_t l,
3039
GroupSize group_size,
3140
F&& launch_kernel) {
32-
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
33-
using namespace cute;
34-
3541
constexpr int kAlignmentA = 128 / sizeof_bits<Element>::value;
3642
constexpr int kAlignmentB = 128 / sizeof_bits<Quant>::value;
3743
constexpr int kTileShapeK =
@@ -40,8 +46,7 @@ void qmm_sm90(
4046

4147
using Arch = cutlass::arch::Sm90;
4248
using Accumulator = float;
43-
using TileShape = Shape<_128, _16, Int<kTileShapeK>>;
44-
using ClusterShape = Shape<_1, _1, _1>;
49+
using TileShape = decltype(append(TileShapeMN{}, Int<kTileShapeK>{}));
4550

4651
using Epilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
4752
Arch,
@@ -66,7 +71,7 @@ void qmm_sm90(
6671
Arch,
6772
cutlass::arch::OpClassTensorOp,
6873
// ElementA:
69-
cute::tuple<Quant, Element, Element>,
74+
tuple<Quant, Element, Element>,
7075
cutlass::layout::RowMajor,
7176
kAlignmentB,
7277
// ElementB:
@@ -101,16 +106,14 @@ void qmm_sm90(
101106

102107
auto* kernel = &cutlass::device_kernel<GemmKernel>;
103108
void* kernel_params[] = {const_cast<Gemm::Params*>(&gemm.params())};
109+
auto cluster = ClusterShape{};
104110
launch_kernel(
105111
reinterpret_cast<void*>(kernel),
106112
gemm.get_grid_shape(gemm.params()),
107113
GemmKernel::get_block_shape(),
114+
{get<0>(cluster), get<1>(cluster), get<2>(cluster)},
108115
GemmKernel::SharedStorageSize,
109116
kernel_params);
110-
#else
111-
throw std::runtime_error(
112-
"[quantized_matmul] Hopper-only kernel is not available.");
113-
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
114117
}
115118

116119
} // namespace cutlass_gemm
@@ -167,29 +170,29 @@ inline void dispatch_groups(int group_size, const char* tag, F&& f) {
167170
}
168171
}
169172

170-
void qmm_sm90(
171-
const array& x_,
173+
template <typename TileShapeMN, typename ClusterShape>
174+
void qmm_impl_sm90(
175+
const array& x,
172176
const array& w,
173177
const array& scales_,
174-
const std::optional<array>& biases_,
178+
const array& biases_,
175179
array& out,
176180
int bits,
177181
int group_size,
178-
QuantizationMode mode,
179182
cu::CommandEncoder& encoder,
180183
Stream s) {
181-
if ((mode != QuantizationMode::Affine) || !biases_) {
182-
throw std::runtime_error("qmm_sm90 NYI");
183-
}
184-
185184
const char* tag = "[quantized_matmul]";
186185
int m = out.shape(-2);
187186
int n = out.shape(-1);
188-
int k = x_.shape(-1);
187+
int k = x.shape(-1);
189188
int l = out.size() / (m * n);
190189
if (k % 64 != 0) {
191190
throw std::runtime_error(fmt::format("{} K must be multiples of 64.", tag));
192191
}
192+
if (!x.flags().row_contiguous) {
193+
throw std::runtime_error(
194+
fmt::format("{} Activations must be row contiguous.", tag));
195+
}
193196
if (!w.flags().row_contiguous) {
194197
throw std::runtime_error(
195198
fmt::format("{} Weights must be row contiguous.", tag));
@@ -198,16 +201,14 @@ void qmm_sm90(
198201
throw std::runtime_error(
199202
fmt::format("{} Scales must be row contiguous.", tag));
200203
}
201-
if (!biases_->flags().row_contiguous) {
204+
if (!biases_.flags().row_contiguous) {
202205
throw std::runtime_error(
203206
fmt::format("{} Biases must be row contiguous.", tag));
204207
}
205208

206-
// TODO: Support column-major x.
207-
array x = ensure_row_contiguous(x_, encoder, s);
208209
// FIXME: Copy happens for every call.
209210
array scales = transpose_last_2_dims(scales_, encoder, s);
210-
array biases = transpose_last_2_dims(*biases_, encoder, s);
211+
array biases = transpose_last_2_dims(biases_, encoder, s);
211212

212213
dispatch_element_types(out.dtype(), tag, [&]<typename Element>() {
213214
dispatch_quant_types(bits, tag, [&]<typename Quant>() {
@@ -231,14 +232,40 @@ void qmm_sm90(
231232
[&](auto* kernel,
232233
dim3 num_blocks,
233234
dim3 block_dims,
235+
dim3 cluster_shape,
234236
uint32_t smem_bytes,
235237
void** args) {
236-
encoder.add_kernel_node(
237-
kernel, num_blocks, block_dims, smem_bytes, args);
238+
encoder.add_kernel_node_raw(
239+
kernel,
240+
num_blocks,
241+
block_dims,
242+
cluster_shape,
243+
smem_bytes,
244+
args);
238245
});
239246
});
240247
});
241248
});
242249
}
243250

244251
} // namespace mlx::core
252+
253+
#define QMM_SM90_GPU(TileShapeMN, ClusterShape) \
254+
namespace mlx::core { \
255+
template void qmm_impl_sm90<TileShapeMN, ClusterShape>( \
256+
const array& x, \
257+
const array& w, \
258+
const array& scales, \
259+
const array& biases, \
260+
array& out, \
261+
int bits, \
262+
int group_size, \
263+
cu::CommandEncoder& encoder, \
264+
Stream s); \
265+
}
266+
267+
#else
268+
269+
#define QMM_SM90_GPU(TileShapeMN, ClusterShape)
270+
271+
#endif // defined(MLX_CUDA_SM90A_ENABLED)
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
// Copyright © 2026 Apple Inc.
2+
3+
#include "mlx/backend/cuda/quantized/qmm/qmm_impl_sm90.cuh"
4+
5+
using namespace cute;
6+
7+
using TileShapeMN = Shape<_128, _128>;
8+
using ClusterShape = Shape<_2, _1, _1>;
9+
10+
QMM_SM90_GPU(TileShapeMN, ClusterShape)
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
// Copyright © 2026 Apple Inc.
2+
3+
#include "mlx/backend/cuda/quantized/qmm/qmm_impl_sm90.cuh"
4+
5+
using namespace cute;
6+
7+
using TileShapeMN = Shape<_128, _16>;
8+
using ClusterShape = Shape<_1, _1, _1>;
9+
10+
QMM_SM90_GPU(TileShapeMN, ClusterShape)
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
// Copyright © 2026 Apple Inc.
2+
3+
#include "mlx/backend/cuda/quantized/qmm/qmm_impl_sm90.cuh"
4+
5+
using namespace cute;
6+
7+
using TileShapeMN = Shape<_128, _256>;
8+
using ClusterShape = Shape<_2, _1, _1>;
9+
10+
QMM_SM90_GPU(TileShapeMN, ClusterShape)

0 commit comments

Comments
 (0)