Skip to content

Commit 977fea5

Browse files
z-sachinkashif
authored andcommitted
ggml-zendnn : add Q8_0 quantization support (ggml-org#23414)
* ggml-zendnn : add Q8_0 quantization support * ggml-zendnn : sync with latest ZenDNN * ggml-zendnn : address review comments for Q8_0
1 parent cfd200e commit 977fea5

2 files changed

Lines changed: 46 additions & 12 deletions

File tree

ggml/src/ggml-zendnn/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ if (NOT ZENDNN_ROOT OR ZENDNN_ROOT STREQUAL "" OR ZENDNN_ROOT STREQUAL "OFF")
2828
ExternalProject_Add(
2929
zendnn
3030
GIT_REPOSITORY https://github.com/amd/ZenDNN.git
31-
GIT_TAG ac9e580d9434b7b98985f2627a7ebfb5eba4bb0d # ZenDNN-2026-WW17
31+
GIT_TAG 253b94ce0d7e9284c265fefb485714944caff9d3 # ZenDNN-2026-WW19
3232
PREFIX ${ZENDNN_PREFIX}
3333
SOURCE_DIR ${ZENDNN_SOURCE_DIR}
3434
BINARY_DIR ${ZENDNN_BUILD_DIR}

ggml/src/ggml-zendnn/ggml-zendnn.cpp

Lines changed: 45 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22

33
#include "ggml-backend-impl.h"
44
#include "ggml-impl.h"
5+
6+
#define GGML_COMMON_DECL_CPP
7+
#include "ggml-common.h"
8+
59
#include "zendnnl.hpp"
610

711
#include <cstring>
@@ -19,6 +23,8 @@ zendnnl::common::data_type_t ggml_to_zendnn_type() {
1923
return zendnnl::common::data_type_t::f32;
2024
} else if constexpr (std::is_same_v<T, ggml_bf16_t>) {
2125
return zendnnl::common::data_type_t::bf16;
26+
} else if constexpr (std::is_same_v<T, block_q8_0>) {
27+
return zendnnl::common::data_type_t::s8;
2228
} else {
2329
return zendnnl::common::data_type_t::none;
2430
}
@@ -48,6 +54,17 @@ static bool ggml_zendnn_matmul(ggml_backend_zendnn_context * ctx, int64_t m, int
4854
params.num_threads = ctx->n_threads;
4955

5056
zendnnl::lowoha::matmul::matmul_batch_params_t batch_params;
57+
58+
if constexpr (std::is_same_v<TA, block_q8_0>) {
59+
params.dtypes.compute = zendnnl::common::data_type_t::s8;
60+
const int64_t num_groups = k / QK8_0;
61+
params.dynamic_quant = true;
62+
params.quant_params.src_scale.buff = nullptr;
63+
params.quant_params.src_scale.dt = zendnnl::common::data_type_t::bf16;
64+
params.quant_params.src_scale.dims = {n, num_groups};
65+
params.packing.pack_format_b = 1;
66+
}
67+
5168
zendnnl::error_handling::status_t status = zendnnl::lowoha::matmul::matmul_direct(
5269
'r', false, true, // row-major, don't transpose B, transpose A (because it's column-major)
5370
n, // M: rows of B and C
@@ -108,6 +125,14 @@ static bool ggml_zendnn_sgemm(ggml_backend_zendnn_context * ctx, int64_t m, int6
108125
(const ggml_bf16_t *)B, ldb,
109126
(float *)C, ldc);
110127
return false;
128+
case GGML_TYPE_Q8_0:
129+
if (Btype != GGML_TYPE_F32 || Ctype != GGML_TYPE_F32)
130+
return false;
131+
return ggml_zendnn_matmul<block_q8_0, float, float>(
132+
ctx, m, n, k,
133+
(const block_q8_0 *)A, lda,
134+
(const float *)B, ldb,
135+
(float *)C, ldc);
111136
default:
112137
return false; // unsupported type
113138
}
@@ -145,7 +170,9 @@ static void ggml_zendnn_compute_forward_mul_mat(
145170
const int64_t r3 = ne13/ne03;
146171

147172
void * work_data = ctx->work_data.get();
148-
if (src1->type != vec_dot_type) {
173+
174+
// ZenDNN requires FP32 for dynamic quantization, so conversion is skipped
175+
if (src1->type != vec_dot_type && src0->type != GGML_TYPE_Q8_0) {
149176
const size_t nbw1 = ggml_row_size(vec_dot_type, ne10);
150177
const size_t nbw2 = nbw1 * ne11;
151178
const size_t nbw3 = nbw2 * ne12;
@@ -171,7 +198,7 @@ static void ggml_zendnn_compute_forward_mul_mat(
171198

172199
for (int64_t i13 = 0; i13 < ne13; i13++) {
173200
for (int64_t i12 = 0; i12 < ne12; i12++) {
174-
const void* wdata = src1->type == vec_dot_type ? src1->data : work_data;
201+
const void* wdata = (src1->type == vec_dot_type || src0->type == GGML_TYPE_Q8_0) ? src1->data : work_data;
175202
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
176203
if (!ggml_zendnn_sgemm(ctx,
177204
ne01, // m
@@ -184,7 +211,7 @@ static void ggml_zendnn_compute_forward_mul_mat(
184211
static_cast<char *>(dst->data) + i12*nb2 + i13*nb3,
185212
ne01, // ldc
186213
src0->type,
187-
vec_dot_type,
214+
src0->type == GGML_TYPE_Q8_0 ? GGML_TYPE_F32 : vec_dot_type,
188215
dst->type))
189216
GGML_ABORT("%s: ZenDNN sgemm failed\n", __func__);
190217
}
@@ -261,10 +288,15 @@ static void ggml_zendnn_compute_forward_mul_mat_id(
261288
const size_t nbw1 = row_size;
262289
const size_t nbw2 = nbw1 * ne11;
263290
const size_t nbw3 = nbw2 * ne12;
264-
const size_t src1_conv_size = (src1->type != vec_dot_type) ? ne13 * nbw3 : 0;
291+
const size_t src1_conv_size = (src1->type != vec_dot_type && src0->type != GGML_TYPE_Q8_0) ? ne13 * nbw3 : 0;
292+
293+
// For Q8_0, src1 is always F32; the gather buffer must hold F32 rows (ne10*4 bytes),
294+
// not Q8_0-encoded rows (row_size ≈ ne10/32*34 bytes) — they differ by ~4x.
295+
const size_t f32_row_size = (size_t)ne10 * sizeof(float);
296+
const size_t gather_row_size = (src0->type == GGML_TYPE_Q8_0) ? f32_row_size : row_size;
265297

266298
// size for MoE gather/scatter buffers
267-
const size_t wdata_cur_size = max_rows * row_size;
299+
const size_t wdata_cur_size = max_rows * gather_row_size;
268300
const size_t dst_cur_size = max_rows * ggml_row_size(dst->type, ne01);
269301

270302
// allocate single buffer for all needs
@@ -279,7 +311,8 @@ static void ggml_zendnn_compute_forward_mul_mat_id(
279311
char * wdata_cur = work_data + src1_conv_size;
280312
char * dst_cur = wdata_cur + wdata_cur_size;
281313

282-
if (src1->type != vec_dot_type) {
314+
// ZenDNN requires FP32 for dynamic quantization, so conversion is skipped
315+
if (src1->type != vec_dot_type && src0->type != GGML_TYPE_Q8_0) {
283316
GGML_ASSERT(src1->type == GGML_TYPE_F32);
284317

285318
#pragma omp parallel for collapse(3) num_threads(ctx->n_threads) schedule(static)
@@ -294,7 +327,7 @@ static void ggml_zendnn_compute_forward_mul_mat_id(
294327
}
295328
}
296329

297-
const void * wdata = src1->type == vec_dot_type ? src1->data : work_data;
330+
const void * wdata = (src1->type == vec_dot_type || src0->type == GGML_TYPE_Q8_0) ? src1->data : work_data;
298331

299332
// process each expert with gather -> gemm -> scatter pattern
300333
for (int64_t cur_a = 0; cur_a < n_as; ++cur_a) {
@@ -315,9 +348,9 @@ static void ggml_zendnn_compute_forward_mul_mat_id(
315348
const int64_t i12 = row_mapping.i2;
316349

317350
std::memcpy(
318-
wdata_cur + ir1 * row_size,
319-
(const char *) wdata + (i11 + i12*ne11) * row_size,
320-
row_size
351+
wdata_cur + ir1 * gather_row_size,
352+
(const char *) wdata + (i11 + i12*ne11) * gather_row_size,
353+
gather_row_size
321354
);
322355
}
323356

@@ -333,7 +366,7 @@ static void ggml_zendnn_compute_forward_mul_mat_id(
333366
dst_cur,
334367
ne01, // ldc
335368
src0->type,
336-
vec_dot_type,
369+
src0->type == GGML_TYPE_Q8_0 ? GGML_TYPE_F32 : vec_dot_type,
337370
dst->type)) {
338371
GGML_ABORT("%s: ZenDNN sgemm failed\n", __func__);
339372
}
@@ -577,6 +610,7 @@ static bool ggml_backend_zendnn_device_supports_op(ggml_backend_dev_t dev, const
577610
switch (weights->type) {
578611
case GGML_TYPE_F32:
579612
case GGML_TYPE_BF16:
613+
case GGML_TYPE_Q8_0:
580614
return true;
581615
default:
582616
return false;

0 commit comments

Comments
 (0)