@@ -1008,7 +1008,7 @@ __global__ void __launch_bounds__(NUM_THREADS, 2) kPreconditionOptimizerStatic8b
10081008 local_max_s2 = BlockReduce (temp_storage.reduce ).Reduce (local_max_s2, cuda::maximum<>{}, valid_items);
10091009 if (unorm != NULL ) {
10101010 __syncthreads ();
1011- local_unorm = BlockReduce (temp_storage.reduce ).Reduce (local_unorm, cub::Sum () , valid_items);
1011+ local_unorm = BlockReduce (temp_storage.reduce ).Reduce (local_unorm, cuda::std::plus<>{} , valid_items);
10121012 }
10131013
10141014 if (threadIdx .x == 0 ) {
@@ -1220,7 +1220,7 @@ __global__ void __launch_bounds__(NUM_THREADS, 2) kPreconditionOptimizerStatic8b
12201220 }
12211221 if (unorm != NULL ) {
12221222 __syncthreads ();
1223- local_unorm = BlockReduce (temp_storage.reduce ).Reduce (local_unorm, cub::Sum () , valid_items);
1223+ local_unorm = BlockReduce (temp_storage.reduce ).Reduce (local_unorm, cuda::std::plus<>{} , valid_items);
12241224 if (threadIdx .x == 0 ) {
12251225 atomicAdd (&unorm[0 ], local_unorm);
12261226 }
@@ -1525,11 +1525,11 @@ __launch_bounds__(256, 3) __global__ void kOptimizerStatic8bit2StateBlockwise(
15251525 }
15261526
15271527 // reduce: 2.51/1.60 -> 2.67/1.69
1528- new_local_abs_max1 = BlockReduce1 (reduce1).Reduce (new_local_abs_max1, cub::Max () );
1529- new_local_abs_max2 = BlockReduce2 (reduce2).Reduce (new_local_abs_max2, cub::Max () );
1528+ new_local_abs_max1 = BlockReduce1 (reduce1).Reduce (new_local_abs_max1, cuda::maximum<>{} );
1529+ new_local_abs_max2 = BlockReduce2 (reduce2).Reduce (new_local_abs_max2, cuda::maximum<>{} );
15301530
15311531 if (OPTIMIZER == ADEMAMIX) {
1532- new_local_abs_max3 = BlockReduce3 (reduce3).Reduce (new_local_abs_max3, cub::Max () );
1532+ new_local_abs_max3 = BlockReduce3 (reduce3).Reduce (new_local_abs_max3, cuda::maximum<>{} );
15331533 }
15341534
15351535 if (threadIdx .x == 0 ) {
@@ -1738,7 +1738,7 @@ __launch_bounds__(256, 3) __global__ void kOptimizerStatic8bit1StateBlockwise(
17381738 }
17391739
17401740 // reduce: 2.51/1.60 -> 2.67/1.69
1741- new_local_abs_max1 = BlockReduce1 (reduce1).Reduce (new_local_abs_max1, cub::Max () );
1741+ new_local_abs_max1 = BlockReduce1 (reduce1).Reduce (new_local_abs_max1, cuda::maximum<>{} );
17421742
17431743 if (threadIdx .x == 0 )
17441744 smem_exchange1[0 ] = new_local_abs_max1;
0 commit comments