@@ -367,6 +367,84 @@ AOTITorchError aoti_torch_copy_(
367367 return Error::Ok;
368368}
369369
370+ // Check if a strided view is densely packed (no holes in memory).
371+ // A densely packed tensor's storage extent equals its numel.
372+ static bool is_packed_strides (
373+ const std::vector<aten::SizesType>& sizes,
374+ const std::vector<aten::StridesType>& strides) {
375+ int64_t ndim = static_cast <int64_t >(sizes.size ());
376+ if (ndim == 0 )
377+ return true ;
378+
379+ // Compute numel
380+ int64_t numel = 1 ;
381+ for (int64_t i = 0 ; i < ndim; i++) {
382+ numel *= sizes[i];
383+ }
384+ if (numel <= 1 )
385+ return true ;
386+
387+ // Compute storage extent: max offset + 1
388+ int64_t max_offset = 0 ;
389+ for (int64_t i = 0 ; i < ndim; i++) {
390+ if (sizes[i] > 1 ) {
391+ max_offset += (sizes[i] - 1 ) * strides[i];
392+ }
393+ }
394+ return (max_offset + 1 ) == numel;
395+ }
396+
397+ // Materialize a non-packed strided view into a new contiguous Metal buffer.
398+ // Copies elements from source using strided access. The caller must free the
399+ // returned buffer. On failure returns nullptr.
400+ static void * materialize_packed (
401+ void * src,
402+ const std::vector<aten::SizesType>& sizes,
403+ const std::vector<aten::StridesType>& strides,
404+ size_t element_size) {
405+ int64_t ndim = static_cast <int64_t >(sizes.size ());
406+ int64_t numel = 1 ;
407+ for (int64_t i = 0 ; i < ndim; i++) {
408+ numel *= sizes[i];
409+ }
410+
411+ void * dst = metal_allocate_buffer (numel * element_size);
412+ if (!dst)
413+ return nullptr ;
414+
415+ // Ensure pending GPU writes to the source buffer are complete
416+ if (metal_is_device_pointer (src)) {
417+ auto * stream = getCurrentMetalStream ();
418+ if (stream) {
419+ stream->synchronize (SyncType::COMMIT_AND_WAIT);
420+ }
421+ }
422+
423+ // Element-by-element strided copy
424+ char * src_bytes = static_cast <char *>(src);
425+ char * dst_bytes = static_cast <char *>(dst);
426+ std::vector<int64_t > coord (ndim, 0 );
427+ for (int64_t flat = 0 ; flat < numel; flat++) {
428+ // Compute source offset from strides
429+ int64_t src_offset = 0 ;
430+ for (int64_t d = 0 ; d < ndim; d++) {
431+ src_offset += coord[d] * strides[d];
432+ }
433+ std::memcpy (
434+ dst_bytes + flat * element_size,
435+ src_bytes + src_offset * element_size,
436+ element_size);
437+
438+ // Increment coordinate (last dim fastest)
439+ for (int64_t d = ndim - 1 ; d >= 0 ; d--) {
440+ if (++coord[d] < sizes[d])
441+ break ;
442+ coord[d] = 0 ;
443+ }
444+ }
445+ return dst;
446+ }
447+
370448AOTITorchError aoti_torch__reinterpret_tensor (
371449 AOTITensorHandle self,
372450 int64_t ndim,
@@ -377,6 +455,12 @@ AOTITorchError aoti_torch__reinterpret_tensor(
377455 ET_LOG (Debug, " aoti_torch__reinterpret_tensor: entered" );
378456
379457 // Validate input parameters first
458+ ET_CHECK_OR_RETURN_ERROR (
459+ ndim >= 0 ,
460+ InvalidArgument,
461+ " aoti_torch__reinterpret_tensor failed: ndim must be >= 0, got %lld" ,
462+ ndim);
463+
380464 ET_CHECK_OR_RETURN_ERROR (
381465 self != nullptr ,
382466 InvalidArgument,
@@ -430,8 +514,9 @@ AOTITorchError aoti_torch__reinterpret_tensor(
430514 data_ptr);
431515
432516 // Handle storage offset by adjusting the data pointer
433- void * adjusted_data = static_cast <char *>(data_ptr) +
434- (storage_offset * dtype_to_element_size (dtype));
517+ size_t element_size = dtype_to_element_size (dtype);
518+ void * adjusted_data =
519+ static_cast <char *>(data_ptr) + (storage_offset * element_size);
435520
436521 // Convert sizes using utility function from utils.h
437522 std::vector<aten::SizesType> sizes = convert_sizes_to_vector (ndim, sizes_ptr);
@@ -440,14 +525,35 @@ AOTITorchError aoti_torch__reinterpret_tensor(
440525 std::vector<aten::StridesType> strides =
441526 convert_strides_to_vector (ndim, sizes_ptr, strides_ptr);
442527
443- // Create new tensor view that reinterprets the same memory with different
444- // shape/strides This creates a view, not a copy - the data pointer is shared
528+ // If the view is not densely packed (e.g. chunk/split creating holes),
529+ // materialize it into a new contiguous buffer.
530+ void * tensor_data = adjusted_data;
531+ bool owns_buffer = false ;
532+ if (!is_packed_strides (sizes, strides)) {
533+ ET_LOG (
534+ Debug,
535+ " aoti_torch__reinterpret_tensor: non-packed strides, "
536+ " materializing to packed buffer" );
537+ tensor_data =
538+ materialize_packed (adjusted_data, sizes, strides, element_size);
539+ ET_CHECK_OR_RETURN_ERROR (
540+ tensor_data != nullptr ,
541+ MemoryAllocationFailed,
542+ " Failed to materialize non-packed tensor" );
543+ owns_buffer = true ;
544+
545+ // Compute contiguous strides for the packed buffer
546+ strides.resize (ndim);
547+ if (ndim > 0 ) {
548+ strides[ndim - 1 ] = 1 ;
549+ for (int64_t i = ndim - 2 ; i >= 0 ; i--) {
550+ strides[i] = strides[i + 1 ] * sizes[i + 1 ];
551+ }
552+ }
553+ }
554+
445555 std::shared_ptr<Tensor> tensor = executorch::extension::from_blob (
446- adjusted_data, // Use adjusted data pointer with storage offset applied
447- sizes, // New sizes with explicit SizesType
448- strides, // New strides with explicit StridesType
449- dtype_to_scalar_type (dtype) // Convert dtype with explicit type casting
450- );
556+ tensor_data, sizes, strides, dtype_to_scalar_type (dtype));
451557
452558 ET_CHECK_OR_RETURN_ERROR (
453559 tensor != nullptr ,
@@ -456,32 +562,36 @@ AOTITorchError aoti_torch__reinterpret_tensor(
456562
457563 // Store the tensor so it doesn't get destroyed
458564 tensors[tensor.get ()] = tensor;
459-
460565 *ret_new_tensor = tensor.get ();
461566
462- if (adjusted_data != data_ptr) {
463- ET_LOG (
464- Debug,
465- " aoti_torch__reinterpret_tensor: Adjusted original_data=%p, storage_offset=%lld, element_size=%zu, adjusted_data=%p" ,
466- data_ptr,
467- storage_offset,
468- dtype_to_element_size (dtype),
469- adjusted_data);
470-
471- ET_CHECK_OR_RETURN_ERROR (
472- metal_buffer_nocopy (adjusted_data, tensor->nbytes (), true ),
473- Internal,
474- " metal_buffer_nocopy failed for adjusted_data=%p, nbytes=%zu" ,
475- adjusted_data,
476- static_cast <size_t >(tensor->nbytes ()));
477-
478- memory_to_n_tensor[adjusted_data] = NOT_OWN;
479- }
567+ if (owns_buffer) {
568+ // The materialized buffer is a new allocation owned by this tensor
569+ memory_to_n_tensor[tensor_data] = 1 ;
570+ } else {
571+ if (adjusted_data != data_ptr) {
572+ ET_LOG (
573+ Debug,
574+ " aoti_torch__reinterpret_tensor: Adjusted original_data=%p, "
575+ " storage_offset=%lld, element_size=%zu, adjusted_data=%p" ,
576+ data_ptr,
577+ storage_offset,
578+ element_size,
579+ adjusted_data);
580+
581+ ET_CHECK_OR_RETURN_ERROR (
582+ metal_buffer_nocopy (adjusted_data, tensor->nbytes (), true ),
583+ Internal,
584+ " metal_buffer_nocopy failed for adjusted_data=%p, nbytes=%zu" ,
585+ adjusted_data,
586+ static_cast <size_t >(tensor->nbytes ()));
587+
588+ memory_to_n_tensor[adjusted_data] = NOT_OWN;
589+ }
480590
481- // Increment the reference count for this memory address only if it is owned
482- // by tensor
483- if ( memory_to_n_tensor[data_ptr] != NOT_OWN) {
484- memory_to_n_tensor[data_ptr] += 1 ;
591+ // Increment the reference count for this memory address only if it is owned
592+ if (memory_to_n_tensor[data_ptr] != NOT_OWN) {
593+ memory_to_n_tensor[data_ptr] += 1 ;
594+ }
485595 }
486596
487597 ET_LOG (Debug, " aoti_torch__reinterpret_tensor: successful" );
0 commit comments