@@ -118,6 +118,56 @@ void dequantize_bf16_t4(device const bfloat4 * src, short il, thread type4 & reg
118118}
119119#endif
120120
121+ template <typename type4x4>
122+ void dequantize_q1_0 (device const block_q1_0 * xb, short il, thread type4x4 & reg) {
123+ device const uint8_t * qs = xb->qs ;
124+ const float d = xb->d ;
125+ const float neg_d = -d;
126+
127+ const int byte_offset = il * 2 ; // il*16 bits = il*2 bytes
128+ const uint8_t b0 = qs[byte_offset];
129+ const uint8_t b1 = qs[byte_offset + 1 ];
130+
131+ float4x4 reg_f;
132+
133+ reg_f[0 ][0 ] = select (neg_d, d, bool (b0 & 0x01 ));
134+ reg_f[0 ][1 ] = select (neg_d, d, bool (b0 & 0x02 ));
135+ reg_f[0 ][2 ] = select (neg_d, d, bool (b0 & 0x04 ));
136+ reg_f[0 ][3 ] = select (neg_d, d, bool (b0 & 0x08 ));
137+ reg_f[1 ][0 ] = select (neg_d, d, bool (b0 & 0x10 ));
138+ reg_f[1 ][1 ] = select (neg_d, d, bool (b0 & 0x20 ));
139+ reg_f[1 ][2 ] = select (neg_d, d, bool (b0 & 0x40 ));
140+ reg_f[1 ][3 ] = select (neg_d, d, bool (b0 & 0x80 ));
141+
142+ reg_f[2 ][0 ] = select (neg_d, d, bool (b1 & 0x01 ));
143+ reg_f[2 ][1 ] = select (neg_d, d, bool (b1 & 0x02 ));
144+ reg_f[2 ][2 ] = select (neg_d, d, bool (b1 & 0x04 ));
145+ reg_f[2 ][3 ] = select (neg_d, d, bool (b1 & 0x08 ));
146+ reg_f[3 ][0 ] = select (neg_d, d, bool (b1 & 0x10 ));
147+ reg_f[3 ][1 ] = select (neg_d, d, bool (b1 & 0x20 ));
148+ reg_f[3 ][2 ] = select (neg_d, d, bool (b1 & 0x40 ));
149+ reg_f[3 ][3 ] = select (neg_d, d, bool (b1 & 0x80 ));
150+
151+ reg = (type4x4) reg_f;
152+ }
153+
154+ template <typename type4>
155+ void dequantize_q1_0_t4 (device const block_q1_0 * xb, short il, thread type4 & reg) {
156+ const float d = xb->d ;
157+ const float neg_d = -d;
158+ const int base = il * 4 ;
159+ const uint8_t byte = xb->qs [base / 8 ];
160+ const int s = base % 8 ;
161+
162+ float4 reg_f;
163+ reg_f[0 ] = select (neg_d, d, bool ((byte >> (s )) & 1 ));
164+ reg_f[1 ] = select (neg_d, d, bool ((byte >> (s + 1 )) & 1 ));
165+ reg_f[2 ] = select (neg_d, d, bool ((byte >> (s + 2 )) & 1 ));
166+ reg_f[3 ] = select (neg_d, d, bool ((byte >> (s + 3 )) & 1 ));
167+
168+ reg = (type4) reg_f;
169+ }
170+
121171template <typename type4x4>
122172void dequantize_q4_0 (device const block_q4_0 * xb, short il, thread type4x4 & reg) {
123173 device const uint16_t * qs = ((device const uint16_t *)xb + 1 );
@@ -152,6 +202,23 @@ void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & r
152202 }
153203}
154204
205+ void quantize_q1_0 (device const float * src, device block_q1_0 & dst) {
206+ float sum_abs = 0 .0f ;
207+ for (int j = 0 ; j < QK1_0; j++) {
208+ sum_abs += fabs (src[j]);
209+ }
210+ dst.d = sum_abs / QK1_0;
211+
212+ for (int j = 0 ; j < QK1_0 / 8 ; j++) {
213+ dst.qs [j] = 0 ;
214+ }
215+ for (int j = 0 ; j < QK1_0; j++) {
216+ if (src[j] >= 0 .0f ) {
217+ dst.qs [j / 8 ] |= (1 << (j % 8 ));
218+ }
219+ }
220+ }
221+
155222void quantize_q4_0 (device const float * src, device block_q4_0 & dst) {
156223#pragma METAL fp math_mode(safe)
157224 float amax = 0 .0f ; // absolute max
@@ -3116,6 +3183,35 @@ kernel void kernel_group_norm_f32(
31163183 }
31173184}
31183185
3186+ // Q1_0 dot product: dot = d * (2 * Σ(yl[i] where bit=1) - sumy)
3187+ inline float block_q_n_dot_y (device const block_q1_0 * qb_curr, float sumy, thread float * yl, int il) {
3188+ device const uint8_t * qs = qb_curr->qs + il / 8 ;
3189+ const uint8_t b0 = qs[0 ];
3190+ const uint8_t b1 = qs[1 ];
3191+
3192+ float acc = 0 .0f ;
3193+
3194+ acc += select (0 .0f , yl[ 0 ], bool (b0 & 0x01 ));
3195+ acc += select (0 .0f , yl[ 1 ], bool (b0 & 0x02 ));
3196+ acc += select (0 .0f , yl[ 2 ], bool (b0 & 0x04 ));
3197+ acc += select (0 .0f , yl[ 3 ], bool (b0 & 0x08 ));
3198+ acc += select (0 .0f , yl[ 4 ], bool (b0 & 0x10 ));
3199+ acc += select (0 .0f , yl[ 5 ], bool (b0 & 0x20 ));
3200+ acc += select (0 .0f , yl[ 6 ], bool (b0 & 0x40 ));
3201+ acc += select (0 .0f , yl[ 7 ], bool (b0 & 0x80 ));
3202+
3203+ acc += select (0 .0f , yl[ 8 ], bool (b1 & 0x01 ));
3204+ acc += select (0 .0f , yl[ 9 ], bool (b1 & 0x02 ));
3205+ acc += select (0 .0f , yl[10 ], bool (b1 & 0x04 ));
3206+ acc += select (0 .0f , yl[11 ], bool (b1 & 0x08 ));
3207+ acc += select (0 .0f , yl[12 ], bool (b1 & 0x10 ));
3208+ acc += select (0 .0f , yl[13 ], bool (b1 & 0x20 ));
3209+ acc += select (0 .0f , yl[14 ], bool (b1 & 0x40 ));
3210+ acc += select (0 .0f , yl[15 ], bool (b1 & 0x80 ));
3211+
3212+ return qb_curr->d * (2 .0f * acc - sumy);
3213+ }
3214+
31193215// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
31203216// il indicates where the q4 quants begin (0 or QK4_0/4)
31213217// we assume that the yl's have been multiplied with the appropriate scale factor
@@ -3337,6 +3433,85 @@ void mul_vec_q_n_f32_impl(
33373433 }
33383434}
33393435
3436+ template <int nr0, typename args_t >
3437+ void kernel_mul_mv_q1_0_f32_impl (
3438+ args_t args,
3439+ device const char * src0,
3440+ device const char * src1,
3441+ device char * dst,
3442+ threadgroup char * shmem,
3443+ uint3 tgpig,
3444+ ushort tiisg,
3445+ ushort sgitg) {
3446+ const short NSG = FC_mul_mv_nsg;
3447+
3448+ const int nb = args.ne00 /QK1_0;
3449+
3450+ const int r0 = tgpig.x ;
3451+ const int r1 = tgpig.y ;
3452+ const int im = tgpig.z ;
3453+
3454+ const int first_row = (r0 * NSG + sgitg) * nr0;
3455+
3456+ const uint i12 = im%args.ne12 ;
3457+ const uint i13 = im/args.ne12 ;
3458+
3459+ const uint64_t offset1 = r1*args.nb11 + (i12)*args.nb12 + (i13)*args.nb13 ;
3460+
3461+ device const float * y = (device const float *) (src1 + offset1);
3462+
3463+ device const block_q1_0 * ax[nr0];
3464+ for (int row = 0 ; row < nr0; ++row) {
3465+ const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2 )*args.nb02 + (i13/args.r3 )*args.nb03 ;
3466+ ax[row] = (device const block_q1_0 *) ((device char *) src0 + offset0);
3467+ }
3468+
3469+ float yl[16 ];
3470+ float sumf[nr0] = {0 .f };
3471+
3472+ const short ix = (tiisg/8 );
3473+ const short il = (tiisg%8 )*16 ;
3474+
3475+ device const float * yb = y + ix*QK1_0 + il;
3476+
3477+ for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/8 ) {
3478+ float sumy = 0 .f ;
3479+
3480+ FOR_UNROLL (short i = 0 ; i < 16 ; i++) {
3481+ yl[i] = yb[i];
3482+ sumy += yb[i];
3483+ }
3484+
3485+ FOR_UNROLL (short row = 0 ; row < nr0; row++) {
3486+ sumf[row] += block_q_n_dot_y (ax[row] + ib, sumy, yl, il);
3487+ }
3488+
3489+ yb += QK1_0 * (N_SIMDWIDTH/8 );
3490+ }
3491+
3492+ device float * dst_f32 = (device float *) dst + (uint64_t )im*args.ne0 *args.ne1 + (uint64_t )r1*args.ne0 ;
3493+
3494+ for (int row = 0 ; row < nr0; ++row) {
3495+ const float tot = simd_sum (sumf[row]);
3496+
3497+ if (tiisg == 0 && first_row + row < args.ne01 ) {
3498+ dst_f32[first_row + row] = tot;
3499+ }
3500+ }
3501+ }
3502+
3503+ [[host_name(" kernel_mul_mv_q1_0_f32" )]]
3504+ kernel void kernel_mul_mv_q1_0_f32 (
3505+ constant ggml_metal_kargs_mul_mv & args,
3506+ device const char * src0,
3507+ device const char * src1,
3508+ device char * dst,
3509+ uint3 tgpig[[threadgroup_position_in_grid]],
3510+ ushort tiisg[[thread_index_in_simdgroup]],
3511+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
3512+ kernel_mul_mv_q1_0_f32_impl<N_R0_Q1_0, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr , tgpig, tiisg, sgitg);
3513+ }
3514+
33403515kernel void kernel_mul_mv_q4_0_f32 (
33413516 constant ggml_metal_kargs_mul_mv & args,
33423517 device const char * src0,
@@ -3729,6 +3904,11 @@ template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_4")]] kernel mul_mv_ext_q4
37293904template [[host_name(" kernel_mul_mv_ext_bf16_f32_r1_5" )]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5 , bfloat4, 4 , dequantize_bf16_t4>;
37303905#endif
37313906
3907+ template [[host_name(" kernel_mul_mv_ext_q1_0_f32_r1_2" )]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2 , block_q1_0, 128 , dequantize_q1_0_t4>;
3908+ template [[host_name(" kernel_mul_mv_ext_q1_0_f32_r1_3" )]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3 , block_q1_0, 128 , dequantize_q1_0_t4>;
3909+ template [[host_name(" kernel_mul_mv_ext_q1_0_f32_r1_4" )]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4 , block_q1_0, 128 , dequantize_q1_0_t4>;
3910+ template [[host_name(" kernel_mul_mv_ext_q1_0_f32_r1_5" )]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5 , block_q1_0, 128 , dequantize_q1_0_t4>;
3911+
37323912template [[host_name(" kernel_mul_mv_ext_q4_0_f32_r1_2" )]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2 , block_q4_0, 32 , dequantize_q4_0_t4>;
37333913template [[host_name(" kernel_mul_mv_ext_q4_0_f32_r1_3" )]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3 , block_q4_0, 32 , dequantize_q4_0_t4>;
37343914template [[host_name(" kernel_mul_mv_ext_q4_0_f32_r1_4" )]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4 , block_q4_0, 32 , dequantize_q4_0_t4>;
@@ -7133,6 +7313,7 @@ kernel void kernel_cpy_f32_q(
71337313typedef decltype (kernel_cpy_f32_q<QK8_0, block_q8_0, quantize_q8_0>) cpy_f_q_t;
71347314
71357315template [[host_name(" kernel_cpy_f32_q8_0" )]] kernel cpy_f_q_t kernel_cpy_f32_q<QK8_0, block_q8_0, quantize_q8_0>;
7316+ template [[host_name(" kernel_cpy_f32_q1_0" )]] kernel cpy_f_q_t kernel_cpy_f32_q<QK1_0, block_q1_0, quantize_q1_0>;
71367317template [[host_name(" kernel_cpy_f32_q4_0" )]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_0, block_q4_0, quantize_q4_0>;
71377318template [[host_name(" kernel_cpy_f32_q4_1" )]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_1, block_q4_1, quantize_q4_1>;
71387319template [[host_name(" kernel_cpy_f32_q5_0" )]] kernel cpy_f_q_t kernel_cpy_f32_q<QK5_0, block_q5_0, quantize_q5_0>;
@@ -7173,12 +7354,14 @@ kernel void kernel_cpy_q_f32(
71737354
71747355typedef decltype (kernel_cpy_q_f32<float4x4, block_q4_0, 2 , dequantize_q4_0>) cpy_q_f_t;
71757356
7357+ template [[host_name(" kernel_cpy_q1_0_f32" )]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q1_0, 8 , dequantize_q1_0>;
71767358template [[host_name(" kernel_cpy_q4_0_f32" )]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_0, 2 , dequantize_q4_0>;
71777359template [[host_name(" kernel_cpy_q4_1_f32" )]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_1, 2 , dequantize_q4_1>;
71787360template [[host_name(" kernel_cpy_q5_0_f32" )]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_0, 2 , dequantize_q5_0>;
71797361template [[host_name(" kernel_cpy_q5_1_f32" )]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_1, 2 , dequantize_q5_1>;
71807362template [[host_name(" kernel_cpy_q8_0_f32" )]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q8_0, 2 , dequantize_q8_0>;
71817363
7364+ template [[host_name(" kernel_cpy_q1_0_f16" )]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q1_0, 8 , dequantize_q1_0>;
71827365template [[host_name(" kernel_cpy_q4_0_f16" )]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_0, 2 , dequantize_q4_0>;
71837366template [[host_name(" kernel_cpy_q4_1_f16" )]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_1, 2 , dequantize_q4_1>;
71847367template [[host_name(" kernel_cpy_q5_0_f16" )]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q5_0, 2 , dequantize_q5_0>;
@@ -9776,6 +9959,7 @@ template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_ro
97769959
97779960typedef decltype (kernel_get_rows_q<block_q4_0, 2 , dequantize_q4_0>) get_rows_q_t;
97789961
9962+ template [[host_name(" kernel_get_rows_q1_0" )]] kernel get_rows_q_t kernel_get_rows_q<block_q1_0, 8 , dequantize_q1_0>;
97799963template [[host_name(" kernel_get_rows_q4_0" )]] kernel get_rows_q_t kernel_get_rows_q<block_q4_0, 2 , dequantize_q4_0>;
97809964template [[host_name(" kernel_get_rows_q4_1" )]] kernel get_rows_q_t kernel_get_rows_q<block_q4_1, 2 , dequantize_q4_1>;
97819965template [[host_name(" kernel_get_rows_q5_0" )]] kernel get_rows_q_t kernel_get_rows_q<block_q5_0, 2 , dequantize_q5_0>;
@@ -9838,6 +10022,7 @@ template [[host_name("kernel_mul_mm_f16_f32")]] kernel mul_mm_t kernel_mul_m
983810022#if defined(GGML_METAL_HAS_BF16)
983910023template [[host_name(" kernel_mul_mm_bf16_f32" )]] kernel mul_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat, bfloat2x4, simdgroup_bfloat8x8, bfloat4x4, 1 , dequantize_bf16, bfloat, bfloat4x4, float , float2x4>;
984010024#endif
10025+ template [[host_name(" kernel_mul_mm_q1_0_f32" )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8 , dequantize_q1_0, float , float4x4, float , float2x4>;
984110026template [[host_name(" kernel_mul_mm_q4_0_f32" )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2 , dequantize_q4_0, float , float4x4, float , float2x4>;
984210027template [[host_name(" kernel_mul_mm_q4_1_f32" )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2 , dequantize_q4_1, float , float4x4, float , float2x4>;
984310028template [[host_name(" kernel_mul_mm_q5_0_f32" )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2 , dequantize_q5_0, float , float4x4, float , float2x4>;
@@ -9861,6 +10046,7 @@ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_m
986110046
986210047template [[host_name(" kernel_mul_mm_f32_f16" )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1 , dequantize_f32, float , float4x4, half, half2x4>;
986310048template [[host_name(" kernel_mul_mm_f16_f16" )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1 , dequantize_f16, half, half4x4, half, half2x4>;
10049+ template [[host_name(" kernel_mul_mm_q1_0_f16" )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8 , dequantize_q1_0, float , float4x4, half, half2x4>;
986410050template [[host_name(" kernel_mul_mm_q4_0_f16" )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2 , dequantize_q4_0, float , float4x4, half, half2x4>;
986510051template [[host_name(" kernel_mul_mm_q4_1_f16" )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2 , dequantize_q4_1, float , float4x4, half, half2x4>;
986610052template [[host_name(" kernel_mul_mm_q5_0_f16" )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2 , dequantize_q5_0, float , float4x4, half, half2x4>;
@@ -10070,6 +10256,7 @@ template [[host_name("kernel_mul_mv_id_bf16_f32_4")]] kernel kernel_mul_mv_id_4
1007010256
1007110257template [[host_name(" kernel_mul_mv_id_q8_0_f32" )]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0>>>;
1007210258
10259+ template [[host_name(" kernel_mul_mv_id_q1_0_f32" )]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q1_0_f32_impl<N_R0_Q1_0>>>;
1007310260template [[host_name(" kernel_mul_mv_id_q4_0_f32" )]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0>>>;
1007410261template [[host_name(" kernel_mul_mv_id_q4_1_f32" )]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1>>>;
1007510262template [[host_name(" kernel_mul_mv_id_q5_0_f32" )]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0>>>;
0 commit comments