@@ -40,7 +40,7 @@ cute_dequant(auto w, auto s, auto z, auto out) {
4040 }
4141}
4242
43- template <typename ProblemShape, typename CtaTiler,
43+ template <bool HasKResidue, typename ProblemShape, typename CtaTiler,
4444 typename Element, typename Quant, typename Scale,
4545 typename StrideA, typename SmemLayoutA, typename TiledCopyA,
4646 typename StrideB, typename SmemLayoutB, typename TiledCopyB,
@@ -62,6 +62,20 @@ __global__ void qmm_naive_kernel(
6262 int thread_idx = int (threadIdx .x );
6363 auto [m_coord, n_coord, l_coord] = static_cast <uint3 >(blockIdx );
6464
65+ auto m_max_coord = size<0 >(shape_MNKL) - size<0 >(cta_tiler) * m_coord; // M - BLK_M * m_coord
66+ auto n_max_coord = size<1 >(shape_MNKL) - size<1 >(cta_tiler) * n_coord; // N - BLK_N * n_coord
67+
68+ // Shift tensor so we handle residue of K in the 0th tile.
69+ auto shape_K = size<2 >(shape_MNKL);
70+ auto bK = size<2 >(cta_tiler);
71+ auto k_residue = shape_K - bK * ceil_div (shape_K, bK);
72+ if constexpr (HasKResidue) {
73+ A += k_residue * get<1 >(dA);
74+ B += k_residue * get<1 >(dB) * cuda::std::min (8 , sizeof_bits_v<Quant>) / 8 ;
75+ S += k_residue * stride<1 >(S_layout);
76+ Z += k_residue * stride<1 >(S_layout);
77+ }
78+
6579 // Represent the full tensors.
6680 Tensor mA_mkl = make_tensor (make_gmem_ptr (A), select<0 ,2 ,3 >(shape_MNKL), dA); // (M,K,L)
6781 Tensor mB_nkl = make_tensor (make_gmem_ptr<Quant>(B), select<1 ,2 ,3 >(shape_MNKL), dB); // (N,K,L)
@@ -91,9 +105,6 @@ __global__ void qmm_naive_kernel(
91105 Tensor gS = local_tile (mS , cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k)
92106 Tensor gZ = local_tile (mZ , cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k)
93107
94- auto m_max_coord = size<0 >(shape_MNKL) - size<0 >(gA ) * m_coord; // M - BLK_M * m_coord
95- auto n_max_coord = size<1 >(shape_MNKL) - size<0 >(gB ) * n_coord; // N - BLK_N * n_coord
96-
97108 // Shared memory buffers.
98109 extern __shared__ char shared_memory[];
99110 using SharedStorage = SharedStorage<Element, SmemLayoutA, SmemLayoutB>;
@@ -164,17 +175,53 @@ __global__ void qmm_naive_kernel(
164175 __syncthreads ();
165176 };
166177
178+ // Clear the rmem tiles to account for predicated off loads.
179+ if constexpr (HasKResidue) {
180+ clear (tArA);
181+ clear (tBrB);
182+ clear (tBrS);
183+ clear (tBrZ);
184+ }
185+
167186 // Prefetch first tile.
168- fetch_gmem (0 );
187+ if constexpr (HasKResidue) {
188+ Tensor tAgA_k = tAgA (_,_,_,0 );
189+ CUTE_UNROLL
190+ for (int k = 0 ; k < size<2 >(tArA); ++k) {
191+ if (get<1 >(tAcA (0 ,0 ,k)) >= -k_residue) {
192+ copy_if (copy_a, tApA (_,k), tAgA_k (_,_,k), tArA (_,_,k));
193+ }
194+ }
195+ Tensor tBgB_k = tBgB (_,_,_,0 );
196+ Tensor tBgS_k = tBgS (_,_,_,0 );
197+ Tensor tBgZ_k = tBgZ (_,_,_,0 );
198+ CUTE_UNROLL
199+ for (int k = 0 ; k < size<2 >(tBrB); ++k) {
200+ if (get<1 >(tBcB (0 ,0 ,k)) >= -k_residue) {
201+ copy_if (copy_b, tBpB (_,k), tBgB_k (_,_,k), tBrB (_,_,k));
202+ copy (tBgS_k (_,_,k), tBrS (_,_,k));
203+ copy (tBgZ_k (_,_,k), tBrZ (_,_,k));
204+ }
205+ }
206+ } else {
207+ fetch_gmem (0 );
208+ }
169209
170210 // Clear accumulators.
171211 clear (tCrC);
172212
173213 // Loop over CTA tiles.
174- auto K_TILE_MAX = size<3 >(tAgA);
214+ auto K_TILE_MAX = size<3 >(tAgA);
175215 for (int tile = 0 ; tile < K_TILE_MAX; ++tile) {
176216 store_smem ();
177- fetch_gmem ((tile + 1 < K_TILE_MAX) ? tile + 1 : tile);
217+ if constexpr (HasKResidue) {
218+ // Avoid fetching full 0th-tile when there is residue.
219+ if (K_TILE_MAX > 1 ) {
220+ fetch_gmem ((tile + 1 < K_TILE_MAX) ? tile + 1 : tile);
221+ }
222+ } else {
223+ fetch_gmem ((tile + 1 < K_TILE_MAX) ? tile + 1 : tile);
224+ }
178225 gemm (mma, tCsA, tCsB, tCrC);
179226 }
180227
@@ -236,9 +283,10 @@ inline constexpr auto make_tiled_mma() {
236283 }
237284}
238285
239- template <typename T, bool KMajor = true >
286+ template <typename T, bool KMajor = true , bool HasKResidue = false >
240287inline auto make_tiled_copy (auto num_threads, auto bM, auto bK) {
241- auto n_read = Int<8 >{};
288+ // TODO: Only do 1-element read for the tile of residue.
289+ auto n_read = Int<HasKResidue ? 1 : 8 >{};
242290 auto atom = Copy_Atom<UniversalCopy<uint_bit_t <n_read * sizeof_bits_v<T>>>, T>{};
243291 if constexpr (KMajor) {
244292 auto k_threads = bK / n_read;
@@ -268,7 +316,7 @@ inline constexpr auto make_scales_layout(auto n, auto k, auto l, auto group_size
268316 }
269317}
270318
271- template <int TileM = 16 , bool KMajor = true , bool SM80 = true ,
319+ template <int TileM = 16 , bool KMajor = true , bool SM80 = true , bool HasKResidue = false ,
272320 typename Element, typename Quant, typename Scale>
273321void qmm_naive (
274322 const Element* A,
@@ -314,11 +362,11 @@ void qmm_naive(
314362 auto sB_layout = make_smem_layout<KMajor>(bN, bK);
315363
316364 // Atoms.
317- TiledCopy copy_a = make_tiled_copy<Element>(num_threads, bM, bK);
365+ TiledCopy copy_a = make_tiled_copy<Element, true , HasKResidue >(num_threads, bM, bK);
318366 TiledCopy copy_b = make_tiled_copy<Quant, KMajor>(num_threads, bN, bK);
319367
320368 auto * kernel = &qmm_naive_kernel<
321- decltype (prob_shape), decltype (cta_tiler),
369+ HasKResidue, decltype (prob_shape), decltype (cta_tiler),
322370 Element, Quant, Scale,
323371 decltype (dA), decltype (sA_layout ), decltype (copy_a),
324372 decltype (dB), decltype (sB_layout ), decltype (copy_b),
@@ -348,6 +396,21 @@ void qmm_naive(
348396
349397namespace mlx ::core {
350398
399+ template <bool KMajor, typename F>
400+ inline void dispatch_k (bool has_k_residue, const char * tag, F&& f) {
401+ if constexpr (KMajor) {
402+ if (has_k_residue) {
403+ throw std::invalid_argument (
404+ fmt::format (" {} K must be multiples of group_size." , tag));
405+ }
406+ f.template operator ()<false >();
407+ } else {
408+ dispatch_bool (has_k_residue, [&](auto has_k_residue) {
409+ f.template operator ()<has_k_residue.value >();
410+ });
411+ }
412+ }
413+
351414template <typename F>
352415inline void dispatch_element_types (Dtype dtype, const char * tag, F&& f) {
353416 if (dtype == float32) {
@@ -429,53 +492,55 @@ void qmm_impl_naive(
429492 int n = out.shape (-1 );
430493 int k = x.shape (-1 );
431494 int l = out.size () / (m * n);
432- bool broadcast_b = w.ndim () == 2 ;
495+ bool broadcast_b = ( w.ndim () <= 2 ) || (w. size () != w. data_size ()) ;
433496
434497 bool is_sm80 = encoder.device ().compute_capability_major () >= 8 ;
435498 dispatch_bool (is_sm80, [&](auto sm80) {
436- dispatch_element_types (out.dtype (), tag, [&]<typename Element>() {
437- dispatch_quant_types<Element>(
438- bits,
439- group_size,
440- mode,
441- tag,
442- [&]<typename Quant, typename Scale, int group_size>() {
443- encoder.set_input_array (x);
444- encoder.set_input_array (w);
445- encoder.set_input_array (scales);
446- if (biases) {
447- encoder.set_input_array (*biases);
448- }
449- if (lhs_indices) {
450- encoder.set_input_array (*lhs_indices);
451- }
452- if (rhs_indices) {
453- encoder.set_input_array (*rhs_indices);
454- }
455- encoder.set_output_array (out);
456- cutlass_gemm::qmm_naive<TileM, KMajor, sm80.value >(
457- gpu_ptr<Element>(x),
458- gpu_ptr<Quant>(w),
459- gpu_ptr<Scale>(scales),
460- biases ? gpu_ptr<Element>(*biases) : nullptr ,
461- lhs_indices ? gpu_ptr<uint32_t >(*lhs_indices) : nullptr ,
462- rhs_indices ? gpu_ptr<uint32_t >(*rhs_indices) : nullptr ,
463- gpu_ptr<Element>(out),
464- m,
465- n,
466- k,
467- l,
468- broadcast_b,
469- cute::Int<group_size>{},
470- [&](auto * kernel,
471- dim3 num_blocks,
472- dim3 block_dims,
473- uint32_t smem_bytes,
474- void ** args) {
475- encoder.add_kernel_node_raw (
476- kernel, num_blocks, block_dims, {}, smem_bytes, args);
477- });
478- });
499+ dispatch_k<KMajor>(k % group_size != 0 , tag, [&]<bool has_k_residue>() {
500+ dispatch_element_types (out.dtype (), tag, [&]<typename Element>() {
501+ dispatch_quant_types<Element>(
502+ bits,
503+ group_size,
504+ mode,
505+ tag,
506+ [&]<typename Quant, typename Scale, int group_size>() {
507+ encoder.set_input_array (x);
508+ encoder.set_input_array (w);
509+ encoder.set_input_array (scales);
510+ if (biases) {
511+ encoder.set_input_array (*biases);
512+ }
513+ if (lhs_indices) {
514+ encoder.set_input_array (*lhs_indices);
515+ }
516+ if (rhs_indices) {
517+ encoder.set_input_array (*rhs_indices);
518+ }
519+ encoder.set_output_array (out);
520+ cutlass_gemm::qmm_naive<TileM, KMajor, sm80.value , has_k_residue>(
521+ gpu_ptr<Element>(x),
522+ gpu_ptr<Quant>(w),
523+ gpu_ptr<Scale>(scales),
524+ biases ? gpu_ptr<Element>(*biases) : nullptr ,
525+ lhs_indices ? gpu_ptr<uint32_t >(*lhs_indices) : nullptr ,
526+ rhs_indices ? gpu_ptr<uint32_t >(*rhs_indices) : nullptr ,
527+ gpu_ptr<Element>(out),
528+ m,
529+ n,
530+ k,
531+ l,
532+ broadcast_b,
533+ cute::Int<group_size>{},
534+ [&](auto * kernel,
535+ dim3 num_blocks,
536+ dim3 block_dims,
537+ uint32_t smem_bytes,
538+ void ** args) {
539+ encoder.add_kernel_node_raw (
540+ kernel, num_blocks, block_dims, {}, smem_bytes, args);
541+ });
542+ });
543+ });
479544 });
480545 });
481546}
0 commit comments