Skip to content

Commit 75c6b8e

Browse files
committed
Fix CMakeLists.txt PADDLE_WARP_SIZE 32->64. Fix argidx_fp32_i32 forward
reference error in MetaX runtime.
1 parent 1cf361b commit 75c6b8e

2 files changed

Lines changed: 61 additions & 41 deletions

File tree

backends/metax_gpu/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ set(CMAKE_MODULE_PATH "${CMAKE_SOURCE_DIR}/cmake")
2828
message(STATUS "CMAKE_MODULE_PATH: ${CMAKE_MODULE_PATH}")
2929

3030
if(NOT DEFINED PADDLE_WARP_SIZE)
31-
set(PADDLE_WARP_SIZE 32)
31+
set(PADDLE_WARP_SIZE 64)
3232
endif()
3333
math(EXPR PADDLE_WARP_MASK "${PADDLE_WARP_SIZE} - 1")
3434
if(PADDLE_WARP_SIZE EQUAL 64)

backends/metax_gpu/cinn/compiler/compiler.cc

Lines changed: 60 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -700,6 +700,65 @@ EXPAND_REDUCE_FP64_MACRO(CINN_DISCRETE_REDUCE_MACRO)
700700
EXPAND_REDUCE_BOOL_MACRO(CINN_DISCRETE_REDUCE_MACRO)
701701
EXPAND_REDUCE_FP16_MACRO(CINN_DISCRETE_REDUCE_MACRO)
702702
703+
// ===============================================================
704+
// ArgMin/ArgMax Support (ArgIdx Structures & Combine Functions)
705+
// Must be defined before discrete/block/grid reduce functions that use them
706+
// ===============================================================
707+
708+
// arg reduce arg index struct
709+
// Do not define operator<; force dispatch through std::max overloads
710+
#define ARGIDX_STRUCT_MACRO(TYPENAME, DTYPE, ITYPE, IINIT) \
711+
struct TYPENAME { \
712+
DTYPE value; \
713+
ITYPE index; \
714+
__device__ TYPENAME() {} \
715+
__device__ explicit TYPENAME(DTYPE value) : value(value), index(IINIT) {} \
716+
__device__ TYPENAME(DTYPE value, ITYPE index) \
717+
: value(value), index(index) {} \
718+
__device__ explicit operator ITYPE() { return index; } \
719+
/* Assignment operator support */ \
720+
__device__ inline TYPENAME& operator=(const TYPENAME& other) { \
721+
value = other.value; \
722+
index = other.index; \
723+
return *this; \
724+
} \
725+
__device__ inline volatile TYPENAME& operator=(const volatile TYPENAME& other) volatile { \
726+
value = other.value; \
727+
index = other.index; \
728+
return *this; \
729+
} \
730+
};
731+
732+
// Instantiate structs
733+
#ifdef CINN_CUDA_FP16
734+
ARGIDX_STRUCT_MACRO(argidx_fp16_i64, float16, int64_t, 0LL)
735+
#endif
736+
ARGIDX_STRUCT_MACRO(argidx_fp32_i64, float, int64_t, 0LL)
737+
ARGIDX_STRUCT_MACRO(argidx_fp64_i64, double, int64_t, 0LL)
738+
ARGIDX_STRUCT_MACRO(argidx_i16_i64, int16_t, int64_t, 0LL)
739+
ARGIDX_STRUCT_MACRO(argidx_i32_i64, int, int64_t, 0LL)
740+
ARGIDX_STRUCT_MACRO(argidx_i64_i64, int64_t, int64_t, 0LL)
741+
ARGIDX_STRUCT_MACRO(argidx_u8_i64, uint8_t, int64_t, 0LL)
742+
743+
ARGIDX_STRUCT_MACRO(argidx_fp32_i32, float, int, 0)
744+
ARGIDX_STRUCT_MACRO(argidx_i32_i32, int, int, 0)
745+
746+
// cinn_max_argidx / cinn_min_argidx combine functions
747+
// These are called by CINN_DISCRETE_REDUCE_IMPL via cinn_##REDUCE_TYPE token pasting
748+
#define ARGIDX_COMBINE_MACRO(TYPENAME) \
749+
__device__ TYPENAME cinn_min_##TYPENAME(TYPENAME a, TYPENAME b) { \
750+
return a.value == b.value ? (a.index < b.index ? a : b) \
751+
: (a.value < b.value ? a : b); \
752+
} \
753+
__device__ TYPENAME cinn_max_##TYPENAME(TYPENAME a, TYPENAME b) { \
754+
return a.value == b.value ? (a.index < b.index ? a : b) \
755+
: (a.value > b.value ? a : b); \
756+
}
757+
758+
ARGIDX_COMBINE_MACRO(argidx_fp32_i32)
759+
ARGIDX_COMBINE_MACRO(argidx_fp32_i64)
760+
ARGIDX_COMBINE_MACRO(argidx_i32_i32)
761+
703762
// Discrete reduce for argidx types
704763
__device__ inline argidx_fp32_i32 cinn_discrete_reduce_max_argidx_fp32_i32(
705764
const argidx_fp32_i32 value, argidx_fp32_i32 *shm) {
@@ -983,47 +1042,8 @@ __device__ int cinn_custom_device_resize_bicubic(const int *buf,
9831042
} // extern "C"
9841043
9851044
// ===============================================================
986-
// 8. ArgMin/ArgMax Support (ArgIdx Structures & Shuffles)
1045+
// 8. ArgMin/ArgMax std::max/min Overloads & Block Reduce
9871046
// ===============================================================
988-
// --- C++ Scope Start ---
989-
990-
// arg reduce arg index struct
991-
// Do not define operator<; force dispatch through std::max overloads
992-
#define ARGIDX_STRUCT_MACRO(TYPENAME, DTYPE, ITYPE, IINIT) \
993-
struct TYPENAME { \
994-
DTYPE value; \
995-
ITYPE index; \
996-
__device__ TYPENAME() {} \
997-
__device__ explicit TYPENAME(DTYPE value) : value(value), index(IINIT) {} \
998-
__device__ TYPENAME(DTYPE value, ITYPE index) \
999-
: value(value), index(index) {} \
1000-
__device__ explicit operator ITYPE() { return index; } \
1001-
/* Assignment operator support */ \
1002-
__device__ inline TYPENAME& operator=(const TYPENAME& other) { \
1003-
value = other.value; \
1004-
index = other.index; \
1005-
return *this; \
1006-
} \
1007-
__device__ inline volatile TYPENAME& operator=(const volatile TYPENAME& other) volatile { \
1008-
value = other.value; \
1009-
index = other.index; \
1010-
return *this; \
1011-
} \
1012-
};
1013-
1014-
// Instantiate structs
1015-
#ifdef CINN_CUDA_FP16
1016-
ARGIDX_STRUCT_MACRO(argidx_fp16_i64, float16, int64_t, 0LL)
1017-
#endif
1018-
ARGIDX_STRUCT_MACRO(argidx_fp32_i64, float, int64_t, 0LL)
1019-
ARGIDX_STRUCT_MACRO(argidx_fp64_i64, double, int64_t, 0LL)
1020-
ARGIDX_STRUCT_MACRO(argidx_i16_i64, int16_t, int64_t, 0LL)
1021-
ARGIDX_STRUCT_MACRO(argidx_i32_i64, int, int64_t, 0LL)
1022-
ARGIDX_STRUCT_MACRO(argidx_i64_i64, int64_t, int64_t, 0LL)
1023-
ARGIDX_STRUCT_MACRO(argidx_u8_i64, uint8_t, int64_t, 0LL)
1024-
1025-
ARGIDX_STRUCT_MACRO(argidx_fp32_i32, float, int, 0)
1026-
ARGIDX_STRUCT_MACRO(argidx_i32_i32, int, int, 0)
10271047
10281048
// std::max overloads
10291049
namespace std {

0 commit comments

Comments
 (0)