diff --git a/backends/apple/metal/runtime/shims/v2/aoti_kernel.h b/backends/apple/metal/runtime/shims/v2/aoti_kernel.h new file mode 100644 index 00000000000..bacb20978dc --- /dev/null +++ b/backends/apple/metal/runtime/shims/v2/aoti_kernel.h @@ -0,0 +1,102 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// AOTI kernel-dispatch C ABI for the v2 Metal backend. +// +// Symbols here are called by the AOTI .so when it has a hand-written +// shader to dispatch (shader library + kernel function + arg encoding +// + dispatch). Buffer mgmt and tensor lifecycle live in aoti_tensor.h; +// op-registry fallbacks (mm/bmm) live in aoti_ops.h. + +#pragma once + +#include + +namespace executorch { +namespace backends { +namespace metal { + +struct AOTIMetalKernelFunctionOpaque; +using AOTIMetalKernelFunctionHandle = AOTIMetalKernelFunctionOpaque*; + +struct AOTIMetalShaderLibraryOpaque; +using AOTIMetalShaderLibraryHandle = AOTIMetalShaderLibraryOpaque*; + +#ifdef __cplusplus +extern "C" { +#endif + +// Shader library +AOTITorchError aoti_torch_mps_create_shader_library( + const char* metal_shader_source, + AOTIMetalShaderLibraryHandle* library_handle); + +AOTITorchError aoti_torch_mps_delete_shader_library( + AOTIMetalShaderLibraryHandle library_handle); + +AOTITorchError aoti_torch_mps_get_kernel_function( + AOTIMetalShaderLibraryHandle library_handle, + const char* kernel_name, + AOTIMetalKernelFunctionHandle* function_handle); + +// Kernel arg / dispatch +AOTITorchError aoti_torch_mps_start_encoding( + AOTIMetalKernelFunctionHandle func); + +AOTITorchError aoti_torch_mps_set_arg_tensor( + AOTIMetalKernelFunctionHandle func, + unsigned idx, + AOTITensorHandle tensor); + +AOTITorchError aoti_torch_mps_set_arg_int( + AOTIMetalKernelFunctionHandle func, + unsigned idx, + int64_t val); + +AOTITorchError aoti_torch_mps_dispatch_single( + AOTIMetalKernelFunctionHandle func, + uint64_t length); + +AOTITorchError aoti_torch_mps_dispatch_single_with_group_size( + AOTIMetalKernelFunctionHandle func, + uint64_t length, + uint64_t group_size); + +AOTITorchError aoti_torch_mps_dispatch_array( + AOTIMetalKernelFunctionHandle func, + const uint64_t* length, + size_t length_size); + +AOTITorchError aoti_torch_mps_dispatch_array_with_group_size( + AOTIMetalKernelFunctionHandle func, + const uint64_t* length, + size_t length_size, + const uint64_t* group_size, + size_t group_size_size); + +// Command block +typedef void (*aoti_torch_mps_command_block_callback_t)( + AOTIMetalKernelFunctionHandle func, + void* user_data); + +void aoti_torch_mps_shared_callback( + AOTIMetalKernelFunctionHandle func, + void* user_data); + +AOTITorchError aoti_torch_mps_run_command_block( + AOTIMetalKernelFunctionHandle func, + aoti_torch_mps_command_block_callback_t callback, + void* user_data); + +#ifdef __cplusplus +} // extern "C" +#endif + +} // namespace metal +} // namespace backends +} // namespace executorch diff --git a/backends/apple/metal/runtime/shims/v2/aoti_kernel.mm b/backends/apple/metal/runtime/shims/v2/aoti_kernel.mm new file mode 100644 index 00000000000..0278a991188 --- /dev/null +++ b/backends/apple/metal/runtime/shims/v2/aoti_kernel.mm @@ -0,0 +1,334 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// AOTI kernel-dispatch impl built directly on MetalStream. +// +// No ETMetalShaderLibrary / ETMetalKernelFunction / ETMetalStream classes. +// Opaque AOTI handles point to the two lightweight structs below. + +#import + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace metal { + +using metal_v2::Arg; +using metal_v2::MetalKernel; +using metal_v2::MetalStream; +using metal_v2::uvec3; + +// ========================================================================= +// Internal handle structs — not visible outside this file +// ========================================================================= + +// Holds shader source for deferred compilation. MetalKernelCompiler caches +// compiled kernels internally, so we just need the source string. +struct ShaderLibrary { + std::string source; +}; + +// Holds a compiled kernel + accumulated args between start_encoding and +// dispatch. +struct KernelDispatcher { + MetalKernel* kernel; // owned by MetalKernelCompiler cache + ShaderLibrary* parent; // kept alive by library_storage + + struct DeferredArg { + enum Kind { BUFFER, SCALAR } kind; + void* ptr = nullptr; + size_t size = 0; + uint8_t scalar[8] = {}; + }; + std::vector args; + + void clear() { args.clear(); } + + void setBuffer(unsigned idx, void* ptr, size_t size) { + if (idx >= args.size()) args.resize(idx + 1); + args[idx] = {DeferredArg::BUFFER, ptr, size, {}}; + } + + void setScalarInt(unsigned idx, int64_t val) { + if (idx >= args.size()) args.resize(idx + 1); + auto& a = args[idx]; + a.kind = DeferredArg::SCALAR; + a.size = sizeof(int64_t); + std::memcpy(a.scalar, &val, sizeof(val)); + } + + // Flush accumulated args through MetalStream::dispatch(). + void dispatch(uvec3 grid, uvec3 block) { + auto* stream = getMetalStream(); + + // Register buffer args so MetalStream can map them to MTLBuffers. + for (auto& a : args) { + if (a.kind == DeferredArg::BUFFER && a.ptr) { + stream->registerExternalBuffer(a.ptr, a.size); + } + } + + // Build Arg vector. MetalStream::dispatch takes initializer_list (Arg + // has a private default ctor, so we can't default-construct an array + // of them); build a std::vector via emplace_back, then dispatch via a + // switch on the count. + // TODO: add dispatch(kernel, span, grid, block) to MetalStream. + std::vector gpuArgs; + size_t n = std::min(args.size(), (size_t)8); + gpuArgs.reserve(n); + for (size_t i = 0; i < n; i++) { + auto& a = args[i]; + if (a.kind == DeferredArg::BUFFER) { + gpuArgs.emplace_back(a.ptr, a.size); + } else { + int64_t v; + std::memcpy(&v, a.scalar, sizeof(v)); + gpuArgs.emplace_back(v); + } + } + + switch (n) { + case 0: stream->dispatch(kernel, {}, grid, block); break; + case 1: stream->dispatch(kernel, {gpuArgs[0]}, grid, block); break; + case 2: stream->dispatch(kernel, {gpuArgs[0], gpuArgs[1]}, grid, block); break; + case 3: stream->dispatch(kernel, {gpuArgs[0], gpuArgs[1], gpuArgs[2]}, grid, block); break; + case 4: stream->dispatch(kernel, {gpuArgs[0], gpuArgs[1], gpuArgs[2], gpuArgs[3]}, grid, block); break; + case 5: stream->dispatch(kernel, {gpuArgs[0], gpuArgs[1], gpuArgs[2], gpuArgs[3], gpuArgs[4]}, grid, block); break; + case 6: stream->dispatch(kernel, {gpuArgs[0], gpuArgs[1], gpuArgs[2], gpuArgs[3], gpuArgs[4], gpuArgs[5]}, grid, block); break; + case 7: stream->dispatch(kernel, {gpuArgs[0], gpuArgs[1], gpuArgs[2], gpuArgs[3], gpuArgs[4], gpuArgs[5], gpuArgs[6]}, grid, block); break; + case 8: stream->dispatch(kernel, {gpuArgs[0], gpuArgs[1], gpuArgs[2], gpuArgs[3], gpuArgs[4], gpuArgs[5], gpuArgs[6], gpuArgs[7]}, grid, block); break; + } + + clear(); + } +}; + +// Lifetime management — owns the heap objects behind opaque handles. +static std::unordered_map> library_storage; +static std::unordered_map> function_storage; + +// ========================================================================= +// Grid/block helpers +// ========================================================================= + +static uvec3 computeBlock1D(MetalKernel* k, uint64_t length) { + uint32_t max = k->maxThreadsPerThreadgroup().x; + return uvec3(std::min(max, (uint32_t)length), 1, 1); +} + +static uvec3 computeGrid1D(uint64_t length, uint32_t blockX) { + return uvec3(((uint32_t)length + blockX - 1) / blockX, 1, 1); +} + +// ========================================================================= +// C-ABI shim implementations +// ========================================================================= + +extern "C" { + +// --- Shader library --- + +AOTITorchError aoti_torch_mps_create_shader_library( + const char* source, + AOTIMetalShaderLibraryHandle* out) { + if (!source || !out) return Error::InvalidArgument; + + // DIAGNOSTIC: dump the shader source so we can see exactly what AOTI + // generated for each kernel. Truncate to 1500 chars to keep logs sane. + size_t src_len = std::strlen(source); + ET_LOG(Info, + "[shader-src] create_shader_library len=%zu source:\n----\n%.1500s%s\n----", + src_len, source, src_len > 1500 ? "\n... [truncated]" : ""); + + auto lib = std::make_unique(); + lib->source = source; + auto* raw = lib.get(); + library_storage[raw] = std::move(lib); + *out = reinterpret_cast(raw); + return Error::Ok; +} + +AOTITorchError aoti_torch_mps_delete_shader_library( + AOTIMetalShaderLibraryHandle handle) { + if (!handle) return Error::InvalidArgument; + auto* lib = reinterpret_cast(handle); + library_storage.erase(lib); + return Error::Ok; +} + +AOTITorchError aoti_torch_mps_get_kernel_function( + AOTIMetalShaderLibraryHandle lib_handle, + const char* kernel_name, + AOTIMetalKernelFunctionHandle* out) { + if (!lib_handle || !kernel_name || !out) return Error::InvalidArgument; + + auto* lib = reinterpret_cast(lib_handle); + MetalKernel* kernel = getMetalStream()->compiler()->compile( + lib->source.c_str(), kernel_name); + if (!kernel) { + ET_LOG(Error, "Failed to compile kernel '%s'", kernel_name); + return Error::Internal; + } + + auto disp = std::make_unique(); + disp->kernel = kernel; + disp->parent = lib; + auto* raw = disp.get(); + function_storage[raw] = std::move(disp); + *out = reinterpret_cast(raw); + return Error::Ok; +} + +// --- Encoding / args / dispatch --- + +AOTITorchError aoti_torch_mps_start_encoding(AOTIMetalKernelFunctionHandle func) { + if (!func) return Error::InvalidArgument; + reinterpret_cast(func)->clear(); + return Error::Ok; +} + +AOTITorchError aoti_torch_mps_set_arg_tensor( + AOTIMetalKernelFunctionHandle func, unsigned idx, AOTITensorHandle tensor) { + if (!func || !tensor) return Error::InvalidArgument; + auto* t = reinterpret_cast(tensor); + reinterpret_cast(func)->setBuffer( + idx, t->data_ptr(), t->numel() * t->itemsize()); + return Error::Ok; +} + +AOTITorchError aoti_torch_mps_set_arg_int( + AOTIMetalKernelFunctionHandle func, unsigned idx, int64_t val) { + if (!func) return Error::InvalidArgument; + reinterpret_cast(func)->setScalarInt(idx, val); + return Error::Ok; +} + +AOTITorchError aoti_torch_mps_dispatch_single( + AOTIMetalKernelFunctionHandle func, uint64_t length) { + if (!func) return Error::InvalidArgument; + auto* d = reinterpret_cast(func); + uvec3 block = computeBlock1D(d->kernel, length); + d->dispatch(computeGrid1D(length, block.x), block); + return Error::Ok; +} + +AOTITorchError aoti_torch_mps_dispatch_single_with_group_size( + AOTIMetalKernelFunctionHandle func, uint64_t length, uint64_t group_size) { + if (!func) return Error::InvalidArgument; + auto* d = reinterpret_cast(func); + uint32_t max = d->kernel->maxThreadsPerThreadgroup().x; + uint32_t bx = group_size > 0 ? std::min((uint32_t)group_size, max) + : std::min(max, (uint32_t)length); + d->dispatch(computeGrid1D(length, bx), uvec3(bx, 1, 1)); + return Error::Ok; +} + +AOTITorchError aoti_torch_mps_dispatch_array( + AOTIMetalKernelFunctionHandle func, + const uint64_t* length, size_t length_size) { + if (!func || !length || length_size == 0) return Error::InvalidArgument; + auto* d = reinterpret_cast(func); + uint32_t max = d->kernel->maxThreadsPerThreadgroup().x; + + uvec3 grid, block; + if (length_size == 1) { + uint32_t bx = std::min(max, (uint32_t)length[0]); + block = uvec3(bx, 1, 1); + grid = uvec3(((uint32_t)length[0] + bx - 1) / bx, 1, 1); + } else if (length_size == 2) { + uint32_t bx = std::min(32u, (uint32_t)length[0]); + uint32_t by = max / bx; + block = uvec3(bx, by, 1); + grid = uvec3(((uint32_t)length[0] + bx - 1) / bx, + ((uint32_t)length[1] + by - 1) / by, 1); + } else { + uint32_t bx = std::min(8u, (uint32_t)length[0]); + uint32_t by = std::min(8u, (uint32_t)length[1]); + uint32_t bz = max / (bx * by); + block = uvec3(bx, by, bz); + grid = uvec3(((uint32_t)length[0] + bx - 1) / bx, + ((uint32_t)length[1] + by - 1) / by, + ((uint32_t)(length_size > 2 ? length[2] : 1) + bz - 1) / bz); + } + + d->dispatch(grid, block); + return Error::Ok; +} + +AOTITorchError aoti_torch_mps_dispatch_array_with_group_size( + AOTIMetalKernelFunctionHandle func, + const uint64_t* length, size_t length_size, + const uint64_t* group_size, size_t group_size_size) { + if (!func || !length || length_size == 0) return Error::InvalidArgument; + auto* d = reinterpret_cast(func); + + uint32_t bx, by, bz; + if (length_size == 1) { + bx = group_size && group_size_size > 0 ? (uint32_t)group_size[0] : (uint32_t)length[0]; + by = bz = 1; + } else if (length_size == 2) { + bx = group_size && group_size_size >= 2 ? (uint32_t)group_size[0] : 32; + by = group_size && group_size_size >= 2 ? (uint32_t)group_size[1] : 32; + bz = 1; + } else { + bx = group_size && group_size_size >= 3 ? (uint32_t)group_size[0] : 8; + by = group_size && group_size_size >= 3 ? (uint32_t)group_size[1] : 8; + bz = group_size && group_size_size >= 3 ? (uint32_t)group_size[2] : 8; + } + + uvec3 block(bx, by, bz); + uvec3 grid(((uint32_t)length[0] + bx - 1) / bx, + (length_size > 1 ? ((uint32_t)length[1] + by - 1) / by : 1), + (length_size > 2 ? ((uint32_t)length[2] + bz - 1) / bz : 1)); + d->dispatch(grid, block); + return Error::Ok; +} + +// --- Command block --- + +void aoti_torch_mps_shared_callback( + AOTIMetalKernelFunctionHandle func, void* user_data) { + auto* wrapper = + static_cast*>(user_data); + if (wrapper) (*wrapper)(func); +} + +AOTITorchError aoti_torch_mps_run_command_block( + AOTIMetalKernelFunctionHandle func, + aoti_torch_mps_command_block_callback_t callback, + void* user_data) { + if (!func || !callback) return Error::InvalidArgument; + // v1 used dispatch_sync_with_rethrow on the stream's serial GCD queue — + // that's CPU-side serialization, NOT a GPU drain. v2 (single-threaded + // AOTI) doesn't need either. The callback just encodes into the stream + // (typically into the ICB), no GPU sync required. The earlier + // getMetalStream()->wait() here was an over-translation: it triggered a + // full flush+wait between every kernel callback, which (a) defeated + // ICB's "encode many, drain at end" model and (b) caused stale ICB + // re-execution that corrupted MPSGraph+ICB mixed models. + callback(func, user_data); + return Error::Ok; +} + +} // extern "C" + +} // namespace metal +} // namespace backends +} // namespace executorch diff --git a/backends/apple/metal/runtime/shims/v2/aoti_ops.h b/backends/apple/metal/runtime/shims/v2/aoti_ops.h new file mode 100644 index 00000000000..967c4802704 --- /dev/null +++ b/backends/apple/metal/runtime/shims/v2/aoti_ops.h @@ -0,0 +1,48 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// AOTI op-registry-backed fallbacks for the v2 Metal backend. +// +// AOTI declares a small set of aten ops (mm, bmm, ...) as +// supported_fallback_kernels, so the AOTI .so emits direct calls to these +// symbols rather than generated shaders. v2 routes every fallback through +// the metal_v2 MetalOpRegistry — if an op isn't registered there, we +// return Error::NotImplemented (clean error rather than a missing-symbol +// dlopen crash). + +#pragma once + +#include + +namespace executorch { +namespace backends { +namespace metal { + +#ifdef __cplusplus +extern "C" { +#endif + +// 2D matmul: out = self @ mat2. +AOTITorchError aoti_torch_mps_mm_out( + AOTITensorHandle out, + AOTITensorHandle self, + AOTITensorHandle mat2); + +// 3D batched matmul: out = self @ mat2 (per batch). +AOTITorchError aoti_torch_mps_bmm_out( + AOTITensorHandle out, + AOTITensorHandle self, + AOTITensorHandle mat2); + +#ifdef __cplusplus +} // extern "C" +#endif + +} // namespace metal +} // namespace backends +} // namespace executorch diff --git a/backends/apple/metal/runtime/shims/v2/aoti_ops.mm b/backends/apple/metal/runtime/shims/v2/aoti_ops.mm new file mode 100644 index 00000000000..0ee5f290e98 --- /dev/null +++ b/backends/apple/metal/runtime/shims/v2/aoti_ops.mm @@ -0,0 +1,216 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// AOTI op-registry fallback impls. +// +// SlimTensor handles arrive from the AOTI .so. We materialize zero-copy +// CPU-side ETensor views for them (because MetalOpRegistry consumes +// EValue, which holds etensor::Tensor — not SlimTensor) and dispatch via +// the registry. The TensorPtr vectors keep the views alive across +// dispatch. + +#import + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace executorch { +namespace backends { +namespace metal { + +namespace { + +using executorch::backends::metal_v2::MetalOp; +using executorch::backends::metal_v2::MetalOpRegistry; + +// Convert PyTorch dtype code → ExecuTorch's exec_aten ScalarType. +executorch::aten::ScalarType to_aten_scalar_type( + executorch::backends::aoti::slim::c10::ScalarType slim_dt) { + return static_cast(static_cast(slim_dt)); +} + +// Maximum tensor rank we view in-place. Bump if any op exceeds this; today +// our ops are all ≤4D (mm/bmm/add/relu). +constexpr size_t kMaxTensorDim = 8; + +// A single-tensor view materialized entirely on the stack: sizes, strides, +// dim_order arrays + a TensorImpl in placement-new storage. Construct via +// makeView(); resulting Tensor wraps the placement-new'd TensorImpl, so +// holding the StackTensorView in scope keeps the Tensor valid. +// +// This replaces extension::from_blob() which heap-allocates Storage + +// TensorImpl + 2 shared_ptr control blocks per call. Per-dispatch we now +// allocate zero heap. +struct StackTensorView { + using SizesType = executorch::aten::SizesType; + using StridesType = executorch::aten::StridesType; + using DimOrderType = executorch::runtime::etensor::TensorImpl::DimOrderType; + using TensorImpl = executorch::runtime::etensor::TensorImpl; + using ETensor = executorch::runtime::etensor::Tensor; + + std::array sizes; + std::array strides; + std::array dim_order; + alignas(TensorImpl) std::byte impl_storage[sizeof(TensorImpl)]; + bool constructed = false; + + // Construct the TensorImpl in-place over a SlimTensor's storage and + // return a wrapping ETensor (cheap: just a TensorImpl* copy). + ETensor makeView(const Tensor& s) { + const size_t dim = static_cast(s.dim()); + ET_CHECK_MSG(dim <= kMaxTensorDim, + "StackTensorView: tensor rank %zu exceeds kMaxTensorDim=%zu", + dim, kMaxTensorDim); + for (size_t i = 0; i < dim; ++i) { + sizes[i] = static_cast(s.sizes()[i]); + strides[i] = static_cast(s.strides()[i]); + dim_order[i] = static_cast(i); + } + auto* impl = new (impl_storage) TensorImpl( + to_aten_scalar_type(s.dtype()), + static_cast(dim), + sizes.data(), + s.data_ptr(), + dim_order.data(), + strides.data(), + // DYNAMIC_BOUND so resize_tensor (if invoked by an op) is a no-op + // when the requested shape matches our current shape — matches the + // prior from_blob default. + executorch::runtime::TensorShapeDynamism::DYNAMIC_BOUND); + constructed = true; + return ETensor(impl); + } + + ~StackTensorView() { + if (constructed) { + reinterpret_cast(impl_storage)->~TensorImpl(); + } + } + + StackTensorView() = default; + StackTensorView(const StackTensorView&) = delete; + StackTensorView& operator=(const StackTensorView&) = delete; +}; + +// Maximum number of in/out tensors any registry op consumes. We stack-allocate +// EValue + pointer storage sized to this so dispatch is heap-alloc-free. +// Bump if a future op needs more — small overhead per call. +constexpr size_t kMaxOpInputs = 8; +constexpr size_t kMaxOpOutputs = 4; + +// Run a pre-resolved op against a set of SlimTensor handles. Materializes +// ETensor views (zero-copy, in-place via StackTensorView), wraps them as +// EValues, and dispatches. Zero heap allocations on the hot path: sizes, +// strides, dim_order, TensorImpl, EValue, EValue* pointer arrays — all +// stack-resident. +AOTITorchError dispatchOp( + MetalOp* op, + std::initializer_list inTensors, + std::initializer_list outTensors) { + if (!op) return Error::NotImplemented; + if (inTensors.size() > kMaxOpInputs || + outTensors.size() > kMaxOpOutputs) { + ET_LOG(Error, + "aoti_ops_v2: op '%s' exceeds max in=%zu/%zu, out=%zu/%zu", + op->name(), inTensors.size(), kMaxOpInputs, + outTensors.size(), kMaxOpOutputs); + return Error::InvalidArgument; + } + + // Stack-resident ETensor views (each owns sizes/strides/dim_order arrays + // and a TensorImpl in placement-new storage). Holding these in scope + // keeps the ETensors valid for the duration of dispatch. + std::array inViews; + std::array outViews; + std::array inEValues; + std::array outEValues; + std::array inPtrs; + std::array outPtrs; + + size_t in_n = 0; + for (auto* t : inTensors) { + inEValues[in_n] = executorch::runtime::EValue(inViews[in_n].makeView(*t)); + inPtrs[in_n] = &inEValues[in_n]; + ++in_n; + } + size_t out_n = 0; + for (auto* t : outTensors) { + outEValues[out_n] = + executorch::runtime::EValue(outViews[out_n].makeView(*t)); + outPtrs[out_n] = &outEValues[out_n]; + ++out_n; + } + + op->dispatch( + getMetalStream(), + MetalOp::EValuePtrSpan(inPtrs.data(), in_n), + MetalOp::EValuePtrSpan(outPtrs.data(), out_n)); + return Error::Ok; +} + +// Convenience wrapper: does the (string-keyed, allocating) registry lookup +// each call. Prefer dispatchOp(MetalOp*, ...) at hot call sites where the +// op name is fixed — see the static-cached pattern in aoti_torch_mps_mm_out. +AOTITorchError dispatchRegistryOp( + const char* opName, + std::initializer_list inTensors, + std::initializer_list outTensors) { + auto* op = MetalOpRegistry::shared().get(opName); + if (!op) { + ET_LOG(Error, "aoti_ops_v2: op '%s' not found in MetalOpRegistry", opName); + return Error::NotImplemented; + } + return dispatchOp(op, inTensors, outTensors); +} + +} // namespace + +extern "C" { + +AOTITorchError aoti_torch_mps_mm_out( + AOTITensorHandle out, + AOTITensorHandle self, + AOTITensorHandle mat2) { + if (!out || !self || !mat2) return Error::InvalidArgument; + // Cache the op pointer across calls to skip the registry lookup hot path + // (which hashes a std::string("aten::mm") on every call). + static MetalOp* op = MetalOpRegistry::shared().get("aten::mm"); + return dispatchOp( + op, + {reinterpret_cast(self), reinterpret_cast(mat2)}, + {reinterpret_cast(out)}); +} + +AOTITorchError aoti_torch_mps_bmm_out( + AOTITensorHandle out, + AOTITensorHandle self, + AOTITensorHandle mat2) { + if (!out || !self || !mat2) return Error::InvalidArgument; + static MetalOp* op = MetalOpRegistry::shared().get("aten::bmm"); + return dispatchOp( + op, + {reinterpret_cast(self), reinterpret_cast(mat2)}, + {reinterpret_cast(out)}); +} + +} // extern "C" + +} // namespace metal +} // namespace backends +} // namespace executorch diff --git a/backends/apple/metal/runtime/shims/v2/aoti_tensor.cpp b/backends/apple/metal/runtime/shims/v2/aoti_tensor.cpp new file mode 100644 index 00000000000..03650195e61 --- /dev/null +++ b/backends/apple/metal/runtime/shims/v2/aoti_tensor.cpp @@ -0,0 +1,585 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// AOTI tensor + memory layer impl for the v2 Metal backend. +// +// All buffer/memory work routes through the metal_* C ABI in runtime.h +// so this file stays a .cpp (no Metal/Metal.h required). + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +// Forward declare validate_dtype (defined in shims/utils.cpp). We don't +// include shims/utils.h here because it pulls in v1's shims/types.h, which +// defines `Tensor = etensor::Tensor` and conflicts with our SlimTensor +// alias from aoti_types.h. +namespace executorch { +namespace backends { +namespace metal { +extern "C" AOTITorchError validate_dtype(int32_t dtype); +} // namespace metal +} // namespace backends +} // namespace executorch + +namespace executorch { +namespace backends { +namespace metal { + +using namespace executorch::backends::aoti; +namespace slim = executorch::backends::aoti::slim; + +extern "C" { + +// ===================================================================== +// Globals +// ===================================================================== + +std::unordered_map> tensors; + +// Reference counting for memory addresses. +// NOT_OWN (-1): tensor wraps externally-owned memory; never freed by us. +// N >= 1 : N live tensor handles share this allocation; the last +// handle to be deleted frees the underlying buffer. +constexpr int32_t NOT_OWN = -1; +std::unordered_map memory_to_n_tensor; + +namespace { + +// Convert int64_t sizes/strides arrays into std::vector for use +// with slim::from_blob / IntArrayRef. +std::vector to_int64_vector(int64_t ndim, const int64_t* ptr) { + if (ptr == nullptr) { + return {}; + } + return std::vector(ptr, ptr + ndim); +} + +// Compute contiguous (row-major) strides if the caller didn't provide any. +std::vector compute_or_copy_strides( + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr) { + if (strides_ptr != nullptr) { + return to_int64_vector(ndim, strides_ptr); + } + std::vector strides(ndim); + if (ndim > 0) { + strides[ndim - 1] = 1; + for (int64_t i = ndim - 2; i >= 0; i--) { + // Match v1 quirk: when next-dim size is 0, just propagate the previous + // stride rather than zeroing out (avoids degenerate stride patterns). + strides[i] = (sizes_ptr[i + 1] == 0) + ? strides[i + 1] + : strides[i + 1] * sizes_ptr[i + 1]; + } + } + return strides; +} + +// Insert a SlimTensor into the tensors map and return the raw pointer +// used as the AOTI handle. +Tensor* register_tensor(slim::SlimTensor&& t) { + auto owned = std::make_unique(std::move(t)); + Tensor* raw = owned.get(); + tensors.emplace(raw, std::move(owned)); + return raw; +} + +} // namespace + +// ===================================================================== +// Tensor lifecycle +// ===================================================================== + +AOTITorchError aoti_torch_create_tensor_from_blob_v2( + void* data, + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr, + int64_t storage_offset, + int32_t dtype, + int32_t device_type, + int32_t device_index, + AOTITensorHandle* ret_new_tensor, + int32_t layout, + const uint8_t* opaque_metadata, + int64_t opaque_metadata_size) { + ET_LOG(Debug, "aoti_torch_create_tensor_from_blob_v2[v2]: entered"); + + (void)device_type; + (void)device_index; + (void)opaque_metadata; + (void)layout; + (void)opaque_metadata_size; + + ET_CHECK_OR_RETURN_ERROR( + data != nullptr, InvalidArgument, "data pointer is null"); + ET_CHECK_OR_RETURN_ERROR( + !(sizes_ptr == nullptr && ndim > 0), + InvalidArgument, + "sizes_ptr is null"); + ET_CHECK_OR_RETURN_ERROR( + ret_new_tensor != nullptr, InvalidArgument, "ret_new_tensor is null"); + + ET_CHECK_OK_OR_RETURN_ERROR(validate_dtype(dtype)); + + // Apply storage_offset by adjusting the raw pointer; pass 0 storage_offset + // to from_blob (mirrors v1 behavior). + void* adjusted_data = static_cast(data) + + (storage_offset * dtype_to_element_size(dtype)); + + std::vector sizes = to_int64_vector(ndim, sizes_ptr); + std::vector strides = + compute_or_copy_strides(ndim, sizes_ptr, strides_ptr); + + slim::SlimTensor t = slim::from_blob( + adjusted_data, + slim::makeArrayRef(sizes), + slim::makeArrayRef(strides), + dtype_to_c10_scalar_type(dtype)); + + *ret_new_tensor = register_tensor(std::move(t)); + + // Register this address as externally-owned. It must not already be + // tracked: tensor-from-blob never owns memory it wraps. + auto memory_it = memory_to_n_tensor.find(adjusted_data); + ET_CHECK_OR_RETURN_ERROR( + memory_it == memory_to_n_tensor.end(), + InvalidArgument, + "Memory address %p is already being tracked by another tensor", + adjusted_data); + memory_to_n_tensor[adjusted_data] = NOT_OWN; + return Error::Ok; +} + +AOTITorchError aoti_torch_empty_strided( + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr, + int32_t dtype, + int32_t device_type, + int32_t device_index, + AOTITensorHandle* ret_new_tensor) { + ET_LOG(Debug, "aoti_torch_empty_strided[v2]: entered"); + (void)device_index; + + void* ptr; + int64_t numel = 1; + for (int i = 0; i < ndim; i++) { + numel *= sizes_ptr[i]; + } + + size_t element_size = dtype_to_element_size(dtype); + ET_CHECK_OR_RETURN_ERROR( + element_size != 0, + InvalidArgument, + "Invalid element size for dtype: %d", + dtype); + int64_t nbytes = numel * element_size; + + int32_t mps_device_type = aoti_torch_device_type_mps(); // Returns 13 + if (device_type == mps_device_type) { + ptr = metal_allocate_buffer(nbytes); + if (ptr == nullptr) { + ET_LOG(Error, "Failed to allocate %lld bytes on Metal", nbytes); + return Error::MemoryAllocationFailed; + } + } else if (device_type == 0) { // cpu + int result = posix_memalign(&ptr, 16, nbytes); + ET_CHECK_OR_RETURN_ERROR( + result == 0, + MemoryAllocationFailed, + "posix_memalign failed: error %d", + result); + ET_CHECK_OR_RETURN_ERROR( + ptr != nullptr, MemoryAllocationFailed, "posix_memalign returned null"); + } else { + ET_CHECK_OR_RETURN_ERROR( + false, + NotImplemented, + "empty_strided not implemented for device type %d", + device_type); + } + + std::vector sizes = to_int64_vector(ndim, sizes_ptr); + std::vector strides = + compute_or_copy_strides(ndim, sizes_ptr, strides_ptr); + + slim::SlimTensor t = slim::from_blob( + ptr, + slim::makeArrayRef(sizes), + slim::makeArrayRef(strides), + dtype_to_c10_scalar_type(dtype)); + + *ret_new_tensor = register_tensor(std::move(t)); + + // This tensor logically owns the buffer (we allocated it). Refcount=1. + memory_to_n_tensor[ptr] = 1; + return Error::Ok; +} + +AOTITorchError aoti_torch_delete_tensor_object(AOTITensorHandle tensor) { + if (tensor == nullptr) { + return Error::Ok; + } + + auto it = tensors.find(tensor); + // Tensors not in the map are temporary views (e.g. CPU ETensor wrappers + // created by metal_backend_v2 for I/O). Nothing to free. + if (it == tensors.end()) { + return Error::Ok; + } + + void* data_ptr = tensor->data_ptr(); + auto memory_it = memory_to_n_tensor.find(data_ptr); + if (memory_it != memory_to_n_tensor.end()) { + int32_t ref_count = memory_it->second; + + if (ref_count == NOT_OWN) { + tensors.erase(it); + return Error::Ok; + } else if (ref_count == 1) { + if (metal_is_device_pointer(data_ptr)) { + metal_deallocate_buffer(data_ptr); + } else { + free(data_ptr); + } + memory_to_n_tensor.erase(memory_it); + } else if (ref_count > 1) { + memory_to_n_tensor[data_ptr] = ref_count - 1; + } else { + ET_LOG(Error, "Invalid reference count %d for memory %p", ref_count, data_ptr); + return Error::Internal; + } + } + + tensors.erase(it); + return Error::Ok; +} + +AOTITorchError aoti_torch_copy_( + AOTITensorHandle self, + AOTITensorHandle src, + int32_t non_blocking) { + (void)non_blocking; + + ET_CHECK_OR_RETURN_ERROR( + self != nullptr, InvalidArgument, "self tensor is null"); + ET_CHECK_OR_RETURN_ERROR( + src != nullptr, InvalidArgument, "src tensor is null"); + + // Dtype compatibility check (same dtype required, like PyTorch copy_). + auto self_dtype = self->dtype(); + auto src_dtype = src->dtype(); + ET_CHECK_OR_RETURN_ERROR( + self_dtype == src_dtype, + InvalidArgument, + "dtype mismatch. self=%d, src=%d", + static_cast(self_dtype), + static_cast(src_dtype)); + + // Numel must match. + size_t self_numel = self->numel(); + size_t src_numel = src->numel(); + ET_CHECK_OR_RETURN_ERROR( + self_numel == src_numel, + InvalidArgument, + "numel mismatch. self=%zu, src=%zu", + self_numel, + src_numel); + + // Device classification via the GPU pointer registry (not SlimTensor's + // own device tag — v2 SlimTensors are all CPU-tagged regardless of + // whether the buffer is GPU-accessible). + bool srcIsDevice = metal_is_device_pointer(src->data_ptr()); + bool dstIsDevice = metal_is_device_pointer(self->data_ptr()); + + // Same-schema fast path. (TODO: catch (4,1,5) -> (4,5)-style cases.) + bool same_schema = + self->dim() == src->dim() && self->dtype() == src->dtype(); + if (same_schema) { + auto self_strides = self->strides(); + auto src_strides = src->strides(); + for (size_t i = 0; i < self->dim(); i++) { + if (self_strides[i] != src_strides[i]) { + same_schema = false; + break; + } + } + } + + size_t total_bytes = src->nbytes(); + if (same_schema) { + int result = metal_copy_memory( + self->data_ptr(), + src->data_ptr(), + total_bytes, + srcIsDevice, + dstIsDevice); + ET_CHECK_OR_RETURN_ERROR( + result == 0, Internal, "metal_copy_memory failed: %d", result); + } else { + ET_LOG(Error, "Different schema copies are not implemented yet"); + return Error::NotImplemented; + } + return Error::Ok; +} + +AOTITorchError aoti_torch__reinterpret_tensor( + AOTITensorHandle self, + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr, + int64_t storage_offset, + AOTITensorHandle* ret_new_tensor) { + ET_CHECK_OR_RETURN_ERROR( + self != nullptr, InvalidArgument, "self tensor is null"); + ET_CHECK_OR_RETURN_ERROR( + !(sizes_ptr == nullptr && ndim > 0), + InvalidArgument, + "sizes_ptr is null"); + ET_CHECK_OR_RETURN_ERROR( + ret_new_tensor != nullptr, InvalidArgument, "ret_new_tensor is null"); + + int32_t device_type = 0; + int32_t device_index = 0; + ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_device_type(self, &device_type)); + ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_device_index(self, &device_index)); + ET_CHECK_OR_RETURN_ERROR( + device_index == 0, + InvalidArgument, + "device_index must be 0, got: %d", + device_index); + + int32_t dtype = static_cast(self->dtype()); + ET_CHECK_OK_OR_RETURN_ERROR(validate_dtype(dtype)); + + void* data_ptr = self->data_ptr(); + ET_CHECK_OR_RETURN_ERROR( + data_ptr != nullptr, + InvalidArgument, + "Source tensor has null data pointer"); + + auto memory_it = memory_to_n_tensor.find(data_ptr); + ET_CHECK_OR_RETURN_ERROR( + memory_it != memory_to_n_tensor.end(), + InvalidArgument, + "Memory address %p is not being tracked", + data_ptr); + + void* adjusted_data = static_cast(data_ptr) + + (storage_offset * dtype_to_element_size(dtype)); + + std::vector sizes = to_int64_vector(ndim, sizes_ptr); + std::vector strides = + compute_or_copy_strides(ndim, sizes_ptr, strides_ptr); + + slim::SlimTensor t = slim::from_blob( + adjusted_data, + slim::makeArrayRef(sizes), + slim::makeArrayRef(strides), + dtype_to_c10_scalar_type(dtype)); + + *ret_new_tensor = register_tensor(std::move(t)); + + if (adjusted_data != data_ptr) { + ET_CHECK_OR_RETURN_ERROR( + metal_buffer_nocopy(adjusted_data, (*ret_new_tensor)->nbytes(), true), + Internal, + "metal_buffer_nocopy failed for adjusted_data %p of size %zu", + adjusted_data, + (*ret_new_tensor)->nbytes()); + memory_to_n_tensor[adjusted_data] = NOT_OWN; + } + + // Bump refcount on the source pointer (only when it's owned, not borrowed). + if (memory_to_n_tensor[data_ptr] != NOT_OWN) { + memory_to_n_tensor[data_ptr] += 1; + } + return Error::Ok; +} + +AOTITorchError aoti_torch_new_tensor_handle( + Tensor* orig_handle, + Tensor** new_handle) { + ET_CHECK_OR_RETURN_ERROR( + orig_handle != nullptr, InvalidArgument, "orig_handle is null"); + ET_CHECK_OR_RETURN_ERROR( + new_handle != nullptr, InvalidArgument, "new_handle is null"); + + int32_t device_index = 0; + ET_CHECK_OK_OR_RETURN_ERROR( + aoti_torch_get_device_index(orig_handle, &device_index)); + ET_CHECK_OR_RETURN_ERROR( + device_index == 0, + InvalidArgument, + "device_index must be 0, got: %d", + device_index); + + int32_t dtype = static_cast(orig_handle->dtype()); + ET_CHECK_OK_OR_RETURN_ERROR(validate_dtype(dtype)); + + void* data_ptr = orig_handle->data_ptr(); + ET_CHECK_OR_RETURN_ERROR( + data_ptr != nullptr, + InvalidArgument, + "Source tensor has null data pointer"); + + auto memory_it = memory_to_n_tensor.find(data_ptr); + ET_CHECK_OR_RETURN_ERROR( + memory_it != memory_to_n_tensor.end(), + InvalidArgument, + "Memory address %p is not being tracked", + data_ptr); + + // Mirror the original tensor's shape/strides/dtype, sharing storage. + std::vector sizes( + orig_handle->sizes().begin(), orig_handle->sizes().end()); + std::vector strides( + orig_handle->strides().begin(), orig_handle->strides().end()); + + slim::SlimTensor t = slim::from_blob( + data_ptr, + slim::makeArrayRef(sizes), + slim::makeArrayRef(strides), + orig_handle->dtype()); + + *new_handle = register_tensor(std::move(t)); + + // Refcount: only bump when the source memory is owned (not borrowed). + memory_to_n_tensor[data_ptr] = memory_to_n_tensor[data_ptr] == NOT_OWN + ? NOT_OWN + : memory_to_n_tensor[data_ptr] + 1; + return Error::Ok; +} + +void cleanup_memory() { + // Use aoti_torch_delete_tensor_object so refcounts/buffer frees stay in + // sync. Collect keys first since deletion modifies the map. + std::vector tensor_ptrs; + tensor_ptrs.reserve(tensors.size()); + for (const auto& entry : tensors) { + tensor_ptrs.push_back(entry.first); + } + for (Tensor* tensor_ptr : tensor_ptrs) { + aoti_torch_delete_tensor_object(tensor_ptr); + } + + tensors.clear(); + metal_cleanup_resources(); + + ET_LOG(Info, "[v2] Cleared all tensors and Metal resources"); +} + +// ===================================================================== +// MPS buffer shims +// +// All four route through the metal_* C ABI in runtime.h. This means +// allocations are device-pointer-tracked (so metal_is_device_pointer +// works correctly downstream) and the file stays a .cpp. +// ===================================================================== + +AOTITorchError aoti_torch_mps_malloc(void** buffer, size_t num_bytes) { + if (num_bytes == 0) { + if (buffer) *buffer = nullptr; + return Error::Ok; + } + if (!buffer) return Error::InvalidArgument; + *buffer = metal_allocate_buffer(static_cast(num_bytes)); + return *buffer ? Error::Ok : Error::Internal; +} + +AOTITorchError aoti_torch_mps_free(void* ptr) { + if (ptr) metal_deallocate_buffer(ptr); + return Error::Ok; +} + +AOTITorchError aoti_torch_mps_memcpy( + void* buffer, + size_t constant_offset, + size_t bytes_read, + size_t data_size, + uint8_t* constants_start) { + if (!buffer || !constants_start) return Error::InvalidArgument; + + auto* dst = static_cast(buffer) + constant_offset; + std::memcpy(dst, constants_start + bytes_read, data_size); + + // Register the sub-region so GPU can see it. + if (constant_offset != 0) { + metal_buffer_nocopy(dst, data_size, /*map_ptr_to_buffer=*/true); + } + return Error::Ok; +} + +AOTITorchError aoti_torch_mps_copy_buffer( + void* src_buffer, + void* dst_buffer, + size_t data_size, + size_t src_offset, + size_t dst_offset) { + if (!src_buffer || !dst_buffer) return Error::InvalidArgument; + // Unified memory — direct memcpy. + auto* src = static_cast(src_buffer) + src_offset; + auto* dst = static_cast(dst_buffer) + dst_offset; + std::memcpy(dst, src, data_size); + return Error::Ok; +} + +// ===================================================================== +// MPS device-type override +// ===================================================================== + +__attribute__((__visibility__("default"))) int32_t +aoti_torch_device_type_mps() { + return 13; // Matches c10/core/DeviceType.h::MPS +} + +AOTITorchError aoti_torch_get_device_type( + AOTITensorHandle tensor, + int32_t* ret_device_type) { + (void)tensor; + if (ret_device_type == nullptr) { + return Error::InvalidArgument; + } + *ret_device_type = aoti_torch_device_type_mps(); + return Error::Ok; +} + +} // extern "C" + +// --------------------------------------------------------------------- +// Missing dtype shim (workaround for upstream gap) +// +// backends/aoti/common_shims_slim.cpp defines aoti_torch_dtype_float32(), +// _bfloat16(), _int*(), etc. but not _float16. Without this symbol the +// AOTI-generated .so for an fp16 model dlopens to a partially-resolved +// state and dies with SIGSEGV on first use of the missing trampoline. +// +// We define it locally so the symbol resolves at dlopen time. Note: actual +// fp16 execution is NOT supported because slim::c10::ScalarType doesn't +// include Half (creating a SlimTensor with dtype=5 will assert in +// check_supportive). This shim only satisfies the linker. +extern "C" { +int32_t aoti_torch_dtype_float16() { + return 5; // PyTorch's float16 dtype code (c10::ScalarType::Half) +} +} // extern "C" + +} // namespace metal +} // namespace backends +} // namespace executorch diff --git a/backends/apple/metal/runtime/shims/v2/aoti_tensor.h b/backends/apple/metal/runtime/shims/v2/aoti_tensor.h new file mode 100644 index 00000000000..96492578559 --- /dev/null +++ b/backends/apple/metal/runtime/shims/v2/aoti_tensor.h @@ -0,0 +1,132 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// AOTI tensor + memory layer for the v2 Metal backend. +// +// One-stop shop for everything tensor-and-memory the AOTI .so calls into: +// - Tensor lifecycle: create_tensor_from_blob_v2 / empty_strided / delete / +// copy_ / _reinterpret_tensor / new_tensor_handle / cleanup_memory +// - MPS buffer shims: mps_malloc / mps_free / mps_memcpy / mps_copy_buffer +// - MPS device-type override: aoti_torch_get_device_type / device_type_mps +// +// All operate on SlimTensor handles (see aoti_types.h). Implementations +// route through the metal_* C ABI in runtime.h. +// +// NOTE: These symbols intentionally collide with v1's. v1 and v2 must +// live in separate static libraries; users link exactly one. + +#pragma once + +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace metal { + +extern "C" { + +// ===================================================================== +// Global storage (definitions in aoti_tensor.cpp) +// ===================================================================== + +extern std::unordered_map memory_to_n_tensor; +// Maps raw SlimTensor* -> unique_ptr for O(1) lookup/deletion. +extern std::unordered_map> tensors; + +// ===================================================================== +// Tensor lifecycle +// ===================================================================== + +AOTITorchError aoti_torch_create_tensor_from_blob_v2( + void* data, + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr, + int64_t storage_offset, + int32_t dtype, + int32_t device_type, + int32_t device_index, + AOTITensorHandle* ret_new_tensor, + int32_t layout, + const uint8_t* opaque_metadata, + int64_t opaque_metadata_size); + +AOTITorchError aoti_torch_empty_strided( + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr, + int32_t dtype, + int32_t device_type, + int32_t device_index, + AOTITensorHandle* ret_new_tensor); + +AOTITorchError aoti_torch_delete_tensor_object(AOTITensorHandle tensor); + +AOTITorchError aoti_torch_copy_( + AOTITensorHandle self, + AOTITensorHandle src, + int32_t non_blocking); + +AOTITorchError aoti_torch__reinterpret_tensor( + AOTITensorHandle self, + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr, + int64_t storage_offset, + AOTITensorHandle* ret_new_tensor); + +AOTITorchError aoti_torch_new_tensor_handle( + Tensor* orig_handle, + Tensor** new_handle); + +void cleanup_memory(); + +// ===================================================================== +// MPS buffer shims (the AOTI .so calls these directly for raw buffers) +// ===================================================================== + +AOTITorchError aoti_torch_mps_malloc(void** buffer, size_t num_bytes); +AOTITorchError aoti_torch_mps_free(void* ptr); + +AOTITorchError aoti_torch_mps_memcpy( + void* buffer, + size_t constant_offset, + size_t bytes_read, + size_t data_size, + uint8_t* constants_start); + +AOTITorchError aoti_torch_mps_copy_buffer( + void* src_buffer, + void* dst_buffer, + size_t data_size, + size_t src_offset, + size_t dst_offset); + +// ===================================================================== +// MPS device-type override +// ===================================================================== + +// Returns the MPS device-type code (13). Stable across v1/v2. +int32_t aoti_torch_device_type_mps(); + +// Override common_shims_slim's default that reports the SlimTensor's +// actual device type (CPU). For v2, all tensors going through the AOTI +// shim layer are conceptually MPS regardless of how SlimTensor models +// them internally. +AOTITorchError aoti_torch_get_device_type( + AOTITensorHandle tensor, + int32_t* ret_device_type); + +} // extern "C" + +} // namespace metal +} // namespace backends +} // namespace executorch diff --git a/backends/apple/metal/runtime/shims/v2/aoti_types.h b/backends/apple/metal/runtime/shims/v2/aoti_types.h new file mode 100644 index 00000000000..a103180280d --- /dev/null +++ b/backends/apple/metal/runtime/shims/v2/aoti_types.h @@ -0,0 +1,64 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// SlimTensor-flavored types for the v2 AOTI Metal backend. +// +// The v2 backend uses SlimTensor (executorch::backends::aoti::slim::SlimTensor) +// as its tensor type. This header is the v2 analogue of types.h. + +#pragma once + +#include +#include +#include + +namespace executorch { +namespace backends { +namespace metal { + +// common_shims_slim.h defines `using Tensor = slim::SlimTensor;` in the +// `executorch::backends::aoti` namespace. Re-export here so callers in +// `executorch::backends::metal` can write `Tensor` unqualified. +using executorch::runtime::Error; +using executorch::backends::aoti::Tensor; +using executorch::backends::aoti::AOTIRuntimeError; +using executorch::backends::aoti::AOTITorchError; + +extern "C" { + +// AOTI passes opaque tensor handles across the C ABI. In v2, these are +// SlimTensor pointers. +using AOTITensorHandle = Tensor*; + +} // extern "C" + +// Map int32_t dtype code (PyTorch convention) to slim::c10::ScalarType. +// Mirrors aoti::dtype_to_scalar_type but returns slim's enum instead of +// executorch::aten::ScalarType. +inline executorch::backends::aoti::slim::c10::ScalarType +dtype_to_c10_scalar_type(int32_t dtype) { + using SST = executorch::backends::aoti::slim::c10::ScalarType; + // Enum values match PyTorch's standard dtype encoding, so a value-cast is + // safe for the supported set. Unsupported values (e.g. Half=5) will fail + // SlimTensor's check_supportive assertion downstream. + switch (dtype) { + case 0: return SST::Byte; + case 1: return SST::Char; + case 2: return SST::Short; + case 3: return SST::Int; + case 4: return SST::Long; + case 6: return SST::Float; + case 11: return SST::Bool; + case 15: return SST::BFloat16; + default: return SST::Undefined; + } +} + +} // namespace metal +} // namespace backends +} // namespace executorch diff --git a/backends/apple/metal/runtime/shims/v2/delegate_handle.h b/backends/apple/metal/runtime/shims/v2/delegate_handle.h new file mode 100644 index 00000000000..233d3729625 --- /dev/null +++ b/backends/apple/metal/runtime/shims/v2/delegate_handle.h @@ -0,0 +1,91 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// SlimTensor-flavored AOTI delegate handle for the v2 Metal backend. +// +// Mirrors aoti/aoti_delegate_handle.h, but lives in +// `executorch::backends::metal` so it doesn't collide with the v1 header's +// `aoti::Tensor = etensor::Tensor` typedef. (v1's header and slim's +// `common_shims_slim.h` both define `aoti::Tensor` to different types, +// so they can't be in the same TU.) +// +// At the ABI level the AOTI .so calls these function pointers with +// opaque pointers — the underlying type doesn't matter. We use SlimTensor +// here so the rest of the v2 code stays consistent. + +#pragma once + +#include +#include +#include + +namespace executorch { +namespace backends { +namespace metal { + +extern "C" { + +// Forward declarations for AOT Inductor model container +struct AOTInductorModelContainerOpaque; +using AOTInductorModelContainerHandle = AOTInductorModelContainerOpaque*; +using AOTInductorStreamHandle = void*; +using AOTIProxyExecutorHandle = void*; + +// Function pointer types for AOT Inductor model container operations. +// Use Tensor (= SlimTensor) for the run() handles. +using AOTInductorModelContainerCreateWithDeviceFunc = AOTIRuntimeError (*)( + AOTInductorModelContainerHandle* container_handle, + size_t num_models, + const char* device_str, + const char* cubin_dir); + +using AOTInductorModelContainerDeleteFunc = + AOTIRuntimeError (*)(AOTInductorModelContainerHandle container_handle); + +using AOTInductorModelContainerGetNumInputsFunc = AOTIRuntimeError (*)( + AOTInductorModelContainerHandle container_handle, + size_t* num_inputs); + +using AOTInductorModelContainerGetNumOutputsFunc = AOTIRuntimeError (*)( + AOTInductorModelContainerHandle container_handle, + size_t* num_outputs); + +using AOTInductorModelContainerRunFunc = AOTIRuntimeError (*)( + AOTInductorModelContainerHandle container_handle, + Tensor** input_handles, + size_t num_inputs, + Tensor** output_handles, + size_t n_outputs, + AOTInductorStreamHandle stream_handle, + AOTIProxyExecutorHandle proxy_executor_handle); + +using AOTInductorModelUpdateConstantsFromBlobFunc = AOTIRuntimeError (*)( + AOTInductorModelContainerHandle container_handle, + const uint8_t* weight_blob_ptr); + +} // extern "C" + +// AOTI Delegate Handle structure (v2: minimal subset used by +// metal_backend_v2.cpp; constant-management fields omitted). +struct AOTIDelegateHandle { + void* so_handle; + std::string so_path; + AOTInductorModelContainerHandle container_handle; + std::string method_name; + + AOTInductorModelContainerCreateWithDeviceFunc create_with_device; + AOTInductorModelContainerDeleteFunc delete_container; + AOTInductorModelContainerGetNumInputsFunc get_num_inputs; + AOTInductorModelContainerGetNumOutputsFunc get_num_outputs; + AOTInductorModelContainerRunFunc run; + AOTInductorModelUpdateConstantsFromBlobFunc update_constants_from_blob; +}; + +} // namespace metal +} // namespace backends +} // namespace executorch diff --git a/backends/apple/metal/runtime/shims/v2/runtime.h b/backends/apple/metal/runtime/shims/v2/runtime.h new file mode 100644 index 00000000000..1f64d093ccd --- /dev/null +++ b/backends/apple/metal/runtime/shims/v2/runtime.h @@ -0,0 +1,81 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Pure infrastructure layer for the v2 AOTI Metal backend. +// +// Wraps MetalStream (from portable/runtime/metal_v2/) and exposes: +// - getMetalStream() / getMetalDevice() / metal_set_flush_interval() +// - the "metal_*" C ABI for buffer management used by the AOTI shims +// - synchronize_metal_stream() drain +// +// No AOTI knowledge here. The aoti_* shims (aoti_tensor, aoti_kernel, +// aoti_ops) sit on top of this file. + +#pragma once + +#include +#include + +#ifdef __OBJC__ +#import +typedef id MTLDevice_t; +#else +typedef void* MTLDevice_t; +#endif + +namespace executorch::backends::metal_v2 { +class MetalStream; +} // namespace executorch::backends::metal_v2 + +namespace executorch { +namespace backends { +namespace metal { + +// The single MetalStream backing all v2 AOTI Metal execution (thread-local). +metal_v2::MetalStream* getMetalStream(); + +// Free-function wrappers that .cpp files can call without pulling in +// MetalStream.h (which transitively imports Metal/Metal.h and won't +// compile in non-ObjC++ translation units). +void metal_set_flush_interval(int dispatches); + +MTLDevice_t getMetalDevice(); + +#ifdef __cplusplus +extern "C" { +#endif + +void* metal_allocate_buffer(long bytes); +// Like metal_allocate_buffer but does NOT register the returned pointer as +// a device pointer in g_device_ptrs. metal_is_device_pointer() will return +// false for the returned pointer. Used by aoti_torch_mps_malloc to match +// the pre-reorg behavior where mps_malloc'd buffers weren't tracked. +void* metal_allocate_buffer_untracked(long bytes); +void metal_deallocate_buffer(void* ptr); +bool metal_is_device_pointer(void* ptr); +int metal_copy_memory( + void* dst, + const void* src, + size_t nbytes, + bool src_is_device, + bool dst_is_device); +void metal_cleanup_resources(); +bool metal_buffer_nocopy(void* ptr, size_t nbytes, bool map_ptr_to_buffer); +// Like metal_buffer_nocopy but does NOT add the pointer to the +// device-pointer tracking set. Used by aoti_torch_mps_memcpy to register +// constant sub-regions without affecting metal_is_device_pointer(). +bool metal_register_external_buffer_only(void* ptr, size_t nbytes); +void synchronize_metal_stream(); + +#ifdef __cplusplus +} +#endif + +} // namespace metal +} // namespace backends +} // namespace executorch diff --git a/backends/apple/metal/runtime/shims/v2/runtime.mm b/backends/apple/metal/runtime/shims/v2/runtime.mm new file mode 100644 index 00000000000..9ecec6610bd --- /dev/null +++ b/backends/apple/metal/runtime/shims/v2/runtime.mm @@ -0,0 +1,158 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// MetalStream wrapper + buffer C ABI for v2. +// +// Buffer management and stream access. All AOTI dispatch logic lives in +// the aoti_* files in this directory. +// +// Device-pointer tracking is kept here (rather than in MetalStream) so we +// don't have to extend the portable MetalStream API. We track every +// pointer we hand back from metal_allocate_buffer or successfully register +// via metal_buffer_nocopy and use that set as the source of truth for +// metal_is_device_pointer. + +#import + +#include +#include +#include +#include + +#include +#include +#include + +namespace executorch { +namespace backends { +namespace metal { + +using metal_v2::MetalStream; + +namespace { + +// Pointers we know are GPU-accessible (allocated via alloc() or +// successfully registered via registerExternalBuffer()). +std::mutex g_device_ptrs_mutex; +std::unordered_set g_device_ptrs; + +void track_device_ptr(void* ptr) { + if (!ptr) return; + std::lock_guard lock(g_device_ptrs_mutex); + g_device_ptrs.insert(ptr); +} + +void untrack_device_ptr(void* ptr) { + if (!ptr) return; + std::lock_guard lock(g_device_ptrs_mutex); + g_device_ptrs.erase(ptr); +} + +bool is_tracked_device_ptr(void* ptr) { + if (!ptr) return false; + std::lock_guard lock(g_device_ptrs_mutex); + return g_device_ptrs.count(ptr) != 0; +} + +} // namespace + +MetalStream* getMetalStream() { + // Thread-local stream: each thread that calls into the v2 shim layer gets + // its own MetalStream. Avoids races on the shared command buffer when + // execute() is invoked concurrently from multiple threads. Trade-off: + // kernel cache and buffer pool are per-thread, so shaders are recompiled + // on each new thread. + return MetalStream::getThreadLocal(); +} + +void metal_set_flush_interval(int dispatches) { + getMetalStream()->setFlushInterval(dispatches); +} + +MTLDevice_t getMetalDevice() { + return getMetalStream()->device(); +} + +extern "C" { + +void* metal_allocate_buffer(long bytes) { + if (bytes <= 0) return nullptr; + void* ptr = getMetalStream()->alloc(static_cast(bytes)); + if (ptr) track_device_ptr(ptr); + return ptr; +} + +void* metal_allocate_buffer_untracked(long bytes) { + if (bytes <= 0) return nullptr; + return getMetalStream()->alloc(static_cast(bytes)); +} + +void metal_deallocate_buffer(void* ptr) { + if (!ptr) return; + getMetalStream()->free(ptr); + untrack_device_ptr(ptr); +} + +bool metal_is_device_pointer(void* ptr) { + return is_tracked_device_ptr(ptr); +} + +int metal_copy_memory( + void* dst, + const void* src, + size_t nbytes, + bool src_is_device, + bool /*dst_is_device*/) { + if (!src || !dst || nbytes == 0) return -1; + + // Apple Silicon unified memory: CPU and GPU share the same address space. + // Just need to ensure GPU writes are visible before the CPU reads them. + // wait() = "drain whatever's in flight"; we don't want to *trigger* any + // new submission here, just block until the GPU is caught up. + if (src_is_device) { + getMetalStream()->wait(); + } + std::memcpy(dst, src, nbytes); + return 0; +} + +void metal_cleanup_resources() { + // MetalStream manages its own pool. Clear our local tracking so a fresh + // process state doesn't leak entries. + std::lock_guard lock(g_device_ptrs_mutex); + g_device_ptrs.clear(); +} + +bool metal_buffer_nocopy(void* ptr, size_t nbytes, bool /*map_ptr_to_buffer*/) { + if (!ptr || nbytes == 0) return false; + bool ok = getMetalStream()->registerExternalBuffer(ptr, nbytes); + if (ok) track_device_ptr(ptr); + return ok; +} + +bool metal_register_external_buffer_only(void* ptr, size_t nbytes) { + if (!ptr || nbytes == 0) return false; + return getMetalStream()->registerExternalBuffer(ptr, nbytes); +} + +// Public C-ABI: drain the stream so GPU writes are visible to the CPU. +// Used by metal_backend_v2::execute() after handle->run(); also exported +// for any AOTI-emitted code that wants an explicit sync point. Marks the +// end of an AOTI execute — calls endExecute() to reset per-execute state +// (currentDispatchIdx_, icbRecordedThisIter_) needed for replay +// correctness across iterations. +void synchronize_metal_stream() { + getMetalStream()->wait(); + getMetalStream()->endExecute(); +} + +} // extern "C" + +} // namespace metal +} // namespace backends +} // namespace executorch diff --git a/backends/portable/CMakeLists.txt b/backends/portable/CMakeLists.txt new file mode 100644 index 00000000000..3aa58eeb5fc --- /dev/null +++ b/backends/portable/CMakeLists.txt @@ -0,0 +1,132 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Portable Backend Runtime +# Architecture: +# - runtime/PortableBackend.cpp: BackendInterface (init/execute/destroy) +# - runtime/GraphRuntime.h: Abstract interface for hardware runtimes +# - runtime/OpRegistry.h: Generic op registration (template-based, backend-agnostic) +# - runtime/cpu/CpuRuntime.cpp: CPU implementation of GraphRuntime +# - runtime_v2/cpu/CpuOps.cpp: CPU op implementations using portable kernels +# - runtime/metal/MetalRuntime.mm: Metal GPU implementation (Apple only) +# - runtime/metal_v2/: High-performance Metal with ICB replay (Apple only) + +# Enable Objective-C++ for .mm files +enable_language(OBJCXX) + +# Option to enable metal_v2 (high-performance Metal with ICB, ResidencySet, etc.) +option(EXECUTORCH_PORTABLE_USE_METAL_V2 "Use metal_v2 runtime with Metal 4 features" ON) + +set(_portable_backend__srcs + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/PortableBackend.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/cpu/CpuRuntime.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/runtime_v2/cpu/CpuOps.cpp +) + +# Add Metal backend on Apple platforms +if(APPLE) + if(EXECUTORCH_PORTABLE_USE_METAL_V2) + # metal_v2: High-performance runtime with ICB replay, ResidencySet, Heap + message(STATUS "Portable backend: Using metal_v2 (Metal 4 features enabled)") + list(APPEND _portable_backend__srcs + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/metal_v2/MetalStream.mm + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/metal_v2/MetalHeap.mm + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/metal_v2/MetalBufferPool.mm + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/metal_v2/MetalKernel.mm + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/metal_v2/MetalKernelCompiler.mm + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/metal_v2/MetalOp.mm + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/metal_v2/MetalOpRegistry.mm + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/metal_v2/MetalRuntime.mm + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/metal_v2/ops/BinaryOps.mm + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/metal_v2/ops/UnaryOps.mm + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/metal_v2/ops/MatMulOp.mm + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/metal_v2/ops/MPSGraphOp.mm + ) + else() + # Original Metal runtime + message(STATUS "Portable backend: Using original Metal runtime") + list(APPEND _portable_backend__srcs + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/metal/MetalRuntime.mm + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/metal/MetalOp.mm + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/metal/MetalOpRegistry.mm + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/metal/MetalOps.mm + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/metal/ops/ElementwiseOps.mm + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/metal/ops/MatMulOps.mm + ) + endif() +endif() + +add_library(portable_backend ${_portable_backend__srcs}) + +target_include_directories(portable_backend + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/runtime + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/cpu +) + +# Link executorch_core (the minimal runtime) and program_schema (for flatbuffer types) +target_link_libraries(portable_backend + PRIVATE + executorch_core + program_schema + portable_ops_lib +) + +# Link Metal framework on Apple +if(APPLE) + target_link_libraries(portable_backend + PRIVATE + "-framework Metal" + "-framework Foundation" + "-framework MetalPerformanceShaders" + "-framework MetalPerformanceShadersGraph" + ) + + if(EXECUTORCH_PORTABLE_USE_METAL_V2) + target_compile_definitions(portable_backend PRIVATE PORTABLE_HAS_METAL_V2=1) + else() + target_compile_definitions(portable_backend PRIVATE PORTABLE_HAS_METAL=1) + endif() + + # Set .mm files to compile as Objective-C++ without ARC (code manages memory manually) + set_source_files_properties( + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/metal_v2/MetalStream.mm + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/metal_v2/MetalHeap.mm + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/metal_v2/MetalBufferPool.mm + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/metal_v2/MetalKernel.mm + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/metal_v2/MetalKernelCompiler.mm + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/metal_v2/MetalOp.mm + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/metal_v2/MetalOpRegistry.mm + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/metal_v2/MetalRuntime.mm + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/metal_v2/ops/BinaryOps.mm + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/metal_v2/ops/UnaryOps.mm + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/metal_v2/ops/MatMulOp.mm + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/metal_v2/ops/MPSGraphOp.mm + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/metal/MetalRuntime.mm + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/metal/MetalOp.mm + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/metal/MetalOpRegistry.mm + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/metal/MetalOps.mm + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/metal/ops/ElementwiseOps.mm + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/metal/ops/MatMulOps.mm + PROPERTIES + LANGUAGE OBJCXX + COMPILE_FLAGS "-fno-objc-arc" + ) +endif() + +# Ensure schema is generated before compiling +add_dependencies(portable_backend program_schema) + +install( + TARGETS portable_backend + EXPORT ExecuTorchTargets + DESTINATION ${CMAKE_INSTALL_LIBDIR} +) + +# v2 portable backend (new architecture per +# runtime/PORTABLE_BACKEND_API_PROPOSAL.md). Registers as +# "PortableBackend_v2" so it coexists with v1. +add_subdirectory(runtime_v2) diff --git a/backends/portable/__init__.py b/backends/portable/__init__.py new file mode 100644 index 00000000000..184be38853c --- /dev/null +++ b/backends/portable/__init__.py @@ -0,0 +1,33 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Portable Backend + +A unified execution backend that can dispatch ops across multiple runtimes +(CPU, Metal, Vulkan) at runtime initialization based on device capabilities. + +Key features: +1. Uses the standard ExecuTorch ExecutionPlan format (no custom serialization) +2. Runtime partitioning happens in C++ based on has_op() queries +3. Supports automatic memory aliasing via AllocationDetails.memory_id +4. Reuses existing portable ops from kernels/portable/cpu/ + +Usage: + from executorch.exir import to_edge + from executorch.backends.portable import PortablePartitioner + + edge_program = to_edge(exported_program) + portable_program = edge_program.to_backend(PortablePartitioner()) +""" + +from .partitioner.portable_partitioner import PortablePartitioner +from .preprocess import PortableBackend + +__all__ = [ + "PortablePartitioner", + "PortableBackend", +] diff --git a/backends/portable/build_and_run.sh b/backends/portable/build_and_run.sh new file mode 100755 index 00000000000..437b16f22aa --- /dev/null +++ b/backends/portable/build_and_run.sh @@ -0,0 +1,253 @@ +#!/bin/bash +# Portable Backend Build & Test Script +# Usage: ./build_and_run.sh [generate|build|run|all] + +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +ET_ROOT="$(cd "$SCRIPT_DIR/../.." && pwd)" +BUILD_DIR="$ET_ROOT/cmake-out" +PYTHON="${PYTHON:-/Users/scroy/miniconda3/envs/et-testing/bin/python}" +MODEL_PATH="/tmp/add_delegated.pte" + +# Metal v2 options +USE_METAL_V2="${USE_METAL_V2:-1}" # Enable metal_v2 by default +METAL_HEAP_SIZE="${METAL_HEAP_SIZE:-536870912}" # 512MB default + +# Colors for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +CYAN='\033[0;36m' +NC='\033[0m' # No Color + +log() { echo -e "${GREEN}[portable]${NC} $1"; } +warn() { echo -e "${YELLOW}[portable]${NC} $1"; } +error() { echo -e "${RED}[portable]${NC} $1"; } +info() { echo -e "${CYAN}[portable]${NC} $1"; } + +generate() { + log "Generating test models..." + cd "$ET_ROOT" + $PYTHON backends/portable/test_export.py + log "Models generated at /tmp/*_delegated.pte" +} + +configure() { + log "Configuring CMake..." + mkdir -p "$BUILD_DIR" + cd "$BUILD_DIR" + + # Metal v2 option + local METAL_V2_OPT="ON" + if [ "$USE_METAL_V2" = "0" ]; then + METAL_V2_OPT="OFF" + info "Using original Metal runtime" + else + info "Using metal_v2 runtime with Metal 4 features" + fi + + cmake "$ET_ROOT" \ + -DEXECUTORCH_BUILD_PORTABLE_BACKEND=ON \ + -DEXECUTORCH_PORTABLE_USE_METAL_V2=$METAL_V2_OPT + + log "CMake configured" + if [ "$METAL_V2_OPT" = "ON" ]; then + info " metal_v2: ENABLED (ICB, ResidencySet, Heap, Binary Archives)" + info " Metal 4 features: GPU addresses, MTLResidencySet" + fi +} + +build() { + log "Building portable_backend and executor_runner..." + + # Configure if needed + if [ ! -f "$BUILD_DIR/CMakeCache.txt" ]; then + configure + fi + + cd "$BUILD_DIR" + cmake --build . --target portable_backend executor_runner -j4 + + if [ $? -eq 0 ]; then + log "Build succeeded!" + else + error "Build failed!" + exit 1 + fi +} + +run() { + local model="${1:-$MODEL_PATH}" + + if [ ! -f "$model" ]; then + error "Model not found: $model" + error "Run './build_and_run.sh generate' first" + exit 1 + fi + + if [ ! -f "$BUILD_DIR/executor_runner" ]; then + error "executor_runner not found" + error "Run './build_and_run.sh build' first" + exit 1 + fi + + log "Running model: $model" + if [ "$USE_METAL_V2" = "1" ]; then + info "Using metal_v2 runtime" + fi + echo "----------------------------------------" + "$BUILD_DIR/executor_runner" --model_path "$model" + echo "----------------------------------------" + log "Done!" +} + +bench() { + local model="${1:-$MODEL_PATH}" + local iterations="${2:-100}" + + if [ ! -f "$model" ]; then + error "Model not found: $model" + exit 1 + fi + + log "Benchmarking model: $model ($iterations iterations)" + if [ "$USE_METAL_V2" = "1" ]; then + info "metal_v2 enabled - expect fast replay after first inference" + fi + echo "----------------------------------------" + # Run with timing + time for i in $(seq 1 $iterations); do + "$BUILD_DIR/executor_runner" --model_path "$model" 2>/dev/null + done + echo "----------------------------------------" + log "Benchmark complete" +} + +all() { + generate + build + run +} + +# Run the suite of metal_v2 test models exported by test_export.py. +# Exits with non-zero if any model fails. +test_metal_v2() { + if [ ! -f "$BUILD_DIR/executor_runner" ]; then + configure + build + fi + + log "Running metal_v2 test suite..." + local models=( + "/tmp/add_delegated.pte" # binary vv (add/mul/sub kernels) + "/tmp/matmul_delegated.pte" # naive matmul (small) + "/tmp/matmul_large_delegated.pte" # matmul_simd (double-buffered, vec4) + "/tmp/gemv_delegated.pte" # gemv (N=1) + "/tmp/matmul_m1_delegated.pte" # gemv_t (M=1) + "/tmp/linear_delegated.pte" # matmul_nt via weight.t() + "/tmp/attention_qk_delegated.pte" # matmul_nt (Q @ K^T) + "/tmp/matmul_tn_delegated.pte" # matmul_tn (A^T @ B) + "/tmp/bmm_delegated.pte" # batched matmul_simd (tgid.z) + "/tmp/all_ops_delegated.pte" # mixed ops + ) + + local failed=0 + for m in "${models[@]}"; do + if [ ! -f "$m" ]; then + warn "Missing $m -- run './build_and_run.sh generate' first" + failed=1 + continue + fi + info " → $(basename $m)" + if ! "$BUILD_DIR/executor_runner" --model_path "$m" >/tmp/_etest.out 2>&1; then + error " FAILED ($m)" + tail -20 /tmp/_etest.out | sed 's/^/ /' + failed=1 + else + log " OK" + fi + done + + if [ $failed -ne 0 ]; then + error "metal_v2 test suite: FAIL" + exit 1 + fi + log "metal_v2 test suite: PASS" +} + +clean() { + log "Cleaning build directory..." + rm -rf "$BUILD_DIR" + log "Clean complete" +} + +usage() { + echo "Portable Backend Build & Test Script" + echo "" + echo "Usage: $0 [options]" + echo "" + echo "Commands:" + echo " generate Export test models (add_delegated.pte, etc.)" + echo " configure Configure CMake build" + echo " build Build portable_backend and executor_runner" + echo " run [path] Run model (default: /tmp/add_delegated.pte)" + echo " bench [path] [n] Benchmark model (default: 100 iterations)" + echo " test_metal_v2 Run all metal_v2 test models (build + run sweep)" + echo " all Generate, build, and run" + echo " clean Remove build directory" + echo "" + echo "Environment Variables:" + echo " USE_METAL_V2=1 Enable metal_v2 runtime (default: 1)" + echo " USE_METAL_V2=0 Use original Metal runtime" + echo " METAL_HEAP_SIZE=N Heap size in bytes (default: 512MB)" + echo "" + echo "Metal v2 Features (when USE_METAL_V2=1):" + echo " - ICB (Indirect Command Buffer) for command replay" + echo " - GPU addresses via argument buffers (Metal 4)" + echo " - MTLResidencySet for GPU-resident memory (Metal 4)" + echo " - MTLHeap for fast allocation" + echo " - Binary Archives for pre-compiled shaders" + echo " - LRU Buffer Pool" + echo "" + echo "Examples:" + echo " $0 all # Full test cycle with metal_v2" + echo " USE_METAL_V2=0 $0 all # Use original Metal runtime" + echo " $0 bench /tmp/linear_delegated.pte 50 # Benchmark 50 iterations" +} + +# Main +case "${1:-}" in + generate) + generate + ;; + configure) + configure + ;; + build) + build + ;; + run) + run "$2" + ;; + bench) + bench "$2" "$3" + ;; + test_metal_v2) + test_metal_v2 + ;; + all) + all + ;; + clean) + clean + ;; + -h|--help|"") + usage + ;; + *) + error "Unknown command: $1" + usage + exit 1 + ;; +esac diff --git a/backends/portable/partitioner/__init__.py b/backends/portable/partitioner/__init__.py new file mode 100644 index 00000000000..68e2f6419d6 --- /dev/null +++ b/backends/portable/partitioner/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from .portable_partitioner import PortablePartitioner, PortableSupportedOperators + +__all__ = ["PortablePartitioner", "PortableSupportedOperators"] diff --git a/backends/portable/partitioner/portable_partitioner.py b/backends/portable/partitioner/portable_partitioner.py new file mode 100644 index 00000000000..b4ada33e314 --- /dev/null +++ b/backends/portable/partitioner/portable_partitioner.py @@ -0,0 +1,205 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Portable Backend Partitioner + +The portable partitioner marks ALL nodes as supported since the portable backend +has a CPU fallback that can execute any portable op. Runtime partitioning across +accelerators (Metal, Vulkan) happens in C++ based on has_op() queries. +""" + +from typing import final, List, Mapping, Tuple, Callable, Optional, Any, Dict + +import torch +from torch.fx import Node +from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner +from torch.fx.passes.operator_support import OperatorSupportBase + +from executorch.exir.backend.compile_spec_schema import CompileSpec +from executorch.exir.backend.partitioner import ( + DelegationSpec, + Partitioner, + PartitionResult, +) +from executorch.exir.backend.utils import tag_constant_data, tag_mutated_buffer + +from torch.export.exported_program import ExportedProgram + + +# Canonical list of ops the portable backend preserves (does NOT decompose) +# during edge lowering when used with `to_edge_transform_and_lower`. This is +# part of the universal-IR specification of the portable backend — it is +# maintained here, not user-configurable per call. +# +# Add an op here when you ship a dedicated C++ handler for it on the Metal +# (or other accelerator) side; otherwise the default decomposition pass at +# edge lowering will break it apart and we'll just dispatch the pieces. +# +# Note: ops marked "terminal" below are NOT in torch's core_aten or +# ExecuTorch's edge decomposition table today, so listing them is a no-op +# under current torch/ET. We list them anyway as defensive future-proofing +# and to document "we have a direct kernel for this op" intent. +_DEFAULT_PRESERVED_OPS = [ + torch.ops.aten.matmul.default, + torch.ops.aten.linear.default, + torch.ops.aten.scaled_dot_product_attention.default, +] + + + +class PortableSupportedOperators(OperatorSupportBase): + """ + Operator support checker for the Portable backend. + + Since the portable backend has a CPU fallback, it supports ALL operators. + The actual runtime partitioning across accelerators happens in C++. + """ + + def is_node_supported( + self, submodules: Mapping[str, torch.nn.Module], node: Node + ) -> bool: + # Skip placeholder and output nodes - they shouldn't be partitioned + if node.op in ("placeholder", "output", "get_attr"): + return False + + # Portable backend supports all call_function ops via CPU fallback + return node.op == "call_function" + + +@final +class PortablePartitioner(Partitioner): + """ + Partitioner for the Portable Backend. + + Unlike other backend partitioners that only claim ops they can accelerate, + the portable partitioner claims ALL ops since: + 1. CPU fallback can execute any portable op + 2. Runtime partitioning in C++ handles dispatch to accelerators + + Two usage paths: + + A) Classic to_edge + to_backend (default decomposition runs first): + edge_program = to_edge(exported_program) + edge_program = edge_program.to_backend(PortablePartitioner()) + + B) to_edge_transform_and_lower (preserves listed ops from decomposition): + edge_program = to_edge_transform_and_lower( + exported_program, + partitioner=[PortablePartitioner()], + ) + # The ops in _DEFAULT_PRESERVED_OPS are kept intact (no decomposition) + # so our accelerator kernels see them whole. The list is maintained + # by this partitioner — extend `_DEFAULT_PRESERVED_OPS` (in code) when + # you ship a new dedicated handler. Per-call overrides are NOT + # exposed: the portable backend is a universal IR and the preserve + # list is part of its specification. + """ + + def __init__( + self, + compile_options: Optional[Dict[str, Any]] = None, + ) -> None: + self.options = compile_options or {} + compile_spec = self._parse_compile_options(self.options) + # Import here to avoid circular dependency + from ..preprocess import PortableBackend + self.delegation_spec = DelegationSpec(PortableBackend.__name__, compile_spec) + + def _parse_compile_options(self, options: Dict[str, Any]) -> List[CompileSpec]: + """Convert compile options dict to CompileSpec list.""" + compile_specs = [] + + for key, value in options.items(): + if isinstance(value, bool): + value_bytes = value.to_bytes(1, byteorder="little") + compile_specs.append(CompileSpec(key, value_bytes)) + elif isinstance(value, int): + value_bytes = value.to_bytes(4, byteorder="little") + compile_specs.append(CompileSpec(key, value_bytes)) + + return compile_specs + + def ops_to_not_decompose( + self, ep: ExportedProgram + ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[Node], bool]]]: + """ + Return ops that should NOT be decomposed during edge lowering. + + Called by `to_edge_transform_and_lower` BEFORE partitioning. The ops + returned here are kept whole (skipped by the default decomposition + pass), so our backend kernels see them in their original form. + + The preserve list (`_DEFAULT_PRESERVED_OPS`) is the canonical list + maintained by the portable backend — it is part of the universal IR + specification and is not user-configurable. Extend the constant in + this file when you ship a new dedicated handler. + + Behavior: + - If the graph is already partitioned (contains lowered_module + get_attr nodes), return empty — partitioning has run, decomposition + decisions are settled. + - Otherwise return the preserve list, intersected with ops that + actually appear in `ep` (no point listing ops the graph doesn't + have). + + The second tuple element (filter callable) is None: we apply the rule + uniformly to every node whose target is in the preserve list. + """ + # Already-partitioned graph -> nothing to preserve. + for node in ep.graph.nodes: + if node.op == "get_attr" and "lowered_module" in node.name: + return ([], None) + + # Intersect preserve list with ops actually present in the graph. + present: List[torch._ops.OpOverload] = [] + seen = set() + for node in ep.graph.nodes: + if node.op != "call_function": + continue + if not isinstance(node.target, torch._ops.OpOverload): + continue + if node.target in _DEFAULT_PRESERVED_OPS and node.target not in seen: + present.append(node.target) + seen.add(node.target) + + return (present, None) + + def partition(self, exported_program: ExportedProgram) -> PartitionResult: + """ + Partition the exported program for the portable backend. + + Since portable supports everything, this partitions the entire graph + into a single delegation block. + """ + partition_tags = {} + + # Use CapabilityBasedPartitioner with our "support everything" checker + capability_partitioner = CapabilityBasedPartitioner( + exported_program.graph_module, + PortableSupportedOperators(), + allows_single_node_partition=True, + ) + + partition_list = capability_partitioner.propose_partitions() + + for partition in partition_list: + for node in partition.nodes: + tag = f"tag{partition.id}" + node.meta["delegation_tag"] = tag + partition_tags[tag] = self.delegation_spec + + # Tag constant data for proper handling + tag_constant_data(exported_program) + # Tag mutated buffer placeholders so they're owned by our delegate + # (KV-cache style state flows through the delegated subgraph; the + # writeback copy_ collapses out of the top-level chain). + tag_mutated_buffer(exported_program) + + return PartitionResult( + tagged_exported_program=exported_program, + partition_tags=partition_tags, + ) diff --git a/backends/portable/preprocess_v2.py b/backends/portable/preprocess_v2.py new file mode 100644 index 00000000000..96340878e0d --- /dev/null +++ b/backends/portable/preprocess_v2.py @@ -0,0 +1,227 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +PortableBackend_v2 — AOT BackendDetails for the v2 portable runtime. + +The class name `PortableBackend_v2` is the backend_id used at runtime +to find the matching BackendInterface. C++ side registers it via +`register_backend({"PortableBackend_v2", ...})` in +runtime_v2/PortableBackend_v2.cpp. + +The preprocess() pipeline: +1. SpecPropPass to populate tensor specs. +2. (If buffer mutations exist) insert_write_back_for_buffers_pass to + add an explicit aten::copy_(buf, mut_src) at end of subgraph. +3. Spec-sharing for buffer mutations: make the mutation source's + TensorSpec be the SAME object as the buffer placeholder's TensorSpec. + The greedy memory planner walks specs by identity, so this yields a + single allocation slot for both — true in-place buffer mutation. + At runtime, the trailing aten::copy_ becomes a self-copy + (dispatcher's pointer-equality short-circuit makes it a no-op). +4. ExternalConstantsPass to tag constants for NDM storage. +5. Memory planning (greedy, allow_overlapping_allocations). +6. Emit and serialize. +""" + +from functools import partial +from typing import Any, Dict, final, List + +from executorch.exir.backend.backend_details import ( + BackendDetails, + CompileSpec, + ExportedProgram, + PreprocessResult, +) +from executorch.exir.backend.utils import DelegateMappingBuilder +from executorch.exir.emit import emit_program +from executorch.exir.memory_planning import greedy, MemoryPlanningAlgorithmSuite +from executorch.exir.passes import MemoryPlanningPass, SpecPropPass +from executorch.exir.passes.insert_write_back_for_buffers_pass import ( + insert_write_back_for_buffers_pass, +) +from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass +from executorch.exir.program._program import _transform +from executorch.exir._serialize._program import serialize_pte_binary, PTEFile + +from torch._export.verifier import Verifier + + +class _AnyOp(Verifier): + """Permissive verifier that allows any op (skip functional check).""" + + dialect = "TRAINING" + + def allowed_op_types(self): + from typing import Callable + + return (Callable,) + + +def _apply_passes(program: ExportedProgram, passes) -> ExportedProgram: + """Apply a sequence of passes to an ExportedProgram.""" + from executorch.exir.pass_base import ExportPass, PassBase + + for p in passes: + if isinstance(p, MemoryPlanningPass) and hasattr(p, "run"): + p.run(program.graph_module) + elif issubclass(type(p), (ExportPass, PassBase)): + if hasattr(p, "_exported_program"): + p._exported_program = program + program = _transform(program, p, override_verifiers=[_AnyOp]) + if isinstance(p, SpecPropPass): + p.update_placeholder_tensor_specs(program, program.graph_module) + else: + program = p(program) + + return program + + +def _parse_compile_spec(compile_specs: List[CompileSpec]) -> Dict[str, Any]: + """Parse compile specs into options dict.""" + options = {} + for spec in compile_specs: + if spec.key == "skip_memory_planning": + options[spec.key] = bool.from_bytes(spec.value, byteorder="little") + return options + + +@final +class PortableBackend_v2(BackendDetails): + """ + BackendDetails for the v2 portable backend. + + Class name `PortableBackend_v2` matches the runtime backend_id + registered in runtime_v2/PortableBackend_v2.cpp. + """ + + @classmethod + def preprocess( + cls, + program: ExportedProgram, + module_compile_spec: List[CompileSpec], + ) -> PreprocessResult: + """ + Preprocess the partitioned subgraph for v2 portable backend execution. + """ + compile_options = _parse_compile_spec(module_compile_spec) + skip_memory_planning = compile_options.get("skip_memory_planning", False) + + # Step 1: SpecPropPass to propagate tensor specs. + program = _apply_passes(program, [SpecPropPass()]) + + # Step 1b: Insert writeback copy_ ops for any mutable buffers that + # got pulled into our delegate by the partitioner's + # tag_mutated_buffer call. Then alias the mutation output's spec + # with the buffer placeholder's spec — so the AOT memory planner + # treats them as a single tensor and allocates ONE slot, achieving + # true in-place buffer mutation. The trailing aten::copy_ becomes + # a self-copy at runtime (handled by dispatcher's pointer check). + from torch.export.graph_signature import InputKind, OutputKind + + has_buffer_mutation = any( + ospec.kind == OutputKind.BUFFER_MUTATION + for ospec in program.graph_signature.output_specs + ) + if has_buffer_mutation: + gm, new_sig = insert_write_back_for_buffers_pass(program) + program._graph_module = gm + program._graph_signature = new_sig + # Re-propagate specs onto the newly inserted copy_ nodes. + program = _apply_passes(program, [SpecPropPass()]) + + # Spec-sharing trick: for each (buffer_placeholder, mutation + # source) pair, make the mutation source's TensorSpec be the + # SAME object as the buffer's TensorSpec. The greedy memory + # planner walks specs by identity, so a shared spec yields a + # single allocation. No wasted slot, no runtime override + # needed (the .pte naturally reports both at the same offset). + import torch + + sig = program.graph_signature + nodes_by_name = { + n.name: n for n in program.graph_module.graph.nodes + } + buf_target_to_node = { + ispec.target: nodes_by_name.get(ispec.arg.name) + for ispec in sig.input_specs + if ispec.kind == InputKind.BUFFER and ispec.target + } + for ospec in sig.output_specs: + if ospec.kind != OutputKind.BUFFER_MUTATION: + continue + buf_node = buf_target_to_node.get(ospec.target) + wb_node = nodes_by_name.get(ospec.arg.name) + if ( + buf_node is None + or wb_node is None + or wb_node.op != "call_function" + or wb_node.target != torch.ops.aten.copy_.default + or len(wb_node.args) < 2 + ): + continue + src_node = wb_node.args[1] + if not hasattr(src_node, "meta"): + continue + buf_spec = buf_node.meta.get("spec") + if buf_spec is None or "spec" not in src_node.meta: + continue + # Alias: src now shares buf's spec object. + src_node.meta["spec"] = buf_spec + + # Step 2: External constants pass — tag weights for NDM storage. + from executorch.exir.passes.external_constants_pass import ( + external_constants_pass, + ) + + external_constants_pass(program.graph_module) + + # Step 3: Memory planning (greedy, allows overlapping allocs). + if not skip_memory_planning: + greedy_memory_planning = partial( + greedy, allow_overlapping_allocations=True + ) + mem_planning_suite = MemoryPlanningAlgorithmSuite( + algo_list=[greedy_memory_planning] + ) + + # Workaround for memory planning without ToOutVarPass + program.graph_module.encounter_to_out_var_failure = True + + program = _apply_passes( + program, + [ + ConstraintBasedSymShapeEvalPass(), + MemoryPlanningPass(memory_planning_algo=mem_planning_suite), + ], + ) + + # Step 4: Emit the program. + delegate_mapping_builder = DelegateMappingBuilder(generated_identifiers=True) + emitter_output = emit_program(program) + + # Step 5: Build named data store from external constants. + from executorch.exir._serialize._named_data_store import NamedDataStore + + named_data_store = NamedDataStore() + if emitter_output.external_constant_buffer: + for tag, fqn_to_idx in emitter_output.external_constant_map.items(): + for fqn, idx in fqn_to_idx.items(): + data = emitter_output.external_constant_buffer[idx] + named_data_store.add_named_data(fqn, data) + + # Step 6: Serialize to bytes. + pte_file = PTEFile( + program=emitter_output.program, + mutable_data=emitter_output.mutable_data, + ) + serialized_bytes = bytes(serialize_pte_binary(pte_file)) + + return PreprocessResult( + processed_bytes=serialized_bytes, + debug_handle_map=emitter_output.debug_handle_map, + data_store_output=named_data_store.get_named_data_store_output(), + ) diff --git a/backends/portable/runtime/metal_v2/MetalBufferPool.mm b/backends/portable/runtime/metal_v2/MetalBufferPool.mm new file mode 100644 index 00000000000..a354c16e7c7 --- /dev/null +++ b/backends/portable/runtime/metal_v2/MetalBufferPool.mm @@ -0,0 +1,144 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#import "MetalStream.h" +#include +#include + +namespace executorch { +namespace backends { +namespace metal_v2 { + +//===----------------------------------------------------------------------===// +// MetalBufferPool +//===----------------------------------------------------------------------===// + +MetalBufferPool::MetalBufferPool(id device, size_t maxBytes) + : device_(device), maxBytes_(maxBytes) { + [device_ retain]; +} + +MetalBufferPool::~MetalBufferPool() { + clear(); + [device_ release]; +} + +id MetalBufferPool::acquire(size_t size) { + // Find best fit with bounded headroom + auto it = sizeMap_.lower_bound(size); + size_t doubleSize = (size > SIZE_MAX / 2) ? SIZE_MAX : 2 * size; + size_t sizeWithHeadroom = (size > SIZE_MAX - kMaxHeadroom) ? SIZE_MAX : size + kMaxHeadroom; + size_t maxAcceptable = std::min(doubleSize, sizeWithHeadroom); + + if (it != sizeMap_.end() && it->first <= maxAcceptable) { + auto lruIt = it->second; + id buffer = lruIt->buffer; + cachedBytes_ -= lruIt->size; + lruList_.erase(lruIt); + sizeMap_.erase(it); + ET_LOG(Debug, "MetalBufferPool: reused %zu byte buffer (requested %zu)", [buffer length], size); + return buffer; + } + + // Allocate new + id buffer = [device_ newBufferWithLength:size options:MTLResourceStorageModeShared]; + ET_LOG(Debug, "MetalBufferPool: allocated new %zu byte buffer", size); + return buffer; +} + +void MetalBufferPool::release(id buffer) { + size_t size = [buffer length]; + + // Don't pool very large buffers + if (size > maxBytes_ / 2) { + [buffer release]; + return; + } + + lruList_.push_front({buffer, size}); + sizeMap_.insert({size, lruList_.begin()}); + cachedBytes_ += size; + + // Evict if over limit + while (cachedBytes_ > maxBytes_ && !lruList_.empty()) { + evictOldest(); + } +} + +void MetalBufferPool::evictOldest() { + auto tail = std::prev(lruList_.end()); + cachedBytes_ -= tail->size; + + // Find and remove from sizeMap + auto range = sizeMap_.equal_range(tail->size); + for (auto it = range.first; it != range.second; ++it) { + if (it->second == tail) { + sizeMap_.erase(it); + break; + } + } + + [tail->buffer release]; + lruList_.erase(tail); +} + +void MetalBufferPool::clear() { + for (auto& entry : lruList_) { + [entry.buffer release]; + } + lruList_.clear(); + sizeMap_.clear(); + cachedBytes_ = 0; +} + +void MetalBufferPool::setMaxBytes(size_t bytes) { + maxBytes_ = bytes; + // Shrink immediately if cached > new cap. + while (cachedBytes_ > maxBytes_ && !lruList_.empty()) { + evictOldest(); + } +} + +void MetalBufferPool::prewarm(const std::vector& sizes) { + for (size_t size : sizes) { + if (size == 0) continue; + // Allocate the buffer and immediately seed it into the pool. Same + // semantics as alloc(size) followed by free(ptr) — but skipping the + // ptr → buffer round-trip and the residency-set add (caller can't + // know GPU addr yet anyway). Honors maxBytes by evicting LRU first. + while (cachedBytes_ + size > maxBytes_ && !lruList_.empty()) { + evictOldest(); + } + if (cachedBytes_ + size > maxBytes_) { + // Single entry larger than cap — skip; user should bump capacity. + ET_LOG(Info, + "MetalBufferPool::prewarm: size %zu exceeds capacity %zu, skipping", + size, maxBytes_); + continue; + } + id buffer = + [device_ newBufferWithLength:size options:MTLResourceStorageModeShared]; + if (!buffer) { + ET_LOG(Error, "MetalBufferPool::prewarm: alloc of %zu failed", size); + continue; + } + PoolEntry entry{buffer, size}; + lruList_.push_front(entry); + sizeMap_.insert({size, lruList_.begin()}); + cachedBytes_ += size; + } + ET_LOG(Info, + "MetalBufferPool::prewarm: seeded %zu buffers, cached=%zu/%zu bytes", + sizes.size(), cachedBytes_, maxBytes_); +} + + +} // namespace metal_v2 +} // namespace backends +} // namespace executorch + diff --git a/backends/portable/runtime/metal_v2/MetalHeap.mm b/backends/portable/runtime/metal_v2/MetalHeap.mm new file mode 100644 index 00000000000..f3b8ebccb65 --- /dev/null +++ b/backends/portable/runtime/metal_v2/MetalHeap.mm @@ -0,0 +1,83 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#import "MetalStream.h" +#include + +namespace executorch { +namespace backends { +namespace metal_v2 { + +//===----------------------------------------------------------------------===// +// MetalHeap +//===----------------------------------------------------------------------===// + +MetalHeap::MetalHeap(id device, size_t size, bool aliasable) : totalSize_(size) { + MTLHeapDescriptor* desc = [[MTLHeapDescriptor alloc] init]; + desc.size = size; + desc.storageMode = MTLStorageModeShared; // Unified memory + desc.cpuCacheMode = MTLCPUCacheModeDefaultCache; + + // Placement heap allows precise control over buffer placement + desc.type = MTLHeapTypePlacement; + + // Hazard tracking mode + if (@available(macOS 10.15, iOS 13.0, *)) { + desc.hazardTrackingMode = MTLHazardTrackingModeTracked; + } + + // Aliasable resources can share memory (reduces footprint) + if (@available(macOS 13.0, iOS 16.0, *)) { + if (aliasable) { + desc.type = MTLHeapTypeAutomatic; // Allows aliasing + } + } + + heap_ = [device newHeapWithDescriptor:desc]; + [desc release]; + + if (heap_) { + [heap_ retain]; + ET_LOG(Info, "MetalHeap: Created %zu MB heap (aliasable=%d)", size / (1024*1024), aliasable); + } else { + ET_LOG(Error, "MetalHeap: Failed to create heap"); + } +} + +MetalHeap::~MetalHeap() { + if (heap_) { + [heap_ release]; + } +} + +id MetalHeap::allocBuffer(size_t size) { + if (!heap_) return nil; + + // Check if heap has space + if (usedSize_ + size > totalSize_) { + ET_LOG(Info, "MetalHeap: Out of space (need %zu, have %zu)", + size, totalSize_ - usedSize_); + return nil; + } + + id buffer = [heap_ newBufferWithLength:size + options:MTLResourceStorageModeShared]; + if (buffer) { + usedSize_ += [buffer allocatedSize]; + ET_LOG(Debug, "MetalHeap: Allocated %zu bytes (used: %zu/%zu)", + size, usedSize_, totalSize_); + } + + return buffer; +} + + +} // namespace metal_v2 +} // namespace backends +} // namespace executorch + diff --git a/backends/portable/runtime/metal_v2/MetalKernel.mm b/backends/portable/runtime/metal_v2/MetalKernel.mm new file mode 100644 index 00000000000..397ebea1dec --- /dev/null +++ b/backends/portable/runtime/metal_v2/MetalKernel.mm @@ -0,0 +1,38 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#import "MetalStream.h" +#include + +namespace executorch { +namespace backends { +namespace metal_v2 { + +//===----------------------------------------------------------------------===// +// MetalKernel +//===----------------------------------------------------------------------===// + +MetalKernel::MetalKernel(id pipeline, const char* name) + : pipeline_(pipeline), name_(name) { + [pipeline_ retain]; +} + +MetalKernel::~MetalKernel() { + [pipeline_ release]; +} + +uvec3 MetalKernel::maxThreadsPerThreadgroup() const { + NSUInteger maxThreads = [pipeline_ maxTotalThreadsPerThreadgroup]; + return uvec3(static_cast(maxThreads), 1, 1); +} + + +} // namespace metal_v2 +} // namespace backends +} // namespace executorch + diff --git a/backends/portable/runtime/metal_v2/MetalKernelCompiler.mm b/backends/portable/runtime/metal_v2/MetalKernelCompiler.mm new file mode 100644 index 00000000000..209ee58bfc3 --- /dev/null +++ b/backends/portable/runtime/metal_v2/MetalKernelCompiler.mm @@ -0,0 +1,324 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#import "MetalStream.h" +#include +#include + +namespace executorch { +namespace backends { +namespace metal_v2 { + +//===----------------------------------------------------------------------===// +// MetalKernelCompiler +//===----------------------------------------------------------------------===// + +MetalKernelCompiler::MetalKernelCompiler(id device) : device_(device), binaryArchive_(nil) { + [device_ retain]; +} + +MetalKernelCompiler::~MetalKernelCompiler() { + if (binaryArchive_) { + [binaryArchive_ release]; + } +#if ET_METAL4_ENABLE + if (@available(macOS 26.0, iOS 26.0, *)) { + if (mtl4Compiler_) { + [mtl4Compiler_ release]; + mtl4Compiler_ = nil; + } + } +#endif + [device_ release]; +} + +bool MetalKernelCompiler::loadBinaryArchive(const char* path) { +#if ET_METAL4_AVAILABLE + if (@available(macOS 11.0, iOS 14.0, *)) { + @autoreleasepool { + NSURL* url = [NSURL fileURLWithPath:[NSString stringWithUTF8String:path]]; + + MTLBinaryArchiveDescriptor* desc = [[MTLBinaryArchiveDescriptor alloc] init]; + desc.url = url; + + NSError* error = nil; + id archive = [device_ newBinaryArchiveWithDescriptor:desc error:&error]; + [desc release]; + + if (archive) { + if (binaryArchive_) { + [binaryArchive_ release]; + } + binaryArchive_ = archive; + [binaryArchive_ retain]; + ET_LOG(Info, "MetalKernelCompiler: Loaded binary archive from %s", path); + return true; + } else { + ET_LOG(Debug, "MetalKernelCompiler: No binary archive at %s: %s", path, + error ? [[error localizedDescription] UTF8String] : "unknown"); + } + } + } +#endif + return false; +} + +bool MetalKernelCompiler::saveBinaryArchive(const char* path) { +#if ET_METAL4_AVAILABLE + if (@available(macOS 11.0, iOS 14.0, *)) { + if (!binaryArchive_) { + // Create new archive if none exists + MTLBinaryArchiveDescriptor* desc = [[MTLBinaryArchiveDescriptor alloc] init]; + NSError* error = nil; + binaryArchive_ = [device_ newBinaryArchiveWithDescriptor:desc error:&error]; + [desc release]; + + if (!binaryArchive_) { + ET_LOG(Error, "MetalKernelCompiler: Failed to create binary archive"); + return false; + } + [binaryArchive_ retain]; + } + + @autoreleasepool { + NSURL* url = [NSURL fileURLWithPath:[NSString stringWithUTF8String:path]]; + NSError* error = nil; + + if ([binaryArchive_ serializeToURL:url error:&error]) { + ET_LOG(Info, "MetalKernelCompiler: Saved binary archive to %s", path); + return true; + } else { + ET_LOG(Error, "MetalKernelCompiler: Failed to save binary archive: %s", + error ? [[error localizedDescription] UTF8String] : "unknown"); + } + } + } +#endif + return false; +} + +MetalKernel* MetalKernelCompiler::compile( + const char* source, + const char* functionName) { + // Cache key includes a hash of the source so different sources reusing + // the same function name (e.g., AOTI's "generated_kernel") don't collide. + std::string key = std::to_string(std::hash{}( + std::string_view(source))) + + "/" + functionName; + auto it = cache_.find(key); + if (it != cache_.end()) { + return it->second.get(); + } + + @autoreleasepool { + NSString* sourceStr = [NSString stringWithUTF8String:source]; + NSError* error = nil; + + MTLCompileOptions* options = [[MTLCompileOptions alloc] init]; + + // Metal 4: Use precise math mode +#if ET_METAL4_AVAILABLE + if (@available(macOS 15.0, iOS 18.0, *)) { + options.mathMode = MTLMathModeSafe; + options.mathFloatingPointFunctions = MTLMathFloatingPointFunctionsPrecise; + } +#endif + + id library = [device_ newLibraryWithSource:sourceStr options:options error:&error]; + [options release]; + + if (!library || error) { + ET_LOG(Error, "MetalKernelCompiler: failed to compile shader: %s", + error ? [[error localizedDescription] UTF8String] : "unknown"); + return nullptr; + } + + NSString* funcName = [NSString stringWithUTF8String:functionName]; + id function = [library newFunctionWithName:funcName]; + + if (!function) { + ET_LOG(Error, "MetalKernelCompiler: function '%s' not found", functionName); + [library release]; + return nullptr; + } + +#if ET_METAL4_ENABLE + // ----- Metal 4 dispatch path ----- + // Use MTL4Compiler when available. The output is still id + // which is the same protocol used by both legacy and MTL4 encoders. + if (useMTL4()) { + if (@available(macOS 26.0, iOS 26.0, *)) { + if (!mtl4Compiler_) { + MTL4CompilerDescriptor* compilerDesc = [[MTL4CompilerDescriptor alloc] init]; + NSError* compilerErr = nil; + mtl4Compiler_ = [device_ newCompilerWithDescriptor:compilerDesc error:&compilerErr]; + [compilerDesc release]; + if (!mtl4Compiler_ || compilerErr) { + ET_LOG(Error, "MetalKernelCompiler: MTL4Compiler creation failed: %s", + compilerErr ? [[compilerErr localizedDescription] UTF8String] : "unknown"); + // Fall through to legacy path + } + } + + if (mtl4Compiler_) { + MTL4LibraryFunctionDescriptor* funcDesc = [[MTL4LibraryFunctionDescriptor alloc] init]; + funcDesc.name = funcName; + funcDesc.library = library; + + MTL4ComputePipelineDescriptor* mtl4PipelineDesc = [[MTL4ComputePipelineDescriptor alloc] init]; + mtl4PipelineDesc.computeFunctionDescriptor = funcDesc; + mtl4PipelineDesc.label = funcName; + mtl4PipelineDesc.supportIndirectCommandBuffers = MTL4IndirectCommandBufferSupportStateEnabled; + + NSError* mtl4Err = nil; + id mtl4Pipeline = + [mtl4Compiler_ newComputePipelineStateWithDescriptor:mtl4PipelineDesc + compilerTaskOptions:nil + error:&mtl4Err]; + [funcDesc release]; + [mtl4PipelineDesc release]; + + if (mtl4Pipeline && !mtl4Err) { + [function release]; + [library release]; + auto kernel = std::make_unique(mtl4Pipeline, functionName); + [mtl4Pipeline release]; + MetalKernel* result = kernel.get(); + cache_[key] = std::move(kernel); + ET_LOG(Info, "MetalKernelCompiler: compiled '%s' via MTL4Compiler", functionName); + return result; + } + ET_LOG(Error, "MetalKernelCompiler: MTL4 pipeline creation failed for '%s': %s. Falling back to legacy.", + functionName, + mtl4Err ? [[mtl4Err localizedDescription] UTF8String] : "unknown"); + // Fall through to legacy path + } + } + } +#endif + + // Create pipeline descriptor for binary archive support + MTLComputePipelineDescriptor* pipelineDesc = [[MTLComputePipelineDescriptor alloc] init]; + pipelineDesc.computeFunction = function; + pipelineDesc.label = funcName; + pipelineDesc.supportIndirectCommandBuffers = YES; // Enable ICB support + + id pipeline = nil; + +#if ET_METAL4_AVAILABLE + // Try to load from binary archive first (fast path) + if (@available(macOS 11.0, iOS 14.0, *)) { + if (binaryArchive_) { + // Try to get pre-compiled pipeline from archive + MTLPipelineOption pipelineOptions = MTLPipelineOptionNone; + pipeline = [device_ newComputePipelineStateWithDescriptor:pipelineDesc + options:pipelineOptions + reflection:nil + error:&error]; + + if (pipeline) { + ET_LOG(Debug, "MetalKernelCompiler: Loaded '%s' from binary archive", functionName); + } + } + } +#endif + + // Compile using descriptor (required for ICB support) + if (!pipeline) { + pipeline = [device_ newComputePipelineStateWithDescriptor:pipelineDesc + options:MTLPipelineOptionNone + reflection:nil + error:&error]; + + if (!pipeline || error) { + ET_LOG(Error, "MetalKernelCompiler: failed to create pipeline: %s", + error ? [[error localizedDescription] UTF8String] : "unknown"); + [function release]; + [library release]; + [pipelineDesc release]; + return nullptr; + } + +#if ET_METAL4_AVAILABLE + // Add to binary archive for future use + if (@available(macOS 11.0, iOS 14.0, *)) { + if (binaryArchive_) { + NSError* archiveError = nil; + if ([binaryArchive_ addComputePipelineFunctionsWithDescriptor:pipelineDesc error:&archiveError]) { + ET_LOG(Debug, "MetalKernelCompiler: Added '%s' to binary archive", functionName); + } + } + } +#endif + } + + [function release]; + [library release]; + [pipelineDesc release]; + + auto kernel = std::make_unique(pipeline, functionName); + [pipeline release]; + + MetalKernel* result = kernel.get(); + cache_[key] = std::move(kernel); + + ET_LOG(Info, "MetalKernelCompiler: compiled '%s'", functionName); + return result; + } +} + +MetalKernel* MetalKernelCompiler::loadFromLibrary( + const void* metallibData, + size_t metallibSize, + const char* functionName) { + @autoreleasepool { + NSError* error = nil; + dispatch_data_t data = dispatch_data_create( + metallibData, metallibSize, nullptr, DISPATCH_DATA_DESTRUCTOR_DEFAULT); + + id library = [device_ newLibraryWithData:data error:&error]; + dispatch_release(data); + + if (!library || error) { + ET_LOG(Error, "MetalKernelCompiler: failed to load metallib: %s", + error ? [[error localizedDescription] UTF8String] : "unknown"); + return nullptr; + } + + NSString* funcName = [NSString stringWithUTF8String:functionName]; + id function = [library newFunctionWithName:funcName]; + [library release]; + + if (!function) { + ET_LOG(Error, "MetalKernelCompiler: function '%s' not found in metallib", functionName); + return nullptr; + } + + id pipeline = [device_ newComputePipelineStateWithFunction:function error:&error]; + [function release]; + + if (!pipeline || error) { + return nullptr; + } + + auto kernel = std::make_unique(pipeline, functionName); + [pipeline release]; + + std::string key = std::string("lib:") + functionName; + MetalKernel* result = kernel.get(); + cache_[key] = std::move(kernel); + + return result; + } +} + + +} // namespace metal_v2 +} // namespace backends +} // namespace executorch + diff --git a/backends/portable/runtime/metal_v2/MetalOp.h b/backends/portable/runtime/metal_v2/MetalOp.h new file mode 100644 index 00000000000..fecc3c1d835 --- /dev/null +++ b/backends/portable/runtime/metal_v2/MetalOp.h @@ -0,0 +1,90 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace metal_v2 { + +// Forward declarations — defined in MetalStream.h. We use forward decls +// here to avoid pulling Metal headers into pure-C++ op headers. +class MetalStream; +class MetalKernel; + +using runtime::etensor::Tensor; +using exec_aten::ArrayRef; +using exec_aten::SizesType; + +//===----------------------------------------------------------------------===// +// MetalOp - Base class for GPU operations +//===----------------------------------------------------------------------===// + +class MetalOp { +public: + virtual ~MetalOp() = default; + + /// Op name (e.g., "aten::add", "aten::mm") + virtual const char* name() const = 0; + + /// Check if this op supports the given dtype + virtual bool supports(ScalarType dtype) const { + return dtype == ScalarType::Float; + } + + /// Convenience alias for the EValue-pointer span used by op interfaces. + /// Caller provides storage (typically a stack std::array) so dispatch is + /// alloc-free per call. + using EValuePtrSpan = runtime::Span; + + /// Compute output shape from inputs (for resize) + /// Returns empty vector if output shape matches first input + virtual std::vector computeOutputShape( + EValuePtrSpan inputs) const { + return {}; + } + + /// Dispatch the op using stream + virtual void dispatch( + MetalStream* stream, + EValuePtrSpan inputs, + EValuePtrSpan outputs) = 0; + +protected: + /// Get or compile kernel by name (caches result) + MetalKernel* getKernel(MetalStream* stream, const char* kernelName); + + /// Kernel source code (subclass provides) + virtual const char* kernelSource() const = 0; + + /// Compute grid size from output tensor + uvec3 computeGrid(const Tensor& output, uint32_t blockSize = 256) const; + + /// Resize output tensor if needed + runtime::Error resizeOutput( + EValuePtrSpan inputs, + runtime::EValue* output) const; + +private: + std::unordered_map kernelCache_; +}; + +} // namespace metal_v2 +} // namespace backends +} // namespace executorch diff --git a/backends/portable/runtime/metal_v2/MetalOp.mm b/backends/portable/runtime/metal_v2/MetalOp.mm new file mode 100644 index 00000000000..df4b911dd5c --- /dev/null +++ b/backends/portable/runtime/metal_v2/MetalOp.mm @@ -0,0 +1,99 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#import "MetalOp.h" +#include +#include +#include + +namespace executorch { +namespace backends { +namespace metal_v2 { + +using runtime::Error; + +//===----------------------------------------------------------------------===// +// MetalOp Base Implementation +//===----------------------------------------------------------------------===// + +MetalKernel* MetalOp::getKernel(MetalStream* stream, const char* kernelName) { + auto it = kernelCache_.find(kernelName); + if (it != kernelCache_.end()) { + // Treat a previously cached null as a hard failure too — a cached null + // means a prior compile attempt for this name failed. + ET_CHECK_MSG( + it->second != nullptr, + "MetalOp '%s': previously failed to compile kernel '%s' (cached null). " + "Most likely the kernel template wasn't instantiated for this dtype " + "(check kernelSource() for the missing [[host_name(\"...\")]] entry).", + name(), kernelName); + return it->second; + } + + MetalKernel* kernel = stream->compiler()->compile(kernelSource(), kernelName); + kernelCache_[kernelName] = kernel; + // Hard-fail on missing kernel rather than letting the dispatch silently + // no-op. The previous behavior just logged at ERROR and returned, which + // produced numerically wrong results that *looked* like a successful run + // (e.g. fake "fast" perf because the kernel wasn't actually executing). + ET_CHECK_MSG( + kernel != nullptr, + "MetalOp '%s': failed to compile/find kernel '%s'. " + "Most likely the kernel template wasn't instantiated for this dtype " + "(add a `template [[host_name(\"%s\")]] kernel void ...(...)` line " + "to kernelSource()'s template instantiation block).", + name(), kernelName, kernelName); + return kernel; +} + +uvec3 MetalOp::computeGrid(const Tensor& output, uint32_t blockSize) const { + size_t numel = output.numel(); + return uvec3((uint32_t)((numel + blockSize - 1) / blockSize), 1, 1); +} + +Error MetalOp::resizeOutput( + EValuePtrSpan inputs, + runtime::EValue* output) const { + + if (!output->isTensor()) { + return Error::InvalidArgument; + } + + auto& out_tensor = output->toTensor(); + auto new_shape = computeOutputShape(inputs); + + if (new_shape.empty()) { + if (!inputs.empty() && inputs[0]->isTensor()) { + auto& in_tensor = inputs[0]->toTensor(); + new_shape.assign(in_tensor.sizes().begin(), in_tensor.sizes().end()); + } + } + + if (!new_shape.empty()) { + auto current = out_tensor.sizes(); + bool needs_resize = (current.size() != new_shape.size()); + if (!needs_resize) { + for (size_t i = 0; i < current.size(); i++) { + if (current[i] != new_shape[i]) { + needs_resize = true; + break; + } + } + } + + if (needs_resize) { + return runtime::resize_tensor(out_tensor, ArrayRef(new_shape.data(), new_shape.size())); + } + } + + return runtime::Error::Ok; +} + +} // namespace metal_v2 +} // namespace backends +} // namespace executorch diff --git a/backends/portable/runtime/metal_v2/MetalOpRegistry.h b/backends/portable/runtime/metal_v2/MetalOpRegistry.h new file mode 100644 index 00000000000..9ae41584de9 --- /dev/null +++ b/backends/portable/runtime/metal_v2/MetalOpRegistry.h @@ -0,0 +1,45 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include +#include +#include + +namespace executorch { +namespace backends { +namespace metal_v2 { + +//===----------------------------------------------------------------------===// +// MetalOpRegistry — global singleton mapping op-name → MetalOp instance +// +// Populated at process start (constructor registers built-in ops). External +// callers look up ops via MetalOpRegistry::shared().get("aten::add"). +//===----------------------------------------------------------------------===// + +class MetalOpRegistry { + public: + static MetalOpRegistry& shared(); + + void registerOp(std::unique_ptr op); + MetalOp* get(const char* name) const; + MetalOp* get(const std::string& name) const { return get(name.c_str()); } + bool hasOp(const char* name) const; + bool hasOp(const std::string& name) const { return hasOp(name.c_str()); } + + private: + MetalOpRegistry(); + std::unordered_map> ops_; +}; + +} // namespace metal_v2 +} // namespace backends +} // namespace executorch diff --git a/backends/portable/runtime/metal_v2/MetalOpRegistry.mm b/backends/portable/runtime/metal_v2/MetalOpRegistry.mm new file mode 100644 index 00000000000..7ad011e174f --- /dev/null +++ b/backends/portable/runtime/metal_v2/MetalOpRegistry.mm @@ -0,0 +1,55 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "MetalOpRegistry.h" + +// Op implementations registered at construction time. +#include "ops/BinaryOps.h" +#include "ops/UnaryOps.h" +#include "ops/MatMulOp.h" + +namespace executorch { +namespace backends { +namespace metal_v2 { + +MetalOpRegistry& MetalOpRegistry::shared() { + static MetalOpRegistry instance; + return instance; +} + +MetalOpRegistry::MetalOpRegistry() { + // Register binary ops + registerOp(std::make_unique()); + registerOp(std::make_unique()); + registerOp(std::make_unique()); + + // Register unary ops + registerOp(std::make_unique()); + + // Register matmul ops + registerOp(std::make_unique()); + registerOp(std::make_unique()); + registerOp(std::make_unique()); +} + +void MetalOpRegistry::registerOp(std::unique_ptr op) { + ops_[op->name()] = std::move(op); +} + +MetalOp* MetalOpRegistry::get(const char* name) const { + auto it = ops_.find(name); + return it != ops_.end() ? it->second.get() : nullptr; +} + +bool MetalOpRegistry::hasOp(const char* name) const { + return ops_.find(name) != ops_.end(); +} + +} // namespace metal_v2 +} // namespace backends +} // namespace executorch diff --git a/backends/portable/runtime/metal_v2/MetalRuntime.h b/backends/portable/runtime/metal_v2/MetalRuntime.h new file mode 100644 index 00000000000..41d148873f1 --- /dev/null +++ b/backends/portable/runtime/metal_v2/MetalRuntime.h @@ -0,0 +1,79 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include +#include + +namespace executorch { +namespace backends { +namespace metal_v2 { + +// Forward declarations - avoid including Metal headers in C++ files +class MetalStream; +class MetalOp; +class MetalOpRegistry; + +//===----------------------------------------------------------------------===// +// MetalRuntime - GraphRuntime implementation using MetalStream +//===----------------------------------------------------------------------===// + +class MetalRuntime : public portable::GraphRuntime { +public: + MetalRuntime(); + ~MetalRuntime() override; + + //=== GraphRuntime interface === + + const char* name() const override { return "MetalRuntime_v2"; } + bool is_available() const override; + + bool has_op(const portable::OperatorCall& op, const portable::Graph& graph) const override; + + runtime::Error init( + const portable::Graph& graph, + runtime::ArrayRef segments) override; + + runtime::Error initialize_constants( + runtime::Span value_indices) override; + + runtime::Error initialize_buffers( + runtime::Span value_indices) override; + + runtime::Error execute_segment( + size_t segment_index, + runtime::Span values) override; + + runtime::Error upload_values( + runtime::Span cpu_values, + runtime::Span indices) override; + + runtime::Error download_values( + runtime::Span cpu_values, + runtime::Span indices) override; + + void destroy() override; + +private: + MetalStream* stream_; + const portable::Graph* graph_; + std::vector segments_; + + // Value index -> GPU buffer mapping + std::unordered_map valueBuffers_; + + // Cached sizes for detecting reallocation needs + std::unordered_map cachedSizes_; +}; + +} // namespace metal_v2 +} // namespace backends +} // namespace executorch diff --git a/backends/portable/runtime/metal_v2/MetalRuntime.mm b/backends/portable/runtime/metal_v2/MetalRuntime.mm new file mode 100644 index 00000000000..e96eb504e1e --- /dev/null +++ b/backends/portable/runtime/metal_v2/MetalRuntime.mm @@ -0,0 +1,202 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#import "MetalRuntime.h" +#import "GpuStream.h" +#import "MetalOp.h" +#import "MetalOpRegistry.h" +#include +#include + +namespace executorch { +namespace backends { +namespace metal_v2 { + +using runtime::Error; +using runtime::EValue; +using runtime::Span; + +MetalRuntime::MetalRuntime() : stream_(nullptr), graph_(nullptr) { + stream_ = MetalStream::getDefault(); +} + +MetalRuntime::~MetalRuntime() { + destroy(); +} + +bool MetalRuntime::is_available() const { + return stream_ != nullptr; +} + +bool MetalRuntime::has_op(const portable::OperatorCall& op, const portable::Graph& graph) const { + const char* op_name = op.name(); + bool has = MetalOpRegistry::shared().hasOp(op_name); + ET_LOG(Info, "MetalRuntime_v2: has_op('%s') = %d", op_name ? op_name : "null", has); + return has; +} + +Error MetalRuntime::init( + const portable::Graph& graph, + ArrayRef segments) { + + graph_ = &graph; + segments_.assign(segments.begin(), segments.end()); + + ET_LOG(Info, "MetalRuntime_v2: initialized with %zu segments", segments_.size()); + return Error::Ok; +} + +Error MetalRuntime::initialize_constants(Span value_indices) { + // TODO: Copy constants to GPU + // For now, constants are handled by the EValue system + ET_LOG(Info, "MetalRuntime_v2: initialize_constants called (%zu values)", value_indices.size()); + return Error::Ok; +} + +Error MetalRuntime::initialize_buffers(Span value_indices) { + // Pre-allocate GPU buffers for the values we'll use + // For unified memory, we just track which values we need + for (auto idx : value_indices) { + // Mark that this runtime handles this value + // Actual allocation happens lazily or on first use + valueBuffers_[idx] = nullptr; // Will be set during upload + } + + ET_LOG(Info, "MetalRuntime_v2: initialized %zu buffer slots", value_indices.size()); + return Error::Ok; +} + +Error MetalRuntime::execute_segment(size_t segment_index, Span values) { + if (segment_index >= segments_.size()) { + return Error::InvalidArgument; + } + + const auto& segment = segments_[segment_index]; + + for (auto instr_idx : segment.instruction_indices) { + // Get instruction from graph + auto op = graph_->get_instruction(instr_idx); + const char* op_name = op.name(); + + MetalOp* gpuOp = MetalOpRegistry::shared().get(op_name); + if (!gpuOp) { + ET_LOG(Error, "MetalRuntime_v2: unsupported op '%s'", op_name ? op_name : "null"); + return Error::NotSupported; + } + + // Gather inputs and outputs + std::vector inputs; + std::vector outputs; + + for (size_t i = 0; i < op.num_inputs(); i++) { + int32_t idx = op.input(i); + if (idx >= 0 && static_cast(idx) < values.size()) { + inputs.push_back(&values[idx]); + } + } + for (size_t i = 0; i < op.num_outputs(); i++) { + int32_t idx = op.output(i); + if (idx >= 0 && static_cast(idx) < values.size()) { + outputs.push_back(&values[idx]); + } + } + + // Ensure output tensors have allocated memory + for (auto* output : outputs) { + if (output->isTensor()) { + auto& tensor = output->toTensor(); + if (!tensor.mutable_data_ptr() && tensor.nbytes() > 0) { + // Allocate buffer for this output + void* data = stream_->alloc(tensor.nbytes()); + if (!data) { + ET_LOG(Error, "MetalRuntime_v2: failed to alloc %zu bytes for output", tensor.nbytes()); + return Error::MemoryAllocationFailed; + } + + // Set the tensor's data pointer + // Note: This requires mutable access to the tensor implementation + // For ExecuTorch, we need to use the TensorImpl API + auto impl = tensor.unsafeGetTensorImpl(); + if (impl) { + impl->set_data(data); + ET_LOG(Info, "MetalRuntime_v2: allocated %zu bytes for output at %p", tensor.nbytes(), data); + } + } + } + } + + // Dispatch - MetalStream handles replay automatically + gpuOp->dispatch( + stream_, + MetalOp::EValuePtrSpan(inputs.data(), inputs.size()), + MetalOp::EValuePtrSpan(outputs.data(), outputs.size())); + } + + // Note: don't sync here - let dispatches accumulate + // sync() happens in download_values() + + return Error::Ok; +} + +Error MetalRuntime::upload_values( + Span cpu_values, + Span indices) { + + // For unified memory (Apple Silicon), no copy needed + // The CPU and GPU share the same memory + // Just track the buffers + + for (size_t i = 0; i < indices.size(); i++) { + uint32_t idx = indices[i]; + if (idx < cpu_values.size() && cpu_values[idx].isTensor()) { + auto& tensor = cpu_values[idx].toTensor(); + void* data = const_cast(tensor.const_data_ptr()); + if (data) { + valueBuffers_[idx] = data; + cachedSizes_[idx] = tensor.nbytes(); + } + } + } + + ET_LOG(Debug, "MetalRuntime_v2: uploaded %zu values (unified memory)", indices.size()); + return Error::Ok; +} + +Error MetalRuntime::download_values( + Span cpu_values, + Span indices) { + + // Sync GPU work first + stream_->sync(); + + // For unified memory, no copy needed + // Data is already in CPU-accessible memory + + ET_LOG(Debug, "MetalRuntime_v2: downloaded %zu values (unified memory)", indices.size()); + return Error::Ok; +} + +void MetalRuntime::destroy() { + // Free any GPU buffers we allocated + for (auto& [idx, buffer] : valueBuffers_) { + if (buffer) { + // Only free buffers we allocated via stream_->alloc() + // Don't free buffers from unified memory (they come from EValue tensors) + } + } + valueBuffers_.clear(); + cachedSizes_.clear(); + segments_.clear(); + graph_ = nullptr; + + ET_LOG(Info, "MetalRuntime_v2: destroyed"); +} + +} // namespace metal_v2 +} // namespace backends +} // namespace executorch diff --git a/backends/portable/runtime/metal_v2/MetalStream.h b/backends/portable/runtime/metal_v2/MetalStream.h new file mode 100644 index 00000000000..e873581df6e --- /dev/null +++ b/backends/portable/runtime/metal_v2/MetalStream.h @@ -0,0 +1,604 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#import +#import + +// Metal 4 SDK availability check. Used to gate MTLResidencySet and other +// Metal-15+ APIs that exist as a runtime feature even when ET_METAL4_ENABLE +// (the build flag for the MTL4 dispatch path) is off. +#if (defined(__MAC_OS_X_VERSION_MAX_ALLOWED) && __MAC_OS_X_VERSION_MAX_ALLOWED >= 150000) || \ + (defined(__IPHONE_OS_VERSION_MAX_ALLOWED) && __IPHONE_OS_VERSION_MAX_ALLOWED >= 180000) +#define ET_METAL4_AVAILABLE 1 +#else +#define ET_METAL4_AVAILABLE 0 +#endif + +// Forward decl: lambda signature for encodeWithLegacyCommandBuffer takes +// MPSCommandBuffer*. Actual MPSCommandBuffer.h import lives in MetalStream.mm +// to keep this header lightweight. +@class MPSCommandBuffer; + +#include +#include +#include +#include +#include +#include +#include +#include + +//===----------------------------------------------------------------------===// +// ET_METAL4_ENABLE +// +// Compile-time opt-in for the Metal 4 dispatch path (MTL4Compiler, +// MTL4ComputeCommandEncoder, MTL4ArgumentTable, MTL4CommandQueue, ...). +// +// Default: 0 (legacy path). Define ET_METAL4_ENABLE=1 in the build to compile +// the MTL4 code paths in. Even when compiled in, MTL4 is only used at runtime +// when the OS supports it (macOS 26+ / iOS 26+) -- see useMTL4() below. +// +// This is independent of ET_METAL4_AVAILABLE (defined in MetalStream.mm), +// which guards macOS-15-era APIs (MTLResidencySet, MTLBinaryArchive, ...). +//===----------------------------------------------------------------------===// +#ifndef ET_METAL4_ENABLE +#define ET_METAL4_ENABLE 0 +#endif + +namespace executorch { +namespace backends { +namespace metal_v2 { + +/// True iff the Metal 4 dispatch paths are both compiled in AND the runtime +/// OS supports them. Use this as the single check at every MTL4 call site: +/// if (useMTL4()) { ...mtl4 path... } else { ...legacy path... } +inline bool useMTL4() { +#if ET_METAL4_ENABLE + if (@available(macOS 26.0, iOS 26.0, *)) { + return true; + } +#endif + return false; +} + +//===----------------------------------------------------------------------===// +// MetalHeap - Pre-allocated memory pool for fast sub-allocation +//===----------------------------------------------------------------------===// + +class MetalHeap { +public: + MetalHeap(id device, size_t size, bool aliasable = false); + ~MetalHeap(); + + /// Allocate buffer from heap (fast: ~100ns vs ~10µs) + id allocBuffer(size_t size); + + /// Get current used size + size_t usedSize() const { return usedSize_; } + + /// Get total heap size + size_t totalSize() const { return totalSize_; } + + /// Reset heap (invalidates all buffers) + void reset() { usedSize_ = 0; } + +private: + id heap_; + size_t totalSize_; + size_t usedSize_ = 0; +}; + +//===----------------------------------------------------------------------===// +// MetalBufferPool - LRU buffer pool with best-fit matching +//===----------------------------------------------------------------------===// + +class MetalBufferPool { +public: + explicit MetalBufferPool(id device, size_t maxBytes = 256 * 1024 * 1024); + ~MetalBufferPool(); + + /// Acquire a buffer of at least `size` bytes + id acquire(size_t size); + + /// Return buffer to pool + void release(id buffer); + + /// Clear all cached buffers + void clear(); + + /// Current bytes in pool + size_t cachedBytes() const { return cachedBytes_; } + + /// Maximum bytes the pool will hold before evicting LRU entries. + size_t maxBytes() const { return maxBytes_; } + + /// Update the cap. If new cap < current cachedBytes_, evicts LRU until + /// under cap. Useful when caller knows memory budget at init. + void setMaxBytes(size_t bytes); + + /// Pre-allocate buffers of these sizes and seed them into the cache so + /// the first round of acquire() calls hit the cache instead of going to + /// the device. Useful when the caller has a memory plan from AOTI. + /// If total prewarmed bytes exceeds maxBytes_, oldest entries get evicted. + void prewarm(const std::vector& sizes); + +private: + void evictOldest(); + + id device_; + size_t maxBytes_; + size_t cachedBytes_ = 0; + + struct PoolEntry { + id buffer; + size_t size; + }; + + std::list lruList_; // newest at front + std::multimap::iterator> sizeMap_; + + static constexpr size_t kMaxHeadroom = 32768; // 32KB +}; + +//===----------------------------------------------------------------------===// +// MetalKernel - Compiled Metal compute pipeline +//===----------------------------------------------------------------------===// + +class MetalKernel { +public: + MetalKernel(id pipeline, const char* name); + ~MetalKernel(); + + const char* name() const { return name_.c_str(); } + uvec3 maxThreadsPerThreadgroup() const; + void* nativeHandle() { return (__bridge void*)pipeline_; } + + id pipeline() const { return pipeline_; } + +private: + id pipeline_; + std::string name_; +}; + +//===----------------------------------------------------------------------===// +// MetalKernelCompiler +//===----------------------------------------------------------------------===// + +class MetalKernelCompiler { +public: + explicit MetalKernelCompiler(id device); + ~MetalKernelCompiler(); + + MetalKernel* compile( + const char* source, + const char* functionName); + + MetalKernel* loadFromLibrary( + const void* metallibData, + size_t metallibSize, + const char* functionName); + + //=== Binary Archive Support (Metal 4) === + + /// Load binary archive from file (fast shader loading) + bool loadBinaryArchive(const char* path); + + /// Save compiled shaders to binary archive + bool saveBinaryArchive(const char* path); + + /// Check if binary archive is loaded + bool hasBinaryArchive() const { return binaryArchive_ != nil; } + +private: + id device_; + id binaryArchive_; + std::unordered_map> cache_; + +#if ET_METAL4_ENABLE + // Lazily-created MTL4 compiler. Used when useMTL4() is true. Reused across + // pipeline creations so we don't pay the per-compiler setup more than once. + id mtl4Compiler_ API_AVAILABLE(macos(26.0), ios(26.0)) = nil; +#endif +}; + +//===----------------------------------------------------------------------===// +// DispatchSignature - For replay detection +//===----------------------------------------------------------------------===// + +struct DispatchSignature { + MetalKernel* kernel; + std::vector bufferSizes; + uvec3 grid; + uvec3 block; + + bool operator==(const DispatchSignature& other) const { + return kernel == other.kernel && + bufferSizes == other.bufferSizes && + grid.x == other.grid.x && grid.y == other.grid.y && grid.z == other.grid.z && + block.x == other.block.x && block.y == other.block.y && block.z == other.block.z; + } +}; + +//===----------------------------------------------------------------------===// +// MetalStream - Main implementation +// +// Uses Metal 4 APIs where available: +// - MTLResidencySet for GPU-resident buffers (macOS 15+, iOS 18+) +// - MTLIndirectCommandBuffer for command replay +//===----------------------------------------------------------------------===// + +class MetalStream { +public: + MetalStream(); + ~MetalStream(); + + //=== Core API === + void dispatch( + MetalKernel* kernel, + std::initializer_list args, + uvec3 grid, + uvec3 block); + + void flush(); + void wait(); + void sync() { flush(); wait(); } + void* alloc(size_t bytes); + void free(void* ptr); + // Register a host-allocated pointer with the stream so dispatches can + // resolve it to an MTLBuffer via bufferForPtr(). On Apple Silicon + // unified memory this is the cheap path for caller-storage tensors. + // + // strict_zero_copy: + // - false (default): tries newBufferWithBytesNoCopy first; if Metal + // refuses (typically because the pointer isn't page-aligned), + // falls back to newBufferWithBytes which COPIES the data once. + // The copy is fine for one-shot graph inputs/outputs but breaks + // true aliasing (subsequent writes to ptr won't be visible to + // the GPU buffer). + // - true: returns false instead of falling back to a copy. Use this + // when the caller needs a guaranteed zero-copy alias (e.g. router + // persistent-alias optimization for cross-runtime intermediates). + bool registerExternalBuffer( + void* ptr, size_t bytes, bool strict_zero_copy = false); + + //=== Optional Control === + void setFlushInterval(int dispatches); + int flushInterval() const { return flushInterval_; } + + void setUseICB(bool enable) { useICB_ = enable; } + bool useICB() const { return useICB_; } + + //=== Accessors === + MetalKernelCompiler* compiler() { return compiler_.get(); } + + //=== Static factories === + /// Get singleton default stream (NOT thread-safe for concurrent dispatch). + static MetalStream* getDefault(); + /// Get thread-local stream (thread-safe — each thread gets its own). + static MetalStream* getThreadLocal(); + /// Create a new independent stream (caller owns lifetime). + static std::unique_ptr create(); + + id device() const { return device_; } + + //=== Per-execute lifecycle hook === + // Called by synchronize_metal_stream at the end of every AOTI forward + // call. Resets per-execute bookkeeping that's tracked across iterations + // for replay correctness: + // - currentDispatchIdx_ → 0 position in the replay sequence + // - icbRecordedThisIter_ → 0 count of ops submitted (cold or replay) + // this iteration; used by partial-flush as + // upper bound (so iter-2 partial flush at + // matmul_K only drains ops [0..K), not the + // whole ICB which would re-execute later + // ops with stale iter-1 bindings). + void endExecute() { + currentDispatchIdx_ = 0; + icbRecordedThisIter_ = 0; + } + + //=== Buffer lookup === + // Resolve a host pointer to its registered MTLBuffer. Auto-registers if + // unknown. Used by ops that need to wrap inputs/outputs as MPS-specific + // buffer types (e.g. MPSGraphTensorData). + id bufferForPtr(void* ptr, size_t bytes) { + if (!ptrToBuffer_.count(ptr)) { + registerExternalBuffer(ptr, bytes); + } + auto it = ptrToBuffer_.find(ptr); + return it == ptrToBuffer_.end() ? nil : it->second; + } + + //=== MPSGraph integration === + // Encode work that requires a legacy MTLCommandBuffer (currently the only + // such consumer is MPSGraph). MetalStream owns the entire orchestration: + // + // Under MTL3: + // - End any open compute encoder, drain pending ICB into the live cb + // - Get/create the live legacy commandBuffer_, wrap as MPSCommandBuffer + // - Invoke encode_fn(mpsCB) — caller does [graph encodeToCommandBuffer:] + // - Adopt back mpsCB.commandBuffer (may be replaced if MPS internally + // called commitAndContinue:) + // + // Under MTL4: + // - End any open mtl4 encoder, drain pending ICB into mtl4 cb + // - Commit current mtl4 cb to mtl4Queue, schedule queue-level signal + // on internal MPS-sync event (so MPS work can wait for prior mtl4 + // work to complete on GPU) + // - Create a fresh legacy MTLCommandBuffer, wrap as MPSCommandBuffer + // - Encode cb-level wait on the MPS-sync event + // - Invoke encode_fn(mpsCB) + // - Encode cb-level signal on the MPS-sync event + // - Commit (fire-and-forget on legacy queue) and schedule a + // queue-level wait on mtl4Queue so the next mtl4 cb commit gates on + // MPS completion. wait() at execute end blocks on the MPS event. + // + // No CPU stall under MTL4 — both queues run concurrently, gated by the + // cross-queue event. Caller's encode_fn just does its MPS encode and + // returns; all sync is handled internally. + void encodeWithLegacyCommandBuffer( + std::function encode_fn); + + //=== Binary Archive === + bool loadShaderArchive(const char* path) { + return compiler_->loadBinaryArchive(path); + } + bool saveShaderArchive(const char* path) { + return compiler_->saveBinaryArchive(path); + } + + //=== Heap Allocation === + /// Enable heap-based allocation (faster, but fixed size) + /// Call before any alloc() calls + void enableHeap(size_t heapSizeBytes, bool aliasable = false); + + /// Check if heap is enabled + bool heapEnabled() const { return useHeap_; } + + //=== Buffer pool tuning === + /// Set the LRU buffer-pool capacity (bytes). Default 256 MiB. + /// If new cap < currently-cached bytes, evicts LRU until under cap. + /// Useful when caller knows the model's working-set size. + void setBufferPoolCapacity(size_t bytes) { + if (bufferPool_) bufferPool_->setMaxBytes(bytes); + } + + /// Pre-allocate buffers of these sizes and seed them into the pool. + /// Useful when caller has a memory plan from AOTI / model compiler so + /// the first iteration's alloc() calls hit the cache instead of cold- + /// allocating from the device. + void prewarmBufferPool(const std::vector& sizes) { + if (bufferPool_) bufferPool_->prewarm(sizes); + } + + //=== Thread Safety === + /// Enable mutex protection for shared stream (default: false) + void setThreadSafe(bool enabled) { threadSafe_ = enabled; } + +private: + // Internal MPS-sync helpers — used only by encodeWithLegacyCommandBuffer. + // (Were previously public methods named publicEndEncoder, publicQueue, + // publicFlushPendingICB, publicCommandBuffer, publicAdoptCommandBuffer, + // publicCommandBufferDone, publicMTL4CommitAndSignal, publicMTL4QueueWait, + // publicNoteMpsEventValue. Made private as MPSGraphOp now goes through the + // single high-level encodeWithLegacyCommandBuffer entry point.) + + // Drain any ICB ops recorded since the last partial/full ICB flush into + // the live cb (without commit). Bounded by icbRecordedThisIter_ for + // replay correctness — see field comment. + void flushPendingICB(); + + // Get/create the live legacy MTLCommandBuffer. Always creates from queue_ + // (under both MTL3 and MTL4) so callers needing a legacy cb get one. + id getOrCreateLegacyCommandBuffer(); + + // Adopt a (possibly replaced) live legacy cb after an external encoder + // (e.g. MPSGraph encodeToCommandBuffer:) may have called commitAndContinue. + void adoptLegacyCommandBuffer(id newCB); + + // Mark the live legacy cb as committed externally; next dispatch opens fresh. + void releaseLegacyCommandBuffer(); + +#if ET_METAL4_ENABLE + // Commit current mtl4 cb (if any work) to mtl4Queue and schedule + // queue-level signal=value on `event` after committed work completes. + void mtl4CommitAndSignal(id event, uint64_t value); + + // Schedule queue-level wait on mtl4Queue for event=value before the next + // committed mtl4 cb runs. + void mtl4QueueWait(id event, uint64_t value); + + // Lazy-create the per-stream MPS-sync event used by + // encodeWithLegacyCommandBuffer to bracket MPS encode with cross-queue + // wait/signal. One event per MetalStream is sufficient (each stream has + // its own queues; no cross-stream MPS sync needed). + id getOrCreateMpsSyncEvent(); +#endif + + void endEncoderInternal() { endEncoder(); } + + //========================================================================= + // Internal — implementation details (continued) + //========================================================================= +private: + void ensureCommandBuffer(); + void ensureEncoder(); + void endEncoder(); + + // Reset the ICB record/replay cache. Called internally on signature + // mismatch (next dispatch differs from what's recorded). Not exposed + // publicly because external callers can't know when to invoke it — + // signature tracking is internal. Future: if/when we expose graph + // boundaries (e.g. "AOTI just compiled a new model variant, drop old + // recording"), promote back to public. + void invalidate(); + + void encodeDispatch(MetalKernel* kernel, const std::vector& args, uvec3 grid, uvec3 block); + void updateArgumentBuffer(size_t dispatchIdx, const std::vector& args); + DispatchSignature buildSignature(MetalKernel* kernel, const std::vector& args, uvec3 grid, uvec3 block); + + // ICB helpers + void setupICB(); + void setupArgumentBuffer(size_t numDispatches); + void encodeIntoICB(MetalKernel* kernel, const std::vector& args, uvec3 grid, uvec3 block); + void executeICB(); + bool supportsGPUAddress() const; + + // Metal 4: Residency helpers + void setupResidencySet(); + void addToResidencySet(id buffer); + void commitResidencySet(); + +#if ET_METAL4_ENABLE + // Initialize MTL4 queue/allocator/argument-table/scratch/event when + // useMTL4(). Called from the constructor; nil-safe so legacy code path + // works when MTL4 init fails. + void setupMTL4() API_AVAILABLE(macos(26.0), ios(26.0)); + // Release all MTL4 objects. Called from destructor. + void teardownMTL4() API_AVAILABLE(macos(26.0), ios(26.0)); +#endif + + // Device flush interval based on architecture + int getDefaultFlushInterval() const; + +private: + id device_; + id queue_; + id commandBuffer_; + + // In-flight (committed-but-not-drained) command buffer. + id inFlightCommandBuffer_ = nil; + id encoder_; + +#if ET_METAL4_ENABLE + // ----- Metal 4 dispatch members ----- + // All gated by ET_METAL4_ENABLE at compile time and useMTL4() at runtime. + // Created by setupMTL4() during construction. nil if useMTL4() is false. + id mtl4Queue_ API_AVAILABLE(macos(26.0), ios(26.0)) = nil; + id mtl4Allocator_ API_AVAILABLE(macos(26.0), ios(26.0)) = nil; + id mtl4CommandBuffer_ API_AVAILABLE(macos(26.0), ios(26.0)) = nil; + id mtl4InFlightCommandBuffer_ API_AVAILABLE(macos(26.0), ios(26.0)) = nil; + id mtl4Encoder_ API_AVAILABLE(macos(26.0), ios(26.0)) = nil; + // Single argument table reused across dispatches in the direct path. + // Bindings are overwritten per dispatch, then setArgumentTable: + dispatch + // captures the snapshot. + id mtl4ArgTable_ API_AVAILABLE(macos(26.0), ios(26.0)) = nil; + // Bump scratch buffer for inline scalars (replaces setBytes:atIndex:). + // 1 MB ring; reset on flush(). + id mtl4ScalarScratch_ = nil; + size_t mtl4ScalarScratchOffset_ = 0; + // Completion signal for wait(). Incremented on each commit; wait() blocks + // until this value is reached. + id mtl4CompletionEvent_ API_AVAILABLE(macos(26.0), ios(26.0)) = nil; + uint64_t mtl4CompletionValue_ = 0; + // Per-stream MPS-sync event (lazy-created via getOrCreateMpsSyncEvent() + // when first MPS op runs). Used by encodeWithLegacyCommandBuffer to + // bracket MPS encode with cross-queue wait/signal between mtl4Queue and + // legacy queue_. Single event suffices: same stream, monotonic counter. + id mpsSyncEvent_ API_AVAILABLE(macos(26.0), ios(26.0)) = nil; + uint64_t mpsSyncCounter_ = 1; + // Tracks outstanding MPS work committed to the legacy queue_ during this + // execute. Set by publicNoteMpsEventValue() (called from MPSGraphOp under + // MTL4 after committing each mps cb); wait() blocks on this value to + // ensure MPS work completes before the CPU reads outputs. + id pendingMpsEvent_ = nil; + uint64_t pendingMpsEventValue_ = 0; + // Fence pool for cross-encoder dependencies (ICB segment chain). + // Created lazily; small ring of MTLFence objects. + std::vector> mtl4Fences_; +#endif + + + // Buffer management + std::unique_ptr bufferPool_; + std::unique_ptr heap_; // Fast allocation for model buffers + bool useHeap_ = false; + std::unordered_map> ptrToBuffer_; + std::unordered_set externalBuffers_; // Track which buffers are external (not allocated by us) + + // Kernel compiler + std::unique_ptr compiler_; + + // ICB for replay + id icb_; + id argumentBuffer_; // Holds GPU addresses for all dispatches + size_t argumentBufferSize_ = 0; + size_t argumentBufferOffset_ = 0; + bool icbValid_ = false; + size_t icbDispatchCount_ = 0; + // For ICB+MPS coexistence: tracks how many ICB commands have already been + // executeCommandsInBuffer:'d into the live cmd buffer via partial flushes + // (publicFlushPendingICB called from MPSGraphOp before its encode). flush() + // only runs commands at indices [icbExecutedCount_, icbDispatchCount_) so + // we don't double-execute. Reset to 0 at end of flush() and on invalidate. + size_t icbExecutedCount_ = 0; + size_t maxICBCommands_ = 1024; // Max dispatches per ICB + + // Per-dispatch argument layout + static constexpr size_t kMaxBuffersPerDispatch = 8; + static constexpr size_t kMaxScalarsPerDispatch = 8; + struct DispatchArgLayout { + size_t offset; // Offset in argumentBuffer_ + size_t numBuffers; + size_t numScalars; + }; + std::vector argLayouts_; + // Per-ICB-command pipeline state, parallel to icbDispatchCount_. Needed + // for MTL4 ICB execute path because the encoder must have a pipeline + // state set BEFORE setArgumentTable: takes effect for the upcoming + // dispatch (MTL4 binds-via-table model). The pipeline is also recorded + // inside the ICB command itself (setComputePipelineState during encode), + // but the encoder needs a copy too. + std::vector> dispatchPipelines_; + + // Dependency tracking for ICB barriers + std::unordered_set writtenBuffers_; // Buffers written by previous ops + std::vector barrierIndices_; // Dispatch indices where barriers needed + + // Signature tracking + std::vector signatures_; + size_t currentDispatchIdx_ = 0; + bool isReplaying_ = false; + // Number of dispatches submitted (cold or replay) in the CURRENT iteration. + // Reset to 0 by publicEndExecute(). Used by publicFlushPendingICB() as the + // upper bound when partial-draining ICB at MPS dispatch boundaries — so on + // replay iterations, we don't re-execute ICB ops that haven't been replayed + // yet (which would re-run them with stale prior-iter bindings). + size_t icbRecordedThisIter_ = 0; + + // Set by dispatch() (both encode and replay paths), cleared in flush() + // after the command buffer has been committed. Gates flush()'s submission + // step so a flush() called twice in a row with no intervening dispatch is + // a no-op. The direct (non-ICB) path was implicitly idempotent because + // commit set commandBuffer_ = nil; ICB needed an explicit flag because + // icbDispatchCount_ and icb_ stay set across flush() for replay. + bool hasPendingWork_ = false; + + // Auto-batching + int flushInterval_ = 40; + bool useICB_ = false; // ICB disabled by default - enable with METAL_USE_ICB=1 + int dispatchCount_ = 0; + + // Thread safety (optional - use setThreadSafe(true) to enable) + bool threadSafe_ = false; + std::mutex mutex_; + + // Metal 4: ResidencySet for GPU-resident memory +#if defined(__MAC_OS_X_VERSION_MAX_ALLOWED) && __MAC_OS_X_VERSION_MAX_ALLOWED >= 150000 + id residencySet_ API_AVAILABLE(macos(15.0), ios(18.0)); +#endif + bool useResidencySet_ = false; +}; + +} // namespace metal_v2 +} // namespace backends +} // namespace executorch diff --git a/backends/portable/runtime/metal_v2/MetalStream.mm b/backends/portable/runtime/metal_v2/MetalStream.mm new file mode 100644 index 00000000000..793e5bc9af2 --- /dev/null +++ b/backends/portable/runtime/metal_v2/MetalStream.mm @@ -0,0 +1,1642 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// MetalStream — the runtime stream itself. Helper classes (MetalHeap, +// MetalBufferPool, MetalKernel, MetalKernelCompiler) live in their own .mm +// files alongside their declarations in MetalStream.h. + +#import "MetalStream.h" +#import +#include +#include +#include + +namespace executorch { +namespace backends { +namespace metal_v2 { + +// Singleton +static MetalStream* defaultStream_ = nullptr; +static thread_local MetalStream* threadLocalStream_ = nullptr; + +MetalStream* MetalStream::getDefault() { + if (!defaultStream_) { + defaultStream_ = new MetalStream(); + } + return defaultStream_; +} + +MetalStream* MetalStream::getThreadLocal() { + if (!threadLocalStream_) { + threadLocalStream_ = new MetalStream(); + } + return threadLocalStream_; +} + +std::unique_ptr MetalStream::create() { + return std::make_unique(); +} + +MetalStream::MetalStream() { + @autoreleasepool { + device_ = MTLCreateSystemDefaultDevice(); + if (!device_) { + ET_LOG(Error, "MetalStream: failed to create Metal device"); + return; + } + [device_ retain]; + + queue_ = [device_ newCommandQueue]; + [queue_ retain]; + + bufferPool_ = std::make_unique(device_); + compiler_ = std::make_unique(device_); + + flushInterval_ = getDefaultFlushInterval(); + + // Check env var for ICB mode (disabled by default) + // ICB mode encodes commands for replay but doesn't handle data dependencies + const char* icbEnv = getenv("METAL_USE_ICB"); + useICB_ = icbEnv && (strcmp(icbEnv, "1") == 0 || strcmp(icbEnv, "true") == 0); + + // Metal 4: Setup ResidencySet for GPU-resident memory + setupResidencySet(); + +#if ET_METAL4_ENABLE + // Metal 4 dispatch path: queue/allocator/arg-table/scratch/event + if (useMTL4()) { + if (@available(macOS 26.0, iOS 26.0, *)) { + setupMTL4(); + } + } +#endif + + ET_LOG(Info, "MetalStream: initialized with device '%s', flush interval=%d, ICB=%s", + [[device_ name] UTF8String], flushInterval_, useICB_ ? "enabled" : "disabled"); + } +} + +MetalStream::~MetalStream() { + @autoreleasepool { + sync(); // flush() + wait() — drains any pending and in-flight work. + + if (icb_) [icb_ release]; + if (argumentBuffer_) [argumentBuffer_ release]; + if (encoder_) [encoder_ release]; + if (commandBuffer_) [commandBuffer_ release]; + if (inFlightCommandBuffer_) [inFlightCommandBuffer_ release]; + +#if ET_METAL4_ENABLE + if (@available(macOS 26.0, iOS 26.0, *)) { + teardownMTL4(); + } +#endif + +#if ET_METAL4_AVAILABLE + if (@available(macOS 15.0, iOS 18.0, *)) { + if (residencySet_) [residencySet_ release]; + } +#endif + + // Release all tracked buffers + for (auto& [ptr, buffer] : ptrToBuffer_) { + [buffer release]; + } + + [queue_ release]; + [device_ release]; + + ET_LOG(Debug, "MetalStream: Destroyed"); + } +} + +int MetalStream::getDefaultFlushInterval() const { + // Determine flush interval based on GPU architecture + // Architecture string ends with: 'p' = iPhone, 'g' = base, 's' = Max, 'd' = Ultra + char suffix = 'g'; + + if (@available(macOS 13.0, iOS 16.0, *)) { + id architecture = [device_ architecture]; + if (architecture) { + NSString* name = [architecture name]; + if (name && [name length] > 0) { + suffix = [name characterAtIndex:[name length] - 1]; + } + } + } + + switch (suffix) { + case 'p': return 20; // iPhone - more conservative + case 'g': return 40; // Base/Pro + case 's': return 50; // Max + case 'd': return 50; // Ultra + default: return 40; + } +} + +void MetalStream::setupResidencySet() { +#if ET_METAL4_AVAILABLE + if (@available(macOS 15.0, iOS 18.0, *)) { + MTLResidencySetDescriptor* desc = [[MTLResidencySetDescriptor alloc] init]; + desc.label = @"GpuStream ResidencySet"; + // Initial capacity - will grow as needed + desc.initialCapacity = 64; + + NSError* error = nil; + residencySet_ = [device_ newResidencySetWithDescriptor:desc error:&error]; + [desc release]; + + if (residencySet_) { + useResidencySet_ = true; + ET_LOG(Info, "MetalStream: Metal 4 ResidencySet enabled"); + } else { + ET_LOG(Info, "MetalStream: ResidencySet not available: %s", + error ? [[error localizedDescription] UTF8String] : "unknown"); + } + } +#endif +} + +void MetalStream::addToResidencySet(id buffer) { +#if ET_METAL4_AVAILABLE + if (@available(macOS 15.0, iOS 18.0, *)) { + if (useResidencySet_ && residencySet_) { + [residencySet_ addAllocation:buffer]; + } + } +#endif +} + +void MetalStream::commitResidencySet() { +#if ET_METAL4_AVAILABLE + if (@available(macOS 15.0, iOS 18.0, *)) { + if (useResidencySet_ && residencySet_) { + // BOTH calls are required: + // commit — applies pending addAllocation:/removeAllocation: changes + // to the set itself. Without this, allocations stay in a + // "pending" state and the set acts as if empty. For MTL4 + // this manifests as kernels reading from never-resident + // memory → silent zeros / input-passthrough output. + // requestResidency — asks the OS to physically page-in the now- + // committed allocations. Best-effort; safe to call again. + [residencySet_ commit]; + [residencySet_ requestResidency]; + ET_LOG(Debug, "MetalStream: Committed ResidencySet (size=%llu bytes)", + (unsigned long long)[residencySet_ allocatedSize]); + } + } +#endif +} + +void MetalStream::enableHeap(size_t heapSizeBytes, bool aliasable) { + if (heap_) { + ET_LOG(Info, "MetalStream: Heap already enabled"); + return; + } + + heap_ = std::make_unique(device_, heapSizeBytes, aliasable); + if (heap_ && heap_->totalSize() > 0) { + useHeap_ = true; + ET_LOG(Info, "MetalStream: Heap enabled (%zu MB)", heapSizeBytes / (1024*1024)); + } +} + +bool MetalStream::supportsGPUAddress() const { +#if ET_METAL4_AVAILABLE + if (@available(macOS 15.0, iOS 18.0, *)) { + return [device_ supportsFamily:MTLGPUFamilyMetal3]; + } +#endif + return false; +} + +#if ET_METAL4_ENABLE +void MetalStream::setupMTL4() { + if (@available(macOS 26.0, iOS 26.0, *)) { + @autoreleasepool { + NSError* err = nil; + + // Command queue + MTL4CommandQueueDescriptor* qDesc = [[MTL4CommandQueueDescriptor alloc] init]; + mtl4Queue_ = [device_ newMTL4CommandQueueWithDescriptor:qDesc error:&err]; + [qDesc release]; + if (!mtl4Queue_ || err) { + ET_LOG(Error, "MetalStream: MTL4CommandQueue creation failed: %s", + err ? [[err localizedDescription] UTF8String] : "unknown"); + mtl4Queue_ = nil; + return; + } + [mtl4Queue_ retain]; + + // Add residency set to MTL4 queue too (so MTL4 cmd buffers see resident memory) + if (residencySet_) { + [mtl4Queue_ addResidencySet:residencySet_]; + } + + // Command allocator + err = nil; + MTL4CommandAllocatorDescriptor* aDesc = [[MTL4CommandAllocatorDescriptor alloc] init]; + mtl4Allocator_ = [device_ newCommandAllocatorWithDescriptor:aDesc error:&err]; + [aDesc release]; + if (!mtl4Allocator_ || err) { + ET_LOG(Error, "MetalStream: MTL4CommandAllocator creation failed: %s", + err ? [[err localizedDescription] UTF8String] : "unknown"); + [mtl4Queue_ release]; mtl4Queue_ = nil; + mtl4Allocator_ = nil; + return; + } + [mtl4Allocator_ retain]; + + // Argument table (single, reused per dispatch) + MTL4ArgumentTableDescriptor* atDesc = [[MTL4ArgumentTableDescriptor alloc] init]; + atDesc.maxBufferBindCount = kMaxBuffersPerDispatch; + err = nil; + mtl4ArgTable_ = [device_ newArgumentTableWithDescriptor:atDesc error:&err]; + [atDesc release]; + if (!mtl4ArgTable_ || err) { + ET_LOG(Error, "MetalStream: MTL4ArgumentTable creation failed: %s", + err ? [[err localizedDescription] UTF8String] : "unknown"); + [mtl4Allocator_ release]; mtl4Allocator_ = nil; + [mtl4Queue_ release]; mtl4Queue_ = nil; + mtl4ArgTable_ = nil; + return; + } + [mtl4ArgTable_ retain]; + + // Inline-scalar bump scratch (1 MB, shared storage) + constexpr size_t kScratchBytes = 1u << 20; + mtl4ScalarScratch_ = [device_ newBufferWithLength:kScratchBytes + options:MTLResourceStorageModeShared]; + [mtl4ScalarScratch_ retain]; + addToResidencySet(mtl4ScalarScratch_); + mtl4ScalarScratchOffset_ = 0; + + // Completion event for wait() + mtl4CompletionEvent_ = [device_ newSharedEvent]; + [mtl4CompletionEvent_ retain]; + mtl4CompletionValue_ = 0; + + ET_LOG(Info, "MetalStream: MTL4 dispatch path initialized " + "(queue+allocator+arg-table+scratch+event)"); + } + } +} + +void MetalStream::teardownMTL4() { + if (@available(macOS 26.0, iOS 26.0, *)) { + if (mtl4Encoder_) { [mtl4Encoder_ release]; mtl4Encoder_ = nil; } + if (mtl4CommandBuffer_) { [mtl4CommandBuffer_ release]; mtl4CommandBuffer_ = nil; } + if (mtl4InFlightCommandBuffer_) { [mtl4InFlightCommandBuffer_ release]; mtl4InFlightCommandBuffer_ = nil; } + if (mtl4ArgTable_) { [mtl4ArgTable_ release]; mtl4ArgTable_ = nil; } + if (mtl4Allocator_) { [mtl4Allocator_ release]; mtl4Allocator_ = nil; } + if (mtl4Queue_) { [mtl4Queue_ release]; mtl4Queue_ = nil; } + if (mtl4ScalarScratch_) { [mtl4ScalarScratch_ release]; mtl4ScalarScratch_ = nil; } + if (mtl4CompletionEvent_) { [mtl4CompletionEvent_ release]; mtl4CompletionEvent_ = nil; } + if (mpsSyncEvent_) { [mpsSyncEvent_ release]; mpsSyncEvent_ = nil; } + for (id f : mtl4Fences_) [f release]; + mtl4Fences_.clear(); + } +} +#endif + +void MetalStream::setupICB() { + if (icb_) return; + + @autoreleasepool { + MTLIndirectCommandBufferDescriptor* desc = [[MTLIndirectCommandBufferDescriptor alloc] init]; + desc.commandTypes = MTLIndirectCommandTypeConcurrentDispatch; + // inheritBuffers depends on dispatch model: + // - MTL3 (legacy): NO. ICB commands carry their own setKernelBuffer: + // bindings, executed by MTLComputeCommandEncoder which has no + // argument-table model. + // - MTL4: YES. ICB commands inherit from the executing + // MTL4ComputeCommandEncoder's argument table. Metal validation + // rejects setKernelBuffer: on a inherit-mode ICB, so the encode + // path also conditionally skips setKernelBuffer: under MTL4. + desc.inheritBuffers = useMTL4() ? YES : NO; + desc.inheritPipelineState = NO; + desc.maxKernelBufferBindCount = kMaxBuffersPerDispatch; + + icb_ = [device_ newIndirectCommandBufferWithDescriptor:desc + maxCommandCount:maxICBCommands_ + options:MTLResourceStorageModeShared]; + [desc release]; + + if (icb_) { + [icb_ retain]; + ET_LOG(Info, "MetalStream: Created ICB with max %zu commands", maxICBCommands_); + // Add the ICB itself to the residency set so it's GPU-resident under + // MTL4 (which has no automatic residency tracking for resources + // referenced via executeCommandsInBuffer:). MTLIndirectCommandBuffer + // conforms to MTLAllocation so addAllocation: accepts it directly. +#if ET_METAL4_AVAILABLE + if (@available(macOS 15.0, iOS 18.0, *)) { + if (useResidencySet_ && residencySet_) { + [residencySet_ addAllocation:icb_]; + ET_LOG(Info, "MetalStream: Added ICB to residency set"); + } + } +#endif + } + } +} + +void MetalStream::setupArgumentBuffer(size_t numDispatches) { + // Calculate required size + // Each dispatch needs: 8 GPU addresses (64 bytes) + 8 scalars (64 bytes) = 128 bytes + size_t bytesPerDispatch = (kMaxBuffersPerDispatch * sizeof(uint64_t)) + + (kMaxScalarsPerDispatch * sizeof(int64_t)); + size_t requiredSize = numDispatches * bytesPerDispatch; + + if (!argumentBuffer_ || argumentBufferSize_ < requiredSize) { + if (argumentBuffer_) { + [argumentBuffer_ release]; + } + + argumentBufferSize_ = std::max(requiredSize, (size_t)(1024 * bytesPerDispatch)); // Pre-allocate for 1024 + argumentBuffer_ = [device_ newBufferWithLength:argumentBufferSize_ + options:MTLResourceStorageModeShared]; + [argumentBuffer_ retain]; + + addToResidencySet(argumentBuffer_); + ET_LOG(Info, "MetalStream: Created argument buffer (%zu bytes)", argumentBufferSize_); + } +} + +void* MetalStream::alloc(size_t bytes) { + id buffer = nil; + + // Try heap first (faster: ~100ns vs ~10µs) + if (useHeap_ && heap_) { + buffer = heap_->allocBuffer(bytes); + } + + // Fallback to buffer pool + if (!buffer) { + buffer = bufferPool_->acquire(bytes); + } + + if (!buffer) { + ET_LOG(Error, "MetalStream::alloc: failed to allocate %zu bytes", bytes); + return nullptr; + } + + void* ptr = [buffer contents]; + ptrToBuffer_[ptr] = buffer; + [buffer retain]; // Keep alive while in ptrToBuffer_ + + // Metal 4: Add to residency set for GPU-resident memory + // This ensures the buffer stays resident on GPU, avoiding page faults + addToResidencySet(buffer); + + ET_LOG(Debug, "MetalStream::alloc: allocated %zu bytes at %p (heap=%d)", + bytes, ptr, useHeap_ && heap_); + return ptr; +} + +void MetalStream::free(void* ptr) { + if (!ptr) return; + + auto it = ptrToBuffer_.find(ptr); + if (it != ptrToBuffer_.end()) { + id buffer = it->second; + ptrToBuffer_.erase(it); + // Only return to pool if we allocated it (not external) + if (!externalBuffers_.count(ptr)) { + bufferPool_->release(buffer); + } + [buffer release]; + } +} + +bool MetalStream::registerExternalBuffer( + void* ptr, size_t bytes, bool strict_zero_copy) { + if (!ptr || bytes == 0) return false; + + // Check if already registered + if (ptrToBuffer_.count(ptr)) { + return true; + } + + // Check alignment - page size is 16KB on ARM64 + bool pageAligned = ((uintptr_t)ptr % 16384) == 0; + ET_LOG(Info, "MetalStream: Registering external buffer %p (%zu bytes, page_aligned=%d, strict_zero_copy=%d)", + ptr, bytes, pageAligned, strict_zero_copy); + + // For unified memory (Apple Silicon), wrap existing memory with MTLBuffer + // This allows GPU to access CPU-allocated memory directly + // Note: Memory must be page-aligned for newBufferWithBytesNoCopy + id buffer = [device_ newBufferWithBytesNoCopy:ptr + length:bytes + options:MTLResourceStorageModeShared + deallocator:nil]; + + if (!buffer) { + if (strict_zero_copy) { + // Caller requires a true alias; refuse rather than silently copying. + ET_LOG(Info, + "MetalStream: zero-copy wrap failed for %p (%zu bytes); strict mode -> refusing fallback", + ptr, bytes); + return false; + } + // Fallback: copy to a new GPU buffer + // WARNING: For output buffers, results won't be visible to CPU! + ET_LOG(Info, "MetalStream: newBufferWithBytesNoCopy failed (not page-aligned?), copying to new buffer"); + buffer = [device_ newBufferWithBytes:ptr + length:bytes + options:MTLResourceStorageModeShared]; + if (!buffer) { + ET_LOG(Error, "MetalStream: Failed to create buffer for external memory %p", ptr); + return false; + } + } else { + ET_LOG(Info, "MetalStream: Zero-copy buffer wrap succeeded"); + } + + [buffer retain]; + ptrToBuffer_[ptr] = buffer; + externalBuffers_.insert(ptr); + + // Add to residency set for GPU access + addToResidencySet(buffer); + + ET_LOG(Info, "MetalStream: Registered external buffer %p -> MTLBuffer %p", ptr, (__bridge void*)buffer); + return true; +} + +void MetalStream::ensureCommandBuffer() { +#if ET_METAL4_ENABLE + if (useMTL4()) { + if (@available(macOS 26.0, iOS 26.0, *)) { + if (mtl4Queue_ && !mtl4CommandBuffer_) { + mtl4CommandBuffer_ = [device_ newCommandBuffer]; + [mtl4CommandBuffer_ retain]; + [mtl4CommandBuffer_ beginCommandBufferWithAllocator:mtl4Allocator_]; + } + if (mtl4Queue_) return; // MTL4 path active; legacy buffer not needed + } + } +#endif + if (!commandBuffer_) { + commandBuffer_ = [queue_ commandBuffer]; + [commandBuffer_ retain]; + } +} + +void MetalStream::ensureEncoder() { +#if ET_METAL4_ENABLE + if (useMTL4()) { + if (@available(macOS 26.0, iOS 26.0, *)) { + if (mtl4Queue_) { + if (!mtl4Encoder_) { + ensureCommandBuffer(); + mtl4Encoder_ = [mtl4CommandBuffer_ computeCommandEncoder]; + [mtl4Encoder_ retain]; + [mtl4Encoder_ setArgumentTable:mtl4ArgTable_]; + } + return; + } + } + } +#endif + if (!encoder_) { + ensureCommandBuffer(); + encoder_ = [commandBuffer_ computeCommandEncoder]; + [encoder_ retain]; + } +} + +void MetalStream::endEncoder() { +#if ET_METAL4_ENABLE + if (@available(macOS 26.0, iOS 26.0, *)) { + if (mtl4Encoder_) { + [mtl4Encoder_ endEncoding]; + [mtl4Encoder_ release]; + mtl4Encoder_ = nil; + } + } +#endif + if (encoder_) { + [encoder_ endEncoding]; + [encoder_ release]; + encoder_ = nil; + } +} + +DispatchSignature MetalStream::buildSignature( + MetalKernel* kernel, + const std::vector& args, + uvec3 grid, + uvec3 block) { + DispatchSignature sig; + sig.kernel = kernel; + sig.grid = grid; + sig.block = block; + + for (const auto& arg : args) { + if (arg.type == Arg::BUFFER) { + sig.bufferSizes.push_back(arg.buffer.size); + } + } + + return sig; +} + +void MetalStream::dispatch( + MetalKernel* kernel, + std::initializer_list argsList, + uvec3 grid, + uvec3 block) { + + // Optional thread safety + std::unique_lock lock(mutex_, std::defer_lock); + if (threadSafe_) { + lock.lock(); + } + + std::vector args(argsList); + DispatchSignature sig = buildSignature(kernel, args, grid, block); + + // Check if we can replay with argument buffer update only + if (icbValid_ && currentDispatchIdx_ < signatures_.size() && + sig == signatures_[currentDispatchIdx_]) { + // Fast path: just update GPU addresses in argument buffer + isReplaying_ = true; + updateArgumentBuffer(currentDispatchIdx_, args); + currentDispatchIdx_++; + icbRecordedThisIter_++; + hasPendingWork_ = true; + return; + } + + // Slow path: need to encode + if (icbValid_) { + // Signature mismatch - invalidate + invalidate(); + } + + isReplaying_ = false; + + // Setup ICB and argument buffer on first dispatch + if (!icb_) { + setupICB(); + } + setupArgumentBuffer(maxICBCommands_); + + // Encode into ICB with argument buffer binding + encodeIntoICB(kernel, args, grid, block); + signatures_.push_back(sig); + // Whether encodeIntoICB took the real ICB branch or fell through to the + // direct encodeDispatch, sync() now has work to do — flag it. + hasPendingWork_ = true; +} + +void MetalStream::encodeIntoICB( + MetalKernel* kernel, + const std::vector& args, + uvec3 grid, + uvec3 block) { + + // Use direct encoding unless ICB is explicitly enabled + // ICB executes commands concurrently which breaks data dependencies + if (!useICB_) { + encodeDispatch(kernel, args, grid, block); + return; + } + + // ICB path - encode commands for potential replay + if (!icb_ || icbDispatchCount_ >= maxICBCommands_) { + ET_LOG(Info, "MetalStream: ICB full or missing, using direct encoding"); + encodeDispatch(kernel, args, grid, block); + return; + } + + auto* metalKernel = static_cast(kernel); + if (!metalKernel || !metalKernel->pipeline()) { + ET_LOG(Error, "MetalStream: Invalid kernel or pipeline for ICB"); + encodeDispatch(kernel, args, grid, block); + return; + } + + // Check if pipeline supports ICB + if (![metalKernel->pipeline() supportIndirectCommandBuffers]) { + ET_LOG(Info, "MetalStream: Pipeline doesn't support ICB, using direct encoding"); + encodeDispatch(kernel, args, grid, block); + return; + } + + // Dependency tracking: check if any input was written by a previous op + // If so, we need a barrier before this dispatch + bool needsBarrier = false; + void* outputBuffer = nullptr; + + // Find output buffer (last buffer arg is typically output) + for (const auto& arg : args) { + if (arg.type == Arg::BUFFER) { + // Check if this input was a previous output + if (writtenBuffers_.count(arg.buffer.ptr)) { + needsBarrier = true; + } + } + } + // Last buffer is output - track it + for (auto it = args.rbegin(); it != args.rend(); ++it) { + if (it->type == Arg::BUFFER) { + outputBuffer = it->buffer.ptr; + break; + } + } + + if (needsBarrier && icbDispatchCount_ > 0) { + // Record barrier point - we'll insert memory barrier before this dispatch + barrierIndices_.push_back(icbDispatchCount_); + ET_LOG(Info, "MetalStream: Barrier needed before ICB[%zu] (RAW dependency)", icbDispatchCount_); + } + + // Track this op's output for future dependency checks + if (outputBuffer) { + writtenBuffers_.insert(outputBuffer); + } + + // Get indirect compute command at current index + id icbCmd = [icb_ indirectComputeCommandAtIndex:icbDispatchCount_]; + + // Set pipeline state + [icbCmd setComputePipelineState:metalKernel->pipeline()]; + // Track for MTL4 flush path (encoder needs to know pipeline before dispatch). + dispatchPipelines_.push_back(metalKernel->pipeline()); + + // Calculate argument buffer offset for this dispatch + size_t bytesPerDispatch = (kMaxBuffersPerDispatch * sizeof(uint64_t)) + + (kMaxScalarsPerDispatch * sizeof(int64_t)); + size_t argOffset = icbDispatchCount_ * bytesPerDispatch; + + // Record layout for replay + DispatchArgLayout layout; + layout.offset = argOffset; + layout.numBuffers = 0; + layout.numScalars = 0; + + // Fill argument buffer with GPU addresses and scalars + char* argBase = (char*)[argumentBuffer_ contents] + argOffset; + uint64_t* gpuAddrs = (uint64_t*)argBase; + int64_t* scalars = (int64_t*)(argBase + kMaxBuffersPerDispatch * sizeof(uint64_t)); + + uint32_t bufferIndex = 0; + uint32_t scalarIndex = 0; + + for (const auto& arg : args) { + switch (arg.type) { + case Arg::BUFFER: { + // Auto-register external buffers + if (!ptrToBuffer_.count(arg.buffer.ptr)) { + registerExternalBuffer(arg.buffer.ptr, arg.buffer.size); + } + + auto it = ptrToBuffer_.find(arg.buffer.ptr); + if (it != ptrToBuffer_.end()) { + // Store GPU address in argument buffer +#if ET_METAL4_AVAILABLE + if (@available(macOS 15.0, iOS 18.0, *)) { + gpuAddrs[bufferIndex] = [it->second gpuAddress]; + } else { + gpuAddrs[bufferIndex] = (uint64_t)(__bridge void*)it->second; + } +#else + gpuAddrs[bufferIndex] = (uint64_t)(__bridge void*)it->second; +#endif + // Bind actual buffer to ICB command (legacy MTL3 path). + // Skip under MTL4: ICB was created with inheritBuffers=YES, and + // Metal validation rejects setKernelBuffer: on inherit-mode ICBs. + // Bindings come from the encoder's arg table at execute time. + if (!useMTL4()) { + [icbCmd setKernelBuffer:it->second offset:0 atIndex:bufferIndex]; + } + bufferIndex++; + layout.numBuffers++; + } + break; + } + case Arg::SCALAR_INT: { + // For ICB, scalars need to be stored in a buffer + // Create a small buffer for the scalar value + scalars[scalarIndex] = arg.scalar_int; + + // Calculate offset into argument buffer for this scalar + size_t scalarOffset = argOffset + kMaxBuffersPerDispatch * sizeof(uint64_t) + scalarIndex * sizeof(int64_t); + if (!useMTL4()) { + [icbCmd setKernelBuffer:argumentBuffer_ offset:scalarOffset atIndex:bufferIndex]; + } + + bufferIndex++; + scalarIndex++; + layout.numScalars++; + break; + } + case Arg::SCALAR_FLOAT: { + scalars[scalarIndex] = (int64_t)arg.scalar_float; + size_t scalarOffset = argOffset + kMaxBuffersPerDispatch * sizeof(uint64_t) + scalarIndex * sizeof(int64_t); + if (!useMTL4()) { + [icbCmd setKernelBuffer:argumentBuffer_ offset:scalarOffset atIndex:bufferIndex]; + } + bufferIndex++; + scalarIndex++; + layout.numScalars++; + break; + } + case Arg::TENSOR: { + // For ICB, treat tensor like a buffer + if (!ptrToBuffer_.count(arg.tensor.ptr)) { + registerExternalBuffer(arg.tensor.ptr, arg.tensor.size); + } + auto it = ptrToBuffer_.find(arg.tensor.ptr); + if (it != ptrToBuffer_.end()) { +#if ET_METAL4_AVAILABLE + if (@available(macOS 15.0, iOS 18.0, *)) { + gpuAddrs[bufferIndex] = [it->second gpuAddress]; + } else { + gpuAddrs[bufferIndex] = (uint64_t)(__bridge void*)it->second; + } +#else + gpuAddrs[bufferIndex] = (uint64_t)(__bridge void*)it->second; +#endif + if (!useMTL4()) { + [icbCmd setKernelBuffer:it->second offset:0 atIndex:bufferIndex]; + } + bufferIndex++; + layout.numBuffers++; + } + break; + } + } + } + + argLayouts_.push_back(layout); + + // Set threadgroup size and dispatch + MTLSize mtlBlock = MTLSizeMake(block.x, block.y, block.z); + MTLSize mtlGrid = MTLSizeMake(grid.x, grid.y, grid.z); + + [icbCmd concurrentDispatchThreadgroups:mtlGrid threadsPerThreadgroup:mtlBlock]; + icbDispatchCount_++; + icbRecordedThisIter_++; + ET_LOG(Info, "MetalStream: Encoded into ICB[%zu]: kernel=%s, grid=(%u,%u,%u)", + icbDispatchCount_-1, kernel->name(), grid.x, grid.y, grid.z); +} + +void MetalStream::encodeDispatch( + MetalKernel* kernel, + const std::vector& args, + uvec3 grid, + uvec3 block) { + + ET_LOG(Info, "MetalStream::encodeDispatch: kernel=%s, args=%zu", + kernel ? kernel->name() : "null", args.size()); + + auto* metalKernel = static_cast(kernel); + if (!metalKernel || !metalKernel->pipeline()) { + ET_LOG(Error, "MetalStream::encodeDispatch: invalid kernel/pipeline"); + return; + } + + ensureEncoder(); + +#if ET_METAL4_ENABLE + if (useMTL4()) { + if (@available(macOS 26.0, iOS 26.0, *)) { + if (mtl4Queue_ && mtl4Encoder_) { + [mtl4Encoder_ setComputePipelineState:metalKernel->pipeline()]; + + // Per-dispatch arg-table updates. Inline scalars go through the + // bump scratch buffer; their GPU address is set on the table slot. + uint32_t bufferIndex = 0; + char* scratchPtr = (char*)[mtl4ScalarScratch_ contents]; + MTLGPUAddress scratchBase = [mtl4ScalarScratch_ gpuAddress]; + const size_t kScratchCap = (size_t)[mtl4ScalarScratch_ length]; + const size_t kAlign = 16; + + for (const auto& arg : args) { + switch (arg.type) { + case Arg::BUFFER: { + if (!ptrToBuffer_.count(arg.buffer.ptr)) { + registerExternalBuffer(arg.buffer.ptr, arg.buffer.size); + } + auto it = ptrToBuffer_.find(arg.buffer.ptr); + if (it != ptrToBuffer_.end()) { + MTLGPUAddress addr = [it->second gpuAddress]; + [mtl4ArgTable_ setAddress:addr atIndex:bufferIndex++]; + } else { + ET_LOG(Error, "MetalStream(MTL4): no buffer for ptr %p", arg.buffer.ptr); + bufferIndex++; + } + break; + } + case Arg::SCALAR_INT: { + // Align + bump scratch + size_t off = (mtl4ScalarScratchOffset_ + kAlign - 1) & ~(kAlign - 1); + if (off + sizeof(int64_t) > kScratchCap) { + ET_LOG(Error, "MetalStream(MTL4): scalar scratch exhausted (call flush more often)"); + bufferIndex++; + break; + } + memcpy(scratchPtr + off, &arg.scalar_int, sizeof(int64_t)); + mtl4ScalarScratchOffset_ = off + sizeof(int64_t); + [mtl4ArgTable_ setAddress:(scratchBase + off) atIndex:bufferIndex++]; + break; + } + case Arg::SCALAR_FLOAT: { + float f = static_cast(arg.scalar_float); + size_t off = (mtl4ScalarScratchOffset_ + kAlign - 1) & ~(kAlign - 1); + if (off + sizeof(float) > kScratchCap) { + ET_LOG(Error, "MetalStream(MTL4): scalar scratch exhausted"); + bufferIndex++; + break; + } + memcpy(scratchPtr + off, &f, sizeof(float)); + mtl4ScalarScratchOffset_ = off + sizeof(float); + [mtl4ArgTable_ setAddress:(scratchBase + off) atIndex:bufferIndex++]; + break; + } + case Arg::TENSOR: { + // For now: bind underlying buffer (same as legacy path) + if (!ptrToBuffer_.count(arg.tensor.ptr)) { + registerExternalBuffer(arg.tensor.ptr, arg.tensor.size); + } + auto it = ptrToBuffer_.find(arg.tensor.ptr); + if (it != ptrToBuffer_.end()) { + MTLGPUAddress addr = [it->second gpuAddress]; + [mtl4ArgTable_ setAddress:addr atIndex:bufferIndex++]; + } else { + bufferIndex++; + } + break; + } + } + } + + MTLSize mtlGrid = MTLSizeMake(grid.x, grid.y, grid.z); + MTLSize mtlBlock = MTLSizeMake(block.x, block.y, block.z); + ET_LOG(Info, "MetalStream(MTL4): dispatching grid=(%u,%u,%u), block=(%u,%u,%u)", + (uint)mtlGrid.width, (uint)mtlGrid.height, (uint)mtlGrid.depth, + (uint)mtlBlock.width, (uint)mtlBlock.height, (uint)mtlBlock.depth); + // Re-bind the (now-mutated) argument table just before dispatch. + // setArgumentTable: snapshots the table state — without re-binding + // after our setAddress: mutations, the encoder would dispatch with + // stale (encoder-creation-time) table contents. + [mtl4Encoder_ setArgumentTable:mtl4ArgTable_]; + [mtl4Encoder_ dispatchThreadgroups:mtlGrid threadsPerThreadgroup:mtlBlock]; + // Insert a memory barrier so the *next* dispatch in this encoder + // sees this dispatch's writes. MTL4 has no automatic hazard + // tracking — without an explicit intra-encoder barrier, multiple + // dispatches in the same encoder may run concurrently, violating + // RAW dependencies for any model where one op reads another op's + // output (matmul→matmul, conv→relu, etc.). This is conservative + // (every dispatch gets a barrier even when independent); a future + // optimization would track per-arg producer/consumer like the + // ICB path's barrierIndices_ logic does. + [mtl4Encoder_ barrierAfterEncoderStages:MTLStageDispatch + beforeEncoderStages:MTLStageDispatch + visibilityOptions:MTL4VisibilityOptionDevice]; + icbDispatchCount_++; + icbRecordedThisIter_++; + return; + } + } + } +#endif + + // ----- Legacy path ----- + [encoder_ setComputePipelineState:metalKernel->pipeline()]; + + uint32_t bufferIndex = 0; + for (const auto& arg : args) { + switch (arg.type) { + case Arg::BUFFER: { + // Auto-register external buffers + if (!ptrToBuffer_.count(arg.buffer.ptr)) { + registerExternalBuffer(arg.buffer.ptr, arg.buffer.size); + } + + auto it = ptrToBuffer_.find(arg.buffer.ptr); + if (it != ptrToBuffer_.end()) { + [encoder_ setBuffer:it->second offset:0 atIndex:bufferIndex++]; + } else { + // Registration failed, use setBytes as last resort (copies data) + ET_LOG(Info, "MetalStream: Using setBytes for ptr %p (%zu bytes)", + arg.buffer.ptr, arg.buffer.size); + [encoder_ setBytes:arg.buffer.ptr length:arg.buffer.size atIndex:bufferIndex++]; + } + break; + } + case Arg::SCALAR_INT: + [encoder_ setBytes:&arg.scalar_int length:sizeof(int64_t) atIndex:bufferIndex++]; + break; + case Arg::SCALAR_FLOAT: { + float f = static_cast(arg.scalar_float); + [encoder_ setBytes:&f length:sizeof(float) atIndex:bufferIndex++]; + break; + } + case Arg::TENSOR: { +#if __has_include() && defined(__MAC_OS_X_VERSION_MAX_ALLOWED) && __MAC_OS_X_VERSION_MAX_ALLOWED >= 260000 + if (@available(macOS 26.0, iOS 26.0, *)) { + // Metal 4.1: Create MTLTensor from buffer + // First ensure buffer is registered + if (!ptrToBuffer_.count(arg.tensor.ptr)) { + registerExternalBuffer(arg.tensor.ptr, arg.tensor.size); + } + + auto it = ptrToBuffer_.find(arg.tensor.ptr); + if (it != ptrToBuffer_.end()) { + // Create tensor descriptor + MTLTensorDescriptor* desc = [[MTLTensorDescriptor alloc] init]; + + // Set dimensions + NSInteger dims[8]; + NSInteger strides[8]; + for (int i = 0; i < arg.tensor.rank; i++) { + dims[i] = arg.tensor.dims[i]; + strides[i] = arg.tensor.strides[i]; + } + desc.dimensions = [[MTLTensorExtents alloc] initWithRank:arg.tensor.rank values:dims]; + desc.strides = [[MTLTensorExtents alloc] initWithRank:arg.tensor.rank values:strides]; + desc.dataType = (MTLTensorDataType)arg.tensor.dtype; + desc.usage = MTLTensorUsageCompute; + + NSError* error = nil; + id tensor = [device_ newTensorWithDescriptor:desc error:&error]; + if (tensor) { + // Tensor wraps the buffer - set it at the buffer index + // Note: MTLTensor conforms to MTLResource, bind via argument buffer or + // use the underlying buffer + [encoder_ setBuffer:it->second offset:0 atIndex:bufferIndex++]; + ET_LOG(Info, "MetalStream: Created tensor [%lld x %lld] at index %u", + arg.tensor.dims[0], arg.tensor.dims[1], bufferIndex - 1); + } else { + ET_LOG(Error, "MetalStream: Failed to create tensor: %s", + error ? [[error localizedDescription] UTF8String] : "unknown"); + // Fallback to buffer + [encoder_ setBuffer:it->second offset:0 atIndex:bufferIndex++]; + } + } + } else +#endif + { + // Fallback: treat as buffer for older Metal versions + if (!ptrToBuffer_.count(arg.tensor.ptr)) { + registerExternalBuffer(arg.tensor.ptr, arg.tensor.size); + } + auto it = ptrToBuffer_.find(arg.tensor.ptr); + if (it != ptrToBuffer_.end()) { + [encoder_ setBuffer:it->second offset:0 atIndex:bufferIndex++]; + } + } + break; + } + } + } + + MTLSize mtlGrid = MTLSizeMake(grid.x, grid.y, grid.z); + MTLSize mtlBlock = MTLSizeMake(block.x, block.y, block.z); + + ET_LOG(Info, "MetalStream: dispatching grid=(%u,%u,%u), block=(%u,%u,%u)", + (uint)mtlGrid.width, (uint)mtlGrid.height, (uint)mtlGrid.depth, + (uint)mtlBlock.width, (uint)mtlBlock.height, (uint)mtlBlock.depth); + + [encoder_ dispatchThreadgroups:mtlGrid threadsPerThreadgroup:mtlBlock]; + icbDispatchCount_++; + icbRecordedThisIter_++; +} + +void MetalStream::updateArgumentBuffer(size_t dispatchIdx, const std::vector& args) { + if (dispatchIdx >= argLayouts_.size()) return; + + const auto& layout = argLayouts_[dispatchIdx]; + char* argBase = (char*)[argumentBuffer_ contents] + layout.offset; + uint64_t* gpuAddrs = (uint64_t*)argBase; + int64_t* scalarSlots = + (int64_t*)(argBase + kMaxBuffersPerDispatch * sizeof(uint64_t)); + + // Fast-path replay. We skip re-encoding the ICB command, but the kernel ABI + // reads its inputs via [[buffer(N)]] bindings (set with setKernelBuffer + // during encoding), NOT via dereferencing the GPU addresses we stash in + // gpuAddrs[]. So when the caller hands us new MTLBuffer objects for the + // same logical args (typical AOTI pattern: fresh per-execute allocations), + // we must re-bind them on the existing ICB command — otherwise the kernel + // reads from the encode-time MTLBuffers, which by now have been recycled + // by the buffer pool and contain stale or unrelated data. + // + // Cost: a handful of setKernelBuffer calls per dispatch (~hundreds of ns + // each on Apple Silicon) — still well under a single re-encode. + id icbCmd = + icb_ ? [icb_ indirectComputeCommandAtIndex:dispatchIdx] : nil; + + size_t bufIdx = 0; // index into gpuAddrs[] (counts buffer / tensor args) + size_t scalarIdx = 0; // index into scalarSlots[] (counts scalar args) + uint32_t slot = 0; // slot on the ICB command (counts ALL kinds of args) + for (const auto& arg : args) { + switch (arg.type) { + case Arg::BUFFER: { + if (bufIdx < layout.numBuffers) { + auto it = ptrToBuffer_.find(arg.buffer.ptr); + if (it != ptrToBuffer_.end()) { +#if ET_METAL4_AVAILABLE + if (@available(macOS 15.0, iOS 18.0, *)) { + gpuAddrs[bufIdx] = [it->second gpuAddress]; + } else { + gpuAddrs[bufIdx] = (uint64_t)(__bridge void*)it->second; + } +#else + gpuAddrs[bufIdx] = (uint64_t)(__bridge void*)it->second; +#endif + // Re-bind the actual MTLBuffer on the ICB command. Without this, + // the kernel reads from the encode-time buffer (which may have + // been freed and reused by the pool). + // Skip under MTL4: ICB created with inheritBuffers=YES; bindings + // come from the encoder's argument table (populated by the MTL4 + // partial-flush / final-flush logic from argLayouts_). + if (icbCmd && !useMTL4()) { + [icbCmd setKernelBuffer:it->second offset:0 atIndex:slot]; + } + bufIdx++; + } + } + slot++; + break; + } + case Arg::TENSOR: { + if (bufIdx < layout.numBuffers) { + auto it = ptrToBuffer_.find(arg.tensor.ptr); + if (it != ptrToBuffer_.end()) { +#if ET_METAL4_AVAILABLE + if (@available(macOS 15.0, iOS 18.0, *)) { + gpuAddrs[bufIdx] = [it->second gpuAddress]; + } else { + gpuAddrs[bufIdx] = (uint64_t)(__bridge void*)it->second; + } +#else + gpuAddrs[bufIdx] = (uint64_t)(__bridge void*)it->second; +#endif + if (icbCmd && !useMTL4()) { + [icbCmd setKernelBuffer:it->second offset:0 atIndex:slot]; + } + bufIdx++; + } + } + slot++; + break; + } + case Arg::SCALAR_INT: { + // Scalar values are written into argumentBuffer_; the binding to + // argumentBuffer_ itself doesn't change between iterations, but the + // VALUE the kernel reads might (e.g. dynamic-shape models passing + // changing M/K/N). Update the value in-place. + if (scalarIdx < layout.numScalars) { + scalarSlots[scalarIdx] = arg.scalar_int; + scalarIdx++; + } + slot++; + break; + } + case Arg::SCALAR_FLOAT: { + if (scalarIdx < layout.numScalars) { + // encodeIntoICB stored scalar_float as int64_t (cast from double). + // Match the same cast here so kernels see consistent bit patterns. + scalarSlots[scalarIdx] = (int64_t)arg.scalar_float; + scalarIdx++; + } + slot++; + break; + } + } + } +} + +// flush() submits any pending work to the GPU command queue. Non-blocking: +// returns as soon as the command buffer is committed; the GPU may still be +// processing when this returns. Idempotent: a no-op if no dispatch() has +// happened since the last flush. +// ============================================================================= +// MPS-bridge internal helpers — moved out of MetalStream.h to keep header +// lightweight and reduce inline-induced header-pull-in of Metal types. +// ============================================================================= + +void MetalStream::flushPendingICB() { + if (!useICB_ || !icb_) return; + size_t upper = std::min(icbRecordedThisIter_, icbDispatchCount_); + if (upper <= icbExecutedCount_) return; + ensureCommandBuffer(); + + std::vector ends; + for (size_t b : barrierIndices_) { + if (b > icbExecutedCount_ && b < upper) { + ends.push_back(b); + } + } + ends.push_back(upper); + +#if ET_METAL4_ENABLE + if (useMTL4()) { + if (@available(macOS 26.0, iOS 26.0, *)) { + size_t start = icbExecutedCount_; + for (size_t i = 0; i < ends.size(); ++i) { + size_t end = ends[i]; + if (end > start) { + id enc = + [mtl4CommandBuffer_ computeCommandEncoder]; + if (start < dispatchPipelines_.size() && start < argLayouts_.size()) { + [enc setComputePipelineState:dispatchPipelines_[start]]; + const auto& layout = argLayouts_[start]; + char* argBase = (char*)[argumentBuffer_ contents] + layout.offset; + uint64_t* gpuAddrs = (uint64_t*)argBase; + MTLGPUAddress argBufBase = [argumentBuffer_ gpuAddress]; + size_t scalarBase = layout.offset + kMaxBuffersPerDispatch * sizeof(uint64_t); + for (size_t j = 0; j < layout.numBuffers; ++j) { + [mtl4ArgTable_ setAddress:gpuAddrs[j] atIndex:j]; + } + for (size_t j = 0; j < layout.numScalars; ++j) { + MTLGPUAddress sAddr = argBufBase + scalarBase + j * sizeof(int64_t); + [mtl4ArgTable_ setAddress:sAddr atIndex:layout.numBuffers + j]; + } + } + [enc setArgumentTable:mtl4ArgTable_]; + NSRange range = NSMakeRange(start, end - start); + [enc executeCommandsInBuffer:icb_ withRange:range]; + [enc endEncoding]; + } + start = end; + } + icbExecutedCount_ = upper; + hasPendingWork_ = true; + return; + } + } +#endif + + size_t start = icbExecutedCount_; + for (size_t i = 0; i < ends.size(); ++i) { + size_t end = ends[i]; + if (end > start) { + id enc = [commandBuffer_ computeCommandEncoder]; + if (argumentBuffer_) [enc useResource:argumentBuffer_ usage:MTLResourceUsageRead]; + for (auto& [ptr, buffer] : ptrToBuffer_) { + [enc useResource:buffer usage:MTLResourceUsageRead | MTLResourceUsageWrite]; + } + NSRange range = NSMakeRange(start, end - start); + [enc executeCommandsInBuffer:icb_ withRange:range]; + if (i < ends.size() - 1) { + [enc memoryBarrierWithScope:MTLBarrierScopeBuffers]; + } + [enc endEncoding]; + } + start = end; + } + icbExecutedCount_ = upper; + hasPendingWork_ = true; +} + +id MetalStream::getOrCreateLegacyCommandBuffer() { + if (!commandBuffer_) { + commandBuffer_ = [queue_ commandBuffer]; + [commandBuffer_ retain]; + } + hasPendingWork_ = true; + return commandBuffer_; +} + +void MetalStream::adoptLegacyCommandBuffer(id newCB) { + if (newCB == commandBuffer_) { + hasPendingWork_ = true; + return; + } + if (commandBuffer_) { + [commandBuffer_ release]; + commandBuffer_ = nil; + } + if (newCB) { + commandBuffer_ = newCB; + [commandBuffer_ retain]; + } + hasPendingWork_ = true; +} + +void MetalStream::releaseLegacyCommandBuffer() { + if (commandBuffer_) { + [commandBuffer_ release]; + commandBuffer_ = nil; + } +} + +#if ET_METAL4_ENABLE +void MetalStream::mtl4CommitAndSignal(id event, uint64_t value) { + if (!useMTL4()) return; + if (@available(macOS 26.0, iOS 26.0, *)) { + flushPendingICB(); + commitResidencySet(); + if (mtl4CommandBuffer_) { + ++mtl4CompletionValue_; + [mtl4CommandBuffer_ endCommandBuffer]; + const id bufs[1] = { mtl4CommandBuffer_ }; + [mtl4Queue_ commit:bufs count:1]; + if (mtl4InFlightCommandBuffer_) [mtl4InFlightCommandBuffer_ release]; + mtl4InFlightCommandBuffer_ = mtl4CommandBuffer_; + mtl4CommandBuffer_ = nil; + mtl4ScalarScratchOffset_ = 0; + [mtl4Queue_ signalEvent:mtl4CompletionEvent_ value:mtl4CompletionValue_]; + [mtl4Queue_ signalEvent:event value:value]; + } else { + [mtl4Queue_ signalEvent:event value:value]; + } + } +} + +void MetalStream::mtl4QueueWait(id event, uint64_t value) { + if (!useMTL4()) return; + if (@available(macOS 26.0, iOS 26.0, *)) { + [mtl4Queue_ waitForEvent:event value:value]; + } +} + +id MetalStream::getOrCreateMpsSyncEvent() { + if (!mpsSyncEvent_) { + mpsSyncEvent_ = [device_ newSharedEvent]; + [mpsSyncEvent_ retain]; + } + return mpsSyncEvent_; +} +#endif + +// ============================================================================= +// encodeWithLegacyCommandBuffer — single high-level API for ops that need a +// legacy MTLCommandBuffer (currently only MPSGraph). +// +// Replaces what used to be a 9-method dance MPSGraphOp had to perform. +// All the wait/signal/commit/adopt orchestration lives here now. +// ============================================================================= +void MetalStream::encodeWithLegacyCommandBuffer( + std::function encode_fn) { + // Both paths first close any open compute encoder and drain pending ICB + // ops into the live cb (so they execute BEFORE the MPS work). + endEncoderInternal(); + flushPendingICB(); + +#if ET_METAL4_ENABLE + if (useMTL4()) { + if (@available(macOS 26.0, iOS 26.0, *)) { + // ── MTL4 path: async cross-queue sync ── + // + // mtl4Queue commits prior MTL4 work + signals event=preValue. + // Fresh legacy cb encodes wait(preValue) → MPS body → signal(postValue), + // commits fire-and-forget. mtl4Queue then waits postValue before any + // subsequent mtl4 cb commit. wait() at execute end blocks on postValue. + id event = getOrCreateMpsSyncEvent(); + uint64_t preValue = mpsSyncCounter_++; + uint64_t postValue = mpsSyncCounter_++; + + mtl4CommitAndSignal(event, preValue); + + id cb = [queue_ commandBuffer]; + MPSCommandBuffer* mpsCB = + [MPSCommandBuffer commandBufferWithCommandBuffer:cb]; + + [mpsCB.commandBuffer encodeWaitForEvent:event value:preValue]; + encode_fn(mpsCB); + [mpsCB.commandBuffer encodeSignalEvent:event value:postValue]; + + [mpsCB commit]; + + mtl4QueueWait(event, postValue); + // Track outstanding MPS event so wait() at end-of-execute blocks + // until MPS work signals. + pendingMpsEvent_ = event; + pendingMpsEventValue_ = postValue; + return; + } + } +#endif + + // ── MTL3 path: adopt-and-share single legacy cb ── + // + // MPSGraph encodes into MetalStream's live legacy cb. Subsequent kernel + // dispatches encode into the same cb. Single end-of-execute commit. + // MPSGraph may internally call commitAndContinue:, in which case the + // original cb gets sealed and mpsCB.commandBuffer points to a fresh one; + // we adopt that fresh one so subsequent dispatches go to the right place. + id cb = getOrCreateLegacyCommandBuffer(); + MPSCommandBuffer* mpsCB = + [MPSCommandBuffer commandBufferWithCommandBuffer:cb]; + encode_fn(mpsCB); + adoptLegacyCommandBuffer(mpsCB.commandBuffer); +} + +void MetalStream::flush() { + // Optional thread safety + std::unique_lock lock(mutex_, std::defer_lock); + if (threadSafe_) { + lock.lock(); + } + + if (!hasPendingWork_) { + // Nothing dispatched since last flush — keep flush() cheap & re-entrant. + return; + } + + // End any active direct encoder first. + endEncoder(); + + // Metal 4: Commit residency set before execution. + commitResidencySet(); + + // ICB path: encode the (already-recorded) ICB commands into segmented + // encoders inside our command buffer. Direct path skips this — its + // dispatchThreadgroups calls already populated commandBuffer_'s encoder. + if (icbDispatchCount_ > 0 && icb_ && useICB_) { + ensureCommandBuffer(); + + // Build barrier points list (add end as final point) + std::vector barriers = barrierIndices_; + barriers.push_back(icbDispatchCount_); // Final segment ends at dispatch count + + bool icbExecuted = false; + +#if ET_METAL4_ENABLE + if (useMTL4()) { + if (@available(macOS 26.0, iOS 26.0, *)) { + if (mtl4CommandBuffer_) { + // Lazily grow the fence pool. Need (numSegments - 1) fences to + // chain segment[i] -> segment[i+1]. + size_t neededFences = barriers.size() > 1 ? barriers.size() - 1 : 0; + while (mtl4Fences_.size() < neededFences) { + id f = [device_ newFence]; + mtl4Fences_.push_back(f); + } + + size_t segmentStart = icbExecutedCount_; // skip already-drained range + for (size_t barrierIdx = 0; barrierIdx < barriers.size(); barrierIdx++) { + size_t segmentEnd = barriers[barrierIdx]; + // Skip barriers that fall inside the already-drained range + // (set by a prior publicFlushPendingICB() call from MPSGraphOp). + if (segmentEnd <= segmentStart) continue; + if (segmentEnd > segmentStart) { + id enc = + [mtl4CommandBuffer_ computeCommandEncoder]; + + // Hypothesis: under MTL4, kernel [[buffer(N)]] reads come + // from the ENCODER'S argument table, not from per-command + // setKernelBuffer: bindings recorded in the legacy ICB. + // To verify, populate the arg table with the same bindings + // the ICB command at segmentStart was encoded with. This + // only works if all dispatches in the segment share the + // same bindings — which is true here because we put a + // barrier between every dependent dispatch (1 cmd/segment). + // Per-segment table state means the ICB command's per-cmd + // setKernelBuffer is functionally redundant on MTL4 — but + // necessary on MTL3. + if (segmentStart < argLayouts_.size()) { + const auto& layout = argLayouts_[segmentStart]; + char* argBase = (char*)[argumentBuffer_ contents] + layout.offset; + uint64_t* gpuAddrs = (uint64_t*)argBase; + MTLGPUAddress argBufBase = [argumentBuffer_ gpuAddress]; + size_t scalarBase = layout.offset + kMaxBuffersPerDispatch * sizeof(uint64_t); + // Set the encoder's pipeline state to the first dispatch in + // this segment. MTL4 requires setComputePipelineState BEFORE + // setArgumentTable: takes effect for the upcoming dispatch. + if (segmentStart < dispatchPipelines_.size()) { + [enc setComputePipelineState:dispatchPipelines_[segmentStart]]; + } + // Buffers: slots [0, numBuffers) — addresses live in gpuAddrs[]. + for (size_t i = 0; i < layout.numBuffers; ++i) { + [mtl4ArgTable_ setAddress:gpuAddrs[i] atIndex:i]; + } + // Scalars: slots [numBuffers, numBuffers+numScalars) — addresses + // are inside argumentBuffer_ at scalarBase + i*8. + for (size_t i = 0; i < layout.numScalars; ++i) { + MTLGPUAddress sAddr = argBufBase + scalarBase + i * sizeof(int64_t); + [mtl4ArgTable_ setAddress:sAddr atIndex:layout.numBuffers + i]; + } + } + [enc setArgumentTable:mtl4ArgTable_]; + + // Wait on the fence signaled by the previous segment. + if (barrierIdx > 0) { + [enc waitForFence:mtl4Fences_[barrierIdx - 1] + beforeEncoderStages:MTLStageDispatch]; + } + + // No useResource: needed -- residency set added to mtl4Queue_ + // covers all our buffers. + NSRange range = NSMakeRange(segmentStart, segmentEnd - segmentStart); + [enc executeCommandsInBuffer:icb_ withRange:range]; + + // Signal a fence for the next segment to wait on. + if (barrierIdx < barriers.size() - 1) { + [enc updateFence:mtl4Fences_[barrierIdx] + afterEncoderStages:MTLStageDispatch]; + } + + [enc endEncoding]; + + ET_LOG(Info, "MetalStream(MTL4)::flush: Executed ICB segment [%zu-%zu) (%zu commands)", + segmentStart, segmentEnd, segmentEnd - segmentStart); + } + segmentStart = segmentEnd; + } + ET_LOG(Info, "MetalStream(MTL4)::flush: Executed ICB with %zu commands, %zu fence-barriers", + icbDispatchCount_, barriers.size() > 1 ? barriers.size() - 1 : 0); + icbExecutedCount_ = icbDispatchCount_; // mark drained + icbValid_ = true; + icbExecuted = true; + } + } + } +#endif + + if (!icbExecuted) { + // Legacy ICB execution path. + // ICB+MPS coexistence: skip commands at indices < icbExecutedCount_; + // those were already encoded into the cmd buffer by an earlier + // publicFlushPendingICB() call (triggered by MPSGraphOp before its + // own encode). Only the tail [icbExecutedCount_, icbDispatchCount_) + // remains. + size_t segmentStart = icbExecutedCount_; + for (size_t barrierIdx = 0; barrierIdx < barriers.size(); barrierIdx++) { + size_t segmentEnd = barriers[barrierIdx]; + if (segmentEnd <= segmentStart) continue; // already executed by partial flush + + if (segmentEnd > segmentStart) { + id enc = [commandBuffer_ computeCommandEncoder]; + + if (argumentBuffer_) { + [enc useResource:argumentBuffer_ usage:MTLResourceUsageRead]; + } + for (auto& [ptr, buffer] : ptrToBuffer_) { + [enc useResource:buffer usage:MTLResourceUsageRead | MTLResourceUsageWrite]; + } + + NSRange range = NSMakeRange(segmentStart, segmentEnd - segmentStart); + [enc executeCommandsInBuffer:icb_ withRange:range]; + + if (barrierIdx < barriers.size() - 1) { + [enc memoryBarrierWithScope:MTLBarrierScopeBuffers]; + } + + [enc endEncoding]; + + ET_LOG(Info, "MetalStream::flush: Executed ICB segment [%zu-%zu) (%zu commands)", + segmentStart, segmentEnd, segmentEnd - segmentStart); + } + segmentStart = segmentEnd; + } + + ET_LOG(Info, "MetalStream::flush: Executed ICB with %zu commands, %zu barriers (icbExecutedCount=%zu before flush)", + icbDispatchCount_, barrierIndices_.size(), icbExecutedCount_); + + icbExecutedCount_ = icbDispatchCount_; // tail now drained + // Mark ICB as valid for replay across future execute()s. + icbValid_ = true; + } + } + + // For replay across executes: reset icbExecutedCount_ so the next execute + // (which re-runs partial flushes from MPSGraphOp at the same op points) + // sees the right "nothing executed yet" baseline. Note icbDispatchCount_ + // also resets between executes via the existing replay-vs-encode logic. + icbExecutedCount_ = 0; + + // Commit (non-blocking). Stash the buffer so wait() can block on it later. +#if ET_METAL4_ENABLE + if (useMTL4()) { + if (@available(macOS 26.0, iOS 26.0, *)) { + if (mtl4CommandBuffer_) { + // Drain prior MTL4 in-flight buffer if caller flushed twice without wait(). + if (mtl4InFlightCommandBuffer_) { + if (mtl4CompletionValue_ > 0) { + [mtl4CompletionEvent_ waitUntilSignaledValue:mtl4CompletionValue_ + timeoutMS:UINT64_MAX]; + } + [mtl4InFlightCommandBuffer_ release]; + mtl4InFlightCommandBuffer_ = nil; + } + // Order matters: signalEvent enqueues a signal that fires after all + // work *previously* committed completes. To signal "after our cmd + // buffer", we must commit FIRST, then signalEvent. + // (Earlier comment claimed signal-then-commit was verified — it was + // not; the standalone test used waitUntilCompleted on the cmd buffer + // directly, which worked for a different reason.) + ++mtl4CompletionValue_; + [mtl4CommandBuffer_ endCommandBuffer]; + const id bufs[1] = { mtl4CommandBuffer_ }; + [mtl4Queue_ commit:bufs count:1]; + [mtl4Queue_ signalEvent:mtl4CompletionEvent_ value:mtl4CompletionValue_]; + mtl4InFlightCommandBuffer_ = mtl4CommandBuffer_; + mtl4CommandBuffer_ = nil; + // Reset per-flush state for MTL4 path. + mtl4ScalarScratchOffset_ = 0; + } + } + } +#endif + if (commandBuffer_) { + if (inFlightCommandBuffer_) { + // Caller flushed twice without an intervening wait(). Drain the older + // submission first so we don't leak completion ownership of two cbufs. + [inFlightCommandBuffer_ waitUntilCompleted]; + if ([inFlightCommandBuffer_ status] == MTLCommandBufferStatusError) { + ET_LOG(Error, "MetalStream: prior in-flight command buffer error: %s", + [[inFlightCommandBuffer_ error] localizedDescription].UTF8String); + } + [inFlightCommandBuffer_ release]; + inFlightCommandBuffer_ = nil; + } + [commandBuffer_ commit]; + inFlightCommandBuffer_ = commandBuffer_; // ownership transfer; will be released in wait() + commandBuffer_ = nil; + } + + // Reset per-batch state. Keep icbDispatchCount_, icbValid_, barrierIndices_ + // — they're needed on replay. On signature change, invalidate() clears them. + // Note: currentDispatchIdx_ and icbRecordedThisIter_ are reset in + // publicEndExecute() (called at execute boundary), not here. flush() can + // be called mid-execute (by metal_copy_memory or future paths) and must + // not clobber per-iteration state. + isReplaying_ = false; + hasPendingWork_ = false; + writtenBuffers_.clear(); +} + +// wait() blocks until all previously-flushed work has completed. Calls +// flush() implicitly so callers don't have to remember the pair. +// Idempotent: a no-op if nothing in flight and nothing pending. +void MetalStream::wait() { + // Push out anything still pending — keeps wait() the "drain" primitive + // most callers expect. flush() takes the mutex; we don't here, since the + // wait is on a property of the in-flight command buffer (not our mutable + // state). + flush(); + +#if ET_METAL4_ENABLE + if (useMTL4()) { + if (@available(macOS 26.0, iOS 26.0, *)) { + if (mtl4InFlightCommandBuffer_) { + if (mtl4CompletionValue_ > 0) { + [mtl4CompletionEvent_ waitUntilSignaledValue:mtl4CompletionValue_ + timeoutMS:UINT64_MAX]; + } + [mtl4InFlightCommandBuffer_ release]; + mtl4InFlightCommandBuffer_ = nil; + } + } + } +#endif + + if (inFlightCommandBuffer_) { + [inFlightCommandBuffer_ waitUntilCompleted]; + + if ([inFlightCommandBuffer_ status] == MTLCommandBufferStatusError) { + ET_LOG(Error, "MetalStream: command buffer error: %s", + [[inFlightCommandBuffer_ error] localizedDescription].UTF8String); + } + + [inFlightCommandBuffer_ release]; + inFlightCommandBuffer_ = nil; + } + +#if ET_METAL4_ENABLE + // Wait for any outstanding MPS work committed during this execute on the + // legacy queue_ via MPSGraphOp's async path. We don't hold MTLCommandBuffer + // handles for those (fire-and-forget commits), so we sync via the shared + // event MPSGraphOp signals at end-of-MPS-encode. + if (pendingMpsEvent_ && pendingMpsEventValue_ > 0) { + [pendingMpsEvent_ waitUntilSignaledValue:pendingMpsEventValue_ + timeoutMS:UINT64_MAX]; + pendingMpsEvent_ = nil; + pendingMpsEventValue_ = 0; + } +#endif +} + +void MetalStream::invalidate() { + icbValid_ = false; + signatures_.clear(); + argLayouts_.clear(); + dispatchPipelines_.clear(); + currentDispatchIdx_ = 0; + icbDispatchCount_ = 0; + argumentBufferOffset_ = 0; + isReplaying_ = false; + hasPendingWork_ = false; + writtenBuffers_.clear(); + barrierIndices_.clear(); +} + +void MetalStream::setFlushInterval(int dispatches) { + flushInterval_ = dispatches; +} + +} // namespace metal_v2 +} // namespace backends +} // namespace executorch diff --git a/backends/portable/runtime/metal_v2/MetalTypes.h b/backends/portable/runtime/metal_v2/MetalTypes.h new file mode 100644 index 00000000000..cc436720ae4 --- /dev/null +++ b/backends/portable/runtime/metal_v2/MetalTypes.h @@ -0,0 +1,124 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +// Backend-shared types used by MetalStream, MetalKernel, MetalKernelCompiler, +// and op implementations. +// +// HISTORICAL NOTE: this file used to define abstract base classes +// `GpuStream`, `GpuKernel`, `GpuKernelCompiler` that pretended to be +// backend-agnostic. They weren't — terminology and assumptions baked in +// (ICB, metallib, MTLComputePipelineState references, etc.) made them +// Metal-specific. With Metal as the only impl, the abstraction added +// boilerplate without value, so the classes were collapsed into their +// `Metal*` concrete versions. This header keeps the genuinely-shared +// types — argument representation, kernel grid/block layout, dtype +// helpers — that are used across MetalStream / op files. + +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace metal_v2 { + +// Use exec_aten::ScalarType directly. +using ScalarType = exec_aten::ScalarType; + +//===----------------------------------------------------------------------===// +// Basic Types +//===----------------------------------------------------------------------===// + +struct uvec3 { + uint32_t x, y, z; + uvec3(uint32_t x = 1, uint32_t y = 1, uint32_t z = 1) : x(x), y(y), z(z) {} +}; + +//===----------------------------------------------------------------------===// +// Dtype helpers +//===----------------------------------------------------------------------===// + +/// Convert ScalarType to kernel name suffix (f32, f16, bf16, etc.) +inline const char* dtypeSuffix(ScalarType dtype) { + switch (dtype) { + case ScalarType::Float: return "f32"; + case ScalarType::Half: return "f16"; + case ScalarType::BFloat16: return "bf16"; + case ScalarType::Int: return "i32"; + case ScalarType::Long: return "i64"; + case ScalarType::Bool: return "bool"; + default: + ET_CHECK_MSG(false, "Unsupported dtype for GPU kernel: %d", static_cast(dtype)); + return nullptr; + } +} + +/// Check if dtype is a floating point type +inline bool isFloatingPoint(ScalarType dtype) { + return dtype == ScalarType::Float || + dtype == ScalarType::Half || + dtype == ScalarType::BFloat16; +} + +//===----------------------------------------------------------------------===// +// Arg - Unified argument type for dispatch +//===----------------------------------------------------------------------===// + +struct Arg { + enum Type { BUFFER, SCALAR_INT, SCALAR_FLOAT, TENSOR } type; + + union { + struct { void* ptr; size_t size; } buffer; + int64_t scalar_int; + double scalar_float; + struct { + void* ptr; // Data pointer + size_t size; // Total size in bytes + int64_t dims[8]; // Dimension sizes (up to 8D) + int64_t strides[8]; // Strides in elements + int32_t rank; // Number of dimensions + int32_t dtype; // Data type (MTLTensorDataType under Metal) + } tensor; + }; + + Arg(void* ptr, size_t size) : type(BUFFER) { + buffer.ptr = ptr; + buffer.size = size; + } + + Arg(int64_t val) : type(SCALAR_INT), scalar_int(val) {} + Arg(int32_t val) : type(SCALAR_INT), scalar_int(val) {} + Arg(uint32_t val) : type(SCALAR_INT), scalar_int(val) {} + + Arg(float val) : type(SCALAR_FLOAT), scalar_float(val) {} + Arg(double val) : type(SCALAR_FLOAT), scalar_float(val) {} + + static Arg Tensor2D(void* ptr, size_t size, int64_t dim0, int64_t dim1, int32_t dtype) { + Arg arg; + arg.type = TENSOR; + arg.tensor.ptr = ptr; + arg.tensor.size = size; + arg.tensor.rank = 2; + arg.tensor.dtype = dtype; + arg.tensor.dims[0] = dim0; + arg.tensor.dims[1] = dim1; + arg.tensor.strides[0] = dim1; // Row-major + arg.tensor.strides[1] = 1; + return arg; + } + +private: + Arg() : type(BUFFER) {} +}; + +} // namespace metal_v2 +} // namespace backends +} // namespace executorch diff --git a/backends/portable/runtime/metal_v2/OpUtils.h b/backends/portable/runtime/metal_v2/OpUtils.h new file mode 100644 index 00000000000..f8c9b693069 --- /dev/null +++ b/backends/portable/runtime/metal_v2/OpUtils.h @@ -0,0 +1,382 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +// Shared host-side helpers for MetalOps. +// +// Inspired by mlx/backend/common/binary.h and mlx/backend/common/utils.h. +// All functions are pure, header-only, and do not allocate GPU resources -- +// they only inspect Tensor metadata (sizes/strides) and produce small host +// buffers that ops can pass to their kernels. +// +// Naming convention mirrors the equivalent shader-side helpers in +// metal_v2/kernels/accessors.metal.h (to be added separately). + +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace metal_v2 { + +using runtime::etensor::Tensor; +using exec_aten::ArrayRef; +using exec_aten::SizesType; + +//===----------------------------------------------------------------------===// +// ElementwiseVariant +// +// Classifies an elementwise binary op's input layout to pick the fastest +// kernel specialization. Mirrors mlx::core::BinaryOpType. +//===----------------------------------------------------------------------===// + +enum class ElementwiseVariant { + ScalarScalar, // both inputs are 1-element + ScalarVector, // a is scalar, b is contiguous vector + VectorScalar, // a is contiguous vector, b is scalar + VectorVector, // both inputs are same shape and contiguous + General, // arbitrary strides / broadcast required +}; + +inline const char* variantPrefix(ElementwiseVariant v) { + switch (v) { + case ElementwiseVariant::ScalarScalar: return "ss"; + case ElementwiseVariant::ScalarVector: return "sv"; + case ElementwiseVariant::VectorScalar: return "vs"; + case ElementwiseVariant::VectorVector: return "vv"; + case ElementwiseVariant::General: return "g"; + } + return "g"; +} + +// Returns true if all dims of `t` have stride matching a packed row-major +// (innermost-fastest) layout. Equivalent to MLX's `flags().row_contiguous`. +inline bool isRowContiguous(const Tensor& t) { + auto sizes = t.sizes(); + auto strides = t.strides(); + if (sizes.size() != strides.size()) return false; + if (sizes.empty()) return true; + int64_t expected = 1; + for (int i = static_cast(sizes.size()) - 1; i >= 0; --i) { + if (sizes[i] == 1) continue; // size-1 dim's stride is irrelevant + if (strides[i] != expected) return false; + expected *= sizes[i]; + } + return true; +} + +inline bool sameShape(const Tensor& a, const Tensor& b) { + auto as = a.sizes(); + auto bs = b.sizes(); + if (as.size() != bs.size()) return false; + for (size_t i = 0; i < as.size(); ++i) { + if (as[i] != bs[i]) return false; + } + return true; +} + +inline ElementwiseVariant classifyBinary(const Tensor& a, const Tensor& b) { + bool a_scalar = (a.numel() == 1); + bool b_scalar = (b.numel() == 1); + if (a_scalar && b_scalar) return ElementwiseVariant::ScalarScalar; + if (a_scalar && isRowContiguous(b)) return ElementwiseVariant::ScalarVector; + if (b_scalar && isRowContiguous(a)) return ElementwiseVariant::VectorScalar; + if (sameShape(a, b) && isRowContiguous(a) && isRowContiguous(b)) { + return ElementwiseVariant::VectorVector; + } + return ElementwiseVariant::General; +} + +//===----------------------------------------------------------------------===// +// broadcastStrides +// +// Compute strides for `shape` aligned to `out_shape` such that broadcast +// dimensions (where the input has size 1 or is missing) get stride 0. +// Used to pass per-input stride arrays to General kernels. +// +// Example: in_shape = [3, 1, 5], out_shape = [2, 3, 4, 5] +// -> strides = [0 (broadcast), 5 (was 3), 0 (was 1), 1] +//===----------------------------------------------------------------------===// + +inline std::vector broadcastStrides( + ArrayRef in_shape, + const std::vector& out_shape) { + std::vector strides(out_shape.size(), 0); + int offset = static_cast(out_shape.size()) - + static_cast(in_shape.size()); + int64_t stride = 1; + for (int i = static_cast(in_shape.size()) - 1; i >= 0; --i) { + if (in_shape[i] == out_shape[i + offset]) { + strides[i + offset] = stride; + stride *= in_shape[i]; + } + // else: input dim is 1 (broadcast) or missing -> stride stays 0 + } + return strides; +} + +//===----------------------------------------------------------------------===// +// collapseContiguousDims +// +// Merges adjacent dims that are contiguous in *every* input's stride layout. +// Reduces the effective ndim passed to the General kernel, which lowers +// per-element index arithmetic and lets more cases hit the fast paths. +// +// Returns: +// .first = collapsed shape (length <= original ndim) +// .second = vector of collapsed strides, one per input (same length as .first) +// +// For a single contiguous tensor this collapses everything to a 1-D shape. +// For broadcast (one input has stride 0 along a dim) the dim is preserved. +// +// Mirrors mlx::core::collapse_contiguous_dims for the multi-array case. +//===----------------------------------------------------------------------===// + +inline std::pair, std::vector>> +collapseContiguousDims( + const std::vector& shape, + const std::vector>& strides_per_input) { + const size_t ndim = shape.size(); + const size_t nin = strides_per_input.size(); + + std::vector out_shape; + std::vector> out_strides(nin); + + if (ndim == 0) { + return {out_shape, out_strides}; + } + + // Start with the innermost dim. + out_shape.push_back(shape[ndim - 1]); + for (size_t k = 0; k < nin; ++k) { + out_strides[k].push_back(strides_per_input[k][ndim - 1]); + } + + // Walk outward; merge dim i into the current group if EVERY input's + // stride[i] equals stride[i+1] * shape[i+1] (i.e. truly adjacent in memory) + // OR all inputs have stride 0 along both dims (still safe to merge). + for (int i = static_cast(ndim) - 2; i >= 0; --i) { + SizesType inner_size = out_shape.back(); + bool can_merge = (shape[i] == 1) || (inner_size == 1); + if (!can_merge) { + can_merge = true; + for (size_t k = 0; k < nin; ++k) { + int64_t outer = strides_per_input[k][i]; + int64_t inner = out_strides[k].back(); + bool both_zero = (outer == 0) && (inner == 0); + bool packed = (outer == inner * inner_size); + if (!both_zero && !packed) { + can_merge = false; + break; + } + } + } + if (can_merge) { + out_shape.back() = shape[i] * inner_size; + // strides for the merged group take the inner stride + for (size_t k = 0; k < nin; ++k) { + if (out_strides[k].back() == 0 && strides_per_input[k][i] != 0) { + out_strides[k].back() = strides_per_input[k][i]; + } + // else keep inner stride + } + } else { + out_shape.push_back(shape[i]); + for (size_t k = 0; k < nin; ++k) { + out_strides[k].push_back(strides_per_input[k][i]); + } + } + } + + // We built the lists innermost-first; reverse to outermost-first. + std::reverse(out_shape.begin(), out_shape.end()); + for (auto& s : out_strides) { + std::reverse(s.begin(), s.end()); + } + return {out_shape, out_strides}; +} + +//===----------------------------------------------------------------------===// +// makeContiguousStrides +// +// Build packed row-major strides for `shape`. +// e.g. shape [2, 3, 4] -> strides [12, 4, 1] +//===----------------------------------------------------------------------===// + +inline std::vector makeContiguousStrides( + const std::vector& shape) { + std::vector strides(shape.size(), 1); + for (int i = static_cast(shape.size()) - 1; i > 0; --i) { + strides[i - 1] = strides[i] * static_cast(shape[i]); + } + return strides; +} + +inline std::vector makeContiguousStrides(ArrayRef shape) { + std::vector strides(shape.size(), 1); + for (int i = static_cast(shape.size()) - 1; i > 0; --i) { + strides[i - 1] = strides[i] * static_cast(shape[i]); + } + return strides; +} + +//===----------------------------------------------------------------------===// +// isColContiguous +// +// True if `t` is column-major contiguous (innermost dim has the largest +// stride). Mirror of isRowContiguous. +//===----------------------------------------------------------------------===// + +inline bool isColContiguous(const Tensor& t) { + auto sizes = t.sizes(); + auto strides = t.strides(); + if (sizes.size() != strides.size()) return false; + if (sizes.empty()) return true; + int64_t expected = 1; + for (size_t i = 0; i < sizes.size(); ++i) { + if (sizes[i] == 1) continue; + if (strides[i] != expected) return false; + expected *= sizes[i]; + } + return true; +} + +//===----------------------------------------------------------------------===// +// getBlockDims +// +// Pick a power-of-two threadgroup shape (block_x, block_y, block_z) that +// fits dim0/dim1/dim2 and whose total thread count is at most 2^maxPow2. +// Mirrors MLX's get_block_dims_common. +// +// Default maxPow2 = 10 (cap = 1024 threads/threadgroup, the Apple-Silicon +// hardware limit). +//===----------------------------------------------------------------------===// + +inline std::tuple getBlockDims( + int dim0, int dim1, int dim2, int maxPow2 = 10) { + int pows[3] = {0, 0, 0}; + int sum = 0; + while (true) { + int presum = sum; + if (dim0 >= (1 << (pows[0] + 1))) { pows[0]++; sum++; } + if (sum == maxPow2) break; + if (dim1 >= (1 << (pows[1] + 1))) { pows[1]++; sum++; } + if (sum == maxPow2) break; + if (dim2 >= (1 << (pows[2] + 1))) { pows[2]++; sum++; } + if (sum == presum || sum == maxPow2) break; + } + return std::make_tuple( + 1u << pows[0], 1u << pows[1], 1u << pows[2]); +} + +//===----------------------------------------------------------------------===// +// get2DGridDims +// +// Factor a flat element count into a 2-D grid (gx, gy) where each axis fits +// in uint32_t. Use this when `numel > UINT32_MAX` would overflow a 1-D grid. +// Returned values multiplied together cover at least `numel / workPerThread` +// threads (so divide your total work by `workPerThread` if each thread +// processes more than one element). +// +// Mirrors MLX's get_2d_grid_dims_common (without strides — we handle only +// the simple "flat numel" case here; broadcast/stride-aware factoring can be\n// added when needed). +//===----------------------------------------------------------------------===// + +inline std::pair get2DGridDims( + uint64_t numel, uint64_t workPerThread = 1) { + uint64_t threads = (numel + workPerThread - 1) / workPerThread; + if (threads == 0) { + return {1u, 1u}; + } + if (threads <= UINT32_MAX) { + return {static_cast(threads), 1u}; + } + // Find the smallest gy such that ceil(threads / gy) fits in uint32_t. + uint64_t gy = (threads + UINT32_MAX - 1) / UINT32_MAX; + uint64_t gx = (threads + gy - 1) / gy; + if (gx > UINT32_MAX || gy > UINT32_MAX) { + // Caller must have an absurdly large tensor (>2^64 elements). Clamp. + gx = UINT32_MAX; + gy = UINT32_MAX; + } + return {static_cast(gx), static_cast(gy)}; +} + +//===----------------------------------------------------------------------===// +// workPerThread +// +// Returns the recommended number of elements each thread should process for +// elementwise kernels, based on dtype size. Smaller dtypes -> more elements +// per thread (better memory bandwidth utilization, larger vectorized loads). +// Mirrors mlx's WorkPerThread trait. +// +// Use as the `N` template parameter for the binary_v* / unary_v* kernels. +//===----------------------------------------------------------------------===// + +inline int workPerThread(ScalarType dtype) { + switch (dtype) { + case ScalarType::Bool: + case ScalarType::Byte: + case ScalarType::Char: + return 8; // 1-byte: 8 elems = 8 bytes (one i64 load) + case ScalarType::Short: + return 8; // 2-byte: 8 elems = 16 bytes (one float4-equivalent) + case ScalarType::Half: + return 8; // 2-byte half: same + case ScalarType::Int: + case ScalarType::Float: + return 4; // 4-byte: 4 elems = 16 bytes (one float4) + case ScalarType::Long: + case ScalarType::Double: + return 2; // 8-byte: 2 elems = 16 bytes + default: + return 4; + } +} + +//===----------------------------------------------------------------------===// +// DeviceTier + getDeviceTier +// +// Coarse classification of the GPU's perf bucket. Used by ops that want +// to pick different tile sizes / thresholds per device class. Mirrors v1's +// MatMulConfig::forDevice but lifted to a generic helper. +// +// Caller passes the device name (e.g. [[device name] UTF8String]) to keep +// this header free of Metal imports. +//===----------------------------------------------------------------------===// + +enum class DeviceTier { + Phone, // iPhone, iPad + MacBase, // M-series base / Pro + MacUltra, // M-series Max / Ultra +}; + +inline DeviceTier getDeviceTierFromName(const char* deviceName) { + if (deviceName == nullptr) return DeviceTier::MacBase; + // Order matters: check Ultra/Max before "M" prefix matches. + if (std::strstr(deviceName, "Ultra") || std::strstr(deviceName, "Max")) { + return DeviceTier::MacUltra; + } + if (std::strstr(deviceName, "iPhone") || std::strstr(deviceName, "iPad") || + std::strstr(deviceName, "Apple A")) { + return DeviceTier::Phone; + } + return DeviceTier::MacBase; +} + +} // namespace metal_v2 +} // namespace backends +} // namespace executorch diff --git a/backends/portable/runtime/metal_v2/TARGETS b/backends/portable/runtime/metal_v2/TARGETS new file mode 100644 index 00000000000..a38c6acac0b --- /dev/null +++ b/backends/portable/runtime/metal_v2/TARGETS @@ -0,0 +1,69 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Metal v2 Runtime - Unified GPU runtime with automatic command replay +# Features: ICB replay, ResidencySet, Binary Archives, MTLHeap, Buffer Pool + +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +oncall("executorch") + +runtime.cxx_library( + name = "gpu_stream", + srcs = ["MetalStream.mm"], + exported_headers = [ + "GpuStream.h", + "MetalStream.h", + ], + compiler_flags = [ + "-fobjc-arc", + ], + frameworks = [ + "Foundation", + "Metal", + ], + preprocessor_flags = [ + "-DPORTABLE_HAS_METAL_V2=1", + ], + visibility = ["//executorch/..."], + deps = [ + "//executorch/runtime/platform:platform", + ], +) + +runtime.cxx_library( + name = "gpu_op", + srcs = ["MetalOp.mm"], + exported_headers = ["MetalOp.h"], + compiler_flags = [ + "-fobjc-arc", + ], + visibility = ["//executorch/..."], + deps = [ + ":gpu_stream", + "//executorch/runtime/core:core", + "//executorch/runtime/core/exec_aten:lib", + ], +) + +runtime.cxx_library( + name = "metal_runtime_v2", + srcs = ["MetalRuntime.mm"], + exported_headers = ["MetalRuntime.h"], + compiler_flags = [ + "-fobjc-arc", + ], + preprocessor_flags = [ + "-DPORTABLE_HAS_METAL_V2=1", + ], + visibility = ["//executorch/..."], + deps = [ + ":gpu_stream", + ":gpu_op", + "//executorch/backends/portable/runtime:graph_runtime", + "//executorch/runtime/platform:platform", + ], +) diff --git a/backends/portable/runtime/metal_v2/kernels/Accessors.h b/backends/portable/runtime/metal_v2/kernels/Accessors.h new file mode 100644 index 00000000000..b8e0a65e715 --- /dev/null +++ b/backends/portable/runtime/metal_v2/kernels/Accessors.h @@ -0,0 +1,136 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +// Shared Metal-shader-side helpers for index decoding (strided / broadcast +// access). Mirrors mlx/backend/metal/kernels/utils.h. +// +// Usage from a host .mm: +// #include "metal_v2/kernels/Accessors.h" +// const char* MyOp::kernelSource() const { +// static const std::string source = +// std::string(kAccessorsMetalSource) + R"( +// // ... your kernel ... +// )"; +// return source.c_str(); +// } +// +// All helpers use `int` strides (matches our host conversion in OpUtils +// where strides are i32 by the time they reach the kernel). + +namespace executorch { +namespace backends { +namespace metal_v2 { + +inline constexpr const char* kAccessorsMetalSource = R"METAL( +//===----------------------------------------------------------------------===// +// WorkPerThread +// +// How many elements each thread should process in elementwise kernels. +// Smaller dtypes -> more elements/thread for better memory throughput. +// Mirrors mlx::WorkPerThread. +// +// MUST stay in sync with the host-side OpUtils::workPerThread(dtype). +//===----------------------------------------------------------------------===// + +template struct WorkPerThread { static constant constexpr int n = 4; }; +template <> struct WorkPerThread { static constant constexpr int n = 8; }; +template <> struct WorkPerThread { static constant constexpr int n = 8; }; +template <> struct WorkPerThread { static constant constexpr int n = 8; }; +template <> struct WorkPerThread { static constant constexpr int n = 8; }; + +//===----------------------------------------------------------------------===// +// elemToLoc family +// +// Convert a flat element index (or a multi-dim grid position) into a strided +// byte/element offset into a tensor's underlying buffer. Mirrors MLX's +// elem_to_loc_* in mlx/backend/metal/kernels/utils.h. +// +// Convention: shape[ndim - 1] is the innermost dimension. +// Strides may be 0 to express broadcasting along that dimension. +//===----------------------------------------------------------------------===// + +// Generic N-D: flat elem -> strided offset. +inline int elemToLoc( + uint elem, + constant const int* shape, + constant const int* strides, + int ndim) { + int loc = 0; + for (int i = ndim - 1; i >= 0 && elem > 0; --i) { + loc += int(elem % uint(shape[i])) * strides[i]; + elem /= uint(shape[i]); + } + return loc; +} + +// 1-D specialization (no shape needed; just a stride). +inline int elemToLoc1(uint elem, int stride) { + return int(elem) * stride; +} + +// 2-D specialization. elem.x = innermost (cols), elem.y = outer (rows). +// strides[0] = outer stride, strides[1] = inner stride (matches our +// outermost-first convention from OpUtils::collapseContiguousDims). +inline int elemToLoc2(uint2 elem, constant const int strides[2]) { + return int(elem.x) * strides[1] + int(elem.y) * strides[0]; +} + +// 3-D specialization. elem.x = innermost, elem.y = middle, elem.z = outer. +inline int elemToLoc3(uint3 elem, constant const int strides[3]) { + return int(elem.x) * strides[2] + int(elem.y) * strides[1] + + int(elem.z) * strides[0]; +} + +//===----------------------------------------------------------------------===// +// elemToLocBinary / elemToLocTernary +// +// Decode the same flat elem against multiple inputs at once (binary/ternary +// general kernels). Returns one offset per input. +//===----------------------------------------------------------------------===// + +inline int2 elemToLocBinary( + uint elem, + constant const int* shape, + constant const int* a_strides, + constant const int* b_strides, + int ndim) { + int2 loc = int2(0, 0); + for (int i = ndim - 1; i >= 0 && elem > 0; --i) { + int dim = int(elem % uint(shape[i])); + loc.x += dim * a_strides[i]; + loc.y += dim * b_strides[i]; + elem /= uint(shape[i]); + } + return loc; +} + +inline int3 elemToLocTernary( + uint elem, + constant const int* shape, + constant const int* a_strides, + constant const int* b_strides, + constant const int* c_strides, + int ndim) { + int3 loc = int3(0, 0, 0); + for (int i = ndim - 1; i >= 0 && elem > 0; --i) { + int dim = int(elem % uint(shape[i])); + loc.x += dim * a_strides[i]; + loc.y += dim * b_strides[i]; + loc.z += dim * c_strides[i]; + elem /= uint(shape[i]); + } + return loc; +} + +)METAL"; + +} // namespace metal_v2 +} // namespace backends +} // namespace executorch diff --git a/backends/portable/runtime/metal_v2/kernels/TileLoad.h b/backends/portable/runtime/metal_v2/kernels/TileLoad.h new file mode 100644 index 00000000000..d3e5a1bdb0e --- /dev/null +++ b/backends/portable/runtime/metal_v2/kernels/TileLoad.h @@ -0,0 +1,147 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +// Shared Metal-shader-side helpers for cooperative tile loading from +// device memory into threadgroup memory. Used by GEMM-style kernels +// (matmul_simd, matmul_nt, matmul_tn, conv2d, etc). +// +// Mirrors the load loops in mlx/backend/metal/kernels/steel/gemm/loader.h. +// +// Usage from a host .mm: +// #include "metal_v2/kernels/TileLoad.h" +// const char* MyOp::kernelSource() const { +// static const std::string source = std::string(kTileLoadMetalSource) + R"( +// // ... your kernel using cooperativeLoadTileVec4 ... +// )"; +// return source.c_str(); +// } + +namespace executorch { +namespace backends { +namespace metal_v2 { + +inline constexpr const char* kTileLoadMetalSource = R"METAL( +//===----------------------------------------------------------------------===// +// cooperativeLoadTileVec4 +// +// Cooperatively load a ROWS x COLS tile from row-major device memory `src` +// (full tensor of size srcRows x srcCols, row stride `srcStride`) into a +// padded threadgroup-memory tile `smem` of shape ROWS x SMEM_STRIDE. +// +// The tile origin in `src` is (baseRow, baseCol). NUM_THREADS threads +// (typically a full threadgroup of 128) cooperate; each thread loads +// ceil(ROWS*COLS/4 / NUM_THREADS) vec chunks. +// +// Vectorized (vec) load is used when the 4-element column block is +// fully in bounds; scalar fallback handles the boundary tile. +// +// Constraints: +// - COLS must be a multiple of 4 +// - SMEM_STRIDE >= COLS (the extra columns are usually padding to avoid +// bank conflicts -- e.g. SMEM_STRIDE = COLS + 4) +// - NUM_THREADS should evenly divide ROWS*COLS/4 for best efficiency. +//===----------------------------------------------------------------------===// + +template +inline void cooperativeLoadTileVec4( + threadgroup T (&smem)[ROWS][SMEM_STRIDE], + device const T* src, + int srcRows, int srcCols, int srcStride, + int baseRow, int baseCol, + uint tid) { + + static_assert(COLS % 4 == 0, "COLS must be a multiple of 4"); + constexpr int VECS_PER_ROW = COLS / 4; + constexpr int TOTAL_VECS = ROWS * VECS_PER_ROW; + + for (int v = int(tid); v < TOTAL_VECS; v += NUM_THREADS) { + int row = v / VECS_PER_ROW; + int col4 = v % VECS_PER_ROW; + int gRow = baseRow + row; + int gCol = baseCol + col4 * 4; + + if (gRow < srcRows && gCol + 3 < srcCols) { + // Aligned 4-wide fast path + metal::vec vv = *((device metal::vec*)(src + gRow * srcStride + gCol)); + smem[row][col4 * 4 ] = vv[0]; + smem[row][col4 * 4 + 1] = vv[1]; + smem[row][col4 * 4 + 2] = vv[2]; + smem[row][col4 * 4 + 3] = vv[3]; + } else { + // Edge: scalar fallback with per-element bounds check + for (int d = 0; d < 4; d++) { + int c = col4 * 4 + d; + smem[row][c] = (gRow < srcRows && (baseCol + c) < srcCols) + ? src[gRow * srcStride + baseCol + c] + : T(0); + } + } + } +} + +//===----------------------------------------------------------------------===// +// cooperativeLoadTileTransposedVec4 +// +// Same as cooperativeLoadTileVec4 but loads a logical ROWS x COLS tile from +// a PHYSICALLY TRANSPOSED source: src is stored as [srcCols x srcRows] (i.e. +// src[c * srcStride + r] = logical(r, c)). srcStride is the physical row +// stride (in elements) of src, which equals the logical row count. +// +// Used by matmul_nt (loads B which is logically [K, N] but stored [N, K]) +// and matmul_tn (loads A which is logically [M, K] but stored [K, M]). +// +// Vec4 coalescing: we vectorize along the PHYSICAL row direction (= logical +// row direction), so each thread loads 4 logical rows for one logical col. +// This requires ROWS to be a multiple of 4. +//===----------------------------------------------------------------------===// + +template +inline void cooperativeLoadTileTransposedVec4( + threadgroup T (&smem)[ROWS][SMEM_STRIDE], + device const T* src, + int srcRows, int srcCols, int srcStride, + int baseRow, int baseCol, + uint tid) { + + static_assert(ROWS % 4 == 0, "ROWS must be a multiple of 4 for transposed vec4"); + constexpr int VECS_PER_COL = ROWS / 4; + constexpr int TOTAL_VECS = VECS_PER_COL * COLS; + + for (int v = int(tid); v < TOTAL_VECS; v += NUM_THREADS) { + int col = v / VECS_PER_COL; // logical col within tile + int row4 = v % VECS_PER_COL; // group of 4 logical rows + int gRow = baseRow + row4 * 4; // first logical row in group + int gCol = baseCol + col; // logical col in src + + // Physical: src[gCol * srcStride + gRow .. gRow+3] + if (gCol < srcCols && gRow + 3 < srcRows) { + metal::vec vv = *((device metal::vec*)(src + gCol * srcStride + gRow)); + smem[row4 * 4 ][col] = vv[0]; + smem[row4 * 4 + 1][col] = vv[1]; + smem[row4 * 4 + 2][col] = vv[2]; + smem[row4 * 4 + 3][col] = vv[3]; + } else { + for (int d = 0; d < 4; d++) { + int r = row4 * 4 + d; + smem[r][col] = (gCol < srcCols && (baseRow + r) < srcRows) + ? src[gCol * srcStride + baseRow + r] + : T(0); + } + } + } +} + +)METAL"; + +} // namespace metal_v2 +} // namespace backends +} // namespace executorch diff --git a/backends/portable/runtime/metal_v2/ops/BinaryOps.h b/backends/portable/runtime/metal_v2/ops/BinaryOps.h new file mode 100644 index 00000000000..eba65be6960 --- /dev/null +++ b/backends/portable/runtime/metal_v2/ops/BinaryOps.h @@ -0,0 +1,73 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace executorch { +namespace backends { +namespace metal_v2 { + +//===----------------------------------------------------------------------===// +// BinaryOp - Base class for elementwise binary operations +// +// Variant classification, broadcast strides, and contiguous-dim collapsing +// all live in OpUtils.h (shared with other ops). +//===----------------------------------------------------------------------===// + +class BinaryOp : public MetalOp { +public: + virtual const char* opName() const = 0; + virtual bool hasAlpha() const { return false; } + + bool supports(ScalarType dtype) const override { + return isFloatingPoint(dtype); + } + + std::vector computeOutputShape( + EValuePtrSpan inputs) const override; + + void dispatch( + MetalStream* stream, + EValuePtrSpan inputs, + EValuePtrSpan outputs) override; + +protected: + const char* kernelSource() const override; + + std::string kernelName(ElementwiseVariant variant, ScalarType dtype) const; +}; + +//===----------------------------------------------------------------------===// +// Concrete Binary Ops +//===----------------------------------------------------------------------===// + +class AddOp : public BinaryOp { +public: + const char* name() const override { return "aten::add"; } + const char* opName() const override { return "add"; } + bool hasAlpha() const override { return true; } +}; + +class MulOp : public BinaryOp { +public: + const char* name() const override { return "aten::mul"; } + const char* opName() const override { return "mul"; } +}; + +class SubOp : public BinaryOp { +public: + const char* name() const override { return "aten::sub"; } + const char* opName() const override { return "sub"; } +}; + +} // namespace metal_v2 +} // namespace backends +} // namespace executorch diff --git a/backends/portable/runtime/metal_v2/ops/BinaryOps.mm b/backends/portable/runtime/metal_v2/ops/BinaryOps.mm new file mode 100644 index 00000000000..1a439048dc4 --- /dev/null +++ b/backends/portable/runtime/metal_v2/ops/BinaryOps.mm @@ -0,0 +1,409 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#import "BinaryOps.h" +#include +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace metal_v2 { + +using runtime::Error; +using torch::executor::resize_to_broadcast_target_size; +using torch::executor::get_broadcast_target_size; + +//===----------------------------------------------------------------------===// +// Variant detection, prefix string, and broadcast strides come from +// metal_v2/OpUtils.h. Kernel-name builder is the only thing left here. +//===----------------------------------------------------------------------===// + +std::string BinaryOp::kernelName(ElementwiseVariant variant, ScalarType dtype) const { + return std::string(variantPrefix(variant)) + "_" + opName() + "_" + dtypeSuffix(dtype); +} + +//===----------------------------------------------------------------------===// +// Output Shape Computation - using ET's broadcast utility +//===----------------------------------------------------------------------===// + +std::vector BinaryOp::computeOutputShape( + EValuePtrSpan inputs) const { + + if (inputs.size() < 2 || !inputs[0]->isTensor() || !inputs[1]->isTensor()) { + return {}; + } + + auto& a = inputs[0]->toTensor(); + auto& b = inputs[1]->toTensor(); + + // Use ET's broadcast utility + SizesType out_sizes[runtime::kTensorDimensionLimit]; + size_t out_dim = 0; + + Error err = get_broadcast_target_size(a, b, out_sizes, runtime::kTensorDimensionLimit, &out_dim); + if (err != Error::Ok) { + return {}; + } + + return std::vector(out_sizes, out_sizes + out_dim); +} + +// (broadcastStrides + collapseContiguousDims live in OpUtils.h.) + +//===----------------------------------------------------------------------===// +// Dispatch +//===----------------------------------------------------------------------===// + +void BinaryOp::dispatch( + MetalStream* stream, + EValuePtrSpan inputs, + EValuePtrSpan outputs) { + + auto& a = inputs[0]->toTensor(); + auto& b = inputs[1]->toTensor(); + auto& out = outputs[0]->toTensor(); + + auto err = resizeOutput(inputs, outputs[0]); + if (err != Error::Ok) { + ET_LOG(Error, "BinaryOp: failed to resize output"); + return; + } + + ScalarType dtype = out.scalar_type(); + ElementwiseVariant variant = classifyBinary(a, b); + std::string kname = kernelName(variant, dtype); + + auto* kernel = getKernel(stream, kname.c_str()); + size_t numel = out.numel(); + + ET_LOG(Info, "BinaryOp::dispatch(%s): variant=%s, kernel=%s, numel=%zu", + name(), variantPrefix(variant), kname.c_str(), numel); + + constexpr uint32_t blockSize = 256; + // Kernels are templated on N = WorkPerThread::n; pick the matching N + // here so the host launch matches the kernel's per-thread work. + const uint32_t elemPerThread = static_cast(workPerThread(dtype)); + + switch (variant) { + case ElementwiseVariant::ScalarScalar: { + if (hasAlpha()) { + float alpha = 1.0f; + stream->dispatch(kernel, { + {a.mutable_data_ptr(), a.nbytes()}, + {b.mutable_data_ptr(), b.nbytes()}, + {out.mutable_data_ptr(), out.nbytes()}, + alpha, + static_cast(numel) + }, uvec3(1, 1, 1), uvec3(1, 1, 1)); + } else { + stream->dispatch(kernel, { + {a.mutable_data_ptr(), a.nbytes()}, + {b.mutable_data_ptr(), b.nbytes()}, + {out.mutable_data_ptr(), out.nbytes()}, + static_cast(numel) + }, uvec3(1, 1, 1), uvec3(1, 1, 1)); + } + break; + } + + case ElementwiseVariant::ScalarVector: + case ElementwiseVariant::VectorScalar: + case ElementwiseVariant::VectorVector: { + uint32_t gridX = (uint32_t)((numel + elemPerThread * blockSize - 1) / (elemPerThread * blockSize)); + + if (hasAlpha()) { + float alpha = 1.0f; + stream->dispatch(kernel, { + {a.mutable_data_ptr(), a.nbytes()}, + {b.mutable_data_ptr(), b.nbytes()}, + {out.mutable_data_ptr(), out.nbytes()}, + alpha, + static_cast(numel) + }, uvec3(gridX, 1, 1), uvec3(blockSize, 1, 1)); + } else { + stream->dispatch(kernel, { + {a.mutable_data_ptr(), a.nbytes()}, + {b.mutable_data_ptr(), b.nbytes()}, + {out.mutable_data_ptr(), out.nbytes()}, + static_cast(numel) + }, uvec3(gridX, 1, 1), uvec3(blockSize, 1, 1)); + } + break; + } + + case ElementwiseVariant::General: { + auto out_shape = computeOutputShape(inputs); + auto a_strides_full = broadcastStrides(a.sizes(), out_shape); + auto b_strides_full = broadcastStrides(b.sizes(), out_shape); + + // Collapse adjacent dims that are contiguous in BOTH inputs to + // shrink ndim and reduce per-element index math in the shader. + auto [shape_collapsed, strides_collapsed] = + collapseContiguousDims(out_shape, {a_strides_full, b_strides_full}); + int32_t ndim = static_cast(shape_collapsed.size()); + + std::vector shape_i32(shape_collapsed.begin(), shape_collapsed.end()); + std::vector a_strides_i32(strides_collapsed[0].begin(), strides_collapsed[0].end()); + std::vector b_strides_i32(strides_collapsed[1].begin(), strides_collapsed[1].end()); + + uint32_t gridX = (uint32_t)((numel + blockSize - 1) / blockSize); + + stream->dispatch(kernel, { + {a.mutable_data_ptr(), a.nbytes()}, + {b.mutable_data_ptr(), b.nbytes()}, + {out.mutable_data_ptr(), out.nbytes()}, + {shape_i32.data(), shape_i32.size() * sizeof(int32_t)}, + {a_strides_i32.data(), a_strides_i32.size() * sizeof(int32_t)}, + {b_strides_i32.data(), b_strides_i32.size() * sizeof(int32_t)}, + ndim, + static_cast(numel) + }, uvec3(gridX, 1, 1), uvec3(blockSize, 1, 1)); + break; + } + } +} + +//===----------------------------------------------------------------------===// +// Kernel Source +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Kernel Source +// +// We prepend metal_v2/kernels/Accessors.h's shared shader helpers (elemToLoc +// family) to the per-op kernel body. The result is built once into a static +// std::string and returned by .c_str() so that MetalOp's `const char*` API is +// preserved. +//===----------------------------------------------------------------------===// + +const char* BinaryOp::kernelSource() const { + static const std::string source = std::string(kAccessorsMetalSource) + R"( +#include +using namespace metal; + +// Op Functors +struct AddOp { + template T operator()(T a, T b) { return a + b; } + template T operator()(T a, T b, float alpha) { return a + T(alpha) * b; } +}; +struct MulOp { template T operator()(T a, T b) { return a * b; } }; +struct SubOp { template T operator()(T a, T b) { return a - b; } }; + +// ScalarScalar (ss) +template +kernel void binary_ss( + device const T* a [[buffer(0)]], + device const T* b [[buffer(1)]], + device T* out [[buffer(2)]], + constant uint& n [[buffer(3)]], + uint gid [[thread_position_in_grid]]) { + if (gid == 0) out[0] = Op()(a[0], b[0]); +} + +template +kernel void binary_ss_alpha( + device const T* a [[buffer(0)]], + device const T* b [[buffer(1)]], + device T* out [[buffer(2)]], + constant float& alpha [[buffer(3)]], + constant uint& n [[buffer(4)]], + uint gid [[thread_position_in_grid]]) { + if (gid == 0) out[0] = Op()(a[0], b[0], alpha); +} + +// ScalarVector (sv) -- N elements per thread, dtype-aware. +template::n> +kernel void binary_sv( + device const T* a [[buffer(0)]], + device const T* b [[buffer(1)]], + device T* out [[buffer(2)]], + constant uint& n [[buffer(3)]], + uint gid [[thread_position_in_grid]]) { + uint idx = gid * uint(N); + T scalar = a[0]; + if (N > 1 && idx + uint(N) > n) { + for (uint j = idx; j < n; j++) out[j] = Op()(scalar, b[j]); + } else { + for (int i = 0; i < N; ++i) out[idx + i] = Op()(scalar, b[idx + i]); + } +} + +template::n> +kernel void binary_sv_alpha( + device const T* a [[buffer(0)]], + device const T* b [[buffer(1)]], + device T* out [[buffer(2)]], + constant float& alpha [[buffer(3)]], + constant uint& n [[buffer(4)]], + uint gid [[thread_position_in_grid]]) { + uint idx = gid * uint(N); + T scalar = a[0]; + if (N > 1 && idx + uint(N) > n) { + for (uint j = idx; j < n; j++) out[j] = Op()(scalar, b[j], alpha); + } else { + for (int i = 0; i < N; ++i) out[idx + i] = Op()(scalar, b[idx + i], alpha); + } +} + +// VectorScalar (vs) -- N elements per thread, dtype-aware. +template::n> +kernel void binary_vs( + device const T* a [[buffer(0)]], + device const T* b [[buffer(1)]], + device T* out [[buffer(2)]], + constant uint& n [[buffer(3)]], + uint gid [[thread_position_in_grid]]) { + uint idx = gid * uint(N); + T scalar = b[0]; + if (N > 1 && idx + uint(N) > n) { + for (uint j = idx; j < n; j++) out[j] = Op()(a[j], scalar); + } else { + for (int i = 0; i < N; ++i) out[idx + i] = Op()(a[idx + i], scalar); + } +} + +template::n> +kernel void binary_vs_alpha( + device const T* a [[buffer(0)]], + device const T* b [[buffer(1)]], + device T* out [[buffer(2)]], + constant float& alpha [[buffer(3)]], + constant uint& n [[buffer(4)]], + uint gid [[thread_position_in_grid]]) { + uint idx = gid * uint(N); + T scalar = b[0]; + if (N > 1 && idx + uint(N) > n) { + for (uint j = idx; j < n; j++) out[j] = Op()(a[j], scalar, alpha); + } else { + for (int i = 0; i < N; ++i) out[idx + i] = Op()(a[idx + i], scalar, alpha); + } +} + +// VectorVector (vv) -- N elements per thread, dtype-aware. +template::n> +kernel void binary_vv( + device const T* a [[buffer(0)]], + device const T* b [[buffer(1)]], + device T* out [[buffer(2)]], + constant uint& n [[buffer(3)]], + uint gid [[thread_position_in_grid]]) { + uint idx = gid * uint(N); + if (N > 1 && idx + uint(N) > n) { + for (uint j = idx; j < n; j++) out[j] = Op()(a[j], b[j]); + } else { + for (int i = 0; i < N; ++i) out[idx + i] = Op()(a[idx + i], b[idx + i]); + } +} + +template::n> +kernel void binary_vv_alpha( + device const T* a [[buffer(0)]], + device const T* b [[buffer(1)]], + device T* out [[buffer(2)]], + constant float& alpha [[buffer(3)]], + constant uint& n [[buffer(4)]], + uint gid [[thread_position_in_grid]]) { + uint idx = gid * uint(N); + if (N > 1 && idx + uint(N) > n) { + for (uint j = idx; j < n; j++) out[j] = Op()(a[j], b[j], alpha); + } else { + for (int i = 0; i < N; ++i) out[idx + i] = Op()(a[idx + i], b[idx + i], alpha); + } +} + +// General (g) - Strided/broadcast. +// Uses elemToLocBinary from Accessors.h: decodes both inputs in one loop. +template +kernel void binary_g( + device const T* a [[buffer(0)]], + device const T* b [[buffer(1)]], + device T* out [[buffer(2)]], + constant int* shape [[buffer(3)]], + constant int* a_strides [[buffer(4)]], + constant int* b_strides [[buffer(5)]], + constant int& ndim [[buffer(6)]], + constant uint& n [[buffer(7)]], + uint gid [[thread_position_in_grid]]) { + if (gid >= n) return; + int2 idx = elemToLocBinary(gid, shape, a_strides, b_strides, ndim); + out[gid] = Op()(a[idx.x], b[idx.y]); +} + +// Instantiate +#define INSTANTIATE_SS(name, op, T, suffix) \ + template [[host_name("ss_" name "_" suffix)]] kernel void binary_ss(device const T*, device const T*, device T*, constant uint&, uint); + +#define INSTANTIATE_SS_ALPHA(name, op, T, suffix) \ + template [[host_name("ss_" name "_" suffix)]] kernel void binary_ss_alpha(device const T*, device const T*, device T*, constant float&, constant uint&, uint); + +#define INSTANTIATE_SV(name, op, T, suffix) \ + template [[host_name("sv_" name "_" suffix)]] kernel void binary_sv(device const T*, device const T*, device T*, constant uint&, uint); + +#define INSTANTIATE_SV_ALPHA(name, op, T, suffix) \ + template [[host_name("sv_" name "_" suffix)]] kernel void binary_sv_alpha(device const T*, device const T*, device T*, constant float&, constant uint&, uint); + +#define INSTANTIATE_VS(name, op, T, suffix) \ + template [[host_name("vs_" name "_" suffix)]] kernel void binary_vs(device const T*, device const T*, device T*, constant uint&, uint); + +#define INSTANTIATE_VS_ALPHA(name, op, T, suffix) \ + template [[host_name("vs_" name "_" suffix)]] kernel void binary_vs_alpha(device const T*, device const T*, device T*, constant float&, constant uint&, uint); + +#define INSTANTIATE_VV(name, op, T, suffix) \ + template [[host_name("vv_" name "_" suffix)]] kernel void binary_vv(device const T*, device const T*, device T*, constant uint&, uint); + +#define INSTANTIATE_VV_ALPHA(name, op, T, suffix) \ + template [[host_name("vv_" name "_" suffix)]] kernel void binary_vv_alpha(device const T*, device const T*, device T*, constant float&, constant uint&, uint); + +#define INSTANTIATE_G(name, op, T, suffix) \ + template [[host_name("g_" name "_" suffix)]] kernel void binary_g(device const T*, device const T*, device T*, constant int*, constant int*, constant int*, constant int&, constant uint&, uint); + +// Add (with alpha) +INSTANTIATE_SS_ALPHA("add", AddOp, float, "f32") +INSTANTIATE_SS_ALPHA("add", AddOp, half, "f16") +INSTANTIATE_SV_ALPHA("add", AddOp, float, "f32") +INSTANTIATE_SV_ALPHA("add", AddOp, half, "f16") +INSTANTIATE_VS_ALPHA("add", AddOp, float, "f32") +INSTANTIATE_VS_ALPHA("add", AddOp, half, "f16") +INSTANTIATE_VV_ALPHA("add", AddOp, float, "f32") +INSTANTIATE_VV_ALPHA("add", AddOp, half, "f16") +INSTANTIATE_G("add", AddOp, float, "f32") +INSTANTIATE_G("add", AddOp, half, "f16") + +// Mul +INSTANTIATE_SS("mul", MulOp, float, "f32") +INSTANTIATE_SS("mul", MulOp, half, "f16") +INSTANTIATE_SV("mul", MulOp, float, "f32") +INSTANTIATE_SV("mul", MulOp, half, "f16") +INSTANTIATE_VS("mul", MulOp, float, "f32") +INSTANTIATE_VS("mul", MulOp, half, "f16") +INSTANTIATE_VV("mul", MulOp, float, "f32") +INSTANTIATE_VV("mul", MulOp, half, "f16") +INSTANTIATE_G("mul", MulOp, float, "f32") +INSTANTIATE_G("mul", MulOp, half, "f16") + +// Sub +INSTANTIATE_SS("sub", SubOp, float, "f32") +INSTANTIATE_SS("sub", SubOp, half, "f16") +INSTANTIATE_SV("sub", SubOp, float, "f32") +INSTANTIATE_SV("sub", SubOp, half, "f16") +INSTANTIATE_VS("sub", SubOp, float, "f32") +INSTANTIATE_VS("sub", SubOp, half, "f16") +INSTANTIATE_VV("sub", SubOp, float, "f32") +INSTANTIATE_VV("sub", SubOp, half, "f16") +INSTANTIATE_G("sub", SubOp, float, "f32") +INSTANTIATE_G("sub", SubOp, half, "f16") +)"; + return source.c_str(); +} + +} // namespace metal_v2 +} // namespace backends +} // namespace executorch diff --git a/backends/portable/runtime/metal_v2/ops/MPSGraphOp.h b/backends/portable/runtime/metal_v2/ops/MPSGraphOp.h new file mode 100644 index 00000000000..bf2cc8fca63 --- /dev/null +++ b/backends/portable/runtime/metal_v2/ops/MPSGraphOp.h @@ -0,0 +1,174 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + + + + + +#import +#import +#import + +#include +#include +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace metal_v2 { + +//===----------------------------------------------------------------------===// +// MPSGraphOp - Base class for ops that delegate compute to MPSGraph. +// +// Subclass contract: +// - Override `buildGraph()` to construct an MPSGraph for the given input +// shapes/dtypes. Return the graph + its input/output placeholder tensors. +// - Override `cacheKey()` if the default (shape+dtype-based) hashing is +// insufficient (e.g. when constants or attributes affect the graph). +// +// Base-class behavior (in dispatch()): +// 1. Build a cache key from inputs/outputs. +// 2. On miss: call buildGraph(), cache the entry. +// 3. End any active compute encoder on the stream. +// 4. Wrap our MTLCommandBuffer as MPSCommandBuffer. +// 5. Wrap each MTLBuffer as MPSGraphTensorData. +// 6. encodeToCommandBuffer. +// 7. Subsequent ops open a new encoder on the same cmd buffer. +// +// Constraints: +// - Under MTL4: uses a dedicated singleton legacy queue + MTLSharedEvent +// for cross-queue sync to MetalStream's MTL4 cb. (MPSGraph wraps +// id legacy; not directly interoperable with the +// MTL4 cb, so we go via a legacy cb + event.) +// - Defeats ICB replay for this op (MPSGraph re-encodes per call). +//===----------------------------------------------------------------------===// + +class MPSGraphOp : public MetalOp { + public: + MPSGraphOp() = default; + ~MPSGraphOp() override; + + // Final dispatch: cache lookup + build-on-miss + encode. + // Subclasses extend behavior via cacheKey() / buildGraph(). + void dispatch( + MetalStream* stream, + EValuePtrSpan inputs, + EValuePtrSpan outputs) override final; + + // MPSGraph ops don't ship MSL kernel source. + const char* kernelSource() const override { return ""; } + + protected: + struct CachedGraph { + MPSGraph* graph = nil; // strong (we retain) + std::vector inputPlaceholders; // ARC-managed + std::vector outputPlaceholders; // ARC-managed + }; + + // Build the MPSGraph for the concrete input shapes. Called once per cache + // miss. Subclass returns the placeholders (in input order) + outputs. + virtual CachedGraph buildGraph( + EValuePtrSpan inputs, + EValuePtrSpan outputs) = 0; + + // Compute a human-readable cache key. Used only for log messages on cache + // miss — the actual cache is keyed on a packed binary ShapeKey (below) for + // perf. Subclasses can override to add op-specific signature info. + virtual std::string cacheKey( + EValuePtrSpan inputs, + EValuePtrSpan outputs); + + private: + // Binary cache key: packs (dtype, dim, sizes...) for each input then + // each output into a small stack-resident array. Comparable + hashable + // without any heap allocation, so the hot dispatch path is alloc-free. + struct ShapeKey { + static constexpr size_t kMaxPacked = 64; + std::array data{}; + size_t len = 0; + bool operator==(const ShapeKey& o) const { + return len == o.len && + std::memcmp(data.data(), o.data.data(), len * sizeof(uint64_t)) == + 0; + } + }; + struct ShapeKeyHash { + size_t operator()(const ShapeKey& k) const noexcept { + // FNV-1a 64-bit + uint64_t h = 0xcbf29ce484222325ULL; + for (size_t i = 0; i < k.len; ++i) { + h ^= k.data[i]; + h *= 0x100000001b3ULL; + } + return static_cast(h); + } + }; + static ShapeKey packShapes( + EValuePtrSpan inputs, + EValuePtrSpan outputs); + + std::unordered_map cache_; + // Memo of the last looked-up entry. Most inference loops call dispatch + // repeatedly with the same shapes, so this avoids even the hash + map + // lookup on the hot path. + ShapeKey last_key_; + CachedGraph* last_entry_ = nullptr; +}; + +//===----------------------------------------------------------------------===// +// MPSGraphMatMulOp - aten::mm via MPSGraph.matrixMultiplicationWithPrimaryTensor +// +// Inputs: A [M, K], B [K, N] +// Output: C [M, N] = A @ B +// +// Useful for cases where MPSGraph's tile/algorithm selection beats our hand +// kernel (large matmul, mixed dtypes, etc). Per-shape graph cache keeps +// per-call overhead at the encode step (~50-100µs) after warm-up. +//===----------------------------------------------------------------------===// + +class MPSGraphMatMulOp : public MPSGraphOp { + public: + const char* name() const override { return "aten::mm"; } + bool supports(ScalarType dtype) const override { + return dtype == ScalarType::Float || dtype == ScalarType::Half || + dtype == ScalarType::BFloat16; + } + // Output shape for mm(A[M,K], B[K,N]) is [M, N]. Required so MPSGraphOp's + // resizeOutput sets the right shape on the output tensor — without this, + // the base resizeOutput falls back to inputs[0]'s shape ([M, K]) and the + // last (N - K) elements of the output are never written. + std::vector computeOutputShape( + EValuePtrSpan inputs) const override { + if (inputs.size() < 2 || !inputs[0]->isTensor() || !inputs[1]->isTensor()) { + return {}; + } + auto& A = inputs[0]->toTensor(); + auto& B = inputs[1]->toTensor(); + if (A.dim() < 2 || B.dim() < 2) return {}; + using SizesType = runtime::etensor::Tensor::SizesType; + return {static_cast(A.size(A.dim() - 2)), + static_cast(B.size(B.dim() - 1))}; + } + + protected: + CachedGraph buildGraph( + EValuePtrSpan inputs, + EValuePtrSpan outputs) override; +}; + +} // namespace metal_v2 +} // namespace backends +} // namespace executorch diff --git a/backends/portable/runtime/metal_v2/ops/MPSGraphOp.mm b/backends/portable/runtime/metal_v2/ops/MPSGraphOp.mm new file mode 100644 index 00000000000..050bbbafdda --- /dev/null +++ b/backends/portable/runtime/metal_v2/ops/MPSGraphOp.mm @@ -0,0 +1,244 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#import "MPSGraphOp.h" + +#include + +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace metal_v2 { + +using executorch::runtime::EValue; +using executorch::runtime::etensor::ScalarType; +using executorch::runtime::etensor::Tensor; + +namespace { + +// Map ExecuTorch ScalarType to MPSDataType. +MPSDataType toMPSDataType(ScalarType t) { + switch (t) { + case ScalarType::Float: return MPSDataTypeFloat32; + case ScalarType::Half: return MPSDataTypeFloat16; + case ScalarType::BFloat16: return MPSDataTypeBFloat16; + case ScalarType::Int: return MPSDataTypeInt32; + case ScalarType::Long: return MPSDataTypeInt64; + case ScalarType::Short: return MPSDataTypeInt16; + case ScalarType::Char: return MPSDataTypeInt8; + case ScalarType::Byte: return MPSDataTypeUInt8; + case ScalarType::Bool: return MPSDataTypeBool; + default: + ET_LOG(Error, "MPSGraphOp: unsupported dtype %d", static_cast(t)); + return MPSDataTypeFloat32; + } +} + +// NSArray from a Tensor's sizes. +NSArray* nsShape(const Tensor& t) { + NSMutableArray* a = + [NSMutableArray arrayWithCapacity:t.dim()]; + for (ssize_t i = 0; i < t.dim(); ++i) { + [a addObject:@(t.size(i))]; + } + return a; +} + +// Wrap an ExecuTorch tensor as MPSGraphTensorData (no copy, no retain on data). +MPSGraphTensorData* makeTensorData(MetalStream* stream, const Tensor& t) { + id buf = stream->bufferForPtr(t.mutable_data_ptr(), t.nbytes()); + return [[MPSGraphTensorData alloc] initWithMTLBuffer:buf + shape:nsShape(t) + dataType:toMPSDataType(t.scalar_type())]; +} + +} // anonymous namespace + +//===----------------------------------------------------------------------===// +// MPSGraphOp base +//===----------------------------------------------------------------------===// + +MPSGraphOp::~MPSGraphOp() { + for (auto& kv : cache_) { + if (kv.second.graph) [kv.second.graph release]; + } + cache_.clear(); +} + +std::string MPSGraphOp::cacheKey( + EValuePtrSpan inputs, + EValuePtrSpan outputs) { + std::ostringstream oss; + auto append = [&oss](const Tensor& t) { + oss << static_cast(t.scalar_type()) << ':'; + for (ssize_t i = 0; i < t.dim(); ++i) { + oss << t.size(i) << (i + 1 < t.dim() ? 'x' : ';'); + } + }; + for (auto* e : inputs) append(e->toTensor()); + oss << "->"; + for (auto* e : outputs) append(e->toTensor()); + return oss.str(); +} + +MPSGraphOp::ShapeKey MPSGraphOp::packShapes( + EValuePtrSpan inputs, + EValuePtrSpan outputs) { + ShapeKey k; + size_t n = 0; + auto pack = [&](const Tensor& t) { + // dtype + dim + sizes; bail if we'd overflow (caller falls back to slow + // path, which is correct but allocates). + size_t need = 2 + static_cast(t.dim()); + if (n + need > ShapeKey::kMaxPacked) { + n = ShapeKey::kMaxPacked + 1; // sentinel: overflow + return; + } + k.data[n++] = static_cast(static_cast(t.scalar_type())); + k.data[n++] = static_cast(t.dim()); + for (ssize_t i = 0; i < t.dim(); ++i) { + k.data[n++] = static_cast(t.size(i)); + } + }; + for (auto* e : inputs) pack(e->toTensor()); + // Marker between inputs and outputs so [a,b] / [c] is distinct from [a] / [b,c]. + if (n < ShapeKey::kMaxPacked) k.data[n++] = ~0ULL; + for (auto* e : outputs) pack(e->toTensor()); + k.len = std::min(n, ShapeKey::kMaxPacked); + return k; +} + +void MPSGraphOp::dispatch( + MetalStream* stream, + EValuePtrSpan inputs, + EValuePtrSpan outputs) { + // Resize output(s) using the standard helper (subclass may also override + // computeOutputShape if non-trivial). + for (auto* outE : outputs) { + auto err = resizeOutput(inputs, outE); + if (err != runtime::Error::Ok) { + ET_LOG(Error, "MPSGraphOp(%s): output resize failed", name()); + return; + } + } + + // 1. Cache lookup. Two-level: a single-slot memo for the common case + // where the same shapes appear across consecutive calls (steady-state + // inference), and a hash-keyed map for the slow path. + ShapeKey key = packShapes(inputs, outputs); + CachedGraph* g_ptr = nullptr; + if (last_entry_ && key == last_key_) { + g_ptr = last_entry_; + } else { + auto it = cache_.find(key); + if (it == cache_.end()) { + auto entry = buildGraph(inputs, outputs); + if (!entry.graph) { + ET_LOG(Error, "MPSGraphOp(%s): buildGraph returned null", name()); + return; + } + [entry.graph retain]; + it = cache_.emplace(key, std::move(entry)).first; + ET_LOG( + Info, + "MPSGraphOp(%s): built graph for key=%s", + name(), + cacheKey(inputs, outputs).c_str()); + } + last_key_ = key; + last_entry_ = &it->second; + g_ptr = &it->second; + } + const CachedGraph& g = *g_ptr; + + if (g.inputPlaceholders.size() != inputs.size() || + g.outputPlaceholders.size() != outputs.size()) { + ET_LOG(Error, + "MPSGraphOp(%s): cached graph arity mismatch (%zu/%zu vs %zu/%zu)", + name(), g.inputPlaceholders.size(), g.outputPlaceholders.size(), + inputs.size(), outputs.size()); + return; + } + + auto* mstream = static_cast(stream); + if (!mstream) { + ET_LOG(Error, "MPSGraphOp(%s): stream is not MetalStream", name()); + return; + } + + @autoreleasepool { + // Wrap inputs/outputs as MPSGraphTensorData (shared by both paths). + NSMutableDictionary* feeds = + [NSMutableDictionary dictionaryWithCapacity:inputs.size()]; + for (size_t i = 0; i < inputs.size(); ++i) { + feeds[g.inputPlaceholders[i]] = + makeTensorData(mstream, inputs[i]->toTensor()); + } + NSMutableDictionary* results = + [NSMutableDictionary dictionaryWithCapacity:outputs.size()]; + for (size_t i = 0; i < outputs.size(); ++i) { + results[g.outputPlaceholders[i]] = + makeTensorData(mstream, outputs[i]->toTensor()); + } + + // MetalStream handles all the cross-queue sync, ICB drain, encoder + // close, commit-and-continue adopt-back. We just encode our MPSGraph. + mstream->encodeWithLegacyCommandBuffer([&](MPSCommandBuffer* mpsCB) { + [g.graph encodeToCommandBuffer:mpsCB + feeds:feeds + targetOperations:nil + resultsDictionary:results + executionDescriptor:nil]; + }); + } +} + +//===----------------------------------------------------------------------===// +// MPSGraphMatMulOp - aten::mm +//===----------------------------------------------------------------------===// + +MPSGraphOp::CachedGraph MPSGraphMatMulOp::buildGraph( + EValuePtrSpan inputs, + EValuePtrSpan outputs) { + const Tensor& A = inputs[0]->toTensor(); + const Tensor& B = inputs[1]->toTensor(); + const Tensor& C = outputs[0]->toTensor(); + + MPSGraph* graph = [MPSGraph new]; + MPSDataType dt = toMPSDataType(C.scalar_type()); + + MPSGraphTensor* aPh = [graph placeholderWithShape:nsShape(A) + dataType:toMPSDataType(A.scalar_type()) + name:@"A"]; + MPSGraphTensor* bPh = [graph placeholderWithShape:nsShape(B) + dataType:toMPSDataType(B.scalar_type()) + name:@"B"]; + MPSGraphTensor* cTensor = + [graph matrixMultiplicationWithPrimaryTensor:aPh + secondaryTensor:bPh + name:@"C"]; + // If output dtype differs from inputs (e.g. fp16 inputs producing fp32), + // cast. Most matmuls keep the dtype. + if (toMPSDataType(A.scalar_type()) != dt) { + cTensor = [graph castTensor:cTensor toType:dt name:@"C_cast"]; + } + + return CachedGraph{ + .graph = graph, + .inputPlaceholders = {aPh, bPh}, + .outputPlaceholders = {cTensor}, + }; +} + +} // namespace metal_v2 +} // namespace backends +} // namespace executorch diff --git a/backends/portable/runtime/metal_v2/ops/MatMulOp.h b/backends/portable/runtime/metal_v2/ops/MatMulOp.h new file mode 100644 index 00000000000..ef4272bf43f --- /dev/null +++ b/backends/portable/runtime/metal_v2/ops/MatMulOp.h @@ -0,0 +1,130 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace executorch { +namespace backends { +namespace metal_v2 { + +//===----------------------------------------------------------------------===// +// MatMul Kernel Types +//===----------------------------------------------------------------------===// + +enum class MatMulKernelType { + Naive, // Simple kernel for small matrices + Tiled, // Tiled with threadgroup memory (32x32) + Simd, // Simdgroup MMA (BM=64x64 tiles, 2x2 simdgroup layout) + Simd_BN32, // Simdgroup MMA, BM=64, BN=32, BK=32, 2x2 simd layout — + // MLX's "small fp32 NN" tile. Smaller BN doubles the + // tg count along N (more parallelism for small-N cases), + // larger BK halves the K-tile barrier count. + Simd_M32, // Simdgroup MMA, BM=32, BN=64, BK=32, 1x4 simd layout + // (for 2 <= M < 64, with bounds-check waste for small M) + NT, // Simd MMA, B is logically transposed (B.T view) + TN, // Simd MMA, A is logically transposed (A.T view) + GEMV, // y = A @ x (N == 1) + GEMV_T, // C = x @ B (M == 1, uses gemv_t with swapped operands) + TensorOps // Metal 4 tensor_ops::matmul2d (Apple9+, fastest path when supported) +}; + +//===----------------------------------------------------------------------===// +// MatMulOp - 2D matrix multiply (aten::mm) +//===----------------------------------------------------------------------===// + +class MatMulOp : public MetalOp { +public: + const char* name() const override { return "aten::mm"; } + + bool supports(ScalarType dtype) const override { + return isFloatingPoint(dtype); + } + + std::vector computeOutputShape( + EValuePtrSpan inputs) const override; + + void dispatch( + MetalStream* stream, + EValuePtrSpan inputs, + EValuePtrSpan outputs) override; + +protected: + const char* kernelSource() const override; + + MatMulKernelType selectKernel(int64_t M, int64_t N, int64_t K) const; + const char* kernelTypePrefix(MatMulKernelType type) const; +}; + +//===----------------------------------------------------------------------===// +// BatchedMatMulOp - 3D batched matrix multiply (aten::bmm) +//===----------------------------------------------------------------------===// + +class BatchedMatMulOp : public MetalOp { +public: + const char* name() const override { return "aten::bmm"; } + + bool supports(ScalarType dtype) const override { + return isFloatingPoint(dtype); + } + + std::vector computeOutputShape( + EValuePtrSpan inputs) const override; + + void dispatch( + MetalStream* stream, + EValuePtrSpan inputs, + EValuePtrSpan outputs) override; + +protected: + const char* kernelSource() const override; +}; + +//===----------------------------------------------------------------------===// +// AddMMOp - fused 2D matrix multiply with bias (aten::addmm) +// +// Computes: out = beta * input + alpha * (mat1 @ mat2) +// inputs[0] = input (bias) [M, N] — must be 2D contiguous OR 1D-broadcast [N] +// inputs[1] = mat1 [M, K] +// inputs[2] = mat2 [K, N] +// inputs[3] = beta (Scalar, default 1) +// inputs[4] = alpha (Scalar, default 1) +// +// Currently supports the common LLM/Linear case: alpha=1, beta=1, NN layout, +// bias is [M, N] contiguous OR [N] (1D-broadcast). Falls back to MatMulOp +// when constraints don't hold (caller-side check; for now we just assert). +// +// Saves an entire elementwise add pass over the matmul's output by fusing +// the bias add into the matmul kernel's epilogue (loaded as an 8x8 simdgroup +// fragment, added to the accumulator before simdgroup_store). +//===----------------------------------------------------------------------===// + +class AddMMOp : public MetalOp { +public: + const char* name() const override { return "aten::addmm"; } + + bool supports(ScalarType dtype) const override { + return isFloatingPoint(dtype); + } + + std::vector computeOutputShape( + EValuePtrSpan inputs) const override; + + void dispatch( + MetalStream* stream, + EValuePtrSpan inputs, + EValuePtrSpan outputs) override; + +protected: + const char* kernelSource() const override; +}; + +} // namespace metal_v2 +} // namespace backends +} // namespace executorch diff --git a/backends/portable/runtime/metal_v2/ops/MatMulOp.mm b/backends/portable/runtime/metal_v2/ops/MatMulOp.mm new file mode 100644 index 00000000000..e1dbdc827e3 --- /dev/null +++ b/backends/portable/runtime/metal_v2/ops/MatMulOp.mm @@ -0,0 +1,1559 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#import "MatMulOp.h" +#include +#include +#include +#include +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace metal_v2 { + +using runtime::Error; + +//===----------------------------------------------------------------------===// +// Output Shape +//===----------------------------------------------------------------------===// + +std::vector MatMulOp::computeOutputShape( + EValuePtrSpan inputs) const { + + if (inputs.size() < 2 || !inputs[0]->isTensor() || !inputs[1]->isTensor()) { + return {}; + } + + auto& A = inputs[0]->toTensor(); + auto& B = inputs[1]->toTensor(); + + if (A.dim() < 2 || B.dim() < 2) { + return {}; + } + + SizesType M = A.size(A.dim() - 2); + SizesType N = B.size(B.dim() - 1); + + return {M, N}; +} + +//===----------------------------------------------------------------------===// +// Kernel Selection +// +// Picks among Naive / Tiled / Simd / NT / TN / GEMV / GEMV_T based on: +// - input layout (row-contig vs col-contig, where col-contig means the +// tensor is a .T view of an underlying row-contig tensor) +// - problem size (M, N, K) +// - device tier (smaller thresholds on phones, larger on Ultra/Max) +//===----------------------------------------------------------------------===// + +namespace { + +struct MatMulThresholds { + int simdMNK; // min M,N,K to pick Simd over Tiled/Naive + int gemvMK; // min M (or N for gemv_t) to use the simdgroup gemv path +}; + +constexpr MatMulThresholds thresholdsForTier(DeviceTier tier) { + switch (tier) { + case DeviceTier::Phone: return {32, 16}; + case DeviceTier::MacUltra: return {64, 32}; + case DeviceTier::MacBase: + default: return {48, 24}; + } +} + +} // namespace + +const char* MatMulOp::kernelTypePrefix(MatMulKernelType type) const { + switch (type) { + case MatMulKernelType::Naive: return "matmul_naive"; + case MatMulKernelType::Tiled: return "matmul_tiled"; + case MatMulKernelType::Simd: return "matmul_simd_t_64_64_16_2_2_n"; + case MatMulKernelType::Simd_BN32: return "matmul_simd_t_64_32_32_2_2_n"; + case MatMulKernelType::Simd_M32: return "matmul_simd_t_32_64_32_1_4_n"; + case MatMulKernelType::NT: return "matmul_simd_t_64_64_16_2_2_t"; + case MatMulKernelType::TN: return "matmul_simd_t_64_64_16_2_2_tn"; + case MatMulKernelType::GEMV: return "gemv"; + case MatMulKernelType::GEMV_T: return "gemv_t"; + case MatMulKernelType::TensorOps: return "matmul_tensor_ops"; + } + return "matmul_naive"; +} + +// selectKernel only handles the size-based fallback ladder for the regular +// (NN) case. NT/TN/GEMV/GEMV_T are picked separately based on input layout. +// +// Tier ladder: +// Simd : M >= 64. 64x64 output, 4 sg in 2x2. +// Simd_M32 : 16 <= M < 64. 32x64 output, 4 sg in 1x4. MLX-style "skinny" +// variant for prefill batches like Llama M=32. +// Tiled : 32 <= M < 16 (rare middle ground), legacy fallback. +// Naive : everything smaller. +// +// Variants compiled but NOT auto-routed (kept for future use / experimentation): +// - Simd_M32_BN128: tried for compute-bound large-K cases. Theoretical AI +// gain (10.7 -> 12.8 FLOPs/byte) is real, but in practice doubling the +// per-sg register pressure (16 vs 8 simdgroup_matrix accumulators) and +// threadgroup memory cuts wave-level occupancy by more than the AI gain +// buys. Net regression on Apple M-series. Could be reconsidered if we +// add register-blocked variants or tune for specific GPU families. +// - Simd_M32_SplitK: didn't help compute-bound cases (the bottleneck is +// arithmetic intensity, not parallelism). +MatMulKernelType MatMulOp::selectKernel(int64_t M, int64_t N, int64_t K) const { + if (M >= 64 && N >= 64 && K >= 16) { + // MLX-inspired heuristic for fp32 NN, refined empirically from sweep: + // N <= 1024 -> BN32 (need more tgs along N for parallelism) + // M >= 512 + K >= 4096 -> BN32 (BK=32 halves K-barrier count, big wins + // when both M and K are large enough that + // barrier overhead dominates) + // otherwise -> Simd (BN=64 wins for moderate-M large-N where + // tg-level data reuse beats parallelism) + if (N <= 1024) return MatMulKernelType::Simd_BN32; + if (M >= 512 && K >= 4096) return MatMulKernelType::Simd_BN32; + return MatMulKernelType::Simd; + } + if (M >= 2 && N >= 64 && K >= 16) return MatMulKernelType::Simd_M32; + if (M >= 32 && N >= 32) return MatMulKernelType::Tiled; + return MatMulKernelType::Naive; +} + +//===----------------------------------------------------------------------===// +// Dispatch +//===----------------------------------------------------------------------===// + +void MatMulOp::dispatch( + MetalStream* stream, + EValuePtrSpan inputs, + EValuePtrSpan outputs) { + + // TEMPORARY runtime switch: when METAL_USE_MPSGRAPH=1 (or =true), route ALL + // matmul cases through MPSGraph instead of our hand-written kernels. Useful + // for benchmarking and as a sanity-check fallback when the custom kernels + // misbehave. Selection logic below is left intact (just bypassed). + // Works under both MTL3 and MTL4 (MPSGraphOp branches internally on + // useMTL4() — under MTL4 it uses a singleton legacy queue + shared event). + static const bool kForceMPSGraph = []() { + const char* env = getenv("METAL_USE_MPSGRAPH"); + return env && (strcmp(env, "1") == 0 || strcmp(env, "true") == 0); + }(); + if (kForceMPSGraph) { + static MPSGraphMatMulOp mpsOp; + ET_LOG(Info, "MatMulOp: forcing MPSGraph path (METAL_USE_MPSGRAPH=1)"); + mpsOp.dispatch(stream, inputs, outputs); + return; + } + + auto& A = inputs[0]->toTensor(); + auto& B = inputs[1]->toTensor(); + auto& C = outputs[0]->toTensor(); + + auto err = resizeOutput(inputs, outputs[0]); + if (err != Error::Ok) { + ET_LOG(Error, "MatMulOp: failed to resize output"); + return; + } + + const bool aRC = isRowContiguous(A); + const bool bRC = isRowContiguous(B); + const bool aCC = !aRC && isColContiguous(A); + const bool bCC = !bRC && isColContiguous(B); + + if (!(aRC || aCC) || !(bRC || bCC)) { + ET_LOG(Error, "MatMulOp: A and B must each be row- or column-contiguous"); + return; + } + if (aCC && bCC) { + ET_LOG(Error, "MatMulOp: matmul_tt (both transposed) is not implemented"); + return; + } + + int32_t M = static_cast(A.size(0)); + int32_t K = static_cast(A.size(1)); + int32_t N = static_cast(B.size(1)); + + ScalarType dtype = C.scalar_type(); + + // Pick kernel type from layout + size. + MatMulKernelType kernelType; + if (aRC && bRC) { + if (N == 1) kernelType = MatMulKernelType::GEMV; + else if (M == 1) kernelType = MatMulKernelType::GEMV_T; + else kernelType = selectKernel(M, N, K); + } else if (aRC && bCC) { + kernelType = MatMulKernelType::NT; + } else /* aCC && bRC — A is transposed (TN) */ { + kernelType = MatMulKernelType::TN; + } + + // Upgrade Simd -> TensorOps when device supports Apple9 family (M3+/A17 Pro+) + // AND sizes meet matmul2d constraints (BM=BN=64, K aligned to 16). We don't + // upgrade NT/TN — tensor_ops::matmul2d would need a different descriptor + // (transpose flags). Could be added later if perf justifies it. + if (kernelType == MatMulKernelType::Simd) { + auto* metalStream = static_cast(stream); + if (metalStream && metalStream->device() && + [metalStream->device() supportsFamily:MTLGPUFamilyApple9] && + (M % 64 == 0) && (N % 64 == 0) && (K % 16 == 0) && + (dtype == ScalarType::Float || + dtype == ScalarType::Half || + dtype == ScalarType::BFloat16)) { + kernelType = MatMulKernelType::TensorOps; + } + } + + std::string kname = std::string(kernelTypePrefix(kernelType)) + "_" + dtypeSuffix(dtype); + auto* kernel = getKernel(stream, kname.c_str()); + + ET_LOG(Info, "MatMulOp: M=%d, K=%d, N=%d, kernel=%s", M, K, N, kname.c_str()); + + uvec3 grid, block; + + switch (kernelType) { + case MatMulKernelType::Naive: + grid = uvec3((N + 7) / 8, (M + 7) / 8, 1); + block = uvec3(8, 8, 1); + break; + + case MatMulKernelType::Tiled: + grid = uvec3((N + 31) / 32, (M + 31) / 32, 1); + block = uvec3(32, 32, 1); + break; + + case MatMulKernelType::Simd: + case MatMulKernelType::NT: + case MatMulKernelType::TN: + case MatMulKernelType::TensorOps: + // 64x64 output tile, 4 simdgroups (128 threads), grid.z=1 (no batch). + grid = uvec3((N + 63) / 64, (M + 63) / 64, 1); + block = uvec3(128, 1, 1); + break; + + case MatMulKernelType::Simd_BN32: + // 64x32 output tile (BM=64, BN=32), 4 simdgroups in 2x2 layout. + // Doubles tg count along N vs Simd, helps small-N cases. + grid = uvec3((N + 31) / 32, (M + 63) / 64, 1); + block = uvec3(128, 1, 1); + break; + + case MatMulKernelType::Simd_M32: + // 32x64 output tile (BM=32, BN=64), 4 simdgroups in 1x4 layout (128 + // threads), grid.z=1 (no batch). + grid = uvec3((N + 63) / 64, (M + 31) / 32, 1); + block = uvec3(128, 1, 1); + break; + + case MatMulKernelType::GEMV: + // y = A @ x ; one simdgroup per output row (M outputs). + grid = uvec3(M * 32, 1, 1); + block = uvec3(32, 1, 1); + break; + + case MatMulKernelType::GEMV_T: + // C = A_row @ B ; MLX-style TM×TN tiled gemv_t. Each tg = 1 simdgroup + // (32 threads) outputs SN*TN = 16 consecutive columns. Total tgs = + // ceil(N / 16). + grid = uvec3(((N + 15) / 16) * 32, 1, 1); + block = uvec3(32, 1, 1); + break; + } + + // GEMV_T has swapped operand semantics: gemv_t(matrix=B, vector=A, out=C). + if (kernelType == MatMulKernelType::GEMV_T) { + stream->dispatch(kernel, { + {B.mutable_data_ptr(), B.nbytes()}, // matrix [K,N] + {A.mutable_data_ptr(), A.nbytes()}, // vector [K] + {C.mutable_data_ptr(), C.nbytes()}, // output [N] + M, K, N + }, grid, block); + } else { + stream->dispatch(kernel, { + {A.mutable_data_ptr(), A.nbytes()}, + {B.mutable_data_ptr(), B.nbytes()}, + {C.mutable_data_ptr(), C.nbytes()}, + M, K, N + }, grid, block); + } +} + +//===----------------------------------------------------------------------===// +// Kernel Source +//===----------------------------------------------------------------------===// + +// Kernel source: prepend kTileLoadMetalSource so matmul_simd / nt / tn can +// call cooperativeLoadTileVec4 + cooperativeLoadTileTransposedVec4. Built +// once into a static std::string. Shared between MatMulOp and +// BatchedMatMulOp so both can reference matmul_simd and bmm kernels. +static const std::string& matmulKernelSource() { + static const std::string source = std::string(kTileLoadMetalSource) + R"( +#include +#include +using namespace metal; + +constant int TILE_SIZE = 32; + +//===----------------------------------------------------------------------===// +// Naive kernel (fallback for small matrices or older devices) +//===----------------------------------------------------------------------===// + +template +kernel void matmul_naive( + device const T* A [[buffer(0)]], + device const T* B [[buffer(1)]], + device T* C [[buffer(2)]], + constant int& M [[buffer(3)]], + constant int& K [[buffer(4)]], + constant int& N [[buffer(5)]], + uint2 gid [[thread_position_in_grid]]) { + int row = gid.y; + int col = gid.x; + if (row >= M || col >= N) return; + + T sum = T(0); + for (int k = 0; k < K; k++) { + sum += A[row * K + k] * B[k * N + col]; + } + C[row * N + col] = sum; +} + +//===----------------------------------------------------------------------===// +// Tiled kernel (medium matrices) +//===----------------------------------------------------------------------===// + +template +kernel void matmul_tiled( + device const T* A [[buffer(0)]], + device const T* B [[buffer(1)]], + device T* C [[buffer(2)]], + constant int& M [[buffer(3)]], + constant int& K [[buffer(4)]], + constant int& N [[buffer(5)]], + uint2 gid [[thread_position_in_grid]], + uint2 tid [[thread_position_in_threadgroup]], + uint2 tgid [[threadgroup_position_in_grid]]) { + + threadgroup T As[TILE_SIZE][TILE_SIZE + 1]; + threadgroup T Bs[TILE_SIZE][TILE_SIZE + 1]; + + int row = tgid.y * TILE_SIZE + tid.y; + int col = tgid.x * TILE_SIZE + tid.x; + + T sum = T(0); + + for (int tileK = 0; tileK < K; tileK += TILE_SIZE) { + int aRow = row; + int aCol = tileK + tid.x; + As[tid.y][tid.x] = (aRow < M && aCol < K) ? A[aRow * K + aCol] : T(0); + + int bRow = tileK + tid.y; + int bCol = col; + Bs[tid.y][tid.x] = (bRow < K && bCol < N) ? B[bRow * N + bCol] : T(0); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (int k = 0; k < TILE_SIZE && (tileK + k) < K; k++) { + sum += As[tid.y][k] * Bs[k][tid.x]; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + if (row < M && col < N) { + C[row * N + col] = sum; + } +} + +//===----------------------------------------------------------------------===// +// MMA helper: run one K-tile of multiply-accumulate using simdgroup_matrix. +// Loads FRAGS_M A-fragments × FRAGS_N B-fragments from threadgroup memory +// at (a_row, b_col) within the tile, then performs FRAGS_M × FRAGS_N MMAs +// into the existing C_frag accumulators. +// +// Templating on FRAGS_M / FRAGS_N lets matmul_simd (4×4) and +// matmul_simd_m32 (4×2) share the inner loop without duplication. +//===----------------------------------------------------------------------===// +template +inline void simdMMAKTile( + simdgroup_matrix C_frag[FRAGS_M][FRAGS_N], + threadgroup const T* As, // points at &As_buf[curBuf][a_row][k_off] + threadgroup const T* Bs, // points at &Bs_buf[curBuf][k_off][b_col] + int BK_) { + for (int k = 0; k < BK_; k += 8) { + simdgroup_matrix A_frag[FRAGS_M]; + #pragma clang loop unroll(full) + for (int i = 0; i < FRAGS_M; ++i) { + simdgroup_load(A_frag[i], As + i * 8 * SMEM_A_STRIDE + k, SMEM_A_STRIDE); + } + simdgroup_matrix B_frag[FRAGS_N]; + #pragma clang loop unroll(full) + for (int j = 0; j < FRAGS_N; ++j) { + simdgroup_load(B_frag[j], Bs + k * SMEM_B_STRIDE + j * 8, SMEM_B_STRIDE); + } + #pragma clang loop unroll(full) + for (int i = 0; i < FRAGS_M; ++i) { + #pragma clang loop unroll(full) + for (int j = 0; j < FRAGS_N; ++j) { + simdgroup_multiply_accumulate( + C_frag[i][j], A_frag[i], B_frag[j], C_frag[i][j]); + } + } + } +} + +//===----------------------------------------------------------------------===// +// BlockLoader: stateful, MLX-style cooperative tile loader. +// +// Drop-in replacement for cooperativeLoadTileVec4 with several improvements: +// +// 1) Auto-derives per-thread vec width from (BROWS * BCOLS / tgp_size). +// Our cooperativeLoadTileVec4 was hardcoded to vec4 — that meant a +// 64-thread tg loading a 16x64 tile had to do 4 vec4 loads/thread +// (16 elts each); BlockLoader auto-derives vec16 = 1 load/thread. +// +// 2) Stateful src pointer + next() advances by one K-tile worth of bytes, +// avoiding the per-call (gRow * srcStride + gCol) re-derivation. +// +// 3) Branch-free load_safe via predicate SELECT (not predicate FLOW): +// tmp_val[j] = src[in_bounds ? offset : 0]; +// tmp_val[j] = in_bounds ? tmp_val[j] : 0; +// The compiler can vectorize fully even at edges, no warp divergence. +// +// 4) reduction_dim template flag selects K-direction: +// reduction_dim=0 → K is the row dim (B's tile) → tile_stride = BROWS*src_ld +// reduction_dim=1 → K is the col dim (A's tile) → tile_stride = BCOLS +// +// 5) ReadVector POD struct for arbitrary vec_size loads (compiler lowers +// to underlying vec4/vec8 instructions). +// +// Constraints (static_assert): +// BROWS * BCOLS must be divisible by tgp_size (so n_reads is integer) +// BCOLS must be divisible by n_reads (so TCOLS is integer) +// +// dst is supplied per call (load_unsafe / load_safe) so that the same +// loader instance can target either of two threadgroup buffers in +// double-buffered K loops. +//===----------------------------------------------------------------------===// + +template < + typename T, short BROWS, short BCOLS, short dst_ld, + short reduction_dim, short tgp_size> +struct BlockLoader { + static_assert((BROWS * BCOLS) % tgp_size == 0, + "BROWS*BCOLS must be divisible by tgp_size"); + + // Compile-time-derived shape using enum (Metal does not allow + // static constexpr struct members in the default address space; enum + // constants work because they have no storage). + enum : short { + n_reads = (BCOLS * BROWS) / tgp_size, + vec_size = n_reads, + TCOLS = BCOLS / n_reads, + TROWS = tgp_size / TCOLS, + }; + static_assert(BCOLS % n_reads == 0, + "BCOLS must be divisible by n_reads"); + + // Per-thread (bi, bj) within the tile. + const int src_ld; + const int tile_stride; + const short bi; + const short bj; + device const T* src; + + // POD-sized vector for raw byte copy. Compiler lowers to native vec4/8/16 + // load/store instructions as appropriate. + struct alignas(sizeof(T)) ReadVector { + uint8_t v[sizeof(T) * vec_size]; + }; + + inline BlockLoader(device const T* src_, int src_ld_, ushort tid) + : src_ld(src_ld_), + tile_stride(reduction_dim ? BCOLS : BROWS * src_ld_), + bi(tid / TCOLS), + bj(vec_size * (tid % TCOLS)), + src(src_ + bi * src_ld_ + bj) {} + + // Branch-free load: assumes the entire tile is in-bounds. + inline void load_unsafe(threadgroup T* dst) const { + threadgroup T* dst_thread = dst + bi * dst_ld + bj; + #pragma clang loop unroll(full) + for (short i = 0; i < BROWS; i += TROWS) { + *((threadgroup ReadVector*)(dst_thread + i * dst_ld)) = + *((device const ReadVector*)(src + i * src_ld)); + } + } + + // Bounds-checked load. src_tile_dim = (in-bounds-cols, in-bounds-rows) + // for the current tile (computed by caller from M/N/K and tile offsets). + // Out-of-bounds elements are zero-filled via predicate SELECT — no + // warp divergence even at edges. + inline void load_safe(threadgroup T* dst, short2 src_tile_dim) const { + threadgroup T* dst_thread = dst + bi * dst_ld + bj; + short2 my_dim = src_tile_dim - short2(bj, bi); + + // This thread is entirely past the tile edge → zero-fill. + if (my_dim.x <= 0 || my_dim.y <= 0) { + #pragma clang loop unroll(full) + for (short i = 0; i < BROWS; i += TROWS) { + #pragma clang loop unroll(full) + for (short j = 0; j < vec_size; ++j) { + dst_thread[i * dst_ld + j] = T(0); + } + } + return; + } + + bool tmp_idx[vec_size]; + T tmp_val[vec_size]; + #pragma clang loop unroll(full) + for (short i = 0; i < BROWS; i += TROWS) { + #pragma clang loop unroll(full) + for (short j = 0; j < vec_size; ++j) { + tmp_idx[j] = (i < my_dim.y) && (j < my_dim.x); + } + // Predicate SELECT for the load: read from a safe address (offset 0) + // when out-of-bounds. Avoids reading past the buffer. + #pragma clang loop unroll(full) + for (short j = 0; j < vec_size; ++j) { + tmp_val[j] = src[tmp_idx[j] ? (i * src_ld + j) : 0]; + } + #pragma clang loop unroll(full) + for (short j = 0; j < vec_size; ++j) { + tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0); + } + #pragma clang loop unroll(full) + for (short j = 0; j < vec_size; ++j) { + dst_thread[i * dst_ld + j] = tmp_val[j]; + } + } + } + + // Advance src to the next K-tile. + inline void next() { + src += tile_stride; + } +}; + +//===----------------------------------------------------------------------===// +// matmul_simd_t: templated GEMM kernel with tunable tile params. +// +// Subsumes matmul_simd / matmul_simd_m32 / matmul_nt via template params: +// BM, BN, BK : output / K tile dims (BM,BN multiples of 8; BK multiple of 8) +// WM, WN : simdgroup grid (WM × WN simdgroups per tg, total WM*WN) +// BM must be multiple of WM*8, BN multiple of WN*8. +// TRANSPOSE_B : if true, B is physically [N, K] (logical [K, N]); load via +// the transposed tile loader (used for matmul_nt). +// +// Shape constraints enforced via static_assert. Per-simdgroup output sub-tile +// is (BM/WM) × (BN/WN), broken into FRAGS_M × FRAGS_N fragments of 8×8. +// +// Threadgroup layout: WM * WN simdgroups × 32 = total threads. +// +// Notes: +// - Bounds-checked loaders used everywhere (cooperativeLoadTileVec4 zero- +// pads M/N/K edges). For NN-aligned shapes the compiler may DCE the +// edge predicates inside the inner loop. We do NOT yet have a separate +// "branch-free interior" kernel instantiation (item 3 in the MLX gap +// list); deferred for clarity. +// - Float accumulator is NOT used here — accumulator type follows T. For +// long K reductions in fp16/bf16 this may lose precision; revisit if +// accuracy issues appear. +//===----------------------------------------------------------------------===// + +template +kernel void matmul_simd_t( + device const T* A [[buffer(0)]], + device const T* B [[buffer(1)]], + device T* C [[buffer(2)]], + constant int& M [[buffer(3)]], + constant int& K [[buffer(4)]], + constant int& N [[buffer(5)]], + uint3 tgid [[threadgroup_position_in_grid]], + uint tid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lane [[thread_index_in_simdgroup]]) { + + static_assert(BM % (WM * 8) == 0, "BM must be multiple of WM*8"); + static_assert(BN % (WN * 8) == 0, "BN must be multiple of WN*8"); + static_assert(BK % 8 == 0, "BK must be multiple of 8"); + static_assert(WM * WN >= 1, "Need at least 1 simdgroup"); + + constexpr int NUM_SIMDS = WM * WN; + constexpr int NUM_THREADS = NUM_SIMDS * 32; + constexpr int SUBROWS_PER_SG = BM / WM; + constexpr int SUBCOLS_PER_SG = BN / WN; + constexpr int FRAGS_M = SUBROWS_PER_SG / 8; + constexpr int FRAGS_N = SUBCOLS_PER_SG / 8; + constexpr int PAD = 4; + constexpr int SMEM_A = BK + PAD; + constexpr int SMEM_B = BN + PAD; + + A += int(tgid.z) * M * K; + B += int(tgid.z) * K * N; + C += int(tgid.z) * M * N; + + const int tileRow = int(tgid.y) * BM; + const int tileCol = int(tgid.x) * BN; + + const int simd_m = int(simd_gid) / WN; + const int simd_n = int(simd_gid) % WN; + const int subRow = tileRow + simd_m * SUBROWS_PER_SG; + const int subCol = tileCol + simd_n * SUBCOLS_PER_SG; + + threadgroup T As[2][BM][SMEM_A]; + threadgroup T Bs[2][BK][SMEM_B]; + + simdgroup_matrix C_frag[FRAGS_M][FRAGS_N]; + #pragma clang loop unroll(full) + for (int i = 0; i < FRAGS_M; ++i) { + #pragma clang loop unroll(full) + for (int j = 0; j < FRAGS_N; ++j) { + C_frag[i][j] = simdgroup_matrix(0); + } + } + + // BlockLoader for A (used unless TRANSPOSE_A; constructed unconditionally + // for simplicity — the unused state is just a few register slots). + // For TRANSPOSE_A=true, A is logically [M, K] but stored as [K, M]; loaded + // via cooperativeLoadTileTransposedVec4 directly (stateless). + BlockLoader + loader_a(A + tileRow * K, K, tid); + + const bool m_aligned = (tileRow + BM <= M); + const bool n_aligned = (tileCol + BN <= N); + const int m_inb = m_aligned ? BM : (M - tileRow); + const int n_inb = n_aligned ? BN : (N - tileCol); + + // A-load helper: dispatches to BlockLoader (NN/NT) or transposed helper (TN). + // BUF_IDX = 0 or 1 (which double-buffer slot); k_off = K-tile starting + // offset; k_inb = in-bounds K count for this tile. + #define LOAD_A_TILE(BUF_IDX, k_off, k_inb) \ + do { \ + if (TRANSPOSE_A) { \ + /* A is physical [K,M], stride M; load logical [BM,BK] tile */ \ + cooperativeLoadTileTransposedVec4( \ + As[(BUF_IDX)], A, K, M, M, (k_off), tileRow, tid); \ + } else if (m_aligned && (k_inb) == BK) { \ + loader_a.load_unsafe(&As[(BUF_IDX)][0][0]); \ + } else { \ + loader_a.load_safe(&As[(BUF_IDX)][0][0], short2((k_inb), m_inb)); \ + } \ + } while (0) + + if (TRANSPOSE_B) { + // Initial loads: A (NN or TN-transposed) + B via existing transposed helper. + LOAD_A_TILE(0, 0, min(int(BK), K)); + cooperativeLoadTileTransposedVec4( + Bs[0], B, K, N, K, 0, tileCol, tid); + } else { + // BlockLoader for B (non-transposed): K is the ROW dim. + BlockLoader + loader_b(B + tileCol, N, tid); + + LOAD_A_TILE(0, 0, min(int(BK), K)); + if (n_aligned && BK <= K) loader_b.load_unsafe(&Bs[0][0][0]); + else loader_b.load_safe(&Bs[0][0][0], + short2(n_inb, min(int(BK), K))); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + const int numKTiles = (K + BK - 1) / BK; + + for (int t = 0; t < numKTiles; t++) { + int curBuf = t & 1; + int nextBuf = curBuf ^ 1; + + if (t + 1 < numKTiles) { + if (!TRANSPOSE_A) loader_a.next(); + loader_b.next(); + int nextTileK = (t + 1) * BK; + int k_inb = min(int(BK), K - nextTileK); + LOAD_A_TILE(nextBuf, nextTileK, k_inb); + if (n_aligned && k_inb == BK) loader_b.load_unsafe(&Bs[nextBuf][0][0]); + else loader_b.load_safe(&Bs[nextBuf][0][0], short2(n_inb, k_inb)); + } + + simdMMAKTile( + C_frag, + &As[curBuf][simd_m * SUBROWS_PER_SG][0], + &Bs[curBuf][0][simd_n * SUBCOLS_PER_SG], + BK); + + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + #pragma clang loop unroll(full) + for (int i = 0; i < FRAGS_M; ++i) { + #pragma clang loop unroll(full) + for (int j = 0; j < FRAGS_N; ++j) { + int outRow = subRow + i * 8; + int outCol = subCol + j * 8; + if (outRow < M && outCol < N) { + simdgroup_store(C_frag[i][j], C + outRow * N + outCol, N); + } + } + } + return; + } + + // ========== TRANSPOSE_B=true K-loop ========== + + threadgroup_barrier(mem_flags::mem_threadgroup); + + const int numKTiles = (K + BK - 1) / BK; + + for (int t = 0; t < numKTiles; t++) { + int curBuf = t & 1; + int nextBuf = curBuf ^ 1; + + if (t + 1 < numKTiles) { + if (!TRANSPOSE_A) loader_a.next(); + int nextTileK = (t + 1) * BK; + int k_inb = min(int(BK), K - nextTileK); + LOAD_A_TILE(nextBuf, nextTileK, k_inb); + cooperativeLoadTileTransposedVec4( + Bs[nextBuf], B, K, N, K, nextTileK, tileCol, tid); + } + + simdMMAKTile( + C_frag, + &As[curBuf][simd_m * SUBROWS_PER_SG][0], + &Bs[curBuf][0][simd_n * SUBCOLS_PER_SG], + BK); + + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + #pragma clang loop unroll(full) + for (int i = 0; i < FRAGS_M; ++i) { + #pragma clang loop unroll(full) + for (int j = 0; j < FRAGS_N; ++j) { + int outRow = subRow + i * 8; + int outCol = subCol + j * 8; + if (outRow < M && outCol < N) { + simdgroup_store(C_frag[i][j], C + outRow * N + outCol, N); + } + } + } + #undef LOAD_A_TILE +} + +//===----------------------------------------------------------------------===// +// Simdgroup helpers +//===----------------------------------------------------------------------===// + +// Sum-reduce a per-lane value across the 32-lane simdgroup using the +// shuffle-down ladder. After this, lane 0 holds the total; other lanes hold +// partial sums (don't rely on them). +// +// Note: simd_shuffle_down has overloads for float/half/int but NOT bfloat, +// so any kernel that uses this can't be instantiated for bfloat directly — +// see gemv below. Kernels that don't need cross-lane reduction (e.g. the +// new gemv_t) work for bfloat too. +template +inline T simdReduceSum(T x) { + #pragma clang loop unroll(full) + for (int offset = 16; offset > 0; offset /= 2) { + x += simd_shuffle_down(x, ushort(offset)); + } + return x; +} + +// Sum-reduce a per-lane value across a SUBSET of lanes within the simdgroup +// — specifically, lanes whose IDs differ only in some upper-stride bits. +// `stride` is the smallest distance between two lanes that should be merged +// (= the count of "fast" lanes that don't participate). `log2_count` is +// the number of merges = log2(participating lane count). +// +// Example layout: lane_id = sm * SN + sn, with sm in [0, SM) and sn in +// [0, SN). To reduce across SM lanes (different sm values, same sn), use +// stride=SN and log2_count=log2(SM). After the call, lanes with sm == 0 +// hold the reduced total for their sn group; other sm lanes hold garbage. +template +inline T simdReduceSumStrided(T x, ushort stride, ushort log2_count) { + for (ushort i = 0; i < log2_count; ++i) { + x += simd_shuffle_down(x, ushort(stride << i)); + } + return x; +} + +//===----------------------------------------------------------------------===// +// matmul_simd_addmm_t: NN matmul with FUSED bias add (epilogue fusion). +// +// CURRENTLY BROKEN — kept as a starting point for a future session. +// +// Problem: MSL's simdgroup_matrix does NOT support the '+' binary operator, +// so the naive `simdgroup_store(C_frag[i][j] + bias_frag, ...)` fails to +// compile. To make this work we'd need one of: +// +// 1) Store C_frag to TGSM via simdgroup_store, then have each thread +// do a scalar bias-add load→store pass to global C. ~30 LOC + ~16KB +// additional TGSM (overlay-able with the As/Bs buffers since K-loop +// is done; would need a union or careful sequencing). +// +// 2) Use simdgroup_multiply_accumulate(out, identity, bias_frag, C_frag) +// where identity is an 8x8 identity matrix loaded from a TGSM constant. +// Requires constructing the identity (no built-in MSL constructor). +// +// 3) Switch to a per-fragment lane-aware scalar epilogue using +// simd_shuffle to gather bias values per lane. +// +// All three are real options. Approach 1 is simplest correct; +// approach 2 keeps everything in registers (best perf); approach 3 is the +// most general (easy to extend to other epilogue ops like ReLU/SiLU). +// +// Additional non-kernel work needed for end-to-end addmm: +// - aoti_torch_mps_addmm_out shim function (AOTI C ABI boundary) +// - Update partition allow-list to include aten::addmm +// - Make decompose_linear_pass conditional on whether v2 supports addmm +// +// For now: this kernel template body is the un-fused matmul (no bias), +// kept compiled so the AddMMOp class wiring remains in place. It will +// produce INCORRECT results (matmul without the bias add) if invoked. +// Since AOTI rejects addmm at export time today (no shim), this is not +// reachable from any test or model. Marked TODO for future work. +//===----------------------------------------------------------------------===// + +template +kernel void matmul_simd_addmm_t( + device const T* A [[buffer(0)]], + device const T* B [[buffer(1)]], + device T* C [[buffer(2)]], + constant int& M [[buffer(3)]], + constant int& K [[buffer(4)]], + constant int& N [[buffer(5)]], + device const T* BIAS [[buffer(6)]], + constant int& bias_stride_m [[buffer(7)]], + uint3 tgid [[threadgroup_position_in_grid]], + uint tid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lane [[thread_index_in_simdgroup]]) { + + (void)BIAS; // unused while bias-fusion is TODO (see header) + (void)bias_stride_m; + + static_assert(BM % (WM * 8) == 0, "BM must be multiple of WM*8"); + static_assert(BN % (WN * 8) == 0, "BN must be multiple of WN*8"); + static_assert(BK % 8 == 0, "BK must be multiple of 8"); + + constexpr int NUM_SIMDS = WM * WN; + constexpr int NUM_THREADS = NUM_SIMDS * 32; + constexpr int SUBROWS_PER_SG = BM / WM; + constexpr int SUBCOLS_PER_SG = BN / WN; + constexpr int FRAGS_M = SUBROWS_PER_SG / 8; + constexpr int FRAGS_N = SUBCOLS_PER_SG / 8; + constexpr int PAD = 4; + constexpr int SMEM_A = BK + PAD; + constexpr int SMEM_B = BN + PAD; + + A += int(tgid.z) * M * K; + B += int(tgid.z) * K * N; + C += int(tgid.z) * M * N; + + const int tileRow = int(tgid.y) * BM; + const int tileCol = int(tgid.x) * BN; + const int simd_m = int(simd_gid) / WN; + const int simd_n = int(simd_gid) % WN; + const int subRow = tileRow + simd_m * SUBROWS_PER_SG; + const int subCol = tileCol + simd_n * SUBCOLS_PER_SG; + + threadgroup T As[2][BM][SMEM_A]; + threadgroup T Bs[2][BK][SMEM_B]; + + simdgroup_matrix C_frag[FRAGS_M][FRAGS_N]; + #pragma clang loop unroll(full) + for (int i = 0; i < FRAGS_M; ++i) { + #pragma clang loop unroll(full) + for (int j = 0; j < FRAGS_N; ++j) { + C_frag[i][j] = simdgroup_matrix(0); + } + } + + BlockLoader + loader_a(A + tileRow * K, K, tid); + BlockLoader + loader_b(B + tileCol, N, tid); + + const bool m_aligned = (tileRow + BM <= M); + const bool n_aligned = (tileCol + BN <= N); + const int m_inb = m_aligned ? BM : (M - tileRow); + const int n_inb = n_aligned ? BN : (N - tileCol); + + if (m_aligned && BK <= K) loader_a.load_unsafe(&As[0][0][0]); + else loader_a.load_safe(&As[0][0][0], short2(min(int(BK), K), m_inb)); + if (n_aligned && BK <= K) loader_b.load_unsafe(&Bs[0][0][0]); + else loader_b.load_safe(&Bs[0][0][0], short2(n_inb, min(int(BK), K))); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + const int numKTiles = (K + BK - 1) / BK; + + for (int t = 0; t < numKTiles; t++) { + int curBuf = t & 1; + int nextBuf = curBuf ^ 1; + + if (t + 1 < numKTiles) { + loader_a.next(); + loader_b.next(); + int nextTileK = (t + 1) * BK; + int k_inb = min(int(BK), K - nextTileK); + if (m_aligned && k_inb == BK) loader_a.load_unsafe(&As[nextBuf][0][0]); + else loader_a.load_safe(&As[nextBuf][0][0], short2(k_inb, m_inb)); + if (n_aligned && k_inb == BK) loader_b.load_unsafe(&Bs[nextBuf][0][0]); + else loader_b.load_safe(&Bs[nextBuf][0][0], short2(n_inb, k_inb)); + } + + simdMMAKTile( + C_frag, + &As[curBuf][simd_m * SUBROWS_PER_SG][0], + &Bs[curBuf][0][simd_n * SUBCOLS_PER_SG], + BK); + + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + // TODO: bias-add epilogue (see header comment for the 3 implementation + // options). For now this just stores the unfused matmul accumulator. + #pragma clang loop unroll(full) + for (int i = 0; i < FRAGS_M; ++i) { + #pragma clang loop unroll(full) + for (int j = 0; j < FRAGS_N; ++j) { + int outRow = subRow + i * 8; + int outCol = subCol + j * 8; + if (outRow < M && outCol < N) { + simdgroup_store(C_frag[i][j], C + outRow * N + outCol, N); + } + } + } +} + +//===----------------------------------------------------------------------===// +// GEMV: Matrix-vector (N=1) +//===----------------------------------------------------------------------===// + +template +kernel void gemv( + device const T* A [[buffer(0)]], + device const T* x [[buffer(1)]], + device T* y [[buffer(2)]], + constant int& M [[buffer(3)]], + constant int& K [[buffer(4)]], + constant int& N [[buffer(5)]], + uint gid [[thread_position_in_grid]], + uint simd_lane [[thread_index_in_simdgroup]]) { + + int row = gid / 32; + if (row >= M) return; + + T sum = T(0); + for (int k = simd_lane; k < K; k += 32) { + sum += A[row * K + k] * x[k]; + } + sum = simdReduceSum(sum); + if (simd_lane == 0) { + y[row] = sum; + } +} + +//===----------------------------------------------------------------------===// +// GEMV transposed: y = A^T @ x, A is [K, N] row-major, x is [K], y is [N]. +// Used when M==1 in matmul (autoregressive decode). +// +// Design follows MLX's GEMVTKernel (mlx/backend/metal/kernels/gemv_masked.h): +// Per-thread tile of TM K-rows × TN N-cols. Simdgroup is laid out as +// SM × SN lanes (SM*SN=32) splitting the K dimension SM ways and the N +// dimension SN ways. After the K loop, partial sums are reduced across +// the SM K-lanes via simd_shuffle_down (handled by simdReduceSumStrided). +// +// tg layout : 1 simdgroup = 32 threads +// per-thread tile : TM=4 K-rows × TN=4 N-cols +// simdgroup tile : SM=8 K-lanes × SN=4 N-lanes -> 32*TM K-rows / iter, +// SN*TN=16 output cols / simdgroup +// tg per N : ceil(N / 16) +// +// Why TM×TN instead of "lane-per-col scalar": +// - More work per thread (16 fmas/iter vs 1) -> better load amortization, +// better ILP on the FMA pipeline. +// - K split SM=8 ways within the simdgroup -> 8x less per-thread K work +// for the same K, so scales to large K without becoming latency-bound. +// - Same memory pattern (lanes within a simdgroup access consecutive N +// cols for fixed K row) -> still fully coalesced. +// +// Accumulator promoted to float so reduction works for bf16 (Metal's +// simd_shuffle_down has no bfloat overload). +//===----------------------------------------------------------------------===// + +template +kernel void gemv_t( + device const T* A [[buffer(0)]], + device const T* x [[buffer(1)]], + device T* y [[buffer(2)]], + constant int& M [[buffer(3)]], + constant int& K [[buffer(4)]], + constant int& N [[buffer(5)]], + uint3 tgid [[threadgroup_position_in_grid]], + uint simd_lane [[thread_index_in_simdgroup]]) { + + constexpr int SM = 8; + constexpr int SN = 4; + constexpr int TM = 4; + constexpr int TN = 4; + static_assert(SM * SN == 32, "simdgroup must have 32 lanes"); + constexpr int BLOCK_K = SM * TM; // K rows consumed per outer iter (32) + constexpr int COLS_PER_SG = SN * TN; // output cols per simdgroup (16) + + // Lane decomposition: sn is fast (changes every lane), sm is slow. + ushort sn = simd_lane % SN; + ushort sm = simd_lane / SN; + + // Each tg owns COLS_PER_SG consecutive output columns. This thread's TN + // contiguous cols start here. + int col_base = int(tgid.x) * COLS_PER_SG + int(sn) * TN; + if (col_base >= N) return; + + // Per-thread accumulators in float for accuracy + bf16-safe reduction. + float results[TN] = {0.0f, 0.0f, 0.0f, 0.0f}; + + // Determine in-bounds TN for THIS thread (uniform across the K loop). + // Branch-free inner loop relies on this being checked once. + int valid_tn = TN; + if (col_base + TN > N) { + valid_tn = N - col_base; // 1..TN-1; we still wrote 'return' above for col_base >= N + } + + // Whole BLOCK_K chunks (no per-K bounds check needed). + int k_full = (K / BLOCK_K) * BLOCK_K; + for (int k_block = 0; k_block < k_full; k_block += BLOCK_K) { + int k_start = k_block + int(sm) * TM; + + float x_vals[TM]; + #pragma clang loop unroll(full) + for (int tm = 0; tm < TM; ++tm) { + x_vals[tm] = float(x[k_start + tm]); + } + if (valid_tn == TN) { + // Hot path: full TN cols. Branch-free inner loop. + #pragma clang loop unroll(full) + for (int tm = 0; tm < TM; ++tm) { + int kk = k_start + tm; + #pragma clang loop unroll(full) + for (int tn = 0; tn < TN; ++tn) { + results[tn] += float(A[kk * N + col_base + tn]) * x_vals[tm]; + } + } + } else { + // Edge tg: partial TN. Bounds-check tn; tm is fully in range. + #pragma clang loop unroll(full) + for (int tm = 0; tm < TM; ++tm) { + int kk = k_start + tm; + for (int tn = 0; tn < valid_tn; ++tn) { + results[tn] += float(A[kk * N + col_base + tn]) * x_vals[tm]; + } + } + } + } + + // K tail: remaining K rows < BLOCK_K (only when K is not a multiple of 32). + if (k_full < K) { + int k_start = k_full + int(sm) * TM; + if (k_start < K) { + int tm_max = min(TM, K - k_start); + for (int tm = 0; tm < tm_max; ++tm) { + float xv = float(x[k_start + tm]); + for (int tn = 0; tn < valid_tn; ++tn) { + results[tn] += float(A[(k_start + tm) * N + col_base + tn]) * xv; + } + } + } + } + + // Reduce across SM K-lanes (different sm, same sn). After this, lanes + // with sm == 0 hold the total; other sm lanes hold garbage. + #pragma clang loop unroll(full) + for (int tn = 0; tn < TN; ++tn) { + results[tn] = simdReduceSumStrided(results[tn], ushort(SN), ushort(3)); + } + + // First K-lane writes the in-bounds cols. + if (sm == 0) { + for (int tn = 0; tn < valid_tn; ++tn) { + y[col_base + tn] = T(results[tn]); + } + } +} + +//===----------------------------------------------------------------------===// +// Template instantiations +//===----------------------------------------------------------------------===// + +template [[host_name("matmul_naive_f32")]] kernel void matmul_naive(device const float*, device const float*, device float*, constant int&, constant int&, constant int&, uint2); +template [[host_name("matmul_naive_f16")]] kernel void matmul_naive(device const half*, device const half*, device half*, constant int&, constant int&, constant int&, uint2); +template [[host_name("matmul_naive_bf16")]] kernel void matmul_naive(device const bfloat*, device const bfloat*, device bfloat*, constant int&, constant int&, constant int&, uint2); + +template [[host_name("matmul_tiled_f32")]] kernel void matmul_tiled(device const float*, device const float*, device float*, constant int&, constant int&, constant int&, uint2, uint2, uint2); +template [[host_name("matmul_tiled_f16")]] kernel void matmul_tiled(device const half*, device const half*, device half*, constant int&, constant int&, constant int&, uint2, uint2, uint2); +template [[host_name("matmul_tiled_bf16")]] kernel void matmul_tiled(device const bfloat*, device const bfloat*, device bfloat*, constant int&, constant int&, constant int&, uint2, uint2, uint2); + +// matmul_simd_t instantiations. Kernel name encodes tile params so the +// host can compose the name from a TileSpec at dispatch time. +// matmul_simd_t_______ +// +// Currently registered: +// (32, 64, 32, 1, 4, n) — replicates matmul_simd_m32 (16 <= M < 64) +// (64, 64, 16, 2, 2, n) — replicates matmul_simd (M >= 64, NN) +// (64, 64, 16, 2, 2, t) — replicates matmul_nt (M >= 64, NT) +// Each combo × 3 dtypes = 9 total. Add more here as we build out per-shape +// or per-dtype tile tables. Each one adds compiled .metallib bytes; only +// register what we'll route to. +template [[host_name("matmul_simd_t_32_64_32_1_4_n_f32")]] kernel void matmul_simd_t(device const float*, device const float*, device float*, constant int&, constant int&, constant int&, uint3, uint, uint, uint); +template [[host_name("matmul_simd_t_32_64_32_1_4_n_f16")]] kernel void matmul_simd_t(device const half*, device const half*, device half*, constant int&, constant int&, constant int&, uint3, uint, uint, uint); +template [[host_name("matmul_simd_t_32_64_32_1_4_n_bf16")]] kernel void matmul_simd_t(device const bfloat*,device const bfloat*,device bfloat*,constant int&, constant int&, constant int&, uint3, uint, uint, uint); + +template [[host_name("matmul_simd_t_64_64_16_2_2_n_f32")]] kernel void matmul_simd_t(device const float*, device const float*, device float*, constant int&, constant int&, constant int&, uint3, uint, uint, uint); +template [[host_name("matmul_simd_t_64_64_16_2_2_n_f16")]] kernel void matmul_simd_t(device const half*, device const half*, device half*, constant int&, constant int&, constant int&, uint3, uint, uint, uint); +template [[host_name("matmul_simd_t_64_64_16_2_2_n_bf16")]] kernel void matmul_simd_t(device const bfloat*,device const bfloat*,device bfloat*,constant int&, constant int&, constant int&, uint3, uint, uint, uint); + +template [[host_name("matmul_simd_t_64_64_16_2_2_t_f32")]] kernel void matmul_simd_t(device const float*, device const float*, device float*, constant int&, constant int&, constant int&, uint3, uint, uint, uint); +template [[host_name("matmul_simd_t_64_64_16_2_2_t_f16")]] kernel void matmul_simd_t(device const half*, device const half*, device half*, constant int&, constant int&, constant int&, uint3, uint, uint, uint); +template [[host_name("matmul_simd_t_64_64_16_2_2_t_bf16")]] kernel void matmul_simd_t(device const bfloat*,device const bfloat*,device bfloat*,constant int&, constant int&, constant int&, uint3, uint, uint, uint); + +// MLX's "small fp32 NN" tile: bm=64, bn=32, bk=32, wm=2, wn=2. +// Use for fp32 medium-M cases where Simd's 64x64 tile produces too few +// tgs (small N relative to M) — the smaller BN doubles tg count along N +// and the bigger BK halves K-tile barriers. +template [[host_name("matmul_simd_t_64_32_32_2_2_n_f32")]] kernel void matmul_simd_t(device const float*, device const float*, device float*, constant int&, constant int&, constant int&, uint3, uint, uint, uint); +template [[host_name("matmul_simd_t_64_32_32_2_2_n_f16")]] kernel void matmul_simd_t(device const half*, device const half*, device half*, constant int&, constant int&, constant int&, uint3, uint, uint, uint); +template [[host_name("matmul_simd_t_64_32_32_2_2_n_bf16")]] kernel void matmul_simd_t(device const bfloat*,device const bfloat*,device bfloat*,constant int&, constant int&, constant int&, uint3, uint, uint, uint); + +// TN instantiations: TRANSPOSE_A=true, TRANSPOSE_B=false. A is physically +// stored [K, M] (column-contiguous from PyTorch's view) and loaded via the +// transposed helper; B is loaded normally. +// matmul_simd_t______tn_ +template [[host_name("matmul_simd_t_64_64_16_2_2_tn_f32")]] kernel void matmul_simd_t(device const float*, device const float*, device float*, constant int&, constant int&, constant int&, uint3, uint, uint, uint); +template [[host_name("matmul_simd_t_64_64_16_2_2_tn_f16")]] kernel void matmul_simd_t(device const half*, device const half*, device half*, constant int&, constant int&, constant int&, uint3, uint, uint, uint); +template [[host_name("matmul_simd_t_64_64_16_2_2_tn_bf16")]] kernel void matmul_simd_t(device const bfloat*,device const bfloat*,device bfloat*,constant int&, constant int&, constant int&, uint3, uint, uint, uint); + +// Fused matmul+bias kernel (NN, single 64x64 tile). Used by AddMMOp. +template [[host_name("matmul_simd_addmm_t_64_64_16_2_2_f32")]] kernel void matmul_simd_addmm_t(device const float*, device const float*, device float*, constant int&, constant int&, constant int&, device const float*, constant int&, uint3, uint, uint, uint); +template [[host_name("matmul_simd_addmm_t_64_64_16_2_2_f16")]] kernel void matmul_simd_addmm_t(device const half*, device const half*, device half*, constant int&, constant int&, constant int&, device const half*, constant int&, uint3, uint, uint, uint); +template [[host_name("matmul_simd_addmm_t_64_64_16_2_2_bf16")]] kernel void matmul_simd_addmm_t(device const bfloat*,device const bfloat*,device bfloat*,constant int&, constant int&, constant int&, device const bfloat*,constant int&, uint3, uint, uint, uint); + +template [[host_name("gemv_f32")]] kernel void gemv(device const float*, device const float*, device float*, constant int&, constant int&, constant int&, uint, uint); +template [[host_name("gemv_f16")]] kernel void gemv(device const half*, device const half*, device half*, constant int&, constant int&, constant int&, uint, uint); +// NOTE: no gemv_bf16: Metal's simd_shuffle_down used inside gemv has no +// bfloat overload (only float/half/int), and instantiating gemv +// would fail to compile the whole shader source — taking down every other +// _bf16 kernel with it. MatMulOp::selectKernel only picks GEMV when N==1, +// which our test models don't hit. If you need bf16 GEMV, refactor the +// kernel to promote to float for the simd reduction. + +template [[host_name("gemv_t_f32")]] kernel void gemv_t(device const float*, device const float*, device float*, constant int&, constant int&, constant int&, uint3, uint); +template [[host_name("gemv_t_f16")]] kernel void gemv_t(device const half*, device const half*, device half*, constant int&, constant int&, constant int&, uint3, uint); +template [[host_name("gemv_t_bf16")]] kernel void gemv_t(device const bfloat*, device const bfloat*, device bfloat*, constant int&, constant int&, constant int&, uint3, uint); +// gemv_t_bf16 works because the float accumulator + simdReduceSumStrided +// promotes to float for the cross-lane reduction (Metal's +// simd_shuffle_down has no bfloat overload). + +//===----------------------------------------------------------------------===// +// matmul_tensor_ops — Metal 4 tensor_ops::matmul2d (Apple9+ / M3+ only) +// +// Uses the Apple-blessed pattern from example_matmul_metal4: each threadgroup +// computes one BMxBN output tile via tensor_ops::matmul2d with +// execution_simdgroups<4> (= 128 threads/threadgroup). K is dynamic so a +// single kernel handles arbitrary K (K must still be a multiple of 16, +// enforced by host dispatch). +// +// Gated on __METAL_VERSION__ >= 410 (MSL 4.1, ships with macOS 26 / iOS 26). +// On older MSL versions this entire block is skipped so other kernels still +// compile. +//===----------------------------------------------------------------------===// +// Note: dropped the __METAL_VERSION__ gate during debugging — re-add once the +// runtime macro for MSL 4.1 is confirmed. +#include +#include + +template +kernel void matmul_tensor_ops( + device T* A_buf [[buffer(0)]], + device T* B_buf [[buffer(1)]], + device T* C_buf [[buffer(2)]], + constant int& M [[buffer(3)]], + constant int& K [[buffer(4)]], + constant int& N [[buffer(5)]], + uint2 tgid [[threadgroup_position_in_grid]]) +{ + using namespace mpp::tensor_ops; + constexpr int BM = 64; + constexpr int BN = 64; + constexpr int BK = 16; // matmul2d's inner K-tile size + + // Build inline tensors from raw buffer pointers + runtime dims. + // dextents arg order: (cols, rows) so the second dim is the M axis. + auto A = metal::tensor, metal::tensor_inline>( + A_buf, metal::dextents(K, M)); + auto B = metal::tensor, metal::tensor_inline>( + B_buf, metal::dextents(N, K)); + auto C = metal::tensor, metal::tensor_inline>( + C_buf, metal::dextents(N, M)); + + // Two ops: init (mode::multiply, overwrites C tile) + accumulate + // (mode::multiply_accumulate, += into C tile). matmul2d processes BK=16 + // K-elements per run() call -> we loop K/BK times. + constexpr auto desc_init = matmul2d_descriptor( + BM, BN, BK, + false, false, false, + matmul2d_descriptor::mode::multiply); + constexpr auto desc_acc = matmul2d_descriptor( + BM, BN, BK, + false, false, false, + matmul2d_descriptor::mode::multiply_accumulate); + matmul2d> mm_init; + matmul2d> mm_acc; + + auto c_tile = C.slice(tgid.x * BN, tgid.y * BM); + + // First K-tile: initialize C + { + auto a = A.slice(0, tgid.y * BM); + auto b = B.slice(tgid.x * BN, 0); + mm_init.run(a, b, c_tile); + } + // Remaining K-tiles: accumulate + for (int k = BK; k < K; k += BK) { + auto a = A.slice(k, tgid.y * BM); + auto b = B.slice(tgid.x * BN, k); + mm_acc.run(a, b, c_tile); + } +} + +template [[host_name("matmul_tensor_ops_f32")]] kernel void matmul_tensor_ops(device float*, device float*, device float*, constant int&, constant int&, constant int&, uint2); +template [[host_name("matmul_tensor_ops_f16")]] kernel void matmul_tensor_ops(device half*, device half*, device half*, constant int&, constant int&, constant int&, uint2); +template [[host_name("matmul_tensor_ops_bf16")]] kernel void matmul_tensor_ops(device bfloat*, device bfloat*, device bfloat*, constant int&, constant int&, constant int&, uint2); + +//===----------------------------------------------------------------------===// +// Naive batched matmul fallback (small problems where SIMD MMA has poor +// occupancy). [B, M, K] @ [B, K, N] -> [B, M, N], one thread per output. +//===----------------------------------------------------------------------===// +template +kernel void bmm( + device const T* A [[buffer(0)]], + device const T* B [[buffer(1)]], + device T* C [[buffer(2)]], + constant int& batch [[buffer(3)]], + constant int& M [[buffer(4)]], + constant int& K [[buffer(5)]], + constant int& N [[buffer(6)]], + constant int& A_batch_stride [[buffer(7)]], + constant int& B_batch_stride [[buffer(8)]], + constant int& C_batch_stride [[buffer(9)]], + uint3 gid [[thread_position_in_grid]]) { + int col = gid.x; + int row = gid.y; + int b = gid.z; + if (row >= M || col >= N || b >= batch) return; + device const T* A_b = A + b * A_batch_stride; + device const T* B_b = B + b * B_batch_stride; + device T* C_b = C + b * C_batch_stride; + T sum = T(0); + for (int k = 0; k < K; k++) sum += A_b[row * K + k] * B_b[k * N + col]; + C_b[row * N + col] = sum; +} + +template [[host_name("bmm_f32")]] kernel void bmm(device const float*, device const float*, device float*, constant int&, constant int&, constant int&, constant int&, constant int&, constant int&, constant int&, uint3); +template [[host_name("bmm_f16")]] kernel void bmm(device const half*, device const half*, device half*, constant int&, constant int&, constant int&, constant int&, constant int&, constant int&, constant int&, uint3); +template [[host_name("bmm_bf16")]] kernel void bmm(device const bfloat*, device const bfloat*, device bfloat*, constant int&, constant int&, constant int&, constant int&, constant int&, constant int&, constant int&, uint3); +)"; + return source; +} + +const char* MatMulOp::kernelSource() const { + return matmulKernelSource().c_str(); +} + +//===----------------------------------------------------------------------===// +// AddMMOp (aten::addmm) — fused bias-matmul. +// +// Schema: addmm(input, mat1, mat2, *, beta=1, alpha=1) -> Tensor +// inputs[0] = input (bias) [M, N] OR broadcast (commonly [N]) +// inputs[1] = mat1 [M, K] +// inputs[2] = mat2 [K, N] +// inputs[3] = beta (Scalar, default 1) — IGNORED, must be 1 for now +// inputs[4] = alpha (Scalar, default 1) — IGNORED, must be 1 for now +// +// Constraints (currently enforced — caller must satisfy or use mm + add): +// - mat1 row-contiguous, mat2 row-contiguous (NN layout) +// - bias is [M, N] contiguous OR [N] (1D-broadcast) +// - beta == alpha == 1 +// +// Falls through to MatMulOp's plain matmul kernel + a separate elementwise +// add IF those constraints don't hold (rare in practice for nn.Linear). +//===----------------------------------------------------------------------===// + +std::vector AddMMOp::computeOutputShape( + EValuePtrSpan inputs) const { + if (inputs.size() < 3 || !inputs[1]->isTensor() || !inputs[2]->isTensor()) { + return {}; + } + const auto& mat1 = inputs[1]->toTensor(); + const auto& mat2 = inputs[2]->toTensor(); + return {static_cast(mat1.size(0)), + static_cast(mat2.size(1))}; +} + +void AddMMOp::dispatch( + MetalStream* stream, + EValuePtrSpan inputs, + EValuePtrSpan outputs) { + + if (inputs.size() < 3) { + ET_LOG(Error, "AddMMOp: expected at least 3 inputs (input, mat1, mat2)"); + return; + } + + auto& bias = inputs[0]->toTensor(); + auto& A = inputs[1]->toTensor(); + auto& B = inputs[2]->toTensor(); + auto& C = outputs[0]->toTensor(); + + auto err = resizeOutput(inputs, outputs[0]); + if (err != Error::Ok) { + ET_LOG(Error, "AddMMOp: failed to resize output"); + return; + } + + // For unsupported alpha/beta or non-NN layouts: fall back to plain matmul + // followed by an elementwise add. (Unimplemented; just error for now — + // PyTorch's addmm with default scalars + nn.Linear bias hits the fast path.) + const bool aRC = isRowContiguous(A); + const bool bRC = isRowContiguous(B); + if (!aRC || !bRC) { + ET_LOG(Error, "AddMMOp: only NN layout (both row-contiguous) is supported " + "currently; got A.RC=%d B.RC=%d", aRC, bRC); + return; + } + + int32_t M = static_cast(A.size(0)); + int32_t K = static_cast(A.size(1)); + int32_t N = static_cast(B.size(1)); + ScalarType dtype = C.scalar_type(); + + // Determine bias stride pattern. Two supported cases: + // 1) bias is 2D [M, N] contiguous → stride_m = N + // 2) bias is 1D [N] (broadcasts across M rows) → stride_m = 0 + // We detect via dim() == 1; for 2D we trust it's contiguous (PyTorch addmm + // requires this OR will broadcast-expand before reaching us). + int32_t bias_stride_m; + if (bias.dim() == 1) { + if (bias.size(0) != N) { + ET_LOG(Error, "AddMMOp: 1D bias dim mismatch (got %lld, expected %d)", + (long long)bias.size(0), N); + return; + } + bias_stride_m = 0; // same row repeated + } else if (bias.dim() == 2 && bias.size(0) == M && bias.size(1) == N) { + bias_stride_m = N; + } else { + ET_LOG(Error, "AddMMOp: unsupported bias shape (dim=%zd)", + (ptrdiff_t)bias.dim()); + return; + } + + std::string kname = + std::string("matmul_simd_addmm_t_64_64_16_2_2_") + dtypeSuffix(dtype); + auto* kernel = getKernel(stream, kname.c_str()); + + ET_LOG(Info, + "AddMMOp: M=%d K=%d N=%d bias_stride_m=%d kernel=%s", + M, K, N, bias_stride_m, kname.c_str()); + + // Same dispatch grid as the Simd 64x64 MM tile. + uvec3 grid((N + 63) / 64, (M + 63) / 64, 1); + uvec3 block(128, 1, 1); + + stream->dispatch(kernel, { + {A.mutable_data_ptr(), A.nbytes()}, + {B.mutable_data_ptr(), B.nbytes()}, + {C.mutable_data_ptr(), C.nbytes()}, + M, K, N, + {bias.mutable_data_ptr(), bias.nbytes()}, + bias_stride_m + }, grid, block); +} + +const char* AddMMOp::kernelSource() const { + // The addmm kernel template lives inside the same source string as + // matmul_simd_t — both are compiled from matmulKernelSource(). Reuse it. + return matmulKernelSource().c_str(); +} + +//===----------------------------------------------------------------------===// +// BatchedMatMulOp (aten::bmm) - [B, M, K] @ [B, K, N] -> [B, M, N] +//===----------------------------------------------------------------------===// + +std::vector BatchedMatMulOp::computeOutputShape( + EValuePtrSpan inputs) const { + + if (inputs.size() < 2 || !inputs[0]->isTensor() || !inputs[1]->isTensor()) { + return {}; + } + + auto& A = inputs[0]->toTensor(); // [B, M, K] + auto& B = inputs[1]->toTensor(); // [B, K, N] + + if (A.dim() != 3 || B.dim() != 3) { + return {}; + } + + SizesType batch = A.size(0); + SizesType M = A.size(1); + SizesType N = B.size(2); + + return {batch, M, N}; +} + +void BatchedMatMulOp::dispatch( + MetalStream* stream, + EValuePtrSpan inputs, + EValuePtrSpan outputs) { + + auto& A = inputs[0]->toTensor(); // [B, M, K] + auto& B = inputs[1]->toTensor(); // [B, K, N] + auto& C = outputs[0]->toTensor(); // [B, M, N] + + auto err = resizeOutput(inputs, outputs[0]); + if (err != Error::Ok) { + ET_LOG(Error, "BatchedMatMulOp: failed to resize output"); + return; + } + + if (!isRowContiguous(A) || !isRowContiguous(B)) { + // Broadcast tolerated below; non-broadcast non-contig is an error. + if (!(isRowContiguous(A) && B.strides().size() >= 1 && B.strides()[0] == 0)) { + ET_LOG(Error, "BatchedMatMulOp: inputs must be row-contiguous (or B broadcast over batch)"); + return; + } + } + + int32_t batch = static_cast(A.size(0)); + int32_t M = static_cast(A.size(1)); + int32_t K = static_cast(A.size(2)); + int32_t N = static_cast(B.size(2)); + + ScalarType dtype = C.scalar_type(); + + //---------------------------------------------------------------------- + // Broadcast fast path: if B's per-batch stride is 0, B is just [K, N] + // replicated across batch. The whole bmm collapses to one 2D matmul: + // (batch*M, K) @ (K, N) -> (batch*M, N) + // Bigger M => better tile occupancy, single kernel launch, and we get + // to ride MatMulOp's full ladder including TensorOps when aligned. + // (Not normally hit by aten::bmm — torch.bmm requires both operands to + // have an explicit batch dim — but defensive in case an upstream pass + // emits this pattern.) + //---------------------------------------------------------------------- + if (B.strides().size() >= 1 && B.strides()[0] == 0 && batch > 1 && + isRowContiguous(A)) { + int32_t M2 = batch * M; + + // Mirror MatMulOp's kernel-selection ladder (size + Apple9 family). + auto* metalStream = static_cast(stream); + const bool canTensorOps = + metalStream && metalStream->device() && + [metalStream->device() supportsFamily:MTLGPUFamilyApple9] && + (M2 % 64 == 0) && (N % 64 == 0) && (K % 16 == 0) && + (dtype == ScalarType::Float || dtype == ScalarType::Half || + dtype == ScalarType::BFloat16); + + std::string kname; + uvec3 grid, block; + if (canTensorOps) { + kname = std::string("matmul_tensor_ops_") + dtypeSuffix(dtype); + grid = uvec3((N + 63) / 64, (M2 + 63) / 64, 1); + block = uvec3(128, 1, 1); + } else if (M2 >= 64 && N >= 64 && K >= 16) { + kname = std::string("matmul_simd_") + dtypeSuffix(dtype); + grid = uvec3((N + 63) / 64, (M2 + 63) / 64, 1); + block = uvec3(128, 1, 1); + } else if (M2 >= 32 && N >= 32) { + kname = std::string("matmul_tiled_") + dtypeSuffix(dtype); + grid = uvec3((N + 31) / 32, (M2 + 31) / 32, 1); + block = uvec3(32, 32, 1); + } else { + kname = std::string("matmul_naive_") + dtypeSuffix(dtype); + grid = uvec3((N + 7) / 8, (M2 + 7) / 8, 1); + block = uvec3(8, 8, 1); + } + + ET_LOG(Info, + "BatchedMatMulOp: broadcast collapse batch=%d M=%d->%d K=%d N=%d kernel=%s", + batch, M, M2, K, N, kname.c_str()); + + auto* kernel = getKernel(stream, kname.c_str()); + stream->dispatch(kernel, { + {A.mutable_data_ptr(), A.nbytes()}, + {B.mutable_data_ptr(), B.nbytes()}, + {C.mutable_data_ptr(), C.nbytes()}, + M2, K, N + }, grid, block); + return; + } + + // Prefer the SIMD MMA kernel (with tgid.z = batch) for large enough tiles; + // fall back to the naive batched kernel for small problems where SIMD + // would have low occupancy. matmul_simd assumes contiguous batched layout + // (A_stride = M*K, etc), which our row-contiguous check above guarantees. + const bool useSimd = (M >= 64) && (N >= 64) && (K >= 16); + + if (useSimd) { + std::string kname = std::string("matmul_simd_") + dtypeSuffix(dtype); + auto* kernel = getKernel(stream, kname.c_str()); + ET_LOG(Info, "BatchedMatMulOp: simd batch=%d M=%d K=%d N=%d", + batch, M, K, N); + uvec3 grid((N + 63) / 64, (M + 63) / 64, batch); + uvec3 block(128, 1, 1); + stream->dispatch(kernel, { + {A.mutable_data_ptr(), A.nbytes()}, + {B.mutable_data_ptr(), B.nbytes()}, + {C.mutable_data_ptr(), C.nbytes()}, + M, K, N + }, grid, block); + return; + } + + // Naive fallback (small problems). + int32_t A_batch_stride = M * K; + int32_t B_batch_stride = K * N; + int32_t C_batch_stride = M * N; + std::string kname = std::string("bmm_") + dtypeSuffix(dtype); + auto* kernel = getKernel(stream, kname.c_str()); + ET_LOG(Info, "BatchedMatMulOp: naive batch=%d M=%d K=%d N=%d", + batch, M, K, N); + uvec3 grid((N + 7) / 8, (M + 7) / 8, batch); + uvec3 block(8, 8, 1); + stream->dispatch(kernel, { + {A.mutable_data_ptr(), A.nbytes()}, + {B.mutable_data_ptr(), B.nbytes()}, + {C.mutable_data_ptr(), C.nbytes()}, + batch, M, K, N, + A_batch_stride, B_batch_stride, C_batch_stride + }, grid, block); +} + +const char* BatchedMatMulOp::kernelSource() const { + // Share the full kernel source with MatMulOp so we can use both + // matmul_simd_ (fast path) and bmm_ (naive fallback). + return matmulKernelSource().c_str(); +} + +} // namespace metal_v2 +} // namespace backends +} // namespace executorch diff --git a/backends/portable/runtime/metal_v2/ops/UnaryOps.h b/backends/portable/runtime/metal_v2/ops/UnaryOps.h new file mode 100644 index 00000000000..b51c4bf9d12 --- /dev/null +++ b/backends/portable/runtime/metal_v2/ops/UnaryOps.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace executorch { +namespace backends { +namespace metal_v2 { + +//===----------------------------------------------------------------------===// +// ReluOp +//===----------------------------------------------------------------------===// + +class ReluOp : public MetalOp { +public: + const char* name() const override { return "aten::relu"; } + + bool supports(ScalarType dtype) const override { + return isFloatingPoint(dtype); + } + + void dispatch( + MetalStream* stream, + EValuePtrSpan inputs, + EValuePtrSpan outputs) override; + +protected: + const char* kernelSource() const override; +}; + +} // namespace metal_v2 +} // namespace backends +} // namespace executorch diff --git a/backends/portable/runtime/metal_v2/ops/UnaryOps.mm b/backends/portable/runtime/metal_v2/ops/UnaryOps.mm new file mode 100644 index 00000000000..23d431b5eeb --- /dev/null +++ b/backends/portable/runtime/metal_v2/ops/UnaryOps.mm @@ -0,0 +1,73 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#import "UnaryOps.h" +#include +#include + +namespace executorch { +namespace backends { +namespace metal_v2 { + +using runtime::Error; + +//===----------------------------------------------------------------------===// +// ReluOp +//===----------------------------------------------------------------------===// + +void ReluOp::dispatch( + MetalStream* stream, + EValuePtrSpan inputs, + EValuePtrSpan outputs) { + + auto& input = inputs[0]->toTensor(); + auto& output = outputs[0]->toTensor(); + + auto err = resizeOutput(inputs, outputs[0]); + if (err != Error::Ok) { + ET_LOG(Error, "ReluOp: failed to resize output"); + return; + } + + ScalarType dtype = output.scalar_type(); + std::string kname = std::string("relu_") + dtypeSuffix(dtype); + + auto* kernel = getKernel(stream, kname.c_str()); + uint32_t numel = static_cast(input.numel()); + + stream->dispatch(kernel, { + {input.mutable_data_ptr(), input.nbytes()}, + {output.mutable_data_ptr(), output.nbytes()}, + numel + }, computeGrid(output), uvec3(256, 1, 1)); +} + +const char* ReluOp::kernelSource() const { + return R"( +#include +using namespace metal; + +template +kernel void relu_kernel( + device const T* input [[buffer(0)]], + device T* output [[buffer(1)]], + constant uint& numel [[buffer(2)]], + uint i [[thread_position_in_grid]]) { + if (i < numel) { + output[i] = max(input[i], T(0)); + } +} + +template [[host_name("relu_f32")]] kernel void relu_kernel(device const float*, device float*, constant uint&, uint); +template [[host_name("relu_f16")]] kernel void relu_kernel(device const half*, device half*, constant uint&, uint); +)"; +} + +} // namespace metal_v2 +} // namespace backends +} // namespace executorch diff --git a/backends/portable/runtime_v2/CMakeLists.txt b/backends/portable/runtime_v2/CMakeLists.txt new file mode 100644 index 00000000000..174328c3739 --- /dev/null +++ b/backends/portable/runtime_v2/CMakeLists.txt @@ -0,0 +1,200 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# PortableBackend_v2 — new architecture per PORTABLE_BACKEND_API_PROPOSAL.md +# +# Layout: +# api/ — pure C++ headers + ProviderRegistry.cpp +# cpu/ — CpuProvider, CpuInstance, HostBuffer, CpuEvent +# metal/ — MetalProvider, MetalInstance, MetalBuffer, MetalEvent +# (Apple only; reuses ../runtime/metal_v2 sources without +# depending on the dead v1 MetalRuntime adapter) +# routers/ — GreedyRouter +# PortableBackend_v2.cpp — BackendInterface adapter; registers as +# "PortableBackend_v2" so it coexists with +# the existing "PortableBackend". +# +# Reuses the existing portable runtime's: +# - Graph / OperatorCall (api/GraphTypes.h) +# - cpu_op_registry() (../runtime_v2/cpu/CpuOpRegistry.h) +# So this target depends on portable_backend (existing) for those. + +set(_portable_backend_v2__srcs + ${CMAKE_CURRENT_SOURCE_DIR}/PortableBackend_v2.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/api/ProviderRegistry.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpu/CpuInstance.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpu/CpuProvider.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/routers/GreedyRouter.cpp +) + +# Apple-only Metal provider. Reuses the existing metal_v2 .mm sources +# without copying — we just compile them into this library directly. +# (We don't depend on portable_backend's metal_v2 wiring because that's +# behind EXECUTORCH_PORTABLE_USE_METAL_V2 and is currently broken on a +# stale GpuStream.h import in MetalRuntime.mm — which we don't use.) +if(APPLE) + enable_language(OBJCXX) + set(_metal_v2_dir ${CMAKE_CURRENT_SOURCE_DIR}/../runtime/metal_v2) + list(APPEND _portable_backend_v2__srcs + # New runtime_v2/metal sources + ${CMAKE_CURRENT_SOURCE_DIR}/metal/MetalBuffer.mm + ${CMAKE_CURRENT_SOURCE_DIR}/metal/MetalProvider.mm + ${CMAKE_CURRENT_SOURCE_DIR}/metal/MetalInstance.mm + # Reused metal_v2 sources (compiled once into this lib) + ${_metal_v2_dir}/MetalStream.mm + ${_metal_v2_dir}/MetalKernel.mm + ${_metal_v2_dir}/MetalKernelCompiler.mm + ${_metal_v2_dir}/MetalBufferPool.mm + ${_metal_v2_dir}/MetalHeap.mm + ${_metal_v2_dir}/MetalOp.mm + ${_metal_v2_dir}/MetalOpRegistry.mm + ${_metal_v2_dir}/ops/BinaryOps.mm + ${_metal_v2_dir}/ops/UnaryOps.mm + ${_metal_v2_dir}/ops/MatMulOp.mm + ${_metal_v2_dir}/ops/MPSGraphOp.mm + ) +endif() + +add_library(portable_backend_v2 ${_portable_backend_v2__srcs}) + +# C++17 needed for std::variant / std::optional / structured bindings. +target_compile_features(portable_backend_v2 PRIVATE cxx_std_17) + +# We #include from runtime/ for Graph and cpu_op_registry — needs the +# runtime_v2/ root on the include path so our own headers resolve, and +# the parent (backends/portable/) for the existing runtime/ tree. +target_include_directories(portable_backend_v2 + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/.. # so #include "runtime_v2/api/..." works + ${CMAKE_CURRENT_SOURCE_DIR}/.. # so #include "runtime_v2/api/GraphTypes.h" works +) + +# Link the v1 portable_backend so we can reuse its Graph types and +# CpuOpRegistry — note we depend on the *library*, not just headers, +# because cpu_op_registry() is defined in CpuOps.cpp. +target_link_libraries(portable_backend_v2 + PRIVATE + executorch_core + program_schema + portable_ops_lib + portable_backend +) + +if(APPLE) + # Metal frameworks for the v2 metal provider + reused metal_v2 sources. + target_link_libraries(portable_backend_v2 + PRIVATE + "-framework Metal" + "-framework Foundation" + "-framework MetalPerformanceShaders" + "-framework MetalPerformanceShadersGraph" + ) + # Tell PortableBackend_v2.cpp it can register MetalProvider. + target_compile_definitions(portable_backend_v2 PRIVATE PORTABLE_V2_HAS_METAL=1) + # All .mm files compile as Objective-C++ without ARC (memory managed + # manually; matches the existing metal_v2 build settings in + # backends/portable/CMakeLists.txt). + set(_objcxx_no_arc_files + ${CMAKE_CURRENT_SOURCE_DIR}/metal/MetalBuffer.mm + ${CMAKE_CURRENT_SOURCE_DIR}/metal/MetalProvider.mm + ${CMAKE_CURRENT_SOURCE_DIR}/metal/MetalInstance.mm + ${_metal_v2_dir}/MetalStream.mm + ${_metal_v2_dir}/MetalKernel.mm + ${_metal_v2_dir}/MetalKernelCompiler.mm + ${_metal_v2_dir}/MetalBufferPool.mm + ${_metal_v2_dir}/MetalHeap.mm + ${_metal_v2_dir}/MetalOp.mm + ${_metal_v2_dir}/MetalOpRegistry.mm + ${_metal_v2_dir}/ops/BinaryOps.mm + ${_metal_v2_dir}/ops/UnaryOps.mm + ${_metal_v2_dir}/ops/MatMulOp.mm + ${_metal_v2_dir}/ops/MPSGraphOp.mm + ) + set_source_files_properties(${_objcxx_no_arc_files} + PROPERTIES LANGUAGE OBJCXX COMPILE_FLAGS "-fno-objc-arc" + ) +endif() + +# Ensure schema is generated first. +add_dependencies(portable_backend_v2 program_schema) + +# Force-link static-init for backend registration. Without this, the +# static auto _register = register_backend(...) in PortableBackend_v2.cpp +# may be stripped by the linker. +if(APPLE) + set_target_properties(portable_backend_v2 PROPERTIES + LINK_FLAGS "-Wl,-force_load" + ) +elseif(UNIX) + # GCC / Clang on Linux: --whole-archive is applied at link site by the + # consuming target, not here. Document this in the install/usage notes. +endif() + +install( + TARGETS portable_backend_v2 + EXPORT ExecuTorchTargets + DESTINATION ${CMAKE_INSTALL_LIBDIR} +) + +#====================================================================== +# Integration test: batch-varying execute() calls +#====================================================================== +# Standalone binary that loads /tmp/dyn_linear_v2.pte and runs +# forward() with batch sizes {1, 3, 5, 8} against the same loaded +# delegate, verifying true runtime-varying dynamic-shape behavior. +add_executable(test_dyn_shapes_v2 + ${CMAKE_CURRENT_SOURCE_DIR}/test_dyn_shapes.cpp +) +target_compile_features(test_dyn_shapes_v2 PRIVATE cxx_std_17) +target_link_libraries(test_dyn_shapes_v2 + PRIVATE + executorch + extension_module + extension_tensor + extension_data_loader + extension_flat_tensor + portable_kernels + portable_ops_lib + portable_backend + "$" +) +if(APPLE) + target_link_libraries(test_dyn_shapes_v2 + PRIVATE + "-framework Metal" + "-framework Foundation" + "-framework MetalPerformanceShaders" + "-framework MetalPerformanceShadersGraph" + ) +endif() + +# Stateful (mutable buffer) test — observes how mutable buffers flow +# through the v2 runtime today. +add_executable(test_stateful_v2 + ${CMAKE_CURRENT_SOURCE_DIR}/test_stateful.cpp +) +target_compile_features(test_stateful_v2 PRIVATE cxx_std_17) +target_link_libraries(test_stateful_v2 + PRIVATE + executorch + extension_module + extension_tensor + extension_data_loader + extension_flat_tensor + portable_kernels + portable_ops_lib + portable_backend + "$" +) +if(APPLE) + target_link_libraries(test_stateful_v2 + PRIVATE + "-framework Metal" + "-framework Foundation" + "-framework MetalPerformanceShaders" + "-framework MetalPerformanceShadersGraph" + ) +endif() diff --git a/backends/portable/runtime_v2/PortableBackend_v2.cpp b/backends/portable/runtime_v2/PortableBackend_v2.cpp new file mode 100644 index 00000000000..175059ed187 --- /dev/null +++ b/backends/portable/runtime_v2/PortableBackend_v2.cpp @@ -0,0 +1,865 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/** + * PortableBackend_v2 — ExecuTorch BackendInterface adapter for the v8.2 + * portable runtime architecture. + * + * Registers "PortableBackend_v2" with the runtime, separate from the + * existing "PortableBackend" so the two can coexist during migration. + * + * See PORTABLE_BACKEND_API_PROPOSAL.md (§3, §6) for the architecture. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef PORTABLE_V2_HAS_METAL +#include +#endif + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace portable_v2 { + +namespace { + +using ::executorch::runtime::ArrayRef; +using ::executorch::runtime::Backend; +using ::executorch::runtime::BackendExecutionContext; +using ::executorch::runtime::BackendInitContext; +using ::executorch::runtime::CompileSpec; +using ::executorch::runtime::DelegateHandle; +using ::executorch::runtime::Error; +using ::executorch::runtime::EValue; +using ::executorch::runtime::FreeableBuffer; +using ::executorch::runtime::Result; +using ValueType = ::executorch::backends::portable::ValueType; +using ::executorch::runtime::Span; + +using ::executorch::aten::DimOrderType; +using ::executorch::aten::ScalarType; +using ::executorch::aten::SizesType; +using ::executorch::aten::StridesType; +using ::executorch::aten::Tensor; +using ::executorch::aten::TensorImpl; + +using ::executorch::backends::portable::Graph; + +/** + * Per-program state held across init/execute/destroy. One per loaded + * delegate. See LoadedDelegate in §6 of the design doc. + */ +struct LoadedDelegate { + // Lifetime root of providers/instances/registry. + std::unique_ptr registry; + std::vector> owned_instances; + + // Parsed program — wrapped behind Graph. Downstream code reaches the + // serialized program ONLY through `graph`; nothing else holds a raw + // flatbuffer pointer. See §3 of PORTABLE_BACKEND_API_PROPOSAL.md. + std::unique_ptr graph; + + // Universal value array, indexed by value_id. Holds EValues for + // everything (scalars, lists, tensors). For tensor EValues, the + // TensorImpl::data_ptr is a denormalized cache of the bound Buffer's + // host_ptr (kept consistent at bind time). + std::vector values; + + // Side table: tensor storage backings only. + BindingTable bindings; + + // Frozen routing decision. + Plan plan; + + // TensorImpl storage for tensor EValues we materialize from the + // flatbuffer (sizes / dim_order / strides arrays + the TensorImpl + // structs themselves). RAII-cleaned via vector. + struct TensorMeta { + std::unique_ptr sizes; + std::unique_ptr dim_order; + std::unique_ptr strides; + std::unique_ptr impl; + }; + std::vector tensor_metas; + + bool poisoned = false; + + ~LoadedDelegate() { + // Drain (no-op for CPU; matters when GPU instances are present). + for (auto* inst : plan.instances) { + if (inst) inst->drain(); + } + // Release owned buffers via their owning Instance. + for (auto& ob : plan.owned_buffers) { + if (ob.owner && ob.buf) ob.owner->release_buffer(ob.buf); + } + } +}; + +// Build the default ProviderSet. v1: CpuProvider at index 0 (host slot). +// v1 supports at most ONE non-CPU provider per process. Selection rules +// (in order): +// - If env PORTABLE_V2_USE_FAKE_ACCEL=1 → register FakeAccel CpuProvider +// as second slot (test mode for routing). Skip Metal even on Apple. +// - Else on Apple, when PORTABLE_V2_HAS_METAL is defined and +// PORTABLE_V2_DISABLE_METAL is unset → try MetalProvider; register +// iff the underlying MetalStream constructed successfully. +// - Else → CpuProvider only (single-provider mode). +std::vector> make_default_providers() { + std::vector> ps; + ps.push_back(std::make_unique()); + + const char* fake = std::getenv("PORTABLE_V2_USE_FAKE_ACCEL"); + if (fake && fake[0] == '1') { + ET_LOG(Info, + "PortableBackend_v2: registering second CpuProvider (\"fake_accel\") for routing tests"); + std::unordered_set allow = { + "aten::add", "aten::add.Tensor", "aten::mul", "aten::mul.Tensor", + }; + ps.push_back(std::make_unique("fake_accel", std::move(allow))); + return ps; + } + +#ifdef PORTABLE_V2_HAS_METAL + const char* disable_metal = std::getenv("PORTABLE_V2_DISABLE_METAL"); + if (!disable_metal || disable_metal[0] != '1') { + auto metal = std::make_unique(); + if (metal->stream_ready()) { + ET_LOG(Info, "PortableBackend_v2: registering MetalProvider"); + ps.push_back(std::move(metal)); + } else { + ET_LOG(Info, + "PortableBackend_v2: MetalProvider unavailable; CPU-only mode"); + } + } +#endif + return ps; +} + +// Initialize one tensor EValue from the flatbuffer Tensor metadata. +// Build a TensorImpl + EValue for value_id from the Graph adapter's +// typed metadata. The data_ptr starts at nullptr; the router/executor +// sets it later from the bound Buffer. +Error initialize_tensor_evalue(LoadedDelegate* d, uint32_t value_id) { + const auto& graph = *d->graph; + auto sizes = graph.tensor_sizes(value_id); + auto dim_order_in = graph.tensor_dim_order(value_id); + size_t dim = sizes.size(); + + LoadedDelegate::TensorMeta meta; + meta.sizes.reset(new SizesType[dim]); + meta.dim_order.reset(new DimOrderType[dim]); + meta.strides.reset(new StridesType[dim]); + + for (size_t i = 0; i < dim; ++i) { + meta.sizes[i] = sizes[i]; + } + if (dim_order_in.size() == dim) { + for (size_t i = 0; i < dim; ++i) { + meta.dim_order[i] = dim_order_in[i]; + } + } else { + for (size_t i = 0; i < dim; ++i) { + meta.dim_order[i] = static_cast(i); + } + } + auto status = ::executorch::runtime::dim_order_to_stride( + meta.sizes.get(), meta.dim_order.get(), dim, meta.strides.get()); + if (status != Error::Ok) return status; + + ScalarType dtype = graph.tensor_dtype(value_id); + meta.impl.reset(new TensorImpl( + dtype, + static_cast(dim), + meta.sizes.get(), + /*data=*/nullptr, + meta.dim_order.get(), + meta.strides.get(), + // DYNAMIC_BOUND so kernels' resize_tensor() calls take effect. + // The flatbuffer's `sizes` is the AOT memory-plan max bound; + // runtime kernels may shrink to actual. + ::executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND)); + + d->values[value_id] = EValue(Tensor(meta.impl.get())); + d->tensor_metas.push_back(std::move(meta)); + return Error::Ok; +} + +// Initialize all value EValues. After this returns, every value_id in +// the graph has a populated EValue: +// - Tensors: TensorImpl with dtype/sizes/dim_order/strides; data_ptr +// starts as nullptr and is set later by prebind_owned_buffers / +// bind_inputs / bind_outputs. +// - Scalars / lists: actual values from the flatbuffer. +// - Inputs/outputs are also materialized here as shell tensor EValues +// using the flatbuffer-declared (max-shape) sizes, so allocate_all +// can read their tensor metadata. +Error initialize_values( + LoadedDelegate* d, + ::executorch::runtime::MemoryAllocator* runtime_alloc) { + using ::executorch::runtime::BoxedEvalueList; + + size_t n = d->graph->num_values(); + if (n == 0) return Error::Ok; + + // d->values is already sized by the init() caller. Don't shrink. + if (d->values.size() < n) { + return Error::InvalidState; + } + + for (uint32_t i = 0; i < n; ++i) { + switch (d->graph->value_type(i)) { + case ValueType::None: + d->values[i] = EValue(); + break; + case ValueType::Int: + d->values[i] = EValue(d->graph->int_value(i)); + break; + case ValueType::Double: + d->values[i] = EValue(d->graph->double_value(i)); + break; + case ValueType::Bool: + d->values[i] = EValue(d->graph->bool_value(i)); + break; + case ValueType::Tensor: + // Materialize EVERY tensor (including IO). For IO, this creates + // a shell TensorImpl with the flatbuffer-declared (max-shape) + // sizes; bind_inputs/bind_outputs will resize per execute and + // upload_from_host will re-alias the bound Buffer's data_ptr. + if (auto e = initialize_tensor_evalue(d, i); e != Error::Ok) { + return e; + } + break; + case ValueType::IntList: { + // Mirror Method::init's BoxedEvalueList setup + // (see runtime/executor/method.cpp). + auto items = d->graph->int_list_member_ids(i); + size_t cnt = items.size(); + EValue** evalp_list = runtime_alloc->allocateList(cnt); + int64_t* int_list = runtime_alloc->allocateList(cnt); + if (!evalp_list || !int_list) { + return Error::MemoryAllocationFailed; + } + for (size_t j = 0; j < cnt; ++j) { + int64_t vidx = items[j]; + if (vidx < 0 || static_cast(vidx) >= n) { + return Error::InvalidProgram; + } + evalp_list[j] = &d->values[static_cast(vidx)]; + } + auto* boxed_mem = + runtime_alloc->allocateInstance>(); + if (!boxed_mem) return Error::MemoryAllocationFailed; + auto* boxed = new (boxed_mem) + BoxedEvalueList(evalp_list, int_list, cnt); + d->values[i] = EValue(boxed); + ET_LOG(Debug, + "initialize_values: value_id=%u IntList(%zu items)", i, cnt); + } break; + case ValueType::Other: + // String / OptionalTensor / BoolList / DoubleList / etc. + // Adapter doesn't surface these yet; default-construct. + d->values[i] = EValue(); + break; + } + } + return Error::Ok; +} + +// Pre-bind tensor EValues to their owned Buffers. The router populates +// `Plan::owned_buffers` with per-Instance allocations, constants, AND +// cross-runtime transfer destinations (which use synthetic value_ids). +// We walk those and bind each (value_id -> Buffer*) into the BindingTable +// AND sync the EValue's TensorImpl::data_ptr to the Buffer's host_ptr. +// +// Pre-condition: synthetic value EValues have already been materialized by +// materialize_synthetic_values (called before this). +void prebind_owned_buffers(LoadedDelegate* d) { + for (auto& ob : d->plan.owned_buffers) { + if (!ob.buf) continue; + d->bindings.bind(ob.value_id, ob.buf); + + // Look up the owning Provider's name for clearer logs. + std::string prov_name = "?"; + for (size_t i = 0; i < d->plan.instances.size(); ++i) { + if (d->plan.instances[i] == ob.owner) { + prov_name = std::string(d->plan.providers[i]->name()); + break; + } + } + if (ob.value_id < d->values.size() && + d->values[ob.value_id].isTensor()) { + void* hp = ob.buf->host_ptr(); + if (hp) { + d->values[ob.value_id] + .toTensor() + .unsafeGetTensorImpl() + ->set_data(hp); + } + ET_LOG(Debug, + "[mem] bind: value_id=%u provider=%s host_ptr=%p bytes=%zu (tensor)", + ob.value_id, prov_name.c_str(), hp, ob.buf->size_bytes()); + } else { + ET_LOG(Debug, + "[mem] bind: value_id=%u provider=%s bytes=%zu (non-tensor or unset)", + ob.value_id, prov_name.c_str(), ob.buf->size_bytes()); + } + } +} + +// Materialize EValues for router-synthesized value_ids (cross-runtime +// transfer destinations). Each synthesized id inherits dtype/shape/strides +// from its source value_id by cloning TensorImpl metadata. The data_ptr +// is set later by prebind_owned_buffers when we know which Buffer. +// +// Pre-condition: d->values has already been sized to fit max(synthetic id)+1 +// (done in init() before initialize_values to avoid invalidating +// BoxedEvalueList EValue* pointers). +Error materialize_synthetic_values(LoadedDelegate* d) { + if (d->plan.synthetic_values.empty()) return Error::Ok; + + for (const auto& sv : d->plan.synthetic_values) { + if (sv.new_id >= d->values.size()) { + ET_LOG(Error, + "materialize_synthetic_values: new_id=%u >= values.size()=%zu", + sv.new_id, d->values.size()); + return Error::InvalidState; + } + if (sv.source_id >= d->values.size() || + !d->values[sv.source_id].isTensor()) { + ET_LOG(Error, + "materialize_synthetic_values: source value_id=%u is not a tensor", + sv.source_id); + return Error::InvalidProgram; + } + + auto& src = d->values[sv.source_id].toTensor(); + auto* src_impl = src.unsafeGetTensorImpl(); + size_t dim = src.dim(); + + LoadedDelegate::TensorMeta tm; + tm.sizes.reset(new SizesType[dim]); + tm.dim_order.reset(new DimOrderType[dim]); + tm.strides.reset(new StridesType[dim]); + for (size_t i = 0; i < dim; ++i) { + tm.sizes[i] = src.size(i); + tm.dim_order[i] = src_impl->dim_order()[i]; + tm.strides[i] = src.strides()[i]; + } + tm.impl.reset(new TensorImpl( + src.scalar_type(), + static_cast(dim), + tm.sizes.get(), + /*data=*/nullptr, + tm.dim_order.get(), + tm.strides.get(), + // DYNAMIC_BOUND so per-execute TransferStep's resize_tensor() + // can update sizes when actual shape varies per execute. + ::executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND)); + + d->values[sv.new_id] = EValue(Tensor(tm.impl.get())); + d->tensor_metas.push_back(std::move(tm)); + ET_LOG(Debug, + "materialize_synthetic_values: new_id=%u (clone of source_id=%u)", + sv.new_id, sv.source_id); + } + return Error::Ok; +} + +// Single host-first allocation pass. Walks each provider's alloc_plan in +// provider index order (host slot 0 first). For each provider, patches +// host_alias on synthetic AllocRequests to point at the source value's +// (already-allocated) Buffer, then calls allocate_all. Pushes resulting +// Buffers into plan.owned_buffers and updates the value→Buffer ledger +// for the next provider. +// +// Pre-condition: initialize_values + materialize_synthetic_values have +// run (every value_id has an EValue with valid TensorImpl metadata). +Error allocate_buffers(LoadedDelegate* d) { + // Build a value_id → source value_id map for synthetic mirrors. + std::unordered_map synth_to_source; + for (const auto& sv : d->plan.synthetic_values) { + synth_to_source[sv.new_id] = sv.source_id; + } + + // value_id → Buffer ledger; populated as we allocate. + std::unordered_map value_to_buf; + // Seed with constants already in plan.owned_buffers (uploaded during + // route() via upload_constant). + for (const auto& ob : d->plan.owned_buffers) { + value_to_buf[ob.value_id] = ob.buf; + } + + // Iterate providers in plan order (host slot 0 first). Earlier providers' + // Buffers become visible as host_alias targets for later providers' + // synthetic AllocRequests. + for (size_t p = 0; p < d->plan.alloc_plans.size(); ++p) { + auto& reqs = d->plan.alloc_plans[p]; + if (reqs.empty()) continue; + if (p >= d->plan.instances.size() || !d->plan.instances[p]) { + return Error::InvalidState; + } + Instance* inst = d->plan.instances[p]; + + // Patch host_alias on synthetic AllocRequests. + for (auto& req : reqs) { + auto sit = synth_to_source.find(req.value_id); + if (sit == synth_to_source.end()) continue; + auto vit = value_to_buf.find(sit->second); + if (vit == value_to_buf.end() || !vit->second) { + ET_LOG(Error, + "allocate_buffers: synthetic value_id=%u source value_id=%u " + "not yet allocated (provider %zu allocates before its sources)", + req.value_id, sit->second, p); + return Error::InvalidState; + } + req.host_alias = vit->second; + } + + // Call allocate_all. + std::vector out(reqs.size(), nullptr); + auto e = inst->allocate_all( + Span(reqs.data(), reqs.size()), + Span(d->values.data(), d->values.size()), + Span(out.data(), out.size())); + if (e != Error::Ok) return e; + + // Record results. + for (size_t i = 0; i < reqs.size(); ++i) { + d->plan.owned_buffers.push_back({out[i], inst, reqs[i].value_id}); + value_to_buf[reqs[i].value_id] = out[i]; + } + } + return Error::Ok; +} + +// bind_inputs: per-execute. For each graph input value_id: +// - Overwrites d->values[vid] with caller's EValue (shares caller's +// TensorImpl, so per-execute kernel resize on our slot is visible +// to caller). +// - Calls upload_from_host on the pre-allocated destination Buffer +// to re-alias it to caller's pointer. +// +// Note: the shell TensorImpl materialized at init time (in initialize_values) +// served its purpose at allocate_all (provided dtype/sizes). Per-execute +// it's superseded by caller's TensorImpl via the EValue copy. The shell +// stays alive in d->tensor_metas until ~LoadedDelegate. +Error bind_inputs(LoadedDelegate* d, Span args) { + size_t n_in = d->plan.inputs.size(); + for (size_t i = 0; i < n_in && i < args.size(); ++i) { + const auto& ib = d->plan.inputs[i]; + uint32_t vid = ib.value_id; + if (vid >= d->values.size()) return Error::InvalidArgument; + + // Replace shell EValue with caller's (shares TensorImpl). + d->values[vid] = *args[i]; + + if (!d->values[vid].isTensor()) continue; + + auto& tensor = d->values[vid].toTensor(); + void* host_data = tensor.mutable_data_ptr(); + size_t nbytes = tensor.nbytes(); + if (!host_data || nbytes == 0) continue; + + Buffer* buf = d->bindings.get(vid); + if (!buf) { + ET_LOG(Error, + "bind_inputs: no pre-allocated Buffer for input value_id=%u", + vid); + return Error::InvalidState; + } + Instance* dst_inst = d->plan.instances[0]; + + if (auto e = dst_inst->upload_from_host( + d->values[vid], host_data, + d->values[vid], buf, + QueueKind::Compute, + Span(), + nullptr); + e != Error::Ok) { + return e; + } + ET_LOG(Debug, + "[mem] bind_input: value_id=%u alias_into=cpu(host) caller_ptr=%p bytes=%zu", + vid, host_data, nbytes); + } + return Error::Ok; +} + +// bind_outputs: symmetric. +Error bind_outputs(LoadedDelegate* d, Span args) { + size_t n_in = d->plan.inputs.size(); + size_t n_out = d->plan.outputs.size(); + for (size_t i = 0; i < n_out && (n_in + i) < args.size(); ++i) { + const auto& ob = d->plan.outputs[i]; + uint32_t vid = ob.value_id; + if (vid >= d->values.size()) return Error::InvalidArgument; + + EValue* arg_ev = args[n_in + i]; + if (!arg_ev) continue; + + d->values[vid] = *arg_ev; + // Producer-side graph-output mirrors (synthetic value_ids whose + // source is this output) must share the caller's TensorImpl so + // that: + // (a) the producer kernel writes to caller's data_ptr, and + // (b) resize_tensor calls inside the kernel update caller's + // TensorImpl (the one our caller will read sizes from). + // The TransferStep host -> producer (emitted by the router for + // each producer-side mirror) handles re-aliasing the producer's + // Buffer to caller's pointer per-execute. + for (const auto& sv : d->plan.synthetic_values) { + if (sv.source_id == vid && sv.new_id < d->values.size()) { + d->values[sv.new_id] = *arg_ev; + } + } + if (!d->values[vid].isTensor()) continue; + + auto& tensor = d->values[vid].toTensor(); + void* host_data = tensor.mutable_data_ptr(); + size_t nbytes = tensor.nbytes(); + if (!host_data || nbytes == 0) continue; + + Buffer* buf = d->bindings.get(vid); + if (!buf) { + ET_LOG(Error, + "bind_outputs: no pre-allocated Buffer for output value_id=%u", + vid); + return Error::InvalidState; + } + Instance* dst_inst = d->plan.instances[0]; + + if (auto e = dst_inst->upload_from_host( + d->values[vid], host_data, + d->values[vid], buf, + QueueKind::Compute, + Span(), + nullptr); + e != Error::Ok) { + return e; + } + ET_LOG(Debug, + "[mem] bind_output: value_id=%u alias_into=cpu(host) caller_ptr=%p bytes=%zu", + vid, host_data, nbytes); + } + return Error::Ok; +} + +// Execute one Step. +Error execute_step(LoadedDelegate* d, const Step& step) { + // Pre-resolve wait_for / signal Event*. v1 has no events emitted by the + // router, so wait_for is empty and signal is kNoEvent. Keep the + // resolution code shape so adding events later is mechanical. + std::vector waits_storage; + Span waits(waits_storage.data(), static_cast(0)); + Event* signal = nullptr; + + return std::visit( + [&](auto&& s) -> Error { + using T = std::decay_t; + if constexpr (std::is_same_v) { + // Resolve wait_for from EventIds (skipped in v1; empty). + (void)s.wait_for; + (void)s.signal; + Instance* inst = d->plan.instances[s.runtime_idx]; + return inst->execute( + s.segment, + Span(d->values.data(), d->values.size()), + d->bindings, + waits, + signal); + } else if constexpr (std::is_same_v) { + if (s.src_value_id >= d->values.size() || + s.dst_value_id >= d->values.size()) { + return Error::InvalidState; + } + EValue& src_ev = d->values[s.src_value_id]; + EValue& dst_ev = d->values[s.dst_value_id]; + Buffer* src_buf = d->bindings.get(s.src_value_id); + Buffer* dst_buf = d->bindings.get(s.dst_value_id); + Instance* src_inst = d->plan.instances[s.src_idx]; + Instance* dst_inst = d->plan.instances[s.dst_idx]; + // Per-execute trace line: ties the runtime transfer back to + // the router's "[mem] router: cross-runtime transfer ..." + // setup line so memory cost is auditable. + std::string src_pname = "?"; + std::string dst_pname = "?"; + if (s.src_idx < d->plan.providers.size() && + d->plan.providers[s.src_idx]) { + src_pname = std::string(d->plan.providers[s.src_idx]->name()); + } + if (s.dst_idx < d->plan.providers.size() && + d->plan.providers[s.dst_idx]) { + dst_pname = std::string(d->plan.providers[s.dst_idx]->name()); + } + size_t xfer_bytes = src_ev.isTensor() ? src_ev.toTensor().nbytes() : 0; + ET_LOG(Debug, + "[mem] step: TransferStep src_value_id=%u (%s) -> dst_value_id=%u (%s) bytes=%zu", + s.src_value_id, src_pname.c_str(), + s.dst_value_id, dst_pname.c_str(), + xfer_bytes); + // Direction-specific dispatch: the device (non-host) Instance + // owns the cross-runtime move. By convention slot 0 is host + // (CPU), so whichever side != 0 is the device side. + // src is host → upload to device's Buffer + // dst is host → download from device's Buffer + // Two non-host runtimes is not supported (and never emitted by + // the router); a runtime↔runtime transfer would route through + // host as two TransferSteps. + constexpr RuntimeIndex kHostIdx = 0; + if (s.src_idx == kHostIdx && s.dst_idx != kHostIdx) { + void* host_src_ptr = src_ev.isTensor() + ? src_ev.toTensor().mutable_data_ptr() + : nullptr; + return dst_inst->upload_from_host( + src_ev, host_src_ptr, + dst_ev, dst_buf, + s.queue, waits, signal); + } else if (s.dst_idx == kHostIdx && s.src_idx != kHostIdx) { + void* host_dst_ptr = dst_ev.isTensor() + ? dst_ev.toTensor().mutable_data_ptr() + : nullptr; + return src_inst->download_to_host( + src_ev, src_buf, + dst_ev, host_dst_ptr, + s.queue, waits, signal); + } else { + ET_LOG(Error, + "TransferStep with neither side on host (src_idx=%u dst_idx=%u) is unsupported", + s.src_idx, s.dst_idx); + return Error::NotSupported; + } + } + return Error::Internal; + }, + step); +} + +} // namespace + +class PortableBackendV2 final : public ::executorch::runtime::BackendInterface { + public: + ~PortableBackendV2() override = default; + + bool is_available() const override { return true; } + + Result init( + BackendInitContext& ctx, + FreeableBuffer* processed, + ArrayRef /*compile_specs*/) const override { + if (!processed || !processed->data() || processed->size() == 0) { + return Error::InvalidArgument; + } + + auto* program = executorch_flatbuffer::GetProgram(processed->data()); + if (!program) return Error::InvalidProgram; + + auto plans = program->execution_plan(); + if (!plans || plans->size() == 0) return Error::InvalidProgram; + + // Allocate the LoadedDelegate from the runtime allocator (lifetime = + // method). We placement-new because MemoryAllocator::allocate returns + // raw bytes. + auto* runtime_alloc = ctx.get_runtime_allocator(); + if (!runtime_alloc) return Error::InvalidState; + void* mem = runtime_alloc->allocate( + sizeof(LoadedDelegate), alignof(LoadedDelegate)); + if (!mem) return Error::MemoryAllocationFailed; + auto* d = new (mem) LoadedDelegate(); + + // Wrap the parsed flatbuffer in a Graph immediately. From this point + // on, NO downstream code (Router, Instances, initialize_values, + // materialize_synthetic_values, prebind_owned_buffers, executor) reaches + // into the flatbuffer directly — Graph is the only handle. + d->graph = std::make_unique(plans->Get(0)); + + // Build the provider set and registry. + d->registry = + std::make_unique(make_default_providers()); + + // Instantiate one Instance per available provider. + auto avail = d->registry->available(); + d->owned_instances.reserve(avail.size()); + std::vector raw_instances; + raw_instances.reserve(avail.size()); + std::vector raw_providers; + raw_providers.reserve(avail.size()); + for (auto* p : avail) { + auto inst = p->instantiate(); + raw_instances.push_back(inst.get()); + raw_providers.push_back(p); + d->owned_instances.push_back(std::move(inst)); + } + + // Route FIRST so we know the synthetic-value count before sizing + // d->values. (initialize_values builds BoxedEvalueList's + // that store EValue* pointers into d->values; if we resize() after + // that, those pointers are invalidated.) + GreedyRouter router; + RouterOptions opts; + auto plan_result = router.route( + *d->graph, + Span(raw_providers.data(), raw_providers.size()), + Span(raw_instances.data(), raw_instances.size()), + ctx.get_named_data_map(), + opts); + if (!plan_result.ok()) { + d->~LoadedDelegate(); + return plan_result.error(); + } + d->plan = std::move(plan_result.get()); + + // Pre-size d->values to fit both original and synthetic value_ids. + // Reserve THEN resize to the final size up-front: subsequent + // initialize_values builds BoxedEvalueList's that store + // EValue* pointers into d->values; if the vector reallocates later, + // those pointers dangle. + size_t num_orig = d->graph->num_values(); + size_t total_size = num_orig; + for (const auto& sv : d->plan.synthetic_values) { + total_size = std::max(total_size, sv.new_id + 1); + } + d->values.reserve(total_size); + d->values.resize(total_size); + + // Materialize EValues for all original value_ids (including IO with + // shell TensorImpls at flatbuffer-declared sizes). + if (auto e = initialize_values(d, runtime_alloc); e != Error::Ok) { + d->~LoadedDelegate(); + return e; + } + // Materialize EValues for synthetic value_ids (clone of source's + // TensorImpl; data_ptr null until allocate_buffers + prebind). + if (auto e = materialize_synthetic_values(d); e != Error::Ok) { + d->~LoadedDelegate(); + return e; + } + // Now allocate Buffers for everything via allocate_all (host-first + // single pass; synthetic AllocRequests' host_alias gets patched to + // the source's just-allocated Buffer before each device call). + if (auto e = allocate_buffers(d); e != Error::Ok) { + d->~LoadedDelegate(); + return e; + } + // Bind allocated Buffers into the BindingTable + sync TensorImpl + // data_ptr to each Buffer's host_ptr. + prebind_owned_buffers(d); + + ET_LOG(Info, "PortableBackend_v2: initialized with %zu steps", + d->plan.steps.size()); + return reinterpret_cast(d); + } + + Error execute( + BackendExecutionContext& /*ctx*/, + DelegateHandle* handle, + Span args) const override { + auto* d = static_cast(handle); + if (!d) return Error::InvalidState; + + // Failure model: a Failed/Poisoned event from any prior execute + // leaves the delegate sticky-Poisoned. Subsequent calls bail. + if (d->poisoned) return Error::Internal; // DelegatePoisoned + + // Bindings are persistent (allocated once at init); per-execute + // bind_inputs / bind_outputs re-aliases the existing Buffers in + // place via upload_from_host. + + if (auto e = bind_inputs(d, args); e != Error::Ok) return e; + if (auto e = bind_outputs(d, args); e != Error::Ok) return e; + + // Issue every step. Status flows through events; we ignore the + // return value of each visit so poison can propagate. (v1 has no + // events, so this is just sequencing.) + Error first_err = Error::Ok; + for (const auto& step : d->plan.steps) { + Error e = execute_step(d, step); + if (e != Error::Ok && first_err == Error::Ok) { + first_err = e; + } + } + + // Single drain at end. + for (auto* inst : d->plan.instances) { + if (inst) inst->drain(); + } + + if (first_err != Error::Ok) { + d->poisoned = true; + return first_err; + } + + // Copy back scalar outputs (rare). Tensor outputs were written + // in-place into caller storage by either the producing kernel (host + // outputs) or a terminal device->host TransferStep. + size_t n_in = d->plan.inputs.size(); + size_t n_out = d->plan.outputs.size(); + for (size_t i = 0; i < n_out && (n_in + i) < args.size(); ++i) { + const auto& ob = d->plan.outputs[i]; + uint32_t vid = ob.value_id; + if (vid >= d->values.size()) continue; + EValue* arg_ev = args[n_in + i]; + if (!arg_ev) continue; + if (!d->values[vid].isTensor()) { + *arg_ev = d->values[vid]; + } + } + return Error::Ok; + } + + void destroy(DelegateHandle* handle) const override { + if (!handle) return; + auto* d = static_cast(handle); + d->~LoadedDelegate(); + // The underlying memory was bump-allocated from the runtime + // allocator and is reclaimed when the Method is destroyed; no free + // here. + } +}; + +namespace { +PortableBackendV2 g_backend_v2; +Backend g_backend_record{"PortableBackend_v2", &g_backend_v2}; +auto g_register = + ::executorch::runtime::register_backend(g_backend_record); +} // namespace + +} // namespace portable_v2 +} // namespace backends +} // namespace executorch diff --git a/backends/portable/runtime_v2/api/BindingTable.h b/backends/portable/runtime_v2/api/BindingTable.h new file mode 100644 index 00000000000..b1494edf627 --- /dev/null +++ b/backends/portable/runtime_v2/api/BindingTable.h @@ -0,0 +1,55 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include +#include + +namespace executorch { +namespace backends { +namespace portable_v2 { + +/** + * Side table mapping value_id -> Buffer* for tensor storage backings only. + * Non-tensor EValues (scalars, lists) live in LoadedDelegate::values; not + * in this table. See §4.3 of PORTABLE_BACKEND_API_PROPOSAL.md. + * + * The TensorImpl::data_ptr in the corresponding EValue is a denormalized + * cache of the bound Buffer's host_ptr(); kept consistent at bind time. + * + * Bindings are populated once at init (prebind_owned_buffers) and live + * for the LoadedDelegate's lifetime. Per-execute IO does NOT change + * bindings — upload_from_host re-aliases the bound Buffer in place. + */ +class BindingTable { + public: + // Returns the storage Buffer for this value's tensor data. + // Returns nullptr if value_id is not a tensor (e.g. a scalar) or hasn't + // been bound yet. + Buffer* get(uint32_t value_id) const { + auto it = map_.find(value_id); + return it != map_.end() ? it->second : nullptr; + } + + void bind(uint32_t value_id, Buffer* buf) { map_[value_id] = buf; } + + // Used during init for diagnostics / iteration. + size_t size() const { return map_.size(); } + + private: + std::unordered_map map_; +}; + +using BindingView = const BindingTable&; + +} // namespace portable_v2 +} // namespace backends +} // namespace executorch diff --git a/backends/portable/runtime_v2/api/Buffer.h b/backends/portable/runtime_v2/api/Buffer.h new file mode 100644 index 00000000000..4ae53465ae4 --- /dev/null +++ b/backends/portable/runtime_v2/api/Buffer.h @@ -0,0 +1,62 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include +#include + +namespace executorch { +namespace backends { +namespace portable_v2 { + +/** + * Opaque storage handle. Lives on a Location. Concrete subclasses + * (HostBuffer, MetalBuffer, VulkanBuffer) are private to each runtime; + * the router and executor see only Buffer*. + * + * Ownership: Buffers are owned by their Provider's RuntimeContext (its + * pool). Instance::allocate returns a non-owning Buffer*. The Plan records + * which Instance allocated which Buffer; on destruction asks each Instance + * to release. + * + * Only host_ptr() is virtual; location() and size_bytes() read base + * member storage. + * + * Note: lifetime of underlying storage (e.g., a held FreeableBuffer for + * NDM-aliased constants, or a pool slot) is the concrete subclass's + * responsibility. + */ +class Buffer { + public: + virtual ~Buffer() = default; + + Location location() const { return location_; } + size_t size_bytes() const { return size_bytes_; } + + // Non-null iff host-addressable. CPU: always. Metal on Apple Silicon: + // usually (MTLStorageModeShared). Discrete GPU: nullptr. + virtual void* host_ptr() { return nullptr; } + + protected: + Buffer(Location loc, size_t bytes) : location_(loc), size_bytes_(bytes) {} + + // Allow derived classes (e.g., recycled HostBuffer slots) to re-set size + // when re-aliasing the underlying storage. + void set_size_bytes(size_t bytes) { size_bytes_ = bytes; } + +private: + Location location_; + size_t size_bytes_; +}; + +} // namespace portable_v2 +} // namespace backends +} // namespace executorch diff --git a/backends/portable/runtime_v2/api/Event.h b/backends/portable/runtime_v2/api/Event.h new file mode 100644 index 00000000000..a68f5bea1f3 --- /dev/null +++ b/backends/portable/runtime_v2/api/Event.h @@ -0,0 +1,89 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include + +namespace executorch { +namespace backends { +namespace portable_v2 { + +/** + * Async coordination primitives. See §4.6 of the design doc. + */ + +// Opaque index into Plan::events. +using EventId = uint16_t; +inline constexpr EventId kNoEvent = 0xFFFF; + +// Sticky completion status. Once non-Pending, only prepare_signal() can +// clear it. +enum class EventStatus : uint8_t { + Pending, // not yet signaled + Complete, // signaled successfully + Failed, // producer reported a backend error + Poisoned, // an upstream event in this event's wait_for chain Failed; + // this step short-circuited and never executed +}; + +// Logical execution lane, picked by the runtime. Two flavors used by the +// router; runtimes may add private ones. +enum class QueueKind : uint8_t { + Compute, // kernel dispatch + Transfer, // DMA / copy. Vulkan: dedicated transfer queue. Metal: same + // as Compute. CPU: same as Compute. +}; + +/** + * Backend-defined opaque async-completion handle. + * CPU: trivially-completed flag. + * Metal: MTLEvent (monotonic signal value). + * Vulkan: timeline VkSemaphore (host wait via VkFence). + */ +class Event { + public: + virtual ~Event() = default; + + // Prepare the event to receive a new signal in this execute() pass. On + // monotonic backends, advances the expected next-signal value. On + // reset-style backends, resets the fence. On CPU, clears the flag. + // Pre-condition: status() != Failed && status() != Poisoned (executor + // enforces by refusing to start the next execute() if delegate is + // Poisoned). + virtual void prepare_signal() = 0; + + // Cheap, lock-free read. Memory ordering: ACQUIRE on the loaded status; + // pairs with the release-store inside the producing signal path. + virtual EventStatus status() const = 0; + + bool is_complete() const { return status() == EventStatus::Complete; } + + // Valid iff status() == Failed or Poisoned. + virtual ::executorch::runtime::Error error() const = 0; + + // Signaling — promoted to base so call sites can drive any Event* + // without dynamic_cast'ing to a concrete subclass. + // + // signal_complete: producer finished successfully. + // signal_failed: producer encountered a backend error; propagate + // `e` as the event's error(). + // signal_poisoned: an upstream event in this event's wait_for chain + // Failed, so this step short-circuited and never + // actually executed; `upstream_error` is the + // upstream's error() value. + virtual void signal_complete() = 0; + virtual void signal_failed(::executorch::runtime::Error e) = 0; + virtual void signal_poisoned(::executorch::runtime::Error upstream_error) = 0; +}; + +} // namespace portable_v2 +} // namespace backends +} // namespace executorch diff --git a/backends/portable/runtime_v2/api/GraphTypes.h b/backends/portable/runtime_v2/api/GraphTypes.h new file mode 100644 index 00000000000..badb41ec940 --- /dev/null +++ b/backends/portable/runtime_v2/api/GraphTypes.h @@ -0,0 +1,743 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +/** + * Graph: backend-facing IR adapter. + * + * These types provide our backends' view of the program IR. The current + * implementation adapts ExecuTorch's flatbuffer + * (executorch_flatbuffer::ExecutionPlan / KernelCall) but the API is + * intentionally backing-agnostic — a different serialization could + * replace the underlying storage without changing this header or any + * backend that uses it. + * + * Some methods are documented IR concepts that the current adapter does + * not yet back (e.g., mutable_buffer_ids, version): they are placeholders + * pending serialization support. + * + * ---------------------------------------------------------------------- + * Fictional IR schema (what Graph effectively presents) + * ---------------------------------------------------------------------- + * If the IR were serialized in its own format (independent of the + * underlying ExecuTorch flatbuffer), it would look like this: + * + * // Top-level + * table Graph { + * version: string; // schema/program version + * values: [Value]; // dense pool indexed by value_id (uint) + * inputs: [uint]; // graph input value_ids + * outputs: [uint]; // graph output value_ids + * mutable_buffers: [uint]; // values that persist across executes + * operators: [OperatorDef]; // op-name registry (deduped) + * chains: [Chain]; // chains[0] is main; others used by + * // future control flow (if/while/call) + * } + * + * table OperatorDef { + * name: string; // e.g. "aten.add.Tensor" + * } + * + * // Op chains + * table Chain { + * instructions: [Instruction]; + * } + * + * table Instruction { + * // Today only one variant; future variants for control flow + * // (if/while/call) would extend this union. + * body: KernelCall; + * } + * + * table KernelCall { + * op_index: uint; // → Graph.operators[op_index].name + * args: [uint]; // value_ids: args[0..n-2] = inputs, + * // args[n-1] = output. Single-output + * // assumed today; multi-output is a + * // future extension. + * } + * + * // Values + * union Value { + * None, + * Int { v: int64; }, + * Double { v: float64; }, + * Bool { v: bool; }, + * String { v: string; }, + * Tensor, + * IntList, + * DoubleList, + * BoolList, + * OptionalTensor, + * // ... + * } + * + * table Tensor { + * scalar_type: ScalarType; // dtype + * sizes: [int32]; // for DYNAMIC_BOUND, this is max-shape + * dim_order: [uint8]; // permutation defining memory layout + * shape_dynamism: ShapeDynamism; // STATIC | DYNAMIC_BOUND | DYNAMIC_UNBOUND + * allocation: AllocationInfo?;// null = no AOT plan + * data: TensorData; + * } + * + * union TensorData { + * None, + * Inline { buffer_idx: uint; }, // bytes embedded in program + * External { ndm_key: string;}, // bytes in NamedDataMap, FQN-keyed + * } + * + * table AllocationInfo { + * pool_id: int32; + * offset: uint64; // raw byte offset within pool_id + * } + * + * table IntList { + * member_ids: [int64]; // value_ids of list elements + * } + * + * ---------------------------------------------------------------------- + * Derived views computed by the adapter (not in the schema itself) + * ---------------------------------------------------------------------- + * mem_obj_id(vid) sort-and-index over (pool_id, offset) + * → dense small int identifying shared + * storage slots. Two values with the + * same id were memory-planned to share + * storage (used by router for + * AllocRequest grouping; used by + * OperatorCall::output_alias_input for + * no-op clone elision). + * value_kind(vid) from membership in inputs/outputs + + * the data field + * → INPUT / OUTPUT / CONSTANT / + * INTERMEDIATE / MUTABLE_BUFFER. + * tensor_constant_data_key(vid) convenience accessor for + * TensorData.External.ndm_key. + * tensor_nbytes_max(vid) dtype_size × prod(sizes); upper bound + * for DYNAMIC_BOUND tensors. + * OperatorCall::output_alias_input(i) + * mem_obj_id matching between an op's + * output and its inputs → backend can + * elide no-op clones. + * + * ---------------------------------------------------------------------- + * Adapter cost + * ---------------------------------------------------------------------- + * All accessors are inline and compile down to essentially the same + * machine code as direct flatbuffer access. The Graph constructor pays + * a one-time O(N log N) precompute (over tensor values) for mem_obj_id, + * and O(num_inputs + num_outputs) for the value_kind sets. Per-call + * overhead in the runtime hot path is a few inline indirections plus + * predictable ET_CHECK branches; release builds compile away these + * checks entirely under -DNDEBUG. + */ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include + +namespace executorch { +namespace backends { +namespace portable { + +// Forward declare +class Graph; + +/** + * Value kind — matches design doc's TensorKind + */ +enum class ValueKind : uint8_t { + INPUT = 0, // Graph input (user provides) + OUTPUT, // Graph output (user reads) + CONSTANT, // Immutable weight + MUTABLE_BUFFER, // Mutable state (e.g., KV cache) + INTERMEDIATE, // Temporary (produced/consumed internally) +}; + +/** + * Type of an EValue stored at a value_id. Mirrors the runtime EValue + * sum-type but in adapter-level form (no flatbuffer types in the API). + */ +enum class ValueType : uint8_t { + None = 0, + Int, + Double, + Bool, + Tensor, + IntList, + Other, // String, OptionalTensor, BoolList, DoubleList, ... + // Adapter doesn't surface these yet; executor falls back to + // default-constructed EValue. +}; + +/** + * Thin wrapper around ExecuTorch's flatbuffer KernelCall providing + * convenient access to op name and input/output value_ids. + * + * ExecuTorch packs all op args into a single `args` array per the + * convention "the last arg is the (single) output." This wrapper + * exposes inputs/outputs accordingly without copying — `inputs()` / + * `output()` return Spans / values that reference the underlying + * flatbuffer storage directly. + * + * Per-call cost: just stores two pointers + an int. No allocations. + * Safe to construct repeatedly in hot dispatch loops. + * + * NOTE: Multi-output ops (e.g., `aten.split`, `aten.max.dim`) are not + * supported by this wrapper — `num_outputs()` always returns 0 or 1. + * Adding multi-output support requires per-op schema knowledge that + * isn't in the flatbuffer. ET_CHECK guards the single-output access + * path. + */ +class OperatorCall { + public: + explicit OperatorCall( + const executorch_flatbuffer::KernelCall* call, + const Graph* graph) + : call_(call), graph_(graph) {} + + // node_id for error messages/profiling + uint32_t node_id() const { return node_id_; } + void set_node_id(uint32_t id) { node_id_ = id; } + + // Op name (e.g., "aten.add.Tensor") + const char* name() const; + + // All op args (flat). Last entry is the output; the rest are inputs. + // Optional args are NOT indicated by -1 — instead the index points to + // a value with isNone() == true. (ExecuTorch convention.) + runtime::Span args() const { + auto* a = call_->args(); + return a ? runtime::Span(a->data(), a->size()) + : runtime::Span{}; + } + + // Inputs = args[0..n-2]. Returns empty span if op has no args. + runtime::Span inputs() const { + auto a = args(); + return a.empty() ? a + : runtime::Span(a.data(), a.size() - 1); + } + + size_t num_inputs() const { + auto a = args(); + return a.empty() ? 0 : a.size() - 1; + } + uint32_t input(size_t i) const { + ET_CHECK_MSG(i < num_inputs(), + "OperatorCall::input: index %zu >= num_inputs()=%zu", + i, num_inputs()); + return static_cast(args()[i]); + } + + // Single-output assumption (see class doc). Multi-output ops will + // break here. + size_t num_outputs() const { return args().empty() ? 0 : 1; } + uint32_t output(size_t i) const { + ET_CHECK_MSG(i < num_outputs(), + "OperatorCall::output: index %zu >= num_outputs()=%zu", + i, num_outputs()); + auto a = args(); + return static_cast(a[a.size() - 1]); + } + + // Returns the input index that output `output_idx` aliases, or -1 if + // none. "Aliases" here means the AOT memory planner placed them in + // the same (pool_id, offset) slot — which the planner only does when + // the op's semantics call for it (e.g., clone, view, expand_as). + // + // Backends can use this to optimize: if a clone's output aliases its + // input, the dispatch can be a no-op (no bytes need to move). + // + // Computed on demand (O(num_inputs) per call); not stored. mem_obj_id + // itself is O(1) (precomputed in Graph constructor). + int32_t output_alias_input(size_t output_idx) const; + + private: + const executorch_flatbuffer::KernelCall* call_; + const Graph* graph_; + uint32_t node_id_ = 0; +}; + +/** + * IR view of a program. Adapts executorch_flatbuffer::ExecutionPlan; + * exposes value metadata, input/output IDs, the operator table, and + * chains of operator calls. See the file-header comment for the + * adapter-pattern rationale. + */ +class Graph { + public: + explicit Graph(const executorch_flatbuffer::ExecutionPlan* plan) + : plan_(plan) { + // Precompute input/output value_id sets for O(1) value_kind lookup. + if (auto* in = plan_->inputs()) { + input_ids_.reserve(in->size()); + for (size_t i = 0; i < in->size(); ++i) { + input_ids_.insert(static_cast(in->Get(i))); + } + } + if (auto* out = plan_->outputs()) { + output_ids_.reserve(out->size()); + for (size_t i = 0; i < out->size(); ++i) { + output_ids_.insert(static_cast(out->Get(i))); + } + } + + // Precompute mem_obj_id for every tensor value. + // + // Algorithm: collect (pool_id, offset) keys for all aliasable tensor + // values; sort the unique keys; assign mem_obj_id = sort rank. Two + // values with the same (pool_id, offset) get the same id (they share + // storage). Sort-and-index is deterministic across runs and depends + // only on the AOT memory plan. + size_t n_vals = num_values(); + mem_obj_ids_.assign(n_vals, -1); + if (n_vals == 0) return; + + // 1. Collect (key, value_id) entries for tensor values with allocation_info. + struct Entry { + uint64_t key; // (pool_id << 32) | offset + uint32_t value_id; + }; + std::vector entries; + entries.reserve(n_vals); + for (uint32_t i = 0; i < n_vals; ++i) { + auto* val = value_meta(i); + if (!val || + val->val_type() != executorch_flatbuffer::KernelTypes::Tensor) { + continue; + } + auto* t = val->val_as_Tensor(); + if (!t) continue; + auto* alloc = t->allocation_info(); + if (!alloc) continue; + uint64_t pool = static_cast(alloc->memory_id()); + uint64_t off = alloc->memory_offset_low(); + entries.push_back({(pool << 32) | off, i}); + } + if (entries.empty()) return; + + // 2. Sort by key (lex order on (pool_id, offset)). + std::sort(entries.begin(), entries.end(), + [](const Entry& a, const Entry& b) { return a.key < b.key; }); + + // 3. Assign mem_obj_id = sort rank (same key → same id). + int32_t next_id = -1; + uint64_t prev_key = ~0ULL; + for (const auto& e : entries) { + if (e.key != prev_key) { + ++next_id; + prev_key = e.key; + } + mem_obj_ids_[e.value_id] = next_id; + } + + // Precompute mutable_buffer_ids_: tensor values with allocation_info, + // not graph IO, not constants, and NOT produced by any op (i.e. + // placeholders that aren't graph inputs). These are mutable buffer + // placeholders pulled into the delegate by tag_mutated_buffer; their + // state persists across execute() calls. + // + // Used by the router to distinguish semantic alias groups (buffer + // mutation: AOT spec-shared the buffer placeholder with its mutation + // source) from lifetime-reuse aliasing (the planner happened to put + // two values at the same offset because their lifetimes don't + // overlap). Only semantic groups need the "all touching ops on same + // runtime else home=host" coordination. + { + // Collect all op-output value_ids across all chains. + std::unordered_set produced_vids; + for (size_t ci = 0; ci < num_chains(); ++ci) { + for (size_t oi = 0; oi < num_ops_in_chain(ci); ++oi) { + OperatorCall op = get_op(ci, oi); + for (size_t k = 0; k < op.num_outputs(); ++k) { + produced_vids.insert(op.output(k)); + } + } + } + for (uint32_t i = 0; i < n_vals; ++i) { + if (mem_obj_ids_[i] < 0) continue; // not an allocated tensor + if (input_ids_.count(i) > 0) continue; // graph input + if (output_ids_.count(i) > 0) continue; // graph output + if (tensor_constant_data_key(i) != nullptr) continue; // constant + if (produced_vids.count(i) > 0) continue; // produced by an op + // Tensor with alloc, not IO, not constant, not produced → it's a + // mutable buffer placeholder. + mutable_buffer_ids_.push_back(i); + } + } + } + + //===------------------------------------------------------------------===// + // Version + //===------------------------------------------------------------------===// + + const char* version() const { + // TODO: Return actual version when available + return "1.0"; + } + + //===------------------------------------------------------------------===// + // Values + //===------------------------------------------------------------------===// + + size_t num_values() const { + auto v = plan_->values(); + return v ? v->size() : 0; + } + + // Access serialized value metadata. + // NOTE: returns the raw flatbuffer EValue. This is the construction- + // seam escape hatch — backends and routers should prefer the typed + // accessors below (value_type, int_value, tensor_*, etc) so they + // don't couple to the underlying serialization. + const executorch_flatbuffer::EValue* value_meta(uint32_t value_id) const { + auto values = plan_->values(); + if (!values || value_id >= values->size()) return nullptr; + return values->Get(value_id); + } + + // Value metadata helpers + ValueKind value_kind(uint32_t value_id) const; + int32_t mem_obj_id(uint32_t value_id) const; + + //===------------------------------------------------------------------===// + // Typed value accessors (adapter-level — no flatbuffer types leak) + //===------------------------------------------------------------------===// + + // Returns the kind of the EValue stored at value_id. + ValueType value_type(uint32_t value_id) const; + + // Scalar accessors — ET_CHECK if the value isn't of the expected kind. + int64_t int_value(uint32_t value_id) const; + double double_value(uint32_t value_id) const; + bool bool_value(uint32_t value_id) const; + + // Tensor accessors — ET_CHECK if the value isn't a tensor. + ::executorch::aten::ScalarType tensor_dtype(uint32_t value_id) const; + ::executorch::runtime::Span tensor_sizes(uint32_t value_id) const; + ::executorch::runtime::Span tensor_dim_order(uint32_t value_id) const; + ::executorch::aten::TensorShapeDynamism tensor_shape_dynamism(uint32_t value_id) const; + // Returns NDM key (FQN) for an external constant tensor, or nullptr + // if the tensor isn't an NDM-stored constant. + const char* tensor_constant_data_key(uint32_t value_id) const; + // dtype-size × prod(sizes); 0 if not a tensor or sizes empty. + size_t tensor_nbytes_max(uint32_t value_id) const; + + // IntList accessors — ET_CHECK if the value isn't an IntList. + // Returns the EValue indices that the list elements reference (stored + // as int64 in the serialization); the caller resolves them through + // the values array. + ::executorch::runtime::Span int_list_member_ids(uint32_t value_id) const; + + //===------------------------------------------------------------------===// + // Input/Output IDs + //===------------------------------------------------------------------===// + + size_t num_input_ids() const { + auto in = plan_->inputs(); + return in ? in->size() : 0; + } + + uint32_t input_id(size_t i) const { + auto in = plan_->inputs(); + ET_CHECK_MSG(in && i < in->size(), + "Graph::input_id(%zu) out of range (have %zu inputs)", + i, in ? in->size() : 0); + return static_cast(in->Get(i)); + } + + size_t num_output_ids() const { + auto out = plan_->outputs(); + return out ? out->size() : 0; + } + + uint32_t output_id(size_t i) const { + auto out = plan_->outputs(); + ET_CHECK_MSG(out && i < out->size(), + "Graph::output_id(%zu) out of range (have %zu outputs)", + i, out ? out->size() : 0); + return static_cast(out->Get(i)); + } + + //===------------------------------------------------------------------===// + // Mutable Buffer IDs (values that persist across execute() calls) + //===------------------------------------------------------------------===// + + size_t num_mutable_buffer_ids() const { + return mutable_buffer_ids_.size(); + } + + uint32_t mutable_buffer_id(size_t i) const { + ET_CHECK_MSG(i < mutable_buffer_ids_.size(), + "Graph::mutable_buffer_id: index %zu out of range " + "(have %zu mutable buffers)", + i, mutable_buffer_ids_.size()); + return mutable_buffer_ids_[i]; + } + + //===------------------------------------------------------------------===// + // Operators (for op name lookup) + //===------------------------------------------------------------------===// + + size_t num_operators() const { + auto ops = plan_->operators(); + return ops ? ops->size() : 0; + } + + const char* operator_name(size_t idx) const { + auto ops = plan_->operators(); + if (!ops || idx >= ops->size()) return nullptr; + auto op = ops->Get(idx); + return op && op->name() ? op->name()->c_str() : nullptr; + } + + //===------------------------------------------------------------------===// + // Chains + //===------------------------------------------------------------------===// + + size_t num_chains() const { + auto chains = plan_->chains(); + return chains ? chains->size() : 0; + } + + int32_t main_chain_idx() const { + return 0; // Default: first chain is main + } + + // Get number of ops in a chain + size_t num_ops_in_chain(size_t chain_idx) const { + auto chains = plan_->chains(); + ET_CHECK_MSG(chains && chain_idx < chains->size(), + "Graph::num_ops_in_chain(%zu) out of range (have %zu chains)", + chain_idx, chains ? chains->size() : 0); + auto instrs = chains->Get(chain_idx)->instructions(); + return instrs ? instrs->size() : 0; + } + + // Get OperatorCall for op in chain + OperatorCall get_op(size_t chain_idx, size_t op_idx) const { + auto chains = plan_->chains(); + ET_CHECK_MSG(chains && chain_idx < chains->size(), + "Graph::get_op: chain_idx=%zu out of range (have %zu chains)", + chain_idx, chains ? chains->size() : 0); + auto instrs = chains->Get(chain_idx)->instructions(); + ET_CHECK_MSG(instrs && op_idx < instrs->size(), + "Graph::get_op: op_idx=%zu out of range in chain %zu " + "(have %zu ops)", + op_idx, chain_idx, instrs ? instrs->size() : 0); + auto instr = instrs->Get(op_idx); + auto kernel = static_cast(instr->instr_args()); + return OperatorCall(kernel, this); + } + + //===------------------------------------------------------------------===// + // Convenience: main chain accessors + //===------------------------------------------------------------------===// + + size_t num_instructions() const { + return num_ops_in_chain(main_chain_idx()); + } + + OperatorCall get_instruction(size_t idx) const { + return get_op(main_chain_idx(), idx); + } + + private: + const executorch_flatbuffer::ExecutionPlan* plan_; + // Precomputed at construction for O(1) value_kind lookup. + std::unordered_set input_ids_; + std::unordered_set output_ids_; + // mem_obj_ids_[value_id] = dense small int identifying the storage slot + // (sort rank of (pool_id, offset) pairs across all aliasable tensor + // values). -1 for non-tensor / non-allocated values. Same id ⇒ same + // storage. Computed once at construction; O(1) lookup at use sites. + std::vector mem_obj_ids_; + + // Mutable buffer placeholder value_ids: tensor values with allocation + // info that aren't graph IO, aren't constants, and aren't produced by + // any op. These persist across execute() calls (their storage is + // preserved between invocations). Identified by tag_mutated_buffer at + // AOT time. + std::vector mutable_buffer_ids_; +}; + +// Implement OperatorCall::name() after Graph is defined +inline const char* OperatorCall::name() const { + // In ExecuTorch, op names are in the operators table, indexed by op_index + return graph_->operator_name(call_->op_index()); +} + +inline int32_t OperatorCall::output_alias_input(size_t output_idx) const { + if (output_idx >= num_outputs()) return -1; + int32_t out_mid = graph_->mem_obj_id(output(output_idx)); + if (out_mid < 0) return -1; + size_t n_in = num_inputs(); + for (size_t i = 0; i < n_in; ++i) { + if (graph_->mem_obj_id(input(i)) == out_mid) { + return static_cast(i); + } + } + return -1; +} + +// Implement value metadata accessors +inline ValueKind Graph::value_kind(uint32_t value_id) const { + if (input_ids_.count(value_id)) return ValueKind::INPUT; + if (output_ids_.count(value_id)) return ValueKind::OUTPUT; + + // Constant if the tensor has a baked data buffer. + auto val = value_meta(value_id); + if (val && val->val_type() == executorch_flatbuffer::KernelTypes::Tensor) { + auto* tensor = val->val_as_Tensor(); + if (tensor && tensor->data_buffer_idx() > 0) { + return ValueKind::CONSTANT; + } + } + return ValueKind::INTERMEDIATE; +} + +inline int32_t Graph::mem_obj_id(uint32_t value_id) const { + return value_id < mem_obj_ids_.size() ? mem_obj_ids_[value_id] : -1; +} + +//===----------------------------------------------------------------------===// +// Typed value accessors +//===----------------------------------------------------------------------===// + +inline ValueType Graph::value_type(uint32_t value_id) const { + auto* val = value_meta(value_id); + if (!val) return ValueType::None; + using KT = executorch_flatbuffer::KernelTypes; + switch (val->val_type()) { + case KT::Null: return ValueType::None; + case KT::Int: return ValueType::Int; + case KT::Double: return ValueType::Double; + case KT::Bool: return ValueType::Bool; + case KT::Tensor: return ValueType::Tensor; + case KT::IntList: return ValueType::IntList; + default: return ValueType::Other; + } +} + +inline int64_t Graph::int_value(uint32_t value_id) const { + auto* val = value_meta(value_id); + ET_CHECK_MSG(val && val->val_type() == executorch_flatbuffer::KernelTypes::Int, + "Graph::int_value(%u): value is not an Int", value_id); + return static_cast(val->val())->int_val(); +} + +inline double Graph::double_value(uint32_t value_id) const { + auto* val = value_meta(value_id); + ET_CHECK_MSG(val && val->val_type() == executorch_flatbuffer::KernelTypes::Double, + "Graph::double_value(%u): value is not a Double", value_id); + return static_cast(val->val())->double_val(); +} + +inline bool Graph::bool_value(uint32_t value_id) const { + auto* val = value_meta(value_id); + ET_CHECK_MSG(val && val->val_type() == executorch_flatbuffer::KernelTypes::Bool, + "Graph::bool_value(%u): value is not a Bool", value_id); + return static_cast(val->val())->bool_val(); +} + +namespace detail { +inline const executorch_flatbuffer::Tensor* tensor_or_null( + const executorch_flatbuffer::EValue* val) { + if (!val || val->val_type() != executorch_flatbuffer::KernelTypes::Tensor) { + return nullptr; + } + return val->val_as_Tensor(); +} +} // namespace detail + +inline ::executorch::aten::ScalarType Graph::tensor_dtype( + uint32_t value_id) const { + auto* t = detail::tensor_or_null(value_meta(value_id)); + ET_CHECK_MSG(t, "Graph::tensor_dtype(%u): value is not a Tensor", value_id); + return static_cast<::executorch::aten::ScalarType>(t->scalar_type()); +} + +inline ::executorch::runtime::Span +Graph::tensor_sizes(uint32_t value_id) const { + auto* t = detail::tensor_or_null(value_meta(value_id)); + ET_CHECK_MSG(t, "Graph::tensor_sizes(%u): value is not a Tensor", value_id); + auto* s = t->sizes(); + return s ? ::executorch::runtime::Span(s->data(), s->size()) + : ::executorch::runtime::Span{}; +} + +inline ::executorch::runtime::Span +Graph::tensor_dim_order(uint32_t value_id) const { + auto* t = detail::tensor_or_null(value_meta(value_id)); + ET_CHECK_MSG(t, "Graph::tensor_dim_order(%u): value is not a Tensor", value_id); + auto* d = t->dim_order(); + return d ? ::executorch::runtime::Span(d->data(), d->size()) + : ::executorch::runtime::Span{}; +} + +inline ::executorch::aten::TensorShapeDynamism +Graph::tensor_shape_dynamism(uint32_t value_id) const { + auto* t = detail::tensor_or_null(value_meta(value_id)); + ET_CHECK_MSG(t, "Graph::tensor_shape_dynamism(%u): value is not a Tensor", + value_id); + return static_cast<::executorch::aten::TensorShapeDynamism>( + t->shape_dynamism()); +} + +inline const char* Graph::tensor_constant_data_key(uint32_t value_id) const { + auto* t = detail::tensor_or_null(value_meta(value_id)); + if (!t) return nullptr; + auto* eti = t->extra_tensor_info(); + if (!eti) return nullptr; + if (eti->location() != executorch_flatbuffer::TensorDataLocation::EXTERNAL) { + return nullptr; + } + auto* fqn = eti->fully_qualified_name(); + return (fqn && fqn->size() > 0) ? fqn->c_str() : nullptr; +} + +inline size_t Graph::tensor_nbytes_max(uint32_t value_id) const { + auto* t = detail::tensor_or_null(value_meta(value_id)); + if (!t || !t->sizes()) return 0; + size_t numel = 1; + for (size_t i = 0; i < t->sizes()->size(); ++i) { + int dim = t->sizes()->Get(i); + if (dim < 0) return 0; + numel *= static_cast(dim); + } + auto stype = static_cast<::executorch::aten::ScalarType>(t->scalar_type()); + return numel * ::executorch::runtime::elementSize(stype); +} + +inline ::executorch::runtime::Span +Graph::int_list_member_ids(uint32_t value_id) const { + auto* val = value_meta(value_id); + ET_CHECK_MSG(val && val->val_type() == executorch_flatbuffer::KernelTypes::IntList, + "Graph::int_list_member_ids(%u): value is not an IntList", + value_id); + auto* items = + static_cast(val->val())->items(); + return items + ? ::executorch::runtime::Span(items->data(), items->size()) + : ::executorch::runtime::Span{}; +} + +} // namespace portable +} // namespace backends +} // namespace executorch diff --git a/backends/portable/runtime_v2/api/Instance.h b/backends/portable/runtime_v2/api/Instance.h new file mode 100644 index 00000000000..a03c9ac1c17 --- /dev/null +++ b/backends/portable/runtime_v2/api/Instance.h @@ -0,0 +1,242 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include // reuse existing Graph + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace portable_v2 { + +/** + * Per-program backend-private compiled state. Holds whatever the runtime + * needs to dispatch a contiguous run of instructions: encoded ICB, + * pipeline state objects, kernel handles, etc. + * + * Held by the Instance; referenced by ComputeStep. + */ +class CompiledSegment { + public: + virtual ~CompiledSegment() = default; +}; + +/** + * Per-program execution state for one runtime. Owns the compiled segments + * and the per-program buffers. + * + * All work-issuing methods are async-by-default. They return immediately + * after enqueueing on the requested QueueKind and signal the provided + * Event* on completion. CPU runtime returns with the event already + * complete; per-call overhead is two branches. + * + * See §4.7 of the design doc. + */ +class Instance { + public: + virtual ~Instance() = default; + + // Compile a contiguous run of instructions. Encodes ICB / shaders / + // pipelines as appropriate. Returned CompiledSegment* is owned by + // Instance. + // SYNCHRONOUS (init-time only). + // + // value_remap: optional rewrite from "graph value_id" to "value_id this + // segment should look up in the BindingTable." Used when the router + // synthesizes new value_ids for cross-runtime transfer destinations: + // the segment's kernels were exported referencing the original value_id + // V, but the router needs them to read from V' (a destination Buffer + // on this runtime). The mapping is applied per op-arg at execute time. + // Pass an empty Span for single-runtime / no-rewrite case. + virtual ::executorch::runtime::Result compile_segment( + const ::executorch::backends::portable::Graph& graph, + ::executorch::runtime::Span instruction_indices, + ::executorch::runtime::Span input_value_ids, + ::executorch::runtime::Span output_value_ids, + ::executorch::runtime::Span> + value_remap) = 0; + + // Batched allocation request. The router builds a list of these (one + // per value_id this Instance is asked to back) and calls allocate_all + // once at init. + // + // The backend reads the tensor's metadata (dtype, sizes, dim_order) + // from `values[req.value_id].toTensor()`. Backends own all memory + // planning internally: + // - CPU / Apple Silicon Metal: typically loop and allocate per-request. + // - Vulkan: group by mem_obj_id, aggregate VkMemoryRequirements (max + // size, max alignment, AND'd memoryTypeBits), allocate one VMA + // allocation per group, bind each user's VkBuffer to it. + // + // Released at Plan tear-down via release_buffer(). + struct AllocRequest { + uint32_t value_id; // index into `values` for this request + int32_t mem_obj_id; // -1 = dedicated allocation; else share with same id + + // Set ONLY for cross-runtime synthetic values (segment IO between + // segments on different providers). Carries the source value's + // Buffer — the host-side end of the transfer. Since all cross- + // runtime transfers go through host, segment IO always has a host + // buffer. + // + // If host_alias != nullptr AND host_alias->host_ptr() is non-null, + // the backend MAY back this allocation by aliasing + // host_alias->host_ptr() (zero-copy). If host_alias is null + // (non-synthetic) or host_alias->host_ptr() is null (source is + // VRAM-only, e.g., Vulkan), the backend allocates fresh and + // per-execute upload_from_host / download_to_host copies bytes. + Buffer* host_alias = nullptr; + }; + + // Allocate Buffers for the requested value_ids. out_buffers[i] is the + // Buffer backing requests[i].value_id. Multiple requests sharing the + // same mem_obj_id MAY yield Buffers that alias the same underlying + // memory; that's the backend's choice. + // + // SYNCHRONOUS (init-time only). + virtual ::executorch::runtime::Error allocate_all( + ::executorch::runtime::Span requests, + ::executorch::runtime::Span values, + ::executorch::runtime::Span out_buffers) = 0; + + // Materialize a graph constant on this runtime; persistent. The Instance + // reads from the NamedDataMap directly so it can choose how to + // materialize: + // - CPU: HostBuffer aliases the FreeableBuffer's region (zero-copy). + // - Apple-Silicon Metal: registers the FreeableBuffer's region with + // MetalStream; MetalBuffer aliases (zero-copy). + // - Discrete GPU: allocate device buffer + load_data_into directly + // (avoids host roundtrip). + // The Instance owns the resulting Buffer AND the FreeableBuffer (so the + // mmap'd region stays alive as long as the Buffer aliases it). + // SYNCHRONOUS (init-time only). Constants do NOT go through + // upload_from_host — they have their own dedicated path. + virtual ::executorch::runtime::Result upload_constant( + const ::executorch::runtime::NamedDataMap& ndm, + std::string_view key) = 0; + + // Provide the storage backing one Event slot. Called once per slot at + // Plan construction. + virtual std::unique_ptr make_event() = 0; + + //=== Async work issuance + sync helpers =================================== + // + // Public hot-path API: `execute` (intra-runtime kernel dispatch) and the + // host↔device pair `upload_from_host` / `download_to_host` (cross-runtime + // moves between CPU and a device runtime). + // + // The transfer methods only ever live on the **device** Instance: + // * upload_from_host: CPU produced a value, this device consumes it. + // * download_to_host: this device produced a value, CPU consumes it. + // CpuInstance never has these called (CPU is always the host side); its + // overrides return NotImplemented. + // + // We never need a runtime↔runtime path because v1 has at most one + // non-host runtime. Any transfer between two non-host runtimes (future) + // would route through host as two steps. + + // Make `dev_dst_buf` reflect `host_src_ptr`'s bytes by the time + // `signal` reaches Complete. Implementation strategy is the runtime's + // call: + // + // - Host-addressable runtimes (CPU, Apple-Silicon Metal) typically + // re-alias `dev_dst_buf` to point at `host_src_ptr` (zero-copy). + // The "skip if dev_dst_buf->host_ptr() == host_src_ptr" check + // makes repeated executes with the same caller pointer free. + // If zero-copy alias fails (e.g. Metal refuses unaligned ptr), + // fall back to memcpy into the existing Buffer's storage. + // - Discrete GPU (Vulkan): real copy from host_src_ptr into + // pre-allocated VRAM via vkCmdCopyBuffer. + // + // Also propagates shape from src to dst's TensorImpl before signaling + // (per the shape-on-event contract). + // + // Default: NotImplemented. Override in non-host Instances. + virtual ::executorch::runtime::Error upload_from_host( + ::executorch::runtime::EValue& /*host_src_ev*/, + void* /*host_src_ptr*/, + ::executorch::runtime::EValue& /*dev_dst_ev*/, + Buffer* /*dev_dst_buf*/, + QueueKind /*queue*/, + ::executorch::runtime::Span /*wait_for*/, + Event* /*signal*/) { + return ::executorch::runtime::Error::NotImplemented; + } + + // Symmetric: make `host_dst_ptr`'s bytes reflect `dev_src_buf`'s + // contents. Same rebind-or-copy contract as upload_from_host. + // + // Default: NotImplemented. Override in non-host Instances. + virtual ::executorch::runtime::Error download_to_host( + ::executorch::runtime::EValue& /*dev_src_ev*/, + Buffer* /*dev_src_buf*/, + ::executorch::runtime::EValue& /*host_dst_ev*/, + void* /*host_dst_ptr*/, + QueueKind /*queue*/, + ::executorch::runtime::Span /*wait_for*/, + Event* /*signal*/) { + return ::executorch::runtime::Error::NotImplemented; + } + + // Inputs guaranteed on this runtime; outputs MUST end up here. + // values: the LoadedDelegate's EValue array (carries dtype/shape and + // scalar inputs). bindings: tensor storage backings (Buffer*) for + // tensor value_ids. Both are needed by kernels: scalars come from + // values[idx], tensor data comes from bindings.get(idx). + // + // SHAPE-ON-EVENT CONTRACT: by the time `signal` reaches + // EventStatus::Complete, every output value's TensorImpl shape AND + // bound Buffer bytes MUST be valid for downstream consumers. + // Backends that determine output shapes synchronously inside execute() + // (CPU portable kernels, Metal-with-MPSGraph metadata) update shape + // before encoding the kernel. Backends whose output shapes are only + // known after the GPU runs (e.g., Metal kernels that determine shape + // post-execution) must register a completion handler that updates the + // output TensorImpls' shape arrays AND THEN signals the event last. + virtual ::executorch::runtime::Error execute( + CompiledSegment* segment, + ::executorch::runtime::Span<::executorch::runtime::EValue> values, + BindingView bindings, + ::executorch::runtime::Span wait_for, + Event* signal) = 0; + + // Block CPU until event signals. Returns Error::Ok iff event reaches + // EventStatus::Complete; returns the stored error if Failed/Poisoned. + virtual ::executorch::runtime::Error wait(Event* event) = 0; + + // Stable per-Instance id within this Instance's RuntimeContext. + virtual InstanceId id() const = 0; + + // Drain in-flight work *issued by this Instance only*. Blocks until + // every currently outstanding submission this Instance has issued via + // copy_*/execute reaches a terminal state. Idempotent. + virtual void drain() = 0; + + // Forwards to the RuntimeContext's pool / recycler. + virtual void release_buffer(Buffer* buf) = 0; +}; + +} // namespace portable_v2 +} // namespace backends +} // namespace executorch diff --git a/backends/portable/runtime_v2/api/InstanceUtils.h b/backends/portable/runtime_v2/api/InstanceUtils.h new file mode 100644 index 00000000000..80f4b65635e --- /dev/null +++ b/backends/portable/runtime_v2/api/InstanceUtils.h @@ -0,0 +1,159 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include + +namespace executorch { +namespace backends { +namespace portable_v2 { + +/** + * Helpers shared by all backend Instance implementations. + * + * These deliberately operate against the public Event interface (no + * dynamic_cast to backend-specific subclasses). + */ + +// If any event in `wait_for` is in a terminal failure state +// (Failed / Poisoned), poison `signal` with that upstream error and +// return InternalError. Otherwise return Ok. +// +// Rationale: callers should short-circuit their work when an upstream +// dependency failed; the poisoned status propagates the failure +// downstream so consumers don't read garbage. +inline ::executorch::runtime::Error check_async_dependencies( + ::executorch::runtime::Span wait_for, + Event* signal) { + for (Event* dep : wait_for) { + if (!dep) continue; + EventStatus s = dep->status(); + if (s == EventStatus::Failed || s == EventStatus::Poisoned) { + if (signal) signal->signal_poisoned(dep->error()); + return ::executorch::runtime::Error::Internal; + } + } + return ::executorch::runtime::Error::Ok; +} + +/** + * Verify (don't sync) the data_ptr invariant for every tensor arg of every + * op in a CompiledSegment: TensorImpl::data_ptr MUST equal + * bindings.get(remap(vid))->host_ptr() for every tensor argument. + * + * The data_ptr is established earlier (prebind_owned_buffers / + * bind_inputs / upload_from_host); by the time we get here it must + * already be in sync. Any mismatch indicates an upstream bug; we log + * and signal failure rather than silently overwriting. + * + * Templated over CompiledSegment subclass (which exposes + * `instruction_indices()` and `remap(vid)`); avoids virtual dispatch and + * a base class. + * + * On any violation: logs an Error with `backend_name` prefix, signals + * `signal->signal_failed(err)` (if non-null), and returns the error. + * + * Non-tensor args (scalars, IntLists) are silently skipped — they + * legitimately don't have Buffer bindings and the kernel reads them + * from the EValue directly. + */ +template +::executorch::runtime::Error verify_segment_bindings( + const SegmentT* seg, + const ::executorch::backends::portable::Graph* graph, + ::executorch::runtime::Span<::executorch::runtime::EValue> values, + BindingView bindings, + Event* signal, + std::string_view backend_name) { + using Err = ::executorch::runtime::Error; + auto fail = [&](Err e) -> Err { + if (signal) signal->signal_failed(e); + return e; + }; + + for (uint32_t instr_idx : seg->instruction_indices()) { + if (instr_idx >= graph->num_instructions()) { + ET_LOG(Error, + "%.*s::execute: instruction index %u out of range " + "(graph has %zu instructions)", + static_cast(backend_name.size()), backend_name.data(), + instr_idx, graph->num_instructions()); + return fail(Err::InvalidProgram); + } + auto op = graph->get_instruction(instr_idx); + + auto check_arg = [&](uint32_t vid_orig) -> Err { + uint32_t vid = seg->remap(vid_orig); + if (vid >= values.size()) { + ET_LOG(Error, + "%.*s::execute: value_id=%u out of range " + "(values.size()=%zu) — router/remap bug", + static_cast(backend_name.size()), backend_name.data(), + vid, values.size()); + return fail(Err::InvalidProgram); + } + const auto& ev = values[vid]; + // Non-tensor args legitimately have no Buffer binding. + if (!ev.isTensor()) return Err::Ok; + Buffer* buf = bindings.get(vid); + if (!buf) { + ET_LOG(Error, + "%.*s::execute: tensor value_id=%u has no Buffer binding " + "— init bug", + static_cast(backend_name.size()), backend_name.data(), + vid); + return fail(Err::InvalidState); + } + void* hp = buf->host_ptr(); + if (!hp) { + ET_LOG(Error, + "%.*s::execute: Buffer for value_id=%u has null host_ptr " + "— Buffer construction bug", + static_cast(backend_name.size()), backend_name.data(), + vid); + return fail(Err::Internal); + } + void* current = ev.toTensor().mutable_data_ptr(); + if (current != hp) { + ET_LOG(Error, + "%.*s::execute: tensor value_id=%u data_ptr=%p but bound " + "Buffer host_ptr=%p — sync invariant violated (prebind / " + "bind_inputs / upload_from_host should have kept these in " + "sync)", + static_cast(backend_name.size()), backend_name.data(), + vid, current, hp); + return fail(Err::Internal); + } + return Err::Ok; + }; + + for (size_t i = 0; i < op.num_inputs(); ++i) { + if (auto e = check_arg(op.input(i)); e != Err::Ok) return e; + } + for (size_t i = 0; i < op.num_outputs(); ++i) { + if (auto e = check_arg(op.output(i)); e != Err::Ok) return e; + } + } + return Err::Ok; +} + +} // namespace portable_v2 +} // namespace backends +} // namespace executorch diff --git a/backends/portable/runtime_v2/api/Location.h b/backends/portable/runtime_v2/api/Location.h new file mode 100644 index 00000000000..e64be523222 --- /dev/null +++ b/backends/portable/runtime_v2/api/Location.h @@ -0,0 +1,56 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace executorch { +namespace backends { +namespace portable_v2 { + +// Process-wide opaque identity, assigned by ProviderRegistry on registration. +// Used for Location, capability queries, registry lookups, debug. +// Never to be hardcoded; never used to index hot-path arrays. +using RuntimeId = uint16_t; +inline constexpr RuntimeId kHost = 0xFFFF; + +// Plan-local dense index into Plan::providers / Plan::instances. +// Used only inside Plan/Step/executor. Translated from RuntimeId once at +// route() time. By convention slot 0 is the host (CPU) Provider; CPU is +// required and always present, so kHostIdx == 0 — no sentinel dance. +using RuntimeIndex = uint8_t; +inline constexpr RuntimeIndex kHostIdx = 0; + +/** + * Pure tag describing where a value lives. ~2 bytes; no pointers, no + * ownership. + */ +class Location { + public: + // Default-constructs to host. Convenient for aggregate-initialized + // structs (InputBinding, OutputBinding) that hold a Location. + constexpr Location() : id_(kHost) {} + + static Location host() { return Location{kHost}; } + static Location on(RuntimeId id) { return Location{id}; } + + RuntimeId runtime_id() const { return id_; } + bool is_host() const { return id_ == kHost; } + + bool operator==(Location other) const { return id_ == other.id_; } + bool operator!=(Location other) const { return !(*this == other); } + + private: + explicit constexpr Location(RuntimeId id) : id_(id) {} + RuntimeId id_; +}; + +} // namespace portable_v2 +} // namespace backends +} // namespace executorch diff --git a/backends/portable/runtime_v2/api/OpDescriptor.h b/backends/portable/runtime_v2/api/OpDescriptor.h new file mode 100644 index 00000000000..5029de43708 --- /dev/null +++ b/backends/portable/runtime_v2/api/OpDescriptor.h @@ -0,0 +1,31 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace executorch { +namespace backends { +namespace portable_v2 { + +/** + * What the Provider sees for a single op when asked can_run(). + * + * Currently carries only the op name. Per-value descriptors (dtype, + * shape, dynamism, etc.) and capability/cost return types are NOT yet + * here — they're additive when multi-provider routing actually needs + * them, and adding fields to this struct is non-breaking. + */ +struct OpDescriptor { + std::string_view name; // e.g. "aten.add.Tensor" +}; + +} // namespace portable_v2 +} // namespace backends +} // namespace executorch diff --git a/backends/portable/runtime_v2/api/Plan.h b/backends/portable/runtime_v2/api/Plan.h new file mode 100644 index 00000000000..5d4764e0776 --- /dev/null +++ b/backends/portable/runtime_v2/api/Plan.h @@ -0,0 +1,130 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace executorch { +namespace backends { +namespace portable_v2 { + +/** + * Per-Instance Buffer ownership record. Released at ~Plan via + * owner->release_buffer(buf). The value_id is the (one) graph value this + * Buffer was allocated for; for shared/aliased intermediates, multiple + * `OwnedBuffer` records may exist with different value_ids pointing to + * the same `Buffer*` (memory-plan aliasing). + */ +struct OwnedBuffer { + Buffer* buf; + Instance* owner; + uint32_t value_id; // bind this id -> buf at LoadedDelegate construction +}; + +/** + * Each event slot remembers which Instance created it. Required so the + * executor can call the right Instance::wait at the host boundary. + */ +struct EventSlot { + std::unique_ptr event; + Instance* owner; +}; + +/** + * Per-input/output binding (router output). + */ +struct InputBinding { + Location loc; + uint32_t value_id; +}; + +struct OutputBinding { + Location loc; + uint32_t value_id; +}; + +/** + * The router synthesizes new value_ids for cross-runtime transfer + * destinations. Each synthetic id has a "source" graph value_id whose + * dtype/shape it inherits. The LoadedDelegate constructor walks this + * list to extend its EValue array and create matching TensorImpls. + * + * Per-execute, a TransferStep moves bytes from source_id to new_id. + * On host-addressable destination runtimes (CPU, Apple Silicon Metal), + * the destination Buffer for new_id is allocated by the backend's + * allocate_all using the AllocRequest's host_alias hint — meaning the + * destination Buffer aliases the source's host_ptr at init. The + * per-execute upload_from_host then sees host_ptr == host_ptr and + * skip-if-same returns immediately (no work). + * + * On Vulkan/discrete GPUs, the destination Buffer is real VRAM; the + * per-execute upload_from_host actually copies bytes via vkCmdCopyBuffer. + */ +struct SyntheticValueDesc { + uint32_t new_id; + uint32_t source_id; +}; + +/** + * Frozen output of Router::route. Holds: + * - Provider/Instance arrays indexed by RuntimeIndex, + * - the issue-ordered Step list, + * - per-input/output bindings, + * - pre-allocated event slots, + * - owned-buffer ledger, + * - per-provider allocation request lists. + * + * See §4.9 of the design doc. + */ +struct Plan { + // Parallel arrays indexed by RuntimeIndex. By convention, index 0 is + // the host (CPU) Provider. + std::vector providers; + std::vector instances; // non-owning; lifetime is LoadedDelegate + + std::vector steps; // ordered by *issue* (not completion) + + std::vector inputs; + std::vector outputs; + + // Pre-allocated event slots. Each event is reset lazily by its + // producing Instance immediately before signaling. + std::vector events; + + // Released at Plan destruction. Filled by the post-route allocation + // step (PortableBackend_v2::allocate_buffers), NOT by the router. + std::vector owned_buffers; + + // Synthesized value_ids for cross-runtime transfer destinations. + // Resolved into TensorImpls by the LoadedDelegate constructor after + // route() returns. Each synthetic value_id always has a TransferStep + // emitted (intra-segment intermediates do not appear here). + std::vector synthetic_values; + + // Per-provider list of allocation requests emitted by the router. + // Allocation itself is performed by a post-route step (host-first + // single-pass), which lets the host allocate first so the device's + // requests can carry host_alias hints pointing at already-allocated + // host buffers (zero-copy alias for cross-runtime synthetics). + // + // alloc_plans[runtime_idx] is the request list for that provider. + std::vector> alloc_plans; +}; + +} // namespace portable_v2 +} // namespace backends +} // namespace executorch diff --git a/backends/portable/runtime_v2/api/Provider.h b/backends/portable/runtime_v2/api/Provider.h new file mode 100644 index 00000000000..ba31dcbfc18 --- /dev/null +++ b/backends/portable/runtime_v2/api/Provider.h @@ -0,0 +1,69 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include + +#include +#include + +#include +#include + +namespace executorch { +namespace backends { +namespace portable_v2 { + +class ProviderRegistry; + +/** + * Process-wide singleton per runtime (CPU, Metal, Vulkan). + * Queryable for capabilities; factory for Instances and Buffers. + * + * See §4.5 of the design doc. + */ +class Provider { + public: + virtual ~Provider() = default; + virtual std::string_view name() const = 0; + + // Assigned by the ProviderRegistry on registration; unique per process. + // Never to be hardcoded. This is the opaque identity (used in + // Location); the dense per-Plan RuntimeIndex is assigned at route() + // time. + RuntimeId id() const { return id_; } + + virtual bool is_available_on_device() const = 0; + + // Capability query. Cheap; called O(num_ops × num_providers) at routing. + // The Provider sees the OpDescriptor and decides accept/reject. + // Returns true iff this Provider can execute the op. Cost ranking can + // be added later by replacing bool with optional; both + // OpDescriptor and the return type are designed to grow non-breakingly. + virtual bool can_run(const OpDescriptor& op) const = 0; + + // Process-wide state (pools, kernel caches, command stream). + // Lazy-initialized on first instantiate(); lives until process exit. + virtual RuntimeContext& context() = 0; + + // Per-program factory. Holds non-owning RuntimeContext& from this + // Provider. + virtual std::unique_ptr instantiate() = 0; + + private: + friend class ProviderRegistry; + RuntimeId id_ = 0; +}; + +} // namespace portable_v2 +} // namespace backends +} // namespace executorch diff --git a/backends/portable/runtime_v2/api/ProviderRegistry.cpp b/backends/portable/runtime_v2/api/ProviderRegistry.cpp new file mode 100644 index 00000000000..88a61ae3085 --- /dev/null +++ b/backends/portable/runtime_v2/api/ProviderRegistry.cpp @@ -0,0 +1,34 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace executorch { +namespace backends { +namespace portable_v2 { + +ProviderRegistry::ProviderRegistry( + std::vector> providers) + : owned_(std::move(providers)) { + all_.reserve(owned_.size()); + available_.reserve(owned_.size()); + + for (RuntimeId id = 0; id < owned_.size(); ++id) { + Provider* p = owned_[id].get(); + // Stamp the RuntimeId on the Provider via friend access. + p->id_ = id; + all_.push_back(p); + if (p->is_available_on_device()) { + available_.push_back(p); + } + } +} + +} // namespace portable_v2 +} // namespace backends +} // namespace executorch diff --git a/backends/portable/runtime_v2/api/ProviderRegistry.h b/backends/portable/runtime_v2/api/ProviderRegistry.h new file mode 100644 index 00000000000..a62afde8b1e --- /dev/null +++ b/backends/portable/runtime_v2/api/ProviderRegistry.h @@ -0,0 +1,75 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include +#include + +#include +#include + +namespace executorch { +namespace backends { +namespace portable_v2 { + +/** + * Per-PortableBackend Provider registry. + * + * v1 commits to explicit factory (no static init): the + * PortableBackend constructor takes a ProviderSet (a list of + * unique_ptr); registration order determines RuntimeId. + * Convention: index 0 is the CPU Provider (host-slot invariant, §4.1). + * + * See §4.5.1 of the design doc. + */ +class ProviderRegistry { + public: + // Construct with the full Provider set up front. Each Provider is + // assigned a fresh RuntimeId equal to its index. Caches the + // is_available_on_device() result for every Provider. + explicit ProviderRegistry( + std::vector> providers); + + // Non-copyable, movable. + ProviderRegistry(const ProviderRegistry&) = delete; + ProviderRegistry& operator=(const ProviderRegistry&) = delete; + ProviderRegistry(ProviderRegistry&&) = default; + ProviderRegistry& operator=(ProviderRegistry&&) = default; + + // The set of providers whose is_available_on_device() returned true. + // Cached for the life of the registry. v1 does NOT support hot-plug. + ::executorch::runtime::Span available() const { + return ::executorch::runtime::Span( + available_.data(), available_.size()); + } + + // All providers regardless of availability (for diagnostics). + ::executorch::runtime::Span all() const { + return ::executorch::runtime::Span( + all_.data(), all_.size()); + } + + // Lookup by id. + Provider* lookup(RuntimeId id) const { + if (id == kHost) return nullptr; + if (id >= owned_.size()) return nullptr; + return owned_[id].get(); + } + + private: + std::vector> owned_; // index = RuntimeId + std::vector all_; // raw pointers parallel to owned_ + std::vector available_; // subset where is_available_on_device() +}; + +} // namespace portable_v2 +} // namespace backends +} // namespace executorch diff --git a/backends/portable/runtime_v2/api/Router.h b/backends/portable/runtime_v2/api/Router.h new file mode 100644 index 00000000000..0bbdd3f92e9 --- /dev/null +++ b/backends/portable/runtime_v2/api/Router.h @@ -0,0 +1,48 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include + +#include +#include +#include + +namespace executorch { +namespace backends { +namespace portable_v2 { + +struct RouterOptions { + bool dump_trace = false; +}; + +/** + * Router::route maps (Graph, Providers, Instances) -> Plan. + * + * Default implementation in routers/GreedyRouter.h. See §4.10 of the + * design doc. + */ +class Router { + public: + virtual ~Router() = default; + + virtual ::executorch::runtime::Result route( + const ::executorch::backends::portable::Graph& graph, + ::executorch::runtime::Span providers, + ::executorch::runtime::Span instances, + const ::executorch::runtime::NamedDataMap* ndm, // for upload_constant + const RouterOptions& options) = 0; +}; + +} // namespace portable_v2 +} // namespace backends +} // namespace executorch diff --git a/backends/portable/runtime_v2/api/RuntimeContext.h b/backends/portable/runtime_v2/api/RuntimeContext.h new file mode 100644 index 00000000000..032527cc3a7 --- /dev/null +++ b/backends/portable/runtime_v2/api/RuntimeContext.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace executorch { +namespace backends { +namespace portable_v2 { + +/** + * Process-wide state owned by a Provider (pools, kernel caches, GPU + * queues, command streams, per-Instance submission tracker). + * + * Tag base class; each concrete runtime subclasses with its own state. + * Survives across multiple loaded programs. + * + * See §4.8 of the design doc. + */ +class RuntimeContext { + public: + virtual ~RuntimeContext() = default; +}; + +// Unique within a single RuntimeContext (NOT process-wide). Used by +// SubmissionTracker to scope drain() to a single Instance's submissions. +using InstanceId = uint32_t; + +} // namespace portable_v2 +} // namespace backends +} // namespace executorch diff --git a/backends/portable/runtime_v2/api/Step.h b/backends/portable/runtime_v2/api/Step.h new file mode 100644 index 00000000000..2cab64ba6da --- /dev/null +++ b/backends/portable/runtime_v2/api/Step.h @@ -0,0 +1,53 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +#include + +#include +#include + +namespace executorch { +namespace backends { +namespace portable_v2 { + +/** + * One unit of issued work. Carries dense RuntimeIndex (not opaque + * RuntimeId). See §4.9 of the design doc. + */ +struct ComputeStep { + RuntimeIndex runtime_idx; // dense index into Plan::instances + CompiledSegment* segment; + QueueKind queue; // Compute in v1 + ::executorch::runtime::Span wait_for; + EventId signal; // kNoEvent = none +}; + +struct TransferStep { + // Both ends are looked up via bindings at execute time. + uint32_t src_value_id; + uint32_t dst_value_id; + Location src; // identity-tagged, for diagnostics & trace + Location dst; + RuntimeIndex src_idx; // hot-path; kHostIdx if host + RuntimeIndex dst_idx; + QueueKind queue; // typically Transfer + ::executorch::runtime::Span wait_for; + EventId signal; +}; + +using Step = std::variant; + +} // namespace portable_v2 +} // namespace backends +} // namespace executorch diff --git a/backends/portable/runtime_v2/cpu/CpuEvent.h b/backends/portable/runtime_v2/cpu/CpuEvent.h new file mode 100644 index 00000000000..65ddd7bf8bc --- /dev/null +++ b/backends/portable/runtime_v2/cpu/CpuEvent.h @@ -0,0 +1,67 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include + +#include + +namespace executorch { +namespace backends { +namespace portable_v2 { + +/** + * Trivially-completed Event for CPU. CPU operations run synchronously, + * so the producing call sets status to Complete before returning. + * + * Memory ordering: status_ is atomic with acquire/release semantics + * (acquire on read, release on signal) per §4.6. + */ +class CpuEvent final : public Event { + public: + CpuEvent() = default; + + void prepare_signal() override { + // Allowed only when not Failed/Poisoned (executor enforces). + status_.store(EventStatus::Pending, std::memory_order_release); + error_ = ::executorch::runtime::Error::Ok; + } + + EventStatus status() const override { + return status_.load(std::memory_order_acquire); + } + + ::executorch::runtime::Error error() const override { return error_; } + + // CPU-specific helpers used by CpuInstance to drive transitions: + + void signal_complete() override { + status_.store(EventStatus::Complete, std::memory_order_release); + } + + void signal_failed(::executorch::runtime::Error e) override { + error_ = e; + status_.store(EventStatus::Failed, std::memory_order_release); + } + + void signal_poisoned(::executorch::runtime::Error upstream_error) override { + error_ = upstream_error; + status_.store(EventStatus::Poisoned, std::memory_order_release); + } + + private: + std::atomic status_{EventStatus::Pending}; + ::executorch::runtime::Error error_ = ::executorch::runtime::Error::Ok; +}; + +} // namespace portable_v2 +} // namespace backends +} // namespace executorch diff --git a/backends/portable/runtime_v2/cpu/CpuInstance.cpp b/backends/portable/runtime_v2/cpu/CpuInstance.cpp new file mode 100644 index 00000000000..3df44da7389 --- /dev/null +++ b/backends/portable/runtime_v2/cpu/CpuInstance.cpp @@ -0,0 +1,363 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include +#include + +#include +#include +#include +#include + +#include +#include +#include + +namespace executorch { +namespace backends { +namespace portable_v2 { + +namespace { + +// Casts the EventStatus path uniformly. Returns Ok if event is null or +// already Complete; otherwise returns the carried error. +::executorch::runtime::Error wait_status(EventStatus s, Event* e) { + if (s == EventStatus::Complete) return ::executorch::runtime::Error::Ok; + if (e == nullptr) return ::executorch::runtime::Error::Ok; + return e->error(); +} + +} // namespace + +CpuInstance::~CpuInstance() = default; + +::executorch::runtime::Result CpuInstance::compile_segment( + const ::executorch::backends::portable::Graph& graph, + ::executorch::runtime::Span instruction_indices, + ::executorch::runtime::Span /*input_value_ids*/, + ::executorch::runtime::Span /*output_value_ids*/, + ::executorch::runtime::Span> + value_remap) { + std::vector idxs( + instruction_indices.begin(), instruction_indices.end()); + std::unordered_map remap; + remap.reserve(value_remap.size()); + for (const auto& kv : value_remap) { + remap.emplace(kv.first, kv.second); + } + auto seg = std::make_unique(&graph, std::move(idxs), + std::move(remap)); + CompiledSegment* raw = seg.get(); + compiled_segments_.push_back(std::move(seg)); + return raw; +} + +::executorch::runtime::Error CpuInstance::allocate_all( + ::executorch::runtime::Span requests, + ::executorch::runtime::Span values, + ::executorch::runtime::Span out_buffers) { + if (requests.size() != out_buffers.size()) { + return ::executorch::runtime::Error::InvalidArgument; + } + // CPU has unified memory: simple per-request allocation. + // Default 16-byte alignment for SIMD. + // host_alias would only be set for cross-runtime synthetic values whose + // CONSUMER is on CPU; in v1 (host=CPU), all synthetic mirrors live on + // device, not on host, so host_alias is always null here. We still + // honor it defensively so future "device produces, host consumes" + // cases can re-use the same plumbing. + // Dedup by mem_obj_id (set by the AOT memory planner): multiple + // AllocRequests sharing the same mem_obj_id (>=0) refer to the SAME + // physical slot — e.g. a buffer placeholder and its mutation result + // (in-place buffer mutation), or two values with non-overlapping + // lifetimes that the planner reused. Allocate one Buffer per group; + // every sharer points at the same data_ptr. + std::unordered_map mem_obj_buffers; + for (size_t i = 0; i < requests.size(); ++i) { + const auto& req = requests[i]; + if (req.value_id >= values.size() || + !values[req.value_id].isTensor()) { + ET_LOG(Error, + "CpuInstance::allocate_all: value_id=%u missing or not a tensor", + req.value_id); + return ::executorch::runtime::Error::InvalidArgument; + } + + if (req.mem_obj_id >= 0) { + auto it = mem_obj_buffers.find(req.mem_obj_id); + if (it != mem_obj_buffers.end()) { + out_buffers[i] = it->second; + ET_LOG(Debug, + "[mem] cpu: allocate_all value_id=%u shares mem_obj_id=%d -> Buffer host_ptr=%p", + req.value_id, req.mem_obj_id, it->second->host_ptr()); + continue; + } + } + + const auto& tensor = values[req.value_id].toTensor(); + size_t nbytes = tensor.nbytes(); + Buffer* buf = nullptr; + + // Zero-copy alias path: source already host-addressable, just wrap. + if (req.host_alias && req.host_alias->host_ptr()) { + void* p = req.host_alias->host_ptr(); + std::unique_ptr hb(HostBuffer::alias(p, nbytes)); + buf = hb.get(); + ET_LOG(Debug, + "[mem] cpu: allocate_all value_id=%u bytes=%zu host_alias=%p (zero-copy alias)", + req.value_id, nbytes, p); + owned_buffers_.push_back(std::move(hb)); + } else { + // Fresh allocation. + std::unique_ptr hb(HostBuffer::allocate(nbytes, /*alignment=*/16)); + if (!hb) return ::executorch::runtime::Error::MemoryAllocationFailed; + buf = hb.get(); + ET_LOG(Debug, + "[mem] cpu: allocate_all value_id=%u bytes=%zu host_ptr=%p", + req.value_id, nbytes, buf->host_ptr()); + owned_buffers_.push_back(std::move(hb)); + } + out_buffers[i] = buf; + if (req.mem_obj_id >= 0) { + mem_obj_buffers[req.mem_obj_id] = buf; + } + } + return ::executorch::runtime::Error::Ok; +} + +::executorch::runtime::Result CpuInstance::upload_constant( + const ::executorch::runtime::NamedDataMap& ndm, + std::string_view key) { + // CPU: zero-copy alias the FreeableBuffer's data. The HostBuffer + // takes ownership of the FreeableBuffer (frees it in destructor). + auto data_result = ndm.get_data(key); + if (!data_result.ok()) { + ET_LOG(Error, + "CpuInstance: NDM key not found for constant upload"); + return data_result.error(); + } + size_t bytes = data_result.get().size(); + void* ptr = const_cast(data_result.get().data()); + std::unique_ptr hb( + HostBuffer::alias_ndm(std::move(data_result.get()))); + Buffer* raw = hb.get(); + ET_LOG(Debug, + "[mem] cpu: upload_constant key='%.*s' bytes=%zu host_ptr=%p (zero-copy NDM alias)", + static_cast(key.size()), key.data(), bytes, ptr); + owned_buffers_.push_back(std::move(hb)); + return raw; +} + +std::unique_ptr CpuInstance::make_event() { + return std::make_unique(); +} + +::executorch::runtime::Error CpuInstance::upload_from_host( + ::executorch::runtime::EValue& host_src_ev, + void* host_src_ptr, + ::executorch::runtime::EValue& dev_dst_ev, + Buffer* dev_dst_buf, + QueueKind /*queue*/, + ::executorch::runtime::Span wait_for, + Event* signal) { + if (auto e = check_dependencies_(wait_for, signal); + e != ::executorch::runtime::Error::Ok) { + return e; + } + if (!host_src_ptr || !dev_dst_buf || !host_src_ev.isTensor() || + !dev_dst_ev.isTensor()) { + if (signal) signal->signal_failed(::executorch::runtime::Error::InvalidArgument); + return ::executorch::runtime::Error::InvalidArgument; + } + + auto& src_t = host_src_ev.toTensor(); + auto& dst_t = dev_dst_ev.toTensor(); + size_t nbytes = src_t.nbytes(); + + // Propagate shape (per the shape-on-event contract). + if (auto e = ::executorch::runtime::resize_tensor(dst_t, src_t.sizes()); + e != ::executorch::runtime::Error::Ok) { + if (signal) signal->signal_failed(e); + return e; + } + + // Re-alias the destination HostBuffer to point at host_src_ptr. + // Idempotent if already pointing there. Frees any prior Owned storage. + auto* hb = static_cast(dev_dst_buf); + bool was_already = (hb->host_ptr() == host_src_ptr); + hb->re_alias(host_src_ptr, nbytes); + dst_t.unsafeGetTensorImpl()->set_data(host_src_ptr); + + if (signal) { + signal->prepare_signal(); + signal->signal_complete(); + } + ET_LOG(Debug, + "[mem] cpu: upload_from_host caller_ptr=%p bytes=%zu (%s)", + host_src_ptr, nbytes, + was_already ? "alias unchanged" : "re-aliased"); + return ::executorch::runtime::Error::Ok; +} + +::executorch::runtime::Error CpuInstance::download_to_host( + ::executorch::runtime::EValue& dev_src_ev, + Buffer* dev_src_buf, + ::executorch::runtime::EValue& host_dst_ev, + void* host_dst_ptr, + QueueKind /*queue*/, + ::executorch::runtime::Span wait_for, + Event* signal) { + if (auto e = check_dependencies_(wait_for, signal); + e != ::executorch::runtime::Error::Ok) { + return e; + } + if (!host_dst_ptr || !dev_src_buf || !dev_src_ev.isTensor() || + !host_dst_ev.isTensor()) { + if (signal) signal->signal_failed(::executorch::runtime::Error::InvalidArgument); + return ::executorch::runtime::Error::InvalidArgument; + } + + auto& src_t = dev_src_ev.toTensor(); + auto& dst_t = host_dst_ev.toTensor(); + size_t nbytes = src_t.nbytes(); + + if (auto e = ::executorch::runtime::resize_tensor(dst_t, src_t.sizes()); + e != ::executorch::runtime::Error::Ok) { + if (signal) signal->signal_failed(e); + return e; + } + + // For CPU, "download to host" is the symmetric: re-alias the source + // Buffer's pointer to the destination host pointer (caller's storage). + // The kernel that produced dev_src_buf wrote directly into host_dst_ptr + // already if alias was set up at bind_outputs time, in which case this + // is idempotent. + auto* hb = static_cast(dev_src_buf); + bool was_already = (hb->host_ptr() == host_dst_ptr); + hb->re_alias(host_dst_ptr, nbytes); + dst_t.unsafeGetTensorImpl()->set_data(host_dst_ptr); + + if (signal) { + signal->prepare_signal(); + signal->signal_complete(); + } + ET_LOG(Debug, + "[mem] cpu: download_to_host caller_ptr=%p bytes=%zu (%s)", + host_dst_ptr, nbytes, + was_already ? "alias unchanged" : "re-aliased"); + return ::executorch::runtime::Error::Ok; +} + +::executorch::runtime::Error CpuInstance::check_dependencies_( + ::executorch::runtime::Span wait_for, Event* signal) { + return check_async_dependencies(wait_for, signal); +} + +::executorch::runtime::Error CpuInstance::execute( + CompiledSegment* segment, + ::executorch::runtime::Span<::executorch::runtime::EValue> values, + BindingView bindings, + ::executorch::runtime::Span wait_for, + Event* signal) { + if (auto e = check_dependencies_(wait_for, signal); + e != ::executorch::runtime::Error::Ok) { + return e; + } + + auto* seg = static_cast(segment); + if (!seg) { + return ::executorch::runtime::Error::InvalidArgument; + } + const auto* graph = seg->graph(); + if (!graph) return ::executorch::runtime::Error::InvalidState; + + // Verify each tensor EValue's TensorImpl::data_ptr matches the + // currently-bound HostBuffer's host_ptr (after remap). The data_ptr + // is established at prebind / bind_inputs / upload_from_host time; + // by the time we get here it MUST match the bound Buffer. + if (auto e = verify_segment_bindings(seg, graph, values, bindings, + signal, "CpuInstance"); + e != ::executorch::runtime::Error::Ok) { + return e; + } + + // Drive ops via the existing portable kernel registry. + ::executorch::runtime::KernelRuntimeContext kctx{}; + ::executorch::backends::portable::CpuGraph cpu_graph(kctx, values); + + for (uint32_t instr_idx : seg->instruction_indices()) { + auto op = graph->get_instruction(instr_idx); + const char* op_name = op.name(); + if (!op_name) { + if (signal) signal->signal_failed(::executorch::runtime::Error::InvalidProgram); + return ::executorch::runtime::Error::InvalidProgram; + } + + ET_LOG(Debug, "CpuInstance: instr %u op='%s' (in=%zu, out=%zu)", instr_idx, + op_name, op.num_inputs(), op.num_outputs()); + + auto* handler = + ::executorch::backends::portable::cpu_op_registry().try_get_op_fn( + op_name); + if (!handler) { + ET_LOG(Error, "CpuInstance: no handler for %s", op_name); + if (signal) signal->signal_failed(::executorch::runtime::Error::NotSupported); + return ::executorch::runtime::Error::NotSupported; + } + + std::vector<::executorch::backends::portable::ValueRef> args; + args.reserve(op.num_inputs() + op.num_outputs()); + for (size_t i = 0; i < op.num_inputs(); ++i) { + args.push_back(static_cast<::executorch::backends::portable::ValueRef>( + seg->remap(op.input(i)))); + } + for (size_t i = 0; i < op.num_outputs(); ++i) { + args.push_back(static_cast<::executorch::backends::portable::ValueRef>( + seg->remap(op.output(i)))); + } + (*handler)(cpu_graph, args); + + if (kctx.failure_state() != ::executorch::runtime::Error::Ok) { + auto err = kctx.failure_state(); + if (signal) signal->signal_failed(err); + return err; + } + } + + if (signal) { + signal->prepare_signal(); + signal->signal_complete(); + } + return ::executorch::runtime::Error::Ok; +} + +::executorch::runtime::Error CpuInstance::wait(Event* event) { + if (!event) return ::executorch::runtime::Error::Ok; + // CPU events are settled by the producing call. Just translate status. + return wait_status(event->status(), event); +} + +void CpuInstance::release_buffer(Buffer* buf) { + // For Owned/NDM buffers held in owned_buffers_, the destructor of the + // unique_ptr handles freeing on ~CpuInstance. release_buffer is called + // for every owned buffer at ~Plan, which we treat as a no-op here + // (the actual free happens at our destruction). + // + // For Aliasing buffers from the HostImportArena, this is called from + // the executor at the top of the next execute() — also a no-op since + // the arena reset() is what reclaims the slot. + (void)buf; +} + +} // namespace portable_v2 +} // namespace backends +} // namespace executorch diff --git a/backends/portable/runtime_v2/cpu/CpuInstance.h b/backends/portable/runtime_v2/cpu/CpuInstance.h new file mode 100644 index 00000000000..630946af546 --- /dev/null +++ b/backends/portable/runtime_v2/cpu/CpuInstance.h @@ -0,0 +1,152 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace portable_v2 { + +/** + * CompiledSegment for CPU: just remembers which instructions to run. + * No actual compilation — CPU dispatches via the portable kernel + * registry per-instruction at execute time. + * + * value_remap rewrites graph value_ids to the value_ids the segment + * should consult in the BindingTable. The router uses this when it + * synthesizes new value_ids for cross-runtime transfer destinations. + */ +class CpuCompiledSegment final : public CompiledSegment { + public: + CpuCompiledSegment( + const ::executorch::backends::portable::Graph* graph, + std::vector instruction_indices, + std::unordered_map value_remap) + : graph_(graph), + instruction_indices_(std::move(instruction_indices)), + value_remap_(std::move(value_remap)) {} + + const ::executorch::backends::portable::Graph* graph() const { + return graph_; + } + const std::vector& instruction_indices() const { + return instruction_indices_; + } + uint32_t remap(uint32_t v) const { + auto it = value_remap_.find(v); + return it != value_remap_.end() ? it->second : v; + } + bool has_remap() const { return !value_remap_.empty(); } + + private: + const ::executorch::backends::portable::Graph* graph_; + std::vector instruction_indices_; + std::unordered_map value_remap_; +}; + +/** + * CPU Instance — synchronous execution via the existing portable kernel + * registry. All work-issuing methods complete before returning; events + * transition to Complete or Failed in-place. + */ +class CpuInstance final : public Instance { + public: + explicit CpuInstance(CpuRuntimeContext& ctx, InstanceId id) + : ctx_(ctx), id_(id) {} + + ~CpuInstance() override; + + ::executorch::runtime::Result compile_segment( + const ::executorch::backends::portable::Graph& graph, + ::executorch::runtime::Span instruction_indices, + ::executorch::runtime::Span input_value_ids, + ::executorch::runtime::Span output_value_ids, + ::executorch::runtime::Span> + value_remap) override; + + ::executorch::runtime::Error allocate_all( + ::executorch::runtime::Span requests, + ::executorch::runtime::Span values, + ::executorch::runtime::Span out_buffers) override; + + ::executorch::runtime::Result upload_constant( + const ::executorch::runtime::NamedDataMap& ndm, + std::string_view key) override; + + std::unique_ptr make_event() override; + + // CpuInstance overrides these to re-alias the destination HostBuffer + // to point at the caller's host pointer (zero-copy). Used for graph + // I/O bindings on the CPU side and for cross-runtime intermediates + // where CPU is the consumer. + ::executorch::runtime::Error upload_from_host( + ::executorch::runtime::EValue& host_src_ev, + void* host_src_ptr, + ::executorch::runtime::EValue& dev_dst_ev, + Buffer* dev_dst_buf, + QueueKind queue, + ::executorch::runtime::Span wait_for, + Event* signal) override; + + ::executorch::runtime::Error download_to_host( + ::executorch::runtime::EValue& dev_src_ev, + Buffer* dev_src_buf, + ::executorch::runtime::EValue& host_dst_ev, + void* host_dst_ptr, + QueueKind queue, + ::executorch::runtime::Span wait_for, + Event* signal) override; + + ::executorch::runtime::Error execute( + CompiledSegment* segment, + ::executorch::runtime::Span<::executorch::runtime::EValue> values, + BindingView bindings, + ::executorch::runtime::Span wait_for, + Event* signal) override; + + ::executorch::runtime::Error wait(Event* event) override; + + InstanceId id() const override { return id_; } + + void drain() override {} // CPU is synchronous; no in-flight work. + + void release_buffer(Buffer* buf) override; + + private: + CpuRuntimeContext& ctx_; + InstanceId id_; + + // Owns all Buffers allocated via allocate() / upload_constant(). + // I/O destination Buffers may transition from Owned → Aliasing + // in-place via HostBuffer::re_alias when upload_from_host re-points + // them at caller storage; the destructor still works correctly via + // mode_ tracking. + std::vector> owned_buffers_; + + // Owns CompiledSegments returned from compile_segment. + std::vector> compiled_segments_; + + // Helper: check wait_for for poison; if any, signal poisons signal and + // returns AsyncDependencyFailed. + ::executorch::runtime::Error check_dependencies_( + ::executorch::runtime::Span wait_for, Event* signal); +}; + +} // namespace portable_v2 +} // namespace backends +} // namespace executorch diff --git a/backends/portable/runtime_v2/cpu/CpuOpRegistry.h b/backends/portable/runtime_v2/cpu/CpuOpRegistry.h new file mode 100644 index 00000000000..8c2017e997e --- /dev/null +++ b/backends/portable/runtime_v2/cpu/CpuOpRegistry.h @@ -0,0 +1,173 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +/** + * CpuOpRegistry.h - CPU backend op registration macros + * + * Provides CPU-specific registration macros built on the generic OpRegistry. + * CPU ops dispatch to ExecuTorch's existing portable kernel library. + */ + +#include +#include +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace portable { + +// Forward declaration +class CpuRuntime; + +/** + * CpuGraph - Wrapper providing the Graph interface for CPU execution. + * + * Routes tensor access: boundary values use shared values_ array, + * CPU-internal intermediates use shadow EVaules from CpuRuntime. + */ +class CpuGraph { + public: + CpuGraph( + runtime::KernelRuntimeContext& ctx, + runtime::Span values, + const std::unordered_set* intermediate_indices = nullptr, + std::unordered_map* cpu_shadow = nullptr) + : ctx_(ctx), values_(values), + intermediate_indices_(intermediate_indices), cpu_shadow_(cpu_shadow) {} + + //===--------------------------------------------------------------------===// + // Graph Interface (required by OpView) + //===--------------------------------------------------------------------===// + + bool val_is_tensor(ValueRef idx) const { + return idx < values_.size() && get_value(idx).isTensor(); + } + + bool val_is_none(ValueRef idx) const { + return idx >= values_.size() || get_value(idx).isNone(); + } + + template + T extract_scalar(ValueRef idx) const { + const auto& ev = get_value(idx); + if (ev.isInt()) { + return static_cast(ev.toInt()); + } else if (ev.isDouble()) { + return static_cast(ev.toDouble()); + } else if (ev.isBool()) { + return static_cast(ev.toBool()); + } + return T{}; + } + + executorch::aten::ScalarType dtype_of(ValueRef idx) const { + if (idx < values_.size() && get_value(idx).isTensor()) { + return get_value(idx).toTensor().scalar_type(); + } + return executorch::aten::ScalarType::Float; + } + + std::vector sizes_of(ValueRef idx) const { + if (idx < values_.size() && get_value(idx).isTensor()) { + auto sizes = get_value(idx).toTensor().sizes(); + return std::vector(sizes.begin(), sizes.end()); + } + return {}; + } + + //===--------------------------------------------------------------------===// + // CPU-Specific Accessors + //===--------------------------------------------------------------------===// + + runtime::KernelRuntimeContext& context() { + return ctx_; + } + + /// Get EValue - routes to shadow for intermediates, values_ for boundaries + runtime::EValue& value(ValueRef idx) { + return get_value_mut(idx); + } + + runtime::EValue* value_ptr(ValueRef idx) { + return &get_value_mut(idx); + } + + size_t num_values() const { + return values_.size(); + } + + private: + runtime::KernelRuntimeContext& ctx_; + runtime::Span values_; + const std::unordered_set* intermediate_indices_; + std::unordered_map* cpu_shadow_; + + /// Route access: intermediate → shadow, boundary → values_ + const runtime::EValue& get_value(ValueRef idx) const { + if (intermediate_indices_ && cpu_shadow_ && + intermediate_indices_->count(idx) > 0) { + auto it = cpu_shadow_->find(idx); + if (it != cpu_shadow_->end()) { + return it->second; + } + } + return values_[idx]; + } + + runtime::EValue& get_value_mut(ValueRef idx) { + if (intermediate_indices_ && cpu_shadow_ && + intermediate_indices_->count(idx) > 0) { + auto it = cpu_shadow_->find(idx); + if (it != cpu_shadow_->end()) { + return it->second; + } + } + return values_[idx]; + } +}; + +/// Global CPU op registry accessor. +OperatorRegistry& cpu_op_registry(); + +//===----------------------------------------------------------------------===// +// CPU Registration Macros +//===----------------------------------------------------------------------===// + +/// Check if op is registered. +#define CPU_HAS_OP(name) \ + ::executorch::backends::portable::cpu_op_registry().has_op(name) + +/// Check if op supports dtype. +#define CPU_HAS_OP_DTYPE(name, dtype) \ + ::executorch::backends::portable::cpu_op_registry().has_op(name, dtype) + +/// Get op function. +#define CPU_GET_OP_FN(name) \ + ::executorch::backends::portable::cpu_op_registry().get_op_fn(name) + +/// Register an op (all dtypes). +#define CPU_REGISTER_OP(name, fn) \ + ::executorch::backends::portable::cpu_op_registry().register_op(#name, fn) + +/// Register an op with specific dtypes. +#define CPU_REGISTER_OP_DTYPES(name, fn, ...) \ + ::executorch::backends::portable::cpu_op_registry().register_op( \ + #name, fn, {__VA_ARGS__}) + +/// Static registration block. +#define REGISTER_CPU_OPERATORS \ + REGISTER_OPERATORS(cpu) + +} // namespace portable +} // namespace backends +} // namespace executorch diff --git a/backends/portable/runtime_v2/cpu/CpuOps.cpp b/backends/portable/runtime_v2/cpu/CpuOps.cpp new file mode 100644 index 00000000000..9f55a072406 --- /dev/null +++ b/backends/portable/runtime_v2/cpu/CpuOps.cpp @@ -0,0 +1,143 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/** + * CpuOps.cpp - CPU op implementations for the portable backend + * + * Uses generic dispatch_kernel() to call ExecuTorch portable kernels. + * Adding a new op is a single line: CPU_DISPATCH_OP(op_name, "kernel_name"); + */ + +#include +#include +#include + +namespace executorch { +namespace backends { +namespace portable { + +// Global registry instance +OperatorRegistry& cpu_op_registry() { + static OperatorRegistry registry; + return registry; +} + +namespace { + +/// Generic kernel dispatch - passes args + dummy return slot to kernel +void dispatch_kernel( + CpuGraph& graph, + const std::vector& args, + const char* kernel_name) { + + auto& ctx = graph.context(); + + auto kernel = torch::executor::getOpsFn( + kernel_name, runtime::ArrayRef()); + + if (!kernel) { + ET_LOG(Error, "CPU: kernel %s not found", kernel_name); + ctx.fail(runtime::Error::NotSupported); + return; + } + + // Build stack from args + dummy return slot (codegen wrapper expects it) + runtime::EValue dummy; + std::vector stack; + stack.reserve(args.size() + 1); + + for (size_t i = 0; i < args.size(); i++) { + stack.push_back(graph.value_ptr(args[i])); + } + stack.push_back(&dummy); + + kernel(ctx, runtime::Span(stack.data(), stack.size())); +} + +/// Dispatch for aten::copy_ (in-place buffer writeback). The IR emits 4 +/// args [self, src, non_blocking, out_formal] — the in-place semantic is +/// "self gets src's bytes." We dispatch to the 3-arg in-place +/// "aten::copy_" kernel directly, ignoring the formal out arg. +/// +/// Short-circuit for self-aliased writebacks: when the AOT memory planner +/// has aliased self and src to the same slot (true for buffer mutations +/// after our `alias_buffer_mutations_post_planning` pass), self.data_ptr +/// == src.data_ptr — the writeback is trivially satisfied by the kernel +/// that already wrote to the shared slot. Skip the (potentially expensive, +/// large for KV-cache) memcpy in that case. +void dispatch_copy_inplace( + CpuGraph& graph, + const std::vector& args) { + auto& ctx = graph.context(); + if (args.size() < 3) { + ET_LOG(Error, "CPU: aten::copy_ expects at least 3 args, got %zu", args.size()); + ctx.fail(runtime::Error::InvalidArgument); + return; + } + auto* self_ev = graph.value_ptr(args[0]); + auto* src_ev = graph.value_ptr(args[1]); + if (self_ev->isTensor() && src_ev->isTensor()) { + if (self_ev->toTensor().const_data_ptr() == + src_ev->toTensor().const_data_ptr()) { + // Aliased: the bytes already match (kernel that produced src wrote + // through to self's slot). Writeback is a no-op. + ET_LOG(Debug, + "CPU: aten::copy_ short-circuit (self == src @ %p) — no-op", + self_ev->toTensor().const_data_ptr()); + return; + } + } + auto kernel = torch::executor::getOpsFn( + "aten::copy_", runtime::ArrayRef()); + if (!kernel) { + ET_LOG(Error, "CPU: kernel aten::copy_ not found"); + ctx.fail(runtime::Error::NotSupported); + return; + } + // 3-arg in-place schema: copy_(self, src, non_blocking) -> self. + // Stack: [self, src, non_blocking, return_slot]. + runtime::EValue dummy; + std::vector stack = { + self_ev, // self + src_ev, // src + graph.value_ptr(args[2]), // non_blocking + &dummy, // return slot + }; + kernel(ctx, runtime::Span(stack.data(), stack.size())); +} + +} // namespace + +//===----------------------------------------------------------------------===// +// Op Registration - maps op names to kernel names +//===----------------------------------------------------------------------===// + +#define CPU_DISPATCH_OP(op_name, kernel_name) \ + CPU_REGISTER_OP(op_name, [](CpuGraph& graph, const std::vector& args) { \ + dispatch_kernel(graph, args, kernel_name); \ + }) + +REGISTER_CPU_OPERATORS { + CPU_DISPATCH_OP(aten::add, "aten::add.out"); + CPU_DISPATCH_OP(aten::sub, "aten::sub.out"); + CPU_DISPATCH_OP(aten::mul, "aten::mul.out"); + CPU_DISPATCH_OP(aten::div, "aten::div.out"); + CPU_DISPATCH_OP(aten::permute_copy, "aten::permute_copy.out"); + CPU_DISPATCH_OP(aten::mm, "aten::mm.out"); + CPU_DISPATCH_OP(aten::clone, "aten::clone.out"); + // copy_ uses a custom dispatcher (drops the formal out arg, calls the + // 3-arg in-place kernel). The IR emits 4 args but the in-place kernel + // only takes 3; the formal out is just a memory-plan placeholder. + CPU_REGISTER_OP(aten::copy_, dispatch_copy_inplace); +} + +#undef CPU_DISPATCH_OP + +} // namespace portable +} // namespace backends +} // namespace executorch diff --git a/backends/portable/runtime_v2/cpu/CpuProvider.cpp b/backends/portable/runtime_v2/cpu/CpuProvider.cpp new file mode 100644 index 00000000000..4fab5eeb020 --- /dev/null +++ b/backends/portable/runtime_v2/cpu/CpuProvider.cpp @@ -0,0 +1,37 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +#include + +namespace executorch { +namespace backends { +namespace portable_v2 { + +bool CpuProvider::can_run(const OpDescriptor& op) const { + std::string name(op.name); + // Allowlist filter (test/dev mode). + if (!supported_ops_.empty() && supported_ops_.count(name) == 0) { + return false; + } + // Default: accept any op the portable kernel registry knows about. + return ::executorch::backends::portable::cpu_op_registry().has_op(name); +} + +std::unique_ptr CpuProvider::instantiate() { + InstanceId id = next_instance_id_.fetch_add(1, std::memory_order_relaxed); + return std::make_unique(ctx_, id); +} + +} // namespace portable_v2 +} // namespace backends +} // namespace executorch diff --git a/backends/portable/runtime_v2/cpu/CpuProvider.h b/backends/portable/runtime_v2/cpu/CpuProvider.h new file mode 100644 index 00000000000..9edcaa48b7c --- /dev/null +++ b/backends/portable/runtime_v2/cpu/CpuProvider.h @@ -0,0 +1,68 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace portable_v2 { + +/** + * CPU Provider — fallback runtime that supports every op in the existing + * portable kernel registry. Always available; expected at index 0 (host + * slot invariant). + * + * Configurable for testing: pass a name override and/or an op allowlist + * to spin up a SECOND CpuProvider that pretends to be a different + * runtime for multi-provider routing tests. (e.g. "fake_accel" with + * supported_ops = {"aten::add", "aten::mul"}.) + */ +class CpuProvider final : public Provider { + public: + CpuProvider() = default; + + // Test/dev constructor: override the registered name and restrict ops. + // If supported_ops is empty, behaves like the default ("accept all + // registered ops"). If non-empty, only ops in the set are accepted. + CpuProvider(std::string_view name, + std::unordered_set supported_ops) + : name_(name), supported_ops_(std::move(supported_ops)) {} + + ~CpuProvider() override = default; + + std::string_view name() const override { return name_; } + + bool is_available_on_device() const override { return true; } + + bool can_run(const OpDescriptor& op) const override; + + RuntimeContext& context() override { return ctx_; } + + std::unique_ptr instantiate() override; + + private: + CpuRuntimeContext ctx_; + std::atomic next_instance_id_{0}; + std::string_view name_ = "cpu"; + // Empty = accept any op the portable kernel registry can dispatch. + // Non-empty = only the listed names. + std::unordered_set supported_ops_; +}; + +} // namespace portable_v2 +} // namespace backends +} // namespace executorch diff --git a/backends/portable/runtime_v2/cpu/CpuRuntimeContext.h b/backends/portable/runtime_v2/cpu/CpuRuntimeContext.h new file mode 100644 index 00000000000..3c61b46f185 --- /dev/null +++ b/backends/portable/runtime_v2/cpu/CpuRuntimeContext.h @@ -0,0 +1,29 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace executorch { +namespace backends { +namespace portable_v2 { + +/** + * RuntimeContext for the CPU Provider. Currently empty — CPU runs + * synchronously through the existing portable-kernel registry; no GPU + * queues, no kernel cache, no per-execute arena. Reserved for future + * cross-execute caches (e.g., cached layout-conversion buffers). + * + * Process-global per-Provider; survives across LoadedDelegate lifetimes. + */ +class CpuRuntimeContext final : public RuntimeContext {}; + +} // namespace portable_v2 +} // namespace backends +} // namespace executorch diff --git a/backends/portable/runtime_v2/cpu/HostBuffer.h b/backends/portable/runtime_v2/cpu/HostBuffer.h new file mode 100644 index 00000000000..e1a06642a07 --- /dev/null +++ b/backends/portable/runtime_v2/cpu/HostBuffer.h @@ -0,0 +1,114 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include + +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace portable_v2 { + +/** + * Concrete Buffer subclass for host (CPU) memory. + * + * Three lifetime modes (selected at construction): + * 1. OWNED — wraps memory we malloc'd; freed in destructor. + * 2. ALIASING — wraps a host pointer the caller keeps alive (used by + * HostImportArena and bind_inputs/bind_outputs); destructor is a no-op. + * 3. NDM_ALIAS — wraps a NamedDataMap FreeableBuffer (held internally, + * released in destructor). + * + * The arena/recycler can re-target an ALIASING HostBuffer's pointer/size + * via re_alias() to recycle slots. + */ +class HostBuffer : public Buffer { + public: + enum class Mode : uint8_t { Owned, Aliasing, NdmAlias }; + + // Mode::Owned — allocate `bytes` via std::aligned_alloc. + static HostBuffer* allocate(size_t bytes, size_t alignment) { + void* mem = std::aligned_alloc( + alignment, ((bytes + alignment - 1) / alignment) * alignment); + return new HostBuffer(Mode::Owned, mem, bytes); + } + + // Mode::Aliasing — wrap a caller-owned pointer. + static HostBuffer* alias(void* ptr, size_t bytes) { + return new HostBuffer(Mode::Aliasing, ptr, bytes); + } + + // Mode::NdmAlias — wrap a FreeableBuffer (move-in; held until ~HostBuffer). + static HostBuffer* alias_ndm(::executorch::runtime::FreeableBuffer&& fb) { + void* ptr = const_cast(fb.data()); + size_t bytes = fb.size(); + auto* hb = new HostBuffer(Mode::NdmAlias, ptr, bytes); + hb->ndm_buffer_.~FreeableBuffer(); + new (&hb->ndm_buffer_) ::executorch::runtime::FreeableBuffer(std::move(fb)); + return hb; + } + + ~HostBuffer() override { + if (mode_ == Mode::Owned && ptr_) { + std::free(ptr_); + } else if (mode_ == Mode::NdmAlias) { + ndm_buffer_.Free(); + } + // Aliasing: nothing to do. + } + + void* host_ptr() override { return ptr_; } + + // Re-target this buffer's pointer/size in place. + // + // - Idempotent: if `ptr == ptr_` and `bytes == size_bytes()`, no-op. + // - Owned → Aliasing transition: if currently Owned and ptr_ != nullptr, + // frees the malloc'd storage (so the destructor doesn't double-free) + // and switches mode to Aliasing. + // - Aliasing-mode rebind: just updates ptr/size. + // - NdmAlias is invalid here (NDM buffers shouldn't be re-aliased); + // asserts in debug. + // + // Used by HostImportArena AND by upload_from_host's rebind path so + // upload-time aliasing replaces a stale Owned allocation in place. + void re_alias(void* ptr, size_t bytes) { + if (ptr_ == ptr && size_bytes() == bytes) { + return; // already aliased here; cheap idempotent path + } + if (mode_ == Mode::Owned && ptr_) { + std::free(ptr_); + } + mode_ = Mode::Aliasing; + ptr_ = ptr; + set_size_bytes(bytes); + } + + Mode mode() const { return mode_; } + + private: + HostBuffer(Mode m, void* ptr, size_t bytes) + : Buffer(Location::host(), bytes), + mode_(m), + ptr_(ptr), + ndm_buffer_(nullptr, 0, nullptr, nullptr) {} + + Mode mode_; + void* ptr_; + ::executorch::runtime::FreeableBuffer ndm_buffer_; +}; + +} // namespace portable_v2 +} // namespace backends +} // namespace executorch diff --git a/backends/portable/runtime_v2/metal/MetalBuffer.h b/backends/portable/runtime_v2/metal/MetalBuffer.h new file mode 100644 index 00000000000..0c54265530b --- /dev/null +++ b/backends/portable/runtime_v2/metal/MetalBuffer.h @@ -0,0 +1,144 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include + +#include +#include + +// Forward declaration so MetalBuffer.h stays pure C++ (no Metal/ObjC +// imports). The actual id is owned by MetalStream's internal +// pools (ptrToBuffer_), so MetalBuffer just holds the host_ptr from +// stream->alloc() / registerExternalBuffer() and the byte count. +namespace executorch { +namespace backends { +namespace metal_v2 { +class MetalStream; +} // namespace metal_v2 +} // namespace backends +} // namespace executorch + +namespace executorch { +namespace backends { +namespace portable_v2 { + +/** + * Concrete Buffer subclass for Apple Silicon Metal storage. + * + * On Apple Silicon, MTLBuffer.contents is a host-addressable pointer + * (unified memory). We therefore store the host pointer here directly; + * the underlying id lives in MetalStream's internal + * ptrToBuffer_ map, looked up via stream->bufferForPtr() when ops need + * to encode dispatches against it. This keeps MetalBuffer.h pure-C++ + * and avoids dragging Metal/ObjC headers into runtime_v2 callers. + * + * Three lifetime modes (selected at construction): + * 1. OWNED — wraps memory we got from stream->alloc(); destructor + * calls stream->free(ptr_) to return it to the pool. + * 2. ALIASING — wraps a host pointer the caller keeps alive (router's + * cross-runtime alias optimization, or upload_from_host re-aliasing + * an Owned-mode Buffer to point at caller storage). Destructor is + * a no-op; pool memory was already returned during the Owned → + * Aliasing transition by re_alias. + * 3. NDM_ALIAS — wraps a FreeableBuffer from the NamedDataMap (held + * internally; released in destructor). The data was registered with + * MetalStream via registerExternalBuffer() at upload time. + * + * The "location" is set per Provider (router stamps it during routing). + */ +class MetalBuffer : public Buffer { + public: + enum class Mode : uint8_t { Owned, Aliasing, NdmAlias }; + + // Mode::Owned — `ptr` was returned by stream->alloc(bytes). Destructor + // calls stream->free(ptr). + static MetalBuffer* allocate( + ::executorch::backends::metal_v2::MetalStream* stream, + Location loc, + void* ptr, + size_t bytes) { + return new MetalBuffer(Mode::Owned, loc, stream, ptr, bytes); + } + + // Mode::Aliasing — `ptr` came from caller storage; was registered with + // the stream by upload_from_host (or set up directly by the caller). + // Destructor is a no-op (caller keeps host memory alive). + static MetalBuffer* alias( + ::executorch::backends::metal_v2::MetalStream* stream, + Location loc, + void* ptr, + size_t bytes) { + return new MetalBuffer(Mode::Aliasing, loc, stream, ptr, bytes); + } + + // Mode::NdmAlias — `fb` is the NDM FreeableBuffer; `ptr` = fb.data(). + // Was registered with the stream via registerExternalBuffer. + // Destructor releases the FreeableBuffer (which keeps the mmap'd + // region alive). + static MetalBuffer* alias_ndm( + ::executorch::backends::metal_v2::MetalStream* stream, + Location loc, + ::executorch::runtime::FreeableBuffer&& fb) { + void* ptr = const_cast(fb.data()); + size_t bytes = fb.size(); + auto* mb = new MetalBuffer(Mode::NdmAlias, loc, stream, ptr, bytes); + mb->ndm_buffer_.~FreeableBuffer(); + new (&mb->ndm_buffer_) + ::executorch::runtime::FreeableBuffer(std::move(fb)); + return mb; + } + + ~MetalBuffer() override; // defined in MetalBuffer.mm to call stream_->free() + + void* host_ptr() override { return ptr_; } + + // Re-target this buffer's pointer/size in place. + // + // - Idempotent: if `ptr == ptr_` and `bytes == size_bytes()`, no-op. + // - Owned → Aliasing transition: if currently Owned, returns the + // pool-allocated ptr_ to the stream's pool and switches mode to + // Aliasing. The new caller-supplied ptr is then stored. The caller + // is responsible for stream->registerExternalBuffer(new_ptr) so + // dispatches resolve. + // - Aliasing-mode rebind: just updates ptr/size. + // + // Defined in MetalBuffer.mm to call stream_->free. + void re_alias(void* ptr, size_t bytes); + + Mode mode() const { return mode_; } + + ::executorch::backends::metal_v2::MetalStream* stream() const { + return stream_; + } + + private: + MetalBuffer( + Mode m, + Location loc, + ::executorch::backends::metal_v2::MetalStream* stream, + void* ptr, + size_t bytes) + : Buffer(loc, bytes), + mode_(m), + stream_(stream), + ptr_(ptr), + ndm_buffer_(nullptr, 0, nullptr, nullptr) {} + + Mode mode_; + ::executorch::backends::metal_v2::MetalStream* stream_; + void* ptr_; + ::executorch::runtime::FreeableBuffer ndm_buffer_; +}; + +} // namespace portable_v2 +} // namespace backends +} // namespace executorch diff --git a/backends/portable/runtime_v2/metal/MetalBuffer.mm b/backends/portable/runtime_v2/metal/MetalBuffer.mm new file mode 100644 index 00000000000..1b60f61d8b8 --- /dev/null +++ b/backends/portable/runtime_v2/metal/MetalBuffer.mm @@ -0,0 +1,51 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +namespace executorch { +namespace backends { +namespace portable_v2 { + +MetalBuffer::~MetalBuffer() { + if (mode_ == Mode::Owned && ptr_ && stream_) { + stream_->free(ptr_); + } else if (mode_ == Mode::NdmAlias) { + ndm_buffer_.Free(); + // Note: we don't unregister from MetalStream's ptrToBuffer_ here. + // For Owned mode, stream_->free() handles pool return AND ptrToBuffer_ + // bookkeeping; for NdmAlias, the registration persists for the + // delegate's lifetime, which is correct since constants are + // permanent. Net leak per delegate teardown is bounded by the + // number of constants, and the next delegate gets a fresh stream. + } + // Aliasing: nothing to free. +} + +void MetalBuffer::re_alias(void* ptr, size_t bytes) { + // Idempotent: already aliased here. + if (ptr_ == ptr && size_bytes() == bytes) { + return; + } + // Owned → Aliasing transition: return pool memory to the stream and + // switch modes so the destructor doesn't double-free. + if (mode_ == Mode::Owned && ptr_ && stream_) { + stream_->free(ptr_); + } + mode_ = Mode::Aliasing; + ptr_ = ptr; + set_size_bytes(bytes); + // Note: caller is responsible for stream_->registerExternalBuffer(ptr) + // so dispatches resolve via bufferForPtr(). +} + +} // namespace portable_v2 +} // namespace backends +} // namespace executorch diff --git a/backends/portable/runtime_v2/metal/MetalEvent.h b/backends/portable/runtime_v2/metal/MetalEvent.h new file mode 100644 index 00000000000..a050cf5f670 --- /dev/null +++ b/backends/portable/runtime_v2/metal/MetalEvent.h @@ -0,0 +1,77 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include + +namespace executorch { +namespace backends { +namespace portable_v2 { + +/** + * MetalEvent — async event signaled by a Metal command buffer's + * completion handler. + * + * Same atomic-flag pattern as CpuEvent. The only difference is who + * calls signal_complete() / signal_failed(): + * - CpuEvent: synchronous code right after the work returns. + * - MetalEvent: a [commandBuffer addCompletedHandler:] block that + * fires on a private dispatch queue after the GPU finishes. + * + * The contract is the same: by the time status() observes Complete, + * both the bytes in any Buffer the producer wrote AND the shape on any + * TensorImpl the producer's outputs reference are valid for host reads. + * + * For MetalEvent specifically, "valid bytes" needs no extra work on + * Apple Silicon (unified memory: GPU writes are visible to host once + * the command buffer completes). Shape is updated inside the producer + * Instance's execute() before signaling, per the standard contract. + */ +class MetalEvent final : public Event { + public: + MetalEvent() : status_(EventStatus::Complete), error_(::executorch::runtime::Error::Ok) {} + ~MetalEvent() override = default; + + void prepare_signal() override { + error_ = ::executorch::runtime::Error::Ok; + status_.store(EventStatus::Pending, std::memory_order_release); + } + + EventStatus status() const override { + return status_.load(std::memory_order_acquire); + } + + ::executorch::runtime::Error error() const override { return error_; } + + // Called by MetalInstance code (typically inside a command buffer + // completion handler). + void signal_complete() override { + status_.store(EventStatus::Complete, std::memory_order_release); + } + + void signal_failed(::executorch::runtime::Error err) override { + error_ = err; + status_.store(EventStatus::Failed, std::memory_order_release); + } + + void signal_poisoned(::executorch::runtime::Error err) override { + error_ = err; + status_.store(EventStatus::Poisoned, std::memory_order_release); + } + + private: + std::atomic status_; + ::executorch::runtime::Error error_; +}; + +} // namespace portable_v2 +} // namespace backends +} // namespace executorch diff --git a/backends/portable/runtime_v2/metal/MetalInstance.h b/backends/portable/runtime_v2/metal/MetalInstance.h new file mode 100644 index 00000000000..2ae2b3a8853 --- /dev/null +++ b/backends/portable/runtime_v2/metal/MetalInstance.h @@ -0,0 +1,169 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include +#include +#include +#include + +// Forward decl: keep this header pure-C++; .mm includes MetalStream.h. +namespace executorch { +namespace backends { +namespace metal_v2 { +class MetalStream; +} // namespace metal_v2 +} // namespace backends +} // namespace executorch + +namespace executorch { +namespace backends { +namespace portable_v2 { + +class MetalProvider; +class MetalBuffer; + +/** + * CompiledSegment for Metal: stores instruction indices and the + * value_remap (analogous to CpuCompiledSegment). Kernel compilation + * happens lazily inside MetalOp::getKernel during dispatch. + */ +class MetalCompiledSegment final : public CompiledSegment { + public: + MetalCompiledSegment( + const ::executorch::backends::portable::Graph* graph, + std::vector instruction_indices, + std::unordered_map value_remap) + : graph_(graph), + instruction_indices_(std::move(instruction_indices)), + value_remap_(std::move(value_remap)) {} + + const ::executorch::backends::portable::Graph* graph() const { + return graph_; + } + const std::vector& instruction_indices() const { + return instruction_indices_; + } + uint32_t remap(uint32_t v) const { + auto it = value_remap_.find(v); + return it != value_remap_.end() ? it->second : v; + } + bool has_remap() const { return !value_remap_.empty(); } + + private: + const ::executorch::backends::portable::Graph* graph_; + std::vector instruction_indices_; + std::unordered_map value_remap_; +}; + +/** + * Metal Instance — dispatches ops via the existing metal_v2 + * MetalOpRegistry. Per-execute model: + * 1. For each op in the segment: look up MetalOp; build EValue* + * vectors for inputs/outputs; sync TensorImpl::data_ptr to the + * bound MetalBuffer's host_ptr (= [mtlBuffer contents]); call + * op->dispatch(stream_, ins, outs). + * 2. After all dispatches, stream_->flush() commits the live command + * buffer; install an addCompletedHandler that signals our event + * when the GPU finishes. + * 3. On Apple Silicon (unified memory), bytes in the buffer are + * visible to host once the event signals. + * + * transfer_tensor handles cross-runtime hand-off via host pointers + * (source's host_ptr → dst's host_ptr memcpy; works for CPU↔Metal + * because both are HostBuffer-compatible on Apple Silicon). + */ +class MetalInstance final : public Instance { + public: + MetalInstance(MetalProvider* provider, InstanceId id); + ~MetalInstance() override; + + ::executorch::runtime::Result compile_segment( + const ::executorch::backends::portable::Graph& graph, + ::executorch::runtime::Span instruction_indices, + ::executorch::runtime::Span input_value_ids, + ::executorch::runtime::Span output_value_ids, + ::executorch::runtime::Span> + value_remap) override; + + ::executorch::runtime::Error allocate_all( + ::executorch::runtime::Span requests, + ::executorch::runtime::Span values, + ::executorch::runtime::Span out_buffers) override; + + ::executorch::runtime::Result upload_constant( + const ::executorch::runtime::NamedDataMap& ndm, + std::string_view key) override; + + std::unique_ptr make_event() override; + + // Cross-runtime moves: only the device side overrides these. + // Apple Silicon Metal: re-aliases the destination MetalBuffer to point + // at the caller's host pointer (zero-copy via newBufferWithBytesNoCopy). + // Falls back to memcpy into the original pool-allocated buffer if Metal + // refuses zero-copy on the pointer. + ::executorch::runtime::Error upload_from_host( + ::executorch::runtime::EValue& host_src_ev, + void* host_src_ptr, + ::executorch::runtime::EValue& dev_dst_ev, + Buffer* dev_dst_buf, + QueueKind queue, + ::executorch::runtime::Span wait_for, + Event* signal) override; + + ::executorch::runtime::Error download_to_host( + ::executorch::runtime::EValue& dev_src_ev, + Buffer* dev_src_buf, + ::executorch::runtime::EValue& host_dst_ev, + void* host_dst_ptr, + QueueKind queue, + ::executorch::runtime::Span wait_for, + Event* signal) override; + + ::executorch::runtime::Error execute( + CompiledSegment* segment, + ::executorch::runtime::Span<::executorch::runtime::EValue> values, + BindingView bindings, + ::executorch::runtime::Span wait_for, + Event* signal) override; + + ::executorch::runtime::Error wait(Event* event) override; + + InstanceId id() const override { return id_; } + + void drain() override; + + void release_buffer(Buffer* buf) override; + + private: + // Helper: drain wait_for events; if any is Failed/Poisoned, poison + // signal and return AsyncDependencyFailed. + ::executorch::runtime::Error check_dependencies_( + ::executorch::runtime::Span wait_for, Event* signal); + + MetalProvider* provider_; // not owned; provider outlives instance + ::executorch::backends::metal_v2::MetalStream* stream_; // borrowed from provider + InstanceId id_; + + // Owns CompiledSegments returned from compile_segment. + std::vector> compiled_segments_; + + // Owns all MetalBuffers from allocate / upload_constant. I/O destination + // Buffers may transition Owned → Aliasing in-place when upload_from_host + // re-points them at caller storage; the destructor still works correctly + // via mode_ tracking. + std::vector> owned_buffers_; +}; + +} // namespace portable_v2 +} // namespace backends +} // namespace executorch diff --git a/backends/portable/runtime_v2/metal/MetalInstance.mm b/backends/portable/runtime_v2/metal/MetalInstance.mm new file mode 100644 index 00000000000..58fbc610347 --- /dev/null +++ b/backends/portable/runtime_v2/metal/MetalInstance.mm @@ -0,0 +1,451 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace portable_v2 { + +namespace metal_v2_ns = ::executorch::backends::metal_v2; +namespace runtime = ::executorch::runtime; + +MetalInstance::MetalInstance(MetalProvider* provider, InstanceId id) + : provider_(provider), + stream_(provider ? provider->stream() : nullptr), + id_(id) {} + +MetalInstance::~MetalInstance() { + // Drain any in-flight work before tearing down owned buffers, so the + // GPU isn't still referencing them. + drain(); + // Buffers are owned via unique_ptr vector; destructors handle pool + // returns (Owned mode) or are no-ops (Aliasing/NdmAlias). + owned_buffers_.clear(); +} + +runtime::Error MetalInstance::check_dependencies_( + runtime::Span wait_for, Event* signal) { + for (Event* e : wait_for) { + if (!e) continue; + if (auto err = wait(e); err != runtime::Error::Ok) { + if (signal) signal->signal_poisoned(err); + return runtime::Error::Internal; // AsyncDependencyFailed + } + } + return runtime::Error::Ok; +} + +runtime::Result MetalInstance::compile_segment( + const ::executorch::backends::portable::Graph& graph, + runtime::Span instruction_indices, + runtime::Span /*input_value_ids*/, + runtime::Span /*output_value_ids*/, + runtime::Span> value_remap) { + std::vector idxs( + instruction_indices.begin(), instruction_indices.end()); + std::unordered_map remap; + remap.reserve(value_remap.size()); + for (const auto& kv : value_remap) { + remap.emplace(kv.first, kv.second); + } + auto seg = std::make_unique( + &graph, std::move(idxs), std::move(remap)); + CompiledSegment* raw = seg.get(); + compiled_segments_.push_back(std::move(seg)); + return raw; +} + +runtime::Error MetalInstance::allocate_all( + runtime::Span requests, + runtime::Span values, + runtime::Span out_buffers) { + if (!stream_) return runtime::Error::InvalidState; + if (requests.size() != out_buffers.size()) { + return runtime::Error::InvalidArgument; + } + Location loc = Location::on(provider_->id()); + for (size_t i = 0; i < requests.size(); ++i) { + const auto& req = requests[i]; + if (req.value_id >= values.size() || + !values[req.value_id].isTensor()) { + ET_LOG(Error, + "MetalInstance::allocate_all: value_id=%u missing or not a tensor", + req.value_id); + return runtime::Error::InvalidArgument; + } + const auto& tensor = values[req.value_id].toTensor(); + size_t nbytes = tensor.nbytes(); + + // Cross-runtime synthetic with a host-addressable source: try + // zero-copy alias. Register the source's host pointer with the + // stream as a standalone MTLBuffer (newBufferWithBytesNoCopy). If + // Metal accepts, return a Mode::Aliasing MetalBuffer wrapping the + // host pointer — NO pool allocation. + if (req.host_alias && req.host_alias->host_ptr()) { + void* p = req.host_alias->host_ptr(); + if (stream_->registerExternalBuffer(p, nbytes, /*strict_zero_copy=*/true)) { + auto* buf = MetalBuffer::alias(stream_, loc, p, nbytes); + out_buffers[i] = buf; + ET_LOG(Debug, + "[mem] metal: allocate_all value_id=%u bytes=%zu host_alias=%p (zero-copy alias)", + req.value_id, nbytes, p); + owned_buffers_.emplace_back(buf); + continue; + } + // Metal refused zero-copy on this pointer (alignment etc.). Fall + // through to a fresh pool allocation; per-execute upload_from_host + // will memcpy bytes in. + } + + // Fresh pool allocation: intermediates, IO destinations, alias-refused + // synthetics, and (in the future) alias hints whose source has no + // host_ptr (Vulkan-produced values). + void* ptr = stream_->alloc(nbytes); + if (!ptr) return runtime::Error::MemoryAllocationFailed; + (void)stream_->bufferForPtr(ptr, nbytes); + auto* buf = MetalBuffer::allocate(stream_, loc, ptr, nbytes); + out_buffers[i] = buf; + ET_LOG(Debug, + "[mem] metal: allocate_all value_id=%u bytes=%zu host_ptr=%p", + req.value_id, nbytes, ptr); + owned_buffers_.emplace_back(buf); + } + return runtime::Error::Ok; +} + +runtime::Result MetalInstance::upload_constant( + const runtime::NamedDataMap& ndm, std::string_view key) { + if (!stream_) return runtime::Error::InvalidState; + auto fb_result = ndm.get_data(key); + if (!fb_result.ok()) { + ET_LOG(Error, "MetalInstance: upload_constant: NDM key '%.*s' not found", + static_cast(key.size()), key.data()); + return fb_result.error(); + } + runtime::FreeableBuffer fb = std::move(fb_result.get()); + void* ptr = const_cast(fb.data()); + size_t bytes = fb.size(); + // Zero-copy alias: register the FreeableBuffer's mmap'd region with + // MetalStream so kernels' bufferForPtr() resolves to a wrapping + // MTLBuffer (newBufferWithBytesNoCopy under the hood). + stream_->registerExternalBuffer(ptr, bytes); + (void)stream_->bufferForPtr(ptr, bytes); + Location loc = Location::on(provider_->id()); + auto* buf = MetalBuffer::alias_ndm(stream_, loc, std::move(fb)); + owned_buffers_.emplace_back(buf); + ET_LOG(Debug, + "[mem] metal: upload_constant key='%.*s' bytes=%zu host_ptr=%p (zero-copy NDM alias via registerExternalBuffer)", + static_cast(key.size()), key.data(), bytes, ptr); + return buf; +} + +std::unique_ptr MetalInstance::make_event() { + return std::make_unique(); +} + +runtime::Error MetalInstance::upload_from_host( + runtime::EValue& host_src_ev, + void* host_src_ptr, + runtime::EValue& dev_dst_ev, + Buffer* dev_dst_buf, + QueueKind /*queue*/, + runtime::Span wait_for, + Event* signal) { + // 1. Wait on producer event(s). + if (auto e = check_dependencies_(wait_for, signal); + e != runtime::Error::Ok) { + return e; + } + + // 2. Validate. + if (!host_src_ptr || !dev_dst_buf || !host_src_ev.isTensor() || + !dev_dst_ev.isTensor()) { + if (signal) signal->signal_failed(runtime::Error::InvalidArgument); + return runtime::Error::InvalidArgument; + } + + auto& src_t = host_src_ev.toTensor(); + auto& dst_t = dev_dst_ev.toTensor(); + size_t nbytes = src_t.nbytes(); + + // 3. Propagate shape. + if (auto e = runtime::resize_tensor(dst_t, src_t.sizes()); + e != runtime::Error::Ok) { + if (signal) signal->signal_failed(e); + return e; + } + + auto* mb = static_cast(dev_dst_buf); + + // 4. Skip-if-same: if dest is already aliased here, no work. + if (mb->host_ptr() == host_src_ptr) { + dst_t.unsafeGetTensorImpl()->set_data(host_src_ptr); + if (signal) { + signal->prepare_signal(); + signal->signal_complete(); + } + ET_LOG(Debug, + "[mem] metal: upload_from_host caller_ptr=%p bytes=%zu (alias unchanged)", + host_src_ptr, nbytes); + return runtime::Error::Ok; + } + + // 5. Try zero-copy: register caller pointer with the stream as a + // standalone MTLBuffer (newBufferWithBytesNoCopy). If Metal accepts, + // re-alias the destination MetalBuffer to point at host_src_ptr (the + // pool-allocated storage is returned to the pool). + if (stream_->registerExternalBuffer( + host_src_ptr, nbytes, /*strict_zero_copy=*/true)) { + mb->re_alias(host_src_ptr, nbytes); + dst_t.unsafeGetTensorImpl()->set_data(host_src_ptr); + if (signal) { + signal->prepare_signal(); + signal->signal_complete(); + } + ET_LOG(Debug, + "[mem] metal: upload_from_host caller_ptr=%p bytes=%zu (re-aliased zero-copy)", + host_src_ptr, nbytes); + return runtime::Error::Ok; + } + + // 6. Fallback: Metal refused zero-copy on this pointer. memcpy into + // the existing MetalBuffer's pool-allocated storage. + void* dst_ptr = mb->host_ptr(); + if (!dst_ptr) { + if (signal) signal->signal_failed(runtime::Error::InvalidArgument); + return runtime::Error::InvalidArgument; + } + std::memcpy(dst_ptr, host_src_ptr, nbytes); + dst_t.unsafeGetTensorImpl()->set_data(dst_ptr); + if (signal) { + signal->prepare_signal(); + signal->signal_complete(); + } + ET_LOG(Debug, + "[mem] metal: upload_from_host caller_ptr=%p bytes=%zu (memcpy fallback; zero-copy refused)", + host_src_ptr, nbytes); + return runtime::Error::Ok; +} + +runtime::Error MetalInstance::download_to_host( + runtime::EValue& dev_src_ev, + Buffer* dev_src_buf, + runtime::EValue& host_dst_ev, + void* host_dst_ptr, + QueueKind /*queue*/, + runtime::Span wait_for, + Event* signal) { + // 1. Wait on producer event(s). + if (auto e = check_dependencies_(wait_for, signal); + e != runtime::Error::Ok) { + return e; + } + + // 2. Validate. + if (!host_dst_ptr || !dev_src_buf || !dev_src_ev.isTensor() || + !host_dst_ev.isTensor()) { + if (signal) signal->signal_failed(runtime::Error::InvalidArgument); + return runtime::Error::InvalidArgument; + } + + auto& src_t = dev_src_ev.toTensor(); + auto& dst_t = host_dst_ev.toTensor(); + size_t nbytes = src_t.nbytes(); + + // 3. Propagate shape. + if (auto e = runtime::resize_tensor(dst_t, src_t.sizes()); + e != runtime::Error::Ok) { + if (signal) signal->signal_failed(e); + return e; + } + + auto* mb = static_cast(dev_src_buf); + + // 4. Sync GPU work so any pending writes to dev_src_buf are visible. + stream_->sync(); + + // 5. Skip-if-same: dev_src_buf already aliases host_dst_ptr (kernel + // wrote directly into caller storage). + if (mb->host_ptr() == host_dst_ptr) { + if (signal) { + signal->prepare_signal(); + signal->signal_complete(); + } + ET_LOG(Debug, + "[mem] metal: download_to_host caller_ptr=%p bytes=%zu (alias unchanged)", + host_dst_ptr, nbytes); + return runtime::Error::Ok; + } + + // 6. Try zero-copy alias (uncommon for outputs but symmetric). + if (stream_->registerExternalBuffer( + host_dst_ptr, nbytes, /*strict_zero_copy=*/true)) { + // Need to memcpy current bytes since the buffer just got rebound. + void* old_ptr = mb->host_ptr(); + std::memcpy(host_dst_ptr, old_ptr, nbytes); + mb->re_alias(host_dst_ptr, nbytes); + if (signal) { + signal->prepare_signal(); + signal->signal_complete(); + } + ET_LOG(Debug, + "[mem] metal: download_to_host caller_ptr=%p bytes=%zu (re-aliased after copy)", + host_dst_ptr, nbytes); + return runtime::Error::Ok; + } + + // 7. Fallback: memcpy from device buffer to caller storage. + void* src_ptr = mb->host_ptr(); + if (!src_ptr) { + if (signal) signal->signal_failed(runtime::Error::InvalidArgument); + return runtime::Error::InvalidArgument; + } + std::memcpy(host_dst_ptr, src_ptr, nbytes); + if (signal) { + signal->prepare_signal(); + signal->signal_complete(); + } + ET_LOG(Debug, + "[mem] metal: download_to_host caller_ptr=%p bytes=%zu (memcpy fallback)", + host_dst_ptr, nbytes); + return runtime::Error::Ok; +} + +runtime::Error MetalInstance::execute( + CompiledSegment* segment, + runtime::Span values, + BindingView bindings, + runtime::Span wait_for, + Event* signal) { + if (auto e = check_dependencies_(wait_for, signal); + e != runtime::Error::Ok) { + return e; + } + + auto* seg = static_cast(segment); + if (!seg) return runtime::Error::InvalidArgument; + const auto* graph = seg->graph(); + if (!graph) return runtime::Error::InvalidState; + + // Verify each tensor EValue's TensorImpl::data_ptr matches the + // currently-bound MetalBuffer's host_ptr (= [mtlBuffer contents]) + // after remap. The data_ptr is established at prebind / bind_inputs / + // upload_from_host time; by the time we get here it MUST match the + // bound Buffer. + if (auto e = verify_segment_bindings(seg, graph, values, bindings, + signal, "MetalInstance"); + e != runtime::Error::Ok) { + return e; + } + + // Dispatch each instruction via the metal_v2 registry. + auto& registry = metal_v2_ns::MetalOpRegistry::shared(); + for (uint32_t instr_idx : seg->instruction_indices()) { + auto op = graph->get_instruction(instr_idx); + const char* op_name = op.name(); + if (!op_name) { + if (signal) signal->signal_failed(runtime::Error::InvalidProgram); + return runtime::Error::InvalidProgram; + } + metal_v2_ns::MetalOp* metal_op = registry.get(op_name); + if (!metal_op) { + ET_LOG(Error, "MetalInstance: op '%s' not in MetalOpRegistry", op_name); + if (signal) signal->signal_failed(runtime::Error::NotSupported); + return runtime::Error::NotSupported; + } + + ET_LOG(Debug, "MetalInstance: instr %u op='%s' (in=%zu, out=%zu)", + instr_idx, op_name, op.num_inputs(), op.num_outputs()); + + // Build EValue* vectors for the op (with value_remap applied). + std::vector ins; + std::vector outs; + ins.reserve(op.num_inputs()); + outs.reserve(op.num_outputs()); + for (size_t i = 0; i < op.num_inputs(); ++i) { + uint32_t vid = seg->remap(op.input(i)); + if (vid < values.size()) ins.push_back(&values[vid]); + } + for (size_t i = 0; i < op.num_outputs(); ++i) { + uint32_t vid = seg->remap(op.output(i)); + if (vid < values.size()) outs.push_back(&values[vid]); + } + + metal_op->dispatch( + stream_, + runtime::Span(ins.data(), ins.size()), + runtime::Span(outs.data(), outs.size())); + } + + // Flush + wait so the shape-on-event contract holds: by the time + // signal goes Complete, host can read both shape (already on + // TensorImpls — set host-side by each op's resizeOutput before + // dispatch) AND bytes (ensured by stream->wait()). + stream_->sync(); + + if (signal) { + signal->prepare_signal(); + signal->signal_complete(); + } + return runtime::Error::Ok; +} + +runtime::Error MetalInstance::wait(Event* event) { + if (!event) return runtime::Error::Ok; + // Spin until the event's atomic status reaches a terminal state. + // For execute() and transfer_tensor(), we synchronously sync the + // stream before signaling, so by the time the executor reaches a + // wait() the event is already Complete and this is a few-load + // hot path. + while (true) { + auto s = event->status(); + if (s == EventStatus::Complete) return runtime::Error::Ok; + if (s == EventStatus::Failed || s == EventStatus::Poisoned) { + return event->error(); + } + // Pending: yield. (Acceptable for v1; a condvar is the future + // optimization if we move signaling into completion handlers.) + } +} + +void MetalInstance::drain() { + if (stream_) stream_->sync(); +} + +void MetalInstance::release_buffer(Buffer* buf) { + // All Buffers live in owned_buffers_; their unique_ptrs run their + // destructors at ~MetalInstance, which return Owned-mode pool memory + // and release NdmAlias FreeableBuffers. Aliasing-mode Buffers (those + // re-aliased via upload_from_host) are no-op destructors. + // Per-Plan release_buffer is a no-op; lifetime is tied to the Instance. + (void)buf; +} + +} // namespace portable_v2 +} // namespace backends +} // namespace executorch diff --git a/backends/portable/runtime_v2/metal/MetalProvider.h b/backends/portable/runtime_v2/metal/MetalProvider.h new file mode 100644 index 00000000000..8b08108b5c1 --- /dev/null +++ b/backends/portable/runtime_v2/metal/MetalProvider.h @@ -0,0 +1,84 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include +#include +#include + +// Forward declaration: keep MetalProvider.h pure-C++; the .mm +// implementation includes MetalStream.h. +namespace executorch { +namespace backends { +namespace metal_v2 { +class MetalStream; +} // namespace metal_v2 +} // namespace backends +} // namespace executorch + +namespace executorch { +namespace backends { +namespace portable_v2 { + +/** + * Metal Provider — Apple-silicon accelerator using the existing + * `metal_v2` infrastructure (MetalStream, MetalOpRegistry, MetalOp). + * + * Owns a private MetalStream (via MetalStream::create()). Constructable + * iff MTLCreateSystemDefaultDevice() returns non-nil. + * + * Op support is delegated to MetalOpRegistry: can_run() returns Some + * iff the op's name is registered there. Initial set: aten::add / + * aten::mul / aten::sub / aten::relu / aten::mm / aten::bmm. + * + * Buffers are MetalBuffers wrapping host pointers from + * MetalStream::alloc() (Apple Silicon unified memory means + * MTLBuffer.contents() is a directly-addressable host pointer). + * + * v1 constraint: at most one non-CPU provider per process. MetalProvider + * is mutually exclusive with FakeAccelProvider in + * make_default_providers(). + */ +class MetalProvider final : public Provider { + public: + MetalProvider(); + ~MetalProvider() override; + + // Returns true iff stream construction succeeded (i.e., MTLDevice + // present). Use after construction to decide whether to register. + bool stream_ready() const; + + std::string_view name() const override { return "metal"; } + bool is_available_on_device() const override { return stream_ready(); } + + bool can_run(const OpDescriptor& op) const override; + + RuntimeContext& context() override { return ctx_; } + + std::unique_ptr instantiate() override; + + // Used by MetalInstance to obtain the stream we own. + ::executorch::backends::metal_v2::MetalStream* stream() { return stream_.get(); } + + private: + // Tag-only RuntimeContext (we don't carry any per-Provider state in + // it; the MetalStream IS the per-Provider state and is held directly + // on this class). + struct MetalRuntimeContext : public RuntimeContext {}; + + std::unique_ptr<::executorch::backends::metal_v2::MetalStream> stream_; + MetalRuntimeContext ctx_; + std::atomic next_instance_id_{0}; +}; + +} // namespace portable_v2 +} // namespace backends +} // namespace executorch diff --git a/backends/portable/runtime_v2/metal/MetalProvider.mm b/backends/portable/runtime_v2/metal/MetalProvider.mm new file mode 100644 index 00000000000..a7a75d6dded --- /dev/null +++ b/backends/portable/runtime_v2/metal/MetalProvider.mm @@ -0,0 +1,58 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include +#include + +#include + +#include + +namespace executorch { +namespace backends { +namespace portable_v2 { + +MetalProvider::MetalProvider() { + // MetalStream::create() returns nullptr internally if MTLDevice + // construction fails (no Metal-capable GPU). The wrapping unique_ptr + // would still be valid even if the underlying state isn't fully + // constructed; check stream_ready() before using. + stream_ = ::executorch::backends::metal_v2::MetalStream::create(); + if (!stream_) { + ET_LOG(Info, "MetalProvider: MetalStream::create() returned nullptr; " + "no Metal-capable device available"); + } +} + +MetalProvider::~MetalProvider() = default; + +bool MetalProvider::stream_ready() const { + return stream_ != nullptr && stream_->device() != nil; +} + +bool MetalProvider::can_run(const OpDescriptor& op) const { + if (!stream_ready()) return false; + std::string n(op.name); + return ::executorch::backends::metal_v2::MetalOpRegistry::shared().hasOp(n); + // dtype-level filtering would happen here once OpDescriptor carries + // dtype info; for v1 we accept unconditionally (assume Float, which is + // what the metal_v2 ops default to). +} + +std::unique_ptr MetalProvider::instantiate() { + InstanceId id = next_instance_id_.fetch_add(1, std::memory_order_relaxed); + return std::make_unique(this, id); +} + +} // namespace portable_v2 +} // namespace backends +} // namespace executorch diff --git a/backends/portable/runtime_v2/routers/GreedyRouter.cpp b/backends/portable/runtime_v2/routers/GreedyRouter.cpp new file mode 100644 index 00000000000..e20db6bc5b1 --- /dev/null +++ b/backends/portable/runtime_v2/routers/GreedyRouter.cpp @@ -0,0 +1,598 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace portable_v2 { + +namespace { + +using ::executorch::backends::portable::Graph; +using ::executorch::backends::portable::OperatorCall; +using ::executorch::runtime::Error; +using ::executorch::runtime::Result; +using ::executorch::runtime::Span; +using ValueType = ::executorch::backends::portable::ValueType; + +// Pick a provider for an op. Priority: any provider at index >= 1 that +// returns true from can_run wins (host-slot CPU is the fallback). +int pick_provider(const OperatorCall& op, Span providers) { + OpDescriptor desc; + desc.name = op.name() ? op.name() : ""; + + // Try non-host providers first (index 1..N). + for (size_t i = 1; i < providers.size(); ++i) { + if (providers[i]->can_run(desc)) { + return static_cast(i); + } + } + // Fallback: host (index 0) if available. + if (!providers.empty() && providers[0]->can_run(desc)) { + return 0; + } + return -1; +} + +} // namespace + +Result GreedyRouter::route( + const Graph& graph, + Span providers, + Span instances, + const ::executorch::runtime::NamedDataMap* ndm, + const RouterOptions& options) { + if (providers.size() != instances.size()) return Error::InvalidArgument; + if (providers.empty()) return Error::InvalidArgument; + + Plan plan; + plan.providers.assign(providers.begin(), providers.end()); + plan.instances.assign(instances.begin(), instances.end()); + + // ---- 1. Per-instruction provider assignment --------------------------- + std::vector assignments; + assignments.reserve(graph.num_instructions()); + for (uint32_t i = 0; i < graph.num_instructions(); ++i) { + OperatorCall op = graph.get_instruction(i); + int p = pick_provider(op, providers); + if (p < 0) { + ET_LOG(Error, "GreedyRouter: no provider for op '%s'", + op.name() ? op.name() : "?"); + return Error::NotSupported; + } + ET_LOG(Debug, "GreedyRouter: instr %u op='%s' -> provider %d (%s)", i, + op.name() ? op.name() : "?", p, + std::string(providers[p]->name()).c_str()); + assignments.push_back(p); + } + + // ---- 2. Group consecutive same-runtime instructions into segments ----- + struct PendingSegment { + int provider_idx; + std::vector instruction_indices; + std::set input_value_ids; // consumed-but-not-produced + std::set output_value_ids; // produced + }; + std::vector segments; + + if (!assignments.empty()) { + segments.push_back({assignments[0], {0}, {}, {}}); + } + for (uint32_t i = 0; i < assignments.size(); ++i) { + if (i > 0 && assignments[i] != segments.back().provider_idx) { + segments.push_back({assignments[i], {i}, {}, {}}); + } else if (i > 0) { + segments.back().instruction_indices.push_back(i); + } + OperatorCall op = graph.get_instruction(i); + auto& seg = segments.back(); + for (size_t j = 0; j < op.num_inputs(); ++j) { + uint32_t v = op.input(j); + if (seg.output_value_ids.count(v) == 0) seg.input_value_ids.insert(v); + } + for (size_t j = 0; j < op.num_outputs(); ++j) { + seg.output_value_ids.insert(op.output(j)); + } + } + ET_LOG(Debug, "GreedyRouter: built %zu segments", segments.size()); + + // ---- 3. Build value_id -> producing-segment-index map ---------------- + std::unordered_map value_producer_seg; + for (size_t s = 0; s < segments.size(); ++s) { + for (uint32_t v : segments[s].output_value_ids) { + value_producer_seg[v] = s; + } + } + + // ---- 4. Collect graph IO and value-meta ------------------------------- + std::set io_ids; + for (size_t i = 0; i < graph.num_input_ids(); ++i) + io_ids.insert(graph.input_id(i)); + for (size_t i = 0; i < graph.num_output_ids(); ++i) + io_ids.insert(graph.output_id(i)); + + // Provider that consumes each value (could be multiple). + std::unordered_map> value_consumer_providers; + for (auto& seg : segments) { + for (uint32_t v : seg.input_value_ids) { + value_consumer_providers[v].insert(seg.provider_idx); + } + } + + // ---- 5. Decide where each value's "home" Buffer lives ---------------- + // Unified rule: + // - Graph IO → home = host (provider 0). + // - Aliased group (same mem_obj_id, set by AOT memory planner): if + // ALL ops touching ANY value in the group are on a single non-host + // runtime, home = that runtime (e.g. KV cache stays on Vulkan when + // all ops including the writeback are on Vulkan). Otherwise home = + // host (correctness via cross-runtime mirrors; perf hit acceptable). + // - Cross-runtime intermediate → home = host. + // - Otherwise (single-runtime) → home = producer's runtime. + // + // Allocating cross-runtime values on host means every synthetic mirror + // sources from host, so allocate_buffers' host-first iteration order + // naturally satisfies dependencies — no two-pass needed. + + // Build mem_obj_id groups (only for values with mem_obj_id >= 0; the + // sort-and-index Graph::mem_obj_id assigns -1 for non-tensor / no-alloc + // values, and a unique non-negative id per (pool, offset) slot). + std::unordered_map> mem_obj_groups; + for (uint32_t i = 0; i < graph.num_values(); ++i) { + if (graph.value_type(i) != ValueType::Tensor) continue; + if (graph.tensor_constant_data_key(i) != nullptr) continue; + int32_t mid = graph.mem_obj_id(i); + if (mid >= 0) mem_obj_groups[mid].push_back(i); + } + + // Identify SEMANTIC alias mem_obj_ids: groups that include a mutable + // buffer placeholder (value with allocation info, not graph IO, not + // constant, not produced by any op). These come from tag_mutated_buffer + // + AOT spec-sharing: the placeholder and the mutation source share a + // mem_obj_id by design, so writing to the source IS writing to the + // buffer — they MUST share storage at runtime. + // + // Other mem_obj_id groups (e.g. lifetime reuse where the planner + // packed two values with non-overlapping lifetimes into one slot) + // don't have this constraint — they're aliased in storage but not + // semantically the same data; per-value home rules suffice. + std::set semantic_alias_mids; + for (size_t i = 0; i < graph.num_mutable_buffer_ids(); ++i) { + uint32_t buf_vid = graph.mutable_buffer_id(i); + int32_t mid = graph.mem_obj_id(buf_vid); + if (mid >= 0) semantic_alias_mids.insert(mid); + } + + // Helper: union of all touching runtimes for one value_id. + auto touching_runtimes = [&](uint32_t v) -> std::set { + std::set rts; + auto pit = value_producer_seg.find(v); + if (pit != value_producer_seg.end()) { + rts.insert(segments[pit->second].provider_idx); + } + auto cit = value_consumer_providers.find(v); + if (cit != value_consumer_providers.end()) { + for (int p : cit->second) rts.insert(p); + } + return rts; + }; + + std::unordered_map value_home_provider; + for (uint32_t i = 0; i < graph.num_values(); ++i) { + if (graph.value_type(i) != ValueType::Tensor) continue; + if (graph.tensor_constant_data_key(i) != nullptr) continue; // constant + if (io_ids.count(i) > 0) { + value_home_provider[i] = 0; // graph IO → host + continue; + } + + // Aliased-group rule: applies ONLY to semantic alias groups (where + // the planner deliberately shares storage between a buffer placeholder + // and its mutation source). Lifetime-reuse aliasing falls through to + // per-value rules. + int32_t mid = graph.mem_obj_id(i); + if (mid >= 0 && semantic_alias_mids.count(mid)) { + auto git = mem_obj_groups.find(mid); + if (git != mem_obj_groups.end() && git->second.size() > 1) { + std::set all_rts; + for (uint32_t v : git->second) { + for (int p : touching_runtimes(v)) all_rts.insert(p); + } + // Single non-host runtime touches all members → home there + // (fast path: zero cross-runtime traffic). + if (all_rts.size() == 1 && *all_rts.begin() != 0) { + value_home_provider[i] = *all_rts.begin(); + } else { + // Mixed runtimes (or all on host) → home = host. Correctness + // via cross-runtime mirrors; non-unified backends (Vulkan) + // pay upload+download per execute. + value_home_provider[i] = 0; + } + continue; + } + } + + // Non-aliased: per-value rule. + auto pit = value_producer_seg.find(i); + auto cit = value_consumer_providers.find(i); + + if (pit != value_producer_seg.end()) { + // Has a producer in the delegate (intermediate). Cross-runtime if + // any cross-segment consumer is on a different runtime than the + // producer. + int producer_p = segments[pit->second].provider_idx; + bool cross_runtime = false; + if (cit != value_consumer_providers.end()) { + for (int c : cit->second) { + if (c != producer_p) { cross_runtime = true; break; } + } + } + value_home_provider[i] = cross_runtime ? 0 : producer_p; + } else { + // No producer = placeholder (mutable buffer pulled into delegate + // by tag_mutated_buffer; not a graph input). Home = consumer's + // runtime if single-runtime consumer set, else host. + if (cit == value_consumer_providers.end()) continue; // truly unused + const std::set& consumers = cit->second; + if (consumers.size() == 1) { + value_home_provider[i] = *consumers.begin(); + } else { + value_home_provider[i] = 0; + } + } + } + + // ---- 6. Upload constants and emit alloc-plans for intermediates ----- + // Constants are uploaded to ALL providers that consume them (synchronous + // upload_constant call; not subject to allocate_all). For each + // (provider, value_id, mem_obj_id) tuple we dedupe via a map. + // Intermediates are NOT allocated here — the router emits an + // AllocRequest entry into plan.alloc_plans[home_provider]; the + // executor's allocate_buffers step (called after route + materialize) + // performs the actual allocation via allocate_all. + + // Initialize alloc_plans (one entry per provider). + plan.alloc_plans.assign(providers.size(), {}); + + // (provider_idx, mem_obj_id) -> sentinel to track if mem_obj_id already + // emitted for this provider (so we emit at most one AllocRequest per + // (provider, mem_obj_id) group). + std::set> mem_id_emitted; + + for (uint32_t i = 0; i < graph.num_values(); ++i) { + if (graph.value_type(i) != ValueType::Tensor) continue; + + if (const char* key = graph.tensor_constant_data_key(i); key != nullptr) { + // Constant. Upload to each consuming provider via upload_constant + // (separate path from allocate_all; persistent zero-copy alias). + if (!ndm) { + ET_LOG(Error, "GreedyRouter: constant '%s' needs NDM", key); + return Error::InvalidArgument; + } + auto it = value_consumer_providers.find(i); + std::set consumers = + it != value_consumer_providers.end() ? it->second : std::set{0}; + for (int p : consumers) { + auto buf_result = instances[p]->upload_constant(*ndm, key); + if (!buf_result.ok()) { + ET_LOG(Error, "GreedyRouter: upload_constant('%s') on provider %d failed", + key, p); + return buf_result.error(); + } + ET_LOG(Debug, + "[mem] router: upload_constant value_id=%u key='%s' provider=%d (%s) bytes=%zu", + i, key, p, + std::string(providers[p]->name()).c_str(), + buf_result.get()->size_bytes()); + if (p == *consumers.begin()) { + plan.owned_buffers.push_back({buf_result.get(), instances[p], i}); + } + } + continue; + } + + if (io_ids.count(i) > 0) continue; // emitted in graph-IO loop below + + // Intermediate: emit AllocRequest on the value's home provider. + auto hit = value_home_provider.find(i); + if (hit == value_home_provider.end()) continue; // unused + int home_p = hit->second; + + int32_t mem_id = graph.mem_obj_id(i); + size_t nbytes = graph.tensor_nbytes_max(i); + if (nbytes == 0) continue; + + // Always emit an AllocRequest per value_id. Even if mem_id is + // shared with another value (e.g., AOT memory planner aliased them + // because their lifetimes don't overlap, or because of an op like + // aten::copy_'s formal `out` sharing data with `src`), each value_id + // gets its own binding entry. Backends that want to honor mem_id + // sharing as actual storage aliasing (Vulkan SharedObject style) + // can do so internally based on req.mem_obj_id; for our current + // backends each value_id gets its own Buffer. + Instance::AllocRequest req; + req.value_id = i; + req.mem_obj_id = mem_id; + req.host_alias = nullptr; + plan.alloc_plans[home_p].push_back(req); + ET_LOG(Debug, + "[mem] router: alloc-request intermediate value_id=%u home_provider=%d (%s) mem_id=%d nbytes=%zu", + i, home_p, + std::string(providers[home_p]->name()).c_str(), + mem_id, nbytes); + } + + // Add IO destination requests on the host provider (slot 0). bind_inputs + // / bind_outputs will re-alias these to caller storage per execute. + for (size_t i = 0; i < graph.num_input_ids(); ++i) { + Instance::AllocRequest req; + req.value_id = graph.input_id(i); + req.mem_obj_id = -1; + req.host_alias = nullptr; + plan.alloc_plans[0].push_back(req); + ET_LOG(Debug, + "[mem] router: alloc-request graph input value_id=%u provider=0 (host)", + req.value_id); + } + for (size_t i = 0; i < graph.num_output_ids(); ++i) { + Instance::AllocRequest req; + req.value_id = graph.output_id(i); + req.mem_obj_id = -1; + req.host_alias = nullptr; + plan.alloc_plans[0].push_back(req); + ET_LOG(Debug, + "[mem] router: alloc-request graph output value_id=%u provider=0 (host)", + req.value_id); + } + + // ---- 6. For each cross-segment value, synthesize destination ---------- + // For each segment's input value V where the producing segment's provider + // differs from this segment's provider: + // - Synthesize a destination value_id V_synth on dst_p. + // - Emit an AllocRequest for V_synth into plan.alloc_plans[dst_p]. + // The executor's allocate_buffers step will patch host_alias to + // point at V's Buffer (after V is allocated on src_p) so the + // destination backend can choose to zero-copy alias the host + // pointer (CPU/Apple-Silicon Metal) or allocate fresh + copy + // per-execute (Vulkan). + // - Emit a TransferStep V → V_synth (always — per-execute work is + // a no-op skip-if-same on host-addressable runtimes). + // - Record V → V_synth remap for the destination CompiledSegment. + + uint32_t next_synth_id = static_cast(graph.num_values()); + // segment idx -> (V_orig, V_synth) pairs to pass as remap. + std::vector>> seg_remaps( + segments.size()); + // For each segment, the list of (src_value_id, dst_value_id) transfers. + // Inserted before the segment's ComputeStep. + struct PendingTransfer { + uint32_t src_value_id; // V (producer's binding) + uint32_t dst_value_id; // V' (this segment's view) + int src_provider_idx; + int dst_provider_idx; + }; + std::vector> seg_transfers(segments.size()); + + for (size_t s = 0; s < segments.size(); ++s) { + auto& seg = segments[s]; + int dst_p = seg.provider_idx; + for (uint32_t v : seg.input_value_ids) { + auto pit = value_producer_seg.find(v); + if (pit == value_producer_seg.end()) { + // Graph input or constant — handled by bind_inputs / step 5. + continue; + } + int src_p = segments[pit->second].provider_idx; + if (src_p == dst_p) continue; // same provider, no transfer + + // Dedup: if v already has an allocation on dst_p (e.g., it's a + // graph output whose host alloc was emitted in step 5, and dst_p + // is host), the consumer can read v directly from its existing + // Buffer. Skipping the redundant mirror is also REQUIRED for + // correctness here: if v is allocated in dst_p's plan AND we + // emit a synthetic mirror also in dst_p's plan sourcing from v, + // allocate_buffers' host_alias patching can't resolve the source + // (the source is in the same provider's plan and isn't in the + // value_to_buf ledger until that plan's allocate_all returns). + bool v_already_on_dst = false; + for (const auto& req : plan.alloc_plans[dst_p]) { + if (req.value_id == v) { v_already_on_dst = true; break; } + } + if (v_already_on_dst) { + ET_LOG(Debug, + "[mem] router: skip mirror for value_id=%u in seg=%zu " + "(already allocated on provider=%d)", + v, s, dst_p); + continue; + } + + // Only tensors get transferred / aliased. + if (graph.value_type(v) != ValueType::Tensor) continue; + size_t nbytes = graph.tensor_nbytes_max(v); + if (nbytes == 0) continue; + + uint32_t v_synth = next_synth_id++; + + // Emit AllocRequest for the synthetic value on dst_p. + // host_alias is left null here; the executor patches it to point + // at v's Buffer (looked up from already-allocated src_p) before + // calling allocate_all. mem_obj_id = -1 (synthetic dedicated). + Instance::AllocRequest req; + req.value_id = v_synth; + req.mem_obj_id = -1; + req.host_alias = nullptr; // patched by allocate_buffers + plan.alloc_plans[dst_p].push_back(req); + + plan.synthetic_values.push_back({v_synth, v}); + seg_remaps[s].push_back({v, v_synth}); + seg_transfers[s].push_back({v, v_synth, src_p, dst_p}); + + ET_LOG(Debug, + "[mem] router: cross-runtime mirror seg=%zu value_id=%u (%s) -> synth_id=%u (%s) bytes=%zu", + s, v, std::string(providers[src_p]->name()).c_str(), + v_synth, std::string(providers[dst_p]->name()).c_str(), + nbytes); + } + } + + // ---- 6b. Producer-side mirror for any value whose home is host but + // whose producing segment is on a non-host provider. ------- + // Includes: + // - Graph outputs produced by a non-host segment. + // - Cross-runtime intermediates (homed on host by step 5). + // For each such value V: + // - Synthesize V_synth on producer's runtime, source = V's host alloc. + // - Add seg_remap so producer kernel writes to V_synth (which + // host_aliases V's host buffer; bytes land in host memory directly + // on Apple Silicon Metal — zero copy). + // - Emit TransferStep host -> producer BEFORE producer's ComputeStep + // so V_synth's Metal Buffer re-aliases V's current host_ptr per + // execute (matches bind_outputs / bind_inputs rebinding). + // + // This is the symmetric counterpart of the consumer-side mirror loop + // above. Together they handle all cross-runtime data flow with the + // same machinery; the value's "home" decides which side gets the mirror. + for (uint32_t v : [&]() { + std::vector vs; + for (const auto& kv : value_home_provider) { + if (kv.second == 0) vs.push_back(kv.first); + } + std::sort(vs.begin(), vs.end()); + return vs; + }()) { + auto pit = value_producer_seg.find(v); + if (pit == value_producer_seg.end()) continue; // graph input (no producer in delegate) + int producer_p = segments[pit->second].provider_idx; + if (producer_p == 0) continue; // producer IS host; writes directly + if (graph.value_type(v) != ValueType::Tensor) continue; + + uint32_t v_synth = next_synth_id++; + Instance::AllocRequest req; + req.value_id = v_synth; + req.mem_obj_id = -1; + req.host_alias = nullptr; // patched by allocate_buffers from v's host Buffer + plan.alloc_plans[producer_p].push_back(req); + + plan.synthetic_values.push_back({v_synth, v}); + seg_remaps[pit->second].push_back({v, v_synth}); + // Per-execute re-alias: host -> producer mirror, runs BEFORE the + // producing segment's ComputeStep. + seg_transfers[pit->second].push_back( + {v, v_synth, /*src_p=*/0, /*dst_p=*/producer_p}); + + ET_LOG(Debug, + "[mem] router: producer-side mirror value_id=%u (host) -> " + "synth_id=%u (%s) for producing seg=%zu", + v, v_synth, + std::string(providers[producer_p]->name()).c_str(), + pit->second); + } + + // ---- 7. Compile each segment with its value remap -------------------- + std::vector compiled_segments; + for (size_t s = 0; s < segments.size(); ++s) { + auto& seg = segments[s]; + Instance* inst = instances[seg.provider_idx]; + std::vector ins(seg.input_value_ids.begin(), + seg.input_value_ids.end()); + std::vector outs(seg.output_value_ids.begin(), + seg.output_value_ids.end()); + ET_LOG(Debug, + "[mem] router: compile_segment %zu provider=%d (%s) instructions=%zu " + "inputs=%zu outputs=%zu remaps=%zu", + s, seg.provider_idx, + std::string(providers[seg.provider_idx]->name()).c_str(), + seg.instruction_indices.size(), ins.size(), + outs.size(), seg_remaps[s].size()); + auto r = inst->compile_segment( + graph, + Span(seg.instruction_indices.data(), + seg.instruction_indices.size()), + Span(ins.data(), ins.size()), + Span(outs.data(), outs.size()), + Span>(seg_remaps[s].data(), + seg_remaps[s].size())); + if (!r.ok()) return r.error(); + compiled_segments.push_back(r.get()); + } + + // ---- 8. Reserve graph input/output bindings -------------------------- + // Input/output destination Buffers are pre-allocated by the executor's + // allocate_buffers() call (post-route) and bound persistently by + // prebind_owned_buffers. bind_inputs/bind_outputs re-alias the + // existing Buffer in place each execute via upload_from_host. + for (size_t i = 0; i < graph.num_input_ids(); ++i) { + InputBinding ib; + ib.loc = Location::host(); + ib.value_id = graph.input_id(i); + plan.inputs.push_back(ib); + } + for (size_t i = 0; i < graph.num_output_ids(); ++i) { + OutputBinding ob; + ob.loc = Location::host(); + ob.value_id = graph.output_id(i); + plan.outputs.push_back(ob); + } + + // ---- 9. Emit Steps in order: per-segment transfers, then ComputeStep -- + for (size_t s = 0; s < segments.size(); ++s) { + for (const auto& xfer : seg_transfers[s]) { + TransferStep ts; + ts.src_value_id = xfer.src_value_id; + ts.dst_value_id = xfer.dst_value_id; + ts.src = (xfer.src_provider_idx == 0) + ? Location::host() + : Location::on(providers[xfer.src_provider_idx]->id()); + ts.dst = (xfer.dst_provider_idx == 0) + ? Location::host() + : Location::on(providers[xfer.dst_provider_idx]->id()); + ts.src_idx = static_cast(xfer.src_provider_idx); + ts.dst_idx = static_cast(xfer.dst_provider_idx); + ts.queue = QueueKind::Transfer; + ts.signal = kNoEvent; + plan.steps.emplace_back(ts); + } + + ComputeStep cs; + cs.runtime_idx = static_cast(segments[s].provider_idx); + cs.segment = compiled_segments[s]; + cs.queue = QueueKind::Compute; + cs.signal = kNoEvent; + plan.steps.emplace_back(cs); + } + + if (options.dump_trace) { + ET_LOG(Info, + "GreedyRouter: %zu segments / %zu steps / %zu owned buffers / " + "%zu synthetic values", + segments.size(), plan.steps.size(), plan.owned_buffers.size(), + plan.synthetic_values.size()); + } + + return plan; +} + +} // namespace portable_v2 +} // namespace backends +} // namespace executorch diff --git a/backends/portable/runtime_v2/routers/GreedyRouter.h b/backends/portable/runtime_v2/routers/GreedyRouter.h new file mode 100644 index 00000000000..1e89c12cac3 --- /dev/null +++ b/backends/portable/runtime_v2/routers/GreedyRouter.h @@ -0,0 +1,56 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace executorch { +namespace backends { +namespace portable_v2 { + +/** + * Default greedy router. See §4.10 of PORTABLE_BACKEND_API_PROPOSAL.md. + * + * Algorithm: + * 1. Build dense index space (CPU at index 0; others by registration order). + * 2. For each instruction, ask each Provider in priority order; pick the + * first whose can_run() returns true. + * 3. Group consecutive same-runtime instructions into segments; call + * Instance::compile_segment for each. + * 4. For each value crossing a segment boundary: synthesize a destination + * value_id on the consumer side, emit a TransferStep, and emit an + * AllocRequest carrying the source's host buffer as a host_alias hint + * (the executor patches the actual Buffer* before allocate_all). + * 5. Upload constants via Instance::upload_constant. + * 6. Emit per-provider AllocRequest lists into Plan::alloc_plans (the + * executor's allocate_buffers step performs the actual allocation + * host-first so device requests can resolve their host_alias hints). + * 7. Build Plan::inputs / outputs. + * + * v1 scope: + * - CPU-only happy path: one segment, no TransferSteps, host-aliased + * inputs/outputs (host-addressable runtimes alias caller storage zero-copy), constants + * uploaded via NDM (zero-copy alias). + * - Multi-provider routing: skeleton in place but the cross-runtime + * transfer machinery is left as TODO until the first non-CPU + * Provider lands. + */ +class GreedyRouter final : public Router { + public: + ::executorch::runtime::Result route( + const ::executorch::backends::portable::Graph& graph, + ::executorch::runtime::Span providers, + ::executorch::runtime::Span instances, + const ::executorch::runtime::NamedDataMap* ndm, + const RouterOptions& options) override; +}; + +} // namespace portable_v2 +} // namespace backends +} // namespace executorch diff --git a/backends/portable/runtime_v2/test_dyn_shapes.cpp b/backends/portable/runtime_v2/test_dyn_shapes.cpp new file mode 100644 index 00000000000..fdd90f87b1a --- /dev/null +++ b/backends/portable/runtime_v2/test_dyn_shapes.cpp @@ -0,0 +1,130 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/** + * Integration test: batch-varying execute() calls. + * + * Loads /tmp/dyn_linear_v2.pte (a Linear model with dynamic batch dim + * in [1, 8], traced at batch=3) and runs forward() with batch sizes + * {1, 3, 5, 8} in succession against the SAME loaded delegate, to + * verify true runtime-varying dynamic-shape behavior across our v2 + * runtime — including: + * + * - Buffer max-shape allocation (sized once at init for batch=8). + * - bind_inputs accepting a smaller-than-max input shape. + * - Per-execute reset of transient bindings. + * - Cross-runtime TransferStep (cpu → metal for permute_copy + * output) propagating the actual batch each call. + * - MetalInstance kernels reading the actual shape via + * TensorImpl.sizes() and producing correctly-shaped output. + * + * Expected output for each batch B: a [B, 5] tensor of zeros, since + * x = ones(B,4); weight=full(5,4,0.25); bias=full(5,-1). + * x @ weight.T + bias = 4*1*0.25 + (-1) = 0. + */ + +#include +#include + +#include +#include +#include +#include + +#include +#include +#include + +using ::executorch::extension::from_blob; +using ::executorch::extension::Module; +using ::executorch::runtime::Error; +using ::executorch::runtime::EValue; + +namespace { + +constexpr const char* kDefaultModelPath = "/tmp/dyn_linear_v2.pte"; + +bool run_one(Module& mod, int32_t batch) { + // Build a [batch, 4] float input filled with 1.0. The data buffer + // must outlive the execute() call. + std::vector data(static_cast(batch) * 4, 1.0f); + std::vector<::executorch::aten::SizesType> sizes = {batch, 4}; + + auto input_tensor = from_blob( + data.data(), sizes, ::executorch::aten::ScalarType::Float); + + auto outputs_result = mod.execute("forward", {EValue(input_tensor)}); + if (!outputs_result.ok()) { + std::fprintf(stderr, " batch=%d: execute() failed: 0x%x\n", + batch, static_cast(outputs_result.error())); + return false; + } + + const auto& outputs = outputs_result.get(); + if (outputs.empty() || !outputs[0].isTensor()) { + std::fprintf(stderr, " batch=%d: no tensor output\n", batch); + return false; + } + const auto& out_t = outputs[0].toTensor(); + + auto out_sizes = out_t.sizes(); + std::printf(" batch=%d: output sizes=[", batch); + for (size_t i = 0; i < out_sizes.size(); ++i) { + std::printf("%s%d", i ? ", " : "", + static_cast(out_sizes[i])); + } + std::printf("]"); + + // Verify: shape should be [batch, 5]. + bool shape_ok = (out_sizes.size() == 2 && out_sizes[0] == batch && + out_sizes[1] == 5); + + // Verify: every element should be 0.0. + const float* odata = out_t.const_data_ptr(); + size_t numel = static_cast(batch) * 5; + bool values_ok = true; + float max_abs = 0.0f; + for (size_t i = 0; i < numel; ++i) { + float a = std::fabs(odata[i]); + if (a > max_abs) max_abs = a; + if (a > 1e-5f) values_ok = false; + } + + std::printf(", max_abs=%g %s%s\n", max_abs, + shape_ok ? "shape ✓" : "shape ✗", + values_ok ? " values ✓" : " values ✗"); + return shape_ok && values_ok; +} + +} // namespace + +int main(int argc, char** argv) { + ::executorch::runtime::runtime_init(); + + const char* model_path = + (argc > 1) ? argv[1] : kDefaultModelPath; + std::printf("Loading model: %s\n", model_path); + + Module mod(model_path); + auto load_err = mod.load_forward(); + if (load_err != Error::Ok) { + std::fprintf(stderr, "load_forward failed: 0x%x\n", + static_cast(load_err)); + return 1; + } + + std::printf("\nBatch-varying execute() across one loaded delegate:\n"); + bool all_ok = true; + for (int32_t b : {1, 3, 5, 8}) { + if (!run_one(mod, b)) all_ok = false; + } + + std::printf("\n%s\n", all_ok ? "PASS \u2014 all batch sizes correct" + : "FAIL \u2014 at least one batch produced wrong output"); + return all_ok ? 0 : 1; +} diff --git a/backends/portable/runtime_v2/test_stateful.cpp b/backends/portable/runtime_v2/test_stateful.cpp new file mode 100644 index 00000000000..707f09517d0 --- /dev/null +++ b/backends/portable/runtime_v2/test_stateful.cpp @@ -0,0 +1,117 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/** + * Integration test: stateful model with a mutable buffer. + * + * Loads /tmp/stateful_add_v2.pte (a model with `register_buffer("acc", + * zeros(2,3))` that does `self.acc.add_(x); return self.acc.clone()`) + * and runs forward(ones(2,3)) three times against the SAME loaded + * delegate, expecting the accumulator to grow: + * + * call 1 → output = ones(2,3) + * call 2 → output = twos(2,3) + * call 3 → output = threes(2,3) + * + * This exercises the IR's mutable-buffer concept. The current Graph + * adapter has `num_mutable_buffer_ids() = 0` (stub), so this test + * surfaces what happens end-to-end when the executor doesn't know to + * preserve the mutable buffer's state across calls. + */ + +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +#include + +using ::executorch::extension::from_blob; +using ::executorch::extension::Module; +using ::executorch::runtime::Error; +using ::executorch::runtime::EValue; + +namespace { + +constexpr const char* kDefaultModelPath = "/tmp/stateful_add_v2.pte"; + +bool run_one(Module& mod, int call_idx, float expected) { + std::vector data(6, 1.0f); + std::vector<::executorch::aten::SizesType> sizes = {2, 3}; + auto input_tensor = from_blob( + data.data(), sizes, ::executorch::aten::ScalarType::Float); + + auto outputs_result = mod.execute("forward", {EValue(input_tensor)}); + if (!outputs_result.ok()) { + std::fprintf(stderr, " call %d: execute() failed: 0x%x\n", + call_idx, static_cast(outputs_result.error())); + return false; + } + const auto& outputs = outputs_result.get(); + if (outputs.empty() || !outputs[0].isTensor()) { + std::fprintf(stderr, " call %d: no tensor output\n", call_idx); + return false; + } + const auto& out_t = outputs[0].toTensor(); + auto out_sizes = out_t.sizes(); + bool shape_ok = + (out_sizes.size() == 2 && out_sizes[0] == 2 && out_sizes[1] == 3); + + const float* odata = out_t.const_data_ptr(); + float max_abs_err = 0.0f; + for (size_t i = 0; i < 6; ++i) { + float err = std::fabs(odata[i] - expected); + if (err > max_abs_err) max_abs_err = err; + } + bool values_ok = (max_abs_err < 1e-5f); + + std::printf(" call %d: out=[%.1f %.1f %.1f; %.1f %.1f %.1f] " + "expected=%.1f %s%s\n", + call_idx, + odata[0], odata[1], odata[2], + odata[3], odata[4], odata[5], + expected, + shape_ok ? "shape ✓" : "shape ✗", + values_ok ? " values ✓" : " values ✗"); + return shape_ok && values_ok; +} + +} // namespace + +int main(int argc, char** argv) { + ::executorch::runtime::runtime_init(); + + const char* model_path = (argc > 1) ? argv[1] : kDefaultModelPath; + std::printf("Loading model: %s\n", model_path); + + Module mod(model_path); + auto load_err = mod.load_forward(); + if (load_err != Error::Ok) { + std::fprintf(stderr, "load_forward failed: 0x%x\n", + static_cast(load_err)); + return 1; + } + + std::printf("\nStateful execute() — accumulator should grow:\n"); + bool all_ok = true; + for (int call = 1; call <= 3; ++call) { + float expected = static_cast(call); + if (!run_one(mod, call, expected)) all_ok = false; + } + + std::printf("\n%s\n", + all_ok ? "PASS \u2014 mutable buffer state preserved" + : "FAIL \u2014 mutable buffer state NOT preserved correctly"); + return all_ok ? 0 : 1; +} diff --git a/backends/portable/serialization/__init__.py b/backends/portable/serialization/__init__.py new file mode 100644 index 00000000000..f9db75ed410 --- /dev/null +++ b/backends/portable/serialization/__init__.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Serialization utilities for Portable Backend. + +The Portable Backend reuses the standard ExecuTorch serialization format +(ExecutionPlan in program.fbs). This module re-exports the relevant schema +classes for convenience. + +For memory aliasing, tensors with the same AllocationDetails.memory_id +share the same storage pool slot. This is computed by the memory planning +pass and preserved in the serialized program. +""" + +# Re-export standard ExecuTorch schema classes +from executorch.exir.schema import ( + AllocationDetails, + EValue, + ExecutionPlan, + Instruction, + KernelCall, + Operator, + Program, + Tensor, +) + +__all__ = [ + "AllocationDetails", + "EValue", + "ExecutionPlan", + "Instruction", + "KernelCall", + "Operator", + "Program", + "Tensor", +] diff --git a/backends/portable/test_export_v2.py b/backends/portable/test_export_v2.py new file mode 100644 index 00000000000..e33dda7339f --- /dev/null +++ b/backends/portable/test_export_v2.py @@ -0,0 +1,216 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Smoke-test export for PortableBackend_v2. + +Exports a tiny PTE that targets the new "PortableBackend_v2" backend +(registered by libportable_backend_v2.a). Reuses the existing +PortablePartitioner machinery; just overrides the delegation spec's +backend_id to point at the v2 backend. +""" + +import os +import sys + +# Make the in-repo executorch package importable by adding the parent of +# the executorch/ root to sys.path. +_EXECUTORCH_ROOT = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +) +_PARENT = os.path.dirname(_EXECUTORCH_ROOT) +for _p in (_PARENT, _EXECUTORCH_ROOT): + if _p not in sys.path: + sys.path.insert(0, _p) + +import torch +from torch.export import export + +from executorch.exir import to_edge, EdgeCompileConfig +from executorch.exir.backend.partitioner import DelegationSpec +from executorch.backends.portable.partitioner import PortablePartitioner +# Importing this triggers the Python-side BackendDetails registration +# under the name "PortableBackend_v2" (matched at runtime by the C++ +# register_backend call in runtime_v2/PortableBackend_v2.cpp). +from executorch.backends.portable.preprocess_v2 import PortableBackend_v2 # noqa: F401 + + +def _patch_partitioner_to_v2(p: PortablePartitioner) -> PortablePartitioner: + """Mutate p's delegation_spec so the backend_id points at v2.""" + old = p.delegation_spec + p.delegation_spec = DelegationSpec("PortableBackend_v2", old.compile_specs) + return p + + +def export_v2(model, example_inputs, name: str, dynamic_shapes=None) -> str: + print(f"\nExporting {name} (delegated to PortableBackend_v2)...") + if dynamic_shapes is not None: + exported = export(model, example_inputs, dynamic_shapes=dynamic_shapes) + else: + exported = export(model, example_inputs) + # Skip dim_order so .clone() exports as aten::clone (which we + # eventually want in our op registry) rather than + # dim_order_ops::_clone_dim_order (a dim-order-aware variant our + # v2 op registry doesn't dispatch). + edge = to_edge(exported, compile_config=EdgeCompileConfig(_skip_dim_order=True)) + + partitioner = _patch_partitioner_to_v2(PortablePartitioner()) + delegated = edge.to_backend(partitioner) + et_program = delegated.to_executorch() + + path = f"/tmp/{name}_v2.pte" + with open(path, "wb") as f: + f.write(et_program.buffer) + print(f" Saved to {path} ({len(et_program.buffer)} bytes)") + return path + + +def main() -> None: + print("=" * 60) + print("PortableBackend_v2 — smoke-test export") + print("=" * 60) + + # Tiny add: y = x + x + class TinyAdd(torch.nn.Module): + def forward(self, x): + return x + x + + export_v2(TinyAdd().eval(), (torch.ones(2, 3),), "tiny_add") + + # Add with constant: y = x + w (exercises NDM upload_constant path) + class AddConst(torch.nn.Module): + def __init__(self): + super().__init__() + self.w = torch.nn.Parameter(torch.full((2, 3), 0.5)) + + def forward(self, x): + return x + self.w + + export_v2(AddConst().eval(), (torch.ones(2, 3),), "add_const") + + # Two inputs, one output: + # t1 = x + y intermediate (mem_obj_id) + # t2 = t1 + x intermediate (may share with t1) + # out = t2 - y + # x = ones(2,3), y = ones(2,3)*2 → out = (1+2)+1 - 2 = 2 + class TwoInputsChain(torch.nn.Module): + def forward(self, x, y): + t1 = x + y + t2 = t1 + x + out = t2 - y + return out + + export_v2( + TwoInputsChain().eval(), + (torch.ones(2, 3), torch.full((2, 3), 2.0)), + "two_inputs_chain", + ) + + # Many sequential adds — exercises mem_obj_id sharing aggressively. + # Use tensor+tensor so the loop doesn't fold to a single op. + # Result: x + 10*y = 0 + 10*1 = 10 everywhere. + class ManyAdds(torch.nn.Module): + def forward(self, x, y): + for _ in range(10): + x = x + y + return x + + export_v2( + ManyAdds().eval(), + (torch.zeros(4, 4), torch.ones(4, 4)), + "many_adds", + ) + + # Linear: matmul + bias. + # x: [3,4]; W: [5,4]; bias: [5]. out = x @ W^T + bias → [3,5] + class Linear(torch.nn.Module): + def __init__(self): + super().__init__() + self.weight = torch.nn.Parameter(torch.full((5, 4), 0.25)) + self.bias = torch.nn.Parameter(torch.full((5,), -1.0)) + + def forward(self, x): + return torch.mm(x, self.weight.t()) + self.bias + + export_v2(Linear().eval(), (torch.ones(3, 4),), "linear") + + # Dynamic shape: x has a dynamic batch dim with bound min=1, max=8. + # At export time we trace with batch=3, but the model should run with + # any batch in [1, 8]. AOT memory planning sizes intermediates to + # max_shape (batch=8); per-execute, the actual batch is propagated + # via TensorImpl.sizes(). + class DynBatch(torch.nn.Module): + def forward(self, x, y): + t1 = x + y + t2 = t1 * y + return t2 - x + + batch = torch.export.Dim("batch", min=1, max=8) + export_v2( + DynBatch().eval(), + (torch.ones(3, 4), torch.full((3, 4), 2.0)), + "dyn_batch", + dynamic_shapes={"x": {0: batch}, "y": {0: batch}}, + ) + + # Dynamic shape across a cross-runtime boundary: + # permute_copy(weight) → CPU (not in MetalOpRegistry) + # mm(x, perm_weight) → Metal (TransferStep cpu→metal for perm_weight) + # add(mm_out, bias) → Metal + # The mm output's shape depends on x's dynamic batch dim; + # the add output's bias broadcast shape varies too. This exercises + # transfer_tensor's resize_tensor shape propagation under dynamism. + class DynLinear(torch.nn.Module): + def __init__(self): + super().__init__() + self.weight = torch.nn.Parameter(torch.full((5, 4), 0.25)) + self.bias = torch.nn.Parameter(torch.full((5,), -1.0)) + + def forward(self, x): + return torch.mm(x, self.weight.t()) + self.bias + + dyn_batch_lin = torch.export.Dim("dyn_batch_lin", min=1, max=8) + export_v2( + DynLinear().eval(), + (torch.ones(3, 4),), + "dyn_linear", + dynamic_shapes={"x": {0: dyn_batch_lin}}, + ) + + # Stateful model with a mutable buffer (KV-cache pattern). + # Each forward call adds x to the running accumulator (acc += x) + # and returns the new accumulator value. State must persist across + # execute() calls. + # call 1 with ones(2,3): output = [1,1,1; 1,1,1] + # call 2 with ones(2,3): output = [2,2,2; 2,2,2] + # call 3 with ones(2,3): output = [3,3,3; 3,3,3] + class StatefulAdd(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("acc", torch.zeros(2, 3)) + + def forward(self, x): + self.acc.add_(x) + return self.acc.clone() + + export_v2(StatefulAdd().eval(), (torch.ones(2, 3),), "stateful_add") + + print("\nDone. Run with:") + for name in [ + "tiny_add", + "add_const", + "two_inputs_chain", + "many_adds", + "linear", + "dyn_batch", + "dyn_linear", + "stateful_add", + ]: + print(f" ./cmake-out/executor_runner --model_path /tmp/{name}_v2.pte") + + +if __name__ == "__main__": + main() diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index e2d256ea396..940887e3f1f 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -794,10 +794,6 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager: ) ) - # Now cast to the dtype override after quantization, so non-quantized - # components use the desired computation dtype. - edge_manager.model = edge_manager.model.to(dtype=dtype_override.to_torch_dtype()) - return edge_manager @@ -1857,6 +1853,12 @@ def _get_source_transforms( # noqa ) ) + # Cast to dtype_override after quantization transforms, so non-quantized + # components use the desired computation dtype. This must happen before + # _convert_model_for_aarch64 which converts IntxUnpackedToInt8Tensor to + # IntxOpaqueTensor (which doesn't support .to()). + transforms.append(lambda m: m.to(dtype=dtype_override.to_torch_dtype())) + if any([use_torchao_kernels_linear, use_torchao_kernels_tied_embedding]): from torchao.prototype.tensor_conversion.api import _convert_model_for_aarch64