Skip to content

Commit c05e4d0

Browse files
committed
add unit tests
Signed-off-by: Andrew Duffy <andrew@a10y.dev>
1 parent 33cdc46 commit c05e4d0

File tree

1 file changed

+131
-11
lines changed
  • vortex-array/src/arrays/patched/compute

1 file changed

+131
-11
lines changed

vortex-array/src/arrays/patched/compute/take.rs

Lines changed: 131 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,18 @@ use rustc_hash::FxHashMap;
55
use vortex_buffer::Buffer;
66
use vortex_error::VortexResult;
77

8+
use crate::ArrayRef;
9+
use crate::DynArray;
10+
use crate::ExecutionCtx;
11+
use crate::IntoArray;
12+
use crate::arrays::Patched;
13+
use crate::arrays::PrimitiveArray;
814
use crate::arrays::dict::TakeExecute;
915
use crate::arrays::primitive::PrimitiveArrayParts;
10-
use crate::arrays::{Patched, PrimitiveArray};
11-
use crate::dtype::{IntegerPType, NativePType};
12-
use crate::{ArrayRef, DynArray, IntoArray, match_each_native_ptype};
13-
use crate::{ExecutionCtx, match_each_unsigned_integer_ptype};
16+
use crate::dtype::IntegerPType;
17+
use crate::dtype::NativePType;
18+
use crate::match_each_native_ptype;
19+
use crate::match_each_unsigned_integer_ptype;
1420

1521
impl TakeExecute for Patched {
1622
fn take(
@@ -50,19 +56,20 @@ impl TakeExecute for Patched {
5056

5157
// SAFETY: output and validity still have same length after take_map returns.
5258
unsafe {
53-
return Ok(Some(
59+
Ok(Some(
5460
PrimitiveArray::new_unchecked(output.freeze(), validity).into_array(),
55-
));
61+
))
5662
}
5763
})
58-
});
64+
})
5965
}
6066
}
6167

6268
/// Take patches for the given `indices` and apply them onto an `output` using a hash map.
6369
///
6470
/// First, builds a hashmap from index to patch value, then uses the hashmap in a loop to collect
6571
/// the values.
72+
#[allow(clippy::too_many_arguments)]
6673
fn take_map<I: IntegerPType, V: NativePType>(
6774
output: &mut [V],
6875
indices: &[I],
@@ -75,10 +82,11 @@ fn take_map<I: IntegerPType, V: NativePType>(
7582
patch_value: &[V],
7683
) {
7784
// Build a hashmap of patch_index -> values.
78-
let mut index_map = FxHashMap::with_capacity(indices.len());
85+
let mut index_map = FxHashMap::with_capacity_and_hasher(indices.len(), Default::default());
7986
for chunk in 0..n_chunks {
8087
for lane in 0..n_lanes {
81-
let [lane_start, lane_end] = lane_offsets[chunk * n_lanes + lane..][..2];
88+
let lane_start = lane_offsets[chunk * n_lanes + lane];
89+
let lane_end = lane_offsets[chunk * n_lanes + lane + 1];
8290
for i in lane_start..lane_end {
8391
let patch_idx = patch_index[i as usize];
8492
let patch_value = patch_value[i as usize];
@@ -103,8 +111,120 @@ fn take_map<I: IntegerPType, V: NativePType>(
103111

104112
#[cfg(test)]
105113
mod tests {
114+
use vortex_buffer::buffer;
115+
use vortex_error::VortexResult;
116+
use vortex_session::VortexSession;
117+
118+
use crate::DynArray;
119+
use crate::ExecutionCtx;
120+
use crate::IntoArray;
121+
use crate::arrays::PatchedArray;
122+
use crate::arrays::PrimitiveArray;
123+
use crate::assert_arrays_eq;
124+
use crate::patches::Patches;
125+
126+
fn make_patched_array(
127+
base: &[u16],
128+
patch_indices: &[u32],
129+
patch_values: &[u16],
130+
) -> VortexResult<PatchedArray> {
131+
let values = PrimitiveArray::from_iter(base.iter().copied()).into_array();
132+
let patches = Patches::new(
133+
base.len(),
134+
0,
135+
PrimitiveArray::from_iter(patch_indices.iter().copied()).into_array(),
136+
PrimitiveArray::from_iter(patch_values.iter().copied()).into_array(),
137+
None,
138+
)?;
139+
140+
let session = VortexSession::empty();
141+
let mut ctx = ExecutionCtx::new(session);
142+
143+
PatchedArray::from_array_and_patches(values, &patches, &mut ctx)
144+
}
145+
146+
#[test]
147+
fn test_take_basic() -> VortexResult<()> {
148+
// Array with base values [0, 0, 0, 0, 0] patched at indices [1, 3] with values [10, 30]
149+
let array = make_patched_array(&[0; 5], &[1, 3], &[10, 30])?.into_array();
150+
151+
// Take indices [0, 1, 2, 3, 4] - should get [0, 10, 0, 30, 0]
152+
let indices = buffer![0u32, 1, 2, 3, 4].into_array();
153+
let result = array.take(indices)?;
154+
155+
let expected = PrimitiveArray::from_iter([0u16, 10, 0, 30, 0]).into_array();
156+
assert_arrays_eq!(expected, result);
157+
158+
Ok(())
159+
}
160+
161+
#[test]
162+
fn test_take_out_of_order() -> VortexResult<()> {
163+
// Array with base values [0, 0, 0, 0, 0] patched at indices [1, 3] with values [10, 30]
164+
let array = make_patched_array(&[0; 5], &[1, 3], &[10, 30])?.into_array();
165+
166+
// Take indices in reverse order
167+
let indices = buffer![4u32, 3, 2, 1, 0].into_array();
168+
let result = array.take(indices)?;
169+
170+
let expected = PrimitiveArray::from_iter([0u16, 30, 0, 10, 0]).into_array();
171+
assert_arrays_eq!(expected, result);
172+
173+
Ok(())
174+
}
175+
176+
#[test]
177+
fn test_take_duplicates() -> VortexResult<()> {
178+
// Array with base values [0, 0, 0, 0, 0] patched at index [2] with value [99]
179+
let array = make_patched_array(&[0; 5], &[2], &[99])?.into_array();
180+
181+
// Take the same patched index multiple times
182+
let indices = buffer![2u32, 2, 0, 2].into_array();
183+
let result = array.take(indices)?;
184+
185+
let expected = PrimitiveArray::from_iter([99u16, 99, 0, 99]).into_array();
186+
assert_arrays_eq!(expected, result);
187+
188+
Ok(())
189+
}
190+
106191
#[test]
107-
fn test_take() {
108-
// Patch some values here instead.
192+
fn test_take_with_null_indices() -> VortexResult<()> {
193+
use crate::arrays::BoolArray;
194+
use crate::validity::Validity;
195+
196+
// Array: 10 elements, base value 0, patches at indices 2, 5, 8 with values 20, 50, 80
197+
let array = make_patched_array(&[0; 10], &[2, 5, 8], &[20, 50, 80])?.into_array();
198+
199+
// Take 10 indices, with nulls at positions 1, 4, 7
200+
// Indices: [0, 2, 2, 5, 8, 0, 5, 8, 3, 1]
201+
// Nulls: [ , , N, , , N, , , N, ]
202+
// Position 2 (index=2, patched) is null
203+
// Position 5 (index=0, unpatched) is null
204+
// Position 8 (index=3, unpatched) is null
205+
let indices = PrimitiveArray::new(
206+
buffer![0u32, 2, 2, 5, 8, 0, 5, 8, 3, 1],
207+
Validity::Array(
208+
BoolArray::from_iter([
209+
true, true, false, true, true, false, true, true, false, true,
210+
])
211+
.into_array(),
212+
),
213+
);
214+
let result = array.take(indices.into_array())?;
215+
216+
// Expected: [0, 20, null, 50, 80, null, 50, 80, null, 0]
217+
let expected = PrimitiveArray::new(
218+
buffer![0u16, 20, 0, 50, 80, 0, 50, 80, 0, 0],
219+
Validity::Array(
220+
BoolArray::from_iter([
221+
true, true, false, true, true, false, true, true, false, true,
222+
])
223+
.into_array(),
224+
),
225+
);
226+
assert_arrays_eq!(expected.into_array(), result);
227+
228+
Ok(())
109229
}
110230
}

0 commit comments

Comments
 (0)