Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions cmake/patches/onnx/onnx.patch
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,64 @@ index a6a8a83..153da87 100644
.SetDoc(GroupNormalization_ver18_doc)
.Attr("epsilon", "The epsilon value to use to avoid division by zero.", AttributeProto::FLOAT, 1e-5f)
.Attr(
diff --git a/onnx/defs/math/old.cc b/onnx/defs/math/old.cc
index 6f83b4b..7b79012 100644
--- a/onnx/defs/math/old.cc
+++ b/onnx/defs/math/old.cc
@@ -3448,16 +3448,23 @@ static const char* TopK_ver10_doc = R"DOC(
// should be enforced)
if (nullptr != k && axis_dim.has_dim_value()) {
int64_t k_value = 0;
- if (k->dims_size() != 1 || k->dims(0) != 1) {
- fail_shape_inference("K input must be a one-dimensional tensor of size 1.");
+ int64_t k_element_count = 1;
+ for (int i = 0; i < k->dims_size(); ++i) {
+ k_element_count *= k->dims(i);
+ }
+ if (k_element_count != 1) {
+ fail_shape_inference("K input must contain exactly one element.");
}
-
+
if (k->data_type() == TensorProto::INT64) {
const auto data = ParseData<int64_t>(k);
k_value = data[0];
} else {
fail_shape_inference("K input must be of type int64.");
}
+ if (k_value < 0) {
+ fail_shape_inference("K input must not be negative.");
+ }
-
+
if (axis_dim.dim_value() < k_value) {
fail_shape_inference("Axis has less than the requested k elements.");
diff --git a/onnx/defs/math/utils.cc b/onnx/defs/math/utils.cc
index 58c5f5d..f5e18d5 100644
--- a/onnx/defs/math/utils.cc
+++ b/onnx/defs/math/utils.cc
@@ -113,15 +113,22 @@ void TopKOpSchemaGenerator(OpSchema& schema) {
// should be enforced)
if (nullptr != k && axis_dim.has_dim_value()) {
int64_t k_value = 0;
- if (k->dims_size() != 1 || k->dims(0) != 1) {
- fail_shape_inference("K input must be a one-dimensional tensor of size 1.");
+ int64_t k_element_count = 1;
+ for (int i = 0; i < k->dims_size(); ++i) {
+ k_element_count *= k->dims(i);
+ }
+ if (k_element_count != 1) {
+ fail_shape_inference("K input must contain exactly one element.");
}
if (k->data_type() == TensorProto::INT64) {
const auto data = ParseData<int64_t>(k);
k_value = data[0];
} else {
fail_shape_inference("K input must be of type int64.");
}
+ if (k_value < 0) {
+ fail_shape_inference("K input must not be negative.");
+ }
if (axis_dim.dim_value() < k_value) {
fail_shape_inference("Axis has less than the requested k elements.");
}
12 changes: 8 additions & 4 deletions onnxruntime/core/providers/cpu/math/top_k.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <queue>
#include <algorithm>
#include <cmath>
#include <limits>
#include <core/common/safeint.h>

namespace onnxruntime {
Expand Down Expand Up @@ -481,17 +482,20 @@ static Status ComputeImplOpset1011(OpKernelContext* p_op_kernel_context, int axi
"the tensor to be processed and a tensor containing k value");
}

auto y_shape = Y->Shape().GetDims();
if (y_shape.size() != 1 || y_shape[0] != 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "k tensor should be a 1D tensor of size 1");
if (Y->Shape().Size() != 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "k tensor should contain exactly one element");
}

auto parsed_input_k = Y->Data<int64_t>()[0];
if (parsed_input_k < 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "value of k must not be negative");
}

return TopKImpl<T>(p_op_kernel_context, X, axis, gsl::narrow_cast<unsigned>(parsed_input_k), is_largest, is_sorted);
if (parsed_input_k > std::numeric_limits<unsigned>::max()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "value of k exceeds supported range");
}

return TopKImpl<T>(p_op_kernel_context, X, axis, static_cast<unsigned>(parsed_input_k), is_largest, is_sorted);
}

template <>
Expand Down
81 changes: 81 additions & 0 deletions onnxruntime/test/framework/inference_session_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,34 @@ struct KernelRegistryAndStatus {
std::shared_ptr<onnxruntime::KernelRegistry> kernel_registry = std::make_shared<onnxruntime::KernelRegistry>();
onnxruntime::Status st;
};

std::filesystem::path FindRvqDecoderModelPath() {
constexpr const char* kModelFileName = "rvq_decoder_v1.onnx";

std::filesystem::path current_path = std::filesystem::current_path();
for (int i = 0; i < 8; ++i) {
auto candidate = current_path / kModelFileName;
if (std::filesystem::exists(candidate)) {
return candidate;
}

if (!current_path.has_parent_path() || current_path == current_path.parent_path()) {
break;
}

current_path = current_path.parent_path();
}

const std::filesystem::path source_file_path{__FILE__};
if (source_file_path.is_absolute()) {
auto candidate = source_file_path.parent_path().parent_path().parent_path().parent_path() / kModelFileName;
if (std::filesystem::exists(candidate)) {
return candidate;
}
}

return {};
}
} // namespace
namespace onnxruntime {

Expand Down Expand Up @@ -1039,6 +1067,59 @@ TEST(InferenceSessionTests, TestWithIstream) {
RunModel(session_object, run_options);
}

TEST(InferenceSessionTests, RvqDecoderRunsWithDynamicTopK) {
const auto model_path = FindRvqDecoderModelPath();
if (model_path.empty()) {
GTEST_SKIP() << "rvq_decoder_v1.onnx is not available in this checkout.";
}

SessionOptions session_options;
InferenceSession session_object{session_options, GetEnvironment()};
ASSERT_STATUS_OK(session_object.Load(model_path.string()));
ASSERT_STATUS_OK(session_object.Initialize());

auto allocator = TestCPUExecutionProvider()->CreatePreferredAllocators()[0];
NameMLValMap feeds;

OrtValue pre_emb;
CreateMLValue<float>(allocator, std::vector<int64_t>{1, 128, 300}, std::vector<float>(1 * 128 * 300, 0.0f),
&pre_emb);
feeds.insert({"pre_emb", pre_emb});

OrtValue tokens;
CreateMLValue<int32_t>(allocator, std::vector<int64_t>{1, 300, 2}, std::vector<int32_t>(1 * 300 * 2, 0),
&tokens);
feeds.insert({"tokens", tokens});

OrtValue dur;
CreateMLValue<int32_t>(allocator, std::vector<int64_t>{1, 300}, std::vector<int32_t>(1 * 300, 1), &dur);
feeds.insert({"dur", dur});

OrtValue text_mask;
CreateMLValue<bool>(allocator, std::vector<int64_t>{1, 1, 300}, std::vector<bool>(1 * 1 * 300, true),
&text_mask);
feeds.insert({"text_mask", text_mask});

OrtValue mel_mask;
CreateMLValue<bool>(allocator, std::vector<int64_t>{1, 1, 4096}, std::vector<bool>(1 * 1 * 4096, true),
&mel_mask);
feeds.insert({"mel_mask", mel_mask});

OrtValue spk_id;
CreateMLValue<int32_t>(allocator, std::vector<int64_t>{1}, std::vector<int32_t>{0}, &spk_id);
feeds.insert({"spk_id", spk_id});

OrtValue mel_len;
CreateMLValue<int64_t>(allocator, std::vector<int64_t>{}, std::vector<int64_t>{4096}, &mel_len);
feeds.insert({"mel_len", mel_len});

RunOptions run_options;
std::vector<std::string> output_names{"mel"};
std::vector<OrtValue> fetches;
ASSERT_STATUS_OK(session_object.Run(run_options, feeds, output_names, &fetches));
ASSERT_EQ(fetches.size(), 1U);
}

TEST(InferenceSessionTests, TestRegisterExecutionProvider) {
SessionOptions so;

Expand Down
18 changes: 18 additions & 0 deletions onnxruntime/test/providers/cpu/math/topk_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "gtest/gtest.h"
#include "test/providers/provider_test_utils.h"
#include "test/common/cuda_op_test_utils.h"
#include "test/unittest_util/graph_transform_test_builder.h"

namespace onnxruntime {
namespace test {
Expand Down Expand Up @@ -75,6 +76,23 @@ TEST(TopKOperator, Top1DefaultAxisOpset9_double) {
RunTest(9, 1, input_vals, input_dimensions, expected_vals, expected_indices, expected_dimensions, false);
}

TEST(TopKOperator, DynamicScalarKFromShapeGather) {
auto build_test_case = [](ModelTestBuilder& builder) {
auto* input_arg = builder.MakeInput<float>({1, 4}, {0.1f, 0.4f, 0.2f, 0.3f});
auto* shape_arg = builder.MakeIntermediate<int64_t>(std::vector<int64_t>{2});
auto* k_arg = builder.MakeIntermediate<int64_t>({});
auto* values_arg = builder.MakeOutput<float>(std::vector<int64_t>{1, 4});
auto* indices_arg = builder.MakeOutput<int64_t>(std::vector<int64_t>{1, 4});
auto* gather_index_arg = builder.MakeScalarInitializer<int64_t>(1);

builder.AddNode("Shape", {input_arg}, {shape_arg});
builder.AddNode("Gather", {shape_arg, gather_index_arg}, {k_arg});
builder.AddNode("TopK", {input_arg, k_arg}, {values_arg, indices_arg}).AddAttribute("axis", int64_t{-1});
};

TransformerTester(build_test_case, nullptr, TransformerLevel::Default, TransformerLevel::Default, 17);
}

TEST(TopKOperator, Top2DefaultAxisOpset9) {
std::vector<float> input_vals = {0.1f, 0.3f, 0.2f, 0.4f, 0.1f, 0.3f, 0.4f, 0.2f};
std::vector<int64_t> input_dimensions = {2, 4};
Expand Down