@@ -40,7 +40,7 @@ cute_dequant(auto w, auto s, auto z, auto out) {
4040 }
4141}
4242
43- template <typename ProblemShape, typename CtaTiler,
43+ template <bool HasKResidue, typename ProblemShape, typename CtaTiler,
4444 typename Element, typename Quant, typename Scale,
4545 typename StrideA, typename SmemLayoutA, typename TiledCopyA,
4646 typename StrideB, typename SmemLayoutB, typename TiledCopyB,
@@ -61,6 +61,20 @@ __global__ void qmm_naive_kernel(
6161 int thread_idx = int (threadIdx .x );
6262 auto [m_coord, n_coord, l_coord] = static_cast <uint3 >(blockIdx );
6363
64+ auto m_max_coord = size<0 >(shape_MNKL) - size<0 >(cta_tiler) * m_coord; // M - BLK_M * m_coord
65+ auto n_max_coord = size<1 >(shape_MNKL) - size<1 >(cta_tiler) * n_coord; // N - BLK_N * n_coord
66+
67+ // Shift tensor so we handle residue of K in the 0th tile.
68+ auto shape_K = size<2 >(shape_MNKL);
69+ auto bK = size<2 >(cta_tiler);
70+ auto k_residue = shape_K - bK * ceil_div (shape_K, bK);
71+ if constexpr (HasKResidue) {
72+ A += k_residue * get<1 >(dA);
73+ B += k_residue * get<1 >(dB) * cuda::std::min (8 , sizeof_bits_v<Quant>) / 8 ;
74+ S += k_residue * stride<1 >(S_layout);
75+ Z += k_residue * stride<1 >(S_layout);
76+ }
77+
6478 // Represent the full tensors.
6579 Tensor mA_mkl = make_tensor (make_gmem_ptr (A), select<0 ,2 ,3 >(shape_MNKL), dA); // (M,K,L)
6680 Tensor mB_nkl = make_tensor (make_gmem_ptr<Quant>(B), select<1 ,2 ,3 >(shape_MNKL), dB); // (N,K,L)
@@ -86,9 +100,6 @@ __global__ void qmm_naive_kernel(
86100 Tensor gS = local_tile (mS , cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k)
87101 Tensor gZ = local_tile (mZ , cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k)
88102
89- auto m_max_coord = size<0 >(shape_MNKL) - size<0 >(gA ) * m_coord; // M - BLK_M * m_coord
90- auto n_max_coord = size<1 >(shape_MNKL) - size<0 >(gB ) * n_coord; // N - BLK_N * n_coord
91-
92103 // Shared memory buffers.
93104 extern __shared__ char shared_memory[];
94105 using SharedStorage = SharedStorage<Element, SmemLayoutA, SmemLayoutB>;
@@ -159,17 +170,53 @@ __global__ void qmm_naive_kernel(
159170 __syncthreads ();
160171 };
161172
173+ // Clear the rmem tiles to account for predicated off loads.
174+ if constexpr (HasKResidue) {
175+ clear (tArA);
176+ clear (tBrB);
177+ clear (tBrS);
178+ clear (tBrZ);
179+ }
180+
162181 // Prefetch first tile.
163- fetch_gmem (0 );
182+ if constexpr (HasKResidue) {
183+ Tensor tAgA_k = tAgA (_,_,_,0 );
184+ CUTE_UNROLL
185+ for (int k = 0 ; k < size<2 >(tArA); ++k) {
186+ if (get<1 >(tAcA (0 ,0 ,k)) >= -k_residue) {
187+ copy_if (copy_a, tApA (_,k), tAgA_k (_,_,k), tArA (_,_,k));
188+ }
189+ }
190+ Tensor tBgB_k = tBgB (_,_,_,0 );
191+ Tensor tBgS_k = tBgS (_,_,_,0 );
192+ Tensor tBgZ_k = tBgZ (_,_,_,0 );
193+ CUTE_UNROLL
194+ for (int k = 0 ; k < size<2 >(tBrB); ++k) {
195+ if (get<1 >(tBcB (0 ,0 ,k)) >= -k_residue) {
196+ copy_if (copy_b, tBpB (_,k), tBgB_k (_,_,k), tBrB (_,_,k));
197+ copy (tBgS_k (_,_,k), tBrS (_,_,k));
198+ copy (tBgZ_k (_,_,k), tBrZ (_,_,k));
199+ }
200+ }
201+ } else {
202+ fetch_gmem (0 );
203+ }
164204
165205 // Clear accumulators.
166206 clear (tCrC);
167207
168208 // Loop over CTA tiles.
169- auto K_TILE_MAX = size<3 >(tAgA);
209+ auto K_TILE_MAX = size<3 >(tAgA);
170210 for (int tile = 0 ; tile < K_TILE_MAX; ++tile) {
171211 store_smem ();
172- fetch_gmem ((tile + 1 < K_TILE_MAX) ? tile + 1 : tile);
212+ if constexpr (HasKResidue) {
213+ // Avoid fetching full 0th-tile when there is residue.
214+ if (K_TILE_MAX > 1 ) {
215+ fetch_gmem ((tile + 1 < K_TILE_MAX) ? tile + 1 : tile);
216+ }
217+ } else {
218+ fetch_gmem ((tile + 1 < K_TILE_MAX) ? tile + 1 : tile);
219+ }
173220 gemm (mma, tCsA, tCsB, tCrC);
174221 }
175222
@@ -231,9 +278,10 @@ inline constexpr auto make_tiled_mma() {
231278 }
232279}
233280
234- template <typename T, bool KMajor = true >
281+ template <typename T, bool KMajor = true , bool HasKResidue = false >
235282inline auto make_tiled_copy (auto num_threads, auto bM, auto bK) {
236- auto n_read = Int<8 >{};
283+ // TODO: Only do 1-element read for the tile of residue.
284+ auto n_read = Int<HasKResidue ? 1 : 8 >{};
237285 auto atom = Copy_Atom<UniversalCopy<uint_bit_t <n_read * sizeof_bits_v<T>>>, T>{};
238286 if constexpr (KMajor) {
239287 auto k_threads = bK / n_read;
@@ -263,7 +311,7 @@ inline constexpr auto make_scales_layout(auto n, auto k, auto l, auto group_size
263311 }
264312}
265313
266- template <int TileM = 16 , bool KMajor = true , bool SM80 = true ,
314+ template <int TileM = 16 , bool KMajor = true , bool SM80 = true , bool HasKResidue = false ,
267315 typename Element, typename Quant, typename Scale>
268316void qmm_naive (
269317 const Element* A,
@@ -307,11 +355,11 @@ void qmm_naive(
307355 auto sB_layout = make_smem_layout<KMajor>(bN, bK);
308356
309357 // Atoms.
310- TiledCopy copy_a = make_tiled_copy<Element>(num_threads, bM, bK);
358+ TiledCopy copy_a = make_tiled_copy<Element, true , HasKResidue >(num_threads, bM, bK);
311359 TiledCopy copy_b = make_tiled_copy<Quant, KMajor>(num_threads, bN, bK);
312360
313361 auto * kernel = &qmm_naive_kernel<
314- decltype (prob_shape), decltype (cta_tiler),
362+ HasKResidue, decltype (prob_shape), decltype (cta_tiler),
315363 Element, Quant, Scale,
316364 decltype (dA), decltype (sA_layout ), decltype (copy_a),
317365 decltype (dB), decltype (sB_layout ), decltype (copy_b),
@@ -340,6 +388,21 @@ void qmm_naive(
340388
341389namespace mlx ::core {
342390
391+ template <bool KMajor, typename F>
392+ inline void dispatch_k (bool has_k_residue, const char * tag, F&& f) {
393+ if constexpr (KMajor) {
394+ if (has_k_residue) {
395+ throw std::invalid_argument (
396+ fmt::format (" {} K must be multiples of group_size." , tag));
397+ }
398+ f.template operator ()<false >();
399+ } else {
400+ dispatch_bool (has_k_residue, [&](auto has_k_residue) {
401+ f.template operator ()<has_k_residue.value >();
402+ });
403+ }
404+ }
405+
343406template <typename F>
344407inline void dispatch_element_types (Dtype dtype, const char * tag, F&& f) {
345408 if (dtype == float32) {
@@ -423,41 +486,43 @@ void qmm_impl_naive(
423486
424487 bool is_sm80 = encoder.device ().compute_capability_major () >= 8 ;
425488 dispatch_bool (is_sm80, [&](auto sm80) {
426- dispatch_element_types (out.dtype (), tag, [&]<typename Element>() {
427- dispatch_quant_types<Element>(
428- bits,
429- group_size,
430- mode,
431- tag,
432- [&]<typename Quant, typename Scale, int group_size>() {
433- encoder.set_input_array (x);
434- encoder.set_input_array (w);
435- encoder.set_input_array (scales);
436- if (biases) {
437- encoder.set_input_array (*biases);
438- }
439- encoder.set_output_array (out);
440- cutlass_gemm::qmm_naive<TileM, KMajor, sm80.value >(
441- gpu_ptr<Element>(x),
442- gpu_ptr<Quant>(w),
443- gpu_ptr<Scale>(scales),
444- biases ? gpu_ptr<Element>(*biases) : nullptr ,
445- gpu_ptr<Element>(out),
446- m,
447- n,
448- k,
449- l,
450- broadcast_b,
451- cute::Int<group_size>{},
452- [&](auto * kernel,
453- dim3 num_blocks,
454- dim3 block_dims,
455- uint32_t smem_bytes,
456- void ** args) {
457- encoder.add_kernel_node_raw (
458- kernel, num_blocks, block_dims, {}, smem_bytes, args);
459- });
460- });
489+ dispatch_k<KMajor>(k % group_size != 0 , tag, [&]<bool has_k_residue>() {
490+ dispatch_element_types (out.dtype (), tag, [&]<typename Element>() {
491+ dispatch_quant_types<Element>(
492+ bits,
493+ group_size,
494+ mode,
495+ tag,
496+ [&]<typename Quant, typename Scale, int group_size>() {
497+ encoder.set_input_array (x);
498+ encoder.set_input_array (w);
499+ encoder.set_input_array (scales);
500+ if (biases) {
501+ encoder.set_input_array (*biases);
502+ }
503+ encoder.set_output_array (out);
504+ cutlass_gemm::qmm_naive<TileM, KMajor, sm80.value , has_k_residue>(
505+ gpu_ptr<Element>(x),
506+ gpu_ptr<Quant>(w),
507+ gpu_ptr<Scale>(scales),
508+ biases ? gpu_ptr<Element>(*biases) : nullptr ,
509+ gpu_ptr<Element>(out),
510+ m,
511+ n,
512+ k,
513+ l,
514+ broadcast_b,
515+ cute::Int<group_size>{},
516+ [&](auto * kernel,
517+ dim3 num_blocks,
518+ dim3 block_dims,
519+ uint32_t smem_bytes,
520+ void ** args) {
521+ encoder.add_kernel_node_raw (
522+ kernel, num_blocks, block_dims, {}, smem_bytes, args);
523+ });
524+ });
525+ });
461526 });
462527 });
463528}
0 commit comments