Skip to content

Commit fa4320d

Browse files
authored
[CUDA] Handle residue k in qmm_naive (#3379)
1 parent 859f22f commit fa4320d

File tree

7 files changed

+147
-79
lines changed

7 files changed

+147
-79
lines changed

mlx/backend/cuda/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ target_link_libraries(mlx PRIVATE CUDNN::cudnn_all)
280280
FetchContent_Declare(
281281
cutlass
282282
GIT_REPOSITORY https://github.com/NVIDIA/cutlass.git
283-
GIT_TAG v4.3.5
283+
GIT_TAG v4.4.2
284284
GIT_SHALLOW TRUE
285285
SOURCE_SUBDIR include EXCLUDE_FROM_ALL)
286286
FetchContent_MakeAvailable(cutlass)

mlx/backend/cuda/quantized/qmm/qmm.cu

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,15 @@
77

88
namespace mlx::core {
99

10+
namespace {
11+
12+
inline bool is_last_2_dims_row_contiguous(const array& x) {
13+
return x.flags().contiguous && (x.ndim() >= 2) && (x.strides(-1) == 1) &&
14+
(x.strides(-2) == x.shape(-1));
15+
}
16+
17+
} // namespace
18+
1019
#if defined(MLX_CUDA_SM90A_ENABLED)
1120
// Defined in qmm_impl_sm90_xxx.cu files.
1221
template <typename TileShape, typename ClusterShape>
@@ -43,8 +52,9 @@ bool supports_qmm_sm90(
4352
if (!biases) {
4453
return false;
4554
}
46-
if (!x.flags().row_contiguous || !w.flags().row_contiguous ||
47-
!scales.flags().row_contiguous || !biases->flags().row_contiguous) {
55+
if (!x.flags().row_contiguous || !is_last_2_dims_row_contiguous(w) ||
56+
!is_last_2_dims_row_contiguous(scales) ||
57+
!is_last_2_dims_row_contiguous(*biases)) {
4858
return false;
4959
}
5060
if (!transpose) {
@@ -132,11 +142,11 @@ bool supports_qmm_sm80(
132142
if ((n % 128 != 0) || (k % std::max(64, group_size) != 0)) {
133143
return false;
134144
}
135-
if (!x.flags().row_contiguous || !w.flags().row_contiguous ||
136-
!scales.flags().row_contiguous) {
145+
if (!x.flags().row_contiguous || !is_last_2_dims_row_contiguous(w) ||
146+
!is_last_2_dims_row_contiguous(scales)) {
137147
return false;
138148
}
139-
if (biases && !biases->flags().row_contiguous) {
149+
if (biases && !is_last_2_dims_row_contiguous(*biases)) {
140150
return false;
141151
}
142152
if (x.dtype() != float16 && x.dtype() != bfloat16) {
@@ -214,14 +224,14 @@ bool supports_qmm_naive(
214224
QuantizationMode mode,
215225
cu::Device& device) {
216226
int k = x.shape(-1);
217-
if (k % std::max(64, group_size) != 0) {
227+
if (transpose && (k % std::max(64, group_size) != 0)) {
218228
return false;
219229
}
220-
if (!x.flags().row_contiguous || !w.flags().row_contiguous ||
221-
!scales.flags().row_contiguous) {
230+
if (!x.flags().row_contiguous || !is_last_2_dims_row_contiguous(w) ||
231+
!is_last_2_dims_row_contiguous(scales)) {
222232
return false;
223233
}
224-
if (biases && !biases->flags().row_contiguous) {
234+
if (biases && !is_last_2_dims_row_contiguous(*biases)) {
225235
return false;
226236
}
227237
return true;
@@ -313,11 +323,11 @@ bool supports_qmv(
313323
if (k % 8 != 0) {
314324
return false;
315325
}
316-
if (!x.flags().row_contiguous || !w.flags().row_contiguous ||
317-
!scales.flags().row_contiguous) {
326+
if (!x.flags().row_contiguous || !is_last_2_dims_row_contiguous(w) ||
327+
!is_last_2_dims_row_contiguous(scales)) {
318328
return false;
319329
}
320-
if (biases && !biases->flags().row_contiguous) {
330+
if (biases && !is_last_2_dims_row_contiguous(*biases)) {
321331
return false;
322332
}
323333
if (!transpose) {

mlx/backend/cuda/quantized/qmm/qmm_impl_naive.cuh

Lines changed: 121 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ cute_dequant(auto w, auto s, auto z, auto out) {
4040
}
4141
}
4242

43-
template <typename ProblemShape, typename CtaTiler,
43+
template <bool HasKResidue, typename ProblemShape, typename CtaTiler,
4444
typename Element, typename Quant, typename Scale,
4545
typename StrideA, typename SmemLayoutA, typename TiledCopyA,
4646
typename StrideB, typename SmemLayoutB, typename TiledCopyB,
@@ -62,6 +62,20 @@ __global__ void qmm_naive_kernel(
6262
int thread_idx = int(threadIdx.x);
6363
auto [m_coord, n_coord, l_coord] = static_cast<uint3>(blockIdx);
6464

65+
auto m_max_coord = size<0>(shape_MNKL) - size<0>(cta_tiler) * m_coord; // M - BLK_M * m_coord
66+
auto n_max_coord = size<1>(shape_MNKL) - size<1>(cta_tiler) * n_coord; // N - BLK_N * n_coord
67+
68+
// Shift tensor so we handle residue of K in the 0th tile.
69+
auto shape_K = size<2>(shape_MNKL);
70+
auto bK = size<2>(cta_tiler);
71+
auto k_residue = shape_K - bK * ceil_div(shape_K, bK);
72+
if constexpr (HasKResidue) {
73+
A += k_residue * get<1>(dA);
74+
B += k_residue * get<1>(dB) * cuda::std::min(8, sizeof_bits_v<Quant>) / 8;
75+
S += k_residue * stride<1>(S_layout);
76+
Z += k_residue * stride<1>(S_layout);
77+
}
78+
6579
// Represent the full tensors.
6680
Tensor mA_mkl = make_tensor(make_gmem_ptr(A), select<0,2,3>(shape_MNKL), dA); // (M,K,L)
6781
Tensor mB_nkl = make_tensor(make_gmem_ptr<Quant>(B), select<1,2,3>(shape_MNKL), dB); // (N,K,L)
@@ -91,9 +105,6 @@ __global__ void qmm_naive_kernel(
91105
Tensor gS = local_tile(mS, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k)
92106
Tensor gZ = local_tile(mZ, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k)
93107

94-
auto m_max_coord = size<0>(shape_MNKL) - size<0>(gA) * m_coord; // M - BLK_M * m_coord
95-
auto n_max_coord = size<1>(shape_MNKL) - size<0>(gB) * n_coord; // N - BLK_N * n_coord
96-
97108
// Shared memory buffers.
98109
extern __shared__ char shared_memory[];
99110
using SharedStorage = SharedStorage<Element, SmemLayoutA, SmemLayoutB>;
@@ -164,17 +175,53 @@ __global__ void qmm_naive_kernel(
164175
__syncthreads();
165176
};
166177

178+
// Clear the rmem tiles to account for predicated off loads.
179+
if constexpr (HasKResidue) {
180+
clear(tArA);
181+
clear(tBrB);
182+
clear(tBrS);
183+
clear(tBrZ);
184+
}
185+
167186
// Prefetch first tile.
168-
fetch_gmem(0);
187+
if constexpr (HasKResidue) {
188+
Tensor tAgA_k = tAgA(_,_,_,0);
189+
CUTE_UNROLL
190+
for (int k = 0; k < size<2>(tArA); ++k) {
191+
if (get<1>(tAcA(0,0,k)) >= -k_residue) {
192+
copy_if(copy_a, tApA(_,k), tAgA_k(_,_,k), tArA(_,_,k));
193+
}
194+
}
195+
Tensor tBgB_k = tBgB(_,_,_,0);
196+
Tensor tBgS_k = tBgS(_,_,_,0);
197+
Tensor tBgZ_k = tBgZ(_,_,_,0);
198+
CUTE_UNROLL
199+
for (int k = 0; k < size<2>(tBrB); ++k) {
200+
if (get<1>(tBcB(0,0,k)) >= -k_residue) {
201+
copy_if(copy_b, tBpB(_,k), tBgB_k(_,_,k), tBrB(_,_,k));
202+
copy(tBgS_k(_,_,k), tBrS(_,_,k));
203+
copy(tBgZ_k(_,_,k), tBrZ(_,_,k));
204+
}
205+
}
206+
} else {
207+
fetch_gmem(0);
208+
}
169209

170210
// Clear accumulators.
171211
clear(tCrC);
172212

173213
// Loop over CTA tiles.
174-
auto K_TILE_MAX = size<3>(tAgA);
214+
auto K_TILE_MAX = size<3>(tAgA);
175215
for (int tile = 0; tile < K_TILE_MAX; ++tile) {
176216
store_smem();
177-
fetch_gmem((tile + 1 < K_TILE_MAX) ? tile + 1 : tile);
217+
if constexpr (HasKResidue) {
218+
// Avoid fetching full 0th-tile when there is residue.
219+
if (K_TILE_MAX > 1) {
220+
fetch_gmem((tile + 1 < K_TILE_MAX) ? tile + 1 : tile);
221+
}
222+
} else {
223+
fetch_gmem((tile + 1 < K_TILE_MAX) ? tile + 1 : tile);
224+
}
178225
gemm(mma, tCsA, tCsB, tCrC);
179226
}
180227

@@ -236,9 +283,10 @@ inline constexpr auto make_tiled_mma() {
236283
}
237284
}
238285

239-
template <typename T, bool KMajor = true>
286+
template <typename T, bool KMajor = true, bool HasKResidue = false>
240287
inline auto make_tiled_copy(auto num_threads, auto bM, auto bK) {
241-
auto n_read = Int<8>{};
288+
// TODO: Only do 1-element read for the tile of residue.
289+
auto n_read = Int<HasKResidue ? 1 : 8>{};
242290
auto atom = Copy_Atom<UniversalCopy<uint_bit_t<n_read * sizeof_bits_v<T>>>, T>{};
243291
if constexpr (KMajor) {
244292
auto k_threads = bK / n_read;
@@ -268,7 +316,7 @@ inline constexpr auto make_scales_layout(auto n, auto k, auto l, auto group_size
268316
}
269317
}
270318

271-
template <int TileM = 16, bool KMajor = true, bool SM80 = true,
319+
template <int TileM = 16, bool KMajor = true, bool SM80 = true, bool HasKResidue = false,
272320
typename Element, typename Quant, typename Scale>
273321
void qmm_naive(
274322
const Element* A,
@@ -314,11 +362,11 @@ void qmm_naive(
314362
auto sB_layout = make_smem_layout<KMajor>(bN, bK);
315363

316364
// Atoms.
317-
TiledCopy copy_a = make_tiled_copy<Element>(num_threads, bM, bK);
365+
TiledCopy copy_a = make_tiled_copy<Element, true, HasKResidue>(num_threads, bM, bK);
318366
TiledCopy copy_b = make_tiled_copy<Quant, KMajor>(num_threads, bN, bK);
319367

320368
auto* kernel = &qmm_naive_kernel<
321-
decltype(prob_shape), decltype(cta_tiler),
369+
HasKResidue, decltype(prob_shape), decltype(cta_tiler),
322370
Element, Quant, Scale,
323371
decltype(dA), decltype(sA_layout), decltype(copy_a),
324372
decltype(dB), decltype(sB_layout), decltype(copy_b),
@@ -348,6 +396,21 @@ void qmm_naive(
348396

349397
namespace mlx::core {
350398

399+
template <bool KMajor, typename F>
400+
inline void dispatch_k(bool has_k_residue, const char* tag, F&& f) {
401+
if constexpr (KMajor) {
402+
if (has_k_residue) {
403+
throw std::invalid_argument(
404+
fmt::format("{} K must be multiples of group_size.", tag));
405+
}
406+
f.template operator()<false>();
407+
} else {
408+
dispatch_bool(has_k_residue, [&](auto has_k_residue) {
409+
f.template operator()<has_k_residue.value>();
410+
});
411+
}
412+
}
413+
351414
template <typename F>
352415
inline void dispatch_element_types(Dtype dtype, const char* tag, F&& f) {
353416
if (dtype == float32) {
@@ -429,53 +492,55 @@ void qmm_impl_naive(
429492
int n = out.shape(-1);
430493
int k = x.shape(-1);
431494
int l = out.size() / (m * n);
432-
bool broadcast_b = w.ndim() == 2;
495+
bool broadcast_b = (w.ndim() <= 2) || (w.size() != w.data_size());
433496

434497
bool is_sm80 = encoder.device().compute_capability_major() >= 8;
435498
dispatch_bool(is_sm80, [&](auto sm80) {
436-
dispatch_element_types(out.dtype(), tag, [&]<typename Element>() {
437-
dispatch_quant_types<Element>(
438-
bits,
439-
group_size,
440-
mode,
441-
tag,
442-
[&]<typename Quant, typename Scale, int group_size>() {
443-
encoder.set_input_array(x);
444-
encoder.set_input_array(w);
445-
encoder.set_input_array(scales);
446-
if (biases) {
447-
encoder.set_input_array(*biases);
448-
}
449-
if (lhs_indices) {
450-
encoder.set_input_array(*lhs_indices);
451-
}
452-
if (rhs_indices) {
453-
encoder.set_input_array(*rhs_indices);
454-
}
455-
encoder.set_output_array(out);
456-
cutlass_gemm::qmm_naive<TileM, KMajor, sm80.value>(
457-
gpu_ptr<Element>(x),
458-
gpu_ptr<Quant>(w),
459-
gpu_ptr<Scale>(scales),
460-
biases ? gpu_ptr<Element>(*biases) : nullptr,
461-
lhs_indices ? gpu_ptr<uint32_t>(*lhs_indices) : nullptr,
462-
rhs_indices ? gpu_ptr<uint32_t>(*rhs_indices) : nullptr,
463-
gpu_ptr<Element>(out),
464-
m,
465-
n,
466-
k,
467-
l,
468-
broadcast_b,
469-
cute::Int<group_size>{},
470-
[&](auto* kernel,
471-
dim3 num_blocks,
472-
dim3 block_dims,
473-
uint32_t smem_bytes,
474-
void** args) {
475-
encoder.add_kernel_node_raw(
476-
kernel, num_blocks, block_dims, {}, smem_bytes, args);
477-
});
478-
});
499+
dispatch_k<KMajor>(k % group_size != 0, tag, [&]<bool has_k_residue>() {
500+
dispatch_element_types(out.dtype(), tag, [&]<typename Element>() {
501+
dispatch_quant_types<Element>(
502+
bits,
503+
group_size,
504+
mode,
505+
tag,
506+
[&]<typename Quant, typename Scale, int group_size>() {
507+
encoder.set_input_array(x);
508+
encoder.set_input_array(w);
509+
encoder.set_input_array(scales);
510+
if (biases) {
511+
encoder.set_input_array(*biases);
512+
}
513+
if (lhs_indices) {
514+
encoder.set_input_array(*lhs_indices);
515+
}
516+
if (rhs_indices) {
517+
encoder.set_input_array(*rhs_indices);
518+
}
519+
encoder.set_output_array(out);
520+
cutlass_gemm::qmm_naive<TileM, KMajor, sm80.value, has_k_residue>(
521+
gpu_ptr<Element>(x),
522+
gpu_ptr<Quant>(w),
523+
gpu_ptr<Scale>(scales),
524+
biases ? gpu_ptr<Element>(*biases) : nullptr,
525+
lhs_indices ? gpu_ptr<uint32_t>(*lhs_indices) : nullptr,
526+
rhs_indices ? gpu_ptr<uint32_t>(*rhs_indices) : nullptr,
527+
gpu_ptr<Element>(out),
528+
m,
529+
n,
530+
k,
531+
l,
532+
broadcast_b,
533+
cute::Int<group_size>{},
534+
[&](auto* kernel,
535+
dim3 num_blocks,
536+
dim3 block_dims,
537+
uint32_t smem_bytes,
538+
void** args) {
539+
encoder.add_kernel_node_raw(
540+
kernel, num_blocks, block_dims, {}, smem_bytes, args);
541+
});
542+
});
543+
});
479544
});
480545
});
481546
}

mlx/backend/cuda/quantized/qmm/qmm_impl_sm80.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -451,7 +451,7 @@ void qmm_impl_sm80(
451451
int n = out.shape(-1);
452452
int k = x.shape(-1);
453453
int l = out.size() / (m * n);
454-
bool broadcast_b = w.ndim() == 2;
454+
bool broadcast_b = (w.ndim() <= 2) || (w.size() != w.data_size());
455455

456456
dispatch_element_types(out.dtype(), tag, [&]<typename Element>() {
457457
dispatch_quant_types<Element>(

mlx/backend/cuda/quantized/qmm/qmm_impl_sm90.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ void qmm_impl_sm90(
193193
int n = out.shape(-1);
194194
int k = x.shape(-1);
195195
int l = out.size() / (m * n);
196-
bool broadcast_b = w.ndim() == 2;
196+
bool broadcast_b = (w.ndim() <= 2) || (w.size() != w.data_size());
197197

198198
// FIXME: Copy happens for every call.
199199
array scales = transpose_last_2_dims(scales_, encoder, s);

mlx/backend/cuda/quantized/qmm/qmv.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ void qmv(
406406
int n = out.shape(-1);
407407
int k = x.shape(-1);
408408
int l = out.size() / (m * n);
409-
bool broadcast_w = w.ndim() == 2;
409+
bool broadcast_w = (w.ndim() <= 2) || (w.size() != w.data_size());
410410

411411
dispatch_element_types(out.dtype(), tag, [&]<typename T>() {
412412
dispatch_quant_types<T>(

python/tests/cuda_skip.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,4 @@
1919
"TestQuantized.test_gather_qmm",
2020
"TestQuantized.test_gather_qmm_sorted",
2121
"TestQuantized.test_gather_qmm_grad",
22-
"TestQuantized.test_non_multiples",
23-
"TestQuantized.test_qmm_shapes",
24-
"TestQuantized.test_fp_qvm",
25-
"TestQuantized.test_qvm",
26-
"TestQuantized.test_qmv_small_non_multiples",
27-
"TestQuantized.test_small_matrix",
28-
"TestExportImport.test_export_quantized_model",
2922
}

0 commit comments

Comments
 (0)