Skip to content

Commit 0a20758

Browse files
committed
up
1 parent caa85a8 commit 0a20758

80 files changed

Lines changed: 15920 additions & 0 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
// AOTI kernel-dispatch C ABI for the v2 Metal backend.
10+
//
11+
// Symbols here are called by the AOTI .so when it has a hand-written
12+
// shader to dispatch (shader library + kernel function + arg encoding
13+
// + dispatch). Buffer mgmt and tensor lifecycle live in aoti_tensor.h;
14+
// op-registry fallbacks (mm/bmm) live in aoti_ops.h.
15+
16+
#pragma once
17+
18+
#include <executorch/backends/apple/metal/runtime/shims/v2/aoti_types.h>
19+
20+
namespace executorch {
21+
namespace backends {
22+
namespace metal {
23+
24+
struct AOTIMetalKernelFunctionOpaque;
25+
using AOTIMetalKernelFunctionHandle = AOTIMetalKernelFunctionOpaque*;
26+
27+
struct AOTIMetalShaderLibraryOpaque;
28+
using AOTIMetalShaderLibraryHandle = AOTIMetalShaderLibraryOpaque*;
29+
30+
#ifdef __cplusplus
31+
extern "C" {
32+
#endif
33+
34+
// Shader library
35+
AOTITorchError aoti_torch_mps_create_shader_library(
36+
const char* metal_shader_source,
37+
AOTIMetalShaderLibraryHandle* library_handle);
38+
39+
AOTITorchError aoti_torch_mps_delete_shader_library(
40+
AOTIMetalShaderLibraryHandle library_handle);
41+
42+
AOTITorchError aoti_torch_mps_get_kernel_function(
43+
AOTIMetalShaderLibraryHandle library_handle,
44+
const char* kernel_name,
45+
AOTIMetalKernelFunctionHandle* function_handle);
46+
47+
// Kernel arg / dispatch
48+
AOTITorchError aoti_torch_mps_start_encoding(
49+
AOTIMetalKernelFunctionHandle func);
50+
51+
AOTITorchError aoti_torch_mps_set_arg_tensor(
52+
AOTIMetalKernelFunctionHandle func,
53+
unsigned idx,
54+
AOTITensorHandle tensor);
55+
56+
AOTITorchError aoti_torch_mps_set_arg_int(
57+
AOTIMetalKernelFunctionHandle func,
58+
unsigned idx,
59+
int64_t val);
60+
61+
AOTITorchError aoti_torch_mps_dispatch_single(
62+
AOTIMetalKernelFunctionHandle func,
63+
uint64_t length);
64+
65+
AOTITorchError aoti_torch_mps_dispatch_single_with_group_size(
66+
AOTIMetalKernelFunctionHandle func,
67+
uint64_t length,
68+
uint64_t group_size);
69+
70+
AOTITorchError aoti_torch_mps_dispatch_array(
71+
AOTIMetalKernelFunctionHandle func,
72+
const uint64_t* length,
73+
size_t length_size);
74+
75+
AOTITorchError aoti_torch_mps_dispatch_array_with_group_size(
76+
AOTIMetalKernelFunctionHandle func,
77+
const uint64_t* length,
78+
size_t length_size,
79+
const uint64_t* group_size,
80+
size_t group_size_size);
81+
82+
// Command block
83+
typedef void (*aoti_torch_mps_command_block_callback_t)(
84+
AOTIMetalKernelFunctionHandle func,
85+
void* user_data);
86+
87+
void aoti_torch_mps_shared_callback(
88+
AOTIMetalKernelFunctionHandle func,
89+
void* user_data);
90+
91+
AOTITorchError aoti_torch_mps_run_command_block(
92+
AOTIMetalKernelFunctionHandle func,
93+
aoti_torch_mps_command_block_callback_t callback,
94+
void* user_data);
95+
96+
#ifdef __cplusplus
97+
} // extern "C"
98+
#endif
99+
100+
} // namespace metal
101+
} // namespace backends
102+
} // namespace executorch

0 commit comments

Comments
 (0)