Skip to content

Commit f454bd7

Browse files
authored
opencl: add iq4_nl support (#22272)
* opencl: add general support for iq4_nl * opencl: add iq4_nl gemm/gemv for adreno * opencl: pack 2 lut entries into a uint
1 parent b760272 commit f454bd7

8 files changed

Lines changed: 1695 additions & 0 deletions

ggml/src/ggml-opencl/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@ set(GGML_OPENCL_KERNELS
9696
mul_mv_q6_k_f32_flat
9797
mul_mv_q8_0_f32
9898
mul_mv_q8_0_f32_flat
99+
mul_mv_iq4_nl_f32
100+
mul_mv_iq4_nl_f32_flat
99101
mul_mv_mxfp4_f32
100102
mul_mv_mxfp4_f32_flat
101103
mul_mv_id_q4_0_f32_8x_flat
@@ -110,12 +112,15 @@ set(GGML_OPENCL_KERNELS
110112
mul_mm_q4_0_f32_l4_lm
111113
mul_mm_q4_1_f32_l4_lm
112114
mul_mm_q8_0_f32_l4_lm
115+
mul_mm_iq4_nl_f32_l4_lm
113116
mul_mm_q4_k_f32_l4_lm
114117
mul_mm_q5_k_f32_l4_lm
115118
mul_mm_q6_k_f32_l4_lm
116119
mul_mm_q8_0_f32_8x4
117120
gemv_noshuffle_q4_1_f32
118121
gemm_noshuffle_q4_1_f32
122+
gemv_noshuffle_iq4_nl_f32
123+
gemm_noshuffle_iq4_nl_f32
119124
gemv_noshuffle_general_q8_0_f32
120125
gemv_noshuffle_q4_k_f32
121126
gemm_noshuffle_q4_k_f32

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 594 additions & 0 deletions
Large diffs are not rendered by default.

ggml/src/ggml-opencl/kernels/cvt.cl

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,17 @@ struct block_q6_K {
8787
half d; // super-block scale
8888
};
8989

90+
//------------------------------------------------------------------------------
91+
// block_iq4_nl
92+
//------------------------------------------------------------------------------
93+
#define QK4_NL 32
94+
95+
struct block_iq4_nl
96+
{
97+
half d;
98+
uint8_t qs[QK4_NL / 2];
99+
};
100+
90101
//------------------------------------------------------------------------------
91102
// kernel_convert_block_q4_0
92103
// Convert the block_q4_0 format to 2 separate arrays (AOS -> SOA).
@@ -895,3 +906,99 @@ kernel void kernel_restore_block_q6_K_noshuffle(
895906
b->scales[i] = s[i];
896907
}
897908
}
909+
910+
//------------------------------------------------------------------------------
911+
// kernel_convert_block_iq4_nl
912+
// Convert the block_iq4_nl format to 2 separate arrays (AOS -> SOA).
913+
//------------------------------------------------------------------------------
914+
kernel void kernel_convert_block_iq4_nl(
915+
global struct block_iq4_nl * src0,
916+
global uchar * dst_q,
917+
global half * dst_d,
918+
uchar mask_0F,
919+
uchar mask_F0,
920+
ulong n_blk
921+
) {
922+
if (get_global_id(0) >= n_blk) {
923+
return;
924+
}
925+
global struct block_iq4_nl * b = (global struct block_iq4_nl *) src0 + get_global_id(0);
926+
global uchar * q = (global uchar *) dst_q + QK4_NL/2*get_global_id(0);
927+
global half * d = (global half *) dst_d + get_global_id(0);
928+
929+
*d = b->d;
930+
931+
for (int i = 0; i < QK4_NL/2; ++i) {
932+
q[i] = b->qs[i];
933+
}
934+
}
935+
936+
kernel void kernel_restore_block_iq4_nl(
937+
global uchar * src_q,
938+
global half * src_d,
939+
global struct block_iq4_nl * dst,
940+
ulong n_blk
941+
) {
942+
if (get_global_id(0) >= n_blk) {
943+
return;
944+
}
945+
global struct block_iq4_nl * b = (global struct block_iq4_nl *) dst + get_global_id(0);
946+
global uchar * q = (global uchar *) src_q + QK4_NL/2*get_global_id(0);
947+
global half * d = (global half *) src_d + get_global_id(0);
948+
949+
b->d = *d;
950+
951+
for (int i = 0; i < QK4_NL/2; ++i) {
952+
b->qs[i] = q[i];
953+
}
954+
}
955+
956+
kernel void kernel_convert_block_iq4_nl_noshuffle(
957+
global struct block_iq4_nl * src0,
958+
global uchar * dst_q,
959+
global half * dst_d,
960+
uchar mask_0F,
961+
uchar mask_F0,
962+
ulong n_blk
963+
) {
964+
if (get_global_id(0) >= n_blk) {
965+
return;
966+
}
967+
global struct block_iq4_nl * b = (global struct block_iq4_nl *) src0 + get_global_id(0);
968+
global uchar * q = (global uchar *) dst_q + QK4_NL/2*get_global_id(0);
969+
global half * d = (global half *) dst_d + get_global_id(0);
970+
971+
*d = b->d;
972+
for (int i = 0; i < QK4_NL/4; ++i) {
973+
uchar x0 = b->qs[2*i + 0];
974+
uchar x1 = b->qs[2*i + 1];
975+
976+
q[i + 0 ] = convert_uchar(x0 & mask_0F) | convert_uchar((x1 & mask_0F) << 4);
977+
q[i + QK4_NL/4] = convert_uchar((x0 & mask_F0) >> 4) | convert_uchar(x1 & mask_F0);
978+
}
979+
}
980+
981+
kernel void kernel_restore_block_iq4_nl_noshuffle(
982+
global uchar * src_q,
983+
global half * src_d,
984+
global struct block_iq4_nl * dst,
985+
uchar mask_0F,
986+
uchar mask_F0,
987+
ulong n_blk
988+
) {
989+
if (get_global_id(0) >= n_blk) {
990+
return;
991+
}
992+
global struct block_iq4_nl * b = (global struct block_iq4_nl *) dst + get_global_id(0);
993+
global uchar * q = (global uchar *) src_q + QK4_NL/2*get_global_id(0);
994+
global half * d = (global half *) src_d + get_global_id(0);
995+
996+
b->d = *d;
997+
for (int i = 0; i < QK4_NL/4; ++i) {
998+
uchar x0 = q[i + 0 ];
999+
uchar x1 = q[i + QK4_NL/4];
1000+
1001+
b->qs[2*i + 0] = convert_uchar((x0 & mask_0F) | ((x1 & mask_0F) << 4));
1002+
b->qs[2*i + 1] = convert_uchar(((x0 & mask_F0) >> 4) | (x1 & mask_F0));
1003+
}
1004+
}
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
2+
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
3+
4+
#ifdef cl_qcom_reqd_sub_group_size
5+
#define ADRENO_GPU 1
6+
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
7+
#endif
8+
9+
constant half kvalues_iq4nl[16] = {
10+
(half)-127.f, (half)-104.f, (half)-83.f, (half)-65.f,
11+
(half) -49.f, (half) -35.f, (half)-22.f, (half)-10.f,
12+
(half) 1.f, (half) 13.f, (half) 25.f, (half) 38.f,
13+
(half) 53.f, (half) 69.f, (half) 89.f, (half)113.f
14+
};
15+
16+
// Packed LUT: 2 FP16 values per uint, 8 unique constant loads instead of 16
17+
constant uint iq4nl_packed[8] = {
18+
0xD680D7F0u, // idx 0,1: -127, -104
19+
0xD410D530u, // idx 2,3: -83, -65
20+
0xD060D220u, // idx 4,5: -49, -35
21+
0xC900CD80u, // idx 6,7: -22, -10
22+
0x4A803C00u, // idx 8,9: 1, 13
23+
0x50C04E40u, // idx 10,11: 25, 38
24+
0x545052A0u, // idx 12,13: 53, 69
25+
0x57105590u // idx 14,15: 89, 113
26+
};
27+
28+
// Packed dequant: 1 uint constant load (8-way divergence) + shift + as_half
29+
#define IQ4_NL_DEQUANT(nibble) as_half((ushort)(iq4nl_packed[(nibble) >> 1] >> (((nibble) & 1u) << 4)))
30+
31+
#ifdef ADRENO_GPU
32+
REQD_SUBGROUP_SIZE_128
33+
#endif
34+
35+
kernel void kernel_gemm_noshuffle_iq4_nl_f32(
36+
global const ushort * src0_q,
37+
global const half * src0_d,
38+
read_only image1d_buffer_t src1,
39+
global float * dst,
40+
ulong offsetd,
41+
int m,
42+
int n,
43+
int k,
44+
int n_no_padding
45+
) {
46+
dst = (global float *)((global char *)dst + offsetd);
47+
48+
int m_4 = m >> 2;
49+
int n_4 = n >> 2;
50+
51+
int gy = get_global_id(0);
52+
int gx = get_global_id(1);
53+
int gx_2 = gx << 2;
54+
55+
half8 c0 = 0, c1 = 0, c2 = 0, c3 = 0;
56+
half8 B;
57+
half4 dequantized_weights;
58+
59+
global const ushort * weight_ptr = src0_q + gx_2;
60+
global const half * scale_ptr = src0_d + gx_2;
61+
62+
for (int i = 0; i < k; i += 4) {
63+
B.s0123 = read_imageh(src1, gy*2 + (i)*(n_4));
64+
B.s4567 = read_imageh(src1, gy*2 + (i)*(n_4)+1);
65+
66+
ushort4 bits4 = vload4(0, weight_ptr + (i/4)*(m));
67+
68+
half4 scale = vload4(0, scale_ptr + (i/32)*(m));
69+
70+
// j=0
71+
dequantized_weights.s0 = IQ4_NL_DEQUANT(bits4.s0 & 0x000Fu) * scale.s0;
72+
dequantized_weights.s1 = IQ4_NL_DEQUANT(bits4.s1 & 0x000Fu) * scale.s1;
73+
dequantized_weights.s2 = IQ4_NL_DEQUANT(bits4.s2 & 0x000Fu) * scale.s2;
74+
dequantized_weights.s3 = IQ4_NL_DEQUANT(bits4.s3 & 0x000Fu) * scale.s3;
75+
c0 += B * dequantized_weights.s0;
76+
c1 += B * dequantized_weights.s1;
77+
c2 += B * dequantized_weights.s2;
78+
c3 += B * dequantized_weights.s3;
79+
80+
// j=1
81+
B.s0123 = read_imageh(src1, gy*2 + (i+1)*(n_4));
82+
B.s4567 = read_imageh(src1, gy*2 + (i+1)*(n_4)+1);
83+
dequantized_weights.s0 = IQ4_NL_DEQUANT((bits4.s0 >> 4) & 0x000Fu) * scale.s0;
84+
dequantized_weights.s1 = IQ4_NL_DEQUANT((bits4.s1 >> 4) & 0x000Fu) * scale.s1;
85+
dequantized_weights.s2 = IQ4_NL_DEQUANT((bits4.s2 >> 4) & 0x000Fu) * scale.s2;
86+
dequantized_weights.s3 = IQ4_NL_DEQUANT((bits4.s3 >> 4) & 0x000Fu) * scale.s3;
87+
c0 += B * dequantized_weights.s0;
88+
c1 += B * dequantized_weights.s1;
89+
c2 += B * dequantized_weights.s2;
90+
c3 += B * dequantized_weights.s3;
91+
92+
// j=2
93+
B.s0123 = read_imageh(src1, gy*2 + (i+2)*(n_4));
94+
B.s4567 = read_imageh(src1, gy*2 + (i+2)*(n_4)+1);
95+
dequantized_weights.s0 = IQ4_NL_DEQUANT((bits4.s0 >> 8) & 0x000Fu) * scale.s0;
96+
dequantized_weights.s1 = IQ4_NL_DEQUANT((bits4.s1 >> 8) & 0x000Fu) * scale.s1;
97+
dequantized_weights.s2 = IQ4_NL_DEQUANT((bits4.s2 >> 8) & 0x000Fu) * scale.s2;
98+
dequantized_weights.s3 = IQ4_NL_DEQUANT((bits4.s3 >> 8) & 0x000Fu) * scale.s3;
99+
c0 += B * dequantized_weights.s0;
100+
c1 += B * dequantized_weights.s1;
101+
c2 += B * dequantized_weights.s2;
102+
c3 += B * dequantized_weights.s3;
103+
104+
// j=3
105+
B.s0123 = read_imageh(src1, gy*2 + (i+3)*(n_4));
106+
B.s4567 = read_imageh(src1, gy*2 + (i+3)*(n_4)+1);
107+
dequantized_weights.s0 = IQ4_NL_DEQUANT((bits4.s0 >> 12) & 0x000Fu) * scale.s0;
108+
dequantized_weights.s1 = IQ4_NL_DEQUANT((bits4.s1 >> 12) & 0x000Fu) * scale.s1;
109+
dequantized_weights.s2 = IQ4_NL_DEQUANT((bits4.s2 >> 12) & 0x000Fu) * scale.s2;
110+
dequantized_weights.s3 = IQ4_NL_DEQUANT((bits4.s3 >> 12) & 0x000Fu) * scale.s3;
111+
c0 += B * dequantized_weights.s0;
112+
c1 += B * dequantized_weights.s1;
113+
c2 += B * dequantized_weights.s2;
114+
c3 += B * dequantized_weights.s3;
115+
}
116+
117+
int idx = (gy<<3)*m + (gx<<2);
118+
119+
if(idx+3 < m*n_no_padding){
120+
vstore4((float4)(c0.s0, c1.s0, c2.s0, c3.s0), 0, dst + idx);
121+
idx += m;
122+
}
123+
if(idx+3 < m*n_no_padding){
124+
vstore4((float4)(c0.s1, c1.s1, c2.s1, c3.s1), 0, dst + idx);
125+
idx += m;
126+
}
127+
if(idx+3 < m*n_no_padding){
128+
vstore4((float4)(c0.s2, c1.s2, c2.s2, c3.s2), 0, dst + idx);
129+
idx += m;
130+
}
131+
if(idx+3 < m*n_no_padding){
132+
vstore4((float4)(c0.s3, c1.s3, c2.s3, c3.s3), 0, dst + idx);
133+
idx += m;
134+
}
135+
if(idx+3 < m*n_no_padding){
136+
vstore4((float4)(c0.s4, c1.s4, c2.s4, c3.s4), 0, dst + idx);
137+
idx += m;
138+
}
139+
if(idx+3 < m*n_no_padding){
140+
vstore4((float4)(c0.s5, c1.s5, c2.s5, c3.s5), 0, dst + idx);
141+
idx += m;
142+
}
143+
if(idx+3 < m*n_no_padding){
144+
vstore4((float4)(c0.s6, c1.s6, c2.s6, c3.s6), 0, dst + idx);
145+
idx += m;
146+
}
147+
if(idx+3 < m*n_no_padding){
148+
vstore4((float4)(c0.s7, c1.s7, c2.s7, c3.s7), 0, dst + idx);
149+
}
150+
}

0 commit comments

Comments
 (0)