Skip to content

Commit 118d124

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
Add Span<const Kernel> overload to get_op_function_from_registry (#19519)
Summary: Expose logic for just scanning a passed in KernelRegistry. Differential Revision: D98079809
1 parent 1992bdd commit 118d124

3 files changed

Lines changed: 75 additions & 7 deletions

File tree

runtime/kernel/operator_registry.cpp

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,8 @@ bool registry_has_op_function(
249249

250250
Result<OpFunction> get_op_function_from_registry(
251251
const char* name,
252-
Span<const TensorMeta> meta_list) {
252+
Span<const TensorMeta> meta_list,
253+
Span<const Kernel> kernel_list) {
253254
std::array<char, internal::kKernelKeyBufSize> key_string;
254255
Error err = internal::make_kernel_key_string(
255256
meta_list, key_string.data(), key_string.size());
@@ -260,24 +261,31 @@ Result<OpFunction> get_op_function_from_registry(
260261
KernelKey kernel_key = KernelKey(key_string.data());
261262

262263
int32_t fallback_idx = -1;
263-
for (size_t idx = 0; idx < num_registered_kernels; idx++) {
264-
if (strcmp(registered_kernels[idx].name_, name) == 0) {
265-
if (registered_kernels[idx].kernel_key_ == kernel_key) {
266-
return registered_kernels[idx].op_;
264+
for (size_t idx = 0; idx < kernel_list.size(); idx++) {
265+
if (strcmp(kernel_list[idx].name_, name) == 0) {
266+
if (kernel_list[idx].kernel_key_ == kernel_key) {
267+
return kernel_list[idx].op_;
267268
}
268-
if (registered_kernels[idx].kernel_key_.is_fallback()) {
269+
if (kernel_list[idx].kernel_key_.is_fallback()) {
269270
fallback_idx = idx;
270271
}
271272
}
272273
}
273274
if (fallback_idx != -1) {
274-
return registered_kernels[fallback_idx].op_;
275+
return kernel_list[fallback_idx].op_;
275276
}
276277
ET_LOG(Error, "kernel '%s' not found.", name);
277278
ET_LOG_TENSOR_META(meta_list);
278279
return Error::OperatorMissing;
279280
}
280281

282+
Result<OpFunction> get_op_function_from_registry(
283+
const char* name,
284+
Span<const TensorMeta> meta_list) {
285+
return get_op_function_from_registry(
286+
name, meta_list, get_registered_kernels());
287+
}
288+
281289
Span<const Kernel> get_registered_kernels() {
282290
return {registered_kernels, num_registered_kernels};
283291
}

runtime/kernel/operator_registry.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,15 @@ ::executorch::runtime::Result<OpFunction> get_op_function_from_registry(
233233
const char* name,
234234
Span<const TensorMeta> meta_list = {});
235235

236+
/**
237+
* Returns the operator with a given name and TensorMeta list from the provided
238+
* kernel list instead of the global registry.
239+
*/
240+
::executorch::runtime::Result<OpFunction> get_op_function_from_registry(
241+
const char* name,
242+
Span<const TensorMeta> meta_list,
243+
Span<const Kernel> kernel_list);
244+
236245
/**
237246
* Returns all registered kernels.
238247
*/

runtime/kernel/test/operator_registry_test.cpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,57 @@ TEST_F(OperatorRegistryTest, RegisterTwoKernels) {
387387
ASSERT_EQ(val_2, 50);
388388
}
389389

390+
TEST_F(OperatorRegistryTest, GetOpFunctionUsesProvidedKernelList) {
391+
std::array<char, kKernelKeyBufSize> buf;
392+
Error err = make_kernel_key(
393+
{{ScalarType::Long, {0, 1, 2, 3}}}, buf.data(), buf.size());
394+
ASSERT_EQ(err, Error::Ok);
395+
KernelKey long_key = KernelKey(buf.data());
396+
397+
Kernel kernels[] = {
398+
Kernel(
399+
"test::provided_kernel_list",
400+
KernelKey{},
401+
[](KernelRuntimeContext& context, Span<EValue*> stack) {
402+
(void)context;
403+
*(stack[0]) = Scalar(50);
404+
}),
405+
Kernel(
406+
"test::provided_kernel_list",
407+
long_key,
408+
[](KernelRuntimeContext& context, Span<EValue*> stack) {
409+
(void)context;
410+
*(stack[0]) = Scalar(100);
411+
}),
412+
};
413+
Span<const Kernel> kernels_span(kernels);
414+
415+
Tensor::DimOrderType dims[] = {0, 1, 2, 3};
416+
auto dim_order_type = Span<Tensor::DimOrderType>(dims, 4);
417+
TensorMeta long_meta[] = {TensorMeta(ScalarType::Long, dim_order_type)};
418+
Span<const TensorMeta> long_kernel_key(long_meta);
419+
420+
auto run_kernel = [](OpFunction func) {
421+
EValue value = Scalar(0);
422+
EValue* stack[] = {&value};
423+
KernelRuntimeContext context{};
424+
func(context, Span<EValue*>(stack));
425+
return value.toScalar().to<int64_t>();
426+
};
427+
428+
Result<OpFunction> specialized_func = get_op_function_from_registry(
429+
"test::provided_kernel_list", long_kernel_key, kernels_span);
430+
ASSERT_EQ(specialized_func.error(), Error::Ok);
431+
EXPECT_EQ(run_kernel(*specialized_func), 100);
432+
433+
TensorMeta float_meta[] = {TensorMeta(ScalarType::Float, dim_order_type)};
434+
Span<const TensorMeta> float_kernel_key(float_meta);
435+
Result<OpFunction> fallback_func = get_op_function_from_registry(
436+
"test::provided_kernel_list", float_kernel_key, kernels_span);
437+
ASSERT_EQ(fallback_func.error(), Error::Ok);
438+
EXPECT_EQ(run_kernel(*fallback_func), 50);
439+
}
440+
390441
TEST_F(OperatorRegistryTest, DoubleRegisterKernelsDies) {
391442
std::array<char, kKernelKeyBufSize> buf_long_contiguous;
392443
Error err = make_kernel_key(

0 commit comments

Comments
 (0)