Skip to content

Commit 6a23ab6

Browse files
Metal backend: Materialize non-packed tensor views in reinterpret_tensor
AOTI generates reinterpret_tensor views with non-packed strides (e.g. chunk/split for RoPE rotation) that have holes in memory. ExecuTorch's make_tensor_ptr requires densely packed layouts. When aoti_torch__reinterpret_tensor encounters non-packed strides, allocate a new contiguous Metal buffer and copy elements using strided access from the source. Authored with Claude.
1 parent 1d37abd commit 6a23ab6

1 file changed

Lines changed: 132 additions & 31 deletions

File tree

backends/apple/metal/runtime/shims/memory.cpp

Lines changed: 132 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,82 @@ AOTITorchError aoti_torch_copy_(
367367
return Error::Ok;
368368
}
369369

370+
// Check if a strided view is densely packed (no holes in memory).
371+
// A densely packed tensor's storage extent equals its numel.
372+
static bool is_packed_strides(
373+
const std::vector<aten::SizesType>& sizes,
374+
const std::vector<aten::StridesType>& strides) {
375+
int64_t ndim = static_cast<int64_t>(sizes.size());
376+
if (ndim == 0)
377+
return true;
378+
379+
// Compute numel
380+
int64_t numel = 1;
381+
for (int64_t i = 0; i < ndim; i++) {
382+
numel *= sizes[i];
383+
}
384+
if (numel <= 1)
385+
return true;
386+
387+
// Compute storage extent: max offset + 1
388+
int64_t max_offset = 0;
389+
for (int64_t i = 0; i < ndim; i++) {
390+
if (sizes[i] > 1) {
391+
max_offset += (sizes[i] - 1) * strides[i];
392+
}
393+
}
394+
return (max_offset + 1) == numel;
395+
}
396+
397+
// Materialize a non-packed strided view into a new contiguous Metal buffer.
398+
// Copies elements from source using strided access. The caller must free the
399+
// returned buffer. On failure returns nullptr.
400+
static void* materialize_packed(
401+
void* src,
402+
const std::vector<aten::SizesType>& sizes,
403+
const std::vector<aten::StridesType>& strides,
404+
size_t element_size) {
405+
int64_t ndim = static_cast<int64_t>(sizes.size());
406+
int64_t numel = 1;
407+
for (int64_t i = 0; i < ndim; i++) {
408+
numel *= sizes[i];
409+
}
410+
411+
void* dst = metal_allocate_buffer(numel * element_size);
412+
if (!dst)
413+
return nullptr;
414+
415+
// Ensure pending GPU writes to the source buffer are complete
416+
auto* stream = getCurrentMetalStream();
417+
if (stream) {
418+
stream->synchronize(SyncType::COMMIT_AND_WAIT);
419+
}
420+
421+
// Element-by-element strided copy
422+
char* src_bytes = static_cast<char*>(src);
423+
char* dst_bytes = static_cast<char*>(dst);
424+
std::vector<int64_t> coord(ndim, 0);
425+
for (int64_t flat = 0; flat < numel; flat++) {
426+
// Compute source offset from strides
427+
int64_t src_offset = 0;
428+
for (int64_t d = 0; d < ndim; d++) {
429+
src_offset += coord[d] * strides[d];
430+
}
431+
std::memcpy(
432+
dst_bytes + flat * element_size,
433+
src_bytes + src_offset * element_size,
434+
element_size);
435+
436+
// Increment coordinate (last dim fastest)
437+
for (int64_t d = ndim - 1; d >= 0; d--) {
438+
if (++coord[d] < sizes[d])
439+
break;
440+
coord[d] = 0;
441+
}
442+
}
443+
return dst;
444+
}
445+
370446
AOTITorchError aoti_torch__reinterpret_tensor(
371447
AOTITensorHandle self,
372448
int64_t ndim,
@@ -430,8 +506,9 @@ AOTITorchError aoti_torch__reinterpret_tensor(
430506
data_ptr);
431507

432508
// Handle storage offset by adjusting the data pointer
509+
size_t element_size = dtype_to_element_size(dtype);
433510
void* adjusted_data = static_cast<char*>(data_ptr) +
434-
(storage_offset * dtype_to_element_size(dtype));
511+
(storage_offset * element_size);
435512

436513
// Convert sizes using utility function from utils.h
437514
std::vector<aten::SizesType> sizes = convert_sizes_to_vector(ndim, sizes_ptr);
@@ -440,14 +517,34 @@ AOTITorchError aoti_torch__reinterpret_tensor(
440517
std::vector<aten::StridesType> strides =
441518
convert_strides_to_vector(ndim, sizes_ptr, strides_ptr);
442519

443-
// Create new tensor view that reinterprets the same memory with different
444-
// shape/strides This creates a view, not a copy - the data pointer is shared
520+
// If the view is not densely packed (e.g. chunk/split creating holes),
521+
// materialize it into a new contiguous buffer.
522+
void* tensor_data = adjusted_data;
523+
bool owns_buffer = false;
524+
if (!is_packed_strides(sizes, strides)) {
525+
ET_LOG(
526+
Debug,
527+
"aoti_torch__reinterpret_tensor: non-packed strides, "
528+
"materializing to packed buffer");
529+
tensor_data = materialize_packed(adjusted_data, sizes, strides, element_size);
530+
ET_CHECK_OR_RETURN_ERROR(
531+
tensor_data != nullptr,
532+
MemoryAllocationFailed,
533+
"Failed to materialize non-packed tensor");
534+
owns_buffer = true;
535+
536+
// Compute contiguous strides for the packed buffer
537+
strides.resize(ndim);
538+
if (ndim > 0) {
539+
strides[ndim - 1] = 1;
540+
for (int64_t i = ndim - 2; i >= 0; i--) {
541+
strides[i] = strides[i + 1] * sizes[i + 1];
542+
}
543+
}
544+
}
545+
445546
std::shared_ptr<Tensor> tensor = executorch::extension::from_blob(
446-
adjusted_data, // Use adjusted data pointer with storage offset applied
447-
sizes, // New sizes with explicit SizesType
448-
strides, // New strides with explicit StridesType
449-
dtype_to_scalar_type(dtype) // Convert dtype with explicit type casting
450-
);
547+
tensor_data, sizes, strides, dtype_to_scalar_type(dtype));
451548

452549
ET_CHECK_OR_RETURN_ERROR(
453550
tensor != nullptr,
@@ -456,32 +553,36 @@ AOTITorchError aoti_torch__reinterpret_tensor(
456553

457554
// Store the tensor so it doesn't get destroyed
458555
tensors[tensor.get()] = tensor;
459-
460556
*ret_new_tensor = tensor.get();
461557

462-
if (adjusted_data != data_ptr) {
463-
ET_LOG(
464-
Debug,
465-
"aoti_torch__reinterpret_tensor: Adjusted original_data=%p, storage_offset=%lld, element_size=%zu, adjusted_data=%p",
466-
data_ptr,
467-
storage_offset,
468-
dtype_to_element_size(dtype),
469-
adjusted_data);
470-
471-
ET_CHECK_OR_RETURN_ERROR(
472-
metal_buffer_nocopy(adjusted_data, tensor->nbytes(), true),
473-
Internal,
474-
"metal_buffer_nocopy failed for adjusted_data=%p, nbytes=%zu",
475-
adjusted_data,
476-
static_cast<size_t>(tensor->nbytes()));
477-
478-
memory_to_n_tensor[adjusted_data] = NOT_OWN;
479-
}
558+
if (owns_buffer) {
559+
// The materialized buffer is a new allocation owned by this tensor
560+
memory_to_n_tensor[tensor_data] = 1;
561+
} else {
562+
if (adjusted_data != data_ptr) {
563+
ET_LOG(
564+
Debug,
565+
"aoti_torch__reinterpret_tensor: Adjusted original_data=%p, "
566+
"storage_offset=%lld, element_size=%zu, adjusted_data=%p",
567+
data_ptr,
568+
storage_offset,
569+
element_size,
570+
adjusted_data);
571+
572+
ET_CHECK_OR_RETURN_ERROR(
573+
metal_buffer_nocopy(adjusted_data, tensor->nbytes(), true),
574+
Internal,
575+
"metal_buffer_nocopy failed for adjusted_data=%p, nbytes=%zu",
576+
adjusted_data,
577+
static_cast<size_t>(tensor->nbytes()));
578+
579+
memory_to_n_tensor[adjusted_data] = NOT_OWN;
580+
}
480581

481-
// Increment the reference count for this memory address only if it is owned
482-
// by tensor
483-
if (memory_to_n_tensor[data_ptr] != NOT_OWN) {
484-
memory_to_n_tensor[data_ptr] += 1;
582+
// Increment the reference count for this memory address only if it is owned
583+
if (memory_to_n_tensor[data_ptr] != NOT_OWN) {
584+
memory_to_n_tensor[data_ptr] += 1;
585+
}
485586
}
486587

487588
ET_LOG(Debug, "aoti_torch__reinterpret_tensor: successful");

0 commit comments

Comments
 (0)