Skip to content

Commit dc8b5d6

Browse files
committed
[CUDA] Handle residue k in qmm_naive
1 parent 6a9a121 commit dc8b5d6

File tree

4 files changed

+114
-55
lines changed

4 files changed

+114
-55
lines changed

mlx/backend/cuda/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ target_link_libraries(mlx PRIVATE CUDNN::cudnn_all)
279279
FetchContent_Declare(
280280
cutlass
281281
GIT_REPOSITORY https://github.com/NVIDIA/cutlass.git
282-
GIT_TAG v4.3.5
282+
GIT_TAG v4.4.2
283283
GIT_SHALLOW TRUE
284284
SOURCE_SUBDIR include EXCLUDE_FROM_ALL)
285285
FetchContent_MakeAvailable(cutlass)

mlx/backend/cuda/quantized/qmm/qmm.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ bool supports_qmm_naive(
198198
QuantizationMode mode,
199199
cu::Device& device) {
200200
int k = x.shape(-1);
201-
if (k % std::max(64, group_size) != 0) {
201+
if (transpose && (k % std::max(64, group_size) != 0)) {
202202
return false;
203203
}
204204
if (!x.flags().row_contiguous || !w.flags().row_contiguous ||

mlx/backend/cuda/quantized/qmm/qmm_impl_naive.cuh

Lines changed: 112 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ cute_dequant(auto w, auto s, auto z, auto out) {
4040
}
4141
}
4242

43-
template <typename ProblemShape, typename CtaTiler,
43+
template <bool HasKResidue, typename ProblemShape, typename CtaTiler,
4444
typename Element, typename Quant, typename Scale,
4545
typename StrideA, typename SmemLayoutA, typename TiledCopyA,
4646
typename StrideB, typename SmemLayoutB, typename TiledCopyB,
@@ -61,6 +61,20 @@ __global__ void qmm_naive_kernel(
6161
int thread_idx = int(threadIdx.x);
6262
auto [m_coord, n_coord, l_coord] = static_cast<uint3>(blockIdx);
6363

64+
auto m_max_coord = size<0>(shape_MNKL) - size<0>(cta_tiler) * m_coord; // M - BLK_M * m_coord
65+
auto n_max_coord = size<1>(shape_MNKL) - size<1>(cta_tiler) * n_coord; // N - BLK_N * n_coord
66+
67+
// Shift tensor so we handle residue of K in the 0th tile.
68+
auto shape_K = size<2>(shape_MNKL);
69+
auto bK = size<2>(cta_tiler);
70+
auto k_residue = shape_K - bK * ceil_div(shape_K, bK);
71+
if constexpr (HasKResidue) {
72+
A += k_residue * get<1>(dA);
73+
B += k_residue * get<1>(dB) * cuda::std::min(8, sizeof_bits_v<Quant>) / 8;
74+
S += k_residue * stride<1>(S_layout);
75+
Z += k_residue * stride<1>(S_layout);
76+
}
77+
6478
// Represent the full tensors.
6579
Tensor mA_mkl = make_tensor(make_gmem_ptr(A), select<0,2,3>(shape_MNKL), dA); // (M,K,L)
6680
Tensor mB_nkl = make_tensor(make_gmem_ptr<Quant>(B), select<1,2,3>(shape_MNKL), dB); // (N,K,L)
@@ -86,9 +100,6 @@ __global__ void qmm_naive_kernel(
86100
Tensor gS = local_tile(mS, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k)
87101
Tensor gZ = local_tile(mZ, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k)
88102

89-
auto m_max_coord = size<0>(shape_MNKL) - size<0>(gA) * m_coord; // M - BLK_M * m_coord
90-
auto n_max_coord = size<1>(shape_MNKL) - size<0>(gB) * n_coord; // N - BLK_N * n_coord
91-
92103
// Shared memory buffers.
93104
extern __shared__ char shared_memory[];
94105
using SharedStorage = SharedStorage<Element, SmemLayoutA, SmemLayoutB>;
@@ -159,17 +170,53 @@ __global__ void qmm_naive_kernel(
159170
__syncthreads();
160171
};
161172

173+
// Clear the rmem tiles to account for predicated off loads.
174+
if constexpr (HasKResidue) {
175+
clear(tArA);
176+
clear(tBrB);
177+
clear(tBrS);
178+
clear(tBrZ);
179+
}
180+
162181
// Prefetch first tile.
163-
fetch_gmem(0);
182+
if constexpr (HasKResidue) {
183+
Tensor tAgA_k = tAgA(_,_,_,0);
184+
CUTE_UNROLL
185+
for (int k = 0; k < size<2>(tArA); ++k) {
186+
if (get<1>(tAcA(0,0,k)) >= -k_residue) {
187+
copy_if(copy_a, tApA(_,k), tAgA_k(_,_,k), tArA(_,_,k));
188+
}
189+
}
190+
Tensor tBgB_k = tBgB(_,_,_,0);
191+
Tensor tBgS_k = tBgS(_,_,_,0);
192+
Tensor tBgZ_k = tBgZ(_,_,_,0);
193+
CUTE_UNROLL
194+
for (int k = 0; k < size<2>(tBrB); ++k) {
195+
if (get<1>(tBcB(0,0,k)) >= -k_residue) {
196+
copy_if(copy_b, tBpB(_,k), tBgB_k(_,_,k), tBrB(_,_,k));
197+
copy(tBgS_k(_,_,k), tBrS(_,_,k));
198+
copy(tBgZ_k(_,_,k), tBrZ(_,_,k));
199+
}
200+
}
201+
} else {
202+
fetch_gmem(0);
203+
}
164204

165205
// Clear accumulators.
166206
clear(tCrC);
167207

168208
// Loop over CTA tiles.
169-
auto K_TILE_MAX = size<3>(tAgA);
209+
auto K_TILE_MAX = size<3>(tAgA);
170210
for (int tile = 0; tile < K_TILE_MAX; ++tile) {
171211
store_smem();
172-
fetch_gmem((tile + 1 < K_TILE_MAX) ? tile + 1 : tile);
212+
if constexpr (HasKResidue) {
213+
// Avoid fetching full 0th-tile when there is residue.
214+
if (K_TILE_MAX > 1) {
215+
fetch_gmem((tile + 1 < K_TILE_MAX) ? tile + 1 : tile);
216+
}
217+
} else {
218+
fetch_gmem((tile + 1 < K_TILE_MAX) ? tile + 1 : tile);
219+
}
173220
gemm(mma, tCsA, tCsB, tCrC);
174221
}
175222

@@ -231,9 +278,10 @@ inline constexpr auto make_tiled_mma() {
231278
}
232279
}
233280

234-
template <typename T, bool KMajor = true>
281+
template <typename T, bool KMajor = true, bool HasKResidue = false>
235282
inline auto make_tiled_copy(auto num_threads, auto bM, auto bK) {
236-
auto n_read = Int<8>{};
283+
// TODO: Only do 1-element read for the tile of residue.
284+
auto n_read = Int<HasKResidue ? 1 : 8>{};
237285
auto atom = Copy_Atom<UniversalCopy<uint_bit_t<n_read * sizeof_bits_v<T>>>, T>{};
238286
if constexpr (KMajor) {
239287
auto k_threads = bK / n_read;
@@ -263,7 +311,7 @@ inline constexpr auto make_scales_layout(auto n, auto k, auto l, auto group_size
263311
}
264312
}
265313

266-
template <int TileM = 16, bool KMajor = true, bool SM80 = true,
314+
template <int TileM = 16, bool KMajor = true, bool SM80 = true, bool HasKResidue = false,
267315
typename Element, typename Quant, typename Scale>
268316
void qmm_naive(
269317
const Element* A,
@@ -307,11 +355,11 @@ void qmm_naive(
307355
auto sB_layout = make_smem_layout<KMajor>(bN, bK);
308356

309357
// Atoms.
310-
TiledCopy copy_a = make_tiled_copy<Element>(num_threads, bM, bK);
358+
TiledCopy copy_a = make_tiled_copy<Element, true, HasKResidue>(num_threads, bM, bK);
311359
TiledCopy copy_b = make_tiled_copy<Quant, KMajor>(num_threads, bN, bK);
312360

313361
auto* kernel = &qmm_naive_kernel<
314-
decltype(prob_shape), decltype(cta_tiler),
362+
HasKResidue, decltype(prob_shape), decltype(cta_tiler),
315363
Element, Quant, Scale,
316364
decltype(dA), decltype(sA_layout), decltype(copy_a),
317365
decltype(dB), decltype(sB_layout), decltype(copy_b),
@@ -340,6 +388,21 @@ void qmm_naive(
340388

341389
namespace mlx::core {
342390

391+
template <bool KMajor, typename F>
392+
inline void dispatch_k(bool has_k_residue, const char* tag, F&& f) {
393+
if constexpr (KMajor) {
394+
if (has_k_residue) {
395+
throw std::invalid_argument(
396+
fmt::format("{} K must be multiples of group_size.", tag));
397+
}
398+
f.template operator()<false>();
399+
} else {
400+
dispatch_bool(has_k_residue, [&](auto has_k_residue) {
401+
f.template operator()<has_k_residue.value>();
402+
});
403+
}
404+
}
405+
343406
template <typename F>
344407
inline void dispatch_element_types(Dtype dtype, const char* tag, F&& f) {
345408
if (dtype == float32) {
@@ -423,41 +486,43 @@ void qmm_impl_naive(
423486

424487
bool is_sm80 = encoder.device().compute_capability_major() >= 8;
425488
dispatch_bool(is_sm80, [&](auto sm80) {
426-
dispatch_element_types(out.dtype(), tag, [&]<typename Element>() {
427-
dispatch_quant_types<Element>(
428-
bits,
429-
group_size,
430-
mode,
431-
tag,
432-
[&]<typename Quant, typename Scale, int group_size>() {
433-
encoder.set_input_array(x);
434-
encoder.set_input_array(w);
435-
encoder.set_input_array(scales);
436-
if (biases) {
437-
encoder.set_input_array(*biases);
438-
}
439-
encoder.set_output_array(out);
440-
cutlass_gemm::qmm_naive<TileM, KMajor, sm80.value>(
441-
gpu_ptr<Element>(x),
442-
gpu_ptr<Quant>(w),
443-
gpu_ptr<Scale>(scales),
444-
biases ? gpu_ptr<Element>(*biases) : nullptr,
445-
gpu_ptr<Element>(out),
446-
m,
447-
n,
448-
k,
449-
l,
450-
broadcast_b,
451-
cute::Int<group_size>{},
452-
[&](auto* kernel,
453-
dim3 num_blocks,
454-
dim3 block_dims,
455-
uint32_t smem_bytes,
456-
void** args) {
457-
encoder.add_kernel_node_raw(
458-
kernel, num_blocks, block_dims, {}, smem_bytes, args);
459-
});
460-
});
489+
dispatch_k<KMajor>(k % group_size != 0, tag, [&]<bool has_k_residue>() {
490+
dispatch_element_types(out.dtype(), tag, [&]<typename Element>() {
491+
dispatch_quant_types<Element>(
492+
bits,
493+
group_size,
494+
mode,
495+
tag,
496+
[&]<typename Quant, typename Scale, int group_size>() {
497+
encoder.set_input_array(x);
498+
encoder.set_input_array(w);
499+
encoder.set_input_array(scales);
500+
if (biases) {
501+
encoder.set_input_array(*biases);
502+
}
503+
encoder.set_output_array(out);
504+
cutlass_gemm::qmm_naive<TileM, KMajor, sm80.value, has_k_residue>(
505+
gpu_ptr<Element>(x),
506+
gpu_ptr<Quant>(w),
507+
gpu_ptr<Scale>(scales),
508+
biases ? gpu_ptr<Element>(*biases) : nullptr,
509+
gpu_ptr<Element>(out),
510+
m,
511+
n,
512+
k,
513+
l,
514+
broadcast_b,
515+
cute::Int<group_size>{},
516+
[&](auto* kernel,
517+
dim3 num_blocks,
518+
dim3 block_dims,
519+
uint32_t smem_bytes,
520+
void** args) {
521+
encoder.add_kernel_node_raw(
522+
kernel, num_blocks, block_dims, {}, smem_bytes, args);
523+
});
524+
});
525+
});
461526
});
462527
});
463528
}

python/tests/cuda_skip.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,5 @@
2323
"TestQuantized.test_gather_qmm",
2424
"TestQuantized.test_gather_qmm_sorted",
2525
"TestQuantized.test_gather_qmm_grad",
26-
"TestQuantized.test_non_multiples",
27-
"TestQuantized.test_qmm_shapes",
28-
"TestQuantized.test_fp_qvm",
29-
"TestQuantized.test_qvm",
30-
"TestQuantized.test_qmv_small_non_multiples",
3126
"TestQuantized.test_small_matrix",
32-
"TestExportImport.test_export_quantized_model",
3327
}

0 commit comments

Comments
 (0)