Skip to content

Commit 7d9d5f9

Browse files
committed
Realize cinn_interface for metax_gpu
1 parent 2505ea9 commit 7d9d5f9

12 files changed

Lines changed: 877 additions & 9 deletions

File tree

backends/metax_gpu/CMakeLists.txt

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,13 @@ message(STATUS "CMAKE_MODULE_PATH: ${CMAKE_MODULE_PATH}")
2727
set(WITH_MKLML ON)
2828

2929
include(paddle)
30+
31+
# 【修改点 1】: 添加 CINN 子目录编译
32+
if(WITH_CINN)
33+
message(STATUS "[MetaX] CINN enabled, adding subdirectory: cinn")
34+
add_subdirectory(cinn)
35+
endif()
36+
3037
set(THIRD_PARTY_PATH
3138
"${PADDLE_SOURCE_DIR}/build/third_party"
3239
CACHE PATH "Third party libraries directory.")
@@ -761,6 +768,14 @@ set(CMAKE_CUCC_FLAGS "-I ${MACA_PATH}/tools/cu-bridge/include/")
761768

762769
add_library(${TARGET_NAME} SHARED ${CUSTOM_DEVICE_SRCS})
763770

771+
# 【修改点 2】: 添加 CINN 接口的头文件搜索路径
772+
# 这样 runtime/runtime.cc 里的 #include "../cinn/cinn_interface.h" 才能生效
773+
if(WITH_CINN)
774+
target_include_directories(${TARGET_NAME} PRIVATE
775+
"${CMAKE_CURRENT_SOURCE_DIR}/cinn"
776+
)
777+
endif()
778+
764779
target_include_directories(
765780
${TARGET_NAME}
766781
PRIVATE ${PADDLE_SOURCE_DIR}
@@ -790,15 +805,27 @@ target_link_libraries(${TARGET_NAME} ${MACA_PATH}/lib/libmccl.so)
790805
target_link_libraries(${TARGET_NAME} ${MACA_PATH}/lib/libmcFlashAttn.so)
791806
target_link_libraries(${TARGET_NAME} ${MACA_PATH}/lib/libmcpti.so)
792807

808+
# 【修改点 3】: 将 CINN 编译出的对象文件链接进最终的 .so
809+
# 只有这样,Plugin 加载时才能找到 InitCinnInterface 等符号
810+
if(WITH_CINN)
811+
message(STATUS "[MetaX] Linking CINN object library")
812+
target_link_libraries(${TARGET_NAME} $<TARGET_OBJECTS:metax_cinn_obj>)
813+
endif()
814+
793815
include_directories(BEFORE ${PADDLE_SOURCE_DIR})
794816
include_directories(BEFORE ${CMAKE_SOURCE_DIR}/headers)
795817

796818
target_compile_definitions(
797819
${TARGET_NAME}
798820
PUBLIC PADDLE_WITH_CUDA=1
799821
PADDLE_WITH_CUSTOM_DEVICE=1
800-
mcblasContext=cublasContext
822+
cublasContext=mcblasContext
801823
cublasLtContext=mcblasLtContext
824+
cublasLtMatmulDescOpaque_t=mcblasLtMatmulDescOpaque_t
825+
cublasLtMatrixLayoutOpaque_t=mcblasLtMatrixLayoutOpaque_t
826+
cublasLtMatmulPreferenceOpaque_t=mcblasLtMatmulPreferenceOpaque_t
827+
cublasLtMatmulAlgoOpaque_t=mcblasLtMatmulAlgoOpaque_t
828+
cublasStatus_t=mcblasStatus_t
802829
GPUContext=CustomContext
803830
KPSContext=CustomContext
804831
STREAM_TYPE=cudaStream_t

backends/metax_gpu/change_patch.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,6 @@ cp -r patch/eigen3/ ../../Paddle/third_party/eigen3
2424
rm -r patch/eigen3
2525
# cp patch/tmp/mixed_vector* ../../Paddle/paddle/phi/core
2626
cd ../../Paddle/
27-
git apply --verbose ../backends/metax_gpu/patch/paddle.patch
27+
git apply --verbose /home/sw/Baidu-xuyuhan/PaddleCustomDevice/backends/metax_gpu/patch/paddle.patch
2828
cd -
2929
# cp -r patch/intrinsics.cuh ../../Paddle/third_party/warpctc/include/contrib/moderngpu/include/device/
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# =============================================================================
2+
# CINN Plugin for MetaX (MACA) Backend
3+
# =============================================================================
4+
5+
# 1. 查找 MACA 路径
6+
# 为了在 runtime/cinn_runtime.cc 或 compiler.cc 中能 #include <maca_runtime.h>
7+
# 我们需要把沐曦 SDK 的头文件路径加进来
8+
set(MACA_PATH $ENV{MACA_PATH})
9+
if(NOT MACA_PATH)
10+
set(MACA_PATH "/opt/maca") # 默认回退路径
11+
message(STATUS "[MetaX CINN] MACA_PATH not set, using default: ${MACA_PATH}")
12+
else()
13+
message(STATUS "[MetaX CINN] Found MACA_PATH: ${MACA_PATH}")
14+
endif()
15+
16+
# 2. 定义源文件列表
17+
# 这里必须包含所有涉及到 CINN 实现的 .cc 文件
18+
set(CINN_SRCS
19+
cinn_interface.cc # 总入口,负责 InitCinnInterface
20+
compiler/compiler.cc # 【关键】负责 MetaxCompile 和 MetaxGetRuntimeSource
21+
runtime/cinn_runtime.cc # 负责 MetaxModuleLoad, MetaxLaunchKernel
22+
passes/pass_manager.cc # 负责 MetaxApplyCustomPass
23+
)
24+
25+
# 3. 创建 OBJECT 库
26+
# 使用 OBJECT 模式,只编译出 .o 文件,不生成 .a 或 .so
27+
# 这样上一级的 CMake 可以直接抓取这些 .o 文件链接进最终的 plugin.so
28+
add_library(metax_cinn_obj OBJECT ${CINN_SRCS})
29+
30+
# 4. 配置头文件搜索路径
31+
target_include_directories(metax_cinn_obj PRIVATE
32+
${CMAKE_CURRENT_SOURCE_DIR} # 允许引用当前目录头文件 (cinn_interface.h)
33+
${CMAKE_CURRENT_SOURCE_DIR}/../ # 允许引用上层头文件 (如 common/)
34+
${MACA_PATH}/include # 【关键】允许引用 <maca_runtime.h>
35+
${PADDLE_SOURCE_DIR} # 【新增】必须加这个!否则找不到 paddle/phi/...
36+
# Paddle 的头文件路径通常由外部环境 (Paddle_DIR) 自动包含
37+
)
38+
39+
# 5. 编译选项设置
40+
# CINN 组件通常依赖 C++17 标准
41+
set_property(TARGET metax_cinn_obj PROPERTY CXX_STANDARD 17)
42+
43+
# 开启 PIC (Position Independent Code)
44+
# 因为这些 .o 文件最终要被链接进动态库,必须开启此选项
45+
set_property(TARGET metax_cinn_obj PROPERTY POSITION_INDEPENDENT_CODE ON)
46+
47+
# 如果 compiler.cc 需要使用 filesystem 等库,可能需要链接 stdc++fs (视 GCC 版本而定)
48+
# 但因为是 OBJECT 库,链接操作推迟到父级进行
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "cinn_interface.h"
16+
#include <cstring> // For memset
17+
#include <iostream>
18+
19+
namespace paddle {
20+
namespace custom_device {
21+
namespace metax {
22+
23+
// ============================================================
24+
// 外部函数声明 (External Function Declarations)
25+
// 这些函数需要在对应的子目录文件中实现 (.cc)
26+
// ============================================================
27+
28+
// --- 来自 compiler/compiler.cc ---
29+
// 负责调用 mxcc 将 CINN 生成的源代码编译为二进制
30+
extern C_Status MetaxCompile(void* dev_ptr,
31+
const char* code,
32+
char* out_path,
33+
size_t len);
34+
35+
// 负责提供沐曦 GPU 运行时的基础源码 (类似 cuda_device_runtime.cu)
36+
extern const char* MetaxGetRuntimeSource(void* dev_ptr);
37+
38+
39+
// --- 来自 runtime/cinn_runtime.cc ---
40+
// 负责加载编译好的二进制模块 (.mx / .so)
41+
extern C_Status MetaxModuleLoad(void* dev_ptr,
42+
const char* path,
43+
void** mod_out);
44+
45+
// 负责卸载模块
46+
extern C_Status MetaxModuleUnload(void* dev_ptr,
47+
void* module_handle);
48+
49+
// 负责从模块中查找核函数地址
50+
extern C_Status MetaxGetKernelAddress(void* dev_ptr,
51+
void* module_handle,
52+
const char* func_name,
53+
void** func_out);
54+
55+
// 负责启动核函数 (Launch Kernel)
56+
extern C_Status MetaxLaunchKernel(void* dev_ptr,
57+
void* func_ptr,
58+
void** args,
59+
int num_args,
60+
int gx, int gy, int gz,
61+
int bx, int by, int bz,
62+
int shm,
63+
void* stream);
64+
65+
66+
// --- 来自 passes/pass_manager.cc ---
67+
// 负责应用自定义的图优化 Pass
68+
extern C_Status MetaxApplyCustomPass(void* dev_ptr,
69+
void* ir_module);
70+
71+
72+
// ============================================================
73+
// 接口初始化实现 (Interface Initialization)
74+
// ============================================================
75+
76+
// 静态实例,确保在插件生命周期内有效
77+
static C_CinnInterface metax_cinn_impl;
78+
79+
void InitCinnInterface(C_DeviceInterface* device_interface) {
80+
// 1. 安全起见,先清零
81+
std::memset(&metax_cinn_impl, 0, sizeof(C_CinnInterface));
82+
83+
// 2. 设置结构体大小 (用于版本校验)
84+
metax_cinn_impl.size = sizeof(C_CinnInterface);
85+
86+
// 3. 设置上下文指针 (可选)
87+
// 如果你的实现需要全局状态,可以指向一个结构体;否则设为 nullptr
88+
metax_cinn_impl.dev_ptr = nullptr;
89+
90+
// 4. 挂载 Compiler Toolchain 接口
91+
metax_cinn_impl.compile = MetaxCompile;
92+
metax_cinn_impl.get_runtime_source = MetaxGetRuntimeSource;
93+
94+
// 5. 挂载 Runtime Strategy 接口
95+
metax_cinn_impl.module_load = MetaxModuleLoad;
96+
metax_cinn_impl.module_unload = MetaxModuleUnload;
97+
metax_cinn_impl.get_kernel_address = MetaxGetKernelAddress;
98+
metax_cinn_impl.launch_kernel = MetaxLaunchKernel;
99+
100+
// 6. 挂载 Compile Strategy 接口
101+
metax_cinn_impl.apply_custom_pass = MetaxApplyCustomPass;
102+
103+
// 7. 【关键】将填好的表挂载到 Paddle 主设备接口上
104+
if (device_interface) {
105+
device_interface->cinn_interface = &metax_cinn_impl;
106+
// VLOG(3) << "[MetaX] CINN Interface initialized successfully.";
107+
} else {
108+
std::cerr << "[MetaX] Error: device_interface is null during CINN init." << std::endl;
109+
}
110+
}
111+
112+
} // namespace metax
113+
} // namespace custom_device
114+
} // namespace paddle
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
// 引入 Paddle 定义的 C 接口结构体
18+
#include "paddle/phi/backends/device_ext.h"
19+
20+
namespace paddle {
21+
namespace custom_device {
22+
namespace metax {
23+
24+
/**
25+
* @brief 初始化 CINN 接口
26+
* * 这个函数由 runtime.cc 中的 InitPlugin 调用。
27+
* 它负责将 metax_gpu/cinn 下实现的编译器和运行时函数指针,
28+
* 填充到 device_interface->cinn_interface 中。
29+
* * @param device_interface Paddle Host 侧传入的设备接口指针
30+
*/
31+
void InitCinnInterface(C_DeviceInterface* device_interface);
32+
33+
} // namespace metax
34+
} // namespace custom_device
35+
} // namespace paddle

0 commit comments

Comments
 (0)