Skip to content

Commit b6534ad

Browse files
authored
Add more hardware attribute to support CINN. (#2316)
* Add more hardware attribute to support CINN. * Realize cinn_interface for metax_gpu * Demo test_elementwise_pow_op_metax.py is pass! * Support custom_device_intrinscs_reduce * Fix CINN compilation errors and incorrect reduction results on MetaX backend. Run test_elementwise_pow_op_metax.py success. * Fix some bug. * Add CINN_ENTAIL_LOOP_CONDITION into /backends/metax_gpu/cinn/compiler/compiler.cc * Support argidx ArgMin/ArgMax Block Reduce for CINN metax_gpu. * CINN metax_gpu compiler.cc support Welford for BatchNorm_11_class.py * Fix warp reduce, block reduce for metax_gpu warp_size=64. * Update GetMaxSharedMemPerBlock. * Add compiler.cc int64 int32 abs. int64 vs double abs. * Translate annotation to English. Fix undeclared identifier 'cinn_discrete_reduce_max_argidx_fp32_i32' * Recover CINN irrelevant code. * Fix Code-style. * Update Paddle. * Fix CMakeLists.txt PADDLE_WARP_SIZE 32->64. Fix argidx_fp32_i32 forward reference error in MetaX runtime.
1 parent c4c1031 commit b6534ad

10 files changed

Lines changed: 1751 additions & 3 deletions

File tree

Paddle

Submodule Paddle updated 1296 files

backends/metax_gpu/CMakeLists.txt

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,29 @@ project(${PROJ_NAME} CXX C CUDA)
1717

1818
set(TARGET_NAME ${PROJ_NAME})
1919

20+
option(WITH_CINN "Compile with CINN support" ON)
21+
2022
find_package(Python3 REQUIRED COMPONENTS Interpreter)
2123
set(PY_VERSION ${Python3_VERSION_MAJOR}.${Python3_VERSION_MINOR})
2224
message(STATUS "Python version detected: ${PY_VERSION}")
2325
set(PYTHON_VERSION ${PY_VERSION})
2426

2527
set(CMAKE_MODULE_PATH "${CMAKE_SOURCE_DIR}/cmake")
2628
message(STATUS "CMAKE_MODULE_PATH: ${CMAKE_MODULE_PATH}")
29+
30+
if(NOT DEFINED PADDLE_WARP_SIZE)
31+
set(PADDLE_WARP_SIZE 64)
32+
endif()
33+
math(EXPR PADDLE_WARP_MASK "${PADDLE_WARP_SIZE} - 1")
34+
if(PADDLE_WARP_SIZE EQUAL 64)
35+
set(PADDLE_WARP_SHIFT 6)
36+
else()
37+
set(PADDLE_WARP_SHIFT 5)
38+
endif()
39+
add_definitions(-DPADDLE_WARP_SIZE=${PADDLE_WARP_SIZE})
40+
add_definitions(-DPADDLE_WARP_MASK=${PADDLE_WARP_MASK})
41+
add_definitions(-DPADDLE_WARP_SHIFT=${PADDLE_WARP_SHIFT})
42+
2743
set(WITH_MKLML ON)
2844
if(WITH_ARM)
2945
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fPIC")
@@ -40,6 +56,13 @@ if(WITH_ARM)
4056
add_definitions(-DPADDLE_WITH_ARM)
4157
endif()
4258
include(paddle)
59+
60+
if(WITH_CINN)
61+
message(STATUS "[MetaX] CINN enabled, adding subdirectory: cinn")
62+
add_definitions(-DWITH_CINN)
63+
add_subdirectory(cinn)
64+
endif()
65+
4366
set(THIRD_PARTY_PATH
4467
"${PADDLE_SOURCE_DIR}/build/third_party"
4568
CACHE PATH "Third party libraries directory.")
@@ -792,6 +815,11 @@ set(CMAKE_CUCC_FLAGS "-I ${MACA_PATH}/tools/cu-bridge/include/")
792815

793816
add_library(${TARGET_NAME} SHARED ${CUSTOM_DEVICE_SRCS})
794817

818+
if(WITH_CINN)
819+
target_include_directories(${TARGET_NAME}
820+
PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/cinn")
821+
endif()
822+
795823
target_include_directories(
796824
${TARGET_NAME}
797825
PRIVATE ${PADDLE_SOURCE_DIR}
@@ -821,6 +849,11 @@ target_link_libraries(${TARGET_NAME} ${MACA_PATH}/lib/libmccl.so)
821849
target_link_libraries(${TARGET_NAME} ${MACA_PATH}/lib/libmcFlashAttn.so)
822850
target_link_libraries(${TARGET_NAME} ${MACA_PATH}/lib/libmcpti.so)
823851

852+
if(WITH_CINN)
853+
message(STATUS "[MetaX] Linking CINN object library")
854+
target_link_libraries(${TARGET_NAME} $<TARGET_OBJECTS:metax_cinn_obj>)
855+
endif()
856+
824857
include_directories(BEFORE ${PADDLE_SOURCE_DIR})
825858
include_directories(BEFORE ${CMAKE_SOURCE_DIR}/headers)
826859

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# =============================================================================
2+
# CINN Plugin for MetaX (MACA) Backend
3+
# =============================================================================
4+
5+
# 1. Locate MACA SDK path To allow #include <maca_runtime.h> in
6+
# runtime/cinn_runtime.cc or compiler.cc, we need to add the MetaX SDK header
7+
# search path.
8+
set(MACA_PATH $ENV{MACA_PATH})
9+
if(NOT MACA_PATH)
10+
set(MACA_PATH "/opt/maca") # Default fallback path
11+
message(STATUS "[MetaX CINN] MACA_PATH not set, using default: ${MACA_PATH}")
12+
else()
13+
message(STATUS "[MetaX CINN] Found MACA_PATH: ${MACA_PATH}")
14+
endif()
15+
16+
# 1. Define source file list All .cc files involved in the CINN implementation
17+
# must be included here.
18+
set(CINN_SRCS
19+
cinn_interface.cc # Main entry point, responsible for InitCinnInterface
20+
compiler/compiler.cc # Implements MetaxCompile and MetaxGetRuntimeSource
21+
runtime/cinn_runtime.cc # Implements MetaxModuleLoad, MetaxLaunchKernel
22+
passes/pass_manager.cc # Implements MetaxApplyCustomPass
23+
)
24+
25+
# 1. Create OBJECT library Use OBJECT mode to compile into .o files only, without
26+
# generating .a or .so. This allows the parent CMake to directly collect these
27+
# .o files and link them into the final plugin.so.
28+
add_library(metax_cinn_obj OBJECT ${CINN_SRCS})
29+
30+
# 1. Configure header search paths
31+
target_include_directories(
32+
metax_cinn_obj
33+
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} # Allow referencing headers in current
34+
# directory (cinn_interface.h)
35+
${CMAKE_CURRENT_SOURCE_DIR}/../ # Allow referencing parent-level
36+
# headers (e.g., common/)
37+
${MACA_PATH}/include # Allow referencing <maca_runtime.h>
38+
${PADDLE_SOURCE_DIR} # Allow referencing paddle/phi/... headers
39+
# Paddle header paths are typically auto-included via the external
40+
# environment (Paddle_DIR)
41+
)
42+
43+
# 1. Compiler options The CINN component typically requires C++17 standard
44+
set_property(TARGET metax_cinn_obj PROPERTY CXX_STANDARD 17)
45+
46+
# Enable PIC (Position Independent Code) Required because these .o files will
47+
# ultimately be linked into a shared library
48+
set_property(TARGET metax_cinn_obj PROPERTY POSITION_INDEPENDENT_CODE ON)
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "cinn/cinn_interface.h"
16+
17+
#include <cstring> // For memset
18+
#include <iostream>
19+
20+
namespace paddle {
21+
namespace custom_device {
22+
namespace metax {
23+
24+
// ============================================================
25+
// External Function Declarations
26+
// These functions must be implemented in the corresponding subdirectory files
27+
// (.cc).
28+
// ============================================================
29+
30+
// --- From compiler/compiler.cc ---
31+
// Invokes the mxcc toolchain to compile CINN-generated source code into a
32+
// binary
33+
extern C_Status MetaxCompile(void* dev_ptr,
34+
const char* code,
35+
char* out_path,
36+
size_t len);
37+
38+
// Provides the MetaX GPU device runtime source code
39+
extern const char* MetaxGetRuntimeSource(void* dev_ptr);
40+
41+
// --- From runtime/cinn_runtime.cc ---
42+
// Loads a compiled binary module (.mx / .so)
43+
extern C_Status MetaxModuleLoad(void* dev_ptr,
44+
const char* path,
45+
void** mod_out);
46+
47+
// Unloads a module
48+
extern C_Status MetaxModuleUnload(void* dev_ptr, void* module_handle);
49+
50+
// Retrieves the kernel function address from a loaded module
51+
extern C_Status MetaxGetKernelAddress(void* dev_ptr,
52+
void* module_handle,
53+
const char* func_name,
54+
void** func_out);
55+
56+
// Launches a kernel function
57+
extern C_Status MetaxLaunchKernel(void* dev_ptr,
58+
void* func_ptr,
59+
void** args,
60+
int num_args,
61+
int gx,
62+
int gy,
63+
int gz,
64+
int bx,
65+
int by,
66+
int bz,
67+
int shm,
68+
void* stream);
69+
70+
// --- From passes/pass_manager.cc ---
71+
// Applies custom graph optimization passes
72+
extern C_Status MetaxApplyCustomPass(void* dev_ptr, void* ir_module);
73+
74+
// ============================================================
75+
// Interface Initialization
76+
// ============================================================
77+
78+
// Static instance, valid throughout the plugin lifetime
79+
static C_CinnInterface metax_cinn_impl;
80+
81+
void InitCinnInterface(C_DeviceInterface* device_interface) {
82+
// 1. Zero-initialize for safety
83+
std::memset(&metax_cinn_impl, 0, sizeof(C_CinnInterface));
84+
85+
// 2. Set struct size (used for version validation)
86+
metax_cinn_impl.size = sizeof(C_CinnInterface);
87+
88+
// 3. Set context pointer (optional)
89+
// Point to a global state struct if your implementation needs one; otherwise
90+
// nullptr
91+
metax_cinn_impl.dev_ptr = nullptr;
92+
93+
// 4. Register Compiler Toolchain interface
94+
metax_cinn_impl.compile = MetaxCompile;
95+
metax_cinn_impl.get_runtime_source = MetaxGetRuntimeSource;
96+
97+
// 5. Register Runtime Strategy interface
98+
metax_cinn_impl.module_load = MetaxModuleLoad;
99+
metax_cinn_impl.module_unload = MetaxModuleUnload;
100+
metax_cinn_impl.get_kernel_address = MetaxGetKernelAddress;
101+
metax_cinn_impl.launch_kernel = MetaxLaunchKernel;
102+
103+
// 6. Register Compilation Strategy interface
104+
metax_cinn_impl.apply_custom_pass = MetaxApplyCustomPass;
105+
106+
// 7. Attach the populated dispatch table to the Paddle device interface
107+
if (device_interface) {
108+
device_interface->cinn_interface = &metax_cinn_impl;
109+
} else {
110+
std::cerr << "[MetaX] Error: device_interface is null during CINN init."
111+
<< std::endl;
112+
}
113+
}
114+
115+
} // namespace metax
116+
} // namespace custom_device
117+
} // namespace paddle
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
// Include the Paddle-defined C interface structures
18+
#include "paddle/phi/backends/device_ext.h"
19+
20+
namespace paddle {
21+
namespace custom_device {
22+
namespace metax {
23+
24+
/**
25+
* @brief Initialize the CINN interface.
26+
*
27+
* This function is called by InitPlugin in runtime.cc.
28+
* It populates device_interface->cinn_interface with the compiler
29+
* and runtime function pointers implemented under metax_gpu/cinn.
30+
*
31+
* @param device_interface The device interface pointer passed from the Paddle
32+
* host side.
33+
*/
34+
void InitCinnInterface(C_DeviceInterface* device_interface);
35+
36+
} // namespace metax
37+
} // namespace custom_device
38+
} // namespace paddle

0 commit comments

Comments
 (0)