@@ -118,6 +118,66 @@ void dequantize_bf16_t4(device const bfloat4 * src, short il, thread type4 & reg
118118}
119119#endif
120120
121+ template <typename type4x4>
122+ void dequantize_q1_0 (device const block_q1_0 * xb, short il, thread type4x4 & reg) {
123+ device const uint8_t * qs = xb->qs ;
124+ const float d = xb->d ;
125+ const float neg_d = -d;
126+
127+ // Process 16 bits starting at offset il*16
128+ // Optimization: process 2 bytes (16 bits) at once for better memory access
129+ const int byte_offset = il * 2 ; // il*16 bits = il*2 bytes
130+ const uint8_t b0 = qs[byte_offset];
131+ const uint8_t b1 = qs[byte_offset + 1 ];
132+
133+ float4x4 reg_f;
134+
135+ // Unroll completely for better ILP
136+ // First byte (bits 0-7)
137+ reg_f[0 ][0 ] = (b0 & 0x01 ) ? d : neg_d;
138+ reg_f[0 ][1 ] = (b0 & 0x02 ) ? d : neg_d;
139+ reg_f[0 ][2 ] = (b0 & 0x04 ) ? d : neg_d;
140+ reg_f[0 ][3 ] = (b0 & 0x08 ) ? d : neg_d;
141+ reg_f[1 ][0 ] = (b0 & 0x10 ) ? d : neg_d;
142+ reg_f[1 ][1 ] = (b0 & 0x20 ) ? d : neg_d;
143+ reg_f[1 ][2 ] = (b0 & 0x40 ) ? d : neg_d;
144+ reg_f[1 ][3 ] = (b0 & 0x80 ) ? d : neg_d;
145+
146+ // Second byte (bits 8-15)
147+ reg_f[2 ][0 ] = (b1 & 0x01 ) ? d : neg_d;
148+ reg_f[2 ][1 ] = (b1 & 0x02 ) ? d : neg_d;
149+ reg_f[2 ][2 ] = (b1 & 0x04 ) ? d : neg_d;
150+ reg_f[2 ][3 ] = (b1 & 0x08 ) ? d : neg_d;
151+ reg_f[3 ][0 ] = (b1 & 0x10 ) ? d : neg_d;
152+ reg_f[3 ][1 ] = (b1 & 0x20 ) ? d : neg_d;
153+ reg_f[3 ][2 ] = (b1 & 0x40 ) ? d : neg_d;
154+ reg_f[3 ][3 ] = (b1 & 0x80 ) ? d : neg_d;
155+
156+ reg = (type4x4) reg_f;
157+ }
158+
159+ template <typename type4>
160+ void dequantize_q1_0_t4 (device const block_q1_0 * xb, short il, thread type4 & reg) {
161+ device const uint8_t * qs = xb->qs ;
162+ const float d = xb->d ;
163+
164+ float4 reg_f;
165+
166+ // Process 4 bits for each call
167+ const int offset = il * 4 ;
168+
169+ for (int i = 0 ; i < 4 ; i++) {
170+ const int bit_idx = offset + i;
171+ const int byte_idx = bit_idx / 8 ;
172+ const int bit_offset = bit_idx % 8 ;
173+
174+ const bool bit_val = (qs[byte_idx] >> bit_offset) & 1 ;
175+ reg_f[i] = bit_val ? d : -d;
176+ }
177+
178+ reg = (type4) reg_f;
179+ }
180+
121181template <typename type4x4>
122182void dequantize_q4_0 (device const block_q4_0 * xb, short il, thread type4x4 & reg) {
123183 device const uint16_t * qs = ((device const uint16_t *)xb + 1 );
@@ -3116,6 +3176,29 @@ kernel void kernel_group_norm_f32(
31163176 }
31173177}
31183178
3179+ // function for calculate inner product between part of a q1_0 block and 16 floats (yl), sumy is SUM(yl[i])
3180+ // il indicates where the q1 quants begin (0, 16, 32, ..., 112 for 128-element block)
3181+ // we assume that the yl's have been multiplied with the appropriate scale factor
3182+ inline float block_q_n_dot_y (device const block_q1_0 * qb_curr, float sumy, thread float * yl, int il) {
3183+ float d = qb_curr->d ;
3184+
3185+ float acc = 0 .0f ;
3186+
3187+ // il represents which 16-element chunk of the 128-element block (0, 16, 32, ..., 112)
3188+ // 16 weights = 16 bits = 2 bytes
3189+ const int byte_offset = il / 8 ;
3190+ device const uint8_t * qs = qb_curr->qs + byte_offset;
3191+
3192+ for (int i = 0 ; i < 16 ; i++) {
3193+ const uint8_t byte_idx = i / 8 ;
3194+ const uint8_t bit_idx = i % 8 ;
3195+ const int8_t qval = ((qs[byte_idx] >> bit_idx) & 1 ) ? 1 : -1 ;
3196+ acc += yl[i] * qval;
3197+ }
3198+
3199+ return d * acc;
3200+ }
3201+
31193202// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
31203203// il indicates where the q4 quants begin (0 or QK4_0/4)
31213204// we assume that the yl's have been multiplied with the appropriate scale factor
@@ -3337,6 +3420,78 @@ void mul_vec_q_n_f32_impl(
33373420 }
33383421}
33393422
3423+ kernel void kernel_mul_mv_q1_0_f32 (
3424+ constant ggml_metal_kargs_mul_mv & args,
3425+ device const char * src0,
3426+ device const char * src1,
3427+ device char * dst,
3428+ uint3 tgpig[[threadgroup_position_in_grid]],
3429+ ushort tiisg[[thread_index_in_simdgroup]],
3430+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
3431+ // Q1_0-specific implementation with 128-element blocks
3432+ const int nb = args.ne00 /QK1_0;
3433+
3434+ const int r0 = tgpig.x ;
3435+ const int r1 = tgpig.y ;
3436+ const int im = tgpig.z ;
3437+
3438+ const int first_row = (r0 * N_SG_Q1_0 + sgitg) * N_R0_Q1_0;
3439+
3440+ const uint i12 = im%args.ne12 ;
3441+ const uint i13 = im/args.ne12 ;
3442+
3443+ const uint64_t offset1 = r1*args.nb11 + (i12)*args.nb12 + (i13)*args.nb13 ;
3444+
3445+ device const float * y = (device const float *) (src1 + offset1);
3446+
3447+ // pointers to src0 rows
3448+ device const block_q1_0 * ax[N_R0_Q1_0];
3449+ for (int row = 0 ; row < N_R0_Q1_0; ++row) {
3450+ const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2 )*args.nb02 + (i13/args.r3 )*args.nb03 ;
3451+
3452+ ax[row] = (device const block_q1_0 *) ((device char *) src0 + offset0);
3453+ }
3454+
3455+ float yl[16 ]; // src1 vector cache
3456+ float sumf[N_R0_Q1_0] = {0 .f };
3457+
3458+ // For 128-element blocks, we need 8 passes of 16 elements each
3459+ // Each thread processes a different 16-element chunk
3460+ const short ix = (tiisg/8 ); // which block (0 to 3 for 32 threads / 8)
3461+ const short il = (tiisg%8 )*16 ; // which 16-element chunk within the 128-element block (0, 16, 32, ..., 112)
3462+
3463+ device const float * yb = y + ix*QK1_0 + il;
3464+
3465+ // each thread in a SIMD group deals with 1/8 of a block (16 elements out of 128)
3466+ for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/8 ) {
3467+ float sumy = 0 .f ;
3468+
3469+ // Q1_0: simple copy
3470+ #pragma unroll
3471+ for (short i = 0 ; i < 16 ; i++) {
3472+ yl[i] = yb[i];
3473+ sumy += yb[i];
3474+ }
3475+
3476+ #pragma unroll
3477+ for (short row = 0 ; row < N_R0_Q1_0; row++) {
3478+ sumf[row] += block_q_n_dot_y (ax[row] + ib, sumy, yl, il);
3479+ }
3480+
3481+ yb += QK1_0 * (N_SIMDWIDTH/8 );
3482+ }
3483+
3484+ device float * dst_f32 = (device float *) dst + (uint64_t )im*args.ne0 *args.ne1 + (uint64_t )r1*args.ne0 ;
3485+
3486+ for (int row = 0 ; row < N_R0_Q1_0; ++row) {
3487+ const float tot = simd_sum (sumf[row]);
3488+
3489+ if (tiisg == 0 && first_row + row < args.ne01 ) {
3490+ dst_f32[first_row + row] = tot;
3491+ }
3492+ }
3493+ }
3494+
33403495kernel void kernel_mul_mv_q4_0_f32 (
33413496 constant ggml_metal_kargs_mul_mv & args,
33423497 device const char * src0,
@@ -3729,6 +3884,11 @@ template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_4")]] kernel mul_mv_ext_q4
37293884template [[host_name(" kernel_mul_mv_ext_bf16_f32_r1_5" )]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5 , bfloat4, 4 , dequantize_bf16_t4>;
37303885#endif
37313886
3887+ template [[host_name(" kernel_mul_mv_ext_q1_0_f32_r1_2" )]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2 , block_q1_0, 128 , dequantize_q1_0_t4>;
3888+ template [[host_name(" kernel_mul_mv_ext_q1_0_f32_r1_3" )]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3 , block_q1_0, 128 , dequantize_q1_0_t4>;
3889+ template [[host_name(" kernel_mul_mv_ext_q1_0_f32_r1_4" )]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4 , block_q1_0, 128 , dequantize_q1_0_t4>;
3890+ template [[host_name(" kernel_mul_mv_ext_q1_0_f32_r1_5" )]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5 , block_q1_0, 128 , dequantize_q1_0_t4>;
3891+
37323892template [[host_name(" kernel_mul_mv_ext_q4_0_f32_r1_2" )]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2 , block_q4_0, 32 , dequantize_q4_0_t4>;
37333893template [[host_name(" kernel_mul_mv_ext_q4_0_f32_r1_3" )]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3 , block_q4_0, 32 , dequantize_q4_0_t4>;
37343894template [[host_name(" kernel_mul_mv_ext_q4_0_f32_r1_4" )]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4 , block_q4_0, 32 , dequantize_q4_0_t4>;
@@ -9838,6 +9998,7 @@ template [[host_name("kernel_mul_mm_f16_f32")]] kernel mul_mm_t kernel_mul_m
98389998#if defined(GGML_METAL_HAS_BF16)
98399999template [[host_name(" kernel_mul_mm_bf16_f32" )]] kernel mul_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat, bfloat2x4, simdgroup_bfloat8x8, bfloat4x4, 1 , dequantize_bf16, bfloat, bfloat4x4, float , float2x4>;
984010000#endif
10001+ template [[host_name(" kernel_mul_mm_q1_0_f32" )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8 , dequantize_q1_0, float , float4x4, float , float2x4>;
984110002template [[host_name(" kernel_mul_mm_q4_0_f32" )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2 , dequantize_q4_0, float , float4x4, float , float2x4>;
984210003template [[host_name(" kernel_mul_mm_q4_1_f32" )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2 , dequantize_q4_1, float , float4x4, float , float2x4>;
984310004template [[host_name(" kernel_mul_mm_q5_0_f32" )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2 , dequantize_q5_0, float , float4x4, float , float2x4>;
0 commit comments