Skip to content

Commit 18565c4

Browse files
committed
up
1 parent 7071b04 commit 18565c4

147 files changed

Lines changed: 7879 additions & 4829 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -751,6 +751,10 @@ if(EXECUTORCH_BUILD_CUDA)
751751
endif()
752752

753753
if(EXECUTORCH_BUILD_METAL)
754+
# backends/metal is the runtime library (lives at backends/metal/).
755+
# The AOTI v2 shim (backends/apple/metal) and any other consumers
756+
# depend on this lib for the metal_v2 runtime types.
757+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/metal)
754758
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/apple/metal)
755759
list(APPEND _executorch_backends metal_backend)
756760
endif()

backends/apple/metal/CMakeLists.txt

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,20 +35,43 @@ include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake)
3535
find_package_torch()
3636

3737
set(_aoti_metal_sources
38-
runtime/metal_backend.cpp
3938
runtime/stats.cpp
39+
# v1-only sources (no v2 replacement yet — built in both branches):
40+
runtime/shims/utils.cpp
41+
)
42+
43+
# AOTI shim layer: choose between v1 (vendored MLX qmv_fast for int4,
44+
# legacy stream/encoding/buffer-mgmt shims) and v2 (routes through
45+
# metal_v2::MetalStream + MetalOpRegistry → AffineQuantizedLinearOp,
46+
# SDPAOp with NAX / qmm_t / qmm_t_splitk dispatch + per-CB residency).
47+
#
48+
# v1 and v2 shim files share extern "C" symbol names (intentionally —
49+
# PTE files lower against the same AOTI shim ABI). They CANNOT coexist
50+
# in the same build; this if/else picks one.
51+
if(EXECUTORCH_USE_METAL_V2)
52+
list(APPEND _aoti_metal_sources
53+
runtime/metal_backend_v2.cpp
54+
runtime/shims/v2/aoti_tensor.cpp
55+
runtime/shims/v2/aoti_dtype_stubs.cpp
56+
runtime/shims/v2/runtime.mm
57+
runtime/shims/v2/aoti_kernel.mm
58+
runtime/shims/v2/aoti_fallback_op.mm
59+
)
60+
else()
61+
list(APPEND _aoti_metal_sources
62+
runtime/metal_backend.cpp
4063
runtime/shims/memory.cpp
4164
runtime/shims/et_metal.mm
4265
runtime/shims/shim_mps.mm
4366
runtime/shims/tensor_attribute.cpp
44-
runtime/shims/utils.cpp
4567
runtime/ops/common.mm
4668
runtime/ops/op_bmm.mm
4769
runtime/ops/op_convolution.mm
4870
runtime/ops/op_linear_4bit.mm
4971
runtime/ops/op_mm.mm
5072
runtime/ops/op_sdpa.mm
51-
)
73+
)
74+
endif()
5275

5376
add_library(metal_backend STATIC ${_aoti_metal_sources})
5477
target_include_directories(
@@ -87,6 +110,14 @@ endif()
87110

88111
target_link_options(metal_backend PUBLIC -Wl,-export_dynamic)
89112

113+
# Under the v2 shim path, metal_backend depends on the metal_v2 lib
114+
# (defined in backends/metal/) for MetalStream, MetalOpRegistry,
115+
# AffineQuantizedLinearOp, etc. PRIVATE link so it doesn't propagate
116+
# through install exports of metal_backend's downstream consumers.
117+
if(EXECUTORCH_USE_METAL_V2)
118+
target_link_libraries(metal_backend PRIVATE metal_v2)
119+
endif()
120+
90121
# Find PyTorch's OpenMP library specifically for libtorch-less AOTI
91122
get_torch_base_path(TORCH_BASE_PATH)
92123
find_library(
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
// Dtype + layout helpers shared across the v2 AOTI shim layer.
10+
11+
#pragma once
12+
13+
#include <executorch/backends/aoti/utils.h>
14+
#include <executorch/backends/apple/metal/runtime/shims/v2/aoti_types.h>
15+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
16+
17+
#include <cstddef>
18+
#include <cstdint>
19+
#include <vector>
20+
21+
namespace executorch {
22+
namespace backends {
23+
namespace metal {
24+
25+
// Both enums use the standard PyTorch dtype encoding; value-cast is safe.
26+
inline executorch::aten::ScalarType to_aten_scalar_type(
27+
executorch::backends::aoti::slim::c10::ScalarType slim_dt) {
28+
return static_cast<executorch::aten::ScalarType>(static_cast<int>(slim_dt));
29+
}
30+
31+
inline size_t dtype_to_bytes(int32_t dtype) {
32+
return executorch::backends::aoti::dtype_to_element_size(dtype);
33+
}
34+
35+
// Standard PyTorch-style contiguous strides (in element units).
36+
// For a degenerate shape with a 0-sized dim, strides for the higher
37+
// dim collapse to 0 — same convention as torch.empty(N, 0).contiguous().
38+
inline std::vector<int64_t> compute_contiguous_strides(
39+
const std::vector<int64_t>& sizes) {
40+
std::vector<int64_t> strides(sizes.size());
41+
if (sizes.empty()) return strides;
42+
int64_t stride = 1;
43+
for (ssize_t i = static_cast<ssize_t>(sizes.size()) - 1; i >= 0; --i) {
44+
strides[i] = stride;
45+
stride *= sizes[i];
46+
}
47+
return strides;
48+
}
49+
50+
// Maximum tensor rank supported by StackTensorView (used by the
51+
// MetalOpRegistry fallback path in aoti_fallback_op.mm). AOTI shader
52+
// dispatch (aoti_kernel.mm) is rank-agnostic and not subject to this.
53+
constexpr size_t kMaxTensorDim = 8;
54+
55+
} // namespace metal
56+
} // namespace backends
57+
} // namespace executorch
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
// Linker stubs for AOTI dtype trampolines that aoti/common_shims_slim
10+
// doesn't define. Required so dlopen of an AOTI .so resolves cleanly
11+
// even when the model never actually uses the unsupported dtype.
12+
13+
#include <cstdint>
14+
15+
extern "C" {
16+
17+
// PyTorch float16 dtype code = c10::ScalarType::Half. Models that
18+
// actually USE float16 will fault inside SlimTensor::check_supportive
19+
// because slim::c10::ScalarType has no Half variant; this stub just
20+
// satisfies the linker.
21+
int32_t aoti_torch_dtype_float16() {
22+
return 5;
23+
}
24+
25+
} // extern "C"

0 commit comments

Comments
 (0)