@@ -387,6 +387,57 @@ TEST_F(OperatorRegistryTest, RegisterTwoKernels) {
387387 ASSERT_EQ (val_2, 50 );
388388}
389389
390+ TEST_F (OperatorRegistryTest, GetOpFunctionUsesProvidedKernelList) {
391+ std::array<char , kKernelKeyBufSize > buf;
392+ Error err = make_kernel_key (
393+ {{ScalarType::Long, {0 , 1 , 2 , 3 }}}, buf.data (), buf.size ());
394+ ASSERT_EQ (err, Error::Ok);
395+ KernelKey long_key = KernelKey (buf.data ());
396+
397+ Kernel kernels[] = {
398+ Kernel (
399+ " test::provided_kernel_list" ,
400+ KernelKey{},
401+ [](KernelRuntimeContext& context, Span<EValue*> stack) {
402+ (void )context;
403+ *(stack[0 ]) = Scalar (50 );
404+ }),
405+ Kernel (
406+ " test::provided_kernel_list" ,
407+ long_key,
408+ [](KernelRuntimeContext& context, Span<EValue*> stack) {
409+ (void )context;
410+ *(stack[0 ]) = Scalar (100 );
411+ }),
412+ };
413+ Span<const Kernel> kernels_span (kernels);
414+
415+ Tensor::DimOrderType dims[] = {0 , 1 , 2 , 3 };
416+ auto dim_order_type = Span<Tensor::DimOrderType>(dims, 4 );
417+ TensorMeta long_meta[] = {TensorMeta (ScalarType::Long, dim_order_type)};
418+ Span<const TensorMeta> long_kernel_key (long_meta);
419+
420+ auto run_kernel = [](OpFunction func) {
421+ EValue value = Scalar (0 );
422+ EValue* stack[] = {&value};
423+ KernelRuntimeContext context{};
424+ func (context, Span<EValue*>(stack));
425+ return value.toScalar ().to <int64_t >();
426+ };
427+
428+ Result<OpFunction> specialized_func = get_op_function_from_registry (
429+ " test::provided_kernel_list" , long_kernel_key, kernels_span);
430+ ASSERT_EQ (specialized_func.error (), Error::Ok);
431+ EXPECT_EQ (run_kernel (*specialized_func), 100 );
432+
433+ TensorMeta float_meta[] = {TensorMeta (ScalarType::Float, dim_order_type)};
434+ Span<const TensorMeta> float_kernel_key (float_meta);
435+ Result<OpFunction> fallback_func = get_op_function_from_registry (
436+ " test::provided_kernel_list" , float_kernel_key, kernels_span);
437+ ASSERT_EQ (fallback_func.error (), Error::Ok);
438+ EXPECT_EQ (run_kernel (*fallback_func), 50 );
439+ }
440+
390441TEST_F (OperatorRegistryTest, DoubleRegisterKernelsDies) {
391442 std::array<char , kKernelKeyBufSize > buf_long_contiguous;
392443 Error err = make_kernel_key (
0 commit comments