@@ -17,9 +17,9 @@ inline bool is_last_2_dims_row_contiguous(const array& x) {
1717} // namespace
1818
1919#if defined(MLX_CUDA_SM90A_ENABLED)
20- // Defined in qmm_impl_sm90_xxx .cu files .
21- template <typename TileShape, typename ClusterShape >
22- void qmm_impl_sm90 (
20+ // Defined in qmm_sm90 .cu.
21+ template <int TileN >
22+ void qmm_sm90_impl (
2323 const array& x,
2424 const array& w,
2525 const array& scales,
@@ -83,34 +83,31 @@ void qmm_sm90(
8383 cu::CommandEncoder& encoder,
8484 Stream s) {
8585#if defined(MLX_CUDA_SM90A_ENABLED)
86- auto dispatch = [&]<int tile_m, int tile_n, int cluster_m>() {
87- using cute::Int;
88- using TileShapeMN = cute::Shape<Int<tile_m>, Int<tile_n>>;
89- using ClusterShape = cute::Shape<Int<cluster_m>, Int<1 >, Int<1 >>;
90- qmm_impl_sm90<TileShapeMN, ClusterShape>(
86+ auto dispatch = [&]<int TileN>() {
87+ qmm_sm90_impl<TileN>(
9188 x, w, scales, biases, out, bits, group_size, encoder, s);
9289 };
9390 int m = out.ndim () > 1 ? out.shape (-2 ) : 1 ;
9491 if (m <= 16 ) {
95- dispatch.template operator ()<128 , 16 , 1 >();
92+ dispatch.template operator ()<16 >();
9693 } else if (m <= 32 ) {
97- dispatch.template operator ()<128 , 32 , 1 >();
94+ dispatch.template operator ()<32 >();
9895 } else if (m <= 64 ) {
99- dispatch.template operator ()<128 , 64 , 2 >();
96+ dispatch.template operator ()<64 >();
10097 } else if (m <= 128 ) {
101- dispatch.template operator ()<128 , 128 , 2 >();
98+ dispatch.template operator ()<128 >();
10299 } else {
103- dispatch.template operator ()<128 , 256 , 2 >();
100+ dispatch.template operator ()<256 >();
104101 }
105102#else
106103 throw std::runtime_error (
107104 " [quantized_matmul] Hopper-only kernel is not available." );
108105#endif // defined(MLX_CUDA_SM90A_ENABLED)
109106}
110107
111- // Defined in qmm_impl_sm80_xxx .cu files .
108+ // Defined in qmm_sm80 .cu.
112109template <int TileM>
113- void qmm_impl_sm80 (
110+ void qmm_sm80_impl (
114111 const array& x,
115112 const array& w,
116113 const array& scales,
@@ -174,7 +171,7 @@ void qmm_sm80(
174171 QuantizationMode mode,
175172 cu::CommandEncoder& encoder) {
176173 auto dispatch = [&]<int TileM>() {
177- qmm_impl_sm80 <TileM>(
174+ qmm_sm80_impl <TileM>(
178175 x,
179176 w,
180177 scales,
@@ -197,9 +194,9 @@ void qmm_sm80(
197194 }
198195}
199196
200- // Defined in qmm_impl_naive_xxx .cu files .
201- template <int TileM, bool KMajor>
202- void qmm_impl_naive (
197+ // Defined in qmm_naive .cu.
198+ template <int TileM, bool KMajor, bool HasKResidue, bool SM80 >
199+ void qmm_naive_impl (
203200 const array& x,
204201 const array& w,
205202 const array& scales,
@@ -250,8 +247,8 @@ void qmm_naive(
250247 int group_size,
251248 QuantizationMode mode,
252249 cu::CommandEncoder& encoder) {
253- auto dispatch = [&]<int TileM, bool KMajor>() {
254- qmm_impl_naive <TileM, KMajor>(
250+ auto dispatch = [&]<int TileM, bool KMajor, bool HasKResidue, bool SM80 >() {
251+ qmm_naive_impl <TileM, KMajor, HasKResidue, SM80 >(
255252 x,
256253 w,
257254 scales,
@@ -264,15 +261,37 @@ void qmm_naive(
264261 mode,
265262 encoder);
266263 };
267- dispatch_bool (transpose, [&](auto k_major) {
268- int m = out.ndim () > 1 ? out.shape (-2 ) : 1 ;
269- if (m <= 16 ) {
270- dispatch.template operator ()<16 , k_major.value >();
271- } else if (m <= 32 ) {
272- dispatch.template operator ()<32 , k_major.value >();
264+ auto dispatch_k = [&](auto k_major, bool has_k_residue, auto && f) {
265+ if constexpr (k_major.value ) {
266+ if (has_k_residue) {
267+ throw std::invalid_argument (
268+ " [quantized_matmul] K must be multiples of group_size." );
269+ }
270+ f.template operator ()<false >();
273271 } else {
274- dispatch.template operator ()<64 , k_major.value >();
272+ dispatch_bool (has_k_residue, [&](auto has_k_residue) {
273+ f.template operator ()<has_k_residue.value >();
274+ });
275275 }
276+ };
277+ int m = out.ndim () > 1 ? out.shape (-2 ) : 1 ;
278+ int k = x.shape (-1 );
279+ bool has_k_residue = k % group_size != 0 ;
280+ bool sm80 = encoder.device ().compute_capability_major () >= 8 ;
281+ dispatch_bool (transpose, [&](auto k_major) {
282+ dispatch_k (k_major, has_k_residue, [&]<bool HasKResidue>() {
283+ dispatch_bool (sm80, [&](auto sm80) {
284+ constexpr bool KMajor = k_major.value ;
285+ constexpr bool SM80 = sm80.value ;
286+ if (m <= 16 ) {
287+ dispatch.template operator ()<16 , KMajor, HasKResidue, SM80 >();
288+ } else if (m <= 32 ) {
289+ dispatch.template operator ()<32 , KMajor, HasKResidue, SM80 >();
290+ } else {
291+ dispatch.template operator ()<64 , KMajor, HasKResidue, SM80 >();
292+ }
293+ });
294+ });
276295 });
277296}
278297
0 commit comments