Skip to content

Commit 0411d41

Browse files
hariharans29github-actions[bot]Copilot
authored
[MLAS/CPU EP]: Introduce a backend kernel selector config in MLAS (microsoft#27136)
### Description Introduces a backend kernel selector config struct in MLAS that allows users to configure selection of backend kernels at runtime based on their preference. The immediate use-case of such a feature is to allow users to opt-out of using/selecting KleidiAI kernels should they choose to do so on ARM platforms. This solution should scale to other kernel implementation backends in the future. ### Motivation and Context Allow users to opt-out of using/selecting KleidiAI kernels should they choose to do so on ARM platforms --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 479dd39 commit 0411d41

132 files changed

Lines changed: 1080 additions & 530 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,12 @@ static const char* const kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16 = "mlas
380380
// - "1": Use LUT based GEMM when available.
381381
static const char* const kOrtSessionOptionsMlasLutGemm = "mlas.use_lut_gemm";
382382

383+
// Use KleidiAI kernels in MLAS if available.
384+
// Option values:
385+
// - "0": Use KleidiAI kernels when available. [DEFAULT]
386+
// - "1": Disable KleidiAI kernels even if available.
387+
static const char* const kOrtSessionOptionsMlasDisableKleidiAi = "mlas.disable_kleidiai";
388+
383389
// When converting DQ + MatMul -> MatMulNBits, the accuracy level of the MatMulNBits is controlled by this option.
384390
// Refer to MatMulNBits op schema for more details.
385391
// If not provided, default is 4.

onnxruntime/contrib_ops/cpu/attnlstm/attention_wrapper.cc

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ template <typename T>
1919
AttentionWrapper<T>::AttentionWrapper(AllocatorPtr alloc, const logging::Logger& logger,
2020
int batch_size, int attn_context_depth, int attn_layer_depth,
2121
int inner_cell_hidden_size, bool has_attn_layer,
22-
const IAttentionMechanism<T>& attention_mechanism, concurrency::ThreadPool* threadpool)
22+
const IAttentionMechanism<T>& attention_mechanism, concurrency::ThreadPool* threadpool,
23+
const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* mlas_backend_kernel_selector_config)
2324
: allocator_(alloc),
2425
logger_(logger),
2526
batch_size_(batch_size),
@@ -28,7 +29,8 @@ AttentionWrapper<T>::AttentionWrapper(AllocatorPtr alloc, const logging::Logger&
2829
inner_cell_hidden_size_(inner_cell_hidden_size),
2930
has_attn_layer_(has_attn_layer),
3031
attention_mechanism_(attention_mechanism),
31-
ttp_(threadpool) {
32+
ttp_(threadpool),
33+
mlas_backend_kernel_selector_config_(mlas_backend_kernel_selector_config) {
3234
auto mem_max_steps = attention_mechanism_.GetMaxMemorySteps();
3335
prev_alignments_ = Allocate(allocator_, batch_size_ * mem_max_steps, prev_alignments_ptr_, true);
3436
alignments_ = Allocate(allocator_, batch_size_ * mem_max_steps, alignments_ptr_, true);
@@ -45,7 +47,7 @@ void AttentionWrapper<T>::ProcessOutput(const gsl::span<const T>& rnn_cell_outpu
4547
batch_size_, attn_layer_depth_, inner_cell_hidden_size_, T{1.0},
4648
rnn_cell_output.data(), inner_cell_hidden_size_,
4749
attn_layer_cell_weights_.data(), attn_layer_depth_, T{0.0},
48-
attn_states_.data(), attn_layer_depth_, ttp_);
50+
attn_states_.data(), attn_layer_depth_, ttp_, mlas_backend_kernel_selector_config_);
4951
}
5052

5153
// Get the context which is calculated within attention mechanism.
@@ -62,7 +64,7 @@ void AttentionWrapper<T>::ProcessOutput(const gsl::span<const T>& rnn_cell_outpu
6264
batch_size_, attn_layer_depth_, attn_context_depth_, T{1.0},
6365
attn_context_.data(), attn_context_depth_,
6466
attn_layer_attn_weights_.data(), attn_layer_depth_, T{1.0},
65-
attn_states_.data(), attn_layer_depth_, ttp_);
67+
attn_states_.data(), attn_layer_depth_, ttp_, mlas_backend_kernel_selector_config_);
6668
}
6769
}
6870

onnxruntime/contrib_ops/cpu/attnlstm/attention_wrapper.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "core/common/logging/logging.h"
1010
#include "core/framework/allocator.h"
1111
#include "core/platform/threadpool.h"
12+
#include "core/mlas/inc/mlas.h"
1213

1314
namespace onnxruntime {
1415
namespace contrib {
@@ -23,7 +24,8 @@ class AttentionWrapper {
2324
int attn_layer_depth,
2425
int inner_cell_hidden_size,
2526
bool has_attn_layer,
26-
const IAttentionMechanism<T>& attention_mechanism, concurrency::ThreadPool* threadpool);
27+
const IAttentionMechanism<T>& attention_mechanism, concurrency::ThreadPool* threadpool,
28+
const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* mlas_backend_kernel_selector_config);
2729

2830
virtual ~AttentionWrapper() = default;
2931

@@ -71,6 +73,8 @@ class AttentionWrapper {
7173

7274
const IAttentionMechanism<T>& attention_mechanism_;
7375
concurrency::ThreadPool* ttp_;
76+
77+
const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* mlas_backend_kernel_selector_config_;
7478
};
7579

7680
} // namespace contrib

onnxruntime/contrib_ops/cpu/attnlstm/bahdanau_attention.cc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@ namespace contrib {
1919
template <typename T>
2020
BahdanauAttention<T>::BahdanauAttention(AllocatorPtr allocator, const logging::Logger& logger,
2121
int batch_size, int max_memory_step, int memory_depth,
22-
int query_depth, int attn_depth, bool normalize, concurrency::ThreadPool* threadpool)
23-
: allocator_(allocator), logger_(logger), batch_size_(batch_size), max_memory_steps_(max_memory_step), memory_depth_(memory_depth), query_depth_(query_depth), attn_depth_(attn_depth), normalize_(normalize), ttp_(threadpool) {
22+
int query_depth, int attn_depth, bool normalize, concurrency::ThreadPool* threadpool,
23+
const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* mlas_backend_kernel_selector_config)
24+
: allocator_(allocator), logger_(logger), batch_size_(batch_size), max_memory_steps_(max_memory_step), memory_depth_(memory_depth), query_depth_(query_depth), attn_depth_(attn_depth), normalize_(normalize), ttp_(threadpool), mlas_backend_kernel_selector_config_(mlas_backend_kernel_selector_config) {
2425
values_ = Allocate(allocator_, batch_size_ * max_memory_steps_ * memory_depth_, values_ptr_, true);
2526
keys_ = Allocate(allocator_, batch_size_ * max_memory_steps_ * attn_depth_, keys_ptr_, true);
2627
processed_query_ = Allocate(allocator_, batch_size_ * attn_depth_, processed_query_ptr_, true);
@@ -80,7 +81,7 @@ void BahdanauAttention<T>::PrepareMemory(
8081
batch_size_ * max_memory_steps_, attn_depth_, memory_depth_, T{1.0},
8182
memory.data(), memory_depth_,
8283
memory_layer_weights_.data(), attn_depth_, T{0.0},
83-
keys_.data(), attn_depth_, ttp_);
84+
keys_.data(), attn_depth_, ttp_, mlas_backend_kernel_selector_config_);
8485
}
8586

8687
template <typename T>
@@ -123,7 +124,7 @@ void BahdanauAttention<T>::Compute(
123124
batch_size_, attn_depth_, query_depth_, T{1.0},
124125
queries.data(), query_depth_,
125126
query_layer_weights_.data(), attn_depth_, T{0.0},
126-
processed_query_.data(), attn_depth_, ttp_);
127+
processed_query_.data(), attn_depth_, ttp_, mlas_backend_kernel_selector_config_);
127128

128129
std::fill(aligns.begin(), aligns.end(), T{});
129130

@@ -154,7 +155,7 @@ void BahdanauAttention<T>::Compute(
154155
1, memory_depth_, max_memory_steps_, T{1.0},
155156
alignments, max_memory_steps_,
156157
values.data(), memory_depth_, T{0.0},
157-
outspan.data(), memory_depth_, ttp_);
158+
outspan.data(), memory_depth_, ttp_, mlas_backend_kernel_selector_config_);
158159
}
159160
}
160161

onnxruntime/contrib_ops/cpu/attnlstm/bahdanau_attention.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#pragma once
55

66
#include "core/framework/allocator.h"
7+
#include "core/mlas/inc/mlas.h"
78
#include "core/providers/cpu/rnn/rnn_helpers.h"
89

910
#include "attention_mechanism.h"
@@ -23,7 +24,8 @@ class BahdanauAttention : public IAttentionMechanism<T> {
2324
int memory_depth,
2425
int query_depth,
2526
int attn_depth,
26-
bool normalize, concurrency::ThreadPool* threadpool);
27+
bool normalize, concurrency::ThreadPool* threadpool,
28+
const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* mlas_backend_kernel_selector_config);
2729

2830
void SetWeights(
2931
const gsl::span<const T>& attn_weights,
@@ -78,6 +80,8 @@ class BahdanauAttention : public IAttentionMechanism<T> {
7880

7981
bool normalize_;
8082
concurrency::ThreadPool* ttp_;
83+
84+
const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* mlas_backend_kernel_selector_config_;
8185
};
8286

8387
} // namespace contrib

onnxruntime/contrib_ops/cpu/attnlstm/deep_cpu_attn_lstm.cc

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ Status DeepCpuAttnLstmOp::ComputeImpl(OpKernelContext& context) const {
248248
memory_depth,
249249
query_depth,
250250
am_attn_size,
251-
false, thread_pool);
251+
false, thread_pool, &mlas_backend_kernel_selector_config_);
252252

253253
fam.SetWeights(
254254
FirstHalfSpan(am_v_weights.DataAsSpan<T>()),
@@ -264,7 +264,7 @@ Status DeepCpuAttnLstmOp::ComputeImpl(OpKernelContext& context) const {
264264
attn_layer_depth,
265265
hidden_size_,
266266
has_attention_layer,
267-
fam, thread_pool);
267+
fam, thread_pool, &mlas_backend_kernel_selector_config_);
268268
faw.SetWeights(FirstHalfSpan(attn_layer_weights_span));
269269

270270
UniDirectionalAttnLstm<T> fw(
@@ -275,7 +275,7 @@ Status DeepCpuAttnLstmOp::ComputeImpl(OpKernelContext& context) const {
275275
activation_funcs_.Entries()[0],
276276
activation_funcs_.Entries()[1],
277277
activation_funcs_.Entries()[2],
278-
clip_, thread_pool);
278+
clip_, thread_pool, &mlas_backend_kernel_selector_config_);
279279

280280
BahdanauAttention<T> bam(
281281
alloc,
@@ -285,7 +285,7 @@ Status DeepCpuAttnLstmOp::ComputeImpl(OpKernelContext& context) const {
285285
memory_depth,
286286
query_depth,
287287
am_attn_size,
288-
false, thread_pool);
288+
false, thread_pool, &mlas_backend_kernel_selector_config_);
289289
bam.SetWeights(
290290
SecondHalfSpan(am_v_weights.DataAsSpan<T>()),
291291
SecondHalfSpan(am_query_layer_weights.DataAsSpan<T>()),
@@ -300,7 +300,7 @@ Status DeepCpuAttnLstmOp::ComputeImpl(OpKernelContext& context) const {
300300
attn_layer_depth,
301301
hidden_size_,
302302
has_attention_layer,
303-
bam, thread_pool);
303+
bam, thread_pool, &mlas_backend_kernel_selector_config_);
304304
baw.SetWeights(SecondHalfSpan(attn_layer_weights_span));
305305

306306
UniDirectionalAttnLstm<T> bw(
@@ -311,7 +311,7 @@ Status DeepCpuAttnLstmOp::ComputeImpl(OpKernelContext& context) const {
311311
activation_funcs_.Entries()[3],
312312
activation_funcs_.Entries()[4],
313313
activation_funcs_.Entries()[5],
314-
clip_, thread_pool);
314+
clip_, thread_pool, &mlas_backend_kernel_selector_config_);
315315

316316
fw.Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1, output_1, hidden_output_1, last_cell_1);
317317
bw.Compute(input, sequence_lens_span, num_directions_, input_weights_2, hidden_weights_2, output_2, hidden_output_2, last_cell_2);
@@ -325,7 +325,7 @@ Status DeepCpuAttnLstmOp::ComputeImpl(OpKernelContext& context) const {
325325
memory_depth,
326326
query_depth,
327327
am_attn_size,
328-
false, thread_pool);
328+
false, thread_pool, &mlas_backend_kernel_selector_config_);
329329

330330
fam.SetWeights(
331331
am_v_weights.DataAsSpan<T>(),
@@ -341,7 +341,7 @@ Status DeepCpuAttnLstmOp::ComputeImpl(OpKernelContext& context) const {
341341
attn_layer_depth,
342342
hidden_size_,
343343
has_attention_layer,
344-
fam, thread_pool);
344+
fam, thread_pool, &mlas_backend_kernel_selector_config_);
345345

346346
faw.SetWeights(attn_layer_weights_span);
347347

@@ -353,7 +353,7 @@ Status DeepCpuAttnLstmOp::ComputeImpl(OpKernelContext& context) const {
353353
activation_funcs_.Entries()[0],
354354
activation_funcs_.Entries()[1],
355355
activation_funcs_.Entries()[2],
356-
clip_, thread_pool);
356+
clip_, thread_pool, &mlas_backend_kernel_selector_config_);
357357

358358
fw.Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1, output_1, hidden_output_1, last_cell_1);
359359
}

onnxruntime/contrib_ops/cpu/attnlstm/deep_cpu_attn_lstm.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "core/common/narrow.h"
1111
#include "core/framework/op_kernel.h"
1212
#include "core/providers/cpu/rnn/rnn_helpers.h"
13+
#include "core/providers/cpu/mlas_backend_kernel_selector_config_utils.h"
1314

1415
namespace onnxruntime {
1516
namespace contrib {
@@ -58,6 +59,8 @@ class DeepCpuAttnLstmOp final : public OpKernel {
5859
activation_funcs_ = ActivationFuncs(activation_func_names,
5960
activation_func_alphas,
6061
activation_func_betas);
62+
63+
SetupMlasBackendKernelSelectorFromConfigOptions(mlas_backend_kernel_selector_config_, info.GetConfigOptions());
6164
}
6265

6366
Status Compute(OpKernelContext* context) const override;
@@ -92,6 +95,8 @@ class DeepCpuAttnLstmOp final : public OpKernel {
9295
bool input_forget_ = false;
9396

9497
ActivationFuncs activation_funcs_;
98+
99+
MLAS_BACKEND_KERNEL_SELECTOR_CONFIG mlas_backend_kernel_selector_config_;
95100
};
96101

97102
} // namespace contrib

onnxruntime/contrib_ops/cpu/attnlstm/uni_dir_attn_lstm.cc

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ UniDirectionalAttnLstm<T>::UniDirectionalAttnLstm(AllocatorPtr allocator,
5151
const ActivationFuncs::Entry& activation_func_g,
5252
const ActivationFuncs::Entry& activation_func_h,
5353
const float clip,
54-
onnxruntime::concurrency::ThreadPool* ttp)
54+
onnxruntime::concurrency::ThreadPool* ttp,
55+
const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* mlas_backend_kernel_selector_config)
5556
: allocator_(allocator),
5657
logger_(logger),
5758
seq_length_(seq_length),
@@ -64,7 +65,8 @@ UniDirectionalAttnLstm<T>::UniDirectionalAttnLstm(AllocatorPtr allocator,
6465
use_bias_(!bias.empty()),
6566
use_peepholes_(!peephole_weights.empty()),
6667
attention_wrapper_(attention_wrapper),
67-
ttp_(ttp) {
68+
ttp_(ttp),
69+
mlas_backend_kernel_selector_config_(mlas_backend_kernel_selector_config) {
6870
activation_f_ = {deepcpu::ActivationFuncByName(activation_func_f.name),
6971
activation_func_f.alpha,
7072
activation_func_f.beta};
@@ -260,7 +262,7 @@ void UniDirectionalAttnLstm<T>::Compute(const gsl::span<const T>& inputs_arg,
260262
input_weights.begin(), input_weights.end(), // W[iofc]^T
261263
input_size_ + attention_size_, T{0.0},
262264
output_iofc_.begin(), output_iofc_.end(),
263-
hidden_size_x4, ttp_);
265+
hidden_size_x4, ttp_, mlas_backend_kernel_selector_config_);
264266

265267
DumpMatrix("Xt*(W[iofc]^T)", output_iofc_.data(), total_rows, hidden_size_x4);
266268

@@ -298,7 +300,7 @@ void UniDirectionalAttnLstm<T>::Compute(const gsl::span<const T>& inputs_arg,
298300
input_weights.begin() + input_size_, input_weights.end(), // WA[iofc]
299301
input_size_ + attention_size_, T{1.0},
300302
step_out_IOFC, output_iofc_.end(), // input contains Xt*(W[iofc]^T)
301-
hidden_size_x4, ttp_);
303+
hidden_size_x4, ttp_, mlas_backend_kernel_selector_config_);
302304

303305
// calculate Xt*(W[iofc]^T) + Ht-1*R[iofc]
304306
ComputeGemm(batch_size_, hidden_size_x4, hidden_size_, T{1.0},
@@ -307,7 +309,7 @@ void UniDirectionalAttnLstm<T>::Compute(const gsl::span<const T>& inputs_arg,
307309
recurrent_weights.begin(), recurrent_weights.end(), // R[iofc]
308310
hidden_size_, T{1.0},
309311
step_out_IOFC, output_iofc_.end(), // input contains Xt*(W[iofc]^T)
310-
hidden_size_x4, ttp_);
312+
hidden_size_x4, ttp_, mlas_backend_kernel_selector_config_);
311313

312314
span_T_iter batched_output, batched_output_end;
313315
if (output_sequence) {

onnxruntime/contrib_ops/cpu/attnlstm/uni_dir_attn_lstm.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ class UniDirectionalAttnLstm {
5151
const ActivationFuncs::Entry& activation_func_g,
5252
const ActivationFuncs::Entry& activation_func_h,
5353
const float clip,
54-
onnxruntime::concurrency::ThreadPool* ttp);
54+
onnxruntime::concurrency::ThreadPool* ttp,
55+
const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* mlas_backend_kernel_selector_config);
5556

5657
void Compute(const gsl::span<const T>& inputs,
5758
const gsl::span<const int>& sequence_lengths,
@@ -152,6 +153,8 @@ class UniDirectionalAttnLstm {
152153
AttentionWrapper<T>& attention_wrapper_;
153154

154155
onnxruntime::concurrency::ThreadPool* ttp_;
156+
157+
const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* mlas_backend_kernel_selector_config_;
155158
};
156159

157160
} // namespace detail

0 commit comments

Comments
 (0)