Skip to content

Commit cb9f080

Browse files
authored
Merge pull request #229 from AnkitAhlawat7742/fix/argsort-const-correctness
Make argsort const-correct
2 parents 5fcbec1 + 5b86d5e commit cb9f080

10 files changed

Lines changed: 32 additions & 21 deletions

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ data types.
5656

5757
## Arg sort routines on arrays
5858
```cpp
59-
std::vector<size_t> arg = x86simdsort::argsort(T* arr, size_t size, bool hasnan, bool descending);
59+
std::vector<size_t> arg = x86simdsort::argsort(const T* arr, size_t size, bool hasnan, bool descending);
6060
std::vector<size_t> arg = x86simdsort::argselect(T* arr, size_t k, size_t size, bool hasnan);
6161
```
6262
Supported datatypes: `T` $\in$ `[_Float16, uint16_t, int16_t, float, uint32_t, int32_t, double,

lib/x86simdsort-avx2.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
} \
2424
template <> \
2525
std::vector<size_t> argsort( \
26-
type *arr, size_t arrsize, bool hasnan, bool descending) \
26+
const type *arr, size_t arrsize, bool hasnan, bool descending) \
2727
{ \
2828
return x86simdsortStatic::argsort(arr, arrsize, hasnan, descending); \
2929
} \

lib/x86simdsort-internal.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
bool hasnan = false, \
4545
bool descending = false); \
4646
template <typename T> \
47-
XSS_HIDE_SYMBOL std::vector<size_t> argsort(T *arr, \
47+
XSS_HIDE_SYMBOL std::vector<size_t> argsort(const T *arr, \
4848
size_t arrsize, \
4949
bool hasnan = false, \
5050
bool descending = false); \

lib/x86simdsort-scalar.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ namespace scalar {
7171
}
7272
template <typename T>
7373
std::vector<size_t>
74-
argsort(T *arr, size_t arrsize, bool hasnan, bool reversed)
74+
argsort(const T *arr, size_t arrsize, bool hasnan, bool reversed)
7575
{
7676
UNUSED(hasnan);
7777
std::vector<size_t> arg(arrsize);

lib/x86simdsort-skx.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
} \
2424
template <> \
2525
std::vector<size_t> argsort( \
26-
type *arr, size_t arrsize, bool hasnan, bool descending) \
26+
const type *arr, size_t arrsize, bool hasnan, bool descending) \
2727
{ \
2828
return x86simdsortStatic::argsort(arr, arrsize, hasnan, descending); \
2929
} \

lib/x86simdsort.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,11 +88,11 @@ namespace x86simdsort {
8888

8989
#define DECLARE_INTERNAL_argsort(TYPE) \
9090
static std::vector<size_t> (*internal_argsort##TYPE)( \
91-
TYPE *, size_t, bool, bool) \
91+
const TYPE *, size_t, bool, bool) \
9292
= NULL; \
9393
template <> \
9494
std::vector<size_t> argsort( \
95-
TYPE *arr, size_t arrsize, bool hasnan, bool descending) \
95+
const TYPE *arr, size_t arrsize, bool hasnan, bool descending) \
9696
{ \
9797
return (*internal_argsort##TYPE)(arr, arrsize, hasnan, descending); \
9898
}

lib/x86simdsort.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,10 @@ XSS_EXPORT_SYMBOL void partial_qsort(T *arr,
3535

3636
// argsort
3737
template <typename T>
38-
XSS_EXPORT_SYMBOL std::vector<size_t>
39-
argsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false);
38+
XSS_EXPORT_SYMBOL std::vector<size_t> argsort(const T *arr,
39+
size_t arrsize,
40+
bool hasnan = false,
41+
bool descending = false);
4042

4143
// argselect
4244
template <typename T>

src/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ Equivalent to `np.argsort` in
6363
[NumPy](https://numpy.org/doc/stable/reference/generated/numpy.argsort.html).
6464

6565
```cpp
66-
void x86simdsortStatic::argsort<T>(T* arr, size_t *arg, size_t arrsize, bool hasnan = false, bool descending = false);
66+
void x86simdsortStatic::argsort<T>(const T* arr, size_t *arg, size_t arrsize, bool hasnan = false, bool descending = false);
6767
```
6868
Supported datatypes: `uint32_t`, `int32_t`, `float`, `uint64_t`, `int64_t` and
6969
`double`.

src/x86simdsort-static-incl.h

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,14 @@ X86_SIMD_SORT_FINLINE void partial_qsort(T *arr,
2525
bool descending = false);
2626

2727
template <typename T>
28-
X86_SIMD_SORT_FINLINE std::vector<size_t>
29-
argsort(T *arr, size_t size, bool hasnan = false, bool descending = false);
28+
X86_SIMD_SORT_FINLINE std::vector<size_t> argsort(const T *arr,
29+
size_t size,
30+
bool hasnan = false,
31+
bool descending = false);
3032

3133
/* argsort API required by NumPy: */
3234
template <typename T>
33-
X86_SIMD_SORT_FINLINE void argsort(T *arr,
35+
X86_SIMD_SORT_FINLINE void argsort(const T *arr,
3436
size_t *arg,
3537
size_t size,
3638
bool hasnan = false,
@@ -90,14 +92,17 @@ X86_SIMD_SORT_FINLINE void keyvalue_partial_sort(T1 *key,
9092
ISA##_partial_qsort(arr, k, size, hasnan, descending); \
9193
} \
9294
template <typename T> \
93-
X86_SIMD_SORT_FINLINE void x86simdsortStatic::argsort( \
94-
T *arr, size_t *arg, size_t size, bool hasnan, bool descending) \
95+
X86_SIMD_SORT_FINLINE void x86simdsortStatic::argsort(const T *arr, \
96+
size_t *arg, \
97+
size_t size, \
98+
bool hasnan, \
99+
bool descending) \
95100
{ \
96101
ISA##_argsort(arr, arg, size, hasnan, descending); \
97102
} \
98103
template <typename T> \
99104
X86_SIMD_SORT_FINLINE std::vector<size_t> x86simdsortStatic::argsort( \
100-
T *arr, size_t size, bool hasnan, bool descending) \
105+
const T *arr, size_t size, bool hasnan, bool descending) \
101106
{ \
102107
std::vector<size_t> indices(size); \
103108
std::iota(indices.begin(), indices.end(), 0); \
@@ -211,4 +216,4 @@ XSS_METHODS(avx2)
211216
#error "x86simdsortStatic methods needs to be compiled with avx512/avx2 specific flags"
212217
#endif // (__AVX512VL__ && __AVX512DQ__) || AVX2
213218

214-
#endif // X86_SIMD_SORT_STATIC_METHODS
219+
#endif // X86_SIMD_SORT_STATIC_METHODS

src/xss-common-argsort.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -675,8 +675,10 @@ X86_SIMD_SORT_INLINE void avx512_argsort(T *arr,
675675
bool hasnan = false,
676676
bool descending = false)
677677
{
678-
xss_argsort<T, zmm_vector, ymm_vector>(
679-
arr, arg, arrsize, hasnan, descending);
678+
// Safe: argsort never mutates arr; const is dropped only for SIMD type instantiation
679+
using base_t = std::remove_const_t<T>;
680+
xss_argsort<base_t, zmm_vector, ymm_vector>(
681+
const_cast<base_t *>(arr), arg, arrsize, hasnan, descending);
680682
}
681683

682684
template <typename T>
@@ -686,8 +688,10 @@ X86_SIMD_SORT_INLINE void avx2_argsort(T *arr,
686688
bool hasnan = false,
687689
bool descending = false)
688690
{
689-
xss_argsort<T, avx2_vector, avx2_half_vector>(
690-
arr, arg, arrsize, hasnan, descending);
691+
// Safe: argsort never mutates arr; const is dropped only for SIMD type instantiation
692+
using base_t = std::remove_const_t<T>;
693+
xss_argsort<base_t, avx2_vector, avx2_half_vector>(
694+
const_cast<base_t *>(arr), arg, arrsize, hasnan, descending);
691695
}
692696

693697
/* argselect methods for 32-bit and 64-bit dtypes */

0 commit comments

Comments
 (0)