Skip to content

Commit a7fb87e

Browse files
lucylqGithub Executorch
andauthored
Fix cuda overflow (#19487)
Update copy_ method to check - this->sizes() == other.sizes() - this->dim() == other.dim() Use overflow-safe arithmetic. Co-authored-by: Github Executorch <github_executorch@arm.com>
1 parent 1992bdd commit a7fb87e

1 file changed

Lines changed: 34 additions & 14 deletions

File tree

backends/aoti/slim/core/slim_tensor.h

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -433,13 +433,19 @@ class SlimTensor {
433433
/**
434434
* Copy data from another tensor to this tensor.
435435
*
436-
* Both tensors must have the same numel and dtype.
437-
* Currently only supports CPU-to-CPU copy (contiguous tensors only).
436+
* Both tensors must have the same numel, sizes and dtype.
438437
*
439438
* @param other The source tensor to copy from
440439
* @return Reference to this tensor
441440
*/
442441
SlimTensor& copy_(const SlimTensor& other) {
442+
ET_CHECK_MSG(
443+
this->dim() == other.dim(),
444+
"copy_: dim of tensors must match (%zu vs %zu)",
445+
this->dim(),
446+
other.dim());
447+
ET_CHECK_MSG(
448+
this->sizes() == other.sizes(), "copy_: sizes of tensors must match");
443449
ET_CHECK_MSG(
444450
this->numel() == other.numel(), "copy_: numel of tensors must match");
445451
ET_CHECK_MSG(this->dtype() == other.dtype(), "copy_: dtype must match");
@@ -463,29 +469,43 @@ class SlimTensor {
463469

464470
std::vector<int64_t> counter(this->dim(), 0);
465471
for (size_t i = 0; i < this->numel(); i++) {
466-
// Compute src offset in elements
467472
int64_t src_offset = 0;
468-
for (size_t d = 0; d < other.dim(); d++) {
469-
src_offset += counter[d] * other.stride(d);
470-
}
471-
472-
// Compute dst offset in elements
473473
int64_t dst_offset = 0;
474474
for (size_t d = 0; d < this->dim(); d++) {
475-
dst_offset += counter[d] * this->stride(d);
475+
int64_t src_term = 0;
476+
int64_t dst_term = 0;
477+
// src_offset = src_offset + counter[d] * other.stride(d)
478+
// dst_offset = dst_offset + counter[d] * this->stride(d)
479+
ET_CHECK_MSG(
480+
!::c10::mul_overflows(counter[d], other.stride(d), &src_term) &&
481+
!::c10::add_overflows(src_offset, src_term, &src_offset) &&
482+
!::c10::mul_overflows(counter[d], this->stride(d), &dst_term) &&
483+
!::c10::add_overflows(dst_offset, dst_term, &dst_offset),
484+
"copy_: offset computation overflow");
476485
}
486+
size_t src_byte_offset = 0;
487+
size_t dst_byte_offset = 0;
488+
// src_byte_offset = src_offset * elem_size
489+
// dst_byte_offset = dst_offset * elem_size
490+
ET_CHECK_MSG(
491+
src_offset >= 0 && dst_offset >= 0 &&
492+
!::c10::mul_overflows(
493+
static_cast<size_t>(src_offset),
494+
elem_size,
495+
&src_byte_offset) &&
496+
!::c10::mul_overflows(
497+
static_cast<size_t>(dst_offset), elem_size, &dst_byte_offset),
498+
"copy_: byte offset overflow");
477499

478500
// Copy elem_size bytes from src to dst
479501
if (this->device().is_cpu() && other.device().is_cpu()) {
480502
std::memcpy(
481-
dst_data + dst_offset * elem_size,
482-
src_data + src_offset * elem_size,
483-
elem_size);
503+
dst_data + dst_byte_offset, src_data + src_byte_offset, elem_size);
484504
} else if (this->device().is_cuda() || other.device().is_cuda()) {
485505
#if defined(CUDA_AVAILABLE)
486506
DeviceTraits<c10::DeviceType::CUDA>::memcpy(
487-
dst_data + dst_offset * elem_size,
488-
src_data + src_offset * elem_size,
507+
dst_data + dst_byte_offset,
508+
src_data + src_byte_offset,
489509
elem_size,
490510
device(), // dst device
491511
other.device() // src device

0 commit comments

Comments
 (0)