|
| 1 | +// Copyright © 2026 Apple Inc. |
| 2 | + |
| 3 | +#include "mlx/backend/cuda/cutlass_utils.cuh" |
| 4 | +#include "mlx/backend/cuda/device.h" |
| 5 | +#include "mlx/backend/cuda/kernel_utils.cuh" |
| 6 | +#include "mlx/dtype_utils.h" |
| 7 | + |
| 8 | +#include <cutlass/epilogue/collective/collective_epilogue.hpp> |
| 9 | +#include <cutlass/epilogue/thread/linear_combination.h> |
| 10 | +#include <cutlass/gemm/collective/collective_mma.hpp> |
| 11 | +#include <cutlass/gemm/device/gemm_universal_adapter.h> |
| 12 | +#include <cutlass/gemm/dispatch_policy.hpp> |
| 13 | +#include <cutlass/gemm/kernel/gemm_universal.hpp> |
| 14 | + |
| 15 | +// We can't put kernel code in mlx::core due to name conflicts of "Shape". |
| 16 | +namespace cutlass_gemm { |
| 17 | + |
| 18 | +using namespace cute; |
| 19 | + |
| 20 | +// Modified from cutlass/include/cutlass/gemm/kernel/sm70_gemm.hpp to fuse |
| 21 | +// gather into GEMM. |
| 22 | +template < |
| 23 | + class ProblemShape_, |
| 24 | + class CollectiveMainloop_, |
| 25 | + class CollectiveEpilogue_> |
| 26 | +class GatherGemm { |
| 27 | + public: |
| 28 | + using ProblemShape = ProblemShape_; |
| 29 | + using CollectiveMainloop = CollectiveMainloop_; |
| 30 | + using TileShape = typename CollectiveMainloop::TileShape; |
| 31 | + using TiledMma = typename CollectiveMainloop::TiledMma; |
| 32 | + using ArchTag = typename CollectiveMainloop::ArchTag; |
| 33 | + using ElementA = typename CollectiveMainloop::ElementA; |
| 34 | + using StrideA = typename CollectiveMainloop::StrideA; |
| 35 | + using ElementB = typename CollectiveMainloop::ElementB; |
| 36 | + using StrideB = typename CollectiveMainloop::StrideB; |
| 37 | + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; |
| 38 | + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; |
| 39 | + |
| 40 | + using CollectiveEpilogue = CollectiveEpilogue_; |
| 41 | + using ElementC = typename CollectiveEpilogue::ElementC; |
| 42 | + using StrideC = typename CollectiveEpilogue::StrideC; |
| 43 | + using ElementD = typename CollectiveEpilogue::ElementD; |
| 44 | + using StrideD = typename CollectiveEpilogue::StrideD; |
| 45 | + |
| 46 | + static constexpr int SharedStorageSize = static_cast<int>(cute::max( |
| 47 | + sizeof(typename CollectiveMainloop::SharedStorage), |
| 48 | + sizeof(typename CollectiveEpilogue::SharedStorage))); |
| 49 | + static constexpr uint32_t MaxThreadsPerBlock = |
| 50 | + CUTE_STATIC_V(size(TiledMma{})); |
| 51 | + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; |
| 52 | + |
| 53 | + struct Arguments { |
| 54 | + ProblemShape problem_shape; |
| 55 | + const uint32_t* lhs_indices; |
| 56 | + const uint32_t* rhs_indices; |
| 57 | + typename CollectiveMainloop::Arguments mainloop; |
| 58 | + typename CollectiveEpilogue::Arguments epilogue; |
| 59 | + }; |
| 60 | + |
| 61 | + struct Params { |
| 62 | + ProblemShape problem_shape; |
| 63 | + const uint32_t* lhs_indices; |
| 64 | + const uint32_t* rhs_indices; |
| 65 | + typename CollectiveMainloop::Params mainloop; |
| 66 | + typename CollectiveEpilogue::Params epilogue; |
| 67 | + }; |
| 68 | + |
| 69 | + static Params to_underlying_arguments( |
| 70 | + const Arguments& args, |
| 71 | + void* workspace) { |
| 72 | + return { |
| 73 | + args.problem_shape, |
| 74 | + args.lhs_indices, |
| 75 | + args.rhs_indices, |
| 76 | + CollectiveMainloop::to_underlying_arguments( |
| 77 | + args.problem_shape, args.mainloop, workspace), |
| 78 | + CollectiveEpilogue::to_underlying_arguments( |
| 79 | + args.problem_shape, args.epilogue, workspace)}; |
| 80 | + } |
| 81 | + |
| 82 | + static cutlass::Status |
| 83 | + initialize_workspace(const Arguments&, void*, cudaStream_t, void*) { |
| 84 | + return cutlass::Status::kSuccess; |
| 85 | + } |
| 86 | + |
| 87 | + static dim3 get_grid_shape(const Params& params) { |
| 88 | + auto [m, n, k, l] = params.problem_shape; |
| 89 | + return dim3{ |
| 90 | + uint32_t(ceil_div(m, shape<0>(TileShape{}))), |
| 91 | + uint32_t(ceil_div(n, shape<1>(TileShape{}))), |
| 92 | + uint32_t(l)}; |
| 93 | + } |
| 94 | + |
| 95 | + static dim3 get_block_shape() { |
| 96 | + return dim3{MaxThreadsPerBlock, 1, 1}; |
| 97 | + } |
| 98 | + |
| 99 | + CUTLASS_DEVICE void operator()(const Params& params, char* smem_buf) { |
| 100 | + int thread_idx = int(threadIdx.x); |
| 101 | + auto [m_coord, n_coord, l_coord] = uint3(blockIdx); |
| 102 | + |
| 103 | + auto shape_MNKL = append<4>(params.problem_shape, Int<1>{}); |
| 104 | + auto cta_tile = TileShape{}; |
| 105 | + auto cta_coord = make_coord(m_coord, n_coord, _, l_coord); |
| 106 | + |
| 107 | + // Represent the full tensors. |
| 108 | + Tensor mA_mkl = make_tensor( |
| 109 | + make_gmem_ptr(params.mainloop.ptr_A), |
| 110 | + select<0, 2, 3>(shape_MNKL), |
| 111 | + params.mainloop.dA); |
| 112 | + Tensor mB_nkl = make_tensor( |
| 113 | + make_gmem_ptr(params.mainloop.ptr_B), |
| 114 | + select<1, 2, 3>(shape_MNKL), |
| 115 | + params.mainloop.dB); |
| 116 | + |
| 117 | + // Get batch slice. |
| 118 | + Tensor mA_mk = mA_mkl(_, _, params.lhs_indices[l_coord]); |
| 119 | + Tensor mB_nk = mB_nkl(_, _, params.rhs_indices[l_coord]); |
| 120 | + |
| 121 | + // Slice to get the tiles this thread block is responsible for. |
| 122 | + Tensor gA = |
| 123 | + local_tile(mA_mk, cta_tile, take<0, 3>(cta_coord), Step<_1, X, _1>{}); |
| 124 | + Tensor gB = |
| 125 | + local_tile(mB_nk, cta_tile, take<0, 3>(cta_coord), Step<X, _1, _1>{}); |
| 126 | + |
| 127 | + // Compute tile residues for predication. |
| 128 | + auto m_max_coord = size<0>(shape_MNKL) - size<0>(gA) * get<0>(cta_coord); |
| 129 | + auto n_max_coord = size<1>(shape_MNKL) - size<0>(gB) * get<1>(cta_coord); |
| 130 | + auto k_residue = size<2>(shape_MNKL) - size<1>(gA) * size<2>(gA); |
| 131 | + auto residue_mnk = make_tuple(m_max_coord, n_max_coord, k_residue); |
| 132 | + |
| 133 | + // Allocate the tiled_mma and the accumulators for the (M,N) cta_tile. |
| 134 | + TiledMma tiled_mma; |
| 135 | + Tensor accum = partition_fragment_C(tiled_mma, take<0, 2>(cta_tile)); |
| 136 | + clear(accum); |
| 137 | + |
| 138 | + auto k_tile_iter = make_coord_iterator(shape<2>(gA)); |
| 139 | + int k_tile_count = size<2>(gA); |
| 140 | + |
| 141 | + // Perform the collective scoped MMA. |
| 142 | + CollectiveMainloop collective_mma; |
| 143 | + collective_mma( |
| 144 | + accum, |
| 145 | + gA, |
| 146 | + gB, |
| 147 | + accum, |
| 148 | + k_tile_iter, |
| 149 | + k_tile_count, |
| 150 | + residue_mnk, |
| 151 | + thread_idx, |
| 152 | + smem_buf); |
| 153 | + |
| 154 | + // Epilogue and write to out. |
| 155 | + CollectiveEpilogue epilogue(params.epilogue); |
| 156 | + epilogue( |
| 157 | + shape_MNKL, |
| 158 | + cta_tile, |
| 159 | + cta_coord, |
| 160 | + accum, |
| 161 | + tiled_mma, |
| 162 | + residue_mnk, |
| 163 | + thread_idx, |
| 164 | + smem_buf); |
| 165 | + } |
| 166 | +}; |
| 167 | + |
| 168 | +template <typename Element, bool KMajor> |
| 169 | +struct SimtCopyTraits {}; |
| 170 | + |
| 171 | +template <typename Element> |
| 172 | +struct SimtCopyTraits<Element, true> { |
| 173 | + using GmemTiledCopy = decltype(make_tiled_copy( |
| 174 | + Copy_Atom<UniversalCopy<Element>, Element>{}, |
| 175 | + Layout<Shape<_32, _8>, Stride<_8, _1>>{}, |
| 176 | + Layout<Shape<_1, _1>>{})); |
| 177 | + using SmemLayout = Layout<Shape<_128, _8>, Stride<_1, Int<128 + 1>>>; |
| 178 | + using SmemCopyAtom = Copy_Atom<DefaultCopy, Element>; |
| 179 | +}; |
| 180 | + |
| 181 | +template <typename Element> |
| 182 | +struct SimtCopyTraits<Element, false> { |
| 183 | + using GmemTiledCopy = decltype(make_tiled_copy( |
| 184 | + Copy_Atom<UniversalCopy<Element>, Element>{}, |
| 185 | + Layout<Shape<_32, _8>, Stride<_1, _32>>{}, |
| 186 | + Layout<Shape<_1, _1>>{})); |
| 187 | + using SmemLayout = Layout<Shape<_128, _8>, Stride<_1, _128>>; |
| 188 | + using SmemCopyAtom = Copy_Atom<DefaultCopy, Element>; |
| 189 | +}; |
| 190 | + |
| 191 | +template <typename F> |
| 192 | +void dispatch_stride(bool k_major, int m, int k, F&& f) { |
| 193 | + if (k_major) { |
| 194 | + f(make_stride(k, Int<1>{}, m * k), std::true_type{}); |
| 195 | + } else { |
| 196 | + f(make_stride(Int<1>{}, m, m * k), std::false_type{}); |
| 197 | + } |
| 198 | +} |
| 199 | + |
| 200 | +template <typename Element, typename F> |
| 201 | +void gather_mm( |
| 202 | + int m, |
| 203 | + int n, |
| 204 | + int k, |
| 205 | + int l, |
| 206 | + bool a_transposed, |
| 207 | + bool b_transposed, |
| 208 | + const Element* A, |
| 209 | + const Element* B, |
| 210 | + const uint32_t* lhs_indices, |
| 211 | + const uint32_t* rhs_indices, |
| 212 | + Element* C, |
| 213 | + F&& launch_kernel) { |
| 214 | + auto problem_shape = make_shape(m, n, k, l); |
| 215 | + auto dC = make_stride(m, Int<1>{}, m * n); |
| 216 | + dispatch_stride(!a_transposed, m, k, [&](auto dA, auto k_major_a) { |
| 217 | + dispatch_stride(b_transposed, n, k, [&](auto dB, auto k_major_b) { |
| 218 | + using Accumulator = |
| 219 | + std::conditional_t<(sizeof(Element) < 4), float, Element>; |
| 220 | + using TileShape = Shape<_128, _128, _8>; |
| 221 | + using DispatchPolicy = cutlass::gemm::MainloopSm70TwoStage; |
| 222 | + using TiledMma = TiledMMA< |
| 223 | + MMA_Atom<UniversalFMA<Accumulator, Element, Element, Element>>, |
| 224 | + Layout<Shape<_16, _16, _1>>>; |
| 225 | + |
| 226 | + using CopyTraitsA = SimtCopyTraits<Element, k_major_a.value>; |
| 227 | + using CopyTraitsB = SimtCopyTraits<Element, k_major_b.value>; |
| 228 | + |
| 229 | + using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< |
| 230 | + DispatchPolicy, |
| 231 | + TileShape, |
| 232 | + Element, |
| 233 | + decltype(dA), |
| 234 | + Element, |
| 235 | + decltype(dB), |
| 236 | + TiledMma, |
| 237 | + typename CopyTraitsA::GmemTiledCopy, |
| 238 | + typename CopyTraitsA::SmemLayout, |
| 239 | + typename CopyTraitsA::SmemCopyAtom, |
| 240 | + identity, |
| 241 | + typename CopyTraitsB::GmemTiledCopy, |
| 242 | + typename CopyTraitsB::SmemLayout, |
| 243 | + typename CopyTraitsB::SmemCopyAtom, |
| 244 | + identity>; |
| 245 | + |
| 246 | + using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< |
| 247 | + Element, |
| 248 | + decltype(dC), |
| 249 | + decltype(dC), |
| 250 | + cutlass::epilogue::thread:: |
| 251 | + LinearCombination<Element, 1, Accumulator, Accumulator>, |
| 252 | + cutlass::gemm::EpilogueDefault>; |
| 253 | + |
| 254 | + using GemmKernel = GatherGemm< |
| 255 | + decltype(problem_shape), |
| 256 | + CollectiveMainloop, |
| 257 | + CollectiveEpilogue>; |
| 258 | + using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>; |
| 259 | + |
| 260 | + Gemm gemm; |
| 261 | + typename Gemm::Arguments args{ |
| 262 | + problem_shape, |
| 263 | + lhs_indices, |
| 264 | + rhs_indices, |
| 265 | + {A, dA, B, dB}, |
| 266 | + {{1.f, 0.f}, C, dC, C, dC}}; |
| 267 | + |
| 268 | + CHECK_CUTLASS_ERROR(gemm.initialize(args, nullptr)); |
| 269 | + |
| 270 | + auto* kernel = &cutlass::device_kernel<GemmKernel>; |
| 271 | + void* kernel_params[] = {const_cast<Gemm::Params*>(&gemm.params())}; |
| 272 | + launch_kernel( |
| 273 | + reinterpret_cast<void*>(kernel), |
| 274 | + gemm.get_grid_shape(gemm.params()), |
| 275 | + GemmKernel::get_block_shape(), |
| 276 | + GemmKernel::SharedStorageSize, |
| 277 | + kernel_params); |
| 278 | + }); |
| 279 | + }); |
| 280 | +} |
| 281 | + |
| 282 | +} // namespace cutlass_gemm |
| 283 | + |
| 284 | +namespace mlx::core { |
| 285 | + |
| 286 | +void cutlass_gather_mm( |
| 287 | + bool a_transposed, |
| 288 | + bool b_transposed, |
| 289 | + const array& a, |
| 290 | + const array& b, |
| 291 | + const array& lhs_indices, |
| 292 | + const array& rhs_indices, |
| 293 | + array& out, |
| 294 | + cu::CommandEncoder& encoder) { |
| 295 | + int m = out.shape(-2); |
| 296 | + int n = out.shape(-1); |
| 297 | + int k = a.shape(-1); |
| 298 | + int l = out.size() / (m * n); |
| 299 | + if (m < 16 || n < 16) { |
| 300 | + throw std::invalid_argument("[gather_mm] M/N is too small."); |
| 301 | + } |
| 302 | + |
| 303 | + encoder.set_input_array(a); |
| 304 | + encoder.set_input_array(b); |
| 305 | + encoder.set_input_array(lhs_indices); |
| 306 | + encoder.set_input_array(rhs_indices); |
| 307 | + encoder.set_output_array(out); |
| 308 | + |
| 309 | + dispatch_float_types(out.dtype(), "gather_mm", [&](auto type_tag) { |
| 310 | + using Element = cutlass_type_t<MLX_GET_TYPE(type_tag)>; |
| 311 | + cutlass_gemm::gather_mm( |
| 312 | + m, |
| 313 | + n, |
| 314 | + k, |
| 315 | + l, |
| 316 | + a_transposed, |
| 317 | + b_transposed, |
| 318 | + gpu_ptr<Element>(a), |
| 319 | + gpu_ptr<Element>(b), |
| 320 | + gpu_ptr<uint32_t>(lhs_indices), |
| 321 | + gpu_ptr<uint32_t>(rhs_indices), |
| 322 | + gpu_ptr<Element>(out), |
| 323 | + [&](auto* kernel, |
| 324 | + dim3 num_blocks, |
| 325 | + dim3 block_dims, |
| 326 | + uint32_t smem_bytes, |
| 327 | + void** args) { |
| 328 | + encoder.add_kernel_node_raw( |
| 329 | + kernel, num_blocks, block_dims, {}, smem_bytes, args); |
| 330 | + }); |
| 331 | + }); |
| 332 | +} |
| 333 | + |
| 334 | +} // namespace mlx::core |
0 commit comments