Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions ggml/src/ggml-opencl/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,10 @@ set(GGML_OPENCL_KERNELS
mul_mv_id_mxfp4_f32_flat
gemm_moe_mxfp4_f32
gemv_moe_mxfp4_f32
gemm_moe_mxfp4_f32_ns
gemv_moe_mxfp4_f32_ns
moe_reorder_b
moe_sort_by_expert
mul_mm_f32_f32_l4_lm
mul_mm_f16_f32_l4_lm
mul_mm_q4_0_f32_l4_lm
Expand Down
451 changes: 374 additions & 77 deletions ggml/src/ggml-opencl/ggml-opencl.cpp

Large diffs are not rendered by default.

87 changes: 87 additions & 0 deletions ggml/src/ggml-opencl/kernels/cvt.cl
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,93 @@ kernel void kernel_restore_block_mxfp4_trans(
b->e = src_e[src_blk_offset];
}

kernel void kernel_convert_block_mxfp4_trans4_ns(
global struct block_mxfp4 * src0,
__global uint * dst_q,
__global uchar * dst_e,
uint ne00,
uint ne01
) {
uint i00 = get_global_id(1);
uint i01 = get_global_id(0);
uint i02 = get_global_id(2);

uint ne00_blk = ne00 / QK_MXFP4;
uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;

global struct block_mxfp4 * b = src0 + src_blk_offset;
dst_e[dst_blk_offset] = b->e;

// extract quantization and unshuffle
ushort8 pre_block = ((global ushort8 *)(&(b->qs[0])))[0];

ushort8 post_block = (ushort8)(0);

uchar * pre_block_ptr = (uchar *)(&pre_block);
uchar * post_block_ptr = (uchar *)(&post_block);

for (int i = 0; i < QK_MXFP4 / 4; ++i) {
uchar x0 = pre_block_ptr[2*i + 0];
uchar x1 = pre_block_ptr[2*i + 1];

post_block_ptr[i + 0 ] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4);
post_block_ptr[i + QK_MXFP4 / 4] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0);
}

uint4 q_block = as_uint4(post_block);

uint offset = i02 * ne00_blk * ne01 * 4 + i00 * ne01 * 4 + i01;
dst_q[offset] = q_block.x;
dst_q[offset + ne01] = q_block.y;
dst_q[offset + ne01 * 2] = q_block.z;
dst_q[offset + ne01 * 3] = q_block.w;
}

kernel void kernel_restore_block_mxfp4_trans4_ns(
__global uint * src_q,
__global uchar * src_e,
__global struct block_mxfp4 * dst0,
uint ne00,
uint ne01
) {
uint i00 = get_global_id(1);
uint i01 = get_global_id(0);
uint i02 = get_global_id(2);

uint ne00_blk = ne00 / QK_MXFP4;
uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
uint src_d_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;

__global struct block_mxfp4 * b = dst0 + dst_blk_offset;
b->e = src_e[src_d_offset];

// collect transposed quantization parts for a block
uint src_q_offset = i02 * ne00_blk * ne01 * 4 + i00 * ne01 * 4 + i01;
uint4 q_block;
q_block.x = src_q[src_q_offset];
q_block.y = src_q[src_q_offset + ne01];
q_block.z = src_q[src_q_offset + ne01 * 2];
q_block.w = src_q[src_q_offset + ne01 * 3];

ushort8 post_block = as_ushort8(q_block);
ushort8 pre_block = (ushort8)(0);

uchar * pre_block_ptr = (uchar *)(&pre_block);
uchar * post_block_ptr = (uchar *)(&post_block);

for (int i = 0; i < QK_MXFP4 / 4; ++i) {
uchar x0 = post_block_ptr[i + 0];
uchar x1 = post_block_ptr[i + QK_MXFP4 / 4];

pre_block_ptr[2 * i + 0] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4);
pre_block_ptr[2 * i + 1] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0);
}

((__global ushort8 *)(&(b->qs[0])))[0] = pre_block;
}


//------------------------------------------------------------------------------
// block_q8_0
//------------------------------------------------------------------------------
Expand Down
Loading