@@ -213,6 +213,7 @@ __device__ inline int FN_INT32(mod)(int a, int b) {
213213}
214214__device__ inline int FN_INT32(max)(int a, int b) { return cinn_max(a, b); }
215215__device__ inline int FN_INT32(min)(int a, int b) { return cinn_min(a, b); }
216+ __device__ inline int FN_INT32(abs)(int x) { return abs(x); }
216217__device__ inline int FN_INT32(left_shift)(int a, int b) { return a << b; }
217218__device__ inline int FN_INT32(right_shift)(int a, int b) { return a >> b; }
218219__device__ inline int FN_INT32(bitwise_and)(int a, int b) { return a & b; }
@@ -239,6 +240,7 @@ __device__ inline int64_t FN_INT64(bitwise_xor)(int64_t a, int64_t b) { return a
239240__device__ inline int64_t FN_INT64(bitwise_not)(int64_t a) { return ~a; }
240241__device__ inline int64_t FN_INT64(clz)(int64_t a) { return __clzll(a); }
241242__device__ inline int64_t FN_INT64(popc)(int64_t a) { return __popcll(a); }
243+ __device__ inline int64_t FN_INT64(abs)(int64_t x) { return llabs(x); }
242244__device__ inline int64_t FN_INT64(logical_right_shift)(int64_t a, int64_t b) { return ((uint64_t)a >> b); }
243245__device__ inline int64_t FN_INT64(trunc)(int64_t a) { return a; }
244246__device__ inline int64_t FN_INT64(mod)(int64_t a, int64_t b) { int64_t res = a % b; if ((res != 0) && ((b ^ res) < 0)) res += b; return res; }
@@ -997,7 +999,28 @@ ARGIDX_STRUCT_MACRO(argidx_fp32_i32, float, int, 0)
997999ARGIDX_STRUCT_MACRO(argidx_i32_i32, int, int, 0)
9981000
9991001// 手写 std::max 重载
1000- namespace std {
1002+ namespace std {
1003+ // --- 之前加的 long long / int64_t 补丁保持不变 ---
1004+ __device__ __forceinline__ int64_t max(long long a, int64_t b) { return a > b ? a : b; }
1005+ __device__ __forceinline__ int64_t max(int64_t a, long long b) { return a > b ? a : b; }
1006+ __device__ __forceinline__ int64_t min(long long a, int64_t b) { return a < b ? a : b; }
1007+ __device__ __forceinline__ int64_t min(int64_t a, long long b) { return a < b ? a : b; }
1008+
1009+ // ==============================================================
1010+ // 【新增防弹补丁】:解决 CINN 漏打 'f' 后缀导致的 float 和 double 混合报错
1011+ // ==============================================================
1012+ __device__ __forceinline__ double max(float a, double b) { return a > b ? (double)a : b; }
1013+ __device__ __forceinline__ double max(double a, float b) { return a > b ? a : (double)b; }
1014+ __device__ __forceinline__ double min(float a, double b) { return a < b ? (double)a : b; }
1015+ __device__ __forceinline__ double min(double a, float b) { return a < b ? a : (double)b; }
1016+
1017+ // 以防万一,解决 CINN 把 0 打印成 int 与 float 混合的报错 (如 std::max(val, 0))
1018+ __device__ __forceinline__ float max(float a, int b) { return a > b ? a : (float)b; }
1019+ __device__ __forceinline__ float max(int a, float b) { return a > b ? (float)a : b; }
1020+ __device__ __forceinline__ float min(float a, int b) { return a < b ? a : (float)b; }
1021+ __device__ __forceinline__ float min(int a, float b) { return a < b ? (float)a : b; }
1022+ // ==============================================================
1023+
10011024 // ArgMax 实现
10021025 template <typename T>
10031026 __device__ __forceinline__ T max_argidx_impl(const T& a, const T& b) {
0 commit comments