Skip to content

Commit 490a8a6

Browse files
committed
[CUDA] gather_matmul
1 parent 6cef1e9 commit 490a8a6

File tree

7 files changed

+399
-69
lines changed

7 files changed

+399
-69
lines changed

mlx/backend/cuda/CMakeLists.txt

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ target_sources(
3131
${CMAKE_CURRENT_SOURCE_DIR}/gemms/block_mask.cu
3232
${CMAKE_CURRENT_SOURCE_DIR}/gemms/gemv.cu
3333
${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm.cpp
34+
${CMAKE_CURRENT_SOURCE_DIR}/gemms/gather_gemm.cu
3435
${CMAKE_CURRENT_SOURCE_DIR}/gemms/grouped_gemm_unaligned.cu
3536
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cu
3637
${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp
@@ -119,11 +120,11 @@ target_compile_options(mlx
119120
target_compile_options(
120121
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--expt-relaxed-constexpr>")
121122

122-
if(MSVC)
123-
# Ignore warnings from CUTLASS.
124-
target_compile_options(
125-
mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe="--diag_suppress=2908">)
126-
else()
123+
# Ignore warnings from CUTLASS.
124+
target_compile_options(
125+
mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe="--diag_suppress=2908,2361">)
126+
127+
if(NOT MSVC)
127128
# Required for generating optimized CUTLASS code.
128129
target_compile_options(
129130
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:-Xcompiler=-fno-strict-aliasing>")
Lines changed: 334 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,334 @@
1+
// Copyright © 2026 Apple Inc.
2+
3+
#include "mlx/backend/cuda/cutlass_utils.cuh"
4+
#include "mlx/backend/cuda/device.h"
5+
#include "mlx/backend/cuda/kernel_utils.cuh"
6+
#include "mlx/dtype_utils.h"
7+
8+
#include <cutlass/epilogue/collective/collective_epilogue.hpp>
9+
#include <cutlass/epilogue/thread/linear_combination.h>
10+
#include <cutlass/gemm/collective/collective_mma.hpp>
11+
#include <cutlass/gemm/device/gemm_universal_adapter.h>
12+
#include <cutlass/gemm/dispatch_policy.hpp>
13+
#include <cutlass/gemm/kernel/gemm_universal.hpp>
14+
15+
// We can't put kernel code in mlx::core due to name conflicts of "Shape".
16+
namespace cutlass_gemm {
17+
18+
using namespace cute;
19+
20+
// Modified from cutlass/include/cutlass/gemm/kernel/sm70_gemm.hpp to fuse
21+
// gather into GEMM.
22+
template <
23+
class ProblemShape_,
24+
class CollectiveMainloop_,
25+
class CollectiveEpilogue_>
26+
class GatherGemm {
27+
public:
28+
using ProblemShape = ProblemShape_;
29+
using CollectiveMainloop = CollectiveMainloop_;
30+
using TileShape = typename CollectiveMainloop::TileShape;
31+
using TiledMma = typename CollectiveMainloop::TiledMma;
32+
using ArchTag = typename CollectiveMainloop::ArchTag;
33+
using ElementA = typename CollectiveMainloop::ElementA;
34+
using StrideA = typename CollectiveMainloop::StrideA;
35+
using ElementB = typename CollectiveMainloop::ElementB;
36+
using StrideB = typename CollectiveMainloop::StrideB;
37+
using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy;
38+
using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator;
39+
40+
using CollectiveEpilogue = CollectiveEpilogue_;
41+
using ElementC = typename CollectiveEpilogue::ElementC;
42+
using StrideC = typename CollectiveEpilogue::StrideC;
43+
using ElementD = typename CollectiveEpilogue::ElementD;
44+
using StrideD = typename CollectiveEpilogue::StrideD;
45+
46+
static constexpr int SharedStorageSize = static_cast<int>(cute::max(
47+
sizeof(typename CollectiveMainloop::SharedStorage),
48+
sizeof(typename CollectiveEpilogue::SharedStorage)));
49+
static constexpr uint32_t MaxThreadsPerBlock =
50+
CUTE_STATIC_V(size(TiledMma{}));
51+
static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
52+
53+
struct Arguments {
54+
ProblemShape problem_shape;
55+
const uint32_t* lhs_indices;
56+
const uint32_t* rhs_indices;
57+
typename CollectiveMainloop::Arguments mainloop;
58+
typename CollectiveEpilogue::Arguments epilogue;
59+
};
60+
61+
struct Params {
62+
ProblemShape problem_shape;
63+
const uint32_t* lhs_indices;
64+
const uint32_t* rhs_indices;
65+
typename CollectiveMainloop::Params mainloop;
66+
typename CollectiveEpilogue::Params epilogue;
67+
};
68+
69+
static Params to_underlying_arguments(
70+
const Arguments& args,
71+
void* workspace) {
72+
return {
73+
args.problem_shape,
74+
args.lhs_indices,
75+
args.rhs_indices,
76+
CollectiveMainloop::to_underlying_arguments(
77+
args.problem_shape, args.mainloop, workspace),
78+
CollectiveEpilogue::to_underlying_arguments(
79+
args.problem_shape, args.epilogue, workspace)};
80+
}
81+
82+
static cutlass::Status
83+
initialize_workspace(const Arguments&, void*, cudaStream_t, void*) {
84+
return cutlass::Status::kSuccess;
85+
}
86+
87+
static dim3 get_grid_shape(const Params& params) {
88+
auto [m, n, k, l] = params.problem_shape;
89+
return dim3{
90+
uint32_t(ceil_div(m, shape<0>(TileShape{}))),
91+
uint32_t(ceil_div(n, shape<1>(TileShape{}))),
92+
uint32_t(l)};
93+
}
94+
95+
static dim3 get_block_shape() {
96+
return dim3{MaxThreadsPerBlock, 1, 1};
97+
}
98+
99+
CUTLASS_DEVICE void operator()(const Params& params, char* smem_buf) {
100+
int thread_idx = int(threadIdx.x);
101+
auto [m_coord, n_coord, l_coord] = uint3(blockIdx);
102+
103+
auto shape_MNKL = append<4>(params.problem_shape, Int<1>{});
104+
auto cta_tile = TileShape{};
105+
auto cta_coord = make_coord(m_coord, n_coord, _, l_coord);
106+
107+
// Represent the full tensors.
108+
Tensor mA_mkl = make_tensor(
109+
make_gmem_ptr(params.mainloop.ptr_A),
110+
select<0, 2, 3>(shape_MNKL),
111+
params.mainloop.dA);
112+
Tensor mB_nkl = make_tensor(
113+
make_gmem_ptr(params.mainloop.ptr_B),
114+
select<1, 2, 3>(shape_MNKL),
115+
params.mainloop.dB);
116+
117+
// Get batch slice.
118+
Tensor mA_mk = mA_mkl(_, _, params.lhs_indices[l_coord]);
119+
Tensor mB_nk = mB_nkl(_, _, params.rhs_indices[l_coord]);
120+
121+
// Slice to get the tiles this thread block is responsible for.
122+
Tensor gA =
123+
local_tile(mA_mk, cta_tile, take<0, 3>(cta_coord), Step<_1, X, _1>{});
124+
Tensor gB =
125+
local_tile(mB_nk, cta_tile, take<0, 3>(cta_coord), Step<X, _1, _1>{});
126+
127+
// Compute tile residues for predication.
128+
auto m_max_coord = size<0>(shape_MNKL) - size<0>(gA) * get<0>(cta_coord);
129+
auto n_max_coord = size<1>(shape_MNKL) - size<0>(gB) * get<1>(cta_coord);
130+
auto k_residue = size<2>(shape_MNKL) - size<1>(gA) * size<2>(gA);
131+
auto residue_mnk = make_tuple(m_max_coord, n_max_coord, k_residue);
132+
133+
// Allocate the tiled_mma and the accumulators for the (M,N) cta_tile.
134+
TiledMma tiled_mma;
135+
Tensor accum = partition_fragment_C(tiled_mma, take<0, 2>(cta_tile));
136+
clear(accum);
137+
138+
auto k_tile_iter = make_coord_iterator(shape<2>(gA));
139+
int k_tile_count = size<2>(gA);
140+
141+
// Perform the collective scoped MMA.
142+
CollectiveMainloop collective_mma;
143+
collective_mma(
144+
accum,
145+
gA,
146+
gB,
147+
accum,
148+
k_tile_iter,
149+
k_tile_count,
150+
residue_mnk,
151+
thread_idx,
152+
smem_buf);
153+
154+
// Epilogue and write to out.
155+
CollectiveEpilogue epilogue(params.epilogue);
156+
epilogue(
157+
shape_MNKL,
158+
cta_tile,
159+
cta_coord,
160+
accum,
161+
tiled_mma,
162+
residue_mnk,
163+
thread_idx,
164+
smem_buf);
165+
}
166+
};
167+
168+
template <typename Element, bool KMajor>
169+
struct SimtCopyTraits {};
170+
171+
template <typename Element>
172+
struct SimtCopyTraits<Element, true> {
173+
using GmemTiledCopy = decltype(make_tiled_copy(
174+
Copy_Atom<UniversalCopy<Element>, Element>{},
175+
Layout<Shape<_32, _8>, Stride<_8, _1>>{},
176+
Layout<Shape<_1, _1>>{}));
177+
using SmemLayout = Layout<Shape<_128, _8>, Stride<_1, Int<128 + 1>>>;
178+
using SmemCopyAtom = Copy_Atom<DefaultCopy, Element>;
179+
};
180+
181+
template <typename Element>
182+
struct SimtCopyTraits<Element, false> {
183+
using GmemTiledCopy = decltype(make_tiled_copy(
184+
Copy_Atom<UniversalCopy<Element>, Element>{},
185+
Layout<Shape<_32, _8>, Stride<_1, _32>>{},
186+
Layout<Shape<_1, _1>>{}));
187+
using SmemLayout = Layout<Shape<_128, _8>, Stride<_1, _128>>;
188+
using SmemCopyAtom = Copy_Atom<DefaultCopy, Element>;
189+
};
190+
191+
template <typename F>
192+
void dispatch_stride(bool k_major, int m, int k, F&& f) {
193+
if (k_major) {
194+
f(make_stride(k, Int<1>{}, m * k), std::true_type{});
195+
} else {
196+
f(make_stride(Int<1>{}, m, m * k), std::false_type{});
197+
}
198+
}
199+
200+
template <typename Element, typename F>
201+
void gather_mm(
202+
int m,
203+
int n,
204+
int k,
205+
int l,
206+
bool a_transposed,
207+
bool b_transposed,
208+
const Element* A,
209+
const Element* B,
210+
const uint32_t* lhs_indices,
211+
const uint32_t* rhs_indices,
212+
Element* C,
213+
F&& launch_kernel) {
214+
auto problem_shape = make_shape(m, n, k, l);
215+
auto dC = make_stride(m, Int<1>{}, m * n);
216+
dispatch_stride(!a_transposed, m, k, [&](auto dA, auto k_major_a) {
217+
dispatch_stride(b_transposed, n, k, [&](auto dB, auto k_major_b) {
218+
using Accumulator =
219+
std::conditional_t<(sizeof(Element) < 4), float, Element>;
220+
using TileShape = Shape<_128, _128, _8>;
221+
using DispatchPolicy = cutlass::gemm::MainloopSm70TwoStage;
222+
using TiledMma = TiledMMA<
223+
MMA_Atom<UniversalFMA<Accumulator, Element, Element, Element>>,
224+
Layout<Shape<_16, _16, _1>>>;
225+
226+
using CopyTraitsA = SimtCopyTraits<Element, k_major_a.value>;
227+
using CopyTraitsB = SimtCopyTraits<Element, k_major_b.value>;
228+
229+
using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma<
230+
DispatchPolicy,
231+
TileShape,
232+
Element,
233+
decltype(dA),
234+
Element,
235+
decltype(dB),
236+
TiledMma,
237+
typename CopyTraitsA::GmemTiledCopy,
238+
typename CopyTraitsA::SmemLayout,
239+
typename CopyTraitsA::SmemCopyAtom,
240+
identity,
241+
typename CopyTraitsB::GmemTiledCopy,
242+
typename CopyTraitsB::SmemLayout,
243+
typename CopyTraitsB::SmemCopyAtom,
244+
identity>;
245+
246+
using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue<
247+
Element,
248+
decltype(dC),
249+
decltype(dC),
250+
cutlass::epilogue::thread::
251+
LinearCombination<Element, 1, Accumulator, Accumulator>,
252+
cutlass::gemm::EpilogueDefault>;
253+
254+
using GemmKernel = GatherGemm<
255+
decltype(problem_shape),
256+
CollectiveMainloop,
257+
CollectiveEpilogue>;
258+
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
259+
260+
Gemm gemm;
261+
typename Gemm::Arguments args{
262+
problem_shape,
263+
lhs_indices,
264+
rhs_indices,
265+
{A, dA, B, dB},
266+
{{1.f, 0.f}, C, dC, C, dC}};
267+
268+
CHECK_CUTLASS_ERROR(gemm.initialize(args, nullptr));
269+
270+
auto* kernel = &cutlass::device_kernel<GemmKernel>;
271+
void* kernel_params[] = {const_cast<Gemm::Params*>(&gemm.params())};
272+
launch_kernel(
273+
reinterpret_cast<void*>(kernel),
274+
gemm.get_grid_shape(gemm.params()),
275+
GemmKernel::get_block_shape(),
276+
GemmKernel::SharedStorageSize,
277+
kernel_params);
278+
});
279+
});
280+
}
281+
282+
} // namespace cutlass_gemm
283+
284+
namespace mlx::core {
285+
286+
void cutlass_gather_mm(
287+
bool a_transposed,
288+
bool b_transposed,
289+
const array& a,
290+
const array& b,
291+
const array& lhs_indices,
292+
const array& rhs_indices,
293+
array& out,
294+
cu::CommandEncoder& encoder) {
295+
int m = out.shape(-2);
296+
int n = out.shape(-1);
297+
int k = a.shape(-1);
298+
int l = out.size() / (m * n);
299+
if (m < 16 || n < 16) {
300+
throw std::invalid_argument("[gather_mm] M/N is too small.");
301+
}
302+
303+
encoder.set_input_array(a);
304+
encoder.set_input_array(b);
305+
encoder.set_input_array(lhs_indices);
306+
encoder.set_input_array(rhs_indices);
307+
encoder.set_output_array(out);
308+
309+
dispatch_float_types(out.dtype(), "gather_mm", [&](auto type_tag) {
310+
using Element = cutlass_type_t<MLX_GET_TYPE(type_tag)>;
311+
cutlass_gemm::gather_mm(
312+
m,
313+
n,
314+
k,
315+
l,
316+
a_transposed,
317+
b_transposed,
318+
gpu_ptr<Element>(a),
319+
gpu_ptr<Element>(b),
320+
gpu_ptr<uint32_t>(lhs_indices),
321+
gpu_ptr<uint32_t>(rhs_indices),
322+
gpu_ptr<Element>(out),
323+
[&](auto* kernel,
324+
dim3 num_blocks,
325+
dim3 block_dims,
326+
uint32_t smem_bytes,
327+
void** args) {
328+
encoder.add_kernel_node_raw(
329+
kernel, num_blocks, block_dims, {}, smem_bytes, args);
330+
});
331+
});
332+
}
333+
334+
} // namespace mlx::core
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// Copyright © 2026 Apple Inc.
2+
3+
#pragma once
4+
5+
namespace mlx::core {
6+
7+
namespace cu {
8+
class CommandEncoder;
9+
}
10+
11+
class array;
12+
13+
void cutlass_gather_mm(
14+
bool a_transposed,
15+
bool b_transposed,
16+
const array& a,
17+
const array& b,
18+
const array& lhs_indices,
19+
const array& rhs_indices,
20+
array& out,
21+
cu::CommandEncoder& encoder);
22+
23+
} // namespace mlx::core

mlx/backend/cuda/gemms/gemv.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ __global__ void gemv_gather(
167167
}
168168

169169
bool can_use_gemv(int M, int N, int K, bool a_transposed, bool b_transposed) {
170-
return K % 32 == 0 && ((M == 1 && b_transposed) || (N == 1 && !a_transposed));
170+
return (M == 1 && b_transposed) || (N == 1 && !a_transposed);
171171
}
172172

173173
template <typename F>

0 commit comments

Comments
 (0)