Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 34 additions & 6 deletions mlx/backend/cuda/quantized/qmm/qmm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ void qmm_impl_sm80(
int bits,
int group_size,
QuantizationMode mode,
cu::CommandEncoder& encoder);
cu::CommandEncoder& encoder,
const uint32_t* lhs_indices,
const uint32_t* rhs_indices);

bool supports_qmm_sm80(
const array& x,
Expand Down Expand Up @@ -158,10 +160,22 @@ void qmm_sm80(
int bits,
int group_size,
QuantizationMode mode,
cu::CommandEncoder& encoder) {
cu::CommandEncoder& encoder,
const uint32_t* lhs_indices,
const uint32_t* rhs_indices) {
auto dispatch = [&]<int TileM>() {
qmm_impl_sm80<TileM>(
x, w, scales, biases, out, bits, group_size, mode, encoder);
x,
w,
scales,
biases,
out,
bits,
group_size,
mode,
encoder,
lhs_indices,
rhs_indices);
};
int m = out.ndim() > 1 ? out.shape(-2) : 1;
if (m <= 16) {
Expand All @@ -184,7 +198,9 @@ void qmm_impl_naive(
int bits,
int group_size,
QuantizationMode mode,
cu::CommandEncoder& encoder);
cu::CommandEncoder& encoder,
const uint32_t* lhs_indices,
const uint32_t* rhs_indices);

bool supports_qmm_naive(
const array& x,
Expand Down Expand Up @@ -221,10 +237,22 @@ void qmm_naive(
int bits,
int group_size,
QuantizationMode mode,
cu::CommandEncoder& encoder) {
cu::CommandEncoder& encoder,
const uint32_t* lhs_indices,
const uint32_t* rhs_indices) {
auto dispatch = [&]<int TileM, bool KMajor>() {
qmm_impl_naive<TileM, KMajor>(
x, w, scales, biases, out, bits, group_size, mode, encoder);
x,
w,
scales,
biases,
out,
bits,
group_size,
mode,
encoder,
lhs_indices,
rhs_indices);
};
dispatch_bool(transpose, [&](auto k_major) {
int m = out.ndim() > 1 ? out.shape(-2) : 1;
Expand Down
8 changes: 6 additions & 2 deletions mlx/backend/cuda/quantized/qmm/qmm.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ void qmm_sm80(
int bits,
int group_size,
QuantizationMode mode,
cu::CommandEncoder& encoder);
cu::CommandEncoder& encoder,
const uint32_t* lhs_indices = nullptr,
const uint32_t* rhs_indices = nullptr);
Comment thread
Lyxot marked this conversation as resolved.
Outdated

bool supports_qmm_naive(
const array& x,
Expand All @@ -77,7 +79,9 @@ void qmm_naive(
int bits,
int group_size,
QuantizationMode mode,
cu::CommandEncoder& encoder);
cu::CommandEncoder& encoder,
const uint32_t* lhs_indices = nullptr,
const uint32_t* rhs_indices = nullptr);

bool supports_fp_qmv(
const array& x,
Expand Down
35 changes: 25 additions & 10 deletions mlx/backend/cuda/quantized/qmm/qmm_impl_naive.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ __global__ void qmm_naive_kernel(
const Quant* B, StrideB dB, SmemLayoutB sB_layout, TiledCopyB copy_b,
Element* C, StrideC dC,
const Scale* S, const Element* Z, LayoutS S_layout,
TiledMma mma) {
TiledMma mma,
const uint32_t* lhs_indices = nullptr,
const uint32_t* rhs_indices = nullptr) {
CUTE_STATIC_ASSERT_V(size(copy_a) == size(mma));
CUTE_STATIC_ASSERT_V(size(copy_b) == size(mma));
CUTE_STATIC_ASSERT_V(congruent(select<0,2,3>(shape_MNKL), dA));
Expand All @@ -69,13 +71,17 @@ __global__ void qmm_naive_kernel(
Tensor mS_nkl = make_tensor(make_gmem_ptr(S), S_layout); // (N,(group_size,K/group_size),L)
Tensor mZ_nkl = make_tensor(make_gmem_ptr(Z), S_layout); // (N,(group_size,K/group_size),L)

// For gather, use index lookup for input batch slicing.
uint32_t a_batch = lhs_indices ? lhs_indices[l_coord] : l_coord;
uint32_t b_batch = rhs_indices ? rhs_indices[l_coord] : l_coord;

// Get batch slice.
Tensor mA = mA_mkl(_,_,l_coord); // (M,K)
Tensor mB = mB_nkl(_,_,l_coord); // (N,K)
Tensor mA = mA_mkl(_,_,a_batch); // (M,K)
Tensor mB = mB_nkl(_,_,b_batch); // (N,K)
Tensor mC = mC_mnl(_,_,l_coord); // (M,N)

Tensor mS = mS_nkl(_,_,l_coord); // (N,(group_size,K/group_size))
Tensor mZ = mZ_nkl(_,_,l_coord); // (N,(group_size,K/group_size))
Tensor mS = mS_nkl(_,_,b_batch); // (N,(group_size,K/group_size))
Tensor mZ = mZ_nkl(_,_,b_batch); // (N,(group_size,K/group_size))

// Get the appropriate blocks for this thread block.
auto cta_coord = make_coord(m_coord, n_coord, _); // (m,n,k)
Expand Down Expand Up @@ -274,7 +280,9 @@ void qmm_naive(
int m, int n, int k, int l,
bool broadcast_b,
auto group_size,
auto&& launch_kernel) {
auto&& launch_kernel,
const uint32_t* lhs_indices = nullptr,
const uint32_t* rhs_indices = nullptr) {
// Define shapes (dynamic).
auto prob_shape = make_shape(m, n, k, l); // (M,N,K,L)

Expand Down Expand Up @@ -330,7 +338,8 @@ void qmm_naive(
&B, &dB, &sB_layout, &copy_b,
&C, &dC,
&S, &Z, &S_layout,
&mma};
&mma,
&lhs_indices, &rhs_indices};
launch_kernel(reinterpret_cast<void*>(kernel), num_blocks, block_dims, smem_bytes, args);
}

Expand Down Expand Up @@ -413,7 +422,9 @@ void qmm_impl_naive(
int bits,
int group_size,
QuantizationMode mode,
cu::CommandEncoder& encoder) {
cu::CommandEncoder& encoder,
const uint32_t* lhs_indices = nullptr,
const uint32_t* rhs_indices = nullptr) {
const char* tag = "[quantized_matmul]";
int m = out.ndim() > 1 ? out.shape(-2) : 1;
int n = out.shape(-1);
Expand Down Expand Up @@ -456,7 +467,9 @@ void qmm_impl_naive(
void** args) {
encoder.add_kernel_node_raw(
kernel, num_blocks, block_dims, {}, smem_bytes, args);
});
},
lhs_indices,
rhs_indices);
});
});
});
Expand All @@ -475,5 +488,7 @@ void qmm_impl_naive(
int bits, \
int group_size, \
QuantizationMode mode, \
cu::CommandEncoder& encoder); \
cu::CommandEncoder& encoder, \
const uint32_t* lhs_indices, \
const uint32_t* rhs_indices); \
}
35 changes: 25 additions & 10 deletions mlx/backend/cuda/quantized/qmm/qmm_impl_sm80.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ __global__ void qmm_sm80_kernel(
const Element* A, StrideA dA, SmemLayoutA sA_layout, TiledCopyA g2s_copy_a, S2RAtomA s2r_atom_a,
const Quant* B, StrideB dB, SmemLayoutB sB_layout, TiledCopyB g2s_copy_b, S2RAtomB s2r_atom_b,
Element* C, StrideC dC, SmemLayoutC sC_layout, TiledCopyC s2g_copy_c, R2SAtomC r2s_atom_c,
const Scale* S, const Element* Z, LayoutS S_layout, G2RAtomS g2r_atom_s, TiledMma mma) {
const Scale* S, const Element* Z, LayoutS S_layout, G2RAtomS g2r_atom_s, TiledMma mma,
const uint32_t* lhs_indices = nullptr,
const uint32_t* rhs_indices = nullptr) {
CUTE_STATIC_ASSERT_V(size(g2s_copy_a) == size(mma));
CUTE_STATIC_ASSERT_V(size(g2s_copy_b) == size(mma));
CUTE_STATIC_ASSERT_V(size(s2g_copy_c) == size(mma));
Expand All @@ -48,6 +50,10 @@ __global__ void qmm_sm80_kernel(
int thread_idx = int(threadIdx.x);
auto [m_coord, n_coord, l_coord] = static_cast<uint3>(blockIdx);

// For gather, use index lookup for input batch slicing.
uint32_t a_batch = lhs_indices ? lhs_indices[l_coord] : l_coord;
uint32_t b_batch = rhs_indices ? rhs_indices[l_coord] : l_coord;

// Represent the full tensors.
Tensor mA_mkl = make_tensor(make_gmem_ptr(A), select<0,2,3>(shape_MNKL), dA); // (M,K,L)
Tensor mB_nkl = make_tensor(make_gmem_ptr<Quant>(B), select<1,2,3>(shape_MNKL), dB); // (N,K,L)
Expand All @@ -57,12 +63,12 @@ __global__ void qmm_sm80_kernel(
Tensor mZ_nkl = make_tensor(make_gmem_ptr(Z), S_layout); // (N,(group_size,K/group_size),L)

// Get batch slice.
Tensor mA = mA_mkl(_,_,l_coord); // (M,K)
Tensor mB = mB_nkl(_,_,l_coord); // (N,K)
Tensor mA = mA_mkl(_,_,a_batch); // (M,K)
Tensor mB = mB_nkl(_,_,b_batch); // (N,K)
Tensor mC = mC_mnl(_,_,l_coord); // (M,N)

Tensor mS = mS_nkl(_,_,l_coord); // (N,(group_size,K/group_size))
Tensor mZ = mZ_nkl(_,_,l_coord); // (N,(group_size,K/group_size))
Tensor mS = mS_nkl(_,_,b_batch); // (N,(group_size,K/group_size))
Tensor mZ = mZ_nkl(_,_,b_batch); // (N,(group_size,K/group_size))

// Get the appropriate blocks for this thread block.
auto cta_coord = make_coord(m_coord, n_coord, _); // (m,n,k)
Expand Down Expand Up @@ -277,7 +283,9 @@ void qmm_sm80(
int m, int n, int k, int l,
bool broadcast_b,
GroupSize group_size,
F&& launch_kernel) {
F&& launch_kernel,
const uint32_t* lhs_indices = nullptr,
const uint32_t* rhs_indices = nullptr) {
// Define shapes (dynamic).
auto prob_shape = make_shape(m, n, k, l); // (M,N,K,L)

Expand Down Expand Up @@ -360,7 +368,8 @@ void qmm_sm80(
&A, &dA, &sA_layout, &g2s_copy_a, &s2r_atom_a,
&B, &dB, &sB_layout, &g2s_copy_b, &s2r_atom_b,
&C, &dC, &sC_layout, &s2g_copy_c, &r2s_atom_c,
&S, &Z, &S_layout, &g2r_atom_s, &mma};
&S, &Z, &S_layout, &g2r_atom_s, &mma,
&lhs_indices, &rhs_indices};
launch_kernel(reinterpret_cast<void*>(kernel), num_blocks, block_dims, smem_bytes, args);
}

Expand Down Expand Up @@ -433,7 +442,9 @@ void qmm_impl_sm80(
int bits,
int group_size,
QuantizationMode mode,
cu::CommandEncoder& encoder) {
cu::CommandEncoder& encoder,
const uint32_t* lhs_indices = nullptr,
const uint32_t* rhs_indices = nullptr) {
const char* tag = "[quantized_matmul]";
int m = out.ndim() > 1 ? out.shape(-2) : 1;
int n = out.shape(-1);
Expand Down Expand Up @@ -474,7 +485,9 @@ void qmm_impl_sm80(
void** args) {
encoder.add_kernel_node_raw(
kernel, num_blocks, block_dims, {}, smem_bytes, args);
});
},
lhs_indices,
rhs_indices);
});
});
}
Expand All @@ -492,5 +505,7 @@ void qmm_impl_sm80(
int bits, \
int group_size, \
QuantizationMode mode, \
cu::CommandEncoder& encoder); \
cu::CommandEncoder& encoder, \
const uint32_t* lhs_indices, \
const uint32_t* rhs_indices); \
}
57 changes: 56 additions & 1 deletion mlx/backend/cuda/quantized/quantized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
array lhs_indices = ensure_contiguous(inputs[inputs.size() - 2], encoder, s);
array rhs_indices = ensure_contiguous(inputs[inputs.size() - 1], encoder, s);

int M = out.shape(-2);
int M = out.ndim() > 1 ? out.shape(-2) : 1;
int N = out.shape(-1);
int K = x.shape(-1);
int B = out.size() / (M * N);
Expand All @@ -161,8 +161,45 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
mode_,
encoder.device());
};
bool can_use_qmm_sm80 = supports(supports_qmm_sm80);
bool can_use_qmm_naive = supports(supports_qmm_naive);
bool can_use_qmv = supports(supports_qmv);

auto call_qmm_sm80 = [&]() {
out.set_data(cu::malloc_async(out.nbytes(), encoder));
encoder.set_input_array(lhs_indices);
encoder.set_input_array(rhs_indices);
qmm_sm80(
x,
w,
scales,
biases,
out,
bits_,
group_size_,
mode_,
encoder,
gpu_ptr<uint32_t>(lhs_indices),
gpu_ptr<uint32_t>(rhs_indices));
};
auto call_qmm_naive = [&]() {
out.set_data(cu::malloc_async(out.nbytes(), encoder));
encoder.set_input_array(lhs_indices);
encoder.set_input_array(rhs_indices);
qmm_naive(
x,
w,
scales,
biases,
out,
transpose_,
bits_,
group_size_,
mode_,
encoder,
gpu_ptr<uint32_t>(lhs_indices),
gpu_ptr<uint32_t>(rhs_indices));
};
auto call_qmv = [&]() {
out.set_data(cu::malloc_async(out.nbytes(), encoder));
gather_qmv(
Expand All @@ -179,6 +216,24 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
encoder);
};

if (can_use_qmm_sm80) {
if (can_use_qmv && (M * B < 8)) {
call_qmv();
} else {
call_qmm_sm80();
}
return;
}

if (can_use_qmm_naive) {
if (can_use_qmv && (M * B < 8)) {
call_qmv();
} else {
call_qmm_naive();
}
return;
}

if (can_use_qmv) {
call_qmv();
return;
Expand Down
Loading