From 2ea8f4e62c159ffa1fc73402932ca4b73336f869 Mon Sep 17 00:00:00 2001 From: Damien Dooley Date: Tue, 26 May 2026 19:51:18 +0100 Subject: [PATCH] Loosed Top.K K input tensor condition --- cmake/patches/onnx/onnx.patch | 61 ++++++++++++++ onnxruntime/core/providers/cpu/math/top_k.cc | 12 ++- .../test/framework/inference_session_test.cc | 81 +++++++++++++++++++ .../test/providers/cpu/math/topk_op_test.cc | 18 +++++ 4 files changed, 168 insertions(+), 4 deletions(-) diff --git a/cmake/patches/onnx/onnx.patch b/cmake/patches/onnx/onnx.patch index 0a5680778790b..2158c558e1063 100644 --- a/cmake/patches/onnx/onnx.patch +++ b/cmake/patches/onnx/onnx.patch @@ -82,3 +82,64 @@ index a6a8a83..153da87 100644 .SetDoc(GroupNormalization_ver18_doc) .Attr("epsilon", "The epsilon value to use to avoid division by zero.", AttributeProto::FLOAT, 1e-5f) .Attr( +diff --git a/onnx/defs/math/old.cc b/onnx/defs/math/old.cc +index 6f83b4b..7b79012 100644 +--- a/onnx/defs/math/old.cc ++++ b/onnx/defs/math/old.cc +@@ -3448,16 +3448,23 @@ static const char* TopK_ver10_doc = R"DOC( + // should be enforced) + if (nullptr != k && axis_dim.has_dim_value()) { + int64_t k_value = 0; +- if (k->dims_size() != 1 || k->dims(0) != 1) { +- fail_shape_inference("K input must be a one-dimensional tensor of size 1."); ++ int64_t k_element_count = 1; ++ for (int i = 0; i < k->dims_size(); ++i) { ++ k_element_count *= k->dims(i); ++ } ++ if (k_element_count != 1) { ++ fail_shape_inference("K input must contain exactly one element."); + } +- ++ + if (k->data_type() == TensorProto::INT64) { + const auto data = ParseData(k); + k_value = data[0]; + } else { + fail_shape_inference("K input must be of type int64."); + } ++ if (k_value < 0) { ++ fail_shape_inference("K input must not be negative."); ++ } +- ++ + if (axis_dim.dim_value() < k_value) { + fail_shape_inference("Axis has less than the requested k elements."); +diff --git a/onnx/defs/math/utils.cc b/onnx/defs/math/utils.cc +index 58c5f5d..f5e18d5 100644 +--- a/onnx/defs/math/utils.cc ++++ b/onnx/defs/math/utils.cc +@@ -113,15 +113,22 @@ void TopKOpSchemaGenerator(OpSchema& schema) { + // should be enforced) + if (nullptr != k && axis_dim.has_dim_value()) { + int64_t k_value = 0; +- if (k->dims_size() != 1 || k->dims(0) != 1) { +- fail_shape_inference("K input must be a one-dimensional tensor of size 1."); ++ int64_t k_element_count = 1; ++ for (int i = 0; i < k->dims_size(); ++i) { ++ k_element_count *= k->dims(i); ++ } ++ if (k_element_count != 1) { ++ fail_shape_inference("K input must contain exactly one element."); + } + if (k->data_type() == TensorProto::INT64) { + const auto data = ParseData(k); + k_value = data[0]; + } else { + fail_shape_inference("K input must be of type int64."); + } ++ if (k_value < 0) { ++ fail_shape_inference("K input must not be negative."); ++ } + if (axis_dim.dim_value() < k_value) { + fail_shape_inference("Axis has less than the requested k elements."); + } diff --git a/onnxruntime/core/providers/cpu/math/top_k.cc b/onnxruntime/core/providers/cpu/math/top_k.cc index ab28b2bacf1ab..3b5d44a6640d3 100644 --- a/onnxruntime/core/providers/cpu/math/top_k.cc +++ b/onnxruntime/core/providers/cpu/math/top_k.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include namespace onnxruntime { @@ -481,9 +482,8 @@ static Status ComputeImplOpset1011(OpKernelContext* p_op_kernel_context, int axi "the tensor to be processed and a tensor containing k value"); } - auto y_shape = Y->Shape().GetDims(); - if (y_shape.size() != 1 || y_shape[0] != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "k tensor should be a 1D tensor of size 1"); + if (Y->Shape().Size() != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "k tensor should contain exactly one element"); } auto parsed_input_k = Y->Data()[0]; @@ -491,7 +491,11 @@ static Status ComputeImplOpset1011(OpKernelContext* p_op_kernel_context, int axi return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "value of k must not be negative"); } - return TopKImpl(p_op_kernel_context, X, axis, gsl::narrow_cast(parsed_input_k), is_largest, is_sorted); + if (parsed_input_k > std::numeric_limits::max()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "value of k exceeds supported range"); + } + + return TopKImpl(p_op_kernel_context, X, axis, static_cast(parsed_input_k), is_largest, is_sorted); } template <> diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index e478a23770afd..1dc4fa8eecd5e 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -73,6 +73,34 @@ struct KernelRegistryAndStatus { std::shared_ptr kernel_registry = std::make_shared(); onnxruntime::Status st; }; + +std::filesystem::path FindRvqDecoderModelPath() { + constexpr const char* kModelFileName = "rvq_decoder_v1.onnx"; + + std::filesystem::path current_path = std::filesystem::current_path(); + for (int i = 0; i < 8; ++i) { + auto candidate = current_path / kModelFileName; + if (std::filesystem::exists(candidate)) { + return candidate; + } + + if (!current_path.has_parent_path() || current_path == current_path.parent_path()) { + break; + } + + current_path = current_path.parent_path(); + } + + const std::filesystem::path source_file_path{__FILE__}; + if (source_file_path.is_absolute()) { + auto candidate = source_file_path.parent_path().parent_path().parent_path().parent_path() / kModelFileName; + if (std::filesystem::exists(candidate)) { + return candidate; + } + } + + return {}; +} } // namespace namespace onnxruntime { @@ -1039,6 +1067,59 @@ TEST(InferenceSessionTests, TestWithIstream) { RunModel(session_object, run_options); } +TEST(InferenceSessionTests, RvqDecoderRunsWithDynamicTopK) { + const auto model_path = FindRvqDecoderModelPath(); + if (model_path.empty()) { + GTEST_SKIP() << "rvq_decoder_v1.onnx is not available in this checkout."; + } + + SessionOptions session_options; + InferenceSession session_object{session_options, GetEnvironment()}; + ASSERT_STATUS_OK(session_object.Load(model_path.string())); + ASSERT_STATUS_OK(session_object.Initialize()); + + auto allocator = TestCPUExecutionProvider()->CreatePreferredAllocators()[0]; + NameMLValMap feeds; + + OrtValue pre_emb; + CreateMLValue(allocator, std::vector{1, 128, 300}, std::vector(1 * 128 * 300, 0.0f), + &pre_emb); + feeds.insert({"pre_emb", pre_emb}); + + OrtValue tokens; + CreateMLValue(allocator, std::vector{1, 300, 2}, std::vector(1 * 300 * 2, 0), + &tokens); + feeds.insert({"tokens", tokens}); + + OrtValue dur; + CreateMLValue(allocator, std::vector{1, 300}, std::vector(1 * 300, 1), &dur); + feeds.insert({"dur", dur}); + + OrtValue text_mask; + CreateMLValue(allocator, std::vector{1, 1, 300}, std::vector(1 * 1 * 300, true), + &text_mask); + feeds.insert({"text_mask", text_mask}); + + OrtValue mel_mask; + CreateMLValue(allocator, std::vector{1, 1, 4096}, std::vector(1 * 1 * 4096, true), + &mel_mask); + feeds.insert({"mel_mask", mel_mask}); + + OrtValue spk_id; + CreateMLValue(allocator, std::vector{1}, std::vector{0}, &spk_id); + feeds.insert({"spk_id", spk_id}); + + OrtValue mel_len; + CreateMLValue(allocator, std::vector{}, std::vector{4096}, &mel_len); + feeds.insert({"mel_len", mel_len}); + + RunOptions run_options; + std::vector output_names{"mel"}; + std::vector fetches; + ASSERT_STATUS_OK(session_object.Run(run_options, feeds, output_names, &fetches)); + ASSERT_EQ(fetches.size(), 1U); +} + TEST(InferenceSessionTests, TestRegisterExecutionProvider) { SessionOptions so; diff --git a/onnxruntime/test/providers/cpu/math/topk_op_test.cc b/onnxruntime/test/providers/cpu/math/topk_op_test.cc index 8dbad50344ddf..4740d4409cc61 100644 --- a/onnxruntime/test/providers/cpu/math/topk_op_test.cc +++ b/onnxruntime/test/providers/cpu/math/topk_op_test.cc @@ -5,6 +5,7 @@ #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" #include "test/common/cuda_op_test_utils.h" +#include "test/unittest_util/graph_transform_test_builder.h" namespace onnxruntime { namespace test { @@ -75,6 +76,23 @@ TEST(TopKOperator, Top1DefaultAxisOpset9_double) { RunTest(9, 1, input_vals, input_dimensions, expected_vals, expected_indices, expected_dimensions, false); } +TEST(TopKOperator, DynamicScalarKFromShapeGather) { + auto build_test_case = [](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput({1, 4}, {0.1f, 0.4f, 0.2f, 0.3f}); + auto* shape_arg = builder.MakeIntermediate(std::vector{2}); + auto* k_arg = builder.MakeIntermediate({}); + auto* values_arg = builder.MakeOutput(std::vector{1, 4}); + auto* indices_arg = builder.MakeOutput(std::vector{1, 4}); + auto* gather_index_arg = builder.MakeScalarInitializer(1); + + builder.AddNode("Shape", {input_arg}, {shape_arg}); + builder.AddNode("Gather", {shape_arg, gather_index_arg}, {k_arg}); + builder.AddNode("TopK", {input_arg, k_arg}, {values_arg, indices_arg}).AddAttribute("axis", int64_t{-1}); + }; + + TransformerTester(build_test_case, nullptr, TransformerLevel::Default, TransformerLevel::Default, 17); +} + TEST(TopKOperator, Top2DefaultAxisOpset9) { std::vector input_vals = {0.1f, 0.3f, 0.2f, 0.4f, 0.1f, 0.3f, 0.4f, 0.2f}; std::vector input_dimensions = {2, 4};