11// Copyright © 2026 Apple Inc.
22
33#include " mlx/backend/cuda/cutlass_utils.cuh"
4- #include " mlx/backend/cuda/quantized/qmm.h"
54#include " mlx/backend/cuda/quantized/quantized_utils.h"
65#include " mlx/backend/gpu/copy.h"
76#include " mlx/dtype_utils.h"
1312#include < cutlass/gemm/device/gemm_universal_adapter.h>
1413#include < cutlass/gemm/kernel/gemm_universal.hpp>
1514
15+ #if defined(MLX_CUDA_SM90A_ENABLED)
16+
1617// We can't put kernel code in mlx::core due to name conflicts of "Shape".
1718namespace cutlass_gemm {
1819
19- template <typename GroupSize, typename Element, typename Quant, typename F>
20+ using namespace cute ;
21+
22+ template <
23+ typename TileShapeMN = Shape<_128, _16>,
24+ typename ClusterShape = Shape<_1, _1, _1>,
25+ typename Element,
26+ typename Quant,
27+ typename GroupSize,
28+ typename F>
2029void qmm_sm90 (
2130 const Element* A,
2231 const Quant* B,
@@ -29,9 +38,6 @@ void qmm_sm90(
2938 int64_t l,
3039 GroupSize group_size,
3140 F&& launch_kernel) {
32- #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
33- using namespace cute ;
34-
3541 constexpr int kAlignmentA = 128 / sizeof_bits<Element>::value;
3642 constexpr int kAlignmentB = 128 / sizeof_bits<Quant>::value;
3743 constexpr int kTileShapeK =
@@ -40,8 +46,7 @@ void qmm_sm90(
4046
4147 using Arch = cutlass::arch::Sm90;
4248 using Accumulator = float ;
43- using TileShape = Shape<_128, _16, Int<kTileShapeK >>;
44- using ClusterShape = Shape<_1, _1, _1>;
49+ using TileShape = decltype (append (TileShapeMN{}, Int<kTileShapeK >{}));
4550
4651 using Epilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
4752 Arch,
@@ -66,7 +71,7 @@ void qmm_sm90(
6671 Arch,
6772 cutlass::arch::OpClassTensorOp,
6873 // ElementA:
69- cute:: tuple<Quant, Element, Element>,
74+ tuple<Quant, Element, Element>,
7075 cutlass::layout::RowMajor,
7176 kAlignmentB ,
7277 // ElementB:
@@ -101,16 +106,14 @@ void qmm_sm90(
101106
102107 auto * kernel = &cutlass::device_kernel<GemmKernel>;
103108 void * kernel_params[] = {const_cast <Gemm::Params*>(&gemm.params ())};
109+ auto cluster = ClusterShape{};
104110 launch_kernel (
105111 reinterpret_cast <void *>(kernel),
106112 gemm.get_grid_shape (gemm.params ()),
107113 GemmKernel::get_block_shape (),
114+ {get<0 >(cluster), get<1 >(cluster), get<2 >(cluster)},
108115 GemmKernel::SharedStorageSize,
109116 kernel_params);
110- #else
111- throw std::runtime_error (
112- " [quantized_matmul] Hopper-only kernel is not available." );
113- #endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
114117}
115118
116119} // namespace cutlass_gemm
@@ -167,29 +170,29 @@ inline void dispatch_groups(int group_size, const char* tag, F&& f) {
167170 }
168171}
169172
170- void qmm_sm90 (
171- const array& x_,
173+ template <typename TileShapeMN, typename ClusterShape>
174+ void qmm_impl_sm90 (
175+ const array& x,
172176 const array& w,
173177 const array& scales_,
174- const std::optional< array> & biases_,
178+ const array& biases_,
175179 array& out,
176180 int bits,
177181 int group_size,
178- QuantizationMode mode,
179182 cu::CommandEncoder& encoder,
180183 Stream s) {
181- if ((mode != QuantizationMode::Affine) || !biases_) {
182- throw std::runtime_error (" qmm_sm90 NYI" );
183- }
184-
185184 const char * tag = " [quantized_matmul]" ;
186185 int m = out.shape (-2 );
187186 int n = out.shape (-1 );
188- int k = x_ .shape (-1 );
187+ int k = x .shape (-1 );
189188 int l = out.size () / (m * n);
190189 if (k % 64 != 0 ) {
191190 throw std::runtime_error (fmt::format (" {} K must be multiples of 64." , tag));
192191 }
192+ if (!x.flags ().row_contiguous ) {
193+ throw std::runtime_error (
194+ fmt::format (" {} Activations must be row contiguous." , tag));
195+ }
193196 if (!w.flags ().row_contiguous ) {
194197 throw std::runtime_error (
195198 fmt::format (" {} Weights must be row contiguous." , tag));
@@ -198,16 +201,14 @@ void qmm_sm90(
198201 throw std::runtime_error (
199202 fmt::format (" {} Scales must be row contiguous." , tag));
200203 }
201- if (!biases_-> flags ().row_contiguous ) {
204+ if (!biases_. flags ().row_contiguous ) {
202205 throw std::runtime_error (
203206 fmt::format (" {} Biases must be row contiguous." , tag));
204207 }
205208
206- // TODO: Support column-major x.
207- array x = ensure_row_contiguous (x_, encoder, s);
208209 // FIXME: Copy happens for every call.
209210 array scales = transpose_last_2_dims (scales_, encoder, s);
210- array biases = transpose_last_2_dims (* biases_, encoder, s);
211+ array biases = transpose_last_2_dims (biases_, encoder, s);
211212
212213 dispatch_element_types (out.dtype (), tag, [&]<typename Element>() {
213214 dispatch_quant_types (bits, tag, [&]<typename Quant>() {
@@ -231,14 +232,40 @@ void qmm_sm90(
231232 [&](auto * kernel,
232233 dim3 num_blocks,
233234 dim3 block_dims,
235+ dim3 cluster_shape,
234236 uint32_t smem_bytes,
235237 void ** args) {
236- encoder.add_kernel_node (
237- kernel, num_blocks, block_dims, smem_bytes, args);
238+ encoder.add_kernel_node_raw (
239+ kernel,
240+ num_blocks,
241+ block_dims,
242+ cluster_shape,
243+ smem_bytes,
244+ args);
238245 });
239246 });
240247 });
241248 });
242249}
243250
244251} // namespace mlx::core
252+
253+ #define QMM_SM90_GPU (TileShapeMN, ClusterShape ) \
254+ namespace mlx ::core { \
255+ template void qmm_impl_sm90<TileShapeMN, ClusterShape>( \
256+ const array& x, \
257+ const array& w, \
258+ const array& scales, \
259+ const array& biases, \
260+ array& out, \
261+ int bits, \
262+ int group_size, \
263+ cu::CommandEncoder& encoder, \
264+ Stream s); \
265+ }
266+
267+ #else
268+
269+ #define QMM_SM90_GPU (TileShapeMN, ClusterShape )
270+
271+ #endif // defined(MLX_CUDA_SM90A_ENABLED)
0 commit comments