22
33#include " ggml-backend-impl.h"
44#include " ggml-impl.h"
5+
6+ #define GGML_COMMON_DECL_CPP
7+ #include " ggml-common.h"
8+
59#include " zendnnl.hpp"
610
711#include < cstring>
@@ -19,6 +23,8 @@ zendnnl::common::data_type_t ggml_to_zendnn_type() {
1923 return zendnnl::common::data_type_t ::f32 ;
2024 } else if constexpr (std::is_same_v<T, ggml_bf16_t >) {
2125 return zendnnl::common::data_type_t ::bf16 ;
26+ } else if constexpr (std::is_same_v<T, block_q8_0>) {
27+ return zendnnl::common::data_type_t ::s8;
2228 } else {
2329 return zendnnl::common::data_type_t ::none;
2430 }
@@ -48,6 +54,17 @@ static bool ggml_zendnn_matmul(ggml_backend_zendnn_context * ctx, int64_t m, int
4854 params.num_threads = ctx->n_threads ;
4955
5056 zendnnl::lowoha::matmul::matmul_batch_params_t batch_params;
57+
58+ if constexpr (std::is_same_v<TA , block_q8_0>) {
59+ params.dtypes .compute = zendnnl::common::data_type_t ::s8;
60+ const int64_t num_groups = k / QK8_0 ;
61+ params.dynamic_quant = true ;
62+ params.quant_params .src_scale .buff = nullptr ;
63+ params.quant_params .src_scale .dt = zendnnl::common::data_type_t ::bf16 ;
64+ params.quant_params .src_scale .dims = {n, num_groups};
65+ params.packing .pack_format_b = 1 ;
66+ }
67+
5168 zendnnl::error_handling::status_t status = zendnnl::lowoha::matmul::matmul_direct (
5269 ' r' , false , true , // row-major, don't transpose B, transpose A (because it's column-major)
5370 n, // M: rows of B and C
@@ -108,6 +125,14 @@ static bool ggml_zendnn_sgemm(ggml_backend_zendnn_context * ctx, int64_t m, int6
108125 (const ggml_bf16_t *)B, ldb,
109126 (float *)C, ldc);
110127 return false ;
128+ case GGML_TYPE_Q8_0 :
129+ if (Btype != GGML_TYPE_F32 || Ctype != GGML_TYPE_F32 )
130+ return false ;
131+ return ggml_zendnn_matmul<block_q8_0, float , float >(
132+ ctx, m, n, k,
133+ (const block_q8_0 *)A, lda,
134+ (const float *)B, ldb,
135+ (float *)C, ldc);
111136 default :
112137 return false ; // unsupported type
113138 }
@@ -145,7 +170,9 @@ static void ggml_zendnn_compute_forward_mul_mat(
145170 const int64_t r3 = ne13/ne03;
146171
147172 void * work_data = ctx->work_data .get ();
148- if (src1->type != vec_dot_type) {
173+
174+ // ZenDNN requires FP32 for dynamic quantization, so conversion is skipped
175+ if (src1->type != vec_dot_type && src0->type != GGML_TYPE_Q8_0 ) {
149176 const size_t nbw1 = ggml_row_size (vec_dot_type, ne10);
150177 const size_t nbw2 = nbw1 * ne11;
151178 const size_t nbw3 = nbw2 * ne12;
@@ -171,7 +198,7 @@ static void ggml_zendnn_compute_forward_mul_mat(
171198
172199 for (int64_t i13 = 0 ; i13 < ne13; i13++) {
173200 for (int64_t i12 = 0 ; i12 < ne12; i12++) {
174- const void * wdata = src1->type == vec_dot_type ? src1->data : work_data;
201+ const void * wdata = ( src1->type == vec_dot_type || src0-> type == GGML_TYPE_Q8_0 ) ? src1->data : work_data;
175202 const size_t row_size = ggml_row_size (vec_dot_type, ne10);
176203 if (!ggml_zendnn_sgemm (ctx,
177204 ne01, // m
@@ -184,7 +211,7 @@ static void ggml_zendnn_compute_forward_mul_mat(
184211 static_cast <char *>(dst->data ) + i12*nb2 + i13*nb3,
185212 ne01, // ldc
186213 src0->type ,
187- vec_dot_type,
214+ src0-> type == GGML_TYPE_Q8_0 ? GGML_TYPE_F32 : vec_dot_type,
188215 dst->type ))
189216 GGML_ABORT (" %s: ZenDNN sgemm failed\n " , __func__);
190217 }
@@ -261,10 +288,15 @@ static void ggml_zendnn_compute_forward_mul_mat_id(
261288 const size_t nbw1 = row_size;
262289 const size_t nbw2 = nbw1 * ne11;
263290 const size_t nbw3 = nbw2 * ne12;
264- const size_t src1_conv_size = (src1->type != vec_dot_type) ? ne13 * nbw3 : 0 ;
291+ const size_t src1_conv_size = (src1->type != vec_dot_type && src0->type != GGML_TYPE_Q8_0 ) ? ne13 * nbw3 : 0 ;
292+
293+ // For Q8_0, src1 is always F32; the gather buffer must hold F32 rows (ne10*4 bytes),
294+ // not Q8_0-encoded rows (row_size ≈ ne10/32*34 bytes) — they differ by ~4x.
295+ const size_t f32_row_size = (size_t )ne10 * sizeof (float );
296+ const size_t gather_row_size = (src0->type == GGML_TYPE_Q8_0 ) ? f32_row_size : row_size;
265297
266298 // size for MoE gather/scatter buffers
267- const size_t wdata_cur_size = max_rows * row_size ;
299+ const size_t wdata_cur_size = max_rows * gather_row_size ;
268300 const size_t dst_cur_size = max_rows * ggml_row_size (dst->type , ne01);
269301
270302 // allocate single buffer for all needs
@@ -279,7 +311,8 @@ static void ggml_zendnn_compute_forward_mul_mat_id(
279311 char * wdata_cur = work_data + src1_conv_size;
280312 char * dst_cur = wdata_cur + wdata_cur_size;
281313
282- if (src1->type != vec_dot_type) {
314+ // ZenDNN requires FP32 for dynamic quantization, so conversion is skipped
315+ if (src1->type != vec_dot_type && src0->type != GGML_TYPE_Q8_0 ) {
283316 GGML_ASSERT (src1->type == GGML_TYPE_F32 );
284317
285318 #pragma omp parallel for collapse(3) num_threads(ctx->n_threads) schedule(static)
@@ -294,7 +327,7 @@ static void ggml_zendnn_compute_forward_mul_mat_id(
294327 }
295328 }
296329
297- const void * wdata = src1->type == vec_dot_type ? src1->data : work_data;
330+ const void * wdata = ( src1->type == vec_dot_type || src0-> type == GGML_TYPE_Q8_0 ) ? src1->data : work_data;
298331
299332 // process each expert with gather -> gemm -> scatter pattern
300333 for (int64_t cur_a = 0 ; cur_a < n_as; ++cur_a) {
@@ -315,9 +348,9 @@ static void ggml_zendnn_compute_forward_mul_mat_id(
315348 const int64_t i12 = row_mapping.i2 ;
316349
317350 std::memcpy (
318- wdata_cur + ir1 * row_size ,
319- (const char *) wdata + (i11 + i12*ne11) * row_size ,
320- row_size
351+ wdata_cur + ir1 * gather_row_size ,
352+ (const char *) wdata + (i11 + i12*ne11) * gather_row_size ,
353+ gather_row_size
321354 );
322355 }
323356
@@ -333,7 +366,7 @@ static void ggml_zendnn_compute_forward_mul_mat_id(
333366 dst_cur,
334367 ne01, // ldc
335368 src0->type ,
336- vec_dot_type,
369+ src0-> type == GGML_TYPE_Q8_0 ? GGML_TYPE_F32 : vec_dot_type,
337370 dst->type )) {
338371 GGML_ABORT (" %s: ZenDNN sgemm failed\n " , __func__);
339372 }
@@ -577,6 +610,7 @@ static bool ggml_backend_zendnn_device_supports_op(ggml_backend_dev_t dev, const
577610 switch (weights->type ) {
578611 case GGML_TYPE_F32 :
579612 case GGML_TYPE_BF16 :
613+ case GGML_TYPE_Q8_0 :
580614 return true ;
581615 default :
582616 return false ;
0 commit comments