Skip to content

Commit 79c8119

Browse files
PMZFXclaude
authored andcommitted
SYCL: add BF16 to DMMV kernel path (~4x tg speedup on Intel Arc) (ggml-org#21580)
* SYCL: add BF16 to DMMV kernel path for ~4x token generation speedup BF16 models had no dedicated token generation kernel — they fell through to the generic full-GEMM path, resulting in ~14% memory bandwidth utilization on Intel Arc GPUs. This adds BF16 support to the DMMV (dequantize mul-mat-vec) path, matching the existing F16 implementation. Fixes ggml-org#20478 * SYCL: fix BF16 DMMV out-of-bounds when ncols % 64 != 0 The qk=1 kernel (used for F16 and BF16) iterates with stride 2*GGML_SYCL_DMMV_X (= 64 on Intel targets where WARP_SIZE=16). When ncols is a multiple of DMMV_X (32) but not of 2*DMMV_X (64), the last warp iteration accesses elements at col >= ncols, producing NaN for the final row and wrong values for interior rows. Fix: tighten can_use_dequantize_mul_mat_vec to require ne[0] % (2*DMMV_X) == 0 for F16/BF16 types, and update the ASSERT in the BF16 launcher to match. Quantized types use block-structured kernels with different access patterns and keep the existing DMMV_X check. Verified: test-backend-ops MUL_MAT passes 913/913 on Intel Arc Pro B70. Previously failing: m=128/129 n=1 k=1056 cases (NaN and ERR > 0.0005). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 5a098a3 commit 79c8119

2 files changed

Lines changed: 53 additions & 2 deletions

File tree

ggml/src/ggml-sycl/dmmv.cpp

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,13 @@
33
#include "dequantize.hpp"
44
#include "presets.hpp"
55

6+
#if defined(__INTEL_LLVM_COMPILER)
7+
#if __has_include(<sycl/ext/oneapi/bfloat16.hpp>)
8+
#include <sycl/ext/oneapi/bfloat16.hpp>
9+
#define GGML_SYCL_DMMV_HAS_BF16
10+
#endif
11+
#endif
12+
613
static void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
714
const sycl::half *x = (const sycl::half *)vx;
815

@@ -11,6 +18,16 @@ static void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat
1118
v.y() = x[ib + iqs + 1];
1219
}
1320

21+
#ifdef GGML_SYCL_DMMV_HAS_BF16
22+
static void convert_bf16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
23+
const sycl::ext::oneapi::bfloat16 *x = (const sycl::ext::oneapi::bfloat16 *)vx;
24+
25+
// automatic bfloat16 -> float type cast if dfloat == float
26+
v.x() = x[ib + iqs + 0];
27+
v.y() = x[ib + iqs + 1];
28+
}
29+
#endif
30+
1431
static void convert_f32(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
1532
const float * x = (const float *) vx;
1633

@@ -217,6 +234,28 @@ static void convert_mul_mat_vec_f16_sycl(const void *vx, const dfloat *y,
217234
}
218235
}
219236

237+
#ifdef GGML_SYCL_DMMV_HAS_BF16
238+
static void convert_mul_mat_vec_bf16_sycl(const void *vx, const dfloat *y,
239+
float *dst, const int ncols,
240+
const int nrows,
241+
dpct::queue_ptr stream) {
242+
// The qk=1 kernel iterates with stride 2*GGML_SYCL_DMMV_X, so ncols must be a
243+
// multiple of that — not just GGML_SYCL_DMMV_X — to avoid out-of-bounds reads.
244+
GGML_ASSERT(ncols % (2*GGML_SYCL_DMMV_X) == 0);
245+
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
246+
const sycl::range<3> block_nums(1, 1, block_num_y);
247+
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
248+
{
249+
stream->parallel_for(
250+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
251+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
252+
dequantize_mul_mat_vec<1, 1, convert_bf16>(vx, y, dst, ncols,
253+
nrows, item_ct1);
254+
});
255+
}
256+
}
257+
#endif
258+
220259
/*
221260
DPCT1110:4: The total declared local variable size in device function
222261
dequantize_mul_mat_vec_q2_k exceeds 128 bytes and may cause high register
@@ -1497,7 +1536,8 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
14971536
bool src1_convert_f16 =
14981537
src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 ||
14991538
src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 ||
1500-
src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16;
1539+
src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16 ||
1540+
src0->type == GGML_TYPE_BF16;
15011541

15021542
if (src1_convert_f16) {
15031543
scope_op_debug_print scope_dbg_print(__func__, "/to_fp16_sycl", dst, /*num_src=*/2,
@@ -1565,6 +1605,11 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
15651605
case GGML_TYPE_F16:
15661606
convert_mul_mat_vec_f16_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
15671607
break;
1608+
#ifdef GGML_SYCL_DMMV_HAS_BF16
1609+
case GGML_TYPE_BF16:
1610+
convert_mul_mat_vec_bf16_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
1611+
break;
1612+
#endif
15681613
default:
15691614
printf("ggml_sycl_op_dequantize_mul_mat_vec unsupported GGML_TYPE %d\n", src0->type);
15701615
GGML_ABORT("fatal error");

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3455,6 +3455,7 @@ static bool ggml_sycl_supports_dmmv(enum ggml_type type) {
34553455
case GGML_TYPE_Q5_K:
34563456
case GGML_TYPE_Q6_K:
34573457
case GGML_TYPE_F16:
3458+
case GGML_TYPE_BF16:
34583459
return true;
34593460
default:
34603461
return false;
@@ -3818,8 +3819,13 @@ static void opt_for_reorder(ggml_backend_sycl_context * ctx, const ggml_tensor *
38183819

38193820

38203821
static bool can_use_dequantize_mul_mat_vec(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3822+
// The F16/BF16 qk=1 kernel iterates with stride 2*DMMV_X, requiring ne[0] to be
3823+
// a multiple of 2*DMMV_X. Quantized types use block-structured kernels that only
3824+
// need ne[0] % DMMV_X == 0.
3825+
const int64_t dmmv_x_required = (src0->type == GGML_TYPE_BF16 || src0->type == GGML_TYPE_F16) ?
3826+
2*GGML_SYCL_DMMV_X : GGML_SYCL_DMMV_X;
38213827
return ggml_sycl_supports_dmmv(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
3822-
src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1;
3828+
src0->ne[0] % dmmv_x_required == 0 && src1->ne[1] == 1;
38233829
}
38243830

38253831
static bool can_use_mul_mat_vec_q(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {

0 commit comments

Comments
 (0)