Skip to content

Commit 6be4fb5

Browse files
Metal backend: Materialize non-packed tensor views in reinterpret_tensor (#19033)
AOTI generates reinterpret_tensor views with non-packed strides (e.g. chunk/split for RoPE rotation) that have holes in memory. ExecuTorch's make_tensor_ptr requires densely packed layouts. When aoti_torch__reinterpret_tensor encounters non-packed strides, allocate a new contiguous Metal buffer and copy elements using strided access from the source.
1 parent 273888f commit 6be4fb5

2 files changed

Lines changed: 166 additions & 32 deletions

File tree

backends/apple/metal/runtime/shims/memory.cpp

Lines changed: 142 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,84 @@ AOTITorchError aoti_torch_copy_(
367367
return Error::Ok;
368368
}
369369

370+
// Check if a strided view is densely packed (no holes in memory).
371+
// A densely packed tensor's storage extent equals its numel.
372+
static bool is_packed_strides(
373+
const std::vector<aten::SizesType>& sizes,
374+
const std::vector<aten::StridesType>& strides) {
375+
int64_t ndim = static_cast<int64_t>(sizes.size());
376+
if (ndim == 0)
377+
return true;
378+
379+
// Compute numel
380+
int64_t numel = 1;
381+
for (int64_t i = 0; i < ndim; i++) {
382+
numel *= sizes[i];
383+
}
384+
if (numel <= 1)
385+
return true;
386+
387+
// Compute storage extent: max offset + 1
388+
int64_t max_offset = 0;
389+
for (int64_t i = 0; i < ndim; i++) {
390+
if (sizes[i] > 1) {
391+
max_offset += (sizes[i] - 1) * strides[i];
392+
}
393+
}
394+
return (max_offset + 1) == numel;
395+
}
396+
397+
// Materialize a non-packed strided view into a new contiguous Metal buffer.
398+
// Copies elements from source using strided access. The caller must free the
399+
// returned buffer. On failure returns nullptr.
400+
static void* materialize_packed(
401+
void* src,
402+
const std::vector<aten::SizesType>& sizes,
403+
const std::vector<aten::StridesType>& strides,
404+
size_t element_size) {
405+
int64_t ndim = static_cast<int64_t>(sizes.size());
406+
int64_t numel = 1;
407+
for (int64_t i = 0; i < ndim; i++) {
408+
numel *= sizes[i];
409+
}
410+
411+
void* dst = metal_allocate_buffer(numel * element_size);
412+
if (!dst)
413+
return nullptr;
414+
415+
// Ensure pending GPU writes to the source buffer are complete
416+
if (metal_is_device_pointer(src)) {
417+
auto* stream = getCurrentMetalStream();
418+
if (stream) {
419+
stream->synchronize(SyncType::COMMIT_AND_WAIT);
420+
}
421+
}
422+
423+
// Element-by-element strided copy
424+
char* src_bytes = static_cast<char*>(src);
425+
char* dst_bytes = static_cast<char*>(dst);
426+
std::vector<int64_t> coord(ndim, 0);
427+
for (int64_t flat = 0; flat < numel; flat++) {
428+
// Compute source offset from strides
429+
int64_t src_offset = 0;
430+
for (int64_t d = 0; d < ndim; d++) {
431+
src_offset += coord[d] * strides[d];
432+
}
433+
std::memcpy(
434+
dst_bytes + flat * element_size,
435+
src_bytes + src_offset * element_size,
436+
element_size);
437+
438+
// Increment coordinate (last dim fastest)
439+
for (int64_t d = ndim - 1; d >= 0; d--) {
440+
if (++coord[d] < sizes[d])
441+
break;
442+
coord[d] = 0;
443+
}
444+
}
445+
return dst;
446+
}
447+
370448
AOTITorchError aoti_torch__reinterpret_tensor(
371449
AOTITensorHandle self,
372450
int64_t ndim,
@@ -377,6 +455,12 @@ AOTITorchError aoti_torch__reinterpret_tensor(
377455
ET_LOG(Debug, "aoti_torch__reinterpret_tensor: entered");
378456

379457
// Validate input parameters first
458+
ET_CHECK_OR_RETURN_ERROR(
459+
ndim >= 0,
460+
InvalidArgument,
461+
"aoti_torch__reinterpret_tensor failed: ndim must be >= 0, got %lld",
462+
ndim);
463+
380464
ET_CHECK_OR_RETURN_ERROR(
381465
self != nullptr,
382466
InvalidArgument,
@@ -430,8 +514,9 @@ AOTITorchError aoti_torch__reinterpret_tensor(
430514
data_ptr);
431515

432516
// Handle storage offset by adjusting the data pointer
433-
void* adjusted_data = static_cast<char*>(data_ptr) +
434-
(storage_offset * dtype_to_element_size(dtype));
517+
size_t element_size = dtype_to_element_size(dtype);
518+
void* adjusted_data =
519+
static_cast<char*>(data_ptr) + (storage_offset * element_size);
435520

436521
// Convert sizes using utility function from utils.h
437522
std::vector<aten::SizesType> sizes = convert_sizes_to_vector(ndim, sizes_ptr);
@@ -440,14 +525,35 @@ AOTITorchError aoti_torch__reinterpret_tensor(
440525
std::vector<aten::StridesType> strides =
441526
convert_strides_to_vector(ndim, sizes_ptr, strides_ptr);
442527

443-
// Create new tensor view that reinterprets the same memory with different
444-
// shape/strides This creates a view, not a copy - the data pointer is shared
528+
// If the view is not densely packed (e.g. chunk/split creating holes),
529+
// materialize it into a new contiguous buffer.
530+
void* tensor_data = adjusted_data;
531+
bool owns_buffer = false;
532+
if (!is_packed_strides(sizes, strides)) {
533+
ET_LOG(
534+
Debug,
535+
"aoti_torch__reinterpret_tensor: non-packed strides, "
536+
"materializing to packed buffer");
537+
tensor_data =
538+
materialize_packed(adjusted_data, sizes, strides, element_size);
539+
ET_CHECK_OR_RETURN_ERROR(
540+
tensor_data != nullptr,
541+
MemoryAllocationFailed,
542+
"Failed to materialize non-packed tensor");
543+
owns_buffer = true;
544+
545+
// Compute contiguous strides for the packed buffer
546+
strides.resize(ndim);
547+
if (ndim > 0) {
548+
strides[ndim - 1] = 1;
549+
for (int64_t i = ndim - 2; i >= 0; i--) {
550+
strides[i] = strides[i + 1] * sizes[i + 1];
551+
}
552+
}
553+
}
554+
445555
std::shared_ptr<Tensor> tensor = executorch::extension::from_blob(
446-
adjusted_data, // Use adjusted data pointer with storage offset applied
447-
sizes, // New sizes with explicit SizesType
448-
strides, // New strides with explicit StridesType
449-
dtype_to_scalar_type(dtype) // Convert dtype with explicit type casting
450-
);
556+
tensor_data, sizes, strides, dtype_to_scalar_type(dtype));
451557

452558
ET_CHECK_OR_RETURN_ERROR(
453559
tensor != nullptr,
@@ -456,32 +562,36 @@ AOTITorchError aoti_torch__reinterpret_tensor(
456562

457563
// Store the tensor so it doesn't get destroyed
458564
tensors[tensor.get()] = tensor;
459-
460565
*ret_new_tensor = tensor.get();
461566

462-
if (adjusted_data != data_ptr) {
463-
ET_LOG(
464-
Debug,
465-
"aoti_torch__reinterpret_tensor: Adjusted original_data=%p, storage_offset=%lld, element_size=%zu, adjusted_data=%p",
466-
data_ptr,
467-
storage_offset,
468-
dtype_to_element_size(dtype),
469-
adjusted_data);
470-
471-
ET_CHECK_OR_RETURN_ERROR(
472-
metal_buffer_nocopy(adjusted_data, tensor->nbytes(), true),
473-
Internal,
474-
"metal_buffer_nocopy failed for adjusted_data=%p, nbytes=%zu",
475-
adjusted_data,
476-
static_cast<size_t>(tensor->nbytes()));
477-
478-
memory_to_n_tensor[adjusted_data] = NOT_OWN;
479-
}
567+
if (owns_buffer) {
568+
// The materialized buffer is a new allocation owned by this tensor
569+
memory_to_n_tensor[tensor_data] = 1;
570+
} else {
571+
if (adjusted_data != data_ptr) {
572+
ET_LOG(
573+
Debug,
574+
"aoti_torch__reinterpret_tensor: Adjusted original_data=%p, "
575+
"storage_offset=%lld, element_size=%zu, adjusted_data=%p",
576+
data_ptr,
577+
storage_offset,
578+
element_size,
579+
adjusted_data);
580+
581+
ET_CHECK_OR_RETURN_ERROR(
582+
metal_buffer_nocopy(adjusted_data, tensor->nbytes(), true),
583+
Internal,
584+
"metal_buffer_nocopy failed for adjusted_data=%p, nbytes=%zu",
585+
adjusted_data,
586+
static_cast<size_t>(tensor->nbytes()));
587+
588+
memory_to_n_tensor[adjusted_data] = NOT_OWN;
589+
}
480590

481-
// Increment the reference count for this memory address only if it is owned
482-
// by tensor
483-
if (memory_to_n_tensor[data_ptr] != NOT_OWN) {
484-
memory_to_n_tensor[data_ptr] += 1;
591+
// Increment the reference count for this memory address only if it is owned
592+
if (memory_to_n_tensor[data_ptr] != NOT_OWN) {
593+
memory_to_n_tensor[data_ptr] += 1;
594+
}
485595
}
486596

487597
ET_LOG(Debug, "aoti_torch__reinterpret_tensor: successful");

backends/apple/metal/tests/test_modules.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -664,6 +664,30 @@ def forward(
664664
}
665665

666666

667+
# -------------------------------------------------------------------------
668+
# Narrow (non-packed reinterpret_tensor materialization)
669+
# -------------------------------------------------------------------------
670+
671+
672+
class NarrowLastDim(nn.Module):
673+
"""Splits the last dimension into two halves via narrow, producing
674+
non-packed strided views that the Metal backend must materialize
675+
into contiguous buffers."""
676+
677+
def forward(self, x: torch.Tensor) -> torch.Tensor:
678+
half = x.shape[-1] // 2
679+
a = x.narrow(-1, 0, half)
680+
b = x.narrow(-1, half, half)
681+
return a * 2.0 + b
682+
683+
684+
MODULE_REGISTRY["narrow_last_dim"] = {
685+
"model_class": NarrowLastDim,
686+
"input_shapes": [(2, 4, 16)],
687+
"description": "Non-packed reinterpret_tensor views from last-dim split",
688+
}
689+
690+
667691
# -------------------------------------------------------------------------
668692
# Top-k (MoE expert routing)
669693
# -------------------------------------------------------------------------

0 commit comments

Comments
 (0)