Skip to content

Commit 2ffafe0

Browse files
authored
[CUDA] 3/5/6-bit quants for qmm_naive (#3352)
1 parent 5e2c442 commit 2ffafe0

8 files changed

Lines changed: 142 additions & 112 deletions

File tree

mlx/backend/cuda/quantized/qmm/cute_dequant.cuh

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,137 @@
22

33
#pragma once
44

5+
#include <cute/numeric/numeric_types.hpp>
56
#include <cute/tensor.hpp>
67
#include <cutlass/numeric_conversion.h>
78

9+
namespace cutlass {
10+
11+
using uint3b_t = integer_subbyte<3, false>;
12+
using uint5b_t = integer_subbyte<5, false>;
13+
14+
template <typename T, int N, FloatRoundStyle Round>
15+
struct NumericArrayConverter<T, uint3b_t, N, Round> {
16+
static_assert(N % 8 == 0);
17+
18+
using result_type = Array<T, N>;
19+
using source_type = Array<uint3b_t, N>;
20+
21+
CUTLASS_HOST_DEVICE
22+
static result_type convert(const source_type& source) {
23+
result_type result;
24+
auto* s_base = reinterpret_cast<const uint8_t*>(&source);
25+
CUTLASS_PRAGMA_UNROLL
26+
for (int i = 0; i < N / 8; ++i) {
27+
auto* s = s_base + i * 3;
28+
result[i * 8] = T(s[0] & 0x07);
29+
result[i * 8 + 1] = T((s[0] & 0x38) >> 3);
30+
result[i * 8 + 2] = T((s[0] & 0xc0) >> 6) + T((s[1] & 0x01) << 2);
31+
result[i * 8 + 3] = T((s[1] & 0x0e) >> 1);
32+
result[i * 8 + 4] = T((s[1] & 0x70) >> 4);
33+
result[i * 8 + 5] = T((s[1] & 0x80) >> 7) + T((s[2] & 0x03) << 1);
34+
result[i * 8 + 6] = T((s[2] & 0x1c) >> 2);
35+
result[i * 8 + 7] = T((s[2] & 0xe0) >> 5);
36+
}
37+
return result;
38+
}
39+
40+
CUTLASS_HOST_DEVICE
41+
result_type operator()(const source_type& s) const {
42+
return convert(s);
43+
}
44+
};
45+
46+
template <typename T, int N, FloatRoundStyle Round>
47+
struct NumericArrayConverter<T, uint5b_t, N, Round> {
48+
static_assert(N % 8 == 0);
49+
50+
using result_type = Array<T, N>;
51+
using source_type = Array<uint5b_t, N>;
52+
53+
CUTLASS_HOST_DEVICE
54+
static result_type convert(const source_type& source) {
55+
result_type result;
56+
auto* s_base = reinterpret_cast<const uint8_t*>(&source);
57+
CUTLASS_PRAGMA_UNROLL
58+
for (int i = 0; i < N / 8; ++i) {
59+
auto* s = s_base + i * 5;
60+
result[i * 8] = T(s[0] & 0x1f);
61+
result[i * 8 + 1] = T((s[0] & 0xe0) >> 5) + T((s[1] & 0x03) << 3);
62+
result[i * 8 + 2] = T((s[1] & 0x7c) >> 2);
63+
result[i * 8 + 3] = T((s[1] & 0x80) >> 7) + T((s[2] & 0x0f) << 1);
64+
result[i * 8 + 4] = T((s[2] & 0xf0) >> 4) + T((s[3] & 0x01) << 4);
65+
result[i * 8 + 5] = T((s[3] & 0x3e) >> 1);
66+
result[i * 8 + 6] = T((s[3] & 0xc0) >> 6) + T((s[4] & 0x07) << 2);
67+
result[i * 8 + 7] = T((s[4] & 0xf8) >> 3);
68+
}
69+
return result;
70+
}
71+
72+
CUTLASS_HOST_DEVICE
73+
result_type operator()(const source_type& s) const {
74+
return convert(s);
75+
}
76+
};
77+
78+
template <typename T, int N, FloatRoundStyle Round>
79+
struct NumericArrayConverter<T, uint6b_t, N, Round> {
80+
static_assert(N % 4 == 0);
81+
82+
using result_type = Array<T, N>;
83+
using source_type = Array<uint6b_t, N>;
84+
85+
CUTLASS_HOST_DEVICE
86+
static result_type convert(const source_type& source) {
87+
result_type result;
88+
auto* s_base = reinterpret_cast<const uint8_t*>(&source);
89+
CUTLASS_PRAGMA_UNROLL
90+
for (int i = 0; i < N / 4; ++i) {
91+
auto* s = s_base + i * 3;
92+
result[i * 4] = T(s[0] & 0x3f);
93+
result[i * 4 + 1] = T((s[0] >> 6) & 0x03) + T((s[1] & 0x0f) << 2);
94+
result[i * 4 + 2] = T((s[1] >> 4) & 0x0f) + T((s[2] & 0x03) << 4);
95+
result[i * 4 + 3] = T((s[2] >> 2) & 0x3f);
96+
}
97+
return result;
98+
}
99+
100+
CUTLASS_HOST_DEVICE
101+
result_type operator()(const source_type& s) const {
102+
return convert(s);
103+
}
104+
};
105+
106+
} // namespace cutlass
107+
108+
namespace cute {
109+
110+
// Required by tiled copy for 3/5/6-bit weights.
111+
struct uint24_t {
112+
std::array<std::uint8_t, 3> bytes;
113+
};
114+
struct uint40_t {
115+
std::array<std::uint8_t, 5> bytes;
116+
};
117+
struct uint48_t {
118+
std::array<std::uint8_t, 6> bytes;
119+
};
120+
121+
template <>
122+
struct uint_bit<24> {
123+
using type = uint24_t;
124+
};
125+
template <>
126+
struct uint_bit<40> {
127+
using type = uint40_t;
128+
};
129+
template <>
130+
struct uint_bit<48> {
131+
using type = uint48_t;
132+
};
133+
134+
} // namespace cute
135+
8136
namespace cutlass_gemm {
9137

10138
// Whether the quant type is affine quantization.

mlx/backend/cuda/quantized/qmm/qmm.cu

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ void qmm_sm90(
8080
qmm_impl_sm90<TileShapeMN, ClusterShape>(
8181
x, w, scales, biases, out, bits, group_size, encoder, s);
8282
};
83-
int m = out.shape(-2);
83+
int m = out.ndim() > 1 ? out.shape(-2) : 1;
8484
if (m <= 16) {
8585
dispatch.template operator()<128, 16, 1>();
8686
} else if (m <= 32) {
@@ -163,7 +163,7 @@ void qmm_sm80(
163163
qmm_impl_sm80<TileM>(
164164
x, w, scales, biases, out, bits, group_size, mode, encoder);
165165
};
166-
int m = out.shape(-2);
166+
int m = out.ndim() > 1 ? out.shape(-2) : 1;
167167
if (m <= 16) {
168168
dispatch.template operator()<16>();
169169
} else if (m <= 32) {
@@ -208,9 +208,6 @@ bool supports_qmm_naive(
208208
if (biases && !biases->flags().row_contiguous) {
209209
return false;
210210
}
211-
if (bits != 2 && bits != 4 && bits != 8) {
212-
return false;
213-
}
214211
return true;
215212
}
216213

@@ -230,7 +227,7 @@ void qmm_naive(
230227
x, w, scales, biases, out, bits, group_size, mode, encoder);
231228
};
232229
dispatch_bool(transpose, [&](auto k_major) {
233-
int m = out.shape(-2);
230+
int m = out.ndim() > 1 ? out.shape(-2) : 1;
234231
if (m <= 16) {
235232
dispatch.template operator()<16, k_major.value>();
236233
} else if (m <= 32) {

mlx/backend/cuda/quantized/qmm/qmm_impl_naive.cuh

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -385,8 +385,14 @@ inline void dispatch_quant_types(
385385
dispatch_groups(group_size, tag, [&]<int group_size>() {
386386
if (bits == 2) {
387387
f.template operator()<cutlass::uint2b_t, T, group_size>();
388+
} else if (bits == 3) {
389+
f.template operator()<cutlass::uint3b_t, T, group_size>();
388390
} else if (bits == 4) {
389391
f.template operator()<cutlass::uint4b_t, T, group_size>();
392+
} else if (bits == 5) {
393+
f.template operator()<cutlass::uint5b_t, T, group_size>();
394+
} else if (bits == 6) {
395+
f.template operator()<cutlass::uint6b_t, T, group_size>();
390396
} else if (bits == 8) {
391397
f.template operator()<uint8_t, T, group_size>();
392398
} else {
@@ -409,7 +415,7 @@ void qmm_impl_naive(
409415
QuantizationMode mode,
410416
cu::CommandEncoder& encoder) {
411417
const char* tag = "[quantized_matmul]";
412-
int m = out.shape(-2);
418+
int m = out.ndim() > 1 ? out.shape(-2) : 1;
413419
int n = out.shape(-1);
414420
int k = x.shape(-1);
415421
int l = out.size() / (m * n);

mlx/backend/cuda/quantized/qmm/qmm_impl_sm80.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,7 @@ void qmm_impl_sm80(
435435
QuantizationMode mode,
436436
cu::CommandEncoder& encoder) {
437437
const char* tag = "[quantized_matmul]";
438-
int m = out.shape(-2);
438+
int m = out.ndim() > 1 ? out.shape(-2) : 1;
439439
int n = out.shape(-1);
440440
int k = x.shape(-1);
441441
int l = out.size() / (m * n);

mlx/backend/cuda/quantized/qmm/qmm_impl_sm90.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ void qmm_impl_sm90(
189189
cu::CommandEncoder& encoder,
190190
Stream s) {
191191
const char* tag = "[quantized_matmul]";
192-
int m = out.shape(-2);
192+
int m = out.ndim() > 1 ? out.shape(-2) : 1;
193193
int n = out.shape(-1);
194194
int k = x.shape(-1);
195195
int l = out.size() / (m * n);

mlx/backend/cuda/quantized/qmm/qmv.cu

Lines changed: 1 addition & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -1,112 +1,12 @@
11
// Copyright © 2026 Apple Inc.
22

33
#include "mlx/backend/cuda/kernel_utils.cuh"
4+
#include "mlx/backend/cuda/quantized/qmm/cute_dequant.cuh"
45
#include "mlx/backend/cuda/quantized/qmm/qmm.h"
56
#include "mlx/dtype_utils.h"
67

78
#include <cooperative_groups.h>
89
#include <cooperative_groups/reduce.h>
9-
#include <cute/numeric/numeric_types.hpp>
10-
#include <cutlass/numeric_conversion.h>
11-
12-
namespace cutlass {
13-
14-
using uint3b_t = integer_subbyte<3, false>;
15-
using uint5b_t = integer_subbyte<5, false>;
16-
17-
template <typename T, int N, FloatRoundStyle Round>
18-
struct NumericArrayConverter<T, uint3b_t, N, Round> {
19-
static_assert(N % 8 == 0);
20-
21-
using result_type = Array<T, N>;
22-
using source_type = Array<uint3b_t, N>;
23-
24-
CUTLASS_HOST_DEVICE
25-
static result_type convert(const source_type& source) {
26-
result_type result;
27-
auto* s_base = reinterpret_cast<const uint8_t*>(&source);
28-
CUTLASS_PRAGMA_UNROLL
29-
for (int i = 0; i < N / 8; ++i) {
30-
auto* s = s_base + i * 3;
31-
result[i * 8] = T(s[0] & 0x07);
32-
result[i * 8 + 1] = T((s[0] & 0x38) >> 3);
33-
result[i * 8 + 2] = T((s[0] & 0xc0) >> 6) + T((s[1] & 0x01) << 2);
34-
result[i * 8 + 3] = T((s[1] & 0x0e) >> 1);
35-
result[i * 8 + 4] = T((s[1] & 0x70) >> 4);
36-
result[i * 8 + 5] = T((s[1] & 0x80) >> 7) + T((s[2] & 0x03) << 1);
37-
result[i * 8 + 6] = T((s[2] & 0x1c) >> 2);
38-
result[i * 8 + 7] = T((s[2] & 0xe0) >> 5);
39-
}
40-
return result;
41-
}
42-
43-
CUTLASS_HOST_DEVICE
44-
result_type operator()(const source_type& s) const {
45-
return convert(s);
46-
}
47-
};
48-
49-
template <typename T, int N, FloatRoundStyle Round>
50-
struct NumericArrayConverter<T, uint5b_t, N, Round> {
51-
static_assert(N % 8 == 0);
52-
53-
using result_type = Array<T, N>;
54-
using source_type = Array<uint5b_t, N>;
55-
56-
CUTLASS_HOST_DEVICE
57-
static result_type convert(const source_type& source) {
58-
result_type result;
59-
auto* s_base = reinterpret_cast<const uint8_t*>(&source);
60-
CUTLASS_PRAGMA_UNROLL
61-
for (int i = 0; i < N / 8; ++i) {
62-
auto* s = s_base + i * 5;
63-
result[i * 8] = T(s[0] & 0x1f);
64-
result[i * 8 + 1] = T((s[0] & 0xe0) >> 5) + T((s[1] & 0x03) << 3);
65-
result[i * 8 + 2] = T((s[1] & 0x7c) >> 2);
66-
result[i * 8 + 3] = T((s[1] & 0x80) >> 7) + T((s[2] & 0x0f) << 1);
67-
result[i * 8 + 4] = T((s[2] & 0xf0) >> 4) + T((s[3] & 0x01) << 4);
68-
result[i * 8 + 5] = T((s[3] & 0x3e) >> 1);
69-
result[i * 8 + 6] = T((s[3] & 0xc0) >> 6) + T((s[4] & 0x07) << 2);
70-
result[i * 8 + 7] = T((s[4] & 0xf8) >> 3);
71-
}
72-
return result;
73-
}
74-
75-
CUTLASS_HOST_DEVICE
76-
result_type operator()(const source_type& s) const {
77-
return convert(s);
78-
}
79-
};
80-
81-
template <typename T, int N, FloatRoundStyle Round>
82-
struct NumericArrayConverter<T, uint6b_t, N, Round> {
83-
static_assert(N % 4 == 0);
84-
85-
using result_type = Array<T, N>;
86-
using source_type = Array<uint6b_t, N>;
87-
88-
CUTLASS_HOST_DEVICE
89-
static result_type convert(const source_type& source) {
90-
result_type result;
91-
auto* s_base = reinterpret_cast<const uint8_t*>(&source);
92-
CUTLASS_PRAGMA_UNROLL
93-
for (int i = 0; i < N / 4; ++i) {
94-
auto* s = s_base + i * 3;
95-
result[i * 4] = T(s[0] & 0x3f);
96-
result[i * 4 + 1] = T((s[0] >> 6) & 0x03) + T((s[1] & 0x0f) << 2);
97-
result[i * 4 + 2] = T((s[1] >> 4) & 0x0f) + T((s[2] & 0x03) << 4);
98-
result[i * 4 + 3] = T((s[2] >> 2) & 0x3f);
99-
}
100-
return result;
101-
}
102-
103-
CUTLASS_HOST_DEVICE
104-
result_type operator()(const source_type& s) const {
105-
return convert(s);
106-
}
107-
};
108-
109-
} // namespace cutlass
11010

11111
namespace mlx::core {
11212

mlx/backend/cuda/quantized/quantized.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
7575
}
7676
};
7777

78-
int M = out.shape(-2);
78+
int M = out.ndim() > 1 ? out.shape(-2) : 1;
7979
int N = out.shape(-1);
8080
int K = x.shape(-1);
8181
int B = out.size() / (M * N);

python/tests/cuda_skip.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
"TestQuantized.test_qmm_shapes",
2828
"TestQuantized.test_fp_qvm",
2929
"TestQuantized.test_qvm",
30-
"TestQuantized.test_qvm_splitk",
3130
"TestQuantized.test_qmv_small_non_multiples",
3231
"TestQuantized.test_small_matrix",
3332
"TestExportImport.test_export_quantized_model",

0 commit comments

Comments
 (0)