11#include < cub/cub.cuh>
22
33#include " infini_train/include/common/cuda/common_cuda.h"
4+ #include " infini_train/include/common/cuda/cub_compat.cuh"
45#include " infini_train/include/common/cuda/kernel_helper.cuh"
56#include " infini_train/include/dispatcher.h"
67#include " infini_train/include/tensor.h"
@@ -14,22 +15,22 @@ namespace {
1415// Reduction operators
1516template <typename T, typename ReduceFunc> struct CubOp ;
1617
17- template <typename T> struct CubOp <T, ::cuda::std::plus<> > {
18+ template <typename T> struct CubOp <T, CubSumOp > {
1819 __device__ static T Init () { return common::cuda::Cast<T>(0 ); }
1920 __device__ static T Reduce (T a, T b) { return common::cuda::Add<T>(a, b); }
20- __device__ static ::cuda::std::plus<> Op () { return ::cuda::std::plus<> (); }
21+ __device__ static CubSumOp Op () { return CubSumOp (); }
2122};
2223
23- template <typename T> struct CubOp <T, ::cuda::maximum<> > {
24+ template <typename T> struct CubOp <T, CubMaxOp > {
2425 __device__ static T Init () { return common::cuda::Cast<T>(-kInfinity ); }
2526 __device__ static T Reduce (T a, T b) { return common::cuda::Max<T>(a, b); }
26- __device__ static ::cuda::maximum<> Op () { return ::cuda::maximum<> (); }
27+ __device__ static CubMaxOp Op () { return CubMaxOp (); }
2728};
2829
29- template <typename T> struct CubOp <T, ::cuda::minimum<> > {
30+ template <typename T> struct CubOp <T, CubMinOp > {
3031 __device__ static T Init () { return common::cuda::Cast<T>(kInfinity ); }
3132 __device__ static T Reduce (T a, T b) { return common::cuda::Min<T>(a, b); }
32- __device__ static ::cuda::minimum<> Op () { return ::cuda::minimum<> (); }
33+ __device__ static CubMinOp Op () { return CubMinOp (); }
3334};
3435
3536// Finalization strategies
@@ -179,19 +180,19 @@ std::shared_ptr<Tensor> ReduceOpBackward(const std::shared_ptr<Tensor> &grad_out
179180}
180181
181182std::shared_ptr<Tensor> MeanForward (const std::shared_ptr<Tensor> &input, const int64_t dim, const bool keep_dim) {
182- return ReduceOpForward<::cuda::std::plus<> , MeanFinalize>(input, dim, keep_dim);
183+ return ReduceOpForward<CubSumOp , MeanFinalize>(input, dim, keep_dim);
183184}
184185
185186std::shared_ptr<Tensor> SumForward (const std::shared_ptr<Tensor> &input, const int64_t dim, const bool keep_dim) {
186- return ReduceOpForward<::cuda::std::plus<> , IdentityFinalize>(input, dim, keep_dim);
187+ return ReduceOpForward<CubSumOp , IdentityFinalize>(input, dim, keep_dim);
187188}
188189
189190std::shared_ptr<Tensor> MaxForward (const std::shared_ptr<Tensor> &input, const int64_t dim, const bool keep_dim) {
190- return ReduceOpForward<::cuda::maximum<> , IdentityFinalize>(input, dim, keep_dim);
191+ return ReduceOpForward<CubMaxOp , IdentityFinalize>(input, dim, keep_dim);
191192}
192193
193194std::shared_ptr<Tensor> MinForward (const std::shared_ptr<Tensor> &input, const int64_t dim, const bool keep_dim) {
194- return ReduceOpForward<::cuda::minimum<> , IdentityFinalize>(input, dim, keep_dim);
195+ return ReduceOpForward<CubMinOp , IdentityFinalize>(input, dim, keep_dim);
195196}
196197
197198std::shared_ptr<Tensor> MeanBackward (const std::shared_ptr<Tensor> &grad_output, const std::vector<int64_t > &input_dims,
0 commit comments