Skip to content

Commit 27d9ed8

Browse files
shaofeiqilhez
andauthored
opencl: add basic support for q5_0 and q5_1 (ggml-org#23548)
* opencl: add general q5_0 support * opencl: add general q5_1 support * opencl: support non-uniform workgrp size --------- Co-authored-by: Li He <lih@qti.qualcomm.com>
1 parent 335abed commit 27d9ed8

9 files changed

Lines changed: 1845 additions & 5 deletions

ggml/src/ggml-opencl/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,10 @@ set(GGML_OPENCL_KERNELS
8787
mul_mv_q4_1_f32_flat
8888
mul_mv_q4_k_f32
8989
mul_mv_q4_k_f32_flat
90+
mul_mv_q5_0_f32
91+
mul_mv_q5_0_f32_flat
92+
mul_mv_q5_1_f32
93+
mul_mv_q5_1_f32_flat
9094
mul_mv_q5_k_f32
9195
mul_mv_q5_k_f32_flat
9296
mul_mv_q6_k_f32
@@ -126,6 +130,8 @@ set(GGML_OPENCL_KERNELS
126130
mul_mm_f16_f32_l4_lm
127131
mul_mm_q4_0_f32_l4_lm
128132
mul_mm_q4_1_f32_l4_lm
133+
mul_mm_q5_0_f32_l4_lm
134+
mul_mm_q5_1_f32_l4_lm
129135
mul_mm_q8_0_f32_l4_lm
130136
mul_mm_iq4_nl_f32_l4_lm
131137
mul_mm_q4_k_f32_l4_lm

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 417 additions & 5 deletions
Large diffs are not rendered by default.

ggml/src/ggml-opencl/kernels/cvt.cl

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,53 @@ kernel void kernel_restore_block_q4_1_trans4_ns(
537537
((__global ushort8 *)(&(b->qs[0])))[0] = pre_block;
538538
}
539539

540+
//------------------------------------------------------------------------------
541+
// kernel_convert_block_q5_0
542+
// Convert the block_q5_0 format to 3 separate arrays (AOS -> SOA).
543+
// This kernel does not deshuffle the bits.
544+
//------------------------------------------------------------------------------
545+
kernel void kernel_convert_block_q5_0(
546+
global struct block_q5_0 * src0,
547+
global uchar * dst_qs,
548+
global uint * dst_qh,
549+
global half * dst_d,
550+
ulong n_blk
551+
) {
552+
if (get_global_id(0) >= n_blk) {
553+
return;
554+
}
555+
556+
global struct block_q5_0 * b = (global struct block_q5_0 *) src0 + get_global_id(0);
557+
global uchar * qs = (global uchar *) dst_qs + (QK5_0/2)*get_global_id(0);
558+
global uint * qh = (global uint *) dst_qh + get_global_id(0);
559+
global half * d = (global half *) dst_d + get_global_id(0);
560+
561+
*d = b->d;
562+
*qh = *((global uint *)(b->qh));
563+
564+
for (int i = 0; i < QK5_0/2; ++i) {
565+
qs[i] = b->qs[i];
566+
}
567+
}
568+
569+
kernel void kernel_restore_block_q5_0(
570+
global uchar * src_qs,
571+
global uint * src_qh,
572+
global half * src_d,
573+
global struct block_q5_0 * dst
574+
) {
575+
global struct block_q5_0 * b = (global struct block_q5_0 *) dst + get_global_id(0);
576+
global uchar * qs = (global uchar *) src_qs + (QK5_0/2)*get_global_id(0);
577+
global uint * qh = (global uint *) src_qh + get_global_id(0);
578+
global half * d = (global half *) src_d + get_global_id(0);
579+
580+
b->d = *d;
581+
*((global uint *)(b->qh)) = *qh;
582+
for (int i = 0; i < QK5_0/2; ++i) {
583+
b->qs[i] = qs[i];
584+
}
585+
}
586+
540587
kernel void kernel_convert_block_q5_0_trans4_ns(
541588
__global struct block_q5_0 * src0,
542589
__global uint * dst_qs,
@@ -636,6 +683,59 @@ kernel void kernel_restore_block_q5_0_trans4_ns(
636683
((__global ushort8 *)(&(b->qs[0])))[0] = pre_block;
637684
}
638685

686+
//------------------------------------------------------------------------------
687+
// kernel_convert_block_q5_1
688+
// Convert the block_q5_1 format to 4 separate arrays (AOS -> SOA).
689+
// This kernel does not deshuffle the bits.
690+
//------------------------------------------------------------------------------
691+
kernel void kernel_convert_block_q5_1(
692+
global struct block_q5_1 * src0,
693+
global uchar * dst_qs,
694+
global uint * dst_qh,
695+
global half * dst_d,
696+
global half * dst_m,
697+
ulong n_blk
698+
) {
699+
if (get_global_id(0) >= n_blk) {
700+
return;
701+
}
702+
703+
global struct block_q5_1 * b = (global struct block_q5_1 *) src0 + get_global_id(0);
704+
global uchar * qs = (global uchar *) dst_qs + (QK5_1/2)*get_global_id(0);
705+
global uint * qh = (global uint *) dst_qh + get_global_id(0);
706+
global half * d = (global half *) dst_d + get_global_id(0);
707+
global half * m = (global half *) dst_m + get_global_id(0);
708+
709+
*d = b->d;
710+
*m = b->m;
711+
*qh = *((global uint *)(b->qh));
712+
713+
for (int i = 0; i < QK5_1/2; ++i) {
714+
qs[i] = b->qs[i];
715+
}
716+
}
717+
718+
kernel void kernel_restore_block_q5_1(
719+
global uchar * src_qs,
720+
global uint * src_qh,
721+
global half * src_d,
722+
global half * src_m,
723+
global struct block_q5_1 * dst
724+
) {
725+
global struct block_q5_1 * b = (global struct block_q5_1 *) dst + get_global_id(0);
726+
global uchar * qs = (global uchar *) src_qs + (QK5_1/2)*get_global_id(0);
727+
global uint * qh = (global uint *) src_qh + get_global_id(0);
728+
global half * d = (global half *) src_d + get_global_id(0);
729+
global half * m = (global half *) src_m + get_global_id(0);
730+
731+
b->d = *d;
732+
b->m = *m;
733+
*((global uint *)(b->qh)) = *qh;
734+
for (int i = 0; i < QK5_1/2; ++i) {
735+
b->qs[i] = qs[i];
736+
}
737+
}
738+
639739
kernel void kernel_convert_block_q5_1_trans4_ns(
640740
__global struct block_q5_1 * src0,
641741
__global uint * dst_qs,
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
2+
3+
#define LOAD_VEC_A 8
4+
#define LOAD_VEC_B 4
5+
6+
#define BM 64
7+
#define BN 64
8+
#define BK 32
9+
#define TM 4
10+
#define TN 8
11+
12+
kernel void kernel_mul_mm_q5_0_f32_l4_lm(
13+
global uchar4 * src0_qs,
14+
global uint * src0_qh,
15+
global half * src0_d,
16+
global float4 * src1,
17+
ulong offset1,
18+
global float * dst,
19+
ulong offsetd,
20+
21+
int ne00,
22+
int ne01,
23+
int ne02,
24+
int ne11,
25+
int ne12,
26+
27+
int stride_a,
28+
int stride_b,
29+
int stride_d,
30+
31+
int batch_stride_a,
32+
int batch_stride_b,
33+
int batch_stride_d,
34+
35+
int r2,
36+
int r3
37+
) {
38+
src1 = (global float4*)((global char*)src1 + offset1);
39+
dst = (global float *)((global char*)dst + offsetd);
40+
41+
local float buf_a[BM * BK];
42+
local float buf_b[BN * BK];
43+
44+
const int batch_idx = get_global_id(2);
45+
46+
const int i13 = batch_idx / ne12;
47+
const int i12 = batch_idx % ne12;
48+
49+
const int i03 = i13 / r3;
50+
const int i02 = i12 / r2;
51+
52+
const int batch_idx_a = i03 * ne02 + i02;
53+
54+
const int ir = get_group_id(0);
55+
const int ic = get_group_id(1);
56+
57+
const int tid = get_local_id(0);
58+
const int th_r = tid % (BM / TM);
59+
const int th_c = tid / (BM / TM);
60+
61+
const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A);
62+
const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A);
63+
const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B);
64+
const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B);
65+
66+
const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK;
67+
const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK;
68+
69+
int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A;
70+
int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B;
71+
72+
float sums[TM * TN];
73+
float cache_a[TM];
74+
float cache_b[TN];
75+
76+
for (int i = 0; i < TM * TN; i++) {
77+
sums[i] = 0.0f;
78+
}
79+
80+
for (int block = 0; block < ne00; block += BK) {
81+
for (int l = 0; l < BM; l += loadstride_a) {
82+
if (ir*BM + loadc_a + l < ne01) {
83+
int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
84+
int ib = idx / 4;
85+
int iqs = idx % 4;
86+
87+
float d = (float)src0_d[ib];
88+
uint qh_val = src0_qh[ib];
89+
90+
global uchar4 * qs_ptr = src0_qs + ib*4 + iqs;
91+
uchar4 q = *qs_ptr;
92+
93+
uint qh_lo = qh_val >> (iqs * 4);
94+
uint qh_hi = qh_val >> (iqs * 4 + 16);
95+
96+
uchar4 b_lo = (uchar4)((uchar)qh_lo, (uchar)(qh_lo >> 1), (uchar)(qh_lo >> 2), (uchar)(qh_lo >> 3)) & (uchar)1;
97+
uchar4 b_hi = (uchar4)((uchar)qh_hi, (uchar)(qh_hi >> 1), (uchar)(qh_hi >> 2), (uchar)(qh_hi >> 3)) & (uchar)1;
98+
99+
float4 v1 = (convert_float4((q & (uchar)0x0F) | (b_lo << (uchar)4)) - 16.0f) * d;
100+
float4 v2 = (convert_float4((q >> (uchar)4) | (b_hi << (uchar)4)) - 16.0f) * d;
101+
102+
buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = v1.s0;
103+
buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = v1.s1;
104+
buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = v1.s2;
105+
buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = v1.s3;
106+
buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = v2.s0;
107+
buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = v2.s1;
108+
buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = v2.s2;
109+
buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = v2.s3;
110+
} else {
111+
buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = 0.0f;
112+
buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = 0.0f;
113+
buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = 0.0f;
114+
buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = 0.0f;
115+
buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = 0.0f;
116+
buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = 0.0f;
117+
buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = 0.0f;
118+
buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = 0.0f;
119+
}
120+
}
121+
122+
for (int l = 0; l < BN; l += loadstride_b) {
123+
if (ic*BN + loadc_b + l < ne11) {
124+
int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
125+
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
126+
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
127+
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
128+
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
129+
} else {
130+
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f;
131+
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f;
132+
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f;
133+
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f;
134+
}
135+
}
136+
137+
barrier(CLK_LOCAL_MEM_FENCE);
138+
139+
pos_a += BK / LOAD_VEC_A;
140+
pos_b += BK / LOAD_VEC_B;
141+
142+
for (int i = 0; i < BK; i++) {
143+
for (int j = 0; j < TM; j++) {
144+
cache_a[j] = buf_a[(i) * BM + th_r * TM + j];
145+
}
146+
147+
for (int j = 0; j < TN; j++) {
148+
cache_b[j] = buf_b[(i) * BN + th_c * TN + j];
149+
}
150+
151+
for (int cc = 0; cc < TN; cc++) {
152+
for (int cr = 0; cr < TM; cr++) {
153+
const int sums_idx = cc*TM + cr;
154+
sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]);
155+
}
156+
}
157+
}
158+
barrier(CLK_LOCAL_MEM_FENCE);
159+
}
160+
161+
const int dr = ir * BM + th_r * TM;
162+
const int dc = ic * BN + th_c * TN;
163+
164+
const int offsets = batch_idx * batch_stride_d;
165+
166+
for (int cc = 0; cc < TN; cc++) {
167+
for (int cr = 0; cr < TM; cr++) {
168+
if (dr + cr < ne01 && dc + cc < ne11) {
169+
dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr];
170+
}
171+
}
172+
}
173+
}

0 commit comments

Comments
 (0)