From 077b2e54665e40eda9ba29e3fe43a6d592852a88 Mon Sep 17 00:00:00 2001 From: Pasha Khosravi Date: Tue, 14 Apr 2026 18:09:26 -0700 Subject: [PATCH] opencl: Q1_0 support first attempt --- ggml/src/ggml-opencl/CMakeLists.txt | 3 + ggml/src/ggml-opencl/ggml-opencl.cpp | 282 +++++++++++++++++- ggml/src/ggml-opencl/kernels/cvt.cl | 44 +++ .../kernels/mul_mat_q1_0_Ab_Bi_8x4.cl | 216 ++++++++++++++ .../kernels/mul_mv_q1_0_f32_1d_8x_flat.cl | 144 +++++++++ .../kernels/mul_mv_q1_0_f32_8x_flat.cl | 137 +++++++++ ggml/src/ggml-opencl/kernels/transpose.cl | 44 +++ 7 files changed, 868 insertions(+), 2 deletions(-) create mode 100644 ggml/src/ggml-opencl/kernels/mul_mat_q1_0_Ab_Bi_8x4.cl create mode 100644 ggml/src/ggml-opencl/kernels/mul_mv_q1_0_f32_1d_8x_flat.cl create mode 100644 ggml/src/ggml-opencl/kernels/mul_mv_q1_0_f32_8x_flat.cl diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index 112c2afe821..36033ff3af5 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -96,6 +96,9 @@ set(GGML_OPENCL_KERNELS mul_mv_q6_k_f32_flat mul_mv_q8_0_f32 mul_mv_q8_0_f32_flat + mul_mv_q1_0_f32_8x_flat + mul_mv_q1_0_f32_1d_8x_flat + mul_mat_q1_0_Ab_Bi_8x4 mul_mv_mxfp4_f32 mul_mv_mxfp4_f32_flat mul_mv_id_q4_0_f32_8x_flat diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index a581402300a..a0dd3bd95b4 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -430,6 +430,7 @@ struct ggml_backend_opencl_context { cl_program program_im2col_f16; cl_program program_im2col_f32; cl_program program_mul_mat_Ab_Bi_8x4; + cl_program program_mul_mat_q1_0_Ab_Bi_8x4; cl_program program_mul_mv_q4_0_f32; cl_program program_mul_mv_q4_0_f32_v; cl_program program_mul_mv_q4_0_f32_8x_flat; @@ -437,6 +438,8 @@ struct ggml_backend_opencl_context { cl_program program_mul_mv_q4_0_f32_1d_16x_flat; cl_program program_mul_mv_q6_K; cl_program program_mul_mv_q8_0_f32, program_mul_mv_q8_0_f32_flat; + cl_program program_mul_mv_q1_0_f32_8x_flat; + cl_program program_mul_mv_q1_0_f32_1d_8x_flat; cl_program program_mul_mv_mxfp4_f32; cl_program program_mul_mv_mxfp4_f32_flat; cl_program program_mul_mv_f16_f16; @@ -532,6 +535,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_convert_block_q4_1, kernel_restore_block_q4_1; cl_kernel kernel_convert_block_mxfp4, kernel_convert_block_mxfp4_trans, kernel_restore_block_mxfp4, kernel_restore_block_mxfp4_trans; cl_kernel kernel_convert_block_q8_0, kernel_restore_block_q8_0, kernel_restore_block_q8_0_trans; + cl_kernel kernel_convert_block_q1_0, kernel_restore_block_q1_0; cl_kernel kernel_convert_block_q6_K_noshuffle, kernel_restore_block_q6_K_noshuffle; cl_kernel kernel_mul_mat_q4_0_f32_8x_flat; cl_kernel kernel_convert_block_q4_0_noshuffle; @@ -552,6 +556,8 @@ struct ggml_backend_opencl_context { cl_kernel kernel_mul_mv_q5_K_f32_flat; cl_kernel kernel_mul_mv_q6_K_f32; cl_kernel kernel_mul_mv_q6_K_f32_flat; + cl_kernel kernel_mul_mat_q1_0_f32_8x_flat; + cl_kernel kernel_mul_mat_q1_0_f32_1d_8x_flat; cl_kernel kernel_mul_mv_mxfp4_f32, kernel_mul_mv_mxfp4_f32_flat; cl_kernel kernel_mul_mv_q8_0_f32, kernel_mul_mv_q8_0_f32_flat; cl_kernel kernel_solve_tri_f32; @@ -717,6 +723,7 @@ struct ggml_backend_opencl_context { cl_program program_CL_gemv_11008_1_4096; cl_program program_CL_gemv_32000_1_4096; cl_kernel CL_mul_mat_Ab_Bi_8x4; + cl_kernel CL_mul_mat_q1_0_Ab_Bi_8x4; cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_general; cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_11008; cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_4096; @@ -938,6 +945,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve CL_CHECK((backend_ctx->kernel_convert_block_q8_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q8_0", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q8_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q8_0", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q8_0_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q8_0_trans", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q1_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q1_0", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q1_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q1_0", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q4_K = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_K", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q4_K = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_K", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q4_K_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_K_noshuffle", &err), err)); @@ -1353,6 +1362,38 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // mul_mv_q1_0_f32_8x_flat (SOA layout, token generation) + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q1_0_f32_8x_flat.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q1_0_f32_8x_flat.cl"); +#endif + backend_ctx->program_mul_mv_q1_0_f32_8x_flat = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mat_q1_0_f32_8x_flat = clCreateKernel(backend_ctx->program_mul_mv_q1_0_f32_8x_flat, "kernel_mul_mat_q1_0_f32_8x_flat", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mv_q1_0_f32_1d_8x_flat (SOA layout, batch processing) + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q1_0_f32_1d_8x_flat.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q1_0_f32_1d_8x_flat.cl"); +#endif + backend_ctx->program_mul_mv_q1_0_f32_1d_8x_flat = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mat_q1_0_f32_1d_8x_flat = clCreateKernel(backend_ctx->program_mul_mv_q1_0_f32_1d_8x_flat, "kernel_mul_mat_q1_0_f32_1d_8x_flat", &err), err)); + GGML_LOG_CONT("."); + } + // mul_mv_mxfp4_f32 { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -2602,6 +2643,20 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // mul_mat_q1_0_Ab_Bi_8x4 (Q1_0 GEMM kernel) + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src_q1_0_gemm { + #include "mul_mat_q1_0_Ab_Bi_8x4.cl.h" + }; +#else + const std::string kernel_src_q1_0_gemm = read_file("mul_mat_q1_0_Ab_Bi_8x4.cl"); +#endif + backend_ctx->program_mul_mat_q1_0_Ab_Bi_8x4 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_q1_0_gemm.c_str(), compile_opts); + CL_CHECK((backend_ctx->CL_mul_mat_q1_0_Ab_Bi_8x4 = clCreateKernel(backend_ctx->program_mul_mat_q1_0_Ab_Bi_8x4, "kernel_mul_mat_q1_0_Ab_Bi_8x4", &err), err)); + GGML_LOG_CONT("."); + } + // gemm_noshuffle_q4_1_f32 { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -3552,6 +3607,34 @@ struct ggml_tensor_extra_cl_q8_0 { } }; +struct ggml_tensor_extra_cl_q1_0 { + cl_mem q = nullptr; + cl_mem q_img = nullptr; + cl_mem d = nullptr; + cl_mem d_img = nullptr; + size_t size_q = 0; + size_t size_d = 0; + + ~ggml_tensor_extra_cl_q1_0() { + reset(); + } + + void reset() { + if (q != nullptr) { + CL_CHECK(clReleaseMemObject(q)); + q = nullptr; + } + if (d != nullptr) { + CL_CHECK(clReleaseMemObject(d)); + d = nullptr; + } + q_img = nullptr; + d_img = nullptr; + size_q = 0; + size_d = 0; + } +}; + struct ggml_tensor_extra_cl_q4_K { // Quantized values cl_mem q = nullptr; @@ -4058,6 +4141,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te return op->src[1]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]); } else if (op->src[0]->type == GGML_TYPE_Q8_0) { return op->src[1]->type == GGML_TYPE_F32; + } else if (op->src[0]->type == GGML_TYPE_Q1_0) { + return op->src[1]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]); } return false; case GGML_OP_MUL_MAT_ID: @@ -4250,6 +4335,12 @@ struct ggml_backend_opencl_buffer_context { for (ggml_tensor_extra_cl_q8_0 * e : temp_tensor_extras_q8_0_in_use) { delete e; } + for (ggml_tensor_extra_cl_q1_0 * e : temp_tensor_extras_q1_0) { + delete e; + } + for (ggml_tensor_extra_cl_q1_0 * e : temp_tensor_extras_q1_0_in_use) { + delete e; + } for (ggml_tensor_extra_cl_q4_K * e : temp_tensor_extras_q4_K) { delete e; } @@ -4345,6 +4436,21 @@ struct ggml_backend_opencl_buffer_context { return extra; } + ggml_tensor_extra_cl_q1_0 * ggml_opencl_alloc_temp_tensor_extra_q1_0() { + ggml_tensor_extra_cl_q1_0 * extra; + if (temp_tensor_extras_q1_0.empty()) { + extra = new ggml_tensor_extra_cl_q1_0(); + } else { + extra = temp_tensor_extras_q1_0.back(); + temp_tensor_extras_q1_0.pop_back(); + } + + temp_tensor_extras_q1_0_in_use.push_back(extra); + + extra->reset(); + return extra; + } + ggml_tensor_extra_cl_q4_K * ggml_opencl_alloc_temp_tensor_extra_q4_K() { ggml_tensor_extra_cl_q4_K * extra; if (temp_tensor_extras_q4_K.empty()) { @@ -4416,6 +4522,11 @@ struct ggml_backend_opencl_buffer_context { } temp_tensor_extras_q8_0_in_use.clear(); + for (ggml_tensor_extra_cl_q1_0 * e : temp_tensor_extras_q1_0_in_use) { + temp_tensor_extras_q1_0.push_back(e); + } + temp_tensor_extras_q1_0_in_use.clear(); + for (ggml_tensor_extra_cl_q4_K * e : temp_tensor_extras_q4_K_in_use) { temp_tensor_extras_q4_K.push_back(e); } @@ -4447,6 +4558,8 @@ struct ggml_backend_opencl_buffer_context { std::vector temp_tensor_extras_mxfp4_in_use; std::vector temp_tensor_extras_q8_0; std::vector temp_tensor_extras_q8_0_in_use; + std::vector temp_tensor_extras_q1_0; + std::vector temp_tensor_extras_q1_0_in_use; std::vector temp_tensor_extras_q4_K; std::vector temp_tensor_extras_q4_K_in_use; std::vector temp_tensor_extras_q5_K; @@ -5185,6 +5298,60 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, return; } + if (tensor->type == GGML_TYPE_Q1_0) { + ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; + GGML_ASSERT(extra_orig && "Tensors in OpenCL backend should have been allocated and initialized"); + + ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; + ggml_tensor_extra_cl_q1_0 * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_q1_0(); + + size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); + size_t size_q = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*(ggml_blck_size(tensor->type)/8); + GGML_ASSERT(size_d + size_q == ggml_nbytes(tensor) && "Incorrect tensor size"); + + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + CL_CHECK(clEnqueueWriteBuffer( + queue, data_device, CL_TRUE, 0, + ggml_nbytes(tensor), data, 0, NULL, NULL)); + + cl_buffer_region region; + + region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment); + region.size = size_d; + extra->d = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + auto previous_origin_q1 = region.origin; + + region.origin = align_to(previous_origin_q1 + size_d, backend_ctx->alignment); + region.size = size_q; + extra->q = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + + cl_kernel kernel = backend_ctx->kernel_convert_block_q1_0; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + tensor->extra = extra; + + return; + } if (tensor->type == GGML_TYPE_Q4_K) { ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized"); @@ -5810,6 +5977,32 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, CL_CHECK(clReleaseMemObject(data_device)); return; } + if (tensor->type == GGML_TYPE_Q1_0) { + ggml_tensor_extra_cl_q1_0 * extra = (ggml_tensor_extra_cl_q1_0 *)tensor->extra; + + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + + cl_kernel kernel = backend_ctx->kernel_restore_block_q1_0; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &data_device)); + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {1, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; + } if (tensor->type == GGML_TYPE_Q4_K) { ggml_tensor_extra_cl_q4_K * extra = (ggml_tensor_extra_cl_q4_K *)tensor->extra; @@ -10477,6 +10670,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co ggml_tensor_extra_cl_q4_1 * extra0_q4_1 = (ggml_tensor_extra_cl_q4_1 *)src0->extra; ggml_tensor_extra_cl_mxfp4 * extra0_mxfp4 = (ggml_tensor_extra_cl_mxfp4 *)src0->extra; ggml_tensor_extra_cl_q8_0 * extra0_q8_0 = (ggml_tensor_extra_cl_q8_0 *)src0->extra; + ggml_tensor_extra_cl_q1_0 * extra0_q1_0 = (ggml_tensor_extra_cl_q1_0 *)src0->extra; ggml_tensor_extra_cl_q4_K * extra0_q4_K = (ggml_tensor_extra_cl_q4_K *)src0->extra; ggml_tensor_extra_cl_q5_K * extra0_q5_K = (ggml_tensor_extra_cl_q5_K *)src0->extra; ggml_tensor_extra_cl_q6_K * extra0_q6_K = (ggml_tensor_extra_cl_q6_K *)src0->extra; @@ -10588,6 +10782,29 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co return; } + // Q1_0 x FP32 GEMM for batch processing - direct buffer access (no transpose) + if(src0t == GGML_TYPE_Q1_0 && src1t == GGML_TYPE_F32 && N > 8) { + cl_kernel kernel = backend_ctx->CL_mul_mat_q1_0_Ab_Bi_8x4; + + size_t global_work_size[3] = {(size_t)((N + 7) / 8), (size_t)(M / 4), 1}; + size_t local_work_size[3] = {1, 128, 1}; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q1_0->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q1_0->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &M)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &N)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &K)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &N)); + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + + return; + } + // q4_k x fp32 if (src0t == GGML_TYPE_Q4_K && src1t == GGML_TYPE_F32) { ggml_cl_mul_mat_q4_k_f32_adreno(backend, src0, src1, dst); @@ -11332,12 +11549,38 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &r2)); CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3)); break; + case GGML_TYPE_Q1_0: + GGML_ASSERT(ne11 == ne1); + GGML_ASSERT(ne01 == ne0); + + { + nth0 = 64; + nth1 = 1; + kernel = backend_ctx->kernel_mul_mat_q1_0_f32_1d_8x_flat; + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q1_0->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q1_0->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3)); + break; default: break; } // Launch kernel. - if (src0t == GGML_TYPE_Q4_0) { + if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q1_0) { size_t global_work_size[] = {(size_t)(ne01 + 7)/8*nth0, (size_t)ne11*nth1, (size_t)ne12*ne13}; size_t local_work_size[] = {(size_t)nth0, (size_t)nth1, 1}; @@ -11669,6 +11912,40 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co #endif // GGML_OPENCL_SOA_Q break; } + case GGML_TYPE_Q1_0: + GGML_ASSERT(ne11 == ne1); + GGML_ASSERT(ne01 == ne0); + + kernel = backend_ctx->kernel_mul_mat_q1_0_f32_8x_flat; + + if (backend_ctx->gpu_family == INTEL) { + nth0 = 16; + nth1 = 1; + ndst = 8; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 1; + ndst = 8; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q1_0->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q1_0->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3)); + break; case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: { @@ -11968,7 +12245,8 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_MXFP4 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 || - src0t == GGML_TYPE_Q2_K) { + src0t == GGML_TYPE_Q2_K || + src0t == GGML_TYPE_Q1_0) { // Each SIMD group produces N_DST values in the result. Assuming each // workgroup has N_SIMDGROUP SIMD groups, then each workgroup will // produce N_DST*N_SIMDGROUP values in the result. Hence, the grid size diff --git a/ggml/src/ggml-opencl/kernels/cvt.cl b/ggml/src/ggml-opencl/kernels/cvt.cl index 1bd83d29b3d..cd0656d7cfc 100644 --- a/ggml/src/ggml-opencl/kernels/cvt.cl +++ b/ggml/src/ggml-opencl/kernels/cvt.cl @@ -360,6 +360,50 @@ kernel void kernel_restore_block_mxfp4_trans( b->e = src_e[src_blk_offset]; } +//------------------------------------------------------------------------------ +// block_q1_0 - 1-bit quantization with group size 128 +// group size 128, 1.125 bpw +//------------------------------------------------------------------------------ +#define QK1_0 128 + +typedef struct { + half d; // delta (scale) + uchar qs[QK1_0 / 8]; // 16 bytes = 128 bits for 128 weights +} block_q1_0; + +// Convert block_q1_0 AoS -> SoA (separate scales and quants) +kernel void kernel_convert_block_q1_0( + global block_q1_0 * src0, + global uchar * dst_q, + global half * dst_d +) { + global block_q1_0 * b = (global block_q1_0 *) src0 + get_global_id(0); + global uchar * q = (global uchar *) dst_q + (QK1_0/8)*get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + + *d = b->d; + + // Copy 16 bytes of quantized bits + for (int i = 0; i < QK1_0/8; ++i) { + q[i] = b->qs[i]; + } +} + +kernel void kernel_restore_block_q1_0( + global uchar * src_q, + global half * src_d, + global block_q1_0 * dst +) { + global block_q1_0 * b = (global block_q1_0 *) dst + get_global_id(0); + global uchar * q = (global uchar *) src_q + (QK1_0/8)*get_global_id(0); + global half * d = (global half *) src_d + get_global_id(0); + + b->d = *d; + for (int i = 0; i < QK1_0/8; ++i) { + b->qs[i] = q[i]; + } +} + //------------------------------------------------------------------------------ // block_q8_0 //------------------------------------------------------------------------------ diff --git a/ggml/src/ggml-opencl/kernels/mul_mat_q1_0_Ab_Bi_8x4.cl b/ggml/src/ggml-opencl/kernels/mul_mat_q1_0_Ab_Bi_8x4.cl new file mode 100644 index 00000000000..bce915311f5 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mat_q1_0_Ab_Bi_8x4.cl @@ -0,0 +1,216 @@ +// Q1_0 GEMM Kernel - Direct GGML layout (no transpose needed) +// Each work-item computes an 8x4 output tile +// gy indexes 8 output rows (N dimension - batch/sequence) +// gx indexes 4 output columns (M dimension - output features) +// +// Q1_0: 128 elements per block, 16 bytes (128 bits) + 1 half scale +// GGML stores B as N rows of K elements: B[n][k] at index n*K + k +// This kernel loads B values with strided access to avoid transpose + +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_qcom_reqd_sub_group_size +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#ifndef REQD_SUBGROUP_SIZE_128 +#define REQD_SUBGROUP_SIZE_128 +#endif + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_128 +#endif + +kernel void kernel_mul_mat_q1_0_Ab_Bi_8x4( + global const uchar * src0_q, // packed 1-bit weights (SOA: q buffer) + global const half * src0_d, // scales (SOA: d buffer) + global const uchar * src1_base, // B activations base pointer + ulong src1_offset, // offset into src1 buffer + global uchar * dst_base, // output base pointer + ulong dst_offset, // offset into dst buffer + int m, // M (output features / rows of A) + int n, // N (batch size) + int k, // K (input features / cols of A) + int n_no_padding // N without padding (for bounds check) +) { + // Apply offsets + global const float * src1 = (global const float *)(src1_base + src1_offset); + global float * dst = (global float *)(dst_base + dst_offset); + + int gy = get_global_id(0); // output row tile (0 to N/8) + int gx = get_global_id(1); // output column tile (0 to M/4) + int gx_4 = gx << 2; // starting column (gx * 4) + + float8 c0 = 0, c1 = 0, c2 = 0, c3 = 0; // 8x4 output tile + + int num_blocks = k / 128; // 128 elements per block for Q1_0 + int row_base = gy << 3; // gy * 8 = starting output row + + // Pointers for 4 weight columns (SOA layout, row-major) + // For Q1_0: each block is 16 bytes (128 bits) + global const uchar* weight_base0 = src0_q + (gx_4 + 0) * num_blocks * 16; + global const uchar* weight_base1 = src0_q + (gx_4 + 1) * num_blocks * 16; + global const uchar* weight_base2 = src0_q + (gx_4 + 2) * num_blocks * 16; + global const uchar* weight_base3 = src0_q + (gx_4 + 3) * num_blocks * 16; + + // Scale pointers for 4 columns + global const half* scale_ptr0 = src0_d + (gx_4 + 0) * num_blocks; + global const half* scale_ptr1 = src0_d + (gx_4 + 1) * num_blocks; + global const half* scale_ptr2 = src0_d + (gx_4 + 2) * num_blocks; + global const half* scale_ptr3 = src0_d + (gx_4 + 3) * num_blocks; + + for (int block = 0; block < num_blocks; block++) { + // Load scales for 4 columns + float s0 = (float)scale_ptr0[block]; + float s1 = (float)scale_ptr1[block]; + float s2 = (float)scale_ptr2[block]; + float s3 = (float)scale_ptr3[block]; + + // Load 128 bits (4 uints) for each of 4 columns + global const uint* bits_ptr0 = (global const uint*)(weight_base0 + block * 16); + global const uint* bits_ptr1 = (global const uint*)(weight_base1 + block * 16); + global const uint* bits_ptr2 = (global const uint*)(weight_base2 + block * 16); + global const uint* bits_ptr3 = (global const uint*)(weight_base3 + block * 16); + + uint bits0_0 = bits_ptr0[0], bits0_1 = bits_ptr0[1], bits0_2 = bits_ptr0[2], bits0_3 = bits_ptr0[3]; + uint bits1_0 = bits_ptr1[0], bits1_1 = bits_ptr1[1], bits1_2 = bits_ptr1[2], bits1_3 = bits_ptr1[3]; + uint bits2_0 = bits_ptr2[0], bits2_1 = bits_ptr2[1], bits2_2 = bits_ptr2[2], bits2_3 = bits_ptr2[3]; + uint bits3_0 = bits_ptr3[0], bits3_1 = bits_ptr3[1], bits3_2 = bits_ptr3[2], bits3_3 = bits_ptr3[3]; + + // Process 128 K elements in this block + int k_base = block * 128; + + // Process first 32 bits (elements 0-31) + #pragma unroll 4 + for (int i = 0; i < 32; i++) { + int k_idx = k_base + i; + + // Load 8 B values for 8 output rows at K position k_idx + float8 B; + B.s0 = (row_base + 0 < n) ? src1[(row_base + 0) * k + k_idx] : 0.0f; + B.s1 = (row_base + 1 < n) ? src1[(row_base + 1) * k + k_idx] : 0.0f; + B.s2 = (row_base + 2 < n) ? src1[(row_base + 2) * k + k_idx] : 0.0f; + B.s3 = (row_base + 3 < n) ? src1[(row_base + 3) * k + k_idx] : 0.0f; + B.s4 = (row_base + 4 < n) ? src1[(row_base + 4) * k + k_idx] : 0.0f; + B.s5 = (row_base + 5 < n) ? src1[(row_base + 5) * k + k_idx] : 0.0f; + B.s6 = (row_base + 6 < n) ? src1[(row_base + 6) * k + k_idx] : 0.0f; + B.s7 = (row_base + 7 < n) ? src1[(row_base + 7) * k + k_idx] : 0.0f; + + float w0 = ((bits0_0 >> i) & 1u) ? s0 : -s0; + float w1 = ((bits1_0 >> i) & 1u) ? s1 : -s1; + float w2 = ((bits2_0 >> i) & 1u) ? s2 : -s2; + float w3 = ((bits3_0 >> i) & 1u) ? s3 : -s3; + + c0 += B * w0; + c1 += B * w1; + c2 += B * w2; + c3 += B * w3; + } + + // Process second 32 bits (elements 32-63) + #pragma unroll 4 + for (int i = 0; i < 32; i++) { + int k_idx = k_base + 32 + i; + + float8 B; + B.s0 = (row_base + 0 < n) ? src1[(row_base + 0) * k + k_idx] : 0.0f; + B.s1 = (row_base + 1 < n) ? src1[(row_base + 1) * k + k_idx] : 0.0f; + B.s2 = (row_base + 2 < n) ? src1[(row_base + 2) * k + k_idx] : 0.0f; + B.s3 = (row_base + 3 < n) ? src1[(row_base + 3) * k + k_idx] : 0.0f; + B.s4 = (row_base + 4 < n) ? src1[(row_base + 4) * k + k_idx] : 0.0f; + B.s5 = (row_base + 5 < n) ? src1[(row_base + 5) * k + k_idx] : 0.0f; + B.s6 = (row_base + 6 < n) ? src1[(row_base + 6) * k + k_idx] : 0.0f; + B.s7 = (row_base + 7 < n) ? src1[(row_base + 7) * k + k_idx] : 0.0f; + + float w0 = ((bits0_1 >> i) & 1u) ? s0 : -s0; + float w1 = ((bits1_1 >> i) & 1u) ? s1 : -s1; + float w2 = ((bits2_1 >> i) & 1u) ? s2 : -s2; + float w3 = ((bits3_1 >> i) & 1u) ? s3 : -s3; + + c0 += B * w0; + c1 += B * w1; + c2 += B * w2; + c3 += B * w3; + } + + // Process third 32 bits (elements 64-95) + #pragma unroll 4 + for (int i = 0; i < 32; i++) { + int k_idx = k_base + 64 + i; + + float8 B; + B.s0 = (row_base + 0 < n) ? src1[(row_base + 0) * k + k_idx] : 0.0f; + B.s1 = (row_base + 1 < n) ? src1[(row_base + 1) * k + k_idx] : 0.0f; + B.s2 = (row_base + 2 < n) ? src1[(row_base + 2) * k + k_idx] : 0.0f; + B.s3 = (row_base + 3 < n) ? src1[(row_base + 3) * k + k_idx] : 0.0f; + B.s4 = (row_base + 4 < n) ? src1[(row_base + 4) * k + k_idx] : 0.0f; + B.s5 = (row_base + 5 < n) ? src1[(row_base + 5) * k + k_idx] : 0.0f; + B.s6 = (row_base + 6 < n) ? src1[(row_base + 6) * k + k_idx] : 0.0f; + B.s7 = (row_base + 7 < n) ? src1[(row_base + 7) * k + k_idx] : 0.0f; + + float w0 = ((bits0_2 >> i) & 1u) ? s0 : -s0; + float w1 = ((bits1_2 >> i) & 1u) ? s1 : -s1; + float w2 = ((bits2_2 >> i) & 1u) ? s2 : -s2; + float w3 = ((bits3_2 >> i) & 1u) ? s3 : -s3; + + c0 += B * w0; + c1 += B * w1; + c2 += B * w2; + c3 += B * w3; + } + + // Process fourth 32 bits (elements 96-127) + #pragma unroll 4 + for (int i = 0; i < 32; i++) { + int k_idx = k_base + 96 + i; + + float8 B; + B.s0 = (row_base + 0 < n) ? src1[(row_base + 0) * k + k_idx] : 0.0f; + B.s1 = (row_base + 1 < n) ? src1[(row_base + 1) * k + k_idx] : 0.0f; + B.s2 = (row_base + 2 < n) ? src1[(row_base + 2) * k + k_idx] : 0.0f; + B.s3 = (row_base + 3 < n) ? src1[(row_base + 3) * k + k_idx] : 0.0f; + B.s4 = (row_base + 4 < n) ? src1[(row_base + 4) * k + k_idx] : 0.0f; + B.s5 = (row_base + 5 < n) ? src1[(row_base + 5) * k + k_idx] : 0.0f; + B.s6 = (row_base + 6 < n) ? src1[(row_base + 6) * k + k_idx] : 0.0f; + B.s7 = (row_base + 7 < n) ? src1[(row_base + 7) * k + k_idx] : 0.0f; + + float w0 = ((bits0_3 >> i) & 1u) ? s0 : -s0; + float w1 = ((bits1_3 >> i) & 1u) ? s1 : -s1; + float w2 = ((bits2_3 >> i) & 1u) ? s2 : -s2; + float w3 = ((bits3_3 >> i) & 1u) ? s3 : -s3; + + c0 += B * w0; + c1 += B * w1; + c2 += B * w2; + c3 += B * w3; + } + } + + // Write 8x4 tile to output + if (row_base + 0 < n_no_padding) { + vstore4((float4)(c0.s0, c1.s0, c2.s0, c3.s0), 0, dst + (row_base + 0) * m + (gx << 2)); + } + if (row_base + 1 < n_no_padding) { + vstore4((float4)(c0.s1, c1.s1, c2.s1, c3.s1), 0, dst + (row_base + 1) * m + (gx << 2)); + } + if (row_base + 2 < n_no_padding) { + vstore4((float4)(c0.s2, c1.s2, c2.s2, c3.s2), 0, dst + (row_base + 2) * m + (gx << 2)); + } + if (row_base + 3 < n_no_padding) { + vstore4((float4)(c0.s3, c1.s3, c2.s3, c3.s3), 0, dst + (row_base + 3) * m + (gx << 2)); + } + if (row_base + 4 < n_no_padding) { + vstore4((float4)(c0.s4, c1.s4, c2.s4, c3.s4), 0, dst + (row_base + 4) * m + (gx << 2)); + } + if (row_base + 5 < n_no_padding) { + vstore4((float4)(c0.s5, c1.s5, c2.s5, c3.s5), 0, dst + (row_base + 5) * m + (gx << 2)); + } + if (row_base + 6 < n_no_padding) { + vstore4((float4)(c0.s6, c1.s6, c2.s6, c3.s6), 0, dst + (row_base + 6) * m + (gx << 2)); + } + if (row_base + 7 < n_no_padding) { + vstore4((float4)(c0.s7, c1.s7, c2.s7, c3.s7), 0, dst + (row_base + 7) * m + (gx << 2)); + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q1_0_f32_1d_8x_flat.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q1_0_f32_1d_8x_flat.cl new file mode 100644 index 00000000000..48cdf0cb89f --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q1_0_f32_1d_8x_flat.cl @@ -0,0 +1,144 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#endif + +#define QK1_0 128 + +typedef uchar uint8_t; +typedef ushort uint16_t; + +// Based on working Q1_0 pattern - process 16 elements per thread +// Q1_0 has 128 elements per block, so we process 16 at a time +// il = 0,16,32,48,64,80,96,112 (8 different starting positions for 64 threads / 8 = 8 groups) +inline float block_q1_0_dot_y_flat(global uchar * x, global half * dh, float16 yl, int il) { + float d = *dh; + global ushort * qs = (global ushort *)x; + + // For 128-element block: 8 ushorts total (16 bytes) + // Each thread processes 1 ushort (16 bits) based on il + // il/16 gives us which ushort to read (0-7) + uint bits = qs[il/16]; + + int b0 = (bits >> 0) & 1u; + int b1 = (bits >> 1) & 1u; + int b2 = (bits >> 2) & 1u; + int b3 = (bits >> 3) & 1u; + int b4 = (bits >> 4) & 1u; + int b5 = (bits >> 5) & 1u; + int b6 = (bits >> 6) & 1u; + int b7 = (bits >> 7) & 1u; + int b8 = (bits >> 8) & 1u; + int b9 = (bits >> 9) & 1u; + int ba = (bits >> 10) & 1u; + int bb = (bits >> 11) & 1u; + int bc = (bits >> 12) & 1u; + int bd = (bits >> 13) & 1u; + int be = (bits >> 14) & 1u; + int bf = (bits >> 15) & 1u; + + float s0 = (float)(b0 * 2 - 1); + float s1 = (float)(b1 * 2 - 1); + float s2 = (float)(b2 * 2 - 1); + float s3 = (float)(b3 * 2 - 1); + float s4 = (float)(b4 * 2 - 1); + float s5 = (float)(b5 * 2 - 1); + float s6 = (float)(b6 * 2 - 1); + float s7 = (float)(b7 * 2 - 1); + float s8 = (float)(b8 * 2 - 1); + float s9 = (float)(b9 * 2 - 1); + float sa = (float)(ba * 2 - 1); + float sb = (float)(bb * 2 - 1); + float sc = (float)(bc * 2 - 1); + float sd = (float)(bd * 2 - 1); + float se = (float)(be * 2 - 1); + float sf = (float)(bf * 2 - 1); + + float acc = 0.f; + acc += yl.s0 * s0; + acc += yl.s1 * s1; + acc += yl.s2 * s2; + acc += yl.s3 * s3; + acc += yl.s4 * s4; + acc += yl.s5 * s5; + acc += yl.s6 * s6; + acc += yl.s7 * s7; + acc += yl.s8 * s8; + acc += yl.s9 * s9; + acc += yl.sa * sa; + acc += yl.sb * sb; + acc += yl.sc * sc; + acc += yl.sd * sd; + acc += yl.se * se; + acc += yl.sf * sf; + + return d * acc; +} + +#define N_DST 8 +#define N_SIMDGROUP 1 +#ifdef cl_intel_required_subgroup_size +#define N_SIMDWIDTH 16 +#else +#define N_SIMDWIDTH 64 +#endif + +inline void mul_vec_q1_0_f32_8x_flat(global uchar * src0_q, global half * src0_d, global float * src1, global float * dst, int ne00, int ne01, int ne02, int ne10, int ne12, int ne0, int ne1, int r2, int r3) { + const ulong nb = ne00 / QK1_0; + int r0 = get_group_id(0), r1 = get_group_id(1), im = get_group_id(2); + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + int i12 = im % ne12, i13 = im / ne12; + ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * (QK1_0/8); + global uchar * x = src0_q + offset0_q; + global half * d = src0_d + offset0_d; + global float * y = src1 + r1*ne10 + im*ne00*ne1; + float16 yl; float8 sumf = 0.f; + + // For 128-element blocks: 64 threads / 8 groups = 8 threads per group + // Each thread processes 16 elements (128/8 = 16) + int ix = get_sub_group_local_id() / 8, il = 16 * (get_sub_group_local_id() % 8); + global float * yb = y + ix * QK1_0 + il; + + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/8) { + yl.s0=yb[0]; yl.s1=yb[1]; yl.s2=yb[2]; yl.s3=yb[3]; yl.s4=yb[4]; yl.s5=yb[5]; yl.s6=yb[6]; yl.s7=yb[7]; + yl.s8=yb[8]; yl.s9=yb[9]; yl.sa=yb[10]; yl.sb=yb[11]; yl.sc=yb[12]; yl.sd=yb[13]; yl.se=yb[14]; yl.sf=yb[15]; + sumf.s0 += block_q1_0_dot_y_flat(x + ib*(QK1_0/8) + 0*nb*(QK1_0/8), d + ib + 0*nb, yl, il); + sumf.s1 += block_q1_0_dot_y_flat(x + ib*(QK1_0/8) + 1*nb*(QK1_0/8), d + ib + 1*nb, yl, il); + sumf.s2 += block_q1_0_dot_y_flat(x + ib*(QK1_0/8) + 2*nb*(QK1_0/8), d + ib + 2*nb, yl, il); + sumf.s3 += block_q1_0_dot_y_flat(x + ib*(QK1_0/8) + 3*nb*(QK1_0/8), d + ib + 3*nb, yl, il); + sumf.s4 += block_q1_0_dot_y_flat(x + ib*(QK1_0/8) + 4*nb*(QK1_0/8), d + ib + 4*nb, yl, il); + sumf.s5 += block_q1_0_dot_y_flat(x + ib*(QK1_0/8) + 5*nb*(QK1_0/8), d + ib + 5*nb, yl, il); + sumf.s6 += block_q1_0_dot_y_flat(x + ib*(QK1_0/8) + 6*nb*(QK1_0/8), d + ib + 6*nb, yl, il); + sumf.s7 += block_q1_0_dot_y_flat(x + ib*(QK1_0/8) + 7*nb*(QK1_0/8), d + ib + 7*nb, yl, il); + yb += QK1_0 * (N_SIMDWIDTH/8); + } + float8 tot = (float8)(sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3), sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5), sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7)); + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + if (first_row + 1 < ne01) dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + if (first_row + 2 < ne01) dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + if (first_row + 3 < ne01) dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + if (first_row + 4 < ne01) dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4; + if (first_row + 5 < ne01) dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5; + if (first_row + 6 < ne01) dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6; + if (first_row + 7 < ne01) dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7; + } +} + +#ifdef cl_intel_required_subgroup_size +REQD_SUBGROUP_SIZE_16 +#elif defined(cl_qcom_reqd_sub_group_size) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_q1_0_f32_1d_8x_flat(global uchar * src0_q, global half * src0_d, global float * src1, ulong offset1, global float * dst, ulong offsetd, int ne00, int ne01, int ne02, int ne10, int ne12, int ne0, int ne1, int r2, int r3) { + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + mul_vec_q1_0_f32_8x_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q1_0_f32_8x_flat.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q1_0_f32_8x_flat.cl new file mode 100644 index 00000000000..02fdac7c7b7 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q1_0_f32_8x_flat.cl @@ -0,0 +1,137 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#endif + +#define QK1_0 128 + +typedef uchar uint8_t; +typedef ushort uint16_t; + +// Based on working Q1_0 pattern +inline float block_q1_0_dot_y_flat(global uchar * x, global half * dh, float16 yl, int il) { + float d = *dh; + global ushort * qs = (global ushort *)x; + uint bits = qs[il/16]; + + int b0 = (bits >> 0) & 1u; + int b1 = (bits >> 1) & 1u; + int b2 = (bits >> 2) & 1u; + int b3 = (bits >> 3) & 1u; + int b4 = (bits >> 4) & 1u; + int b5 = (bits >> 5) & 1u; + int b6 = (bits >> 6) & 1u; + int b7 = (bits >> 7) & 1u; + int b8 = (bits >> 8) & 1u; + int b9 = (bits >> 9) & 1u; + int ba = (bits >> 10) & 1u; + int bb = (bits >> 11) & 1u; + int bc = (bits >> 12) & 1u; + int bd = (bits >> 13) & 1u; + int be = (bits >> 14) & 1u; + int bf = (bits >> 15) & 1u; + + float s0 = (float)(b0 * 2 - 1); + float s1 = (float)(b1 * 2 - 1); + float s2 = (float)(b2 * 2 - 1); + float s3 = (float)(b3 * 2 - 1); + float s4 = (float)(b4 * 2 - 1); + float s5 = (float)(b5 * 2 - 1); + float s6 = (float)(b6 * 2 - 1); + float s7 = (float)(b7 * 2 - 1); + float s8 = (float)(b8 * 2 - 1); + float s9 = (float)(b9 * 2 - 1); + float sa = (float)(ba * 2 - 1); + float sb = (float)(bb * 2 - 1); + float sc = (float)(bc * 2 - 1); + float sd = (float)(bd * 2 - 1); + float se = (float)(be * 2 - 1); + float sf = (float)(bf * 2 - 1); + + float acc = 0.f; + acc += yl.s0 * s0; + acc += yl.s1 * s1; + acc += yl.s2 * s2; + acc += yl.s3 * s3; + acc += yl.s4 * s4; + acc += yl.s5 * s5; + acc += yl.s6 * s6; + acc += yl.s7 * s7; + acc += yl.s8 * s8; + acc += yl.s9 * s9; + acc += yl.sa * sa; + acc += yl.sb * sb; + acc += yl.sc * sc; + acc += yl.sd * sd; + acc += yl.se * se; + acc += yl.sf * sf; + + return d * acc; +} + +#define N_DST 8 +#define N_SIMDGROUP 1 +#ifdef cl_intel_required_subgroup_size +#define N_SIMDWIDTH 16 +#else +#define N_SIMDWIDTH 64 +#endif + +inline void mul_vec_q1_0_f32_8x_flat(global uchar * src0_q, global half * src0_d, global float * src1, global float * dst, int ne00, int ne01, int ne02, int ne10, int ne12, int ne0, int ne1, int r2, int r3) { + const ulong nb = ne00 / QK1_0; + int r0 = get_group_id(0), r1 = get_group_id(1), im = get_group_id(2); + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + int i12 = im % ne12, i13 = im / ne12; + ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * (QK1_0/8); + global uchar * x = src0_q + offset0_q; + global half * d = src0_d + offset0_d; + global float * y = src1 + r1*ne10 + im*ne00*ne1; + float16 yl; float8 sumf = 0.f; + + // For 128-element blocks: 64 threads / 8 groups = 8 threads per group + int ix = get_sub_group_local_id() / 8, il = 16 * (get_sub_group_local_id() % 8); + global float * yb = y + ix * QK1_0 + il; + + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/8) { + yl.s0=yb[0]; yl.s1=yb[1]; yl.s2=yb[2]; yl.s3=yb[3]; yl.s4=yb[4]; yl.s5=yb[5]; yl.s6=yb[6]; yl.s7=yb[7]; + yl.s8=yb[8]; yl.s9=yb[9]; yl.sa=yb[10]; yl.sb=yb[11]; yl.sc=yb[12]; yl.sd=yb[13]; yl.se=yb[14]; yl.sf=yb[15]; + sumf.s0 += block_q1_0_dot_y_flat(x + ib*(QK1_0/8) + 0*nb*(QK1_0/8), d + ib + 0*nb, yl, il); + sumf.s1 += block_q1_0_dot_y_flat(x + ib*(QK1_0/8) + 1*nb*(QK1_0/8), d + ib + 1*nb, yl, il); + sumf.s2 += block_q1_0_dot_y_flat(x + ib*(QK1_0/8) + 2*nb*(QK1_0/8), d + ib + 2*nb, yl, il); + sumf.s3 += block_q1_0_dot_y_flat(x + ib*(QK1_0/8) + 3*nb*(QK1_0/8), d + ib + 3*nb, yl, il); + sumf.s4 += block_q1_0_dot_y_flat(x + ib*(QK1_0/8) + 4*nb*(QK1_0/8), d + ib + 4*nb, yl, il); + sumf.s5 += block_q1_0_dot_y_flat(x + ib*(QK1_0/8) + 5*nb*(QK1_0/8), d + ib + 5*nb, yl, il); + sumf.s6 += block_q1_0_dot_y_flat(x + ib*(QK1_0/8) + 6*nb*(QK1_0/8), d + ib + 6*nb, yl, il); + sumf.s7 += block_q1_0_dot_y_flat(x + ib*(QK1_0/8) + 7*nb*(QK1_0/8), d + ib + 7*nb, yl, il); + yb += QK1_0 * (N_SIMDWIDTH/8); + } + float8 tot = (float8)(sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3), sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5), sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7)); + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + if (first_row + 1 < ne01) dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + if (first_row + 2 < ne01) dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + if (first_row + 3 < ne01) dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + if (first_row + 4 < ne01) dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4; + if (first_row + 5 < ne01) dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5; + if (first_row + 6 < ne01) dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6; + if (first_row + 7 < ne01) dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7; + } +} + +#ifdef cl_intel_required_subgroup_size +REQD_SUBGROUP_SIZE_16 +#elif defined(cl_qcom_reqd_sub_group_size) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_q1_0_f32_8x_flat(global uchar * src0_q, global half * src0_d, global float * src1, ulong offset1, global float * dst, ulong offsetd, int ne00, int ne01, int ne02, int ne10, int ne12, int ne0, int ne1, int r2, int r3) { + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + mul_vec_q1_0_f32_8x_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} diff --git a/ggml/src/ggml-opencl/kernels/transpose.cl b/ggml/src/ggml-opencl/kernels/transpose.cl index ad89bdcbdec..3b38de98a24 100644 --- a/ggml/src/ggml-opencl/kernels/transpose.cl +++ b/ggml/src/ggml-opencl/kernels/transpose.cl @@ -108,6 +108,50 @@ kernel void kernel_transpose_32( } +// 32-bit transpose with bounds checking and padding support +// For Q1_0 GEMM - keeps float32 (no FP16 conversion) +// rows = original N (may not be multiple of 4) +// cols = K/4 (K dimension in tiles) +// padded_rows = ceil(N/4) for output stride +kernel void kernel_transpose_32_32( + __read_only image1d_buffer_t input, + __write_only image1d_buffer_t output, + const uint rows, + const uint cols, + const uint padded_rows +) { + const int i = get_global_id(0); // column tile (0 to cols-1) + const int j = get_global_id(1); // row tile (0 to padded_rows-1) + const int i_2 = i << 2; + const int j_2 = j << 2; + + float4 temp0 = (float4)(0, 0, 0, 0); + float4 temp1 = (float4)(0, 0, 0, 0); + float4 temp2 = (float4)(0, 0, 0, 0); + float4 temp3 = (float4)(0, 0, 0, 0); + + // Only load from valid locations (rows may not be multiple of 4) + if (j_2 + 0 < rows) { + temp0 = read_imagef(input, (j_2 + 0) * cols + i); + } + if (j_2 + 1 < rows) { + temp1 = read_imagef(input, (j_2 + 1) * cols + i); + } + if (j_2 + 2 < rows) { + temp2 = read_imagef(input, (j_2 + 2) * cols + i); + } + if (j_2 + 3 < rows) { + temp3 = read_imagef(input, (j_2 + 3) * cols + i); + } + + // Output is (cols*4 x padded_rows*4) = (K x N_padded) row-major + // Write transposed 4x4 tile + write_imagef(output, (i_2 + 0) * padded_rows + j, (float4)(temp0.s0, temp1.s0, temp2.s0, temp3.s0)); + write_imagef(output, (i_2 + 1) * padded_rows + j, (float4)(temp0.s1, temp1.s1, temp2.s1, temp3.s1)); + write_imagef(output, (i_2 + 2) * padded_rows + j, (float4)(temp0.s2, temp1.s2, temp2.s2, temp3.s2)); + write_imagef(output, (i_2 + 3) * padded_rows + j, (float4)(temp0.s3, temp1.s3, temp2.s3, temp3.s3)); +} + // 32-bit transpose, loading/storing a 4x4 tile of elements // Only used for activations // converts to FP16