Skip to content

Commit df44eab

Browse files
authored
Introduce checks to prevent buffer overflow, add tests (#28713)
This pull request improves the correctness and robustness of the `InPlaceAccumulator` and `InPlaceAccumulatorV2` gradient accumulation kernels by adding new unit tests and enforcing stricter shape validation in both CPU and CUDA implementations. The changes ensure that shape mismatches are caught early, and that optional outputs are properly handled in all execution providers. **Test coverage improvements:** * Added a test to verify that `InPlaceAccumulator` correctly passes through the `old_sum` unchanged and does not consume the `value` input when the optional `update_signal` is `false`. * Added tests for `InPlaceAccumulatorV2` to check that shape mismatches between `accumulation_buffer` and `value` are detected and handled as errors, covering both overwrite and accumulate branches. * Added tests to verify that `InPlaceAccumulatorV2` correctly handles the case where the optional `accumulation_buffer_out` output is omitted, for both CPU and CUDA providers. [[1]](diffhunk://#diff-c62fec6d9ac7d24c7d6befe4e18317d2690385051c01a6412651f01e190de1beR2247-R2288) [[2]](diffhunk://#diff-c62fec6d9ac7d24c7d6befe4e18317d2690385051c01a6412651f01e190de1beR2332-R2346) **Kernel validation improvements:** * Added explicit shape validation to the CPU implementation of `InPlaceAccumulatorV2`, ensuring that the shapes of the accumulation buffer and the value tensor match. * Added the same shape validation to the CUDA implementation of `InPlaceAccumulatorV2`, preventing out-of-bounds memory accesses.
1 parent d165fba commit df44eab

3 files changed

Lines changed: 90 additions & 0 deletions

File tree

orttraining/orttraining/test/gradient/gradient_ops_test.cc

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2175,6 +2175,20 @@ TEST(GradientUtilsTest, InPlaceAccumulatorFloat32) {
21752175
test.Run();
21762176
}
21772177

2178+
// When the optional update_signal is false, the kernel must pass old_sum through unchanged
2179+
// and must not consume the value input.
2180+
TEST(GradientUtilsTest, InPlaceAccumulatorFloat32_NoUpdate) {
2181+
OpTester test("InPlaceAccumulator", 1, onnxruntime::kMSDomain);
2182+
2183+
test.AddInput<float>("old_sum", {3}, {1.f, 2.f, 3.f});
2184+
test.AddInput<float>("value", {3}, {4.f, 5.f, 6.f});
2185+
test.AddInput<bool>("update_signal", {1}, {false});
2186+
2187+
test.AddOutput<float>("new_sum", {3}, {1.f, 2.f, 3.f});
2188+
2189+
test.Run();
2190+
}
2191+
21782192
void TestInPlaceAccumulatorV2(
21792193
const std::vector<int64_t>& tensor_dim,
21802194
const std::unordered_set<std::string>& excluded_providers,
@@ -2230,6 +2244,59 @@ TEST(GradientUtilsTest, InPlaceAccumulatorV2Overwrite) {
22302244
test.Run();
22312245
}
22322246

2247+
// Verify the kernel rejects mismatched shapes between accumulation_buffer and value
2248+
// instead of performing an out-of-bounds copy. Exercises both overwrite and accumulate branches.
2249+
static void RunInPlaceAccumulatorV2ShapeMismatch(bool overwrite_flag,
2250+
std::unique_ptr<IExecutionProvider> provider) {
2251+
OpTester test("InPlaceAccumulatorV2", 1, onnxruntime::kMSDomain);
2252+
2253+
test.AddInput<float>("old_sum", {3}, {1.f, 2.f, 3.f});
2254+
// value has more elements than old_sum; without validation the kernel would over-read/over-write.
2255+
test.AddInput<float>("value", {5}, {4.f, 5.f, 6.f, 7.f, 8.f});
2256+
test.AddInput<bool>("overwrite", {1}, {overwrite_flag});
2257+
test.AddOutput<bool>("updated", {1}, {true});
2258+
test.AddOutput<float>("new_sum", {3}, {0.f, 0.f, 0.f});
2259+
2260+
std::vector<std::unique_ptr<IExecutionProvider>> providers;
2261+
providers.emplace_back(std::move(provider));
2262+
test.Run(OpTester::ExpectResult::kExpectFailure,
2263+
"accumulation_buffer shape", {}, nullptr, &providers);
2264+
}
2265+
2266+
TEST(GradientUtilsTest, InPlaceAccumulatorV2_ShapeMismatch_Overwrite) {
2267+
RunInPlaceAccumulatorV2ShapeMismatch(/*overwrite_flag=*/true, DefaultCpuExecutionProvider());
2268+
}
2269+
2270+
TEST(GradientUtilsTest, InPlaceAccumulatorV2_ShapeMismatch_Accumulate) {
2271+
RunInPlaceAccumulatorV2ShapeMismatch(/*overwrite_flag=*/false, DefaultCpuExecutionProvider());
2272+
}
2273+
2274+
#if defined(USE_CUDA)
2275+
TEST(GradientUtilsTest, InPlaceAccumulatorV2_ShapeMismatch_Overwrite_GPU) {
2276+
RunInPlaceAccumulatorV2ShapeMismatch(/*overwrite_flag=*/true, DefaultCudaExecutionProvider());
2277+
}
2278+
2279+
TEST(GradientUtilsTest, InPlaceAccumulatorV2_ShapeMismatch_Accumulate_GPU) {
2280+
RunInPlaceAccumulatorV2ShapeMismatch(/*overwrite_flag=*/false, DefaultCudaExecutionProvider());
2281+
}
2282+
#endif
2283+
2284+
// Exercise the path where the optional accumulation_buffer_out output is omitted.
2285+
// The kernel must still update the in-place accumulation_buffer and produce updated_flag.
2286+
TEST(GradientUtilsTest, InPlaceAccumulatorV2_NoAccumulationOutput_CPU) {
2287+
OpTester test("InPlaceAccumulatorV2", 1, onnxruntime::kMSDomain);
2288+
2289+
test.AddInput<float>("old_sum", {3}, {1.f, 2.f, 3.f});
2290+
test.AddInput<float>("value", {3}, {4.f, 5.f, 6.f});
2291+
test.AddInput<bool>("overwrite", {1}, {false});
2292+
test.AddOutput<bool>("updated", {1}, {true});
2293+
test.AddOptionalOutputEdge<float>();
2294+
2295+
std::vector<std::unique_ptr<IExecutionProvider>> providers;
2296+
providers.emplace_back(DefaultCpuExecutionProvider());
2297+
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &providers);
2298+
}
2299+
22332300
#if defined(USE_CUDA)
22342301
TEST(GradientUtilsTest, InPlaceAccumulatorV2_GPU) {
22352302
std::vector<std::vector<int64_t>> test_dims{
@@ -2273,6 +2340,21 @@ TEST(GradientUtilsTest, InPlaceAccumulatorV2_Float16) {
22732340
providers.emplace_back(DefaultCudaExecutionProvider());
22742341
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &providers);
22752342
}
2343+
2344+
// CUDA-side coverage for the omitted optional accumulation_buffer_out output.
2345+
TEST(GradientUtilsTest, InPlaceAccumulatorV2_NoAccumulationOutput_GPU) {
2346+
OpTester test("InPlaceAccumulatorV2", 1, onnxruntime::kMSDomain);
2347+
2348+
test.AddInput<float>("old_sum", {3}, {1.f, 2.f, 3.f});
2349+
test.AddInput<float>("value", {3}, {4.f, 5.f, 6.f});
2350+
test.AddInput<bool>("overwrite", {1}, {false});
2351+
test.AddOutput<bool>("updated", {1}, {true});
2352+
test.AddOptionalOutputEdge<float>();
2353+
2354+
std::vector<std::unique_ptr<IExecutionProvider>> providers;
2355+
providers.emplace_back(DefaultCudaExecutionProvider());
2356+
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &providers);
2357+
}
22762358
#endif
22772359

22782360
#if defined(USE_CUDA)

orttraining/orttraining/training_ops/cpu/optimizer/gradient_control.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,10 @@ Status InPlaceAccumulatorV2<T>::Compute(OpKernelContext* context) const {
9595
const Tensor* new_value = context->Input<Tensor>(1);
9696
const Tensor* overwrite_tensor = context->Input<Tensor>(2);
9797

98+
ORT_RETURN_IF_NOT(accumulation_buffer->Shape() == new_value->Shape(),
99+
"InPlaceAccumulatorV2: accumulation_buffer shape (", accumulation_buffer->Shape(),
100+
") must match value shape (", new_value->Shape(), ").");
101+
98102
void* accumulation_buffer_data = accumulation_buffer->template MutableData<T>();
99103
const bool overwrite = overwrite_tensor != nullptr ? *(overwrite_tensor->template Data<bool>()) : false;
100104

orttraining/orttraining/training_ops/cuda/optimizer/gradient_control.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,10 @@ Status InPlaceAccumulatorV2<T, T_GRAD>::ComputeInternal(OpKernelContext* ctx) co
119119
const Tensor* overwrite_tensor = ctx->Input<Tensor>(2);
120120
const bool overwrite = overwrite_tensor != nullptr ? *(overwrite_tensor->template Data<bool>()) : false;
121121

122+
ORT_RETURN_IF_NOT(left_addee_buffer.Shape() == right_addee_buffer.Shape(),
123+
"InPlaceAccumulatorV2: accumulation_buffer shape (", left_addee_buffer.Shape(),
124+
") must match value shape (", right_addee_buffer.Shape(), ").");
125+
122126
if (overwrite) {
123127
const T_GRAD* source = right_addee_buffer.template Data<T_GRAD>();
124128
T* target = left_addee_buffer.template MutableData<T>();

0 commit comments

Comments
 (0)