@@ -433,13 +433,19 @@ class SlimTensor {
433433 /* *
434434 * Copy data from another tensor to this tensor.
435435 *
436- * Both tensors must have the same numel and dtype.
437- * Currently only supports CPU-to-CPU copy (contiguous tensors only).
436+ * Both tensors must have the same numel, sizes and dtype.
438437 *
439438 * @param other The source tensor to copy from
440439 * @return Reference to this tensor
441440 */
442441 SlimTensor& copy_ (const SlimTensor& other) {
442+ ET_CHECK_MSG (
443+ this ->dim () == other.dim (),
444+ " copy_: dim of tensors must match (%zu vs %zu)" ,
445+ this ->dim (),
446+ other.dim ());
447+ ET_CHECK_MSG (
448+ this ->sizes () == other.sizes (), " copy_: sizes of tensors must match" );
443449 ET_CHECK_MSG (
444450 this ->numel () == other.numel (), " copy_: numel of tensors must match" );
445451 ET_CHECK_MSG (this ->dtype () == other.dtype (), " copy_: dtype must match" );
@@ -463,29 +469,43 @@ class SlimTensor {
463469
464470 std::vector<int64_t > counter (this ->dim (), 0 );
465471 for (size_t i = 0 ; i < this ->numel (); i++) {
466- // Compute src offset in elements
467472 int64_t src_offset = 0 ;
468- for (size_t d = 0 ; d < other.dim (); d++) {
469- src_offset += counter[d] * other.stride (d);
470- }
471-
472- // Compute dst offset in elements
473473 int64_t dst_offset = 0 ;
474474 for (size_t d = 0 ; d < this ->dim (); d++) {
475- dst_offset += counter[d] * this ->stride (d);
475+ int64_t src_term = 0 ;
476+ int64_t dst_term = 0 ;
477+ // src_offset = src_offset + counter[d] * other.stride(d)
478+ // dst_offset = dst_offset + counter[d] * this->stride(d)
479+ ET_CHECK_MSG (
480+ !::c10::mul_overflows (counter[d], other.stride (d), &src_term) &&
481+ !::c10::add_overflows (src_offset, src_term, &src_offset) &&
482+ !::c10::mul_overflows (counter[d], this ->stride (d), &dst_term) &&
483+ !::c10::add_overflows (dst_offset, dst_term, &dst_offset),
484+ " copy_: offset computation overflow" );
476485 }
486+ size_t src_byte_offset = 0 ;
487+ size_t dst_byte_offset = 0 ;
488+ // src_byte_offset = src_offset * elem_size
489+ // dst_byte_offset = dst_offset * elem_size
490+ ET_CHECK_MSG (
491+ src_offset >= 0 && dst_offset >= 0 &&
492+ !::c10::mul_overflows (
493+ static_cast <size_t >(src_offset),
494+ elem_size,
495+ &src_byte_offset) &&
496+ !::c10::mul_overflows (
497+ static_cast <size_t >(dst_offset), elem_size, &dst_byte_offset),
498+ " copy_: byte offset overflow" );
477499
478500 // Copy elem_size bytes from src to dst
479501 if (this ->device ().is_cpu () && other.device ().is_cpu ()) {
480502 std::memcpy (
481- dst_data + dst_offset * elem_size,
482- src_data + src_offset * elem_size,
483- elem_size);
503+ dst_data + dst_byte_offset, src_data + src_byte_offset, elem_size);
484504 } else if (this ->device ().is_cuda () || other.device ().is_cuda ()) {
485505#if defined(CUDA_AVAILABLE)
486506 DeviceTraits<c10::DeviceType::CUDA>::memcpy (
487- dst_data + dst_offset * elem_size ,
488- src_data + src_offset * elem_size ,
507+ dst_data + dst_byte_offset ,
508+ src_data + src_byte_offset ,
489509 elem_size,
490510 device (), // dst device
491511 other.device () // src device
0 commit comments