@@ -124,56 +124,46 @@ void dequantize_q1_0(device const block_q1_0 * xb, short il, thread type4x4 & re
124124 const float d = xb->d ;
125125 const float neg_d = -d;
126126
127- // Process 16 bits starting at offset il*16
128- // Optimization: process 2 bytes (16 bits) at once for better memory access
129127 const int byte_offset = il * 2 ; // il*16 bits = il*2 bytes
130128 const uint8_t b0 = qs[byte_offset];
131129 const uint8_t b1 = qs[byte_offset + 1 ];
132130
133131 float4x4 reg_f;
134132
135- // Unroll completely for better ILP
136- // First byte (bits 0-7)
137- reg_f[0 ][0 ] = (b0 & 0x01 ) ? d : neg_d;
138- reg_f[0 ][1 ] = (b0 & 0x02 ) ? d : neg_d;
139- reg_f[0 ][2 ] = (b0 & 0x04 ) ? d : neg_d;
140- reg_f[0 ][3 ] = (b0 & 0x08 ) ? d : neg_d;
141- reg_f[1 ][0 ] = (b0 & 0x10 ) ? d : neg_d;
142- reg_f[1 ][1 ] = (b0 & 0x20 ) ? d : neg_d;
143- reg_f[1 ][2 ] = (b0 & 0x40 ) ? d : neg_d;
144- reg_f[1 ][3 ] = (b0 & 0x80 ) ? d : neg_d;
145-
146- // Second byte (bits 8-15)
147- reg_f[2 ][0 ] = (b1 & 0x01 ) ? d : neg_d;
148- reg_f[2 ][1 ] = (b1 & 0x02 ) ? d : neg_d;
149- reg_f[2 ][2 ] = (b1 & 0x04 ) ? d : neg_d;
150- reg_f[2 ][3 ] = (b1 & 0x08 ) ? d : neg_d;
151- reg_f[3 ][0 ] = (b1 & 0x10 ) ? d : neg_d;
152- reg_f[3 ][1 ] = (b1 & 0x20 ) ? d : neg_d;
153- reg_f[3 ][2 ] = (b1 & 0x40 ) ? d : neg_d;
154- reg_f[3 ][3 ] = (b1 & 0x80 ) ? d : neg_d;
133+ reg_f[0 ][0 ] = select (neg_d, d, bool (b0 & 0x01 ));
134+ reg_f[0 ][1 ] = select (neg_d, d, bool (b0 & 0x02 ));
135+ reg_f[0 ][2 ] = select (neg_d, d, bool (b0 & 0x04 ));
136+ reg_f[0 ][3 ] = select (neg_d, d, bool (b0 & 0x08 ));
137+ reg_f[1 ][0 ] = select (neg_d, d, bool (b0 & 0x10 ));
138+ reg_f[1 ][1 ] = select (neg_d, d, bool (b0 & 0x20 ));
139+ reg_f[1 ][2 ] = select (neg_d, d, bool (b0 & 0x40 ));
140+ reg_f[1 ][3 ] = select (neg_d, d, bool (b0 & 0x80 ));
141+
142+ reg_f[2 ][0 ] = select (neg_d, d, bool (b1 & 0x01 ));
143+ reg_f[2 ][1 ] = select (neg_d, d, bool (b1 & 0x02 ));
144+ reg_f[2 ][2 ] = select (neg_d, d, bool (b1 & 0x04 ));
145+ reg_f[2 ][3 ] = select (neg_d, d, bool (b1 & 0x08 ));
146+ reg_f[3 ][0 ] = select (neg_d, d, bool (b1 & 0x10 ));
147+ reg_f[3 ][1 ] = select (neg_d, d, bool (b1 & 0x20 ));
148+ reg_f[3 ][2 ] = select (neg_d, d, bool (b1 & 0x40 ));
149+ reg_f[3 ][3 ] = select (neg_d, d, bool (b1 & 0x80 ));
155150
156151 reg = (type4x4) reg_f;
157152}
158153
159154template <typename type4>
160155void dequantize_q1_0_t4 (device const block_q1_0 * xb, short il, thread type4 & reg) {
161- device const uint8_t * qs = xb->qs ;
162156 const float d = xb->d ;
157+ const float neg_d = -d;
158+ const int base = il * 4 ;
159+ const uint8_t byte = xb->qs [base / 8 ];
160+ const int s = base % 8 ;
163161
164162 float4 reg_f;
165-
166- // Process 4 bits for each call
167- const int offset = il * 4 ;
168-
169- for (int i = 0 ; i < 4 ; i++) {
170- const int bit_idx = offset + i;
171- const int byte_idx = bit_idx / 8 ;
172- const int bit_offset = bit_idx % 8 ;
173-
174- const bool bit_val = (qs[byte_idx] >> bit_offset) & 1 ;
175- reg_f[i] = bit_val ? d : -d;
176- }
163+ reg_f[0 ] = select (neg_d, d, bool ((byte >> (s )) & 1 ));
164+ reg_f[1 ] = select (neg_d, d, bool ((byte >> (s + 1 )) & 1 ));
165+ reg_f[2 ] = select (neg_d, d, bool ((byte >> (s + 2 )) & 1 ));
166+ reg_f[3 ] = select (neg_d, d, bool ((byte >> (s + 3 )) & 1 ));
177167
178168 reg = (type4) reg_f;
179169}
@@ -3176,27 +3166,33 @@ kernel void kernel_group_norm_f32(
31763166 }
31773167}
31783168
3179- // function for calculate inner product between part of a q1_0 block and 16 floats (yl), sumy is SUM(yl[i])
3180- // il indicates where the q1 quants begin (0, 16, 32, ..., 112 for 128-element block)
3181- // we assume that the yl's have been multiplied with the appropriate scale factor
3169+ // Q1_0 dot product: dot = d * (2 * Σ(yl[i] where bit=1) - sumy)
31823170inline float block_q_n_dot_y (device const block_q1_0 * qb_curr, float sumy, thread float * yl, int il) {
3183- float d = qb_curr->d ;
3171+ device const uint8_t * qs = qb_curr->qs + il / 8 ;
3172+ const uint8_t b0 = qs[0 ];
3173+ const uint8_t b1 = qs[1 ];
31843174
31853175 float acc = 0 .0f ;
31863176
3187- // il represents which 16-element chunk of the 128-element block (0, 16, 32, ..., 112)
3188- // 16 weights = 16 bits = 2 bytes
3189- const int byte_offset = il / 8 ;
3190- device const uint8_t * qs = qb_curr->qs + byte_offset;
3177+ acc += select (0 .0f , yl[ 0 ], bool (b0 & 0x01 ));
3178+ acc += select (0 .0f , yl[ 1 ], bool (b0 & 0x02 ));
3179+ acc += select (0 .0f , yl[ 2 ], bool (b0 & 0x04 ));
3180+ acc += select (0 .0f , yl[ 3 ], bool (b0 & 0x08 ));
3181+ acc += select (0 .0f , yl[ 4 ], bool (b0 & 0x10 ));
3182+ acc += select (0 .0f , yl[ 5 ], bool (b0 & 0x20 ));
3183+ acc += select (0 .0f , yl[ 6 ], bool (b0 & 0x40 ));
3184+ acc += select (0 .0f , yl[ 7 ], bool (b0 & 0x80 ));
31913185
3192- for (int i = 0 ; i < 16 ; i++) {
3193- const uint8_t byte_idx = i / 8 ;
3194- const uint8_t bit_idx = i % 8 ;
3195- const int8_t qval = ((qs[byte_idx] >> bit_idx) & 1 ) ? 1 : -1 ;
3196- acc += yl[i] * qval;
3197- }
3186+ acc += select (0 .0f , yl[ 8 ], bool (b1 & 0x01 ));
3187+ acc += select (0 .0f , yl[ 9 ], bool (b1 & 0x02 ));
3188+ acc += select (0 .0f , yl[10 ], bool (b1 & 0x04 ));
3189+ acc += select (0 .0f , yl[11 ], bool (b1 & 0x08 ));
3190+ acc += select (0 .0f , yl[12 ], bool (b1 & 0x10 ));
3191+ acc += select (0 .0f , yl[13 ], bool (b1 & 0x20 ));
3192+ acc += select (0 .0f , yl[14 ], bool (b1 & 0x40 ));
3193+ acc += select (0 .0f , yl[15 ], bool (b1 & 0x80 ));
31983194
3199- return d * acc;
3195+ return qb_curr-> d * ( 2 . 0f * acc - sumy) ;
32003196}
32013197
32023198// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
@@ -3428,7 +3424,6 @@ kernel void kernel_mul_mv_q1_0_f32(
34283424 uint3 tgpig[[threadgroup_position_in_grid]],
34293425 ushort tiisg[[thread_index_in_simdgroup]],
34303426 ushort sgitg[[simdgroup_index_in_threadgroup]]) {
3431- // Q1_0-specific implementation with 128-element blocks
34323427 const int nb = args.ne00 /QK1_0;
34333428
34343429 const int r0 = tgpig.x ;
@@ -3444,29 +3439,23 @@ kernel void kernel_mul_mv_q1_0_f32(
34443439
34453440 device const float * y = (device const float *) (src1 + offset1);
34463441
3447- // pointers to src0 rows
34483442 device const block_q1_0 * ax[N_R0_Q1_0];
34493443 for (int row = 0 ; row < N_R0_Q1_0; ++row) {
34503444 const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2 )*args.nb02 + (i13/args.r3 )*args.nb03 ;
3451-
34523445 ax[row] = (device const block_q1_0 *) ((device char *) src0 + offset0);
34533446 }
34543447
3455- float yl[16 ]; // src1 vector cache
3448+ float yl[16 ];
34563449 float sumf[N_R0_Q1_0] = {0 .f };
34573450
3458- // For 128-element blocks, we need 8 passes of 16 elements each
3459- // Each thread processes a different 16-element chunk
3460- const short ix = (tiisg/8 ); // which block (0 to 3 for 32 threads / 8)
3461- const short il = (tiisg%8 )*16 ; // which 16-element chunk within the 128-element block (0, 16, 32, ..., 112)
3451+ const short ix = (tiisg/8 );
3452+ const short il = (tiisg%8 )*16 ;
34623453
34633454 device const float * yb = y + ix*QK1_0 + il;
34643455
3465- // each thread in a SIMD group deals with 1/8 of a block (16 elements out of 128)
34663456 for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/8 ) {
34673457 float sumy = 0 .f ;
34683458
3469- // Q1_0: simple copy
34703459#pragma unroll
34713460 for (short i = 0 ; i < 16 ; i++) {
34723461 yl[i] = yb[i];
@@ -10022,6 +10011,7 @@ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_m
1002210011
1002310012template [[host_name(" kernel_mul_mm_f32_f16" )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1 , dequantize_f32, float , float4x4, half, half2x4>;
1002410013template [[host_name(" kernel_mul_mm_f16_f16" )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1 , dequantize_f16, half, half4x4, half, half2x4>;
10014+ template [[host_name(" kernel_mul_mm_q1_0_f16" )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8 , dequantize_q1_0, float , float4x4, half, half2x4>;
1002510015template [[host_name(" kernel_mul_mm_q4_0_f16" )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2 , dequantize_q4_0, float , float4x4, half, half2x4>;
1002610016template [[host_name(" kernel_mul_mm_q4_1_f16" )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2 , dequantize_q4_1, float , float4x4, half, half2x4>;
1002710017template [[host_name(" kernel_mul_mm_q5_0_f16" )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2 , dequantize_q5_0, float , float4x4, half, half2x4>;
0 commit comments