diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index 772fc537494..8151aa38e08 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -105,6 +105,10 @@ set(GGML_OPENCL_KERNELS mul_mv_id_mxfp4_f32_flat gemm_moe_mxfp4_f32 gemv_moe_mxfp4_f32 + gemm_moe_mxfp4_f32_ns + gemv_moe_mxfp4_f32_ns + moe_reorder_b + moe_sort_by_expert mul_mm_f32_f32_l4_lm mul_mm_f16_f32_l4_lm mul_mm_q4_0_f32_l4_lm diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 8bc7ae65a6d..04cce74e6e0 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -416,6 +416,15 @@ struct ggml_backend_opencl_context { ggml_cl_buffer prealloc_src0; ggml_cl_buffer prealloc_src1; + // prealloc buffers for MoE router table preprocess + bool toggle_reorder = false; + ggml_cl_buffer prealloc_post_router; + ggml_cl_buffer prealloc_emap; + ggml_cl_buffer prealloc_hist; + ggml_cl_buffer prealloc_tile_offset; + ggml_cl_buffer prealloc_total_tiles; + ggml_cl_buffer prealloc_slot_counter; + cl_program program_add; cl_program program_add_id; cl_program program_clamp; @@ -531,6 +540,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0; cl_kernel kernel_convert_block_q4_1, kernel_restore_block_q4_1; cl_kernel kernel_convert_block_mxfp4, kernel_convert_block_mxfp4_trans, kernel_restore_block_mxfp4, kernel_restore_block_mxfp4_trans; + cl_kernel kernel_convert_block_mxfp4_trans4_ns, kernel_restore_block_mxfp4_trans4_ns; cl_kernel kernel_convert_block_q8_0, kernel_restore_block_q8_0, kernel_restore_block_q8_0_trans; cl_kernel kernel_convert_block_q6_K_noshuffle, kernel_restore_block_q6_K_noshuffle; cl_kernel kernel_mul_mat_q4_0_f32_8x_flat; @@ -582,6 +592,9 @@ struct ggml_backend_opencl_context { cl_kernel kernel_ssm_conv_f32_f32, kernel_ssm_conv_f32_f32_4; cl_kernel kernel_timestep_embedding; cl_kernel kernel_gemv_moe_mxfp4_f32, kernel_gemm_moe_mxfp4_f32; + cl_kernel kernel_gemv_moe_mxfp4_f32_ns, kernel_gemm_moe_mxfp4_f32_ns; + cl_kernel kernel_moe_reorder_b; + cl_kernel kernel_moe_histogram, kernel_moe_scan, kernel_moe_fill, kernel_moe_scatter; cl_kernel kernel_mul_mv_id_q4_0_f32_8x_flat; cl_kernel kernel_mul_mv_id_q8_0_f32, kernel_mul_mv_id_q8_0_f32_flat; cl_kernel kernel_mul_mv_id_mxfp4_f32; @@ -937,6 +950,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve CL_CHECK((backend_ctx->kernel_restore_block_q4_1 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_1", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_mxfp4_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4_trans", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_mxfp4_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4_trans4_ns", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_mxfp4_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_mxfp4_trans4_ns", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_mxfp4_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_mxfp4_trans", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_mxfp4", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q8_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q8_0", &err), err)); @@ -2762,6 +2777,77 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // gemv_moe_mxfp4_f32_ns + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemv_moe_mxfp4_f32_ns.cl.h" + }; +#else + const std::string kernel_src = read_file("gemv_moe_mxfp4_f32_ns.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemv_moe_mxfp4_f32_ns = clCreateKernel(prog, "kernel_gemv_moe_mxfp4_f32_ns", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // gemm_moe_mxfp4_f32_ns + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemm_moe_mxfp4_f32_ns.cl.h" + }; +#else + const std::string kernel_src = read_file("gemm_moe_mxfp4_f32_ns.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemm_moe_mxfp4_f32_ns = clCreateKernel(prog, "kernel_gemm_moe_mxfp4_f32_ns", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // moe_reorder_b + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "moe_reorder_b.cl.h" + }; +#else + const std::string kernel_src = read_file("moe_reorder_b.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); + + CL_CHECK((backend_ctx->kernel_moe_reorder_b = clCreateKernel(prog, "kernel_moe_reorder_b", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // moe_sort_by_expert + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "moe_sort_by_expert.cl.h" + }; +#else + const std::string kernel_src = read_file("moe_sort_by_expert.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); + + CL_CHECK((backend_ctx->kernel_moe_histogram = clCreateKernel(prog, "kernel_moe_histogram", &err), err)); + CL_CHECK((backend_ctx->kernel_moe_scan = clCreateKernel(prog, "kernel_moe_scan", &err), err)); + CL_CHECK((backend_ctx->kernel_moe_fill = clCreateKernel(prog, "kernel_moe_fill", &err), err)); + CL_CHECK((backend_ctx->kernel_moe_scatter = clCreateKernel(prog, "kernel_moe_scatter", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + // gemv_noshuffle_q6_k_f32 { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -3549,13 +3635,12 @@ struct ggml_tensor_extra_cl_mxfp4 { CL_CHECK(clReleaseMemObject(e)); e = nullptr; } - if (q != nullptr) { + if (q_img != nullptr) { CL_CHECK(clReleaseMemObject(q_img)); - q = nullptr; + q_img = nullptr; } - // Currently, q_img and d_img are not used. They can be image1d_buffer_t + // Currently, e_img is not used. They can be image1d_buffer_t // that wraps around q and d to utilize image access path. - q_img = nullptr; e_img = nullptr; size_q = 0; size_e = 0; @@ -4585,7 +4670,7 @@ inline bool use_adreno_kernels(const ggml_backend_opencl_context *backend_ctx, c inline bool use_adreno_moe_kernels(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) { GGML_UNUSED(backend_ctx); int ne01 = tensor->ne[1]; - return ((strstr(tensor->name, "ffn") != NULL) || (strstr(tensor->name, "as") != NULL)) && (ne01 % 64 == 0); + return (((strstr(tensor->name, "ffn") != NULL) && (strstr(tensor->name, "exps") != NULL)) || (strstr(tensor->name, "as") != NULL)) && (ne01 % 64 == 0); } inline bool enable_adreno_trans_weight(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) { @@ -4996,8 +5081,9 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, CL_CHECK(err); #ifdef GGML_OPENCL_USE_ADRENO_KERNELS + // Adreno moe mxfp4 kernel needs special transpose and unshuffling if (use_adreno_moe_kernels(backend_ctx, tensor)) { - cl_kernel kernel = backend_ctx->kernel_convert_block_mxfp4_trans; + cl_kernel kernel = backend_ctx->kernel_convert_block_mxfp4_trans4_ns; int ne00 = tensor->ne[0]; int ne01 = tensor->ne[1]; @@ -5017,9 +5103,21 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, CL_CHECK(clReleaseMemObject(data_device)); tensor->extra = extra; + // Create image for Q + cl_image_format img_format_q = {CL_R, CL_UNSIGNED_INT32}; + cl_image_desc img_desc_q = { + CL_MEM_OBJECT_IMAGE1D_BUFFER, + static_cast(ggml_nelements(tensor) / 8), + 0, 0, 0, 0, 0, 0, 0, + { extra->q } + }; + extra->q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_format_q, &img_desc_q, NULL, &err); + tensor->extra = extra; + return; } -#endif + +#endif // GGML_OPENCL_USE_ADRENO_KERNELS cl_kernel kernel = backend_ctx->kernel_convert_block_mxfp4; CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); @@ -5676,7 +5774,7 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, #ifdef GGML_OPENCL_USE_ADRENO_KERNELS if (use_adreno_moe_kernels(backend_ctx, tensor)) { - cl_kernel kernel = backend_ctx->kernel_restore_block_mxfp4_trans; + cl_kernel kernel = backend_ctx->kernel_restore_block_mxfp4_trans4_ns; int ne00 = tensor->ne[0]; int ne01 = tensor->ne[1]; @@ -5700,7 +5798,8 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, CL_CHECK(clReleaseMemObject(data_device)); return; } -#endif + +#endif // GGML_OPENCL_USE_ADRENO_KERNELS cl_kernel kernel = backend_ctx->kernel_restore_block_mxfp4; CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->e)); @@ -12169,6 +12268,118 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co } } +static void moe_router_reoerder(ggml_backend_t backend, const ggml_tensor * src, int ne20) { + cl_int err; + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *)src->extra; + cl_ulong offset = extra->offset + src->view_offs; + + const int ne21 = src->ne[1]; + const int nb21 = src->nb[1]; + const int ne02 = nb21 / src->nb[0]; + const int n_tile_size = 32; + const int max_post_router_tile = (ne20 * ne21 / n_tile_size) + ne02; + + cl_buffer_region region; + region.origin = offset; + region.size = nb21 * ne21; + cl_mem original_router_buf = clCreateSubBuffer(extra->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + + backend_ctx->prealloc_post_router.allocate(backend_ctx->context, sizeof(int) * max_post_router_tile * n_tile_size); + region.origin = 0; + region.size = sizeof(int) * max_post_router_tile * n_tile_size; + cl_mem post_router_buf = clCreateSubBuffer(backend_ctx->prealloc_post_router.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + + backend_ctx->prealloc_emap.allocate(backend_ctx->context, sizeof(short) * max_post_router_tile); + region.origin = 0; + region.size = sizeof(short) * max_post_router_tile; + cl_mem emap_buf = clCreateSubBuffer(backend_ctx->prealloc_emap.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + + backend_ctx->prealloc_hist.allocate(backend_ctx->context, sizeof(int) * ne02); + region.origin = 0; + region.size = sizeof(int) * ne02; + cl_mem hist_buf = clCreateSubBuffer(backend_ctx->prealloc_hist.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + + backend_ctx->prealloc_tile_offset.allocate(backend_ctx->context, sizeof(int) * ne02); + region.origin = 0; + region.size = sizeof(int) * ne02; + cl_mem tile_offset_buf = clCreateSubBuffer(backend_ctx->prealloc_tile_offset.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + + backend_ctx->prealloc_slot_counter.allocate(backend_ctx->context, sizeof(int) * ne02); + region.origin = 0; + region.size = sizeof(int) * ne02; + cl_mem slot_counter_buf = clCreateSubBuffer(backend_ctx->prealloc_slot_counter.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + + backend_ctx->prealloc_total_tiles.allocate(backend_ctx->context, sizeof(int)); + region.origin = 0; + region.size = sizeof(int); + cl_mem total_tiles_buf = clCreateSubBuffer(backend_ctx->prealloc_total_tiles.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + + // Histogram + cl_kernel kernel = backend_ctx->kernel_moe_histogram; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &original_router_buf)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &hist_buf)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &ne21)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &ne20)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne02)); + + size_t histogram_global_size[] = {(size_t)(((ne21 + 63) / 64) * 64), static_cast(ne20), 1}; + size_t histogram_local_size[] = {64, static_cast(ne20), 1}; + backend_ctx->enqueue_ndrange_kernel(kernel, 3, histogram_global_size, histogram_local_size, src); + + // Scan + kernel = backend_ctx->kernel_moe_scan; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &hist_buf)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &tile_offset_buf)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &total_tiles_buf)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &slot_counter_buf)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &n_tile_size)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne02)); + + size_t scan_global_size[] = {1}; + size_t scan_local_size[] = {1}; + backend_ctx->enqueue_ndrange_kernel(kernel, 1, scan_global_size, scan_local_size, src); + + // Fill + kernel = backend_ctx->kernel_moe_fill; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &post_router_buf)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &total_tiles_buf)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &n_tile_size)); + + size_t fill_global_size[] = {(size_t)(((max_post_router_tile + 63) / 64) * 64), n_tile_size, 1}; + size_t fill_local_size[] = {64, 1, 1}; + backend_ctx->enqueue_ndrange_kernel(kernel, 3, fill_global_size, fill_local_size, src); + + // Scatter + kernel = backend_ctx->kernel_moe_scatter; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &original_router_buf)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &post_router_buf)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &emap_buf)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &tile_offset_buf)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &slot_counter_buf)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne21)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne20)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne02)); + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, histogram_global_size, histogram_local_size, src); + + CL_CHECK(clReleaseMemObject(original_router_buf)); + CL_CHECK(clReleaseMemObject(hist_buf)); + CL_CHECK(clReleaseMemObject(tile_offset_buf)); + CL_CHECK(clReleaseMemObject(total_tiles_buf)); + CL_CHECK(clReleaseMemObject(slot_counter_buf)); + CL_CHECK(clReleaseMemObject(post_router_buf)); + CL_CHECK(clReleaseMemObject(emap_buf)); +} + static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src0->extra); @@ -12230,6 +12441,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, const int ne0 = dst->ne[0]; const int ne1 = dst->ne[1]; + const int ne2 = dst->ne[2]; const int r2 = ne12/ne02; const int r3 = ne13/ne03; @@ -12242,6 +12454,9 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, int nrows = 1; // number of row in src1 int ndst = 4; // number of values produced by each subgroup + const int n_tile_size = 32; + const int max_post_router_tile = (ne20 * ne21 / n_tile_size) + ne02; + cl_kernel kernel; // subgroup mat vec @@ -12373,11 +12588,10 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, size_t local_size[3] = {64, 2, 1}; size_t global_size[3] = {64, 2, 1}; - cl_mem src1_sub_buffer, buf_src1_image, buf_src2; - - int tile_size = 320; if (ne12 == 1) { // for gemv - kernel = backend_ctx->kernel_gemv_moe_mxfp4_f32; + kernel = backend_ctx->kernel_gemv_moe_mxfp4_f32_ns; + + cl_mem src1_sub_buffer, buf_src1_image, buf_src2; // create a sub_buffer for src2 cl_buffer_region region; @@ -12391,78 +12605,154 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, global_size[1] = 4; global_size[2] = static_cast(ne20); local_size[1] = 4; + + // create a sub_buffer for src1 + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + src1_sub_buffer = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // create image for src1 + cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne10 * ne11 * ne12 / 4), 0,0,0,0,0,0,0, {src1_sub_buffer}}; + buf_src1_image = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); + + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_mxfp4->q)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_mxfp4->e)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src1_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne11)); + + // launch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + + // deallocate sub buffers and images + CL_CHECK(clReleaseMemObject(src1_sub_buffer)); + CL_CHECK(clReleaseMemObject(buf_src1_image)); + CL_CHECK(clReleaseMemObject(buf_src2)); + } else { // for gemm - kernel = backend_ctx->kernel_gemm_moe_mxfp4_f32; - - // preprocess router table - int num_tiles_per_expert = (ne01 + tile_size - 1) / tile_size; - void * host_src2_reorder = malloc(ne20 * ne21 * 4 * num_tiles_per_expert * sizeof(short)); - void * host_src2 = malloc(ne21 * nb21); - CL_CHECK(clEnqueueReadBuffer(backend_ctx->queue, extra2->data_device, CL_TRUE, offset2, ne21 * nb21, host_src2, 0, NULL, NULL)); - int total_experts = nb21 / nb20; - int out_idx = 0; - for (int i_expert = 0; i_expert < ne02; i_expert++) { - for (int i_tile = 0; i_tile < num_tiles_per_expert; i_tile++) { - for (int j = 0; j < ne21; j++) { - for (int i = 0; i < ne20; i++) { - int expert = ((int *)host_src2)[j * total_experts + i]; - if (i_expert == expert) { - ((short *)host_src2_reorder)[out_idx] = static_cast(expert); - ((short *)host_src2_reorder)[out_idx + 1] = static_cast(j * ne11 + (i % ne11)); - ((short *)host_src2_reorder)[out_idx + 2] = static_cast(j * ne20 + i); - ((short *)host_src2_reorder)[out_idx + 3] = static_cast(i_tile); - out_idx += 4; - } - } - } - } + kernel = backend_ctx->kernel_gemm_moe_mxfp4_f32_ns; + + // Reorder router if called from test-backend-ops or when new router is generated. + // Otherwise reuse the reordered result from previous mul_mat_id call. + if ((strstr(src0->name, "as") != NULL) || backend_ctx->toggle_reorder) { + moe_router_reoerder(backend, src2, ne20); + backend_ctx->toggle_reorder = false; } - buf_src2 = clCreateBuffer(backend_ctx->context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, ne20 * ne21 * 4 * num_tiles_per_expert * sizeof(short), host_src2_reorder, &status); + + cl_mem sub_buf_src1_pre, buf_src1_reordered, image_src1_reordered, sub_buf_dst, buf_dst_image; + cl_mem buf_src2, buf_src2_emap; + + cl_buffer_region region; + region.origin = 0; + region.size = sizeof(int) * max_post_router_tile * n_tile_size; + GGML_ASSERT(backend_ctx->prealloc_post_router.buffer); + buf_src2 = clCreateSubBuffer(backend_ctx->prealloc_post_router.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); CL_CHECK(status); - // set thread grid - global_size[0] = static_cast(tile_size); - global_size[2] = static_cast(ne20 * ne21 * num_tiles_per_expert); - } + region.origin = 0; + region.size = sizeof(short) * max_post_router_tile; + buf_src2_emap = clCreateSubBuffer(backend_ctx->prealloc_emap.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); - // create a sub_buffer for src1 - cl_buffer_region region; - region.origin = offset1; - region.size = ne10 * ne11 * ne12 * sizeof(float); - src1_sub_buffer = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); - CL_CHECK(status); - - // create image for src1 - cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; - cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne10 * ne11 * ne12 / 4), 0,0,0,0,0,0,0, {src1_sub_buffer}}; - buf_src1_image = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); - CL_CHECK(status); - - // Set kernel args - int arg_idx = 0; - CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_mxfp4->q)); - CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_mxfp4->e)); - CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src1_image)); - CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); - CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_ulong), &offsetd)); - CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); - CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); - if (ne12 == 1) { - CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne11)); - } else { - CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &tile_size)); - } + // Reorder activations + // create a sub_buffer for src1 + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + sub_buf_src1_pre = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // Create image for reordered src1 + // Use pre-allocated placeholder + region.origin = 0; + region.size = ne00 * max_post_router_tile * n_tile_size * sizeof(float); + backend_ctx->prealloc_act_trans.allocate(backend_ctx->context, region.size); + buf_src1_reordered = clCreateSubBuffer( + backend_ctx->prealloc_act_trans.buffer, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + cl_image_format image_format_buf_src1; + cl_image_desc image_desc_buf_src1; + image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne00 * max_post_router_tile * n_tile_size / 4), 0,0,0,0,0,0,0, {buf_src1_reordered}}; + image_src1_reordered = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); - // launch kernel - backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + unsigned short map_ratio = ne20 / ne11; + GGML_ASSERT(((map_ratio == 1) || (map_ratio == ne20)) && "Map ratio not supported\n"); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 0, sizeof(cl_mem), &sub_buf_src1_pre)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 1, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 2, sizeof(cl_mem), &buf_src1_reordered)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 3, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer))); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 4, sizeof(unsigned int), &ne00)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 5, sizeof(unsigned short), &map_ratio)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 6, sizeof(unsigned int), &n_tile_size)); + + size_t reorder_b_local_size[3] = {256, 1, 1}; + size_t reorder_b_global_size[3] = {static_cast(((ne00 / 4) + 255) / 256 * 256), static_cast(max_post_router_tile * n_tile_size), 1}; + + // Dispatch reorder kernel + backend_ctx->enqueue_ndrange_kernel(backend_ctx->kernel_moe_reorder_b, 3, reorder_b_global_size, reorder_b_local_size, dst); + + // MoE kernel prepare + // Create sub buffer for dst + region.origin = offsetd; + region.size = ne0 * ne1 * ne2 * sizeof(float); + sub_buf_dst = clCreateSubBuffer( + extrad->data_device, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + // Create image for dst + cl_image_format image_format_buf_dst = {CL_R, CL_FLOAT}; + cl_image_desc image_desc_buf_dst = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne0 * ne1 * ne2), 0,0,0,0,0,0,0, {sub_buf_dst}}; + buf_dst_image = clCreateImage(backend_ctx->context, CL_MEM_WRITE_ONLY, &image_format_buf_dst, &image_desc_buf_dst, NULL, &status); + CL_CHECK(status); - // deallocate sub buffers and images - CL_CHECK(clReleaseMemObject(src1_sub_buffer)); - CL_CHECK(clReleaseMemObject(buf_src1_image)); - CL_CHECK(clReleaseMemObject(buf_src2)); + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_mxfp4->q_img)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_mxfp4->e)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &image_src1_reordered)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2_emap)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_dst_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer))); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); + + // set thread grid + global_size[1] = static_cast((ne01 + 63) / 64); + global_size[2] = static_cast(max_post_router_tile); + local_size[1] = 1; + local_size[2] = 1; + + // Dispatch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + + clReleaseMemObject(sub_buf_src1_pre); + clReleaseMemObject(buf_src1_reordered); + clReleaseMemObject(image_src1_reordered); + clReleaseMemObject(buf_src2); + clReleaseMemObject(buf_src2_emap); + clReleaseMemObject(sub_buf_dst); + clReleaseMemObject(buf_dst_image); + } return; - } // else fallback to generic kernel + } // fallback to generic MoE mxfp4 kernel #endif // GGML_OPENCL_USE_ADRENO_KERNELS #ifdef GGML_OPENCL_SOA_Q @@ -13408,6 +13698,13 @@ static void ggml_cl_argsort(ggml_backend_t backend, const ggml_tensor * src0, co size_t local_work_size[] = {(size_t)ne00_padded, 1, 1}; backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + const int ne21 = dst->ne[1]; + if ((strstr(src0->name, "_moe") != NULL) && (ne21 != 1)) { + backend_ctx->toggle_reorder = true; + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS } static void ggml_cl_sum_rows(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { diff --git a/ggml/src/ggml-opencl/kernels/cvt.cl b/ggml/src/ggml-opencl/kernels/cvt.cl index 39af32d282b..c11354a9e2f 100644 --- a/ggml/src/ggml-opencl/kernels/cvt.cl +++ b/ggml/src/ggml-opencl/kernels/cvt.cl @@ -360,6 +360,93 @@ kernel void kernel_restore_block_mxfp4_trans( b->e = src_e[src_blk_offset]; } +kernel void kernel_convert_block_mxfp4_trans4_ns( + global struct block_mxfp4 * src0, + __global uint * dst_q, + __global uchar * dst_e, + uint ne00, + uint ne01 +) { + uint i00 = get_global_id(1); + uint i01 = get_global_id(0); + uint i02 = get_global_id(2); + + uint ne00_blk = ne00 / QK_MXFP4; + uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + + global struct block_mxfp4 * b = src0 + src_blk_offset; + dst_e[dst_blk_offset] = b->e; + + // extract quantization and unshuffle + ushort8 pre_block = ((global ushort8 *)(&(b->qs[0])))[0]; + + ushort8 post_block = (ushort8)(0); + + uchar * pre_block_ptr = (uchar *)(&pre_block); + uchar * post_block_ptr = (uchar *)(&post_block); + + for (int i = 0; i < QK_MXFP4 / 4; ++i) { + uchar x0 = pre_block_ptr[2*i + 0]; + uchar x1 = pre_block_ptr[2*i + 1]; + + post_block_ptr[i + 0 ] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4); + post_block_ptr[i + QK_MXFP4 / 4] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0); + } + + uint4 q_block = as_uint4(post_block); + + uint offset = i02 * ne00_blk * ne01 * 4 + i00 * ne01 * 4 + i01; + dst_q[offset] = q_block.x; + dst_q[offset + ne01] = q_block.y; + dst_q[offset + ne01 * 2] = q_block.z; + dst_q[offset + ne01 * 3] = q_block.w; +} + +kernel void kernel_restore_block_mxfp4_trans4_ns( + __global uint * src_q, + __global uchar * src_e, + __global struct block_mxfp4 * dst0, + uint ne00, + uint ne01 +) { + uint i00 = get_global_id(1); + uint i01 = get_global_id(0); + uint i02 = get_global_id(2); + + uint ne00_blk = ne00 / QK_MXFP4; + uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + uint src_d_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + + __global struct block_mxfp4 * b = dst0 + dst_blk_offset; + b->e = src_e[src_d_offset]; + + // collect transposed quantization parts for a block + uint src_q_offset = i02 * ne00_blk * ne01 * 4 + i00 * ne01 * 4 + i01; + uint4 q_block; + q_block.x = src_q[src_q_offset]; + q_block.y = src_q[src_q_offset + ne01]; + q_block.z = src_q[src_q_offset + ne01 * 2]; + q_block.w = src_q[src_q_offset + ne01 * 3]; + + ushort8 post_block = as_ushort8(q_block); + ushort8 pre_block = (ushort8)(0); + + uchar * pre_block_ptr = (uchar *)(&pre_block); + uchar * post_block_ptr = (uchar *)(&post_block); + + for (int i = 0; i < QK_MXFP4 / 4; ++i) { + uchar x0 = post_block_ptr[i + 0]; + uchar x1 = post_block_ptr[i + QK_MXFP4 / 4]; + + pre_block_ptr[2 * i + 0] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4); + pre_block_ptr[2 * i + 1] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0); + } + + ((__global ushort8 *)(&(b->qs[0])))[0] = pre_block; +} + + //------------------------------------------------------------------------------ // block_q8_0 //------------------------------------------------------------------------------ diff --git a/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl new file mode 100644 index 00000000000..e404f392bdd --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl @@ -0,0 +1,302 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load: enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load: enable +#pragma OPENCL EXTENSION cl_qcom_extra_vector_types : enable + +#define TILESIZE_K 16 +#define TILESIZE_M 64 +#define TILESIZE_N 32 + + +static inline half8 mxfp4_to_fp16_packed8(ushort2 fp4x8) { + ushort2 fp16_packed_a_0, fp16_packed_b_0, bias_a, bias_b, sign_a, sign_b; + fp16_packed_a_0.lo = (fp4x8.s0 << 9) & 0x0E00; + fp16_packed_a_0.hi = (fp4x8.s0 << 5) & 0x0E00; + fp16_packed_b_0.lo = (fp4x8.s0 << 1) & 0x0E00; + fp16_packed_b_0.hi = (fp4x8.s0 >> 3) & 0x0E00; + + bias_a.lo = (fp16_packed_a_0.lo != 0) ? 0x3800 : 0x0; + bias_a.hi = (fp16_packed_a_0.hi != 0) ? 0x3800 : 0x0; + bias_b.lo = (fp16_packed_b_0.lo != 0) ? 0x3800 : 0x0; + bias_b.hi = (fp16_packed_b_0.hi != 0) ? 0x3800 : 0x0; + + fp16_packed_a_0.lo = (fp16_packed_a_0.lo != 0x0200) ? fp16_packed_a_0.lo : 0x0; + fp16_packed_a_0.hi = (fp16_packed_a_0.hi != 0x0200) ? fp16_packed_a_0.hi : 0x0; + fp16_packed_b_0.lo = (fp16_packed_b_0.lo != 0x0200) ? fp16_packed_b_0.lo : 0x0; + fp16_packed_b_0.hi = (fp16_packed_b_0.hi != 0x0200) ? fp16_packed_b_0.hi : 0x0; + + sign_a.lo = (fp4x8.s0 << 12) & 0x8000; + sign_a.hi = (fp4x8.s0 << 8) & 0x8000; + sign_b.lo = (fp4x8.s0 << 4) & 0x8000; + sign_b.hi = fp4x8.s0 & 0x8000; + + fp16_packed_a_0 = sign_a + bias_a + fp16_packed_a_0; + fp16_packed_b_0 = sign_b + bias_b + fp16_packed_b_0; + + ushort2 fp16_packed_a_1, fp16_packed_b_1; + fp16_packed_a_1.lo = (fp4x8.s1 << 9) & 0x0E00; + fp16_packed_a_1.hi = (fp4x8.s1 << 5) & 0x0E00; + fp16_packed_b_1.lo = (fp4x8.s1 << 1) & 0x0E00; + fp16_packed_b_1.hi = (fp4x8.s1 >> 3) & 0x0E00; + + bias_a.lo = (fp16_packed_a_1.lo != 0) ? 0x3800 : 0x0; + bias_a.hi = (fp16_packed_a_1.hi != 0) ? 0x3800 : 0x0; + bias_b.lo = (fp16_packed_b_1.lo != 0) ? 0x3800 : 0x0; + bias_b.hi = (fp16_packed_b_1.hi != 0) ? 0x3800 : 0x0; + + fp16_packed_a_1.lo = (fp16_packed_a_1.lo != 0x0200) ? fp16_packed_a_1.lo : 0x0; + fp16_packed_a_1.hi = (fp16_packed_a_1.hi != 0x0200) ? fp16_packed_a_1.hi : 0x0; + fp16_packed_b_1.lo = (fp16_packed_b_1.lo != 0x0200) ? fp16_packed_b_1.lo : 0x0; + fp16_packed_b_1.hi = (fp16_packed_b_1.hi != 0x0200) ? fp16_packed_b_1.hi : 0x0; + + sign_a.lo = (fp4x8.s1 << 12) & 0x8000; + sign_a.hi = (fp4x8.s1 << 8) & 0x8000; + sign_b.lo = (fp4x8.s1 << 4) & 0x8000; + sign_b.hi = fp4x8.s1 & 0x8000; + + fp16_packed_a_1 = sign_a + bias_a + fp16_packed_a_1; + fp16_packed_b_1 = sign_b + bias_b + fp16_packed_b_1; + + return as_half8((ushort8)(fp16_packed_a_0, fp16_packed_b_0, fp16_packed_a_1, fp16_packed_b_1)); +} + + +#define dotx16_reduce8(a_reg, b_lm, c_reg, lm_offset) \ + acc.s0 = dot(a_reg.s0123, b_lm[lm_offset + 0]); \ + acc.s1 = dot(a_reg.s0123, b_lm[lm_offset + 1]); \ + acc.s2 = dot(a_reg.s0123, b_lm[lm_offset + 2]); \ + acc.s3 = dot(a_reg.s0123, b_lm[lm_offset + 3]); \ + acc.s4 = dot(a_reg.s0123, b_lm[lm_offset + 4]); \ + acc.s5 = dot(a_reg.s0123, b_lm[lm_offset + 5]); \ + acc.s6 = dot(a_reg.s0123, b_lm[lm_offset + 6]); \ + acc.s7 = dot(a_reg.s0123, b_lm[lm_offset + 7]); \ + acc.s8 = dot(a_reg.s0123, b_lm[lm_offset + 8]); \ + acc.s9 = dot(a_reg.s0123, b_lm[lm_offset + 9]); \ + acc.sa = dot(a_reg.s0123, b_lm[lm_offset + 10]); \ + acc.sb = dot(a_reg.s0123, b_lm[lm_offset + 11]); \ + acc.sc = dot(a_reg.s0123, b_lm[lm_offset + 12]); \ + acc.sd = dot(a_reg.s0123, b_lm[lm_offset + 13]); \ + acc.se = dot(a_reg.s0123, b_lm[lm_offset + 14]); \ + acc.sf = dot(a_reg.s0123, b_lm[lm_offset + 15]); \ + acc.s0 += dot(a_reg.s4567, b_lm[lm_offset + 32]); \ + acc.s1 += dot(a_reg.s4567, b_lm[lm_offset + 33]); \ + acc.s2 += dot(a_reg.s4567, b_lm[lm_offset + 34]); \ + acc.s3 += dot(a_reg.s4567, b_lm[lm_offset + 35]); \ + acc.s4 += dot(a_reg.s4567, b_lm[lm_offset + 36]); \ + acc.s5 += dot(a_reg.s4567, b_lm[lm_offset + 37]); \ + acc.s6 += dot(a_reg.s4567, b_lm[lm_offset + 38]); \ + acc.s7 += dot(a_reg.s4567, b_lm[lm_offset + 39]); \ + acc.s8 += dot(a_reg.s4567, b_lm[lm_offset + 40]); \ + acc.s9 += dot(a_reg.s4567, b_lm[lm_offset + 41]); \ + acc.sa += dot(a_reg.s4567, b_lm[lm_offset + 42]); \ + acc.sb += dot(a_reg.s4567, b_lm[lm_offset + 43]); \ + acc.sc += dot(a_reg.s4567, b_lm[lm_offset + 44]); \ + acc.sd += dot(a_reg.s4567, b_lm[lm_offset + 45]); \ + acc.se += dot(a_reg.s4567, b_lm[lm_offset + 46]); \ + acc.sf += dot(a_reg.s4567, b_lm[lm_offset + 47]); \ + c_reg.lo += convert_float8(acc.lo); \ + c_reg.hi += convert_float8(acc.hi); \ + acc.s0 = dot(a_reg.s89ab, b_lm[lm_offset + 64]); \ + acc.s1 = dot(a_reg.s89ab, b_lm[lm_offset + 65]); \ + acc.s2 = dot(a_reg.s89ab, b_lm[lm_offset + 66]); \ + acc.s3 = dot(a_reg.s89ab, b_lm[lm_offset + 67]); \ + acc.s4 = dot(a_reg.s89ab, b_lm[lm_offset + 68]); \ + acc.s5 = dot(a_reg.s89ab, b_lm[lm_offset + 69]); \ + acc.s6 = dot(a_reg.s89ab, b_lm[lm_offset + 70]); \ + acc.s7 = dot(a_reg.s89ab, b_lm[lm_offset + 71]); \ + acc.s8 = dot(a_reg.s89ab, b_lm[lm_offset + 72]); \ + acc.s9 = dot(a_reg.s89ab, b_lm[lm_offset + 73]); \ + acc.sa = dot(a_reg.s89ab, b_lm[lm_offset + 74]); \ + acc.sb = dot(a_reg.s89ab, b_lm[lm_offset + 75]); \ + acc.sc = dot(a_reg.s89ab, b_lm[lm_offset + 76]); \ + acc.sd = dot(a_reg.s89ab, b_lm[lm_offset + 77]); \ + acc.se = dot(a_reg.s89ab, b_lm[lm_offset + 78]); \ + acc.sf = dot(a_reg.s89ab, b_lm[lm_offset + 79]); \ + acc.s0 += dot(a_reg.scdef, b_lm[lm_offset + 96]); \ + acc.s1 += dot(a_reg.scdef, b_lm[lm_offset + 97]); \ + acc.s2 += dot(a_reg.scdef, b_lm[lm_offset + 98]); \ + acc.s3 += dot(a_reg.scdef, b_lm[lm_offset + 99]); \ + acc.s4 += dot(a_reg.scdef, b_lm[lm_offset + 100]); \ + acc.s5 += dot(a_reg.scdef, b_lm[lm_offset + 101]); \ + acc.s6 += dot(a_reg.scdef, b_lm[lm_offset + 102]); \ + acc.s7 += dot(a_reg.scdef, b_lm[lm_offset + 103]); \ + acc.s8 += dot(a_reg.scdef, b_lm[lm_offset + 104]); \ + acc.s9 += dot(a_reg.scdef, b_lm[lm_offset + 105]); \ + acc.sa += dot(a_reg.scdef, b_lm[lm_offset + 106]); \ + acc.sb += dot(a_reg.scdef, b_lm[lm_offset + 107]); \ + acc.sc += dot(a_reg.scdef, b_lm[lm_offset + 108]); \ + acc.sd += dot(a_reg.scdef, b_lm[lm_offset + 109]); \ + acc.se += dot(a_reg.scdef, b_lm[lm_offset + 110]); \ + acc.sf += dot(a_reg.scdef, b_lm[lm_offset + 111]); \ + c_reg.lo += convert_float8(acc.lo); \ + c_reg.hi += convert_float8(acc.hi); \ + + +static inline half e8m0_to_fp16(uchar x) { + ushort bits; + bits = (ushort)(x) - (ushort)(112); + bits = ((bits & 0x00E0) != 0) ? 0x7C00 : (bits << 10); + return as_half(bits); +} + +static inline float e8m0_to_fp32(uchar x) { + int bits; + bits = (x == 0) ? 0x00400000 : ((uint) x << 23); + return as_float(bits); +} + + +__attribute__((qcom_wave_pair_mode(1))) // 1=force single 2=force pair +kernel void kernel_gemm_moe_mxfp4_f32_ns( + __read_only image1d_buffer_t src0_q, + __global uchar * src0_d, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global ushort * src2_emap, + __write_only image1d_buffer_t dst, + __global int * total_tiles, + uint ne00, + uint ne01 +) { + uint block_id_m = get_global_id(1); // m_tile + uint block_id_n = get_global_id(2); // n_tile + + // Boundary check + if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) { + return; + } + + __private half16 reg_a; + __private float32 reg_c = (float32)(0); + __local half4 shared_b[128]; + + const ushort expert_id = src2_emap[block_id_n]; + + const uint row = block_id_m * TILESIZE_M; + const uint col = block_id_n * TILESIZE_N; + + uint sub_block_id_m = get_local_id(0); + uint2 b_global_offset; + b_global_offset.x = ((sub_block_id_m & 3) << 2) + (sub_block_id_m >> 2) * ne00; + b_global_offset.y = b_global_offset.x + (16 * ne00); + uint2 b_local_offset; + b_local_offset.x = (sub_block_id_m & 3) * 32 + (sub_block_id_m >> 2); + b_local_offset.y = b_local_offset.x + 16; + + // Loop along K axis, 32 elements (one block) for each iteration, divided into 2 sub-blocks + for (uint step = 0; step < ne00; step += TILESIZE_K * 2) { + // First sub-block + uint q_sub_offset = row + ((ne01 * step) >> 3) + ((expert_id * ne00 * ne01) >> 3); + uint s_sub_offset = row + ((ne01 * step) >> 5) + ((expert_id * ne00 * ne01) >> 5); + uint b_sub_offset = col * ne00 + step; + + // Load scale for current mxfp4 block + uint s_offset = s_sub_offset + get_global_id(0); + float s = e8m0_to_fp32(src0_d[s_offset]); + + // Load 16 fp4 (64-bits) in transposed layout + uint2 mxfp4x16; + mxfp4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x; + mxfp4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x; + + // Load 16x32 floats from matrix B, each fiber out of 64 in a sub-group loads 8 elements + float8 bx8_f32; + bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); + bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); + // Convert to half and store to LM to share within the subgroup + half8 bx8_f16 = convert_half8(bx8_f32); + shared_b[b_local_offset.x] = bx8_f16.lo; + shared_b[b_local_offset.y] = bx8_f16.hi; + + // Dequantization + reg_a.lo = mxfp4_to_fp16_packed8(as_ushort2(mxfp4x16.lo)) * s; + reg_a.hi = mxfp4_to_fp16_packed8(as_ushort2(mxfp4x16.hi)) * s; + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + // 32 16x16 fp16 dot product with 8 elements reduction for better precision + half16 acc; + dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); + dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); + + // Repeat for second sub-block + uint half_step = step + TILESIZE_K; + q_sub_offset = row + ((ne01 * half_step) >> 3) + ((expert_id * ne00 * ne01) >> 3); + b_sub_offset = col * ne00 + half_step; + + // Load next 16 fp4 (64-bits) in transposed layout + mxfp4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x; + mxfp4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x; + + // Load 16x32 floats from matrix B, each fiber out of 64 in a sub-group loads 8 elements + bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); + bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); + // Convert to half and store to LM to share within the subgroup + bx8_f16 = convert_half8(bx8_f32); + shared_b[b_local_offset.x] = bx8_f16.lo; + shared_b[b_local_offset.y] = bx8_f16.hi; + + // Dequantization + reg_a.lo = mxfp4_to_fp16_packed8(as_ushort2(mxfp4x16.lo)) * s; + reg_a.hi = mxfp4_to_fp16_packed8(as_ushort2(mxfp4x16.hi)) * s; + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + // 32 16x16 fp16 dot product with 3-levels reduction for better precision + dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); + dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); + } + + // Load poster router and share in LM + __local uint out_idx[TILESIZE_N]; + + if (get_local_id(0) < TILESIZE_N) { + uint idx = src2[block_id_n * TILESIZE_N + get_local_id(0)]; + if (idx == 0xFFFFFFFF) { + idx = src2[block_id_n * TILESIZE_N + 0]; + } + out_idx[get_local_id(0)] = idx * ne01; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + // Scatter results back to original position in output grid + uint m_offset = row + get_local_id(0); + + write_imagef(dst, out_idx[1] + m_offset, (reg_c.s1)); + write_imagef(dst, out_idx[2] + m_offset, (reg_c.s2)); + write_imagef(dst, out_idx[3] + m_offset, (reg_c.s3)); + write_imagef(dst, out_idx[4] + m_offset, (reg_c.s4)); + write_imagef(dst, out_idx[5] + m_offset, (reg_c.s5)); + write_imagef(dst, out_idx[6] + m_offset, (reg_c.s6)); + write_imagef(dst, out_idx[7] + m_offset, (reg_c.s7)); + write_imagef(dst, out_idx[8] + m_offset, (reg_c.s8)); + write_imagef(dst, out_idx[9] + m_offset, (reg_c.s9)); + write_imagef(dst, out_idx[10] + m_offset, (reg_c.sa)); + write_imagef(dst, out_idx[11] + m_offset, (reg_c.sb)); + write_imagef(dst, out_idx[12] + m_offset, (reg_c.sc)); + write_imagef(dst, out_idx[13] + m_offset, (reg_c.sd)); + write_imagef(dst, out_idx[14] + m_offset, (reg_c.se)); + write_imagef(dst, out_idx[15] + m_offset, (reg_c.sf)); + write_imagef(dst, out_idx[16] + m_offset, (reg_c.sg)); + write_imagef(dst, out_idx[17] + m_offset, (reg_c.sh)); + write_imagef(dst, out_idx[18] + m_offset, (reg_c.si)); + write_imagef(dst, out_idx[19] + m_offset, (reg_c.sj)); + write_imagef(dst, out_idx[20] + m_offset, (reg_c.sk)); + write_imagef(dst, out_idx[21] + m_offset, (reg_c.sl)); + write_imagef(dst, out_idx[22] + m_offset, (reg_c.sm)); + write_imagef(dst, out_idx[23] + m_offset, (reg_c.sn)); + write_imagef(dst, out_idx[24] + m_offset, (reg_c.so)); + write_imagef(dst, out_idx[25] + m_offset, (reg_c.sp)); + write_imagef(dst, out_idx[26] + m_offset, (reg_c.sq)); + write_imagef(dst, out_idx[27] + m_offset, (reg_c.sr)); + write_imagef(dst, out_idx[28] + m_offset, (reg_c.ss)); + write_imagef(dst, out_idx[29] + m_offset, (reg_c.st)); + write_imagef(dst, out_idx[30] + m_offset, (reg_c.su)); + write_imagef(dst, out_idx[31] + m_offset, (reg_c.sv)); + + // Store zero padding parts to the index of first output in tile, override correct result in the end + barrier(CLK_GLOBAL_MEM_FENCE); + write_imagef(dst, out_idx[0] + m_offset, (reg_c.s0)); +} diff --git a/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl new file mode 100644 index 00000000000..e4b44c1a56a --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl @@ -0,0 +1,161 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable + +#define QK_MXFP4 32 +#define N_SIMDGROUP 4 +#define SIMDGROUP_WIDTH 64 + +static inline half8 mxfp4_to_fp16_packed8(ushort2 fp4x8) { + ushort2 fp16_packed_a_0, fp16_packed_b_0, bias_a, bias_b, sign_a, sign_b; + fp16_packed_a_0.lo = (fp4x8.s0 << 9) & 0x0E00; + fp16_packed_a_0.hi = (fp4x8.s0 << 5) & 0x0E00; + fp16_packed_b_0.lo = (fp4x8.s0 << 1) & 0x0E00; + fp16_packed_b_0.hi = (fp4x8.s0 >> 3) & 0x0E00; + + bias_a.lo = (fp16_packed_a_0.lo != 0) ? 0x3800 : 0x0; + bias_a.hi = (fp16_packed_a_0.hi != 0) ? 0x3800 : 0x0; + bias_b.lo = (fp16_packed_b_0.lo != 0) ? 0x3800 : 0x0; + bias_b.hi = (fp16_packed_b_0.hi != 0) ? 0x3800 : 0x0; + + fp16_packed_a_0.lo = (fp16_packed_a_0.lo != 0x0200) ? fp16_packed_a_0.lo : 0x0; + fp16_packed_a_0.hi = (fp16_packed_a_0.hi != 0x0200) ? fp16_packed_a_0.hi : 0x0; + fp16_packed_b_0.lo = (fp16_packed_b_0.lo != 0x0200) ? fp16_packed_b_0.lo : 0x0; + fp16_packed_b_0.hi = (fp16_packed_b_0.hi != 0x0200) ? fp16_packed_b_0.hi : 0x0; + + sign_a.lo = (fp4x8.s0 << 12) & 0x8000; + sign_a.hi = (fp4x8.s0 << 8) & 0x8000; + sign_b.lo = (fp4x8.s0 << 4) & 0x8000; + sign_b.hi = fp4x8.s0 & 0x8000; + + fp16_packed_a_0 = sign_a + bias_a + fp16_packed_a_0; + fp16_packed_b_0 = sign_b + bias_b + fp16_packed_b_0; + + ushort2 fp16_packed_a_1, fp16_packed_b_1; + fp16_packed_a_1.lo = (fp4x8.s1 << 9) & 0x0E00; + fp16_packed_a_1.hi = (fp4x8.s1 << 5) & 0x0E00; + fp16_packed_b_1.lo = (fp4x8.s1 << 1) & 0x0E00; + fp16_packed_b_1.hi = (fp4x8.s1 >> 3) & 0x0E00; + + bias_a.lo = (fp16_packed_a_1.lo != 0) ? 0x3800 : 0x0; + bias_a.hi = (fp16_packed_a_1.hi != 0) ? 0x3800 : 0x0; + bias_b.lo = (fp16_packed_b_1.lo != 0) ? 0x3800 : 0x0; + bias_b.hi = (fp16_packed_b_1.hi != 0) ? 0x3800 : 0x0; + + fp16_packed_a_1.lo = (fp16_packed_a_1.lo != 0x0200) ? fp16_packed_a_1.lo : 0x0; + fp16_packed_a_1.hi = (fp16_packed_a_1.hi != 0x0200) ? fp16_packed_a_1.hi : 0x0; + fp16_packed_b_1.lo = (fp16_packed_b_1.lo != 0x0200) ? fp16_packed_b_1.lo : 0x0; + fp16_packed_b_1.hi = (fp16_packed_b_1.hi != 0x0200) ? fp16_packed_b_1.hi : 0x0; + + sign_a.lo = (fp4x8.s1 << 12) & 0x8000; + sign_a.hi = (fp4x8.s1 << 8) & 0x8000; + sign_b.lo = (fp4x8.s1 << 4) & 0x8000; + sign_b.hi = fp4x8.s1 & 0x8000; + + fp16_packed_a_1 = sign_a + bias_a + fp16_packed_a_1; + fp16_packed_b_1 = sign_b + bias_b + fp16_packed_b_1; + + return as_half8((ushort8)(fp16_packed_a_0, fp16_packed_b_0, fp16_packed_a_1, fp16_packed_b_1)); +} + +static inline float e8m0_to_fp32(uchar x) { + int bits; + bits = (x == 0) ? 0x00400000 : ((uint) x << 23); + return as_float(bits); +} + + +__attribute__((qcom_reqd_sub_group_size("half"))) +__kernel void kernel_gemv_moe_mxfp4_f32_ns( + __global uint * src0_q, + __global uchar * src0_e, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne11 +) { + uint i01 = get_global_id(0); + uint i20 = get_global_id(2); + uint sgid = get_local_id(1); + uint slid = get_sub_group_local_id(); + + uint i11 = i20 % ne11; + + uint expert_id = src2[i20]; + uint expert_offset = expert_id * ne00 * ne01 / 32; + + __private float sum = 0.0f; // each thread calculate partial sum of one output + + // loop along ne00 in block granularity, skip 4 blocks every iter + for (uint ib00 = sgid; ib00 < (ne00 / QK_MXFP4); ib00 += N_SIMDGROUP) { + + // load one block of q + uint4 regQ; + uint block_offset = expert_offset * 4 + ib00 * ne01 * 4 + i01; + + regQ.s0 = src0_q[block_offset]; + regQ.s1 = src0_q[block_offset + ne01]; + regQ.s2 = src0_q[block_offset + ne01 * 2]; + regQ.s3 = src0_q[block_offset + ne01 * 3]; + + uint offset = i11 * ne00 / 4 + ib00 * 8; + + half8 fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s0)); + + float4 shared_y4; + shared_y4 = read_imagef(src1, (offset + 0)); + float4 acc = shared_y4 * convert_float4(fp16x8.lo); + + shared_y4 = read_imagef(src1, (offset + 1)); + acc += shared_y4 * convert_float4(fp16x8.hi); + + fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s1)); + + shared_y4 = read_imagef(src1, (offset + 2)); + acc += shared_y4 * convert_float4(fp16x8.lo); + + shared_y4 = read_imagef(src1, (offset + 3)); + acc += shared_y4 * convert_float4(fp16x8.hi); + + + fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s2)); + + shared_y4 = read_imagef(src1, (offset + 4)); + acc += shared_y4 * convert_float4(fp16x8.lo); + + shared_y4 = read_imagef(src1, (offset + 5)); + acc += shared_y4 * convert_float4(fp16x8.hi); + + + fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s3)); + + shared_y4 = read_imagef(src1, (offset + 6)); + acc += shared_y4 * convert_float4(fp16x8.lo); + + shared_y4 = read_imagef(src1, (offset + 7)); + acc += shared_y4 * convert_float4(fp16x8.hi); + + uchar regE = src0_e[ib00 * ne01 + i01 + expert_offset]; + sum += e8m0_to_fp32(regE) * ((acc.s0 + acc.s1) + (acc.s2 + acc.s3)); + } + + // reduction in local memory, assumes #subgroups=4 + __local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)]; + if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum; + if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum; + if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum; + barrier(CLK_LOCAL_MEM_FENCE); + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid]; + + // 1 outputs per thread in subgroup 0 + if (sgid == 0) { + dst = dst + (offsetd >> 2); + dst[i01 + i20 * ne01] = sum; + } + +} diff --git a/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl b/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl new file mode 100644 index 00000000000..e6295c81648 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl @@ -0,0 +1,30 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#define QK4_0 32 + +kernel void kernel_moe_reorder_b( + global float4 * src, + global uint * router, + global float4 * dst, + global int * total_tiles, + uint K, + ushort map_ratio, + uint tile_size +) { + uint k_4 = get_global_id(0); + uint post_router_idx = get_global_id(1); + + if ((k_4 >= (K / 4)) || (post_router_idx >= total_tiles[0] * tile_size)) { + return; + } + + uint router_idx = router[post_router_idx]; + + float4 out = (float4)(0); + if (router_idx != 0xFFFFFFFF) { + ushort activation_idx = router_idx / map_ratio; + out = src[activation_idx * K / 4 + k_4]; + } + + dst[post_router_idx * K / 4 + k_4] = out; +} diff --git a/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl b/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl new file mode 100644 index 00000000000..d9703429b11 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl @@ -0,0 +1,82 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +__kernel void kernel_moe_histogram( + __global const int * input, + __global int * hist, + uint N, + uint topK, + uint n_experts +) { + uint n = get_global_id(0); + uint k = get_global_id(1); + + if (n >= N || k >= topK) { + return; + } + + int expert_id = input[n * n_experts + k]; + atomic_inc(&hist[expert_id]); +} + +__kernel void kernel_moe_scan( + __global int * hist, + __global int * tile_offset, + __global int * total_tiles, + __global int * slot_counter, + int tile_size, + uint n_experts +) { + int offset = 0; + for (int v = 0; v < n_experts; v++) { + int count = hist[v]; + int tiles = (count + tile_size - 1) / tile_size; + tile_offset[v] = offset; + offset += tiles; + hist[v] = 0; + slot_counter[v] = 0; + } + + *total_tiles = offset; +} + +__kernel void kernel_moe_scatter( + __global const int * input, + __global int * post_router, + __global ushort * emap, + __global const int * tile_offset, + __global int * slot_counter, + int N, + int topK, + uint n_experts +) { + uint n = get_global_id(0); + uint k = get_global_id(1); + + if (n >= N || k >= topK) { + return; + } + + int val = input[n * n_experts + k]; + + int local_slot = atomic_inc(&slot_counter[val]); + + int tile_idx = tile_offset[val] + (local_slot / 32); + int lane = local_slot % 32; + int out_pos = tile_idx * 32 + lane; + + post_router[out_pos] = n * topK + k; + emap[tile_idx] = val; +} + +__kernel void kernel_moe_fill( + __global int * post_router, + __global int * total_tiles, + int tile_size +) { + int tile_id = get_global_id(0); + int vec_id_in_tile = get_global_id(1); + + if (tile_id < total_tiles[0]) { + post_router[tile_id * tile_size + vec_id_in_tile] = 0xFFFFFFFF; + } +}