Skip to content

Commit f4a87c8

Browse files
authored
Merge pull request #230 from AnkitAhlawat7742/fix/argselect-const
Fix argselect const-correct
2 parents cb9f080 + d0a20fa commit f4a87c8

10 files changed

Lines changed: 23 additions & 17 deletions

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ data types.
5757
## Arg sort routines on arrays
5858
```cpp
5959
std::vector<size_t> arg = x86simdsort::argsort(const T* arr, size_t size, bool hasnan, bool descending);
60-
std::vector<size_t> arg = x86simdsort::argselect(T* arr, size_t k, size_t size, bool hasnan);
60+
std::vector<size_t> arg = x86simdsort::argselect(const T* arr, size_t k, size_t size, bool hasnan);
6161
```
6262
Supported datatypes: `T` $\in$ `[_Float16, uint16_t, int16_t, float, uint32_t, int32_t, double,
6363
uint64_t, int64_t]` Note that argsort and argselect are not accelerated with SIMD when using 16-bit

lib/x86simdsort-avx2.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
} \
3030
template <> \
3131
std::vector<size_t> argselect( \
32-
type *arr, size_t k, size_t arrsize, bool hasnan) \
32+
const type *arr, size_t k, size_t arrsize, bool hasnan) \
3333
{ \
3434
return x86simdsortStatic::argselect(arr, k, arrsize, hasnan); \
3535
}

lib/x86simdsort-internal.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
bool descending = false); \
5151
template <typename T> \
5252
XSS_HIDE_SYMBOL std::vector<size_t> \
53-
argselect(T *arr, size_t k, size_t arrsize, bool hasnan = false); \
53+
argselect(const T *arr, size_t k, size_t arrsize, bool hasnan = false); \
5454
}
5555

5656
namespace xss {

lib/x86simdsort-scalar.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,8 @@ namespace scalar {
8888
return arg;
8989
}
9090
template <typename T>
91-
std::vector<size_t> argselect(T *arr, size_t k, size_t arrsize, bool hasnan)
91+
std::vector<size_t>
92+
argselect(const T *arr, size_t k, size_t arrsize, bool hasnan)
9293
{
9394
UNUSED(hasnan);
9495
std::vector<size_t> arg(arrsize);

lib/x86simdsort-skx.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
} \
3030
template <> \
3131
std::vector<size_t> argselect( \
32-
type *arr, size_t k, size_t arrsize, bool hasnan) \
32+
const type *arr, size_t k, size_t arrsize, bool hasnan) \
3333
{ \
3434
return x86simdsortStatic::argselect(arr, k, arrsize, hasnan); \
3535
}

lib/x86simdsort.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,11 +99,11 @@ namespace x86simdsort {
9999

100100
#define DECLARE_INTERNAL_argselect(TYPE) \
101101
static std::vector<size_t> (*internal_argselect##TYPE)( \
102-
TYPE *, size_t, size_t, bool) \
102+
const TYPE *, size_t, size_t, bool) \
103103
= NULL; \
104104
template <> \
105105
std::vector<size_t> argselect( \
106-
TYPE *arr, size_t k, size_t arrsize, bool hasnan) \
106+
const TYPE *arr, size_t k, size_t arrsize, bool hasnan) \
107107
{ \
108108
return (*internal_argselect##TYPE)(arr, k, arrsize, hasnan); \
109109
}

lib/x86simdsort.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ XSS_EXPORT_SYMBOL std::vector<size_t> argsort(const T *arr,
4343
// argselect
4444
template <typename T>
4545
XSS_EXPORT_SYMBOL std::vector<size_t>
46-
argselect(T *arr, size_t k, size_t arrsize, bool hasnan = false);
46+
argselect(const T *arr, size_t k, size_t arrsize, bool hasnan = false);
4747

4848
// keyvalue sort
4949
template <typename T1, typename T2>

src/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ Equivalent to `np.argselect` in
7575
[NumPy](https://numpy.org/doc/stable/reference/generated/numpy.argpartition.html).
7676

7777
```cpp
78-
void x86simdsortStatic::argselect<T>(T* arr, size_t *arg, size_t k, size_t arrsize, bool hasnan = false);
78+
void x86simdsortStatic::argselect<T>(const T* arr, size_t *arg, size_t k, size_t arrsize, bool hasnan = false);
7979
```
8080
Supported datatypes: `uint32_t`, `int32_t`, `float`, `uint64_t`, `int64_t` and
8181
`double`.

src/x86simdsort-static-incl.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,12 @@ X86_SIMD_SORT_FINLINE void argsort(const T *arr,
4040

4141
template <typename T>
4242
X86_SIMD_SORT_FINLINE std::vector<size_t>
43-
argselect(T *arr, size_t k, size_t size, bool hasnan = false);
43+
argselect(const T *arr, size_t k, size_t size, bool hasnan = false);
4444

4545
/* argselect API required by NumPy: */
4646
template <typename T>
47-
void X86_SIMD_SORT_FINLINE
48-
argselect(T *arr, size_t *arg, size_t k, size_t size, bool hasnan = false);
47+
void X86_SIMD_SORT_FINLINE argselect(
48+
const T *arr, size_t *arg, size_t k, size_t size, bool hasnan = false);
4949

5050
template <typename T1, typename T2>
5151
X86_SIMD_SORT_FINLINE void keyvalue_qsort(T1 *key,
@@ -112,13 +112,13 @@ X86_SIMD_SORT_FINLINE void keyvalue_partial_sort(T1 *key,
112112
} \
113113
template <typename T> \
114114
X86_SIMD_SORT_FINLINE void x86simdsortStatic::argselect( \
115-
T *arr, size_t *arg, size_t k, size_t size, bool hasnan) \
115+
const T *arr, size_t *arg, size_t k, size_t size, bool hasnan) \
116116
{ \
117117
ISA##_argselect(arr, arg, k, size, hasnan); \
118118
} \
119119
template <typename T> \
120120
X86_SIMD_SORT_FINLINE std::vector<size_t> x86simdsortStatic::argselect( \
121-
T *arr, size_t k, size_t size, bool hasnan) \
121+
const T *arr, size_t k, size_t size, bool hasnan) \
122122
{ \
123123
std::vector<size_t> indices(size); \
124124
std::iota(indices.begin(), indices.end(), 0); \

src/xss-common-argsort.h

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -741,7 +741,10 @@ X86_SIMD_SORT_INLINE void avx512_argselect(T *arr,
741741
arrsize_t arrsize,
742742
bool hasnan = false)
743743
{
744-
xss_argselect<T, zmm_vector, ymm_vector>(arr, arg, k, arrsize, hasnan);
744+
// Safe: argselect never mutates arr; const is dropped only for SIMD type instantiation
745+
using base_t = std::remove_const_t<T>;
746+
xss_argselect<base_t, zmm_vector, ymm_vector>(
747+
const_cast<base_t *>(arr), arg, k, arrsize, hasnan);
745748
}
746749

747750
template <typename T>
@@ -751,8 +754,10 @@ X86_SIMD_SORT_INLINE void avx2_argselect(T *arr,
751754
arrsize_t arrsize,
752755
bool hasnan = false)
753756
{
754-
xss_argselect<T, avx2_vector, avx2_half_vector>(
755-
arr, arg, k, arrsize, hasnan);
757+
// Safe: argselect never mutates arr; const is dropped only for SIMD type instantiation
758+
using base_t = std::remove_const_t<T>;
759+
xss_argselect<base_t, avx2_vector, avx2_half_vector>(
760+
const_cast<base_t *>(arr), arg, k, arrsize, hasnan);
756761
}
757762

758763
#endif // XSS_COMMON_ARGSORT

0 commit comments

Comments
 (0)