Skip to content

Commit 7c3501a

Browse files
committed
[cuda] initial Q1_0 backend
1 parent d0a6dfe commit 7c3501a

11 files changed

Lines changed: 248 additions & 0 deletions

File tree

ggml/src/ggml-cuda/common.cuh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -918,6 +918,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_F16> {
918918
static constexpr int qr = 1;
919919
};
920920

921+
template<>
922+
struct ggml_cuda_type_traits<GGML_TYPE_Q1_0> {
923+
static constexpr int qk = QK1_0;
924+
static constexpr int qr = QR1_0;
925+
static constexpr int qi = QI1_0;
926+
};
927+
921928
template<>
922929
struct ggml_cuda_type_traits<GGML_TYPE_Q4_0> {
923930
static constexpr int qk = QK4_0;

ggml/src/ggml-cuda/convert.cu

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -711,6 +711,8 @@ to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type) {
711711

712712
to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
713713
switch (type) {
714+
case GGML_TYPE_Q1_0:
715+
return dequantize_block_cont_cuda<QK1_0, QR1_0, dequantize_q1_0>;
714716
case GGML_TYPE_Q4_0:
715717
return dequantize_row_q4_0_cuda;
716718
case GGML_TYPE_Q4_1:
@@ -767,6 +769,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
767769

768770
to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
769771
switch (type) {
772+
case GGML_TYPE_Q1_0:
773+
return dequantize_block_cont_cuda<QK1_0, QR1_0, dequantize_q1_0>;
770774
case GGML_TYPE_Q4_0:
771775
return dequantize_row_q4_0_cuda;
772776
case GGML_TYPE_Q4_1:
@@ -822,6 +826,8 @@ to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type) {
822826
switch (type) {
823827
case GGML_TYPE_F32:
824828
return convert_unary_cuda<float>;
829+
case GGML_TYPE_Q1_0:
830+
return dequantize_block_cuda<QK1_0, QR1_0, dequantize_q1_0>;
825831
case GGML_TYPE_Q4_0:
826832
return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
827833
case GGML_TYPE_Q4_1:
@@ -843,6 +849,8 @@ to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type) {
843849
switch (type) {
844850
case GGML_TYPE_F32:
845851
return convert_unary_cuda<float, nv_bfloat16>;
852+
case GGML_TYPE_Q1_0:
853+
return dequantize_block_cuda<QK1_0, QR1_0, dequantize_q1_0>;
846854
case GGML_TYPE_Q4_0:
847855
return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
848856
case GGML_TYPE_Q4_1:
@@ -864,6 +872,8 @@ to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type) {
864872
switch (type) {
865873
case GGML_TYPE_F16:
866874
return convert_unary_cuda<half, float>;
875+
case GGML_TYPE_Q1_0:
876+
return dequantize_block_cuda<QK1_0, QR1_0, dequantize_q1_0>;
867877
case GGML_TYPE_Q4_0:
868878
return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
869879
case GGML_TYPE_Q4_1:

ggml/src/ggml-cuda/dequantize.cuh

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,28 @@
11
#include "common.cuh"
22

3+
static __device__ __forceinline__ void dequantize_q1_0(const void * vx, const int64_t ib, const int iqs, float2 & v){
4+
const block_q1_0 * x = (const block_q1_0 *) vx;
5+
6+
const float d = x[ib].d;
7+
const float neg_d = -d;
8+
9+
const int bit_index_0 = iqs;
10+
const int bit_index_1 = iqs + 1;
11+
12+
const int byte_index_0 = bit_index_0 / 8;
13+
const int bit_offset_0 = bit_index_0 % 8;
14+
15+
const int byte_index_1 = bit_index_1 / 8;
16+
const int bit_offset_1 = bit_index_1 % 8;
17+
18+
// Extract bits: 1 = +d, 0 = -d
19+
const uint8_t bit_0 = (x[ib].qs[byte_index_0] >> bit_offset_0) & 1;
20+
const uint8_t bit_1 = (x[ib].qs[byte_index_1] >> bit_offset_1) & 1;
21+
22+
v.x = bit_0 ? d : neg_d;
23+
v.y = bit_1 ? d : neg_d;
24+
}
25+
326
static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int64_t ib, const int iqs, float2 & v){
427
const block_q4_0 * x = (const block_q4_0 *) vx;
528

ggml/src/ggml-cuda/getrows.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,10 @@ static void ggml_cuda_get_rows_switch_src0_type(
179179
get_rows_cuda_float((const nv_bfloat16 *) src0_d, src1_d, dst_d,
180180
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
181181
break;
182+
case GGML_TYPE_Q1_0:
183+
get_rows_cuda_q<QK1_0, QR1_0, dequantize_q1_0>(src0_d, src1_d, dst_d,
184+
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
185+
break;
182186
case GGML_TYPE_Q4_0:
183187
get_rows_cuda_q<QK4_0, QR4_0, dequantize_q4_0>(src0_d, src1_d, dst_d,
184188
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4785,6 +4785,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
47854785
switch (a->type) {
47864786
case GGML_TYPE_F32:
47874787
case GGML_TYPE_F16:
4788+
case GGML_TYPE_Q1_0:
47884789
case GGML_TYPE_Q4_0:
47894790
case GGML_TYPE_Q4_1:
47904791
case GGML_TYPE_Q5_0:
@@ -4822,6 +4823,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
48224823
case GGML_TYPE_F32:
48234824
case GGML_TYPE_BF16:
48244825
case GGML_TYPE_I32:
4826+
case GGML_TYPE_Q1_0:
48254827
case GGML_TYPE_Q4_0:
48264828
case GGML_TYPE_Q4_1:
48274829
case GGML_TYPE_Q5_0:

ggml/src/ggml-cuda/mmq.cu

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55

66
static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
77
switch (args.type_x) {
8+
case GGML_TYPE_Q1_0:
9+
mul_mat_q_case<GGML_TYPE_Q1_0>(ctx, args, stream);
10+
break;
811
case GGML_TYPE_Q4_0:
912
mul_mat_q_case<GGML_TYPE_Q4_0>(ctx, args, stream);
1013
break;
@@ -270,6 +273,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t
270273
bool mmq_supported;
271274

272275
switch (type) {
276+
case GGML_TYPE_Q1_0:
273277
case GGML_TYPE_Q4_0:
274278
case GGML_TYPE_Q4_1:
275279
case GGML_TYPE_Q5_0:
@@ -301,6 +305,11 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t
301305
return false;
302306
}
303307

308+
// Q1_0 requires MMA (Turing+) — no DP4A fallback path
309+
if (type == GGML_TYPE_Q1_0 && !turing_mma_available(cc)) {
310+
return false;
311+
}
312+
304313
if (turing_mma_available(cc)) {
305314
return true;
306315
}

ggml/src/ggml-cuda/mmq.cuh

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ using namespace ggml_cuda_mma;
1111

1212
#define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available.
1313
#define MMQ_ITER_K 256
14+
#define MMQ_ITER_K_Q1_0 128 // For Q1_0: QK1_0=128, QI1_0=4, so threads_per_row = 128/(4*4) = 8
1415
#define MMQ_ITER_K_MXFP4_FP4 512
1516
#define MMQ_NWARPS 8
1617

@@ -57,6 +58,8 @@ static_assert(sizeof(block_fp4_mmq) == sizeof(block_q8_1_mmq), "Unexpected b
5758

5859
static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
5960
switch (type_x) {
61+
case GGML_TYPE_Q1_0:
62+
return MMQ_Q8_1_DS_LAYOUT_D4;
6063
case GGML_TYPE_Q4_0:
6164
case GGML_TYPE_Q4_1:
6265
return MMQ_Q8_1_DS_LAYOUT_DS4;
@@ -229,6 +232,7 @@ static_assert(MMQ_MMA_TILE_X_K_NVFP4 % 8 == 4, "Wrong padding.");
229232

230233
static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
231234
switch (type) {
235+
case GGML_TYPE_Q1_0: return MMQ_MMA_TILE_X_K_Q8_0;
232236
case GGML_TYPE_Q4_0: return MMQ_MMA_TILE_X_K_Q8_0;
233237
case GGML_TYPE_Q4_1: return MMQ_MMA_TILE_X_K_Q8_1;
234238
case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0;
@@ -302,6 +306,87 @@ static constexpr __device__ int mmq_get_nwarps_device() {
302306

303307
// ------------------------------------------------------------
304308

309+
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q1_0(
310+
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
311+
#if !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE))
312+
GGML_UNUSED_VARS(x, x_tile, kbx0, i_max, stride, mmq_y, need_check);
313+
NO_DEVICE_CODE;
314+
#else
315+
constexpr int nwarps = mmq_get_nwarps_device();
316+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
317+
318+
int * x_qs = (int *) x_tile;
319+
float * x_df = (float *) (x_qs + 2*MMQ_TILE_NE_K);
320+
321+
constexpr int blocks_per_iter = MMQ_ITER_K / QK1_0;
322+
constexpr int threads_per_row = blocks_per_iter * QI1_0;
323+
constexpr int nrows = warp_size / threads_per_row;
324+
constexpr int scale_entries_per_block = QK1_0 / QK8_1;
325+
constexpr int scale_entries_per_row = blocks_per_iter * scale_entries_per_block;
326+
327+
const int txi = threadIdx.x % threads_per_row;
328+
const int kbx = txi / QI1_0;
329+
const int kqsx = txi % QI1_0;
330+
331+
#pragma unroll
332+
for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
333+
int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
334+
335+
if (need_check) {
336+
i = min(i, i_max);
337+
}
338+
339+
const block_q1_0 * bxi = (const block_q1_0 *) x + kbx0 + i*stride + kbx;
340+
const int qs_offset = 4*kqsx;
341+
const int qs0 = bxi->qs[qs_offset + 0] | (bxi->qs[qs_offset + 1] << 8) |
342+
(bxi->qs[qs_offset + 2] << 16) | (bxi->qs[qs_offset + 3] << 24);
343+
344+
int unpacked_bytes[8];
345+
#pragma unroll
346+
for (int j = 0; j < 8; ++j) {
347+
const int shift = j * 4;
348+
const int bits4 = (qs0 >> shift) & 0x0F;
349+
const int b0 = (bits4 & 0x01) ? 1 : -1;
350+
const int b1 = (bits4 & 0x02) ? 1 : -1;
351+
const int b2 = (bits4 & 0x04) ? 1 : -1;
352+
const int b3 = (bits4 & 0x08) ? 1 : -1;
353+
unpacked_bytes[j] = (b0 & 0xFF) | ((b1 & 0xFF) << 8) | ((b2 & 0xFF) << 16) | ((b3 & 0xFF) << 24);
354+
}
355+
356+
const int dst_offset = kbx*(scale_entries_per_block*QI8_0) + kqsx*QI8_0;
357+
#pragma unroll
358+
for (int j = 0; j < 8; ++j) {
359+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + dst_offset + j] = unpacked_bytes[j];
360+
}
361+
}
362+
363+
const int ksx = threadIdx.x % scale_entries_per_row;
364+
const int scale_block = ksx / scale_entries_per_block;
365+
366+
#pragma unroll
367+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
368+
int i = i0 + threadIdx.y;
369+
370+
if (need_check) {
371+
i = min(i, i_max);
372+
}
373+
374+
const block_q1_0 * bxi = (const block_q1_0 *) x + kbx0 + i*stride + scale_block;
375+
376+
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + ksx] = bxi->d;
377+
}
378+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
379+
}
380+
381+
template <int mmq_x, int mmq_y>
382+
static __device__ __forceinline__ void vec_dot_q1_mmq_dp4a_disabled(
383+
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
384+
// Q1_0 intentionally targets the MMA path only.
385+
// If DP4A support is needed later for older GPUs, it should be reintroduced and validated separately.
386+
GGML_UNUSED_VARS(x, y, sum, k00, mmq_x, mmq_y);
387+
NO_DEVICE_CODE;
388+
}
389+
305390
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
306391
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
307392
constexpr int nwarps = mmq_get_nwarps_device();
@@ -3274,6 +3359,14 @@ static __device__ __forceinline__ void mmq_write_back_mma(
32743359
template <int mmq_x, int mmq_y, bool need_check, ggml_type type>
32753360
struct mmq_type_traits;
32763361

3362+
template <int mmq_x, int mmq_y, bool need_check>
3363+
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q1_0> {
3364+
static constexpr int vdr = VDR_Q1_0_Q8_1_MMQ;
3365+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q1_0<mmq_y, need_check>;
3366+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
3367+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q1_mmq_dp4a_disabled<mmq_x, mmq_y>;
3368+
};
3369+
32773370
template <int mmq_x, int mmq_y, bool need_check>
32783371
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q4_0> {
32793372
static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ;

ggml/src/ggml-cuda/mmvq.cu

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_
99

1010
static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) {
1111
switch (type) {
12+
case GGML_TYPE_Q1_0: return vec_dot_q1_0_q8_1;
1213
case GGML_TYPE_Q4_0: return vec_dot_q4_0_q8_1;
1314
case GGML_TYPE_Q4_1: return vec_dot_q4_1_q8_1;
1415
case GGML_TYPE_Q5_0: return vec_dot_q5_0_q8_1;
@@ -36,6 +37,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type)
3637

3738
static constexpr __host__ __device__ int get_vdr_mmvq(ggml_type type) {
3839
switch (type) {
40+
case GGML_TYPE_Q1_0: return VDR_Q1_0_Q8_1_MMVQ;
3941
case GGML_TYPE_Q4_0: return VDR_Q4_0_Q8_1_MMVQ;
4042
case GGML_TYPE_Q4_1: return VDR_Q4_1_Q8_1_MMVQ;
4143
case GGML_TYPE_Q5_0: return VDR_Q5_0_Q8_1_MMVQ;
@@ -886,6 +888,12 @@ static void mul_mat_vec_q_switch_type(
886888
const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
887889
const int ids_stride, cudaStream_t stream) {
888890
switch (type_x) {
891+
case GGML_TYPE_Q1_0:
892+
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q1_0>
893+
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
894+
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
895+
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
896+
break;
889897
case GGML_TYPE_Q4_0:
890898
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_0>
891899
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,

ggml/src/ggml-cuda/template-instances/generate_cu_files.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size_kq}, {head_size_v}, {ncols1}, {ncols2});\n"
3333

3434
TYPES_MMQ = [
35+
"GGML_TYPE_Q1_0",
3536
"GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
3637
"GGML_TYPE_Q2_K", "GGML_TYPE_Q3_K", "GGML_TYPE_Q4_K", "GGML_TYPE_Q5_K", "GGML_TYPE_Q6_K",
3738
"GGML_TYPE_IQ2_XXS", "GGML_TYPE_IQ2_XS", "GGML_TYPE_IQ2_S", "GGML_TYPE_IQ3_XXS", "GGML_TYPE_IQ3_S",
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
2+
3+
#include "../mmq.cuh"
4+
5+
DECL_MMQ_CASE(GGML_TYPE_Q1_0);

0 commit comments

Comments
 (0)