Skip to content

Commit cf7cdce

Browse files
committed
Add compiler.cc int64 int32 abs. int64 vs double abs.
1 parent bd8c5c2 commit cf7cdce

1 file changed

Lines changed: 24 additions & 1 deletion

File tree

backends/metax_gpu/cinn/compiler/compiler.cc

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ __device__ inline int FN_INT32(mod)(int a, int b) {
213213
}
214214
__device__ inline int FN_INT32(max)(int a, int b) { return cinn_max(a, b); }
215215
__device__ inline int FN_INT32(min)(int a, int b) { return cinn_min(a, b); }
216+
__device__ inline int FN_INT32(abs)(int x) { return abs(x); }
216217
__device__ inline int FN_INT32(left_shift)(int a, int b) { return a << b; }
217218
__device__ inline int FN_INT32(right_shift)(int a, int b) { return a >> b; }
218219
__device__ inline int FN_INT32(bitwise_and)(int a, int b) { return a & b; }
@@ -239,6 +240,7 @@ __device__ inline int64_t FN_INT64(bitwise_xor)(int64_t a, int64_t b) { return a
239240
__device__ inline int64_t FN_INT64(bitwise_not)(int64_t a) { return ~a; }
240241
__device__ inline int64_t FN_INT64(clz)(int64_t a) { return __clzll(a); }
241242
__device__ inline int64_t FN_INT64(popc)(int64_t a) { return __popcll(a); }
243+
__device__ inline int64_t FN_INT64(abs)(int64_t x) { return llabs(x); }
242244
__device__ inline int64_t FN_INT64(logical_right_shift)(int64_t a, int64_t b) { return ((uint64_t)a >> b); }
243245
__device__ inline int64_t FN_INT64(trunc)(int64_t a) { return a; }
244246
__device__ inline int64_t FN_INT64(mod)(int64_t a, int64_t b) { int64_t res = a % b; if ((res != 0) && ((b ^ res) < 0)) res += b; return res; }
@@ -997,7 +999,28 @@ ARGIDX_STRUCT_MACRO(argidx_fp32_i32, float, int, 0)
997999
ARGIDX_STRUCT_MACRO(argidx_i32_i32, int, int, 0)
9981000
9991001
// 手写 std::max 重载
1000-
namespace std {
1002+
namespace std {
1003+
// --- 之前加的 long long / int64_t 补丁保持不变 ---
1004+
__device__ __forceinline__ int64_t max(long long a, int64_t b) { return a > b ? a : b; }
1005+
__device__ __forceinline__ int64_t max(int64_t a, long long b) { return a > b ? a : b; }
1006+
__device__ __forceinline__ int64_t min(long long a, int64_t b) { return a < b ? a : b; }
1007+
__device__ __forceinline__ int64_t min(int64_t a, long long b) { return a < b ? a : b; }
1008+
1009+
// ==============================================================
1010+
// 【新增防弹补丁】:解决 CINN 漏打 'f' 后缀导致的 float 和 double 混合报错
1011+
// ==============================================================
1012+
__device__ __forceinline__ double max(float a, double b) { return a > b ? (double)a : b; }
1013+
__device__ __forceinline__ double max(double a, float b) { return a > b ? a : (double)b; }
1014+
__device__ __forceinline__ double min(float a, double b) { return a < b ? (double)a : b; }
1015+
__device__ __forceinline__ double min(double a, float b) { return a < b ? a : (double)b; }
1016+
1017+
// 以防万一,解决 CINN 把 0 打印成 int 与 float 混合的报错 (如 std::max(val, 0))
1018+
__device__ __forceinline__ float max(float a, int b) { return a > b ? a : (float)b; }
1019+
__device__ __forceinline__ float max(int a, float b) { return a > b ? (float)a : b; }
1020+
__device__ __forceinline__ float min(float a, int b) { return a < b ? a : (float)b; }
1021+
__device__ __forceinline__ float min(int a, float b) { return a < b ? (float)a : b; }
1022+
// ==============================================================
1023+
10011024
// ArgMax 实现
10021025
template <typename T>
10031026
__device__ __forceinline__ T max_argidx_impl(const T& a, const T& b) {

0 commit comments

Comments
 (0)