Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backends/metax_gpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ file(
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/mp_allreduce_sum_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/pow2_decay_with_linear_warmup_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/sigmoid_cross_entropy_with_logits_grad_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/top_k_kernel.cu
# ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/top_k_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/top_k_grad_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/where_grad_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/where_kernel.cu
Expand Down
1 change: 1 addition & 0 deletions backends/metax_gpu/compile.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ export CUCC_PATH=${MACA_PATH}/tools/cu-bridge
export PATH=${PATH}:${CUCC_PATH}/tools:${CUCC_PATH}/bin
export PATH=${MACA_PATH}/bin:${PATH}
export LD_LIBRARY_PATH=${MACA_PATH}/lib:${MACA_PATH}/mxgpu_llvm/lib:${LD_LIBRARY_PATH}
# export MXCC_OVERRIDE_OPTIONS="+-mllvm +-metaxgpu-inline-branch-fold-bias=10000"
export PADDLE_VERSION="3.3.0.dev$(date +%Y%m%d)"
export MACA_AI_VERSION=$(cat /opt/maca/Version.txt | cut -d':' -f2)
if [ ! -d build ]; then
Expand Down
2 changes: 1 addition & 1 deletion backends/metax_gpu/env.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

export MXCC_OVERRIDE_OPTIONS="+-mllvm +-metaxgpu-inline-branch-fold-bias=10000"
DEFAULT_DIR="/opt/maca"
export MACA_PATH=${1:-$DEFAULT_DIR}
export CUDA_PATH=/usr/local/cuda
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -281,8 +281,8 @@ void FlashAttnGradKernel(const Context& ctx,
// // printf("params.dq dims[2]:%d, params.dk dims[2]:%d, params.dv
// dims[2]:%d\n", params.dq->head_num, params.dk->head_num,
// params.dv->head_num);
print_tensor_info(params.dq);
print_tensor_info(params.dk);
// print_tensor_info(params.dq);
// print_tensor_info(params.dk);
// print_tensor_info(params.dv);
mcflashattnStatus_t succ = phi::dynload::mha_bwd(params.batch_size,
params.seqlen_q,
Expand Down
125 changes: 70 additions & 55 deletions backends/metax_gpu/patch/paddle.patch
Original file line number Diff line number Diff line change
Expand Up @@ -198,19 +198,19 @@ index d970878dc2..fe0382ccad 100644
x = *reinterpret_cast<uint16_t*>(&tmp);

diff --git a/paddle/phi/core/enforce.h b/paddle/phi/core/enforce.h
index 024a7de73e..66b373d698 100644
index d07575028c..ec262da03a 100644
--- a/paddle/phi/core/enforce.h
+++ b/paddle/phi/core/enforce.h
@@ -97,7 +97,7 @@ inline bool is_error(bool stat) { return !stat; }

void ThrowWarnInternal(const std::string& message);
PADDLE_API void ThrowWarnInternal(const std::string& message);

-#if defined(__CUDA_ARCH__)
+#if defined(__CUDACC__)
// For cuda, the assertions can affect performance and it is therefore
// recommended to disable them in production code
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#assertion
@@ -109,7 +109,7 @@ void ThrowWarnInternal(const std::string& message);
@@ -109,7 +109,7 @@ PADDLE_API void ThrowWarnInternal(const std::string& message);
__LINE__, \
#_IS_NOT_ERROR, \
##__VA_ARGS__); \
Expand Down Expand Up @@ -916,45 +916,6 @@ index 75a8f71d8c..cb21e9e301 100644
#include "paddle/phi/kernels/impl/qr_kernel_impl.h"
#include "paddle/phi/kernels/impl/tril_triu_kernel_impl.h"
#include "paddle/phi/kernels/lstsq_kernel.h"
diff --git a/paddle/phi/kernels/impl/gammaincc_kernel_impl.h b/paddle/phi/kernels/impl/gammaincc_kernel_impl.h
index 4a28600c38..d96495b7aa 100644
--- a/paddle/phi/kernels/impl/gammaincc_kernel_impl.h
+++ b/paddle/phi/kernels/impl/gammaincc_kernel_impl.h
@@ -56,8 +56,8 @@ HOSTDEVICE T igam(const T a, const T x) {

template <typename T>
HOSTDEVICE T igamc(const T a, const T x) {
- static T big = 4.503599627370496e15;
- static T biginv = 2.22044604925031308085e-16;
+ const static T big = 4.503599627370496e15;
+ const static T biginv = 2.22044604925031308085e-16;

if ((x <= T{0}) || (a <= T{0})) return (T{1.0});

diff --git a/paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h b/paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h
index c627cc1264..b3941570ee 100644
--- a/paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h
+++ b/paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h
@@ -20,8 +20,8 @@
namespace phi {
template <typename T>
HOSTDEVICE T digamma_positive_domain(T x) {
- static T c = T{8.5};
- static T euler_mascheroni = T{0.57721566490153286060};
+ const static T c = T{8.5};
+ const static T euler_mascheroni = T{0.57721566490153286060};
T r;
T value;
T x2;
@@ -54,7 +54,7 @@ HOSTDEVICE T digamma_positive_domain(T x) {

template <typename T>
HOSTDEVICE T digamma(T x) {
- static T pi = T{3.14159265358979323846};
+ const static T pi = T{3.14159265358979323846};

if (x == T{0.0}) {
T inf = std::numeric_limits<T>::infinity();

diff --git a/paddle/phi/kernels/gpudnn/softmax_gpudnn.h b/paddle/phi/kernels/gpudnn/softmax_gpudnn.h
index be6ee4f854..1f507c99f4 100644
Expand Down Expand Up @@ -1056,19 +1017,6 @@ index be6ee4f854..1f507c99f4 100644
} else {
LaunchNormalSoftmaxForward<T, IndexType, LogMode>(

diff --git a/paddle/phi/kernels/gpu/cross_entropy_grad_kernel.cu b/paddle/phi/kernels/gpu/cross_entropy_grad_kernel.cu
index 0a415200df..b0732e28f3 100644
--- a/paddle/phi/kernels/gpu/cross_entropy_grad_kernel.cu
+++ b/paddle/phi/kernels/gpu/cross_entropy_grad_kernel.cu
@@ -147,7 +147,7 @@ void CrossEntropyWithSoftmaxGradGPUKernel(const GPUContext& dev_ctx,
DenseTensor* logits_grad) {
PADDLE_ENFORCE_EQ(
dev_ctx.GetPlace().GetType(),
- AllocationType::GPU,
+ AllocationType::CUSTOM,
common::errors::Unavailable("softmax_with_cross_entropy operator's "
"CUDA kernel only runs on GPU device."));
const T* loss_grad_data = loss_grad.data<T>();
diff --git a/paddle/phi/kernels/funcs/cublaslt.h b/paddle/phi/kernels/funcs/cublaslt.h
index d8bc15926b..6071baf340 100644
--- a/paddle/phi/kernels/funcs/cublaslt.h
Expand Down Expand Up @@ -1102,3 +1050,70 @@ index d8bc15926b..6071baf340 100644
PADDLE_ENFORCE_EQ(
status,

diff --git a/paddle/phi/kernels/funcs/top_k_cuda_kernel.h b/paddle/phi/kernels/funcs/top_k_cuda_kernel.h
index 368cb21c21..f0f99fbd2f 100644
--- a/paddle/phi/kernels/funcs/top_k_cuda_kernel.h
+++ b/paddle/phi/kernels/funcs/top_k_cuda_kernel.h
@@ -167,7 +167,7 @@ struct Bitfield<unsigned int> {
int pos,
int len) {
unsigned int ret;
-#if defined(__HIPCC__)
+#if defined(PADDLE_WITH_CUDA)
ret = (val >> pos) & ((1u << len) - 1u);
#else
asm("bfe.u32 %0, %1, %2, %3;" : "=r"(ret) : "r"(val), "r"(pos), "r"(len));
@@ -178,7 +178,7 @@ struct Bitfield<unsigned int> {
static __device__ __forceinline__ unsigned int setBitfield(
unsigned int val, unsigned int to_insert, int pos, int len) {
unsigned int ret;
-#if defined(__HIPCC__)
+#if defined(PADDLE_WITH_CUDA)
unsigned int mask = ((1u << len) - 1u) << pos;
ret = (val & ~mask) | ((to_insert << pos) & mask);
#else
@@ -196,7 +196,7 @@ struct Bitfield<uint64_t> {
int pos,
int len) {
uint64_t ret;
-#if defined(__HIPCC__)
+#if defined(PADDLE_WITH_CUDA)
ret = (val >> pos) & ((1ULL << len) - 1ULL);
#else
asm("bfe.u64 %0, %1, %2, %3;" : "=l"(ret) : "l"(val), "r"(pos), "r"(len));
@@ -209,7 +209,7 @@ struct Bitfield<uint64_t> {
int pos,
int len) {
uint64_t ret;
-#if defined(__HIPCC__)
+#if defined(PADDLE_WITH_CUDA)
uint64_t mask = ((1ULL << len) - 1ULL) << pos;
ret = (val & ~mask) | ((to_insert << pos) & mask);
#else
@@ -223,7 +223,7 @@ struct Bitfield<uint64_t> {

// --- getLaneId / getLaneMaskLe ---
__device__ __forceinline__ int getLaneId() {
-#if defined(__HIPCC__)
+#if defined(PADDLE_WITH_CUDA)
return __lane_id();
#else
int laneId;
@@ -233,7 +233,7 @@ __device__ __forceinline__ int getLaneId() {
}

__device__ __forceinline__ unsigned getLaneMaskLe() {
-#if defined(__HIPCC__)
+#if defined(PADDLE_WITH_CUDA)
// HIP warp size is 64, construct mask for lanes <= current lane
return (getLaneId() == 63) ? 0xFFFFFFFFFFFFFFFFULL
: (1ULL << (getLaneId() + 1)) - 1ULL;
@@ -245,7 +245,7 @@ __device__ __forceinline__ unsigned getLaneMaskLe() {
}

__device__ __forceinline__ unsigned getLaneMaskLt() {
-#if defined(__HIPCC__)
+#if defined(PADDLE_WITH_CUDA)
return (getLaneId() == 0) ? 0ULL : (1ULL << getLaneId()) - 1ULL;
#else
unsigned mask;
136 changes: 136 additions & 0 deletions backends/metax_gpu/patch/top_p_sampling.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
diff --git a/paddle/phi/kernels/gpu/top_p_sampling_kernel.cu b/paddle/phi/kernels/gpu/top_p_sampling_kernel.cu
index 73da8a62b4..b7ec9080b1 100644
--- a/paddle/phi/kernels/gpu/top_p_sampling_kernel.cu
+++ b/paddle/phi/kernels/gpu/top_p_sampling_kernel.cu
@@ -257,64 +257,40 @@ __device__ __forceinline__ void BlockReduce(Pair<T> shared_max[],
Pair<T> topk[],
Pair<T> beam_max[],
int* beam,
- int* k,
int* count,
const int tid,
const int wid,
const int lane) {
- while (true) {
- __syncthreads();
- Pair<T> input_now = topk[0];
- input_now = WarpReduce(input_now);
+ __syncthreads();
+ Pair<T> input_now = topk[0];
+ input_now = WarpReduce(input_now);

- if (lane == 0) {
- shared_max[wid] = input_now;
- }
- __syncthreads();
- input_now = (tid < BlockSize / WARP_SIZE)
- ? shared_max[lane]
- : Pair<T>(std::numeric_limits<T>::min(), -1);
- if (wid == 0) {
- input_now = WarpReduce(input_now);
- if (lane == 0) shared_max[0] = input_now;
- }
- __syncthreads();
- if (tid == 0) {
- beam_max[*count] = shared_max[0];
- (*count)++;
- }
- int tid_max = shared_max[0].id % BlockSize;
- if (tid == tid_max) {
- (*beam)++;
- }
- if (--(*k) == 0) break;
- __syncthreads();
+ if (lane == 0) {
+ shared_max[wid] = input_now;
+ }
+ __syncthreads();
+ input_now = (tid < BlockSize / WARP_SIZE)
+ ? shared_max[lane]
+ : Pair<T>(std::numeric_limits<T>::min(), -1);
+ if (wid == 0) {
+ input_now = WarpReduce(input_now);
+ if (lane == 0) shared_max[0] = input_now;
+ }
+ __syncthreads();
+ if (tid == 0) {
+ beam_max[*count] = shared_max[0];
+ (*count)++;
+ }
+ int tid_max = shared_max[0].id % BlockSize;
+ if (tid == tid_max) {
+ (*beam)++;
+ }

- if (tid == tid_max) {
- if (*beam < MaxLength) {
- topk[0] = topk[*beam];
- }
- }
+ __syncthreads();

- if (MaxLength < 5) {
- if (*beam >= MaxLength) break;
- } else {
-#ifdef PADDLE_WITH_HIP
- uint64_t mask = 0u;
- mask = __ballot(true);
- if (tid_max / WARP_SIZE == wid) {
- if (__shfl_down(*beam, tid_max % WARP_SIZE, WARP_SIZE) == MaxLength)
- break;
- }
-#else
- unsigned mask = 0u;
- mask = __ballot_sync(FINAL_MASK, true);
- if (tid_max / WARP_SIZE == wid) {
- if (__shfl_down_sync(
- FINAL_MASK, *beam, tid_max % WARP_SIZE, WARP_SIZE) == MaxLength)
- break;
- }
-#endif
+ if (tid == tid_max) {
+ if (*beam < MaxLength) {
+ topk[0] = topk[*beam];
}
}
}
@@ -385,7 +361,7 @@ __global__ void KeMatrixTopPBeamTopK(const T* src,
topk[j].set(std::numeric_limits<T>::min(), -1);
}

- while (top_num) {
+ for (int iter = 0; iter < TopPBeamTopK; ++iter) {
ThreadGetTopK<T, MaxLength, BlockSize>(topk,
&beam,
TopPBeamTopK,
@@ -396,7 +372,7 @@ __global__ void KeMatrixTopPBeamTopK(const T* src,
vocab_size,
tid);
BlockReduce<T, MaxLength, BlockSize>(
- shared_max, topk, beam_max, &beam, &top_num, &count, tid, wid, lane);
+ shared_max, topk, beam_max, &beam, &count, tid, wid, lane);
}
if (tid == 0) {
count_iter_begin[bid] = count_iter[bid];
@@ -488,18 +464,18 @@ __global__ void KeMatrixTopPBeamTopKFt(const T* src,
topk[j].set(std::numeric_limits<T>::min(), -1);
}

- while (top_num) {
+ for (int iter = 0; iter < TopPBeamTopK; ++iter) {
ThreadGetTopK<T, MaxLength, BlockSize>(topk,
&beam,
TopPBeamTopK,
- src + bid * vocab_size,
+ src + offset,
&firststep,
&is_empty,
&max,
vocab_size,
tid);
BlockReduce<T, MaxLength, BlockSize>(
- shared_max, topk, beam_max, &beam, &top_num, &count, tid, wid, lane);
+ shared_max, topk, beam_max, &beam, &count, tid, wid, lane);
}
if (tid == 0) {
count_iter_begin[bid] = count_iter[bid];
Loading