Skip to content

Commit 225088e

Browse files
authored
sycl: Improve mul_mat_id memory efficiency and add BF16 fast path (ggml-org#22119)
* sycl: size mul_mat_id staging buffers by routed rows Previously src1_contiguous/dst_contiguous in ggml_sycl_mul_mat_id were sized to ggml_nelements(src1/dst), which over-allocates when ne12 > 1 and can fail with UR_RESULT_ERROR_OUT_OF_HOST_MEMORY on Level Zero for MoE models (notably with --cpu-moe). Size them by the actual number of routed rows (ids->ne[1] * n_ids) instead. * sycl: add bf16 mul_mat fast path via DNNL When src0 is BF16 (commonly the case for lm_head / output.weight), the existing f16 path is skipped because bf16 isn't covered, and the f32 fallback dequantizes the entire src0 slab to f32 in a single pool alloc (row_diff*ne00 floats). For large-vocab models this can reach several GB and fail with UR_RESULT_ERROR_OUT_OF_HOST_MEMORY on Level Zero. Add a bf16xbf16 -> f32 DNNL matmul fast path that uses the bf16 storage in place and only materializes a small src1 bf16 conversion buffer. bf16 matmul accumulates in f32, so it's correct even when the op requests GGML_PREC_F32 (as lm_head does). - gemm.hpp: map bfloat16 to dnnl::memory::data_type::bf16. - convert.{hpp,cpp}: expose ggml_get_to_bf16_sycl for f32/f16/bf16 -> bf16. - ggml-sycl.cpp: take the bf16 path early in ggml_sycl_op_mul_mat_sycl when DNNL and GGML_SYCL_HAS_BF16 are both available.
1 parent 82d3f4d commit 225088e

6 files changed

Lines changed: 70 additions & 10 deletions

File tree

ggml/src/ggml-sycl/common.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,13 @@
2828

2929
namespace syclexp = sycl::ext::oneapi::experimental;
3030

31+
#if defined(__INTEL_LLVM_COMPILER) && __has_include(<sycl/ext/oneapi/bfloat16.hpp>)
32+
#include <sycl/ext/oneapi/bfloat16.hpp>
33+
#ifndef GGML_SYCL_HAS_BF16
34+
#define GGML_SYCL_HAS_BF16
35+
#endif
36+
#endif
37+
3138
#if GGML_SYCL_DNNL
3239
#include "dnnl.hpp"
3340
#include "dnnl_sycl.hpp"

ggml/src/ggml-sycl/convert.cpp

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,6 @@
22
#include "dequantize.hpp"
33
#include "presets.hpp"
44

5-
#if defined(__INTEL_LLVM_COMPILER)
6-
#if __has_include(<sycl/ext/oneapi/bfloat16.hpp>)
7-
#include <sycl/ext/oneapi/bfloat16.hpp>
8-
#define GGML_SYCL_HAS_BF16
9-
#endif
10-
#endif
11-
125
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
136
static void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k,
147
const sycl::nd_item<3> &item_ct1) {
@@ -767,6 +760,22 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) {
767760
}
768761

769762

763+
#ifdef GGML_SYCL_HAS_BF16
764+
to_bf16_sycl_t ggml_get_to_bf16_sycl(ggml_type type, ggml_tensor * /*dst*/) {
765+
switch (type) {
766+
case GGML_TYPE_F32:
767+
return convert_unary_sycl<float>;
768+
case GGML_TYPE_F16:
769+
return convert_unary_sycl<sycl::half>;
770+
case GGML_TYPE_BF16:
771+
return convert_unary_sycl<sycl::ext::oneapi::bfloat16>;
772+
default:
773+
GGML_ABORT("fatal error: unsupport data type=%s\n", ggml_type_name(type));
774+
return nullptr;
775+
}
776+
}
777+
#endif
778+
770779
to_fp16_nc_sycl_t ggml_get_to_fp16_nc_sycl(ggml_type type) {
771780
switch (type) {
772781
case GGML_TYPE_F32:

ggml/src/ggml-sycl/convert.hpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@ typedef to_t_sycl_t<sycl::half> to_fp16_sycl_t;
2323
to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst);
2424
to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor * dst);
2525

26+
#ifdef GGML_SYCL_HAS_BF16
27+
typedef to_t_sycl_t<sycl::ext::oneapi::bfloat16> to_bf16_sycl_t;
28+
to_bf16_sycl_t ggml_get_to_bf16_sycl(ggml_type type, ggml_tensor * dst);
29+
#endif
30+
2631
// Nc = Non-contiguous
2732
template <typename T>
2833
using to_t_nc_sycl_t = void (*)(const void * x, T * y, int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03,
@@ -35,15 +40,19 @@ template<typename dst_t, typename src_t>
3540
inline dst_t ggml_sycl_cast(src_t x) {
3641
if constexpr (std::is_same_v<dst_t, src_t>) {
3742
return x;
43+
#ifdef GGML_SYCL_HAS_BF16
3844
} else if constexpr (std::is_same_v<dst_t, sycl::ext::oneapi::bfloat16>) {
3945
return sycl::ext::oneapi::bfloat16(float(x));
4046
} else if constexpr (std::is_same_v<src_t, sycl::ext::oneapi::bfloat16>) {
4147
return static_cast<float>(x);
48+
#endif
4249
} else if constexpr (std::is_same_v<src_t, sycl::float2> && std::is_same_v<dst_t, sycl::half2>) {
4350
return x.template convert<sycl::half, sycl::rounding_mode::rte>();
51+
#ifdef GGML_SYCL_HAS_BF16
4452
} else if constexpr (std::is_same_v<src_t, sycl::float2> &&
4553
std::is_same_v<dst_t, sycl::vec<sycl::ext::oneapi::bfloat16, 2>>) {
4654
return {x.x, x.y};
55+
#endif
4756
} else if constexpr(std::is_same_v<dst_t, int32_t>) {
4857
return int32_t(x);
4958
} else {

ggml/src/ggml-sycl/gemm.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ class DnnlGemmWrapper {
2929
static constexpr dt to_dt() {
3030
if constexpr (std::is_same_v<T, float>) return dt::f32;
3131
else if constexpr (std::is_same_v<T, sycl::half>) return dt::f16;
32+
#ifdef GGML_SYCL_HAS_BF16
33+
else if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) return dt::bf16;
34+
#endif
3235
else static_assert(0);
3336
}
3437

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2176,6 +2176,31 @@ inline void ggml_sycl_op_mul_mat_sycl(
21762176
#else
21772177
bool use_fp16 = false;
21782178
#endif
2179+
2180+
#if GGML_SYCL_DNNL && defined(GGML_SYCL_HAS_BF16)
2181+
// Fast path for bf16 src0
2182+
if (src0->type == GGML_TYPE_BF16 && !g_ggml_sycl_disable_dnn && ggml_is_contiguous(src0) &&
2183+
row_diff == src0->ne[1]) {
2184+
using bf16_t = sycl::ext::oneapi::bfloat16;
2185+
ggml_sycl_pool_alloc<bf16_t> src1_as_bf16(ctx.pool(), src1_ncols*ne10);
2186+
if (src1->type != GGML_TYPE_BF16) {
2187+
const to_bf16_sycl_t to_bf16_sycl = ggml_get_to_bf16_sycl(src1->type, dst);
2188+
GGML_ASSERT(to_bf16_sycl != nullptr);
2189+
to_bf16_sycl(src1_ddf_i, src1_as_bf16.get(), src1_ncols*ne10, stream);
2190+
} else {
2191+
stream->memcpy(src1_as_bf16.get(), src1_ddf_i, src1_ncols*ne10*sizeof(bf16_t));
2192+
}
2193+
DnnlGemmWrapper::row_gemm(ctx, row_diff, src1_ncols, ne10,
2194+
src0_dd_i, DnnlGemmWrapper::to_dt<bf16_t>(),
2195+
src1_as_bf16.get(), DnnlGemmWrapper::to_dt<bf16_t>(),
2196+
dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
2197+
GGML_UNUSED(dst);
2198+
GGML_UNUSED(src1_ddq_i);
2199+
GGML_UNUSED(src1_padded_row_size);
2200+
return;
2201+
}
2202+
#endif
2203+
21792204
if ((src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && use_fp16 && ggml_is_contiguous(src0) &&
21802205
row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) {
21812206
ggml_sycl_pool_alloc<sycl::half> src0_as_f16(ctx.pool());
@@ -3848,8 +3873,9 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
38483873
}
38493874
}
38503875
} else {
3851-
ggml_sycl_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
3852-
ggml_sycl_pool_alloc<char> dst_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst));
3876+
const int64_t n_routed_rows = ids->ne[1] * n_ids;
3877+
ggml_sycl_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*n_routed_rows*ne10);
3878+
ggml_sycl_pool_alloc<char> dst_contiguous(ctx.pool(), sizeof(float)*n_routed_rows*ne0);
38533879

38543880
src1_row.data = src1_contiguous.get();
38553881
dst_row.data = dst_contiguous.get();

ggml/src/ggml-sycl/set_rows.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44
namespace utils {
55
template<typename T>
66
static constexpr bool is_arithmetic_v() {
7-
return std::is_arithmetic_v<T> || std::is_same_v<T, sycl::half> || std::is_same_v<T, sycl::ext::oneapi::bfloat16>;
7+
return std::is_arithmetic_v<T> || std::is_same_v<T, sycl::half>
8+
#ifdef GGML_SYCL_HAS_BF16
9+
|| std::is_same_v<T, sycl::ext::oneapi::bfloat16>
10+
#endif
11+
;
812
}
913
}
1014

@@ -181,6 +185,7 @@ static void set_rows_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor * s
181185
stream
182186
);
183187
break;
188+
#ifdef GGML_SYCL_HAS_BF16
184189
case GGML_TYPE_BF16:
185190
set_rows_sycl<TIn, TIdx, sycl::ext::oneapi::bfloat16>(
186191
src0_d, src1_d, (char *)dst->data,
@@ -193,6 +198,7 @@ static void set_rows_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor * s
193198
stream
194199
);
195200
break;
201+
#endif
196202
case GGML_TYPE_Q8_0:
197203
set_rows_sycl_q<TIdx, block_q8_0, QK8_0, cpy_blck_f32_q8_0>(src0_d, src1_d, (block_q8_0 *)dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb1, nb2, nb3, stream);
198204
break;

0 commit comments

Comments
 (0)