11#include "../../../devices/bang/common_bang.h"
22#include "asum_bang.h"
33
4+ #include <type_traits>
5+
46__nram__ char nram_buffer[NRAM_MAX_SIZE];
57
8+ template <typename Tdata>
9+ __mlu_device__ void asumToCompute(float *dst, const Tdata *src, int size) {
10+ if constexpr (std::is_same_v<Tdata, half>) {
11+ __bang_half2float(dst, src, size);
12+ } else if constexpr (std::is_same_v<Tdata, bfloat16_t>) {
13+ __bang_bfloat162float(dst, src, size);
14+ } else {
15+ __memcpy(dst, src, size * sizeof(float), NRAM2NRAM);
16+ }
17+ }
18+
19+ template <typename Tdata>
20+ __mlu_device__ float asumToCompute(Tdata value) {
21+ if constexpr (std::is_same_v<Tdata, half>) {
22+ return __half2float(value);
23+ } else if constexpr (std::is_same_v<Tdata, bfloat16_t>) {
24+ return __bfloat162float(value);
25+ } else {
26+ return static_cast<float>(value);
27+ }
28+ }
29+
30+ template <typename Tdata>
31+ __mlu_device__ void asumStoreResult(Tdata *result, Tdata *nram_result, float *nram_compute, float value) {
32+ nram_compute[0] = value;
33+ if constexpr (std::is_same_v<Tdata, half>) {
34+ __bang_float2half(nram_result, nram_compute, 1);
35+ result[0] = nram_result[0];
36+ } else if constexpr (std::is_same_v<Tdata, bfloat16_t>) {
37+ __bang_float2bfloat16(nram_result, nram_compute, 1);
38+ result[0] = nram_result[0];
39+ } else {
40+ result[0] = nram_compute[0];
41+ }
42+ }
43+
644template <typename Tdata>
745__mlu_global__ void asumKernelContiguous(
8- size_t n,
46+ int n,
947 const Tdata *x,
1048 Tdata *result) {
1149
12- __mlu_shared__ Tdata shared_partial_sum[4];
50+ __mlu_shared__ float shared_partial_sum[4];
1351
14- Tdata *nram_x = (Tdata *)(((size_t)nram_buffer + ALIGN_SIZE - 1) & ~(ALIGN_SIZE - 1));
52+ char *nram_aligned = (char *)(((size_t)nram_buffer + ALIGN_SIZE - 1) & ~(ALIGN_SIZE - 1));
1553
16- size_t nram_usable = NRAM_MAX_SIZE - ((char *)nram_x - nram_buffer);
17- size_t max_chunk_elements = nram_usable / sizeof(Tdata);
54+ size_t nram_usable = NRAM_MAX_SIZE - (nram_aligned - nram_buffer);
55+ size_t max_chunk_elements = nram_usable / ( sizeof(Tdata) + sizeof(float) );
1856
19- int align_elements = ALIGN_SIZE / sizeof(Tdata);
57+ size_t align_elements = ALIGN_SIZE / sizeof(Tdata);
2058 if (align_elements == 0) {
2159 align_elements = 1;
2260 }
23- max_chunk_elements = (max_chunk_elements / align_elements) * align_elements;
61+ int chunk_size = (int)((max_chunk_elements / align_elements) * align_elements);
62+
63+ Tdata *nram_x = (Tdata *)nram_aligned;
64+ float *nram_compute = (float *)(nram_x + chunk_size);
2465
2566 int elements_per_core = n / taskDim;
2667 int remain = n % taskDim;
2768 int core_elements = elements_per_core + (taskId < remain ? 1 : 0);
2869 int core_offset = taskId < remain ? taskId * core_elements : taskId * elements_per_core + remain;
2970
30- int chunks = core_elements / max_chunk_elements ;
31- int chunk_rem = core_elements % max_chunk_elements ;
71+ int chunks = core_elements / chunk_size ;
72+ int chunk_rem = core_elements % chunk_size ;
3273
33- Tdata partial_sum = static_cast<Tdata>(0) ;
74+ float partial_sum = 0.0f ;
3475
3576 for (int c = 0; c < chunks; c++) {
36- size_t current_offset = core_offset + c * max_chunk_elements;
37- __memcpy(nram_x, x + current_offset, max_chunk_elements * sizeof(Tdata), GDRAM2NRAM);
77+ int current_offset = core_offset + c * chunk_size;
3878
39- __bang_abs (nram_x, nram_x, max_chunk_elements );
79+ __memcpy (nram_x, x + current_offset, chunk_size * sizeof(Tdata), GDRAM2NRAM );
4080
41- partial_sum += __bang_sum(nram_x, max_chunk_elements);
81+ asumToCompute(nram_compute, nram_x, chunk_size);
82+ __bang_abs(nram_compute, nram_compute, chunk_size);
83+
84+ partial_sum += __bang_sum(nram_compute, chunk_size);
4285 }
4386
4487 if (chunk_rem > 0) {
45- size_t current_offset = core_offset + chunks * max_chunk_elements ;
88+ int current_offset = core_offset + chunks * chunk_size ;
4689
4790 __memcpy(nram_x, x + current_offset, chunk_rem * sizeof(Tdata), GDRAM2NRAM);
4891
49- __bang_abs(nram_x, nram_x, chunk_rem);
92+ asumToCompute(nram_compute, nram_x, chunk_rem);
93+ __bang_abs(nram_compute, nram_compute, chunk_rem);
5094
51- partial_sum += __bang_sum(nram_x , chunk_rem);
95+ partial_sum += __bang_sum(nram_compute , chunk_rem);
5296 }
5397
5498 shared_partial_sum[coreId] = partial_sum;
5599
56100 __sync_cluster();
57101
58102 if (coreId == 0) {
59- Tdata cluster_sum = static_cast<Tdata>(0) ;
103+ float cluster_sum = 0.0f ;
60104
61105 for (int i = 0; i < coreDim; i++) {
62106 cluster_sum += shared_partial_sum[i];
63107 }
64108
65- result[0] = cluster_sum;
109+ asumStoreResult( result, nram_x, nram_compute, cluster_sum) ;
66110 }
67111}
68112
69113template <typename Tdata>
70114__mlu_global__ void asumKernelStrided(
71- size_t n,
115+ int n,
72116 const Tdata *x,
73- size_t incx,
117+ int incx,
74118 Tdata *result) {
75119
76- __mlu_shared__ Tdata shared_partial_sum[4];
120+ __mlu_shared__ float shared_partial_sum[4];
121+
122+ char *nram_aligned = (char *)(((size_t)nram_buffer + ALIGN_SIZE - 1) & ~(ALIGN_SIZE - 1));
123+
124+ float *nram_compute = (float *)nram_aligned;
125+ Tdata *nram_result = (Tdata *)(nram_compute + 1);
77126
78127 int elements_per_core = n / taskDim;
79128 int remain = n % taskDim;
80129 int actual_tasks = elements_per_core + (taskId < remain ? 1 : 0);
81130 int start_idx = taskId < remain ? taskId * actual_tasks : taskId * elements_per_core + remain;
82131
83- Tdata partial_sum = static_cast<Tdata>(0) ;
132+ float partial_sum = 0.0f ;
84133
85134 for (int i = start_idx; i < start_idx + actual_tasks; ++i) {
86- size_t offset = i * incx;
87- Tdata abs_val = x[offset] > static_cast<Tdata>(0) ? x[offset] : -x[offset];
135+ int offset = i * incx;
136+ float x_val = asumToCompute(x[offset]);
137+ float abs_val = x_val > 0.0f ? x_val : -x_val;
88138
89139 partial_sum += abs_val;
90140 }
@@ -94,12 +144,12 @@ __mlu_global__ void asumKernelStrided(
94144 __sync_cluster();
95145
96146 if (coreId == 0) {
97- Tdata cluster_sum = static_cast<Tdata>(0) ;
147+ float cluster_sum = 0.0f ;
98148
99149 for (int i = 0; i < coreDim; i++) {
100150 cluster_sum += shared_partial_sum[i];
101151 }
102152
103- result[0] = cluster_sum;
153+ asumStoreResult( result, nram_result, nram_compute, cluster_sum) ;
104154 }
105- }
155+ }
0 commit comments