Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mlx/backend/cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ target_link_libraries(mlx PRIVATE CUDNN::cudnn_all)
FetchContent_Declare(
cutlass
GIT_REPOSITORY https://github.com/NVIDIA/cutlass.git
GIT_TAG v4.3.5
GIT_TAG v4.4.2
GIT_SHALLOW TRUE
SOURCE_SUBDIR include EXCLUDE_FROM_ALL)
FetchContent_MakeAvailable(cutlass)
Expand Down
34 changes: 22 additions & 12 deletions mlx/backend/cuda/quantized/qmm/qmm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,15 @@

namespace mlx::core {

namespace {

inline bool is_last_2_dims_row_contiguous(const array& x) {
return x.flags().contiguous && (x.ndim() >= 2) && (x.strides(-1) == 1) &&
(x.strides(-2) == x.shape(-1));
}

} // namespace

#if defined(MLX_CUDA_SM90A_ENABLED)
// Defined in qmm_impl_sm90_xxx.cu files.
template <typename TileShape, typename ClusterShape>
Expand Down Expand Up @@ -43,8 +52,9 @@ bool supports_qmm_sm90(
if (!biases) {
return false;
}
if (!x.flags().row_contiguous || !w.flags().row_contiguous ||
!scales.flags().row_contiguous || !biases->flags().row_contiguous) {
if (!x.flags().row_contiguous || !is_last_2_dims_row_contiguous(w) ||
!is_last_2_dims_row_contiguous(scales) ||
!is_last_2_dims_row_contiguous(*biases)) {
return false;
}
if (!transpose) {
Expand Down Expand Up @@ -132,11 +142,11 @@ bool supports_qmm_sm80(
if ((n % 128 != 0) || (k % std::max(64, group_size) != 0)) {
return false;
}
if (!x.flags().row_contiguous || !w.flags().row_contiguous ||
!scales.flags().row_contiguous) {
if (!x.flags().row_contiguous || !is_last_2_dims_row_contiguous(w) ||
!is_last_2_dims_row_contiguous(scales)) {
return false;
}
if (biases && !biases->flags().row_contiguous) {
if (biases && !is_last_2_dims_row_contiguous(*biases)) {
return false;
}
if (x.dtype() != float16 && x.dtype() != bfloat16) {
Expand Down Expand Up @@ -214,14 +224,14 @@ bool supports_qmm_naive(
QuantizationMode mode,
cu::Device& device) {
int k = x.shape(-1);
if (k % std::max(64, group_size) != 0) {
if (transpose && (k % std::max(64, group_size) != 0)) {
return false;
}
if (!x.flags().row_contiguous || !w.flags().row_contiguous ||
!scales.flags().row_contiguous) {
if (!x.flags().row_contiguous || !is_last_2_dims_row_contiguous(w) ||
!is_last_2_dims_row_contiguous(scales)) {
return false;
}
if (biases && !biases->flags().row_contiguous) {
if (biases && !is_last_2_dims_row_contiguous(*biases)) {
return false;
}
return true;
Expand Down Expand Up @@ -313,11 +323,11 @@ bool supports_qmv(
if (k % 8 != 0) {
return false;
}
if (!x.flags().row_contiguous || !w.flags().row_contiguous ||
!scales.flags().row_contiguous) {
if (!x.flags().row_contiguous || !is_last_2_dims_row_contiguous(w) ||
!is_last_2_dims_row_contiguous(scales)) {
return false;
}
if (biases && !biases->flags().row_contiguous) {
if (biases && !is_last_2_dims_row_contiguous(*biases)) {
return false;
}
if (!transpose) {
Expand Down
177 changes: 121 additions & 56 deletions mlx/backend/cuda/quantized/qmm/qmm_impl_naive.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ cute_dequant(auto w, auto s, auto z, auto out) {
}
}

template <typename ProblemShape, typename CtaTiler,
template <bool HasKResidue, typename ProblemShape, typename CtaTiler,
typename Element, typename Quant, typename Scale,
typename StrideA, typename SmemLayoutA, typename TiledCopyA,
typename StrideB, typename SmemLayoutB, typename TiledCopyB,
Expand All @@ -62,6 +62,20 @@ __global__ void qmm_naive_kernel(
int thread_idx = int(threadIdx.x);
auto [m_coord, n_coord, l_coord] = static_cast<uint3>(blockIdx);

auto m_max_coord = size<0>(shape_MNKL) - size<0>(cta_tiler) * m_coord; // M - BLK_M * m_coord
auto n_max_coord = size<1>(shape_MNKL) - size<1>(cta_tiler) * n_coord; // N - BLK_N * n_coord

// Shift tensor so we handle residue of K in the 0th tile.
auto shape_K = size<2>(shape_MNKL);
auto bK = size<2>(cta_tiler);
auto k_residue = shape_K - bK * ceil_div(shape_K, bK);
if constexpr (HasKResidue) {
A += k_residue * get<1>(dA);
B += k_residue * get<1>(dB) * cuda::std::min(8, sizeof_bits_v<Quant>) / 8;
S += k_residue * stride<1>(S_layout);
Z += k_residue * stride<1>(S_layout);
}

// Represent the full tensors.
Tensor mA_mkl = make_tensor(make_gmem_ptr(A), select<0,2,3>(shape_MNKL), dA); // (M,K,L)
Tensor mB_nkl = make_tensor(make_gmem_ptr<Quant>(B), select<1,2,3>(shape_MNKL), dB); // (N,K,L)
Expand Down Expand Up @@ -91,9 +105,6 @@ __global__ void qmm_naive_kernel(
Tensor gS = local_tile(mS, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k)
Tensor gZ = local_tile(mZ, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k)

auto m_max_coord = size<0>(shape_MNKL) - size<0>(gA) * m_coord; // M - BLK_M * m_coord
auto n_max_coord = size<1>(shape_MNKL) - size<0>(gB) * n_coord; // N - BLK_N * n_coord

// Shared memory buffers.
extern __shared__ char shared_memory[];
using SharedStorage = SharedStorage<Element, SmemLayoutA, SmemLayoutB>;
Expand Down Expand Up @@ -164,17 +175,53 @@ __global__ void qmm_naive_kernel(
__syncthreads();
};

// Clear the rmem tiles to account for predicated off loads.
if constexpr (HasKResidue) {
clear(tArA);
clear(tBrB);
clear(tBrS);
clear(tBrZ);
}

// Prefetch first tile.
fetch_gmem(0);
if constexpr (HasKResidue) {
Tensor tAgA_k = tAgA(_,_,_,0);
CUTE_UNROLL
for (int k = 0; k < size<2>(tArA); ++k) {
if (get<1>(tAcA(0,0,k)) >= -k_residue) {
copy_if(copy_a, tApA(_,k), tAgA_k(_,_,k), tArA(_,_,k));
}
}
Tensor tBgB_k = tBgB(_,_,_,0);
Tensor tBgS_k = tBgS(_,_,_,0);
Tensor tBgZ_k = tBgZ(_,_,_,0);
CUTE_UNROLL
for (int k = 0; k < size<2>(tBrB); ++k) {
if (get<1>(tBcB(0,0,k)) >= -k_residue) {
copy_if(copy_b, tBpB(_,k), tBgB_k(_,_,k), tBrB(_,_,k));
copy(tBgS_k(_,_,k), tBrS(_,_,k));
copy(tBgZ_k(_,_,k), tBrZ(_,_,k));
}
}
} else {
fetch_gmem(0);
}

// Clear accumulators.
clear(tCrC);

// Loop over CTA tiles.
auto K_TILE_MAX = size<3>(tAgA);
auto K_TILE_MAX = size<3>(tAgA);
for (int tile = 0; tile < K_TILE_MAX; ++tile) {
store_smem();
fetch_gmem((tile + 1 < K_TILE_MAX) ? tile + 1 : tile);
if constexpr (HasKResidue) {
// Avoid fetching full 0th-tile when there is residue.
if (K_TILE_MAX > 1) {
fetch_gmem((tile + 1 < K_TILE_MAX) ? tile + 1 : tile);
}
} else {
fetch_gmem((tile + 1 < K_TILE_MAX) ? tile + 1 : tile);
}
gemm(mma, tCsA, tCsB, tCrC);
}

Expand Down Expand Up @@ -236,9 +283,10 @@ inline constexpr auto make_tiled_mma() {
}
}

template <typename T, bool KMajor = true>
template <typename T, bool KMajor = true, bool HasKResidue = false>
inline auto make_tiled_copy(auto num_threads, auto bM, auto bK) {
auto n_read = Int<8>{};
// TODO: Only do 1-element read for the tile of residue.
auto n_read = Int<HasKResidue ? 1 : 8>{};
auto atom = Copy_Atom<UniversalCopy<uint_bit_t<n_read * sizeof_bits_v<T>>>, T>{};
if constexpr (KMajor) {
auto k_threads = bK / n_read;
Expand Down Expand Up @@ -268,7 +316,7 @@ inline constexpr auto make_scales_layout(auto n, auto k, auto l, auto group_size
}
}

template <int TileM = 16, bool KMajor = true, bool SM80 = true,
template <int TileM = 16, bool KMajor = true, bool SM80 = true, bool HasKResidue = false,
typename Element, typename Quant, typename Scale>
void qmm_naive(
const Element* A,
Expand Down Expand Up @@ -314,11 +362,11 @@ void qmm_naive(
auto sB_layout = make_smem_layout<KMajor>(bN, bK);

// Atoms.
TiledCopy copy_a = make_tiled_copy<Element>(num_threads, bM, bK);
TiledCopy copy_a = make_tiled_copy<Element, true, HasKResidue>(num_threads, bM, bK);
TiledCopy copy_b = make_tiled_copy<Quant, KMajor>(num_threads, bN, bK);

auto* kernel = &qmm_naive_kernel<
decltype(prob_shape), decltype(cta_tiler),
HasKResidue, decltype(prob_shape), decltype(cta_tiler),
Element, Quant, Scale,
decltype(dA), decltype(sA_layout), decltype(copy_a),
decltype(dB), decltype(sB_layout), decltype(copy_b),
Expand Down Expand Up @@ -348,6 +396,21 @@ void qmm_naive(

namespace mlx::core {

template <bool KMajor, typename F>
inline void dispatch_k(bool has_k_residue, const char* tag, F&& f) {
if constexpr (KMajor) {
if (has_k_residue) {
throw std::invalid_argument(
fmt::format("{} K must be multiples of group_size.", tag));
}
f.template operator()<false>();
} else {
dispatch_bool(has_k_residue, [&](auto has_k_residue) {
f.template operator()<has_k_residue.value>();
});
}
}

template <typename F>
inline void dispatch_element_types(Dtype dtype, const char* tag, F&& f) {
if (dtype == float32) {
Expand Down Expand Up @@ -429,53 +492,55 @@ void qmm_impl_naive(
int n = out.shape(-1);
int k = x.shape(-1);
int l = out.size() / (m * n);
bool broadcast_b = w.ndim() == 2;
bool broadcast_b = (w.ndim() <= 2) || (w.size() != w.data_size());

bool is_sm80 = encoder.device().compute_capability_major() >= 8;
dispatch_bool(is_sm80, [&](auto sm80) {
dispatch_element_types(out.dtype(), tag, [&]<typename Element>() {
dispatch_quant_types<Element>(
bits,
group_size,
mode,
tag,
[&]<typename Quant, typename Scale, int group_size>() {
encoder.set_input_array(x);
encoder.set_input_array(w);
encoder.set_input_array(scales);
if (biases) {
encoder.set_input_array(*biases);
}
if (lhs_indices) {
encoder.set_input_array(*lhs_indices);
}
if (rhs_indices) {
encoder.set_input_array(*rhs_indices);
}
encoder.set_output_array(out);
cutlass_gemm::qmm_naive<TileM, KMajor, sm80.value>(
gpu_ptr<Element>(x),
gpu_ptr<Quant>(w),
gpu_ptr<Scale>(scales),
biases ? gpu_ptr<Element>(*biases) : nullptr,
lhs_indices ? gpu_ptr<uint32_t>(*lhs_indices) : nullptr,
rhs_indices ? gpu_ptr<uint32_t>(*rhs_indices) : nullptr,
gpu_ptr<Element>(out),
m,
n,
k,
l,
broadcast_b,
cute::Int<group_size>{},
[&](auto* kernel,
dim3 num_blocks,
dim3 block_dims,
uint32_t smem_bytes,
void** args) {
encoder.add_kernel_node_raw(
kernel, num_blocks, block_dims, {}, smem_bytes, args);
});
});
dispatch_k<KMajor>(k % group_size != 0, tag, [&]<bool has_k_residue>() {
dispatch_element_types(out.dtype(), tag, [&]<typename Element>() {
dispatch_quant_types<Element>(
bits,
group_size,
mode,
tag,
[&]<typename Quant, typename Scale, int group_size>() {
encoder.set_input_array(x);
encoder.set_input_array(w);
encoder.set_input_array(scales);
if (biases) {
encoder.set_input_array(*biases);
}
if (lhs_indices) {
encoder.set_input_array(*lhs_indices);
}
if (rhs_indices) {
encoder.set_input_array(*rhs_indices);
}
encoder.set_output_array(out);
cutlass_gemm::qmm_naive<TileM, KMajor, sm80.value, has_k_residue>(
gpu_ptr<Element>(x),
gpu_ptr<Quant>(w),
gpu_ptr<Scale>(scales),
biases ? gpu_ptr<Element>(*biases) : nullptr,
lhs_indices ? gpu_ptr<uint32_t>(*lhs_indices) : nullptr,
rhs_indices ? gpu_ptr<uint32_t>(*rhs_indices) : nullptr,
gpu_ptr<Element>(out),
m,
n,
k,
l,
broadcast_b,
cute::Int<group_size>{},
[&](auto* kernel,
dim3 num_blocks,
dim3 block_dims,
uint32_t smem_bytes,
void** args) {
encoder.add_kernel_node_raw(
kernel, num_blocks, block_dims, {}, smem_bytes, args);
});
});
});
});
});
}
Expand Down
2 changes: 1 addition & 1 deletion mlx/backend/cuda/quantized/qmm/qmm_impl_sm80.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ void qmm_impl_sm80(
int n = out.shape(-1);
int k = x.shape(-1);
int l = out.size() / (m * n);
bool broadcast_b = w.ndim() == 2;
bool broadcast_b = (w.ndim() <= 2) || (w.size() != w.data_size());

dispatch_element_types(out.dtype(), tag, [&]<typename Element>() {
dispatch_quant_types<Element>(
Expand Down
2 changes: 1 addition & 1 deletion mlx/backend/cuda/quantized/qmm/qmm_impl_sm90.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ void qmm_impl_sm90(
int n = out.shape(-1);
int k = x.shape(-1);
int l = out.size() / (m * n);
bool broadcast_b = w.ndim() == 2;
bool broadcast_b = (w.ndim() <= 2) || (w.size() != w.data_size());

// FIXME: Copy happens for every call.
array scales = transpose_last_2_dims(scales_, encoder, s);
Expand Down
2 changes: 1 addition & 1 deletion mlx/backend/cuda/quantized/qmm/qmv.cu
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ void qmv(
int n = out.shape(-1);
int k = x.shape(-1);
int l = out.size() / (m * n);
bool broadcast_w = w.ndim() == 2;
bool broadcast_w = (w.ndim() <= 2) || (w.size() != w.data_size());

dispatch_element_types(out.dtype(), tag, [&]<typename T>() {
dispatch_quant_types<T>(
Expand Down
7 changes: 0 additions & 7 deletions python/tests/cuda_skip.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,4 @@
"TestQuantized.test_gather_qmm",
"TestQuantized.test_gather_qmm_sorted",
"TestQuantized.test_gather_qmm_grad",
"TestQuantized.test_non_multiples",
"TestQuantized.test_qmm_shapes",
"TestQuantized.test_fp_qvm",
"TestQuantized.test_qvm",
"TestQuantized.test_qmv_small_non_multiples",
"TestQuantized.test_small_matrix",
"TestExportImport.test_export_quantized_model",
}
Loading