Skip to content

Commit fbecb9c

Browse files
metal: Q1_0 backend (ggml-org#21528)
* initial Q1_0 Metal backend * tuning q1_0 metal kernels * add Q1_0 to test-backend-ops * add Q1_0<->F32 copy test * Apply suggestions from code review Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
1 parent 7dba116 commit fbecb9c

File tree

6 files changed

+205
-0
lines changed

6 files changed

+205
-0
lines changed

ggml/src/ggml-metal/ggml-metal-device.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -736,6 +736,11 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv(ggml_meta
736736
suffix = ne00 % 4 == 0 ? "_4" : "";
737737
}
738738
} break;
739+
case GGML_TYPE_Q1_0:
740+
{
741+
nsg = N_SG_Q1_0;
742+
nr0 = N_R0_Q1_0;
743+
} break;
739744
case GGML_TYPE_Q4_0:
740745
{
741746
nsg = N_SG_Q4_0;
@@ -948,6 +953,11 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_id(ggml_m
948953
smem = 32*sizeof(float)*nr0;
949954
suffix = ne00 % 4 == 0 ? "_4" : "";
950955
} break;
956+
case GGML_TYPE_Q1_0:
957+
{
958+
nsg = N_SG_Q1_0;
959+
nr0 = N_R0_Q1_0;
960+
} break;
951961
case GGML_TYPE_Q4_0:
952962
{
953963
nsg = N_SG_Q4_0;

ggml/src/ggml-metal/ggml-metal-device.m

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1184,6 +1184,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
11841184
case GGML_TYPE_F16:
11851185
case GGML_TYPE_BF16:
11861186
case GGML_TYPE_Q8_0:
1187+
case GGML_TYPE_Q1_0:
11871188
case GGML_TYPE_Q4_0:
11881189
case GGML_TYPE_Q4_1:
11891190
case GGML_TYPE_Q5_0:
@@ -1210,6 +1211,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
12101211
default:
12111212
return false;
12121213
}
1214+
case GGML_TYPE_Q1_0:
12131215
case GGML_TYPE_Q4_0:
12141216
case GGML_TYPE_Q4_1:
12151217
case GGML_TYPE_Q5_0:

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
//
99
// TODO: for optimal performance, become function of the device and work size
1010

11+
#define N_R0_Q1_0 8
12+
#define N_SG_Q1_0 2
13+
1114
#define N_R0_Q4_0 4
1215
#define N_SG_Q4_0 2
1316

ggml/src/ggml-metal/ggml-metal-ops.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2047,6 +2047,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
20472047
op->src[0]->type == GGML_TYPE_F32 || // TODO: helper function
20482048
op->src[0]->type == GGML_TYPE_F16 ||
20492049
op->src[0]->type == GGML_TYPE_BF16 ||
2050+
op->src[0]->type == GGML_TYPE_Q1_0 ||
20502051
op->src[0]->type == GGML_TYPE_Q4_0 ||
20512052
op->src[0]->type == GGML_TYPE_Q4_1 ||
20522053
op->src[0]->type == GGML_TYPE_Q5_0 ||

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,56 @@ void dequantize_bf16_t4(device const bfloat4 * src, short il, thread type4 & reg
118118
}
119119
#endif
120120

121+
template <typename type4x4>
122+
void dequantize_q1_0(device const block_q1_0 * xb, short il, thread type4x4 & reg) {
123+
device const uint8_t * qs = xb->qs;
124+
const float d = xb->d;
125+
const float neg_d = -d;
126+
127+
const int byte_offset = il * 2; // il*16 bits = il*2 bytes
128+
const uint8_t b0 = qs[byte_offset];
129+
const uint8_t b1 = qs[byte_offset + 1];
130+
131+
float4x4 reg_f;
132+
133+
reg_f[0][0] = select(neg_d, d, bool(b0 & 0x01));
134+
reg_f[0][1] = select(neg_d, d, bool(b0 & 0x02));
135+
reg_f[0][2] = select(neg_d, d, bool(b0 & 0x04));
136+
reg_f[0][3] = select(neg_d, d, bool(b0 & 0x08));
137+
reg_f[1][0] = select(neg_d, d, bool(b0 & 0x10));
138+
reg_f[1][1] = select(neg_d, d, bool(b0 & 0x20));
139+
reg_f[1][2] = select(neg_d, d, bool(b0 & 0x40));
140+
reg_f[1][3] = select(neg_d, d, bool(b0 & 0x80));
141+
142+
reg_f[2][0] = select(neg_d, d, bool(b1 & 0x01));
143+
reg_f[2][1] = select(neg_d, d, bool(b1 & 0x02));
144+
reg_f[2][2] = select(neg_d, d, bool(b1 & 0x04));
145+
reg_f[2][3] = select(neg_d, d, bool(b1 & 0x08));
146+
reg_f[3][0] = select(neg_d, d, bool(b1 & 0x10));
147+
reg_f[3][1] = select(neg_d, d, bool(b1 & 0x20));
148+
reg_f[3][2] = select(neg_d, d, bool(b1 & 0x40));
149+
reg_f[3][3] = select(neg_d, d, bool(b1 & 0x80));
150+
151+
reg = (type4x4) reg_f;
152+
}
153+
154+
template <typename type4>
155+
void dequantize_q1_0_t4(device const block_q1_0 * xb, short il, thread type4 & reg) {
156+
const float d = xb->d;
157+
const float neg_d = -d;
158+
const int base = il * 4;
159+
const uint8_t byte = xb->qs[base / 8];
160+
const int s = base % 8;
161+
162+
float4 reg_f;
163+
reg_f[0] = select(neg_d, d, bool((byte >> (s )) & 1));
164+
reg_f[1] = select(neg_d, d, bool((byte >> (s + 1)) & 1));
165+
reg_f[2] = select(neg_d, d, bool((byte >> (s + 2)) & 1));
166+
reg_f[3] = select(neg_d, d, bool((byte >> (s + 3)) & 1));
167+
168+
reg = (type4) reg_f;
169+
}
170+
121171
template <typename type4x4>
122172
void dequantize_q4_0(device const block_q4_0 * xb, short il, thread type4x4 & reg) {
123173
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
@@ -152,6 +202,23 @@ void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & r
152202
}
153203
}
154204

205+
void quantize_q1_0(device const float * src, device block_q1_0 & dst) {
206+
float sum_abs = 0.0f;
207+
for (int j = 0; j < QK1_0; j++) {
208+
sum_abs += fabs(src[j]);
209+
}
210+
dst.d = sum_abs / QK1_0;
211+
212+
for (int j = 0; j < QK1_0 / 8; j++) {
213+
dst.qs[j] = 0;
214+
}
215+
for (int j = 0; j < QK1_0; j++) {
216+
if (src[j] >= 0.0f) {
217+
dst.qs[j / 8] |= (1 << (j % 8));
218+
}
219+
}
220+
}
221+
155222
void quantize_q4_0(device const float * src, device block_q4_0 & dst) {
156223
#pragma METAL fp math_mode(safe)
157224
float amax = 0.0f; // absolute max
@@ -3116,6 +3183,35 @@ kernel void kernel_group_norm_f32(
31163183
}
31173184
}
31183185

3186+
// Q1_0 dot product: dot = d * (2 * Σ(yl[i] where bit=1) - sumy)
3187+
inline float block_q_n_dot_y(device const block_q1_0 * qb_curr, float sumy, thread float * yl, int il) {
3188+
device const uint8_t * qs = qb_curr->qs + il / 8;
3189+
const uint8_t b0 = qs[0];
3190+
const uint8_t b1 = qs[1];
3191+
3192+
float acc = 0.0f;
3193+
3194+
acc += select(0.0f, yl[ 0], bool(b0 & 0x01));
3195+
acc += select(0.0f, yl[ 1], bool(b0 & 0x02));
3196+
acc += select(0.0f, yl[ 2], bool(b0 & 0x04));
3197+
acc += select(0.0f, yl[ 3], bool(b0 & 0x08));
3198+
acc += select(0.0f, yl[ 4], bool(b0 & 0x10));
3199+
acc += select(0.0f, yl[ 5], bool(b0 & 0x20));
3200+
acc += select(0.0f, yl[ 6], bool(b0 & 0x40));
3201+
acc += select(0.0f, yl[ 7], bool(b0 & 0x80));
3202+
3203+
acc += select(0.0f, yl[ 8], bool(b1 & 0x01));
3204+
acc += select(0.0f, yl[ 9], bool(b1 & 0x02));
3205+
acc += select(0.0f, yl[10], bool(b1 & 0x04));
3206+
acc += select(0.0f, yl[11], bool(b1 & 0x08));
3207+
acc += select(0.0f, yl[12], bool(b1 & 0x10));
3208+
acc += select(0.0f, yl[13], bool(b1 & 0x20));
3209+
acc += select(0.0f, yl[14], bool(b1 & 0x40));
3210+
acc += select(0.0f, yl[15], bool(b1 & 0x80));
3211+
3212+
return qb_curr->d * (2.0f * acc - sumy);
3213+
}
3214+
31193215
// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
31203216
// il indicates where the q4 quants begin (0 or QK4_0/4)
31213217
// we assume that the yl's have been multiplied with the appropriate scale factor
@@ -3337,6 +3433,85 @@ void mul_vec_q_n_f32_impl(
33373433
}
33383434
}
33393435

3436+
template<int nr0, typename args_t>
3437+
void kernel_mul_mv_q1_0_f32_impl(
3438+
args_t args,
3439+
device const char * src0,
3440+
device const char * src1,
3441+
device char * dst,
3442+
threadgroup char * shmem,
3443+
uint3 tgpig,
3444+
ushort tiisg,
3445+
ushort sgitg) {
3446+
const short NSG = FC_mul_mv_nsg;
3447+
3448+
const int nb = args.ne00/QK1_0;
3449+
3450+
const int r0 = tgpig.x;
3451+
const int r1 = tgpig.y;
3452+
const int im = tgpig.z;
3453+
3454+
const int first_row = (r0 * NSG + sgitg) * nr0;
3455+
3456+
const uint i12 = im%args.ne12;
3457+
const uint i13 = im/args.ne12;
3458+
3459+
const uint64_t offset1 = r1*args.nb11 + (i12)*args.nb12 + (i13)*args.nb13;
3460+
3461+
device const float * y = (device const float *) (src1 + offset1);
3462+
3463+
device const block_q1_0 * ax[nr0];
3464+
for (int row = 0; row < nr0; ++row) {
3465+
const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
3466+
ax[row] = (device const block_q1_0 *) ((device char *) src0 + offset0);
3467+
}
3468+
3469+
float yl[16];
3470+
float sumf[nr0] = {0.f};
3471+
3472+
const short ix = (tiisg/8);
3473+
const short il = (tiisg%8)*16;
3474+
3475+
device const float * yb = y + ix*QK1_0 + il;
3476+
3477+
for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/8) {
3478+
float sumy = 0.f;
3479+
3480+
FOR_UNROLL (short i = 0; i < 16; i++) {
3481+
yl[i] = yb[i];
3482+
sumy += yb[i];
3483+
}
3484+
3485+
FOR_UNROLL (short row = 0; row < nr0; row++) {
3486+
sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy, yl, il);
3487+
}
3488+
3489+
yb += QK1_0 * (N_SIMDWIDTH/8);
3490+
}
3491+
3492+
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
3493+
3494+
for (int row = 0; row < nr0; ++row) {
3495+
const float tot = simd_sum(sumf[row]);
3496+
3497+
if (tiisg == 0 && first_row + row < args.ne01) {
3498+
dst_f32[first_row + row] = tot;
3499+
}
3500+
}
3501+
}
3502+
3503+
[[host_name("kernel_mul_mv_q1_0_f32")]]
3504+
kernel void kernel_mul_mv_q1_0_f32(
3505+
constant ggml_metal_kargs_mul_mv & args,
3506+
device const char * src0,
3507+
device const char * src1,
3508+
device char * dst,
3509+
uint3 tgpig[[threadgroup_position_in_grid]],
3510+
ushort tiisg[[thread_index_in_simdgroup]],
3511+
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
3512+
kernel_mul_mv_q1_0_f32_impl<N_R0_Q1_0, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
3513+
}
3514+
33403515
kernel void kernel_mul_mv_q4_0_f32(
33413516
constant ggml_metal_kargs_mul_mv & args,
33423517
device const char * src0,
@@ -3729,6 +3904,11 @@ template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_4")]] kernel mul_mv_ext_q4
37293904
template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, bfloat4, 4, dequantize_bf16_t4>;
37303905
#endif
37313906

3907+
template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q1_0, 128, dequantize_q1_0_t4>;
3908+
template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q1_0, 128, dequantize_q1_0_t4>;
3909+
template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q1_0, 128, dequantize_q1_0_t4>;
3910+
template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q1_0, 128, dequantize_q1_0_t4>;
3911+
37323912
template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_0, 32, dequantize_q4_0_t4>;
37333913
template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_0, 32, dequantize_q4_0_t4>;
37343914
template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_0, 32, dequantize_q4_0_t4>;
@@ -7133,6 +7313,7 @@ kernel void kernel_cpy_f32_q(
71337313
typedef decltype(kernel_cpy_f32_q<QK8_0, block_q8_0, quantize_q8_0>) cpy_f_q_t;
71347314

71357315
template [[host_name("kernel_cpy_f32_q8_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK8_0, block_q8_0, quantize_q8_0>;
7316+
template [[host_name("kernel_cpy_f32_q1_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK1_0, block_q1_0, quantize_q1_0>;
71367317
template [[host_name("kernel_cpy_f32_q4_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_0, block_q4_0, quantize_q4_0>;
71377318
template [[host_name("kernel_cpy_f32_q4_1")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_1, block_q4_1, quantize_q4_1>;
71387319
template [[host_name("kernel_cpy_f32_q5_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK5_0, block_q5_0, quantize_q5_0>;
@@ -7173,12 +7354,14 @@ kernel void kernel_cpy_q_f32(
71737354

71747355
typedef decltype(kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>) cpy_q_f_t;
71757356

7357+
template [[host_name("kernel_cpy_q1_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q1_0, 8, dequantize_q1_0>;
71767358
template [[host_name("kernel_cpy_q4_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>;
71777359
template [[host_name("kernel_cpy_q4_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_1, 2, dequantize_q4_1>;
71787360
template [[host_name("kernel_cpy_q5_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_0, 2, dequantize_q5_0>;
71797361
template [[host_name("kernel_cpy_q5_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_1, 2, dequantize_q5_1>;
71807362
template [[host_name("kernel_cpy_q8_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q8_0, 2, dequantize_q8_0>;
71817363

7364+
template [[host_name("kernel_cpy_q1_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q1_0, 8, dequantize_q1_0>;
71827365
template [[host_name("kernel_cpy_q4_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_0, 2, dequantize_q4_0>;
71837366
template [[host_name("kernel_cpy_q4_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_1, 2, dequantize_q4_1>;
71847367
template [[host_name("kernel_cpy_q5_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q5_0, 2, dequantize_q5_0>;
@@ -9776,6 +9959,7 @@ template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_ro
97769959

97779960
typedef decltype(kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>) get_rows_q_t;
97789961

9962+
template [[host_name("kernel_get_rows_q1_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q1_0, 8, dequantize_q1_0>;
97799963
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>;
97809964
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_1, 2, dequantize_q4_1>;
97819965
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_0, 2, dequantize_q5_0>;
@@ -9838,6 +10022,7 @@ template [[host_name("kernel_mul_mm_f16_f32")]] kernel mul_mm_t kernel_mul_m
983810022
#if defined(GGML_METAL_HAS_BF16)
983910023
template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mul_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat, bfloat2x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, float, float2x4>;
984010024
#endif
10025+
template [[host_name("kernel_mul_mm_q1_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, float, float2x4>;
984110026
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, float, float2x4>;
984210027
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, float, float2x4>;
984310028
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, float, float2x4>;
@@ -9861,6 +10046,7 @@ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_m
986110046

986210047
template [[host_name("kernel_mul_mm_f32_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, half, half2x4>;
986310048
template [[host_name("kernel_mul_mm_f16_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, half, half2x4>;
10049+
template [[host_name("kernel_mul_mm_q1_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, half, half2x4>;
986410050
template [[host_name("kernel_mul_mm_q4_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, half, half2x4>;
986510051
template [[host_name("kernel_mul_mm_q4_1_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, half, half2x4>;
986610052
template [[host_name("kernel_mul_mm_q5_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, half, half2x4>;
@@ -10070,6 +10256,7 @@ template [[host_name("kernel_mul_mv_id_bf16_f32_4")]] kernel kernel_mul_mv_id_4
1007010256

1007110257
template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0>>>;
1007210258

10259+
template [[host_name("kernel_mul_mv_id_q1_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q1_0_f32_impl<N_R0_Q1_0>>>;
1007310260
template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0>>>;
1007410261
template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1>>>;
1007510262
template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0>>>;

tests/test-backend-ops.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7251,6 +7251,7 @@ static const ggml_type all_types[] = {
72517251
GGML_TYPE_Q4_0, GGML_TYPE_Q4_1,
72527252
GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
72537253
GGML_TYPE_Q8_0,
7254+
GGML_TYPE_Q1_0,
72547255
GGML_TYPE_MXFP4, GGML_TYPE_NVFP4,
72557256
GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,
72567257
GGML_TYPE_Q4_K, GGML_TYPE_Q5_K,
@@ -7275,6 +7276,7 @@ static const ggml_type other_types[] = {
72757276
GGML_TYPE_Q4_1,
72767277
GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
72777278
GGML_TYPE_Q8_0,
7279+
GGML_TYPE_Q1_0,
72787280
GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,
72797281
GGML_TYPE_Q5_K,
72807282
GGML_TYPE_Q6_K,

0 commit comments

Comments
 (0)