Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions ggml/src/ggml-metal/ggml-metal-device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -736,6 +736,11 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv(ggml_meta
suffix = ne00 % 4 == 0 ? "_4" : "";
}
} break;
case GGML_TYPE_Q1_0:
{
nsg = N_SG_Q1_0;
nr0 = N_R0_Q1_0;
} break;
case GGML_TYPE_Q4_0:
{
nsg = N_SG_Q4_0;
Expand Down Expand Up @@ -948,6 +953,11 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_id(ggml_m
smem = 32*sizeof(float)*nr0;
suffix = ne00 % 4 == 0 ? "_4" : "";
} break;
case GGML_TYPE_Q1_0:
{
nsg = N_SG_Q1_0;
nr0 = N_R0_Q1_0;
} break;
Comment thread
khosravipasha marked this conversation as resolved.
case GGML_TYPE_Q4_0:
{
nsg = N_SG_Q4_0;
Expand Down
3 changes: 3 additions & 0 deletions ggml/src/ggml-metal/ggml-metal-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
//
// TODO: for optimal performance, become function of the device and work size

#define N_R0_Q1_0 8
#define N_SG_Q1_0 2

#define N_R0_Q4_0 4
#define N_SG_Q4_0 2

Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-metal/ggml-metal-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2047,6 +2047,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
op->src[0]->type == GGML_TYPE_F32 || // TODO: helper function
op->src[0]->type == GGML_TYPE_F16 ||
op->src[0]->type == GGML_TYPE_BF16 ||
op->src[0]->type == GGML_TYPE_Q1_0 ||
op->src[0]->type == GGML_TYPE_Q4_0 ||
op->src[0]->type == GGML_TYPE_Q4_1 ||
Comment thread
khosravipasha marked this conversation as resolved.
op->src[0]->type == GGML_TYPE_Q5_0 ||
Expand Down
168 changes: 168 additions & 0 deletions ggml/src/ggml-metal/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,56 @@ void dequantize_bf16_t4(device const bfloat4 * src, short il, thread type4 & reg
}
#endif

template <typename type4x4>
void dequantize_q1_0(device const block_q1_0 * xb, short il, thread type4x4 & reg) {
device const uint8_t * qs = xb->qs;
const float d = xb->d;
const float neg_d = -d;

const int byte_offset = il * 2; // il*16 bits = il*2 bytes
const uint8_t b0 = qs[byte_offset];
const uint8_t b1 = qs[byte_offset + 1];

float4x4 reg_f;

reg_f[0][0] = select(neg_d, d, bool(b0 & 0x01));
reg_f[0][1] = select(neg_d, d, bool(b0 & 0x02));
reg_f[0][2] = select(neg_d, d, bool(b0 & 0x04));
reg_f[0][3] = select(neg_d, d, bool(b0 & 0x08));
reg_f[1][0] = select(neg_d, d, bool(b0 & 0x10));
reg_f[1][1] = select(neg_d, d, bool(b0 & 0x20));
reg_f[1][2] = select(neg_d, d, bool(b0 & 0x40));
reg_f[1][3] = select(neg_d, d, bool(b0 & 0x80));

reg_f[2][0] = select(neg_d, d, bool(b1 & 0x01));
reg_f[2][1] = select(neg_d, d, bool(b1 & 0x02));
reg_f[2][2] = select(neg_d, d, bool(b1 & 0x04));
reg_f[2][3] = select(neg_d, d, bool(b1 & 0x08));
reg_f[3][0] = select(neg_d, d, bool(b1 & 0x10));
reg_f[3][1] = select(neg_d, d, bool(b1 & 0x20));
reg_f[3][2] = select(neg_d, d, bool(b1 & 0x40));
reg_f[3][3] = select(neg_d, d, bool(b1 & 0x80));

reg = (type4x4) reg_f;
}

template <typename type4>
void dequantize_q1_0_t4(device const block_q1_0 * xb, short il, thread type4 & reg) {
const float d = xb->d;
const float neg_d = -d;
const int base = il * 4;
const uint8_t byte = xb->qs[base / 8];
const int s = base % 8;

float4 reg_f;
reg_f[0] = select(neg_d, d, bool((byte >> (s )) & 1));
reg_f[1] = select(neg_d, d, bool((byte >> (s + 1)) & 1));
reg_f[2] = select(neg_d, d, bool((byte >> (s + 2)) & 1));
reg_f[3] = select(neg_d, d, bool((byte >> (s + 3)) & 1));

reg = (type4) reg_f;
}

template <typename type4x4>
void dequantize_q4_0(device const block_q4_0 * xb, short il, thread type4x4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
Expand Down Expand Up @@ -3116,6 +3166,35 @@ kernel void kernel_group_norm_f32(
}
}

// Q1_0 dot product: dot = d * (2 * Σ(yl[i] where bit=1) - sumy)
inline float block_q_n_dot_y(device const block_q1_0 * qb_curr, float sumy, thread float * yl, int il) {
Copy link

Copilot AI Apr 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In block_q_n_dot_y for Q1_0, the yl parameter is not modified; it can be declared as thread const float * to better document intent and allow the compiler more freedom for optimization.

Suggested change
inline float block_q_n_dot_y(device const block_q1_0 * qb_curr, float sumy, thread float * yl, int il) {
inline float block_q_n_dot_y(device const block_q1_0 * qb_curr, float sumy, thread const float * yl, int il) {

Copilot uses AI. Check for mistakes.
device const uint8_t * qs = qb_curr->qs + il / 8;
const uint8_t b0 = qs[0];
const uint8_t b1 = qs[1];

float acc = 0.0f;

acc += select(0.0f, yl[ 0], bool(b0 & 0x01));
acc += select(0.0f, yl[ 1], bool(b0 & 0x02));
acc += select(0.0f, yl[ 2], bool(b0 & 0x04));
acc += select(0.0f, yl[ 3], bool(b0 & 0x08));
acc += select(0.0f, yl[ 4], bool(b0 & 0x10));
acc += select(0.0f, yl[ 5], bool(b0 & 0x20));
acc += select(0.0f, yl[ 6], bool(b0 & 0x40));
acc += select(0.0f, yl[ 7], bool(b0 & 0x80));

acc += select(0.0f, yl[ 8], bool(b1 & 0x01));
acc += select(0.0f, yl[ 9], bool(b1 & 0x02));
acc += select(0.0f, yl[10], bool(b1 & 0x04));
acc += select(0.0f, yl[11], bool(b1 & 0x08));
acc += select(0.0f, yl[12], bool(b1 & 0x10));
acc += select(0.0f, yl[13], bool(b1 & 0x20));
acc += select(0.0f, yl[14], bool(b1 & 0x40));
acc += select(0.0f, yl[15], bool(b1 & 0x80));

return qb_curr->d * (2.0f * acc - sumy);
}

// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
// il indicates where the q4 quants begin (0 or QK4_0/4)
// we assume that the yl's have been multiplied with the appropriate scale factor
Expand Down Expand Up @@ -3337,6 +3416,87 @@ void mul_vec_q_n_f32_impl(
}
}

template<int nr0, typename args_t>
void kernel_mul_mv_q1_0_f32_impl(
args_t args,
device const char * src0,
device const char * src1,
device char * dst,
threadgroup char * shmem,
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
const short NSG = FC_mul_mv_nsg;

const int nb = args.ne00/QK1_0;

const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;

const int first_row = (r0 * NSG + sgitg) * nr0;

const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;

const uint64_t offset1 = r1*args.nb11 + (i12)*args.nb12 + (i13)*args.nb13;

device const float * y = (device const float *) (src1 + offset1);

device const block_q1_0 * ax[nr0];
for (int row = 0; row < nr0; ++row) {
const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
ax[row] = (device const block_q1_0 *) ((device char *) src0 + offset0);
}

float yl[16];
float sumf[nr0] = {0.f};

const short ix = (tiisg/8);
const short il = (tiisg%8)*16;

device const float * yb = y + ix*QK1_0 + il;

for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/8) {
float sumy = 0.f;

#pragma unroll
for (short i = 0; i < 16; i++) {
yl[i] = yb[i];
sumy += yb[i];
}

#pragma unroll
for (short row = 0; row < nr0; row++) {
sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy, yl, il);
}

yb += QK1_0 * (N_SIMDWIDTH/8);
}

device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;

for (int row = 0; row < nr0; ++row) {
const float tot = simd_sum(sumf[row]);

if (tiisg == 0 && first_row + row < args.ne01) {
dst_f32[first_row + row] = tot;
}
}
}

[[host_name("kernel_mul_mv_q1_0_f32")]]
kernel void kernel_mul_mv_q1_0_f32(
constant ggml_metal_kargs_mul_mv & args,
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_q1_0_f32_impl<N_R0_Q1_0, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
}

kernel void kernel_mul_mv_q4_0_f32(
constant ggml_metal_kargs_mul_mv & args,
device const char * src0,
Expand Down Expand Up @@ -3729,6 +3889,11 @@ template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_4")]] kernel mul_mv_ext_q4
template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, bfloat4, 4, dequantize_bf16_t4>;
#endif

template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q1_0, 128, dequantize_q1_0_t4>;
template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q1_0, 128, dequantize_q1_0_t4>;
template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q1_0, 128, dequantize_q1_0_t4>;
template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q1_0, 128, dequantize_q1_0_t4>;

template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_0, 32, dequantize_q4_0_t4>;
template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_0, 32, dequantize_q4_0_t4>;
template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_0, 32, dequantize_q4_0_t4>;
Expand Down Expand Up @@ -9838,6 +10003,7 @@ template [[host_name("kernel_mul_mm_f16_f32")]] kernel mul_mm_t kernel_mul_m
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mul_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat, bfloat2x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, float, float2x4>;
#endif
template [[host_name("kernel_mul_mm_q1_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, float, float2x4>;
Expand All @@ -9861,6 +10027,7 @@ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_m

template [[host_name("kernel_mul_mm_f32_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_f16_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q1_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q4_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q4_1_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q5_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, half, half2x4>;
Expand Down Expand Up @@ -10070,6 +10237,7 @@ template [[host_name("kernel_mul_mv_id_bf16_f32_4")]] kernel kernel_mul_mv_id_4

template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0>>>;

template [[host_name("kernel_mul_mv_id_q1_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q1_0_f32_impl<N_R0_Q1_0>>>;
template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0>>>;
template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1>>>;
template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0>>>;
Expand Down
Loading