Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/metax_gpu/change_patch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@ cd ../../Paddle/
git apply --verbose ../backends/metax_gpu/patch/paddle.patch
cd -
# cp -r patch/intrinsics.cuh ../../Paddle/third_party/warpctc/include/contrib/moderngpu/include/device/
cp -r ./patch/warpctc.patch ../../Paddle/third_party/warpctc/
23 changes: 20 additions & 3 deletions backends/metax_gpu/cmake/warpctc.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,19 @@ else()
)
endif()

# set(WARPCTC_METAX_PATCH
# "${CUSTOM_DEVICE_SOURCE_DIR}/patches/warpctc-corex.patch") message(STATUS
# "warpctc: METAX patch at PATCH step: ${WARPCTC_METAX_PATCH}")

if(NOT WIN32 AND WITH_GPU)
# set(NVCC_FLAGS_EXTRA "${NVCC_FLAGS_EXTRA} --generate-line-info")
if(${CMAKE_CUDA_COMPILER_VERSION} LESS 12.0 AND ${CMAKE_CXX_COMPILER_VERSION}
VERSION_GREATER 12.0)
file(TO_NATIVE_PATH
${PADDLE_SOURCE_DIR}/patches/warpctc/CMakeLists.txt.patch native_src)
set(WARPCTC_PATCH_COMMAND git checkout -- . && git checkout ${WARPCTC_TAG}
&& patch -Nd ${SOURCE_DIR} < ${native_src} &&)
set(WARPCTC_CCBIN_OPTION -DCCBIN_COMPILER=${CCBIN_COMPILER})
set(WARPCTC_CCBIN_OPTIO -DCCBIN_COMPILER=${CCBIN_COMPILER})
endif()
endif()

Expand All @@ -66,6 +71,11 @@ endif()
set(WARPCTC_INCLUDE_DIR
"${WARPCTC_INSTALL_DIR}/include"
CACHE PATH "Warp-ctc Directory" FORCE)

set(WARPCTC_METAX_PATCH
"${PADDLE_SOURCE_DIR}/third_party/warpctc/warpctc.patch")
message(STATUS "warpctc: METAX patch at PATCH step: ${WARPCTC_METAX_PATCH}")

# Used in unit test test_WarpCTCLayer
set(WARPCTC_LIB_DIR
"${WARPCTC_INSTALL_DIR}/lib"
Expand Down Expand Up @@ -123,6 +133,11 @@ ExternalProject_Add(
COMMAND ${WARPCTC_PATCH_CUDA_COMMAND}
COMMAND ${COPY_COMMAND}
COMMAND ${WARPCTC_PATHCH_ROCM_COMMAND}
COMMAND ${CMAKE_COMMAND} -E echo
"warpctc: applying metax warpctc-corex.patch..."
COMMAND patch -p1 -i "${WARPCTC_METAX_PATCH}"
COMMAND ${CMAKE_COMMAND} -E echo
"warpctc: metax warpctc-corex.patch applied successfully."
# BUILD_ALWAYS 1
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
Expand All @@ -133,9 +148,11 @@ ExternalProject_Add(
-DCMAKE_CXX_FLAGS_RELEASE=${WARPCTC_CXX_FLAGS_RELEASE}
-DCMAKE_CXX_FLAGS_DEBUG=${WARPCTC_CXX_FLAGS_DEBUG}
-DCMAKE_INSTALL_PREFIX=${WARPCTC_INSTALL_DIR}
-DWITH_GPU=${WITH_GPU}
# -DWITH_GPU=${WITH_GPU}
-DWITH_GPU=ON
-DWITH_ROCM=${WITH_ROCM}
-DWITH_OMP=${USE_OMP}
# -DWITH_OMP=${USE_OMP}
-DWITH_OMP=OFF
-DNVCC_FLAGS_EXTRA=${NVCC_FLAGS_EXTRA}
-DWITH_TORCH=OFF
-DCMAKE_DISABLE_FIND_PACKAGE_Torch=ON
Expand Down
128 changes: 73 additions & 55 deletions backends/metax_gpu/patch/intrinsics.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ MGPU_DEVICE uint prmt_ptx(uint a, uint b, uint index) {
// shfl_up

__device__ __forceinline__ float shfl_up(float var,
unsigned int delta, int width = 32) {
unsigned int delta, int width = 64) {

#if __CUDA_ARCH__ >= 300
#if defined(__CUDACC_VER_MAJOR__) && (__CUDACC_VER_MAJOR__ >= 9)
Expand All @@ -122,7 +122,7 @@ __device__ __forceinline__ float shfl_up(float var,
}

__device__ __forceinline__ double shfl_up(double var,
unsigned int delta, int width = 32) {
unsigned int delta, int width = 64) {

#if __CUDA_ARCH__ >= 300
int2 p = mgpu::double_as_int2(var);
Expand Down Expand Up @@ -167,48 +167,66 @@ __device__ __forceinline__ double shfl_up(double var,
// return result;
// }

MGPU_DEVICE int shfl_add(int x, int offset, int width = 32)
{
#if __CUDA_ARCH__ >= 300
unsigned fullMask = 0xffffffffU;
unsigned mask = (width == 32) ? fullMask : ((1U << width) - 1U);
int src = 0;
#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 9
src = __shfl_up_sync(mask, x, offset, width); // CUDA 9+
#else
src = __shfl_up(x, offset, width); // CUDA 8-
#endif
int lane = threadIdx.x & 31;
return (lane >= offset) ? (src + x) : x;
#else
return x;
#endif
}
// MGPU_DEVICE int shfl_add(int x, int offset, int width = 32)
// {
// #if __CUDA_ARCH__ >= 300
// unsigned fullMask = 0xffffffffU;
// unsigned mask = (width == 32) ? fullMask : ((1U << width) - 1U);
// int src = 0;
// #if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 9
// src = __shfl_up_sync(mask, x, offset, width); // CUDA 9+
// #else
// src = __shfl_up(x, offset, width); // CUDA 8-
// #endif
// int lane = threadIdx.x & 31;
// return (lane >= offset) ? (src + x) : x;
// #else
// return x;
// #endif
// }

MGPU_DEVICE int shfl_add(int x, int offset, int width = WARP_SIZE) {
unsigned int lane_id = (unsigned)(threadIdx.x % (width > 0 ? width : 1));
int result = __shfl_up_sync(0xffffffffffffffffUL, x, offset, width);
if (lane_id < (unsigned)width && lane_id >= (unsigned)offset)
result += x;
return result;
}


// MGPU_DEVICE int shfl_max(int x, int offset, int width = WARP_SIZE) {
// int result = 0;
// #if __CUDA_ARCH__ >= 300
// int mask = (WARP_SIZE - width)<< 8;
// #if defined(__CUDACC_VER_MAJOR__) && (__CUDACC_VER_MAJOR__ >= 9)
// asm(
// "{.reg .s32 r0;"
// ".reg .pred p;"
// "shfl.up.sync.b32 r0|p, %1, %2, %3, 0xFFFFFFFF;"
// "@p max.s32 r0, r0, %4;"
// "mov.s32 %0, r0; }"
// : "=r"(result) : "r"(x), "r"(offset), "r"(mask), "r"(x));
// #else
// asm(
// "{.reg .s32 r0;"
// ".reg .pred p;"
// "shfl.up.b32 r0|p, %1, %2, %3;"
// "@p max.s32 r0, r0, %4;"
// "mov.s32 %0, r0; }"
// : "=r"(result) : "r"(x), "r"(offset), "r"(mask), "r"(x));
// #endif
// #endif
// return result;
// }

MGPU_DEVICE int shfl_max(int x, int offset, int width = WARP_SIZE) {
unsigned int lane_id = (unsigned)(threadIdx.x % (width > 0 ? width : 1));
int result = __shfl_up_sync(0xffffffffffffffffUL, x, offset, width);
if (lane_id < (unsigned)width && lane_id >= (unsigned)offset)
result = (result > x) ? result : x;
return result;
}

MGPU_DEVICE int shfl_max(int x, int offset, int width = WARP_SIZE) {
int result = 0;
#if __CUDA_ARCH__ >= 300
int mask = (WARP_SIZE - width)<< 8;
#if defined(__CUDACC_VER_MAJOR__) && (__CUDACC_VER_MAJOR__ >= 9)
asm(
"{.reg .s32 r0;"
".reg .pred p;"
"shfl.up.sync.b32 r0|p, %1, %2, %3, 0xFFFFFFFF;"
"@p max.s32 r0, r0, %4;"
"mov.s32 %0, r0; }"
: "=r"(result) : "r"(x), "r"(offset), "r"(mask), "r"(x));
#else
asm(
"{.reg .s32 r0;"
".reg .pred p;"
"shfl.up.b32 r0|p, %1, %2, %3;"
"@p max.s32 r0, r0, %4;"
"mov.s32 %0, r0; }"
: "=r"(result) : "r"(x), "r"(offset), "r"(mask), "r"(x));
#endif
#endif
return result;
}

////////////////////////////////////////////////////////////////////////////////
// brev, popc, clz, bfe, bfi, prmt
Expand Down Expand Up @@ -260,31 +278,31 @@ MGPU_HOST_DEVICE int ffs(int x) {
}

MGPU_HOST_DEVICE uint bfe(uint x, uint bit, uint numBits) {
#if __CUDA_ARCH__ >= 200
return bfe_ptx(x, bit, numBits);
#else
// #if __CUDA_ARCH__ >= 200
// return bfe_ptx(x, bit, numBits);
// #else
return ((1<< numBits) - 1) & (x>> bit);
#endif
// #endif
}

MGPU_HOST_DEVICE uint bfi(uint x, uint y, uint bit, uint numBits) {
uint result;
#if __CUDA_ARCH__ >= 200
result = bfi_ptx(x, y, bit, numBits);
#else
// #if __CUDA_ARCH__ >= 200
// result = bfi_ptx(x, y, bit, numBits);
// #else
if(bit + numBits > 32) numBits = 32 - bit;
uint mask = ((1<< numBits) - 1)<< bit;
result = y & ~mask;
result |= mask & (x<< bit);
#endif
// #endif
return result;
}

MGPU_HOST_DEVICE uint prmt(uint a, uint b, uint index) {
uint result;
#if __CUDA_ARCH__ >= 200
result = prmt_ptx(a, b, index);
#else
// #if __CUDA_ARCH__ >= 200
// result = prmt_ptx(a, b, index);
// #else
result = 0;
for(int i = 0; i < 4; ++i) {
uint sel = 0xf & (index>> (4 * i));
Expand All @@ -293,7 +311,7 @@ MGPU_HOST_DEVICE uint prmt(uint a, uint b, uint index) {
if(8 & sel) x = (128 & x) ? 0xff : 0;
result |= x<< (8 * i);
}
#endif
// #endif
return result;
}

Expand Down
15 changes: 15 additions & 0 deletions backends/metax_gpu/patch/paddle.patch
Original file line number Diff line number Diff line change
Expand Up @@ -1190,3 +1190,18 @@ index 7ced1fdc17..e49759ebb4 100644

template <typename T>

diff --git a/test/legacy_test/test_imperative_hook_for_layer.py b/test/legacy_test/test_imperative_hook_for_layer.py
index 6655bd61d9..282844f750 100644
--- a/test/legacy_test/test_imperative_hook_for_layer.py
+++ b/test/legacy_test/test_imperative_hook_for_layer.py
@@ -20,6 +20,10 @@ from op_test import get_places
import paddle
from paddle import base

+from paddle.base import core
+
+core.set_cublas_switch(False)
+
call_forward_post_hook = False
call_forward_pre_hook = False

46 changes: 46 additions & 0 deletions backends/metax_gpu/patch/warpctc.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
diff --git a/include/contrib/moderngpu/include/device/devicetypes.cuh b/include/contrib/moderngpu/include/device/devicetypes.cuh
index 282bedd..d8e094a 100644
--- a/include/contrib/moderngpu/include/device/devicetypes.cuh
+++ b/include/contrib/moderngpu/include/device/devicetypes.cuh
@@ -53,8 +53,8 @@ namespace mgpu {
#define MGPU_DEVICE __device__ INLINESYMBOL
#define MGPU_HOST_DEVICE __host__ __device__ INLINESYMBOL

-const int WARP_SIZE = 32;
-const int LOG_WARP_SIZE = 5;
+const int WARP_SIZE = 64;
+const int LOG_WARP_SIZE = 6;

////////////////////////////////////////////////////////////////////////////////
// Device-side comparison operators
diff --git a/include/detail/gpu_ctc.h b/include/detail/gpu_ctc.h
index cafe6ae..0cf5fc2 100644
--- a/include/detail/gpu_ctc.h
+++ b/include/detail/gpu_ctc.h
@@ -355,9 +355,9 @@ GpuCTC<ProbT>::create_metadata_and_choose_config(const int* const flat_labels,
constexpr int num_configs = 12;

int config_NT[num_configs] =
- {32, 64, 128, 64, 128, 32, 64, 128, 64, 128, 128, 128};
+ {64, 64, 128, 64, 128, 64, 64, 128, 64, 128, 128, 128};
int config_VT[num_configs] =
- { 1, 1, 1, 3, 2, 9, 6, 4, 9, 6, 9, 10};
+ { 1, 1, 1, 3, 2, 5, 6, 4, 9, 6, 9, 10};

best_config = 0;

@@ -383,12 +383,12 @@ GpuCTC<ProbT>::launch_gpu_kernels(const ProbT* const probs,
bool l_b) {

switch(config) {
- case 0: {return launch_alpha_beta_kernels<32, 1>(probs, grads, l_a, l_b);}
+ case 0: {return launch_alpha_beta_kernels<64, 1>(probs, grads, l_a, l_b);}
case 1: {return launch_alpha_beta_kernels<64, 1>(probs, grads, l_a, l_b);}
case 2: {return launch_alpha_beta_kernels<128, 1>(probs, grads, l_a, l_b);}
case 3: {return launch_alpha_beta_kernels<64, 3>(probs, grads, l_a, l_b);}
case 4: {return launch_alpha_beta_kernels<128, 2>(probs, grads, l_a, l_b);}
- case 5: {return launch_alpha_beta_kernels<32, 9>(probs, grads, l_a, l_b);}
+ case 5: {return launch_alpha_beta_kernels<64, 5>(probs, grads, l_a, l_b);}
case 6: {return launch_alpha_beta_kernels<64, 6>(probs, grads, l_a, l_b);}
case 7: {return launch_alpha_beta_kernels<128, 4>(probs, grads, l_a, l_b);}
case 8: {return launch_alpha_beta_kernels<64, 9>(probs, grads, l_a, l_b);}
36 changes: 36 additions & 0 deletions backends/metax_gpu/tests/unit_test/test_ctc_custom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
import paddle.nn.functional as F

paddle.set_device("metax_gpu:0") # 或 custom_device

batch = 4
width = 80
classes = 6625

preds = paddle.randn([batch, width, classes], dtype="float32")

preds = preds.transpose([1, 0, 2]) # -> [T,N,C]

labels = paddle.randint(1, classes, shape=[batch, 25], dtype="int32")

pred_len = paddle.full([batch], width, dtype="int64")
# label_len = paddle.full([batch], 15, dtype="int64") # 正常
label_len = paddle.full([batch], 16, dtype="int64") # 报错

loss = F.ctc_loss(preds, labels, pred_len, label_len)

print(loss)
Loading