Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion onnxruntime/contrib_ops/cpu/bert/ngram_repeat_block.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

#pragma once

#include <atomic>

#include "core/common/common.h"
#include "core/common/narrow.h"
#include "core/common/safeint.h"
Expand Down Expand Up @@ -44,6 +46,9 @@ class NGramRepeatBlock : public OpKernel {

const auto* input_ids_data = static_cast<const int64_t*>(input_ids->DataRaw(input_ids->DataType()));

std::atomic<bool> has_invalid_token{false};
std::atomic<int64_t> invalid_token_id{0};

auto lambda = [&](int64_t b) {
for (int64_t i = 0; i < cur_len; ++i) {
if (i + ngram_size_ > cur_len) {
Expand All @@ -62,7 +67,11 @@ class NGramRepeatBlock : public OpKernel {

if (is_banned) {
auto token_id = static_cast<int64_t>(input_ids_data[b * cur_len + i + ngram_size_ - 1]);
ORT_ENFORCE(token_id < vocab_size);
if (token_id < 0 || token_id >= vocab_size) {
has_invalid_token.store(true, std::memory_order_relaxed);
invalid_token_id.store(token_id, std::memory_order_relaxed);
return;
Comment thread
vraspar marked this conversation as resolved.
}
scores_target[b * vocab_size + token_id] = -std::numeric_limits<float>::infinity();
}
}
Expand All @@ -77,6 +86,12 @@ class NGramRepeatBlock : public OpKernel {
}
});

if (has_invalid_token.load(std::memory_order_relaxed)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"NGramRepeatBlock: token_id ", invalid_token_id.load(std::memory_order_relaxed),
" out of range [0, ", vocab_size, ")");
}

return Status::OK();
}

Expand Down
7 changes: 6 additions & 1 deletion onnxruntime/contrib_ops/cuda/bert/ngram_repeat_block_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,12 @@ __global__ void banRepeatedTokens(const int64_t* __restrict__ tokens,
}
if (is_banned == true) {
auto token_to_be_banned = tokens_shm[col + no_repeat_ngram_size - 1];
lprobs[lprob_start + token_to_be_banned] = -std::numeric_limits<float>::infinity();
CUDA_KERNEL_ASSERT(token_to_be_banned >= 0 && token_to_be_banned < vocab_size);
// In release builds, silently skip OOB tokens rather than writing out of bounds.
// CUDA kernels cannot propagate Status errors to the host.
if (token_to_be_banned >= 0 && token_to_be_banned < vocab_size) {
Comment thread
vraspar marked this conversation as resolved.
lprobs[lprob_start + token_to_be_banned] = -std::numeric_limits<float>::infinity();
}
}
}

Expand Down
34 changes: 34 additions & 0 deletions onnxruntime/test/contrib_ops/ngram_repeat_block_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,39 @@ TEST(NGramRepeatBlockTest, NGramSize_3) {
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}

// Negative token_id used as array index causes OOB write (CPU only test).
// CUDA EP is excluded because CUDA_KERNEL_ASSERT corrupts the device context in debug builds.
TEST(NGramRepeatBlockTest, NegativeTokenId) {
OpTester tester("NGramRepeatBlock", 1, onnxruntime::kMSDomain);

// With ngram_size=2, the operator checks if input_ids[i] == input_ids[cur_len-1] (the tail),
// and if so, bans token_id = input_ids[i+1]. Here input_ids[0]=1 matches input_ids[3]=1,
// so token_id = input_ids[1] = -1000 would be used as an array index.
tester.AddInput<int64_t>("input_ids", {1, 4}, {1, -1000, 0, 1});
tester.AddInput<float>("scores", {1, 4}, {1.0f, 2.0f, 3.0f, 4.0f});
tester.AddAttribute("ngram_size", (int64_t)2);
tester.AddOutput<float>("scores_out", {1, 4}, {1.0f, 2.0f, 3.0f, 4.0f});

std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
tester.Run(OpTester::ExpectResult::kExpectFailure, "token_id", {}, nullptr, &execution_providers);
}

// Token_id >= vocab_size causes OOB write (CPU only test).
TEST(NGramRepeatBlockTest, TokenIdExceedsVocabSize) {
OpTester tester("NGramRepeatBlock", 1, onnxruntime::kMSDomain);

// Same logic: input_ids[0]=1 matches input_ids[3]=1, so token_id = input_ids[1] = 100.
// vocab_size = 4 (from scores shape), so 100 >= vocab_size triggers the error.
tester.AddInput<int64_t>("input_ids", {1, 4}, {1, 100, 0, 1});
tester.AddInput<float>("scores", {1, 4}, {1.0f, 2.0f, 3.0f, 4.0f});
tester.AddAttribute("ngram_size", (int64_t)2);
tester.AddOutput<float>("scores_out", {1, 4}, {1.0f, 2.0f, 3.0f, 4.0f});

std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
tester.Run(OpTester::ExpectResult::kExpectFailure, "token_id", {}, nullptr, &execution_providers);
}

} // namespace test
} // namespace onnxruntime
Loading