Skip to content

Commit e6edb07

Browse files
authored
Merge pull request #231 from numpy/copilot/remove-remove-const-t
Propagate `const T*` through argsort/argselect call chain, remove `std::remove_const_t` workarounds
2 parents f4a87c8 + e0f2519 commit e6edb07

6 files changed

Lines changed: 43 additions & 47 deletions

src/avx2-32bit-half.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ struct avx2_half_vector<int32_t> {
8484
return _mm_mask_i32gather_epi32(
8585
src, (const int *)base, index, mask, scale);
8686
}
87-
static reg_t i64gather(type_t *arr, arrsize_t *ind)
87+
static reg_t i64gather(const type_t *arr, arrsize_t *ind)
8888
{
8989
return set(arr[ind[3]], arr[ind[2]], arr[ind[1]], arr[ind[0]]);
9090
}
@@ -237,7 +237,7 @@ struct avx2_half_vector<uint32_t> {
237237
return _mm_mask_i32gather_epi32(
238238
src, (const int *)base, index, mask, scale);
239239
}
240-
static reg_t i64gather(type_t *arr, arrsize_t *ind)
240+
static reg_t i64gather(const type_t *arr, arrsize_t *ind)
241241
{
242242
return set(arr[ind[3]], arr[ind[2]], arr[ind[1]], arr[ind[0]]);
243243
}
@@ -421,7 +421,7 @@ struct avx2_half_vector<float> {
421421
return _mm_mask_i32gather_ps(
422422
src, (const float *)base, index, _mm_castsi128_ps(mask), scale);
423423
}
424-
static reg_t i64gather(type_t *arr, arrsize_t *ind)
424+
static reg_t i64gather(const type_t *arr, arrsize_t *ind)
425425
{
426426
return set(arr[ind[3]], arr[ind[2]], arr[ind[1]], arr[ind[0]]);
427427
}

src/avx2-64bit-qsort.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ struct avx2_vector<int64_t> {
9999
return _mm256_mask_i32gather_epi64(
100100
src, (const long long int *)base, index, mask, scale);
101101
}
102-
static reg_t i64gather(type_t *arr, arrsize_t *ind)
102+
static reg_t i64gather(const type_t *arr, arrsize_t *ind)
103103
{
104104
return set(arr[ind[3]], arr[ind[2]], arr[ind[1]], arr[ind[0]]);
105105
}
@@ -269,7 +269,7 @@ struct avx2_vector<uint64_t> {
269269
return _mm256_mask_i32gather_epi64(
270270
src, (const long long int *)base, index, mask, scale);
271271
}
272-
static reg_t i64gather(type_t *arr, arrsize_t *ind)
272+
static reg_t i64gather(const type_t *arr, arrsize_t *ind)
273273
{
274274
return set(arr[ind[3]], arr[ind[2]], arr[ind[1]], arr[ind[0]]);
275275
}
@@ -499,7 +499,7 @@ struct avx2_vector<double> {
499499
scale);
500500
;
501501
}
502-
static reg_t i64gather(type_t *arr, arrsize_t *ind)
502+
static reg_t i64gather(const type_t *arr, arrsize_t *ind)
503503
{
504504
return set(arr[ind[3]], arr[ind[2]], arr[ind[1]], arr[ind[0]]);
505505
}

src/avx512-64bit-common.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ struct ymm_vector<float> {
9999
{
100100
return _mm256_mmask_i32gather_ps(src, mask, index, base, scale);
101101
}
102-
static reg_t i64gather(type_t *arr, arrsize_t *ind)
102+
static reg_t i64gather(const type_t *arr, arrsize_t *ind)
103103
{
104104
return set(arr[ind[7]],
105105
arr[ind[6]],
@@ -293,7 +293,7 @@ struct ymm_vector<uint32_t> {
293293
{
294294
return _mm256_mmask_i32gather_epi32(src, mask, index, base, scale);
295295
}
296-
static reg_t i64gather(type_t *arr, arrsize_t *ind)
296+
static reg_t i64gather(const type_t *arr, arrsize_t *ind)
297297
{
298298
return set(arr[ind[7]],
299299
arr[ind[6]],
@@ -481,7 +481,7 @@ struct ymm_vector<int32_t> {
481481
{
482482
return _mm256_mmask_i32gather_epi32(src, mask, index, base, scale);
483483
}
484-
static reg_t i64gather(type_t *arr, arrsize_t *ind)
484+
static reg_t i64gather(const type_t *arr, arrsize_t *ind)
485485
{
486486
return set(arr[ind[7]],
487487
arr[ind[6]],
@@ -680,7 +680,7 @@ struct zmm_vector<int64_t> {
680680
{
681681
return _mm512_mask_i32gather_epi64(src, mask, index, base, scale);
682682
}
683-
static reg_t i64gather(type_t *arr, arrsize_t *ind)
683+
static reg_t i64gather(const type_t *arr, arrsize_t *ind)
684684
{
685685
return set(arr[ind[7]],
686686
arr[ind[6]],
@@ -843,7 +843,7 @@ struct zmm_vector<uint64_t> {
843843
{
844844
return _mm512_mask_i32gather_epi64(src, mask, index, base, scale);
845845
}
846-
static reg_t i64gather(type_t *arr, arrsize_t *ind)
846+
static reg_t i64gather(const type_t *arr, arrsize_t *ind)
847847
{
848848
return set(arr[ind[7]],
849849
arr[ind[6]],
@@ -1062,7 +1062,7 @@ struct zmm_vector<double> {
10621062
{
10631063
return _mm512_mask_i32gather_pd(src, mask, index, base, scale);
10641064
}
1065-
static reg_t i64gather(type_t *arr, arrsize_t *ind)
1065+
static reg_t i64gather(const type_t *arr, arrsize_t *ind)
10661066
{
10671067
return set(arr[ind[7]],
10681068
arr[ind[6]],

src/xss-common-argsort.h

Lines changed: 28 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,11 @@
1111
#include <numeric>
1212

1313
template <typename T>
14-
X86_SIMD_SORT_INLINE void std_argselect_withnan(
15-
T *arr, arrsize_t *arg, arrsize_t k, arrsize_t left, arrsize_t right)
14+
X86_SIMD_SORT_INLINE void std_argselect_withnan(const T *arr,
15+
arrsize_t *arg,
16+
arrsize_t k,
17+
arrsize_t left,
18+
arrsize_t right)
1619
{
1720
std::nth_element(arg + left,
1821
arg + k,
@@ -32,8 +35,10 @@ X86_SIMD_SORT_INLINE void std_argselect_withnan(
3235

3336
/* argsort using std::sort */
3437
template <typename T>
35-
X86_SIMD_SORT_INLINE void
36-
std_argsort_withnan(T *arr, arrsize_t *arg, arrsize_t left, arrsize_t right)
38+
X86_SIMD_SORT_INLINE void std_argsort_withnan(const T *arr,
39+
arrsize_t *arg,
40+
arrsize_t left,
41+
arrsize_t right)
3742
{
3843
std::sort(arg + left,
3944
arg + right,
@@ -53,7 +58,7 @@ std_argsort_withnan(T *arr, arrsize_t *arg, arrsize_t left, arrsize_t right)
5358
/* argsort using std::sort */
5459
template <typename T>
5560
X86_SIMD_SORT_INLINE void
56-
std_argsort(T *arr, arrsize_t *arg, arrsize_t left, arrsize_t right)
61+
std_argsort(const T *arr, arrsize_t *arg, arrsize_t left, arrsize_t right)
5762
{
5863
std::sort(arg + left,
5964
arg + right,
@@ -172,7 +177,7 @@ X86_SIMD_SORT_INLINE int32_t partition_vec(type_t *arg,
172177
* last element that is less than equal to the pivot.
173178
*/
174179
template <typename vtype, typename argtype, typename type_t>
175-
X86_SIMD_SORT_INLINE arrsize_t argpartition(type_t *arr,
180+
X86_SIMD_SORT_INLINE arrsize_t argpartition(const type_t *arr,
176181
arrsize_t *arg,
177182
arrsize_t left,
178183
arrsize_t right,
@@ -291,7 +296,7 @@ template <typename vtype,
291296
typename argtype,
292297
int num_unroll,
293298
typename type_t = typename vtype::type_t>
294-
X86_SIMD_SORT_INLINE arrsize_t argpartition_unrolled(type_t *arr,
299+
X86_SIMD_SORT_INLINE arrsize_t argpartition_unrolled(const type_t *arr,
295300
arrsize_t *arg,
296301
arrsize_t left,
297302
arrsize_t right,
@@ -422,7 +427,7 @@ X86_SIMD_SORT_INLINE arrsize_t argpartition_unrolled(type_t *arr,
422427
}
423428

424429
template <typename vtype, typename type_t>
425-
X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *arr,
430+
X86_SIMD_SORT_INLINE type_t get_pivot_64bit(const type_t *arr,
426431
arrsize_t *arg,
427432
const arrsize_t left,
428433
const arrsize_t right)
@@ -468,7 +473,7 @@ X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *arr,
468473
}
469474

470475
template <typename vtype, typename argtype, typename type_t>
471-
X86_SIMD_SORT_INLINE void argsort_(type_t *arr,
476+
X86_SIMD_SORT_INLINE void argsort_(const type_t *arr,
472477
arrsize_t *arg,
473478
arrsize_t left,
474479
arrsize_t right,
@@ -549,7 +554,7 @@ X86_SIMD_SORT_INLINE void argsort_(type_t *arr,
549554
}
550555

551556
template <typename vtype, typename argtype, typename type_t>
552-
X86_SIMD_SORT_INLINE void argselect_(type_t *arr,
557+
X86_SIMD_SORT_INLINE void argselect_(const type_t *arr,
553558
arrsize_t *arg,
554559
arrsize_t pos,
555560
arrsize_t left,
@@ -590,7 +595,7 @@ template <typename T,
590595
typename full_vector,
591596
template <typename...>
592597
typename half_vector>
593-
X86_SIMD_SORT_INLINE void xss_argsort(T *arr,
598+
X86_SIMD_SORT_INLINE void xss_argsort(const T *arr,
594599
arrsize_t *arg,
595600
arrsize_t arrsize,
596601
bool hasnan = false,
@@ -669,29 +674,25 @@ X86_SIMD_SORT_INLINE void xss_argsort(T *arr,
669674
}
670675

671676
template <typename T>
672-
X86_SIMD_SORT_INLINE void avx512_argsort(T *arr,
677+
X86_SIMD_SORT_INLINE void avx512_argsort(const T *arr,
673678
arrsize_t *arg,
674679
arrsize_t arrsize,
675680
bool hasnan = false,
676681
bool descending = false)
677682
{
678-
// Safe: argsort never mutates arr; const is dropped only for SIMD type instantiation
679-
using base_t = std::remove_const_t<T>;
680-
xss_argsort<base_t, zmm_vector, ymm_vector>(
681-
const_cast<base_t *>(arr), arg, arrsize, hasnan, descending);
683+
xss_argsort<T, zmm_vector, ymm_vector>(
684+
arr, arg, arrsize, hasnan, descending);
682685
}
683686

684687
template <typename T>
685-
X86_SIMD_SORT_INLINE void avx2_argsort(T *arr,
688+
X86_SIMD_SORT_INLINE void avx2_argsort(const T *arr,
686689
arrsize_t *arg,
687690
arrsize_t arrsize,
688691
bool hasnan = false,
689692
bool descending = false)
690693
{
691-
// Safe: argsort never mutates arr; const is dropped only for SIMD type instantiation
692-
using base_t = std::remove_const_t<T>;
693-
xss_argsort<base_t, avx2_vector, avx2_half_vector>(
694-
const_cast<base_t *>(arr), arg, arrsize, hasnan, descending);
694+
xss_argsort<T, avx2_vector, avx2_half_vector>(
695+
arr, arg, arrsize, hasnan, descending);
695696
}
696697

697698
/* argselect methods for 32-bit and 64-bit dtypes */
@@ -700,7 +701,7 @@ template <typename T,
700701
typename full_vector,
701702
template <typename...>
702703
typename half_vector>
703-
X86_SIMD_SORT_INLINE void xss_argselect(T *arr,
704+
X86_SIMD_SORT_INLINE void xss_argselect(const T *arr,
704705
arrsize_t *arg,
705706
arrsize_t k,
706707
arrsize_t arrsize,
@@ -735,29 +736,24 @@ X86_SIMD_SORT_INLINE void xss_argselect(T *arr,
735736
}
736737

737738
template <typename T>
738-
X86_SIMD_SORT_INLINE void avx512_argselect(T *arr,
739+
X86_SIMD_SORT_INLINE void avx512_argselect(const T *arr,
739740
arrsize_t *arg,
740741
arrsize_t k,
741742
arrsize_t arrsize,
742743
bool hasnan = false)
743744
{
744-
// Safe: argselect never mutates arr; const is dropped only for SIMD type instantiation
745-
using base_t = std::remove_const_t<T>;
746-
xss_argselect<base_t, zmm_vector, ymm_vector>(
747-
const_cast<base_t *>(arr), arg, k, arrsize, hasnan);
745+
xss_argselect<T, zmm_vector, ymm_vector>(arr, arg, k, arrsize, hasnan);
748746
}
749747

750748
template <typename T>
751-
X86_SIMD_SORT_INLINE void avx2_argselect(T *arr,
749+
X86_SIMD_SORT_INLINE void avx2_argselect(const T *arr,
752750
arrsize_t *arg,
753751
arrsize_t k,
754752
arrsize_t arrsize,
755753
bool hasnan = false)
756754
{
757-
// Safe: argselect never mutates arr; const is dropped only for SIMD type instantiation
758-
using base_t = std::remove_const_t<T>;
759-
xss_argselect<base_t, avx2_vector, avx2_half_vector>(
760-
const_cast<base_t *>(arr), arg, k, arrsize, hasnan);
755+
xss_argselect<T, avx2_vector, avx2_half_vector>(
756+
arr, arg, k, arrsize, hasnan);
761757
}
762758

763759
#endif // XSS_COMMON_ARGSORT

src/xss-common-qsort.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ X86_SIMD_SORT_INLINE arrsize_t replace_nan_with_inf(T *arr, arrsize_t size)
7979
}
8080

8181
template <typename vtype, typename type_t>
82-
X86_SIMD_SORT_INLINE bool array_has_nan(type_t *arr, arrsize_t size)
82+
X86_SIMD_SORT_INLINE bool array_has_nan(const type_t *arr, arrsize_t size)
8383
{
8484
using opmask_t = typename vtype::opmask_t;
8585
using reg_t = typename vtype::reg_t;

src/xss-network-keyvaluesort.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ bitonic_fullmerge_n_vec(typename keyType::reg_t *keys,
208208

209209
template <typename keyType, typename indexType, int numVecs>
210210
X86_SIMD_SORT_INLINE void
211-
argsort_n_vec(typename keyType::type_t *keys, arrsize_t *indices, int N)
211+
argsort_n_vec(const typename keyType::type_t *keys, arrsize_t *indices, int N)
212212
{
213213
using kreg_t = typename keyType::reg_t;
214214
using ireg_t = typename indexType::reg_t;
@@ -354,7 +354,7 @@ X86_SIMD_SORT_INLINE void kvsort_n_vec(typename keyType::type_t *keys,
354354

355355
template <typename keyType, typename indexType, int maxN>
356356
X86_SIMD_SORT_INLINE void
357-
argsort_n(typename keyType::type_t *keys, arrsize_t *indices, int N)
357+
argsort_n(const typename keyType::type_t *keys, arrsize_t *indices, int N)
358358
{
359359
static_assert(keyType::numlanes == indexType::numlanes,
360360
"invalid pairing of value/index types");

0 commit comments

Comments
 (0)