Skip to content

Commit 51885c9

Browse files
committed
Replace cub::Max() with cuda::maximum<> in kernel reductions
1 parent c6f7d27 commit 51885c9

1 file changed

Lines changed: 7 additions & 6 deletions

File tree

csrc/kernels.cu

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <cub/block/block_reduce.cuh>
1212
#include <cub/block/block_store.cuh>
1313
#include <cub/cub.cuh>
14+
#include <cuda/std/functional>
1415
#include <cub/warp/warp_reduce.cuh>
1516
#include <cuda_fp16.h>
1617
#include <math_constants.h>
@@ -416,7 +417,7 @@ __global__ void kQuantizeBlockwise(
416417
for (int j = 0; j < NUM_PER_TH; j++)
417418
local_abs_max = fmaxf(local_abs_max, fabsf((float)vals[j]));
418419

419-
local_abs_max = BlockReduce(reduce).Reduce(local_abs_max, cub::Max(), valid_items);
420+
local_abs_max = BlockReduce(reduce).Reduce(local_abs_max, cuda::maximum<>{}, valid_items);
420421

421422
if (threadIdx.x == 0) {
422423
smem_absmax_value[0] = 1.0f / local_abs_max;
@@ -1002,9 +1003,9 @@ __global__ void __launch_bounds__(NUM_THREADS, 2) kPreconditionOptimizerStatic8b
10021003
}
10031004

10041005
__syncthreads();
1005-
local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, cub::Max(), valid_items);
1006+
local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, cuda::maximum<>{}, valid_items);
10061007
__syncthreads();
1007-
local_max_s2 = BlockReduce(temp_storage.reduce).Reduce(local_max_s2, cub::Max(), valid_items);
1008+
local_max_s2 = BlockReduce(temp_storage.reduce).Reduce(local_max_s2, cuda::maximum<>{}, valid_items);
10081009
if (unorm != NULL) {
10091010
__syncthreads();
10101011
local_unorm = BlockReduce(temp_storage.reduce).Reduce(local_unorm, cub::Sum(), valid_items);
@@ -1213,7 +1214,7 @@ __global__ void __launch_bounds__(NUM_THREADS, 2) kPreconditionOptimizerStatic8b
12131214
}
12141215

12151216
__syncthreads();
1216-
local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, cub::Max(), valid_items);
1217+
local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, cuda::maximum<>{}, valid_items);
12171218
if (threadIdx.x == 0) {
12181219
atomicMax(&new_max1[0], local_max_s1);
12191220
}
@@ -1843,7 +1844,7 @@ __launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024) __global__
18431844
}
18441845

18451846
// Reduce thread-local absmax across the block.
1846-
const TReduction row_absmax = BlockReduceT(temp_storage).Reduce(row_local_absmax, cub::Max(), cols);
1847+
const TReduction row_absmax = BlockReduceT(temp_storage).Reduce(row_local_absmax, cuda::maximum<>{}, cols);
18471848
if (threadIdx.x == 0) {
18481849
// Save our block's absmax to shared memory for the quantization step.
18491850
rowStats[row_id] = smem_row_absmax = row_absmax;
@@ -1898,7 +1899,7 @@ __launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024) __global__
18981899

18991900
// Reduce thread-local absmax across the block.
19001901
// TODO: Consider algorithm BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY
1901-
const float row_absmax = BlockReduceT(temp_storage).Reduce(row_local_absmax, cub::Max(), cols);
1902+
const float row_absmax = BlockReduceT(temp_storage).Reduce(row_local_absmax, cuda::maximum<>{}, cols);
19021903
if (threadIdx.x == 0) {
19031904
// Save our block's absmax to shared memory for the quantization step.
19041905
rowStats[row_id] = row_absmax;

0 commit comments

Comments
 (0)