@@ -387,6 +387,59 @@ TEST_F(OperatorRegistryTest, RegisterTwoKernels) {
387387 ASSERT_EQ (val_2, 50 );
388388}
389389
390+ TEST_F (OperatorRegistryTest, GetOpFunctionUsesProvidedKernelList) {
391+ std::array<char , kKernelKeyBufSize > buf{};
392+ Error err = make_kernel_key (
393+ {{ScalarType::Long, {0 , 1 , 2 , 3 }}}, buf.data (), buf.size ());
394+ ASSERT_EQ (err, Error::Ok);
395+ KernelKey long_key = KernelKey (buf.data ());
396+
397+ std::array<Kernel, 2 > kernels = {
398+ Kernel (
399+ " test::provided_kernel_list" ,
400+ KernelKey{},
401+ [](KernelRuntimeContext& context, Span<EValue*> stack) {
402+ (void )context;
403+ *(stack[0 ]) = Scalar (50 );
404+ }),
405+ Kernel (
406+ " test::provided_kernel_list" ,
407+ long_key,
408+ [](KernelRuntimeContext& context, Span<EValue*> stack) {
409+ (void )context;
410+ *(stack[0 ]) = Scalar (100 );
411+ }),
412+ };
413+ Span<const Kernel> kernels_span (kernels.data (), kernels.size ());
414+
415+ std::array<Tensor::DimOrderType, 4 > dims = {0 , 1 , 2 , 3 };
416+ auto dim_order_type = Span<Tensor::DimOrderType>(dims.data (), dims.size ());
417+ std::array<TensorMeta, 1 > long_meta = {
418+ TensorMeta (ScalarType::Long, dim_order_type)};
419+ Span<const TensorMeta> long_kernel_key (long_meta.data (), long_meta.size ());
420+
421+ auto run_kernel = [](OpFunction func) {
422+ EValue value = Scalar (0 );
423+ std::array<EValue*, 1 > stack = {&value};
424+ KernelRuntimeContext context{};
425+ func (context, Span<EValue*>(stack.data (), stack.size ()));
426+ return value.toScalar ().to <int64_t >();
427+ };
428+
429+ Result<OpFunction> specialized_func = get_op_function_from_registry (
430+ " test::provided_kernel_list" , long_kernel_key, kernels_span);
431+ ASSERT_EQ (specialized_func.error (), Error::Ok);
432+ EXPECT_EQ (run_kernel (*specialized_func), 100 );
433+
434+ std::array<TensorMeta, 1 > float_meta = {
435+ TensorMeta (ScalarType::Float, dim_order_type)};
436+ Span<const TensorMeta> float_kernel_key (float_meta.data (), float_meta.size ());
437+ Result<OpFunction> fallback_func = get_op_function_from_registry (
438+ " test::provided_kernel_list" , float_kernel_key, kernels_span);
439+ ASSERT_EQ (fallback_func.error (), Error::Ok);
440+ EXPECT_EQ (run_kernel (*fallback_func), 50 );
441+ }
442+
390443TEST_F (OperatorRegistryTest, DoubleRegisterKernelsDies) {
391444 std::array<char , kKernelKeyBufSize > buf_long_contiguous;
392445 Error err = make_kernel_key (
0 commit comments