Skip to content

Commit 33cdc46

Browse files
committed
take
Signed-off-by: Andrew Duffy <andrew@a10y.dev>
1 parent 047ede0 commit 33cdc46

File tree

2 files changed

+111
-0
lines changed

2 files changed

+111
-0
lines changed

vortex-array/src/arrays/patched/compute/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44
mod compare;
55
mod filter;
66
pub(crate) mod rules;
7+
mod take;
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use rustc_hash::FxHashMap;
5+
use vortex_buffer::Buffer;
6+
use vortex_error::VortexResult;
7+
8+
use crate::arrays::dict::TakeExecute;
9+
use crate::arrays::primitive::PrimitiveArrayParts;
10+
use crate::arrays::{Patched, PrimitiveArray};
11+
use crate::dtype::{IntegerPType, NativePType};
12+
use crate::{ArrayRef, DynArray, IntoArray, match_each_native_ptype};
13+
use crate::{ExecutionCtx, match_each_unsigned_integer_ptype};
14+
15+
impl TakeExecute for Patched {
16+
fn take(
17+
array: &Self::Array,
18+
indices: &ArrayRef,
19+
ctx: &mut ExecutionCtx,
20+
) -> VortexResult<Option<ArrayRef>> {
21+
// Perform take on the inner array, including the placeholders.
22+
let inner = array
23+
.inner
24+
.take(indices.clone())?
25+
.execute::<PrimitiveArray>(ctx)?;
26+
27+
let PrimitiveArrayParts {
28+
buffer,
29+
validity,
30+
ptype,
31+
} = inner.into_parts();
32+
33+
let indices_ptype = indices.dtype().as_ptype();
34+
35+
match_each_unsigned_integer_ptype!(indices_ptype, |I| {
36+
match_each_native_ptype!(ptype, |V| {
37+
let indices = indices.clone().execute::<PrimitiveArray>(ctx)?;
38+
let mut output = Buffer::<V>::from_byte_buffer(buffer.unwrap_host()).into_mut();
39+
take_map(
40+
output.as_mut(),
41+
indices.as_slice::<I>(),
42+
array.offset,
43+
array.len,
44+
array.n_chunks,
45+
array.n_lanes,
46+
array.lane_offsets.as_host().reinterpret::<u32>(),
47+
array.indices.as_host().reinterpret::<u16>(),
48+
array.values.as_host().reinterpret::<V>(),
49+
);
50+
51+
// SAFETY: output and validity still have same length after take_map returns.
52+
unsafe {
53+
return Ok(Some(
54+
PrimitiveArray::new_unchecked(output.freeze(), validity).into_array(),
55+
));
56+
}
57+
})
58+
});
59+
}
60+
}
61+
62+
/// Take patches for the given `indices` and apply them onto an `output` using a hash map.
63+
///
64+
/// First, builds a hashmap from index to patch value, then uses the hashmap in a loop to collect
65+
/// the values.
66+
fn take_map<I: IntegerPType, V: NativePType>(
67+
output: &mut [V],
68+
indices: &[I],
69+
offset: usize,
70+
len: usize,
71+
n_chunks: usize,
72+
n_lanes: usize,
73+
lane_offsets: &[u32],
74+
patch_index: &[u16],
75+
patch_value: &[V],
76+
) {
77+
// Build a hashmap of patch_index -> values.
78+
let mut index_map = FxHashMap::with_capacity(indices.len());
79+
for chunk in 0..n_chunks {
80+
for lane in 0..n_lanes {
81+
let [lane_start, lane_end] = lane_offsets[chunk * n_lanes + lane..][..2];
82+
for i in lane_start..lane_end {
83+
let patch_idx = patch_index[i as usize];
84+
let patch_value = patch_value[i as usize];
85+
86+
let index = chunk * 1024 + patch_idx as usize;
87+
if index >= offset && index < offset + len {
88+
index_map.insert(index, patch_value);
89+
}
90+
}
91+
}
92+
}
93+
94+
// Now, iterate the take indices using the prebuilt hashmap.
95+
// Undefined/null indices will miss the hash map, which we can ignore.
96+
for index in indices {
97+
let index = index.as_();
98+
if let Some(&patch_value) = index_map.get(&index) {
99+
output[index] = patch_value;
100+
}
101+
}
102+
}
103+
104+
#[cfg(test)]
105+
mod tests {
106+
#[test]
107+
fn test_take() {
108+
// Patch some values here instead.
109+
}
110+
}

0 commit comments

Comments
 (0)