Skip to content

Commit 84e4dc0

Browse files
authored
AVX2 take handles indices that are equal to the index type max value (#7359)
The in the avx2 take code correctly states that we need to fitler out values greater than but then it also filters out values that are equal, this fixes the edge case for values that are equal to the max index value --------- Signed-off-by: Robert Kruszewski <github@robertk.io>
1 parent 1d8f307 commit 84e4dc0

1 file changed

Lines changed: 24 additions & 2 deletions

File tree

  • vortex-array/src/arrays/primitive/compute/take

vortex-array/src/arrays/primitive/compute/take/avx2.rs

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ use std::arch::x86_64::_mm_setzero_si128;
1414
use std::arch::x86_64::_mm_shuffle_epi32;
1515
use std::arch::x86_64::_mm_storeu_si128;
1616
use std::arch::x86_64::_mm_unpacklo_epi64;
17+
use std::arch::x86_64::_mm256_andnot_si256;
1718
use std::arch::x86_64::_mm256_cmpgt_epi32;
1819
use std::arch::x86_64::_mm256_cmpgt_epi64;
1920
use std::arch::x86_64::_mm256_cvtepu8_epi32;
@@ -118,7 +119,6 @@ where
118119
/// # Safety
119120
///
120121
/// The caller must ensure the `avx2` feature is enabled.
121-
#[allow(dead_code, unused_variables, reason = "TODO(connor): Implement this")]
122122
#[target_feature(enable = "avx2")]
123123
#[doc(hidden)]
124124
unsafe fn take_avx2<V: NativePType, I: UnsignedPType>(buffer: &[V], indices: &[I]) -> Buffer<V> {
@@ -135,6 +135,10 @@ unsafe fn take_avx2<V: NativePType, I: UnsignedPType>(buffer: &[V], indices: &[I
135135
}};
136136
}
137137

138+
if buffer.is_empty() {
139+
return Buffer::zeroed(indices.len());
140+
}
141+
138142
match (I::PTYPE, V::PTYPE) {
139143
// Int value types. Only 32 and 64 bit types are supported.
140144
(PType::U8, PType::I32) => dispatch_avx2!(u8, i32),
@@ -223,7 +227,7 @@ macro_rules! impl_gather {
223227
// Create a vec of the max idx.
224228
let max_idx_vec = unsafe { $splat(max_idx as _) };
225229
// Create a mask for valid indices (where the max_idx > provided index).
226-
let invalid_mask = unsafe { $mask_indices(max_idx_vec, indices_vec) };
230+
let invalid_mask = unsafe { _mm256_andnot_si256($mask_indices(indices_vec, max_idx_vec), $splat(-1)) };
227231
let invalid_mask = {
228232
let $mask_var = invalid_mask;
229233
$mask_cvt
@@ -579,4 +583,22 @@ mod avx2_tests {
579583
index_type => u64,
580584
value_types => u32, i32, u64, i64, f32, f64
581585
);
586+
587+
#[test]
588+
fn test_avx2_take_last_valid_index_u8() {
589+
let values: Vec<i64> = (0..(255 + 1)).collect();
590+
let indices: Vec<u8> = vec![255; 20];
591+
592+
let result = unsafe { take_avx2(&values, &indices) };
593+
assert_eq!(&vec![255; indices.len()], result.as_slice());
594+
}
595+
596+
#[test]
597+
fn test_avx2_take_last_valid_index_u16() {
598+
let values: Vec<i64> = (0..(65535 + 1)).collect();
599+
let indices: Vec<u16> = vec![65535; 20];
600+
601+
let result = unsafe { take_avx2(&values, &indices) };
602+
assert_eq!(&vec![65535; indices.len()], result.as_slice());
603+
}
582604
}

0 commit comments

Comments
 (0)