Skip to content

Commit 45c11cb

Browse files
JYMiracle305kilinchange
authored andcommitted
feat: compatible with Cub versions
1 parent c4703e1 commit 45c11cb

4 files changed

Lines changed: 33 additions & 13 deletions

File tree

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#pragma once
2+
3+
#include <cub/version.cuh>
4+
5+
namespace infini_train::kernels::cuda {
6+
7+
#if defined(CUB_VERSION) && CUB_VERSION >= 200800
8+
using CubSumOp = ::cuda::std::plus<>;
9+
using CubMaxOp = ::cuda::maximum<>;
10+
using CubMinOp = ::cuda::minimum<>;
11+
#else
12+
using CubSumOp = cub::Sum;
13+
using CubMaxOp = cub::Max;
14+
using CubMinOp = cub::Min;
15+
#endif
16+
17+
} // namespace infini_train::kernels::cuda

infini_train/src/kernels/cuda/cross_entropy.cu

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <cuda_runtime.h>
77

88
#include "infini_train/include/common/cuda/common_cuda.h"
9+
#include "infini_train/include/common/cuda/cub_compat.cuh"
910
#include "infini_train/include/common/cuda/kernel_helper.cuh"
1011
#include "infini_train/include/dispatcher.h"
1112
#include "infini_train/include/tensor.h"
@@ -44,7 +45,7 @@ __global__ void CrossEntropyForwardKernel(const InputType *__restrict__ input_pt
4445
for (int i = tid; i < num_classes; i += BLOCK_SIZE) {
4546
thread_max = fmaxf(thread_max, common::cuda::Cast<float>(input_ptr[base + i]));
4647
}
47-
const float block_max = cub::BlockReduce<float, BLOCK_SIZE>(shared.reduce).Reduce(thread_max, ::cuda::maximum<>());
48+
const float block_max = cub::BlockReduce<float, BLOCK_SIZE>(shared.reduce).Reduce(thread_max, CubMaxOp());
4849
if (tid == 0) {
4950
shared.max_logit = block_max;
5051
}
@@ -139,7 +140,7 @@ __global__ void CrossEntropyBackwardKernel(const InputType *__restrict__ input_p
139140
for (int i = tid; i < num_classes; i += BLOCK_SIZE) {
140141
thread_max = fmaxf(thread_max, common::cuda::Cast<float>(input_ptr[idx_base + i]));
141142
}
142-
const float block_max = cub::BlockReduce<float, BLOCK_SIZE>(shared.reduce).Reduce(thread_max, ::cuda::maximum<>());
143+
const float block_max = cub::BlockReduce<float, BLOCK_SIZE>(shared.reduce).Reduce(thread_max, CubMaxOp());
143144
if (tid == 0) {
144145
shared.max_logit = block_max;
145146
}

infini_train/src/kernels/cuda/reduction.cu

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include <cub/cub.cuh>
22

33
#include "infini_train/include/common/cuda/common_cuda.h"
4+
#include "infini_train/include/common/cuda/cub_compat.cuh"
45
#include "infini_train/include/common/cuda/kernel_helper.cuh"
56
#include "infini_train/include/dispatcher.h"
67
#include "infini_train/include/tensor.h"
@@ -14,22 +15,22 @@ namespace {
1415
// Reduction operators
1516
template <typename T, typename ReduceFunc> struct CubOp;
1617

17-
template <typename T> struct CubOp<T, ::cuda::std::plus<>> {
18+
template <typename T> struct CubOp<T, CubSumOp> {
1819
__device__ static T Init() { return common::cuda::Cast<T>(0); }
1920
__device__ static T Reduce(T a, T b) { return common::cuda::Add<T>(a, b); }
20-
__device__ static ::cuda::std::plus<> Op() { return ::cuda::std::plus<>(); }
21+
__device__ static CubSumOp Op() { return CubSumOp(); }
2122
};
2223

23-
template <typename T> struct CubOp<T, ::cuda::maximum<>> {
24+
template <typename T> struct CubOp<T, CubMaxOp> {
2425
__device__ static T Init() { return common::cuda::Cast<T>(-kInfinity); }
2526
__device__ static T Reduce(T a, T b) { return common::cuda::Max<T>(a, b); }
26-
__device__ static ::cuda::maximum<> Op() { return ::cuda::maximum<>(); }
27+
__device__ static CubMaxOp Op() { return CubMaxOp(); }
2728
};
2829

29-
template <typename T> struct CubOp<T, ::cuda::minimum<>> {
30+
template <typename T> struct CubOp<T, CubMinOp> {
3031
__device__ static T Init() { return common::cuda::Cast<T>(kInfinity); }
3132
__device__ static T Reduce(T a, T b) { return common::cuda::Min<T>(a, b); }
32-
__device__ static ::cuda::minimum<> Op() { return ::cuda::minimum<>(); }
33+
__device__ static CubMinOp Op() { return CubMinOp(); }
3334
};
3435

3536
// Finalization strategies
@@ -179,19 +180,19 @@ std::shared_ptr<Tensor> ReduceOpBackward(const std::shared_ptr<Tensor> &grad_out
179180
}
180181
181182
std::shared_ptr<Tensor> MeanForward(const std::shared_ptr<Tensor> &input, const int64_t dim, const bool keep_dim) {
182-
return ReduceOpForward<::cuda::std::plus<>, MeanFinalize>(input, dim, keep_dim);
183+
return ReduceOpForward<CubSumOp, MeanFinalize>(input, dim, keep_dim);
183184
}
184185
185186
std::shared_ptr<Tensor> SumForward(const std::shared_ptr<Tensor> &input, const int64_t dim, const bool keep_dim) {
186-
return ReduceOpForward<::cuda::std::plus<>, IdentityFinalize>(input, dim, keep_dim);
187+
return ReduceOpForward<CubSumOp, IdentityFinalize>(input, dim, keep_dim);
187188
}
188189
189190
std::shared_ptr<Tensor> MaxForward(const std::shared_ptr<Tensor> &input, const int64_t dim, const bool keep_dim) {
190-
return ReduceOpForward<::cuda::maximum<>, IdentityFinalize>(input, dim, keep_dim);
191+
return ReduceOpForward<CubMaxOp, IdentityFinalize>(input, dim, keep_dim);
191192
}
192193
193194
std::shared_ptr<Tensor> MinForward(const std::shared_ptr<Tensor> &input, const int64_t dim, const bool keep_dim) {
194-
return ReduceOpForward<::cuda::minimum<>, IdentityFinalize>(input, dim, keep_dim);
195+
return ReduceOpForward<CubMinOp, IdentityFinalize>(input, dim, keep_dim);
195196
}
196197
197198
std::shared_ptr<Tensor> MeanBackward(const std::shared_ptr<Tensor> &grad_output, const std::vector<int64_t> &input_dims,

infini_train/src/kernels/cuda/softmax.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "glog/logging.h"
77

88
#include "infini_train/include/common/cuda/common_cuda.h"
9+
#include "infini_train/include/common/cuda/cub_compat.cuh"
910
#include "infini_train/include/common/cuda/kernel_helper.cuh"
1011
#include "infini_train/include/dispatcher.h"
1112
#include "infini_train/include/tensor.h"
@@ -31,7 +32,7 @@ __global__ void SoftmaxForwardKernel(T *output, const T *input, int64_t outer_si
3132
int64_t idx = (group * axis_size + axis) * inner_size + inner_idx;
3233
thread_max = max(thread_max, common::cuda::Cast<float>(input[idx]));
3334
}
34-
float block_max = BlockReduce(temp_storage_max).Reduce(thread_max, ::cuda::maximum<>());
35+
float block_max = BlockReduce(temp_storage_max).Reduce(thread_max, CubMaxOp());
3536

3637
if (tid == 0) {
3738
row_max = block_max;

0 commit comments

Comments
 (0)