@@ -47,6 +47,7 @@ static bool ggml_zendnn_matmul(ggml_backend_zendnn_context * ctx, int64_t m, int
4747 params.dtypes .dst = ggml_to_zendnn_type<TC>();
4848 params.num_threads = ctx->n_threads ;
4949
50+ zendnnl::lowoha::matmul::matmul_batch_params_t batch_params;
5051 zendnnl::error_handling::status_t status = zendnnl::lowoha::matmul::matmul_direct (
5152 ' r' , false , true , // row-major, don't transpose B, transpose A (because it's column-major)
5253 n, // M: rows of B and C
@@ -59,7 +60,7 @@ static bool ggml_zendnn_matmul(ggml_backend_zendnn_context * ctx, int64_t m, int
5960 0 .0f , // beta
6061 C, ldc, // output C[n,m]
6162 true , // is_weights_const
62- {}, // batch_params
63+ batch_params, // batch_params
6364 params // params
6465 );
6566
@@ -520,6 +521,12 @@ static ggml_backend_buffer_t ggml_backend_zendnn_device_buffer_from_host_ptr(ggm
520521 GGML_UNUSED (max_tensor_size);
521522}
522523
524+ static bool ggml_zendnn_adaptive_fallback_enabled () {
525+ static const bool enabled = std::getenv (" GGML_ZENDNN_ADAPTIVE_FALLBACK" ) == nullptr ||
526+ std::atoi (std::getenv (" GGML_ZENDNN_ADAPTIVE_FALLBACK" )) != 0 ;
527+ return enabled;
528+ }
529+
523530static bool ggml_backend_zendnn_device_supports_op (ggml_backend_dev_t dev, const struct ggml_tensor * op) {
524531 switch (op->op ) {
525532 case GGML_OP_NONE:
@@ -538,12 +545,24 @@ static bool ggml_backend_zendnn_device_supports_op(ggml_backend_dev_t dev, const
538545 const int64_t ne10 = inputs->ne [0 ];
539546 const int64_t ne0 = op->ne [0 ];
540547 const int64_t ne1 = op->ne [1 ];
541-
542548 const int64_t min_batch = 1 ;
543- if (!ggml_is_contiguous (weights) || !ggml_is_contiguous (inputs) ||
544- ne0 < min_batch || ne1 < min_batch || ne10 < min_batch) {
549+
550+ if (!ggml_is_contiguous (weights) || !ggml_is_contiguous (inputs)) {
551+ return false ;
552+ }
553+
554+ if (ggml_zendnn_adaptive_fallback_enabled ()) {
555+ const int64_t K = inputs->ne [0 ];
556+ const int64_t N = (inputs->ne [1 ]*inputs->ne [2 ]*inputs->ne [3 ]);
557+ const int64_t M = weights->ne [1 ];
558+ if (K <= 256 || N <= 128 || M <= 96 ) {
545559 return false ;
560+ }
546561 }
562+ else if (ne0 < min_batch || ne1 < min_batch || ne10 < min_batch) {
563+ return false ;
564+ }
565+
547566 // MUL_MAT_ID performs best with a moderate number of experts due to its
548567 // gather + batched matmul + scatter approach. Future versions will leverage
549568 // ZenDNN's grouped_gemm for better scalability with larger expert counts:
0 commit comments