Skip to content

Commit 61af07c

Browse files
authored
ggml-zendnn : adaptive fallback to CPU backend for small batch sizes (#22681)
* ggml-zendnn : add runtime env var GGML_ZENDNN_ADAPTIVE_FALLBACK to control adaptive fallback (default: enabled) * ggml-zendnn : restore original fallback logic when adaptive fallback is disabled
1 parent 856c3ad commit 61af07c

2 files changed

Lines changed: 24 additions & 5 deletions

File tree

ggml/src/ggml-zendnn/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ if (NOT ZENDNN_ROOT OR ZENDNN_ROOT STREQUAL "" OR ZENDNN_ROOT STREQUAL "OFF")
2828
ExternalProject_Add(
2929
zendnn
3030
GIT_REPOSITORY https://github.com/amd/ZenDNN.git
31-
GIT_TAG f79f7321a1add65ced6397a6bfab7edba6e3e14e # ZenDNN-2026-WW13
31+
GIT_TAG ac9e580d9434b7b98985f2627a7ebfb5eba4bb0d # ZenDNN-2026-WW17
3232
PREFIX ${ZENDNN_PREFIX}
3333
SOURCE_DIR ${ZENDNN_SOURCE_DIR}
3434
BINARY_DIR ${ZENDNN_BUILD_DIR}

ggml/src/ggml-zendnn/ggml-zendnn.cpp

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ static bool ggml_zendnn_matmul(ggml_backend_zendnn_context * ctx, int64_t m, int
4747
params.dtypes.dst = ggml_to_zendnn_type<TC>();
4848
params.num_threads = ctx->n_threads;
4949

50+
zendnnl::lowoha::matmul::matmul_batch_params_t batch_params;
5051
zendnnl::error_handling::status_t status = zendnnl::lowoha::matmul::matmul_direct(
5152
'r', false, true, // row-major, don't transpose B, transpose A (because it's column-major)
5253
n, // M: rows of B and C
@@ -59,7 +60,7 @@ static bool ggml_zendnn_matmul(ggml_backend_zendnn_context * ctx, int64_t m, int
5960
0.0f, // beta
6061
C, ldc, // output C[n,m]
6162
true, // is_weights_const
62-
{}, // batch_params
63+
batch_params, // batch_params
6364
params // params
6465
);
6566

@@ -520,6 +521,12 @@ static ggml_backend_buffer_t ggml_backend_zendnn_device_buffer_from_host_ptr(ggm
520521
GGML_UNUSED(max_tensor_size);
521522
}
522523

524+
static bool ggml_zendnn_adaptive_fallback_enabled() {
525+
static const bool enabled = std::getenv("GGML_ZENDNN_ADAPTIVE_FALLBACK") == nullptr ||
526+
std::atoi(std::getenv("GGML_ZENDNN_ADAPTIVE_FALLBACK")) != 0;
527+
return enabled;
528+
}
529+
523530
static bool ggml_backend_zendnn_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
524531
switch (op->op) {
525532
case GGML_OP_NONE:
@@ -538,12 +545,24 @@ static bool ggml_backend_zendnn_device_supports_op(ggml_backend_dev_t dev, const
538545
const int64_t ne10 = inputs->ne[0];
539546
const int64_t ne0 = op->ne[0];
540547
const int64_t ne1 = op->ne[1];
541-
542548
const int64_t min_batch = 1;
543-
if (!ggml_is_contiguous(weights) || !ggml_is_contiguous(inputs) ||
544-
ne0 < min_batch || ne1 < min_batch || ne10 < min_batch) {
549+
550+
if(!ggml_is_contiguous(weights) || !ggml_is_contiguous(inputs)) {
551+
return false;
552+
}
553+
554+
if (ggml_zendnn_adaptive_fallback_enabled()) {
555+
const int64_t K = inputs->ne[0];
556+
const int64_t N = (inputs->ne[1]*inputs->ne[2]*inputs->ne[3]);
557+
const int64_t M = weights->ne[1];
558+
if(K <= 256 || N <= 128 || M <= 96) {
545559
return false;
560+
}
546561
}
562+
else if (ne0 < min_batch || ne1 < min_batch || ne10 < min_batch) {
563+
return false;
564+
}
565+
547566
// MUL_MAT_ID performs best with a moderate number of experts due to its
548567
// gather + batched matmul + scatter approach. Future versions will leverage
549568
// ZenDNN's grouped_gemm for better scalability with larger expert counts:

0 commit comments

Comments
 (0)