@@ -700,6 +700,65 @@ EXPAND_REDUCE_FP64_MACRO(CINN_DISCRETE_REDUCE_MACRO)
700700EXPAND_REDUCE_BOOL_MACRO(CINN_DISCRETE_REDUCE_MACRO)
701701EXPAND_REDUCE_FP16_MACRO(CINN_DISCRETE_REDUCE_MACRO)
702702
703+ // ===============================================================
704+ // ArgMin/ArgMax Support (ArgIdx Structures & Combine Functions)
705+ // Must be defined before discrete/block/grid reduce functions that use them
706+ // ===============================================================
707+
708+ // arg reduce arg index struct
709+ // Do not define operator<; force dispatch through std::max overloads
710+ #define ARGIDX_STRUCT_MACRO(TYPENAME, DTYPE, ITYPE, IINIT) \
711+ struct TYPENAME { \
712+ DTYPE value; \
713+ ITYPE index; \
714+ __device__ TYPENAME() {} \
715+ __device__ explicit TYPENAME(DTYPE value) : value(value), index(IINIT) {} \
716+ __device__ TYPENAME(DTYPE value, ITYPE index) \
717+ : value(value), index(index) {} \
718+ __device__ explicit operator ITYPE() { return index; } \
719+ /* Assignment operator support */ \
720+ __device__ inline TYPENAME& operator=(const TYPENAME& other) { \
721+ value = other.value; \
722+ index = other.index; \
723+ return *this; \
724+ } \
725+ __device__ inline volatile TYPENAME& operator=(const volatile TYPENAME& other) volatile { \
726+ value = other.value; \
727+ index = other.index; \
728+ return *this; \
729+ } \
730+ };
731+
732+ // Instantiate structs
733+ #ifdef CINN_CUDA_FP16
734+ ARGIDX_STRUCT_MACRO(argidx_fp16_i64, float16, int64_t, 0LL)
735+ #endif
736+ ARGIDX_STRUCT_MACRO(argidx_fp32_i64, float, int64_t, 0LL)
737+ ARGIDX_STRUCT_MACRO(argidx_fp64_i64, double, int64_t, 0LL)
738+ ARGIDX_STRUCT_MACRO(argidx_i16_i64, int16_t, int64_t, 0LL)
739+ ARGIDX_STRUCT_MACRO(argidx_i32_i64, int, int64_t, 0LL)
740+ ARGIDX_STRUCT_MACRO(argidx_i64_i64, int64_t, int64_t, 0LL)
741+ ARGIDX_STRUCT_MACRO(argidx_u8_i64, uint8_t, int64_t, 0LL)
742+
743+ ARGIDX_STRUCT_MACRO(argidx_fp32_i32, float, int, 0)
744+ ARGIDX_STRUCT_MACRO(argidx_i32_i32, int, int, 0)
745+
746+ // cinn_max_argidx / cinn_min_argidx combine functions
747+ // These are called by CINN_DISCRETE_REDUCE_IMPL via cinn_##REDUCE_TYPE token pasting
748+ #define ARGIDX_COMBINE_MACRO(TYPENAME) \
749+ __device__ TYPENAME cinn_min_##TYPENAME(TYPENAME a, TYPENAME b) { \
750+ return a.value == b.value ? (a.index < b.index ? a : b) \
751+ : (a.value < b.value ? a : b); \
752+ } \
753+ __device__ TYPENAME cinn_max_##TYPENAME(TYPENAME a, TYPENAME b) { \
754+ return a.value == b.value ? (a.index < b.index ? a : b) \
755+ : (a.value > b.value ? a : b); \
756+ }
757+
758+ ARGIDX_COMBINE_MACRO(argidx_fp32_i32)
759+ ARGIDX_COMBINE_MACRO(argidx_fp32_i64)
760+ ARGIDX_COMBINE_MACRO(argidx_i32_i32)
761+
703762// Discrete reduce for argidx types
704763__device__ inline argidx_fp32_i32 cinn_discrete_reduce_max_argidx_fp32_i32(
705764 const argidx_fp32_i32 value, argidx_fp32_i32 *shm) {
@@ -983,47 +1042,8 @@ __device__ int cinn_custom_device_resize_bicubic(const int *buf,
9831042} // extern "C"
9841043
9851044// ===============================================================
986- // 8. ArgMin/ArgMax Support (ArgIdx Structures & Shuffles)
1045+ // 8. ArgMin/ArgMax std::max/min Overloads & Block Reduce
9871046// ===============================================================
988- // --- C++ Scope Start ---
989-
990- // arg reduce arg index struct
991- // Do not define operator<; force dispatch through std::max overloads
992- #define ARGIDX_STRUCT_MACRO(TYPENAME, DTYPE, ITYPE, IINIT) \
993- struct TYPENAME { \
994- DTYPE value; \
995- ITYPE index; \
996- __device__ TYPENAME() {} \
997- __device__ explicit TYPENAME(DTYPE value) : value(value), index(IINIT) {} \
998- __device__ TYPENAME(DTYPE value, ITYPE index) \
999- : value(value), index(index) {} \
1000- __device__ explicit operator ITYPE() { return index; } \
1001- /* Assignment operator support */ \
1002- __device__ inline TYPENAME& operator=(const TYPENAME& other) { \
1003- value = other.value; \
1004- index = other.index; \
1005- return *this; \
1006- } \
1007- __device__ inline volatile TYPENAME& operator=(const volatile TYPENAME& other) volatile { \
1008- value = other.value; \
1009- index = other.index; \
1010- return *this; \
1011- } \
1012- };
1013-
1014- // Instantiate structs
1015- #ifdef CINN_CUDA_FP16
1016- ARGIDX_STRUCT_MACRO(argidx_fp16_i64, float16, int64_t, 0LL)
1017- #endif
1018- ARGIDX_STRUCT_MACRO(argidx_fp32_i64, float, int64_t, 0LL)
1019- ARGIDX_STRUCT_MACRO(argidx_fp64_i64, double, int64_t, 0LL)
1020- ARGIDX_STRUCT_MACRO(argidx_i16_i64, int16_t, int64_t, 0LL)
1021- ARGIDX_STRUCT_MACRO(argidx_i32_i64, int, int64_t, 0LL)
1022- ARGIDX_STRUCT_MACRO(argidx_i64_i64, int64_t, int64_t, 0LL)
1023- ARGIDX_STRUCT_MACRO(argidx_u8_i64, uint8_t, int64_t, 0LL)
1024-
1025- ARGIDX_STRUCT_MACRO(argidx_fp32_i32, float, int, 0)
1026- ARGIDX_STRUCT_MACRO(argidx_i32_i32, int, int, 0)
10271047
10281048// std::max overloads
10291049namespace std {
0 commit comments