1+ // Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+ //
3+ // Licensed under the Apache License, Version 2.0 (the "License");
4+ // you may not use this file except in compliance with the License.
5+ // You may obtain a copy of the License at
6+ //
7+ // http://www.apache.org/licenses/LICENSE-2.0
8+ //
9+ // Unless required by applicable law or agreed to in writing, software
10+ // distributed under the License is distributed on an "AS IS" BASIS,
11+ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ // See the License for the specific language governing permissions and
13+ // limitations under the License.
14+
15+ #include " cinn_interface.h"
16+ #include < cstring> // For memset
17+ #include < iostream>
18+
19+ namespace paddle {
20+ namespace custom_device {
21+ namespace metax {
22+
23+ // ============================================================
24+ // 外部函数声明 (External Function Declarations)
25+ // 这些函数需要在对应的子目录文件中实现 (.cc)
26+ // ============================================================
27+
28+ // --- 来自 compiler/compiler.cc ---
29+ // 负责调用 mxcc 将 CINN 生成的源代码编译为二进制
30+ extern C_Status MetaxCompile (void * dev_ptr,
31+ const char * code,
32+ char * out_path,
33+ size_t len);
34+
35+ // 负责提供沐曦 GPU 运行时的基础源码 (类似 cuda_device_runtime.cu)
36+ extern const char * MetaxGetRuntimeSource (void * dev_ptr);
37+
38+
39+ // --- 来自 runtime/cinn_runtime.cc ---
40+ // 负责加载编译好的二进制模块 (.mx / .so)
41+ extern C_Status MetaxModuleLoad (void * dev_ptr,
42+ const char * path,
43+ void ** mod_out);
44+
45+ // 负责卸载模块
46+ extern C_Status MetaxModuleUnload (void * dev_ptr,
47+ void * module_handle);
48+
49+ // 负责从模块中查找核函数地址
50+ extern C_Status MetaxGetKernelAddress (void * dev_ptr,
51+ void * module_handle,
52+ const char * func_name,
53+ void ** func_out);
54+
55+ // 负责启动核函数 (Launch Kernel)
56+ extern C_Status MetaxLaunchKernel (void * dev_ptr,
57+ void * func_ptr,
58+ void ** args,
59+ int num_args,
60+ int gx, int gy, int gz,
61+ int bx, int by, int bz,
62+ int shm,
63+ void * stream);
64+
65+
66+ // --- 来自 passes/pass_manager.cc ---
67+ // 负责应用自定义的图优化 Pass
68+ extern C_Status MetaxApplyCustomPass (void * dev_ptr,
69+ void * ir_module);
70+
71+
72+ // ============================================================
73+ // 接口初始化实现 (Interface Initialization)
74+ // ============================================================
75+
76+ // 静态实例,确保在插件生命周期内有效
77+ static C_CinnInterface metax_cinn_impl;
78+
79+ void InitCinnInterface (C_DeviceInterface* device_interface) {
80+ // 1. 安全起见,先清零
81+ std::memset (&metax_cinn_impl, 0 , sizeof (C_CinnInterface));
82+
83+ // 2. 设置结构体大小 (用于版本校验)
84+ metax_cinn_impl.size = sizeof (C_CinnInterface);
85+
86+ // 3. 设置上下文指针 (可选)
87+ // 如果你的实现需要全局状态,可以指向一个结构体;否则设为 nullptr
88+ metax_cinn_impl.dev_ptr = nullptr ;
89+
90+ // 4. 挂载 Compiler Toolchain 接口
91+ metax_cinn_impl.compile = MetaxCompile;
92+ metax_cinn_impl.get_runtime_source = MetaxGetRuntimeSource;
93+
94+ // 5. 挂载 Runtime Strategy 接口
95+ metax_cinn_impl.module_load = MetaxModuleLoad;
96+ metax_cinn_impl.module_unload = MetaxModuleUnload;
97+ metax_cinn_impl.get_kernel_address = MetaxGetKernelAddress;
98+ metax_cinn_impl.launch_kernel = MetaxLaunchKernel;
99+
100+ // 6. 挂载 Compile Strategy 接口
101+ metax_cinn_impl.apply_custom_pass = MetaxApplyCustomPass;
102+
103+ // 7. 【关键】将填好的表挂载到 Paddle 主设备接口上
104+ if (device_interface) {
105+ device_interface->cinn_interface = &metax_cinn_impl;
106+ // VLOG(3) << "[MetaX] CINN Interface initialized successfully.";
107+ } else {
108+ std::cerr << " [MetaX] Error: device_interface is null during CINN init." << std::endl;
109+ }
110+ }
111+
112+ } // namespace metax
113+ } // namespace custom_device
114+ } // namespace paddle
0 commit comments