11#include " common.cuh"
22
33// Row reduction kernel template - compute sum (norm=false) or mean (norm=true)
4- template <bool norm>
4+ template <bool norm>
55static __global__ void reduce_rows_f32 (const float * __restrict__ x, float * __restrict__ dst, const int ncols) {
66 const int row = blockIdx .x ;
77 const int col = threadIdx .x ;
88
9- float sum = 0 .0f ;
9+ float sum = 0 .0f ;
1010 const int num_unroll = 8 ;
11- float temp[num_unroll];
12- float sum_temp[num_unroll] = {0 .0f };
11+ float temp[num_unroll];
12+ float sum_temp[num_unroll] = { 0 .0f };
1313 for (int i = col; i < ncols;) {
14- for (int j = 0 ; j < num_unroll; ++j){
15- if (i < ncols){
14+ for (int j = 0 ; j < num_unroll; ++j) {
15+ if (i < ncols) {
1616 temp[j] = x[row * ncols + i];
17- }
18- else {
17+ } else {
1918 temp[j] = 0 ;
2019 }
2120 i += blockDim .x ;
2221 }
23- for (int j = 0 ; j < num_unroll; ++j){
22+ for (int j = 0 ; j < num_unroll; ++j) {
2423 sum_temp[j] += temp[j];
2524 }
2625 }
27- for (int j = 0 ; j < num_unroll; ++j){
28- sum += sum_temp[j];
26+ for (int j = 0 ; j < num_unroll; ++j) {
27+ sum += sum_temp[j];
2928 }
3029
3130 // sum up partial sums
3231 sum = warp_reduce_sum (sum);
3332 if (blockDim .x > WARP_SIZE) {
3433 assert ((blockDim .x <= 1024 ) && (blockDim .x % WARP_SIZE) == 0 );
3534 __shared__ float s_sum[32 ];
36- const int warp_id = threadIdx .x / WARP_SIZE;
37- const int lane_id = threadIdx .x % WARP_SIZE;
35+ const int warp_id = threadIdx .x / WARP_SIZE;
36+ const int lane_id = threadIdx .x % WARP_SIZE;
3837 if (lane_id == 0 ) {
3938 s_sum[warp_id] = sum;
4039 }
@@ -51,4 +50,4 @@ static __global__ void reduce_rows_f32(const float * __restrict__ x, float * __r
5150 }
5251
5352 dst[row] = norm ? sum / ncols : sum;
54- }
53+ }
0 commit comments