Skip to content

Commit f91e4d3

Browse files
committed
Fix format and type conversion errors
1 parent 74564bd commit f91e4d3

45 files changed

Lines changed: 579 additions & 429 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

python/infinicore/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,14 +64,14 @@
6464
from infinicore.ops.axpy import axpy
6565
from infinicore.ops.baddbmm import baddbmm
6666
from infinicore.ops.bilinear import bilinear
67-
from infinicore.ops.blas_amax import blas_amax
68-
from infinicore.ops.blas_amin import blas_amin
69-
from infinicore.ops.blas_copy import blas_copy
70-
from infinicore.ops.blas_dot import blas_dot
7167
from infinicore.ops.binary_cross_entropy_with_logits import (
7268
binary_cross_entropy_with_logits,
7369
)
7470
from infinicore.ops.bitwise_right_shift import bitwise_right_shift
71+
from infinicore.ops.blas_amax import blas_amax
72+
from infinicore.ops.blas_amin import blas_amin
73+
from infinicore.ops.blas_copy import blas_copy
74+
from infinicore.ops.blas_dot import blas_dot
7575
from infinicore.ops.block_diag import block_diag
7676
from infinicore.ops.broadcast_to import broadcast_to
7777
from infinicore.ops.cat import cat
@@ -121,9 +121,9 @@
121121
from infinicore.ops.scal import scal
122122
from infinicore.ops.scatter import scatter
123123
from infinicore.ops.sinh import sinh
124-
from infinicore.ops.swap import swap
125124
from infinicore.ops.squeeze import squeeze
126125
from infinicore.ops.sum import sum
126+
from infinicore.ops.swap import swap
127127
from infinicore.ops.take import take
128128
from infinicore.ops.tan import tan
129129
from infinicore.ops.topk import topk

src/infiniop/ops/asum/bang/asum_bang.mlu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ infiniStatus_t calculateAsum(
3333
Tdata *result,
3434
cnrtQueue_t queue) {
3535

36-
const size_t n = info.n;
37-
const ptrdiff_t incx = info.incx;
36+
const int n = utils::cast<int>(info.n);
37+
const int incx = utils::cast<int>(info.incx);
3838

3939
cnrtDim3_t k_dim;
4040
cnrtFunctionType_t k_type;
Lines changed: 78 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,90 +1,140 @@
11
#include "../../../devices/bang/common_bang.h"
22
#include "asum_bang.h"
33

4+
#include <type_traits>
5+
46
__nram__ char nram_buffer[NRAM_MAX_SIZE];
57

8+
template <typename Tdata>
9+
__mlu_device__ void asumToCompute(float *dst, const Tdata *src, int size) {
10+
if constexpr (std::is_same_v<Tdata, half>) {
11+
__bang_half2float(dst, src, size);
12+
} else if constexpr (std::is_same_v<Tdata, bfloat16_t>) {
13+
__bang_bfloat162float(dst, src, size);
14+
} else {
15+
__memcpy(dst, src, size * sizeof(float), NRAM2NRAM);
16+
}
17+
}
18+
19+
template <typename Tdata>
20+
__mlu_device__ float asumToCompute(Tdata value) {
21+
if constexpr (std::is_same_v<Tdata, half>) {
22+
return __half2float(value);
23+
} else if constexpr (std::is_same_v<Tdata, bfloat16_t>) {
24+
return __bfloat162float(value);
25+
} else {
26+
return static_cast<float>(value);
27+
}
28+
}
29+
30+
template <typename Tdata>
31+
__mlu_device__ void asumStoreResult(Tdata *result, Tdata *nram_result, float *nram_compute, float value) {
32+
nram_compute[0] = value;
33+
if constexpr (std::is_same_v<Tdata, half>) {
34+
__bang_float2half(nram_result, nram_compute, 1);
35+
result[0] = nram_result[0];
36+
} else if constexpr (std::is_same_v<Tdata, bfloat16_t>) {
37+
__bang_float2bfloat16(nram_result, nram_compute, 1);
38+
result[0] = nram_result[0];
39+
} else {
40+
result[0] = nram_compute[0];
41+
}
42+
}
43+
644
template <typename Tdata>
745
__mlu_global__ void asumKernelContiguous(
8-
size_t n,
46+
int n,
947
const Tdata *x,
1048
Tdata *result) {
1149

12-
__mlu_shared__ Tdata shared_partial_sum[4];
50+
__mlu_shared__ float shared_partial_sum[4];
1351

14-
Tdata *nram_x = (Tdata *)(((size_t)nram_buffer + ALIGN_SIZE - 1) & ~(ALIGN_SIZE - 1));
52+
char *nram_aligned = (char *)(((size_t)nram_buffer + ALIGN_SIZE - 1) & ~(ALIGN_SIZE - 1));
1553

16-
size_t nram_usable = NRAM_MAX_SIZE - ((char *)nram_x - nram_buffer);
17-
size_t max_chunk_elements = nram_usable / sizeof(Tdata);
54+
size_t nram_usable = NRAM_MAX_SIZE - (nram_aligned - nram_buffer);
55+
size_t max_chunk_elements = nram_usable / (sizeof(Tdata) + sizeof(float));
1856

19-
int align_elements = ALIGN_SIZE / sizeof(Tdata);
57+
size_t align_elements = ALIGN_SIZE / sizeof(Tdata);
2058
if (align_elements == 0) {
2159
align_elements = 1;
2260
}
23-
max_chunk_elements = (max_chunk_elements / align_elements) * align_elements;
61+
int chunk_size = (int)((max_chunk_elements / align_elements) * align_elements);
62+
63+
Tdata *nram_x = (Tdata *)nram_aligned;
64+
float *nram_compute = (float *)(nram_x + chunk_size);
2465

2566
int elements_per_core = n / taskDim;
2667
int remain = n % taskDim;
2768
int core_elements = elements_per_core + (taskId < remain ? 1 : 0);
2869
int core_offset = taskId < remain ? taskId * core_elements : taskId * elements_per_core + remain;
2970

30-
int chunks = core_elements / max_chunk_elements;
31-
int chunk_rem = core_elements % max_chunk_elements;
71+
int chunks = core_elements / chunk_size;
72+
int chunk_rem = core_elements % chunk_size;
3273

33-
Tdata partial_sum = static_cast<Tdata>(0);
74+
float partial_sum = 0.0f;
3475

3576
for (int c = 0; c < chunks; c++) {
36-
size_t current_offset = core_offset + c * max_chunk_elements;
37-
__memcpy(nram_x, x + current_offset, max_chunk_elements * sizeof(Tdata), GDRAM2NRAM);
77+
int current_offset = core_offset + c * chunk_size;
3878

39-
__bang_abs(nram_x, nram_x, max_chunk_elements);
79+
__memcpy(nram_x, x + current_offset, chunk_size * sizeof(Tdata), GDRAM2NRAM);
4080

41-
partial_sum += __bang_sum(nram_x, max_chunk_elements);
81+
asumToCompute(nram_compute, nram_x, chunk_size);
82+
__bang_abs(nram_compute, nram_compute, chunk_size);
83+
84+
partial_sum += __bang_sum(nram_compute, chunk_size);
4285
}
4386

4487
if (chunk_rem > 0) {
45-
size_t current_offset = core_offset + chunks * max_chunk_elements;
88+
int current_offset = core_offset + chunks * chunk_size;
4689

4790
__memcpy(nram_x, x + current_offset, chunk_rem * sizeof(Tdata), GDRAM2NRAM);
4891

49-
__bang_abs(nram_x, nram_x, chunk_rem);
92+
asumToCompute(nram_compute, nram_x, chunk_rem);
93+
__bang_abs(nram_compute, nram_compute, chunk_rem);
5094

51-
partial_sum += __bang_sum(nram_x, chunk_rem);
95+
partial_sum += __bang_sum(nram_compute, chunk_rem);
5296
}
5397

5498
shared_partial_sum[coreId] = partial_sum;
5599

56100
__sync_cluster();
57101

58102
if (coreId == 0) {
59-
Tdata cluster_sum = static_cast<Tdata>(0);
103+
float cluster_sum = 0.0f;
60104

61105
for (int i = 0; i < coreDim; i++) {
62106
cluster_sum += shared_partial_sum[i];
63107
}
64108

65-
result[0] = cluster_sum;
109+
asumStoreResult(result, nram_x, nram_compute, cluster_sum);
66110
}
67111
}
68112

69113
template <typename Tdata>
70114
__mlu_global__ void asumKernelStrided(
71-
size_t n,
115+
int n,
72116
const Tdata *x,
73-
size_t incx,
117+
int incx,
74118
Tdata *result) {
75119

76-
__mlu_shared__ Tdata shared_partial_sum[4];
120+
__mlu_shared__ float shared_partial_sum[4];
121+
122+
char *nram_aligned = (char *)(((size_t)nram_buffer + ALIGN_SIZE - 1) & ~(ALIGN_SIZE - 1));
123+
124+
float *nram_compute = (float *)nram_aligned;
125+
Tdata *nram_result = (Tdata *)(nram_compute + 1);
77126

78127
int elements_per_core = n / taskDim;
79128
int remain = n % taskDim;
80129
int actual_tasks = elements_per_core + (taskId < remain ? 1 : 0);
81130
int start_idx = taskId < remain ? taskId * actual_tasks : taskId * elements_per_core + remain;
82131

83-
Tdata partial_sum = static_cast<Tdata>(0);
132+
float partial_sum = 0.0f;
84133

85134
for (int i = start_idx; i < start_idx + actual_tasks; ++i) {
86-
size_t offset = i * incx;
87-
Tdata abs_val = x[offset] > static_cast<Tdata>(0) ? x[offset] : -x[offset];
135+
int offset = i * incx;
136+
float x_val = asumToCompute(x[offset]);
137+
float abs_val = x_val > 0.0f ? x_val : -x_val;
88138

89139
partial_sum += abs_val;
90140
}
@@ -94,12 +144,12 @@ __mlu_global__ void asumKernelStrided(
94144
__sync_cluster();
95145

96146
if (coreId == 0) {
97-
Tdata cluster_sum = static_cast<Tdata>(0);
147+
float cluster_sum = 0.0f;
98148

99149
for (int i = 0; i < coreDim; i++) {
100150
cluster_sum += shared_partial_sum[i];
101151
}
102152

103-
result[0] = cluster_sum;
153+
asumStoreResult(result, nram_result, nram_compute, cluster_sum);
104154
}
105-
}
155+
}

src/infiniop/ops/asum/cpu/asum_cpu.cc

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,22 +31,24 @@ infiniStatus_t calculateAsum(
3131
const Tdata *x,
3232
Tdata *result) {
3333

34-
const ptrdiff_t n = info.n;
34+
const size_t n = info.n;
3535
const ptrdiff_t incx = info.incx;
3636

3737
if constexpr (std::is_same<Tdata, fp16_t>::value || std::is_same<Tdata, bf16_t>::value) {
3838
float total_sum = 0.0;
3939

40-
for (ptrdiff_t i = 0; i < n; ++i) {
41-
total_sum += std::abs(utils::cast<float>(x[i * incx]));
40+
for (size_t i = 0; i < n; ++i) {
41+
const ptrdiff_t idx = utils::cast<ptrdiff_t>(i) * incx;
42+
total_sum += std::abs(utils::cast<float>(x[idx]));
4243
}
4344

4445
result[0] = utils::cast<Tdata>(total_sum);
4546
} else {
4647
Tdata total_sum = 0.0;
4748

48-
for (ptrdiff_t i = 0; i < n; ++i) {
49-
total_sum += std::abs(x[i * incx]);
49+
for (size_t i = 0; i < n; ++i) {
50+
const ptrdiff_t idx = utils::cast<ptrdiff_t>(i) * incx;
51+
total_sum += std::abs(x[idx]);
5052
}
5153

5254
result[0] = total_sum;

src/infiniop/ops/asum/metax/asum_metax.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ infiniStatus_t Descriptor::calculate(
4242
(void)workspace;
4343
(void)workspace_size;
4444

45-
const size_t n = _info.n;
46-
const ptrdiff_t incx = _info.incx;
45+
const int n = utils::cast<int>(_info.n);
46+
const int incx = utils::cast<int>(_info.incx);
4747
const infiniDtype_t data_type = _info.data_type;
4848

4949
CHECK_STATUS(_opaque->internal->useMcblas(

src/infiniop/ops/axpy/bang/axpy_bang.mlu

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@ infiniStatus_t calculateAxpy(
3636
Tdata *y,
3737
cnrtQueue_t queue) {
3838

39-
const size_t size = info.n;
40-
const ptrdiff_t incx = info.incx;
41-
const ptrdiff_t incy = info.incy;
39+
const int n = utils::cast<int>(info.n);
40+
const int incx = utils::cast<int>(info.incx);
41+
const int incy = utils::cast<int>(info.incy);
4242

4343
cnrtDim3_t k_dim;
4444
cnrtFunctionType_t k_type;
@@ -50,13 +50,13 @@ infiniStatus_t calculateAxpy(
5050

5151
if (incx == 1 && incy == 1) {
5252
axpyKernelContiguous<Tdata><<<k_dim, k_type, queue>>>(
53-
size,
53+
n,
5454
alpha,
5555
x,
5656
y);
5757
} else {
5858
axpyKernelStrided<Tdata><<<k_dim, k_type, queue>>>(
59-
size,
59+
n,
6060
alpha,
6161
x,
6262
incx,

src/infiniop/ops/axpy/bang/axpy_bang_kernel.mlu

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,24 +5,24 @@ __nram__ char nram_buffer[NRAM_MAX_SIZE];
55

66
template <typename Tdata>
77
__mlu_global__ void axpyKernelContiguous(
8-
size_t n,
8+
int n,
99
const Tdata *alpha,
1010
const Tdata *x,
1111
Tdata *y) {
1212

13-
Tdata *nram_align = (Tdata *)(((size_t)nram_buffer + ALIGN_SIZE - 1) & ~(ALIGN_SIZE - 1));
13+
char *nram_aligned = (char *)(((size_t)nram_buffer + ALIGN_SIZE - 1) & ~(ALIGN_SIZE - 1));
1414

15-
size_t nram_usable = NRAM_MAX_SIZE - ((char *)nram_align - nram_buffer);
15+
size_t nram_usable = NRAM_MAX_SIZE - (nram_aligned - nram_buffer);
1616
size_t max_chunk_elements = nram_usable / (2 * sizeof(Tdata));
1717

18-
int align_elements = ALIGN_SIZE / sizeof(Tdata);
18+
size_t align_elements = ALIGN_SIZE / sizeof(Tdata);
1919
if (align_elements == 0) {
2020
align_elements = 1;
2121
}
22-
max_chunk_elements = (max_chunk_elements / align_elements) * align_elements;
22+
int chunk_size = (int)((max_chunk_elements / align_elements) * align_elements);
2323

24-
Tdata *nram_x = nram_align;
25-
Tdata *nram_y = nram_align + max_chunk_elements;
24+
Tdata *nram_x = (Tdata *)nram_aligned;
25+
Tdata *nram_y = nram_x + chunk_size;
2626

2727
int elements_per_core = n / taskDim;
2828
int remain = n % taskDim;
@@ -33,22 +33,23 @@ __mlu_global__ void axpyKernelContiguous(
3333
return;
3434
}
3535

36-
int chunks = core_elements / max_chunk_elements;
37-
int chunk_rem = core_elements % max_chunk_elements;
36+
int chunks = core_elements / chunk_size;
37+
int chunk_rem = core_elements % chunk_size;
3838

3939
for (int c = 0; c < chunks; c++) {
40-
size_t current_offset = core_offset + c * max_chunk_elements;
41-
__memcpy(nram_x, x + current_offset, max_chunk_elements * sizeof(Tdata), GDRAM2NRAM);
42-
__memcpy(nram_y, y + current_offset, max_chunk_elements * sizeof(Tdata), GDRAM2NRAM);
40+
int current_offset = core_offset + c * chunk_size;
4341

44-
__bang_mul_scalar(nram_x, nram_x, alpha[0], max_chunk_elements);
45-
__bang_add(nram_y, nram_y, nram_x, max_chunk_elements);
42+
__memcpy(nram_x, x + current_offset, chunk_size * sizeof(Tdata), GDRAM2NRAM);
43+
__memcpy(nram_y, y + current_offset, chunk_size * sizeof(Tdata), GDRAM2NRAM);
4644

47-
__memcpy(y + current_offset, nram_y, max_chunk_elements * sizeof(Tdata), NRAM2GDRAM);
45+
__bang_mul_scalar(nram_x, nram_x, alpha[0], chunk_size);
46+
__bang_add(nram_y, nram_y, nram_x, chunk_size);
47+
48+
__memcpy(y + current_offset, nram_y, chunk_size * sizeof(Tdata), NRAM2GDRAM);
4849
}
4950

5051
if (chunk_rem > 0) {
51-
size_t current_offset = core_offset + chunks * max_chunk_elements;
52+
int current_offset = core_offset + chunks * chunk_size;
5253
int align_rem = ((chunk_rem + align_elements - 1) / align_elements) * align_elements;
5354

5455
__memcpy(nram_x, x + current_offset, chunk_rem * sizeof(Tdata), GDRAM2NRAM);
@@ -63,22 +64,22 @@ __mlu_global__ void axpyKernelContiguous(
6364

6465
template <typename Tdata>
6566
__mlu_global__ void axpyKernelStrided(
66-
size_t n,
67+
int n,
6768
const Tdata *alpha,
6869
const Tdata *x,
69-
size_t incx,
70+
int incx,
7071
Tdata *y,
71-
size_t incy) {
72+
int incy) {
7273

7374
int elements_per_core = n / taskDim;
7475
int remain = n % taskDim;
7576
int actual_tasks = elements_per_core + (taskId < remain ? 1 : 0);
7677
int start_idx = taskId < remain ? taskId * actual_tasks : taskId * elements_per_core + remain;
7778

7879
for (int i = start_idx; i < start_idx + actual_tasks; ++i) {
79-
size_t idx_x = i * incx;
80-
size_t idx_y = i * incy;
80+
int idx_x = i * incx;
81+
int idx_y = i * incy;
8182

8283
y[idx_y] += alpha[0] * x[idx_x];
8384
}
84-
}
85+
}

0 commit comments

Comments
 (0)