Skip to content

Commit 6dc9b51

Browse files
committed
Replace cub::Max() with cuda::maximum<> in kernel reductions
1 parent 51885c9 commit 6dc9b51

1 file changed

Lines changed: 6 additions & 6 deletions

File tree

csrc/kernels.cu

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1008,7 +1008,7 @@ __global__ void __launch_bounds__(NUM_THREADS, 2) kPreconditionOptimizerStatic8b
10081008
local_max_s2 = BlockReduce(temp_storage.reduce).Reduce(local_max_s2, cuda::maximum<>{}, valid_items);
10091009
if (unorm != NULL) {
10101010
__syncthreads();
1011-
local_unorm = BlockReduce(temp_storage.reduce).Reduce(local_unorm, cub::Sum(), valid_items);
1011+
local_unorm = BlockReduce(temp_storage.reduce).Reduce(local_unorm, cuda::std::plus<>{}, valid_items);
10121012
}
10131013

10141014
if (threadIdx.x == 0) {
@@ -1220,7 +1220,7 @@ __global__ void __launch_bounds__(NUM_THREADS, 2) kPreconditionOptimizerStatic8b
12201220
}
12211221
if (unorm != NULL) {
12221222
__syncthreads();
1223-
local_unorm = BlockReduce(temp_storage.reduce).Reduce(local_unorm, cub::Sum(), valid_items);
1223+
local_unorm = BlockReduce(temp_storage.reduce).Reduce(local_unorm, cuda::std::plus<>{}, valid_items);
12241224
if (threadIdx.x == 0) {
12251225
atomicAdd(&unorm[0], local_unorm);
12261226
}
@@ -1525,11 +1525,11 @@ __launch_bounds__(256, 3) __global__ void kOptimizerStatic8bit2StateBlockwise(
15251525
}
15261526

15271527
// reduce: 2.51/1.60 -> 2.67/1.69
1528-
new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, cub::Max());
1529-
new_local_abs_max2 = BlockReduce2(reduce2).Reduce(new_local_abs_max2, cub::Max());
1528+
new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, cuda::maximum<>{});
1529+
new_local_abs_max2 = BlockReduce2(reduce2).Reduce(new_local_abs_max2, cuda::maximum<>{});
15301530

15311531
if (OPTIMIZER == ADEMAMIX) {
1532-
new_local_abs_max3 = BlockReduce3(reduce3).Reduce(new_local_abs_max3, cub::Max());
1532+
new_local_abs_max3 = BlockReduce3(reduce3).Reduce(new_local_abs_max3, cuda::maximum<>{});
15331533
}
15341534

15351535
if (threadIdx.x == 0) {
@@ -1738,7 +1738,7 @@ __launch_bounds__(256, 3) __global__ void kOptimizerStatic8bit1StateBlockwise(
17381738
}
17391739

17401740
// reduce: 2.51/1.60 -> 2.67/1.69
1741-
new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, cub::Max());
1741+
new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, cuda::maximum<>{});
17421742

17431743
if (threadIdx.x == 0)
17441744
smem_exchange1[0] = new_local_abs_max1;

0 commit comments

Comments
 (0)