Skip to content

Commit f2b50f9

Browse files
committed
initial Q1_0 Metal backend
1 parent 16ff288 commit f2b50f9

4 files changed

Lines changed: 175 additions & 0 deletions

File tree

ggml/src/ggml-metal/ggml-metal-device.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -736,6 +736,11 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv(ggml_meta
736736
suffix = ne00 % 4 == 0 ? "_4" : "";
737737
}
738738
} break;
739+
case GGML_TYPE_Q1_0:
740+
{
741+
nsg = N_SG_Q1_0;
742+
nr0 = N_R0_Q1_0;
743+
} break;
739744
case GGML_TYPE_Q4_0:
740745
{
741746
nsg = N_SG_Q4_0;
@@ -948,6 +953,11 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_id(ggml_m
948953
smem = 32*sizeof(float)*nr0;
949954
suffix = ne00 % 4 == 0 ? "_4" : "";
950955
} break;
956+
case GGML_TYPE_Q1_0:
957+
{
958+
nsg = N_SG_Q1_0;
959+
nr0 = N_R0_Q1_0;
960+
} break;
951961
case GGML_TYPE_Q4_0:
952962
{
953963
nsg = N_SG_Q4_0;

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
//
99
// TODO: for optimal performance, become function of the device and work size
1010

11+
#define N_R0_Q1_0 4
12+
#define N_SG_Q1_0 2
13+
1114
#define N_R0_Q4_0 4
1215
#define N_SG_Q4_0 2
1316

ggml/src/ggml-metal/ggml-metal-ops.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2047,6 +2047,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
20472047
op->src[0]->type == GGML_TYPE_F32 || // TODO: helper function
20482048
op->src[0]->type == GGML_TYPE_F16 ||
20492049
op->src[0]->type == GGML_TYPE_BF16 ||
2050+
op->src[0]->type == GGML_TYPE_Q1_0 ||
20502051
op->src[0]->type == GGML_TYPE_Q4_0 ||
20512052
op->src[0]->type == GGML_TYPE_Q4_1 ||
20522053
op->src[0]->type == GGML_TYPE_Q5_0 ||

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,66 @@ void dequantize_bf16_t4(device const bfloat4 * src, short il, thread type4 & reg
118118
}
119119
#endif
120120

121+
template <typename type4x4>
122+
void dequantize_q1_0(device const block_q1_0 * xb, short il, thread type4x4 & reg) {
123+
device const uint8_t * qs = xb->qs;
124+
const float d = xb->d;
125+
const float neg_d = -d;
126+
127+
// Process 16 bits starting at offset il*16
128+
// Optimization: process 2 bytes (16 bits) at once for better memory access
129+
const int byte_offset = il * 2; // il*16 bits = il*2 bytes
130+
const uint8_t b0 = qs[byte_offset];
131+
const uint8_t b1 = qs[byte_offset + 1];
132+
133+
float4x4 reg_f;
134+
135+
// Unroll completely for better ILP
136+
// First byte (bits 0-7)
137+
reg_f[0][0] = (b0 & 0x01) ? d : neg_d;
138+
reg_f[0][1] = (b0 & 0x02) ? d : neg_d;
139+
reg_f[0][2] = (b0 & 0x04) ? d : neg_d;
140+
reg_f[0][3] = (b0 & 0x08) ? d : neg_d;
141+
reg_f[1][0] = (b0 & 0x10) ? d : neg_d;
142+
reg_f[1][1] = (b0 & 0x20) ? d : neg_d;
143+
reg_f[1][2] = (b0 & 0x40) ? d : neg_d;
144+
reg_f[1][3] = (b0 & 0x80) ? d : neg_d;
145+
146+
// Second byte (bits 8-15)
147+
reg_f[2][0] = (b1 & 0x01) ? d : neg_d;
148+
reg_f[2][1] = (b1 & 0x02) ? d : neg_d;
149+
reg_f[2][2] = (b1 & 0x04) ? d : neg_d;
150+
reg_f[2][3] = (b1 & 0x08) ? d : neg_d;
151+
reg_f[3][0] = (b1 & 0x10) ? d : neg_d;
152+
reg_f[3][1] = (b1 & 0x20) ? d : neg_d;
153+
reg_f[3][2] = (b1 & 0x40) ? d : neg_d;
154+
reg_f[3][3] = (b1 & 0x80) ? d : neg_d;
155+
156+
reg = (type4x4) reg_f;
157+
}
158+
159+
template <typename type4>
160+
void dequantize_q1_0_t4(device const block_q1_0 * xb, short il, thread type4 & reg) {
161+
device const uint8_t * qs = xb->qs;
162+
const float d = xb->d;
163+
164+
float4 reg_f;
165+
166+
// Process 4 bits for each call
167+
const int offset = il * 4;
168+
169+
for (int i = 0; i < 4; i++) {
170+
const int bit_idx = offset + i;
171+
const int byte_idx = bit_idx / 8;
172+
const int bit_offset = bit_idx % 8;
173+
174+
const bool bit_val = (qs[byte_idx] >> bit_offset) & 1;
175+
reg_f[i] = bit_val ? d : -d;
176+
}
177+
178+
reg = (type4) reg_f;
179+
}
180+
121181
template <typename type4x4>
122182
void dequantize_q4_0(device const block_q4_0 * xb, short il, thread type4x4 & reg) {
123183
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
@@ -3116,6 +3176,29 @@ kernel void kernel_group_norm_f32(
31163176
}
31173177
}
31183178

3179+
// function for calculate inner product between part of a q1_0 block and 16 floats (yl), sumy is SUM(yl[i])
3180+
// il indicates where the q1 quants begin (0, 16, 32, ..., 112 for 128-element block)
3181+
// we assume that the yl's have been multiplied with the appropriate scale factor
3182+
inline float block_q_n_dot_y(device const block_q1_0 * qb_curr, float sumy, thread float * yl, int il) {
3183+
float d = qb_curr->d;
3184+
3185+
float acc = 0.0f;
3186+
3187+
// il represents which 16-element chunk of the 128-element block (0, 16, 32, ..., 112)
3188+
// 16 weights = 16 bits = 2 bytes
3189+
const int byte_offset = il / 8;
3190+
device const uint8_t * qs = qb_curr->qs + byte_offset;
3191+
3192+
for (int i = 0; i < 16; i++) {
3193+
const uint8_t byte_idx = i / 8;
3194+
const uint8_t bit_idx = i % 8;
3195+
const int8_t qval = ((qs[byte_idx] >> bit_idx) & 1) ? 1 : -1;
3196+
acc += yl[i] * qval;
3197+
}
3198+
3199+
return d * acc;
3200+
}
3201+
31193202
// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
31203203
// il indicates where the q4 quants begin (0 or QK4_0/4)
31213204
// we assume that the yl's have been multiplied with the appropriate scale factor
@@ -3337,6 +3420,78 @@ void mul_vec_q_n_f32_impl(
33373420
}
33383421
}
33393422

3423+
kernel void kernel_mul_mv_q1_0_f32(
3424+
constant ggml_metal_kargs_mul_mv & args,
3425+
device const char * src0,
3426+
device const char * src1,
3427+
device char * dst,
3428+
uint3 tgpig[[threadgroup_position_in_grid]],
3429+
ushort tiisg[[thread_index_in_simdgroup]],
3430+
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
3431+
// Q1_0-specific implementation with 128-element blocks
3432+
const int nb = args.ne00/QK1_0;
3433+
3434+
const int r0 = tgpig.x;
3435+
const int r1 = tgpig.y;
3436+
const int im = tgpig.z;
3437+
3438+
const int first_row = (r0 * N_SG_Q1_0 + sgitg) * N_R0_Q1_0;
3439+
3440+
const uint i12 = im%args.ne12;
3441+
const uint i13 = im/args.ne12;
3442+
3443+
const uint64_t offset1 = r1*args.nb11 + (i12)*args.nb12 + (i13)*args.nb13;
3444+
3445+
device const float * y = (device const float *) (src1 + offset1);
3446+
3447+
// pointers to src0 rows
3448+
device const block_q1_0 * ax[N_R0_Q1_0];
3449+
for (int row = 0; row < N_R0_Q1_0; ++row) {
3450+
const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
3451+
3452+
ax[row] = (device const block_q1_0 *) ((device char *) src0 + offset0);
3453+
}
3454+
3455+
float yl[16]; // src1 vector cache
3456+
float sumf[N_R0_Q1_0] = {0.f};
3457+
3458+
// For 128-element blocks, we need 8 passes of 16 elements each
3459+
// Each thread processes a different 16-element chunk
3460+
const short ix = (tiisg/8); // which block (0 to 3 for 32 threads / 8)
3461+
const short il = (tiisg%8)*16; // which 16-element chunk within the 128-element block (0, 16, 32, ..., 112)
3462+
3463+
device const float * yb = y + ix*QK1_0 + il;
3464+
3465+
// each thread in a SIMD group deals with 1/8 of a block (16 elements out of 128)
3466+
for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/8) {
3467+
float sumy = 0.f;
3468+
3469+
// Q1_0: simple copy
3470+
#pragma unroll
3471+
for (short i = 0; i < 16; i++) {
3472+
yl[i] = yb[i];
3473+
sumy += yb[i];
3474+
}
3475+
3476+
#pragma unroll
3477+
for (short row = 0; row < N_R0_Q1_0; row++) {
3478+
sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy, yl, il);
3479+
}
3480+
3481+
yb += QK1_0 * (N_SIMDWIDTH/8);
3482+
}
3483+
3484+
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
3485+
3486+
for (int row = 0; row < N_R0_Q1_0; ++row) {
3487+
const float tot = simd_sum(sumf[row]);
3488+
3489+
if (tiisg == 0 && first_row + row < args.ne01) {
3490+
dst_f32[first_row + row] = tot;
3491+
}
3492+
}
3493+
}
3494+
33403495
kernel void kernel_mul_mv_q4_0_f32(
33413496
constant ggml_metal_kargs_mul_mv & args,
33423497
device const char * src0,
@@ -3729,6 +3884,11 @@ template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_4")]] kernel mul_mv_ext_q4
37293884
template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, bfloat4, 4, dequantize_bf16_t4>;
37303885
#endif
37313886

3887+
template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q1_0, 128, dequantize_q1_0_t4>;
3888+
template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q1_0, 128, dequantize_q1_0_t4>;
3889+
template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q1_0, 128, dequantize_q1_0_t4>;
3890+
template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q1_0, 128, dequantize_q1_0_t4>;
3891+
37323892
template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_0, 32, dequantize_q4_0_t4>;
37333893
template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_0, 32, dequantize_q4_0_t4>;
37343894
template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_0, 32, dequantize_q4_0_t4>;
@@ -9838,6 +9998,7 @@ template [[host_name("kernel_mul_mm_f16_f32")]] kernel mul_mm_t kernel_mul_m
98389998
#if defined(GGML_METAL_HAS_BF16)
98399999
template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mul_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat, bfloat2x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, float, float2x4>;
984010000
#endif
10001+
template [[host_name("kernel_mul_mm_q1_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, float, float2x4>;
984110002
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, float, float2x4>;
984210003
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, float, float2x4>;
984310004
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, float, float2x4>;

0 commit comments

Comments
 (0)