Skip to content

Commit 35bfb5f

Browse files
committed
use child for values instead of buffer
Signed-off-by: Andrew Duffy <andrew@a10y.dev>
1 parent 6c1d7aa commit 35bfb5f

File tree

7 files changed

+99
-59
lines changed

7 files changed

+99
-59
lines changed

vortex-array/src/arrays/patched/array.rs

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,19 @@ use crate::ArrayRef;
1212
use crate::Canonical;
1313
use crate::DynArray;
1414
use crate::ExecutionCtx;
15+
use crate::IntoArray;
16+
use crate::arrays::PrimitiveArray;
1517
use crate::arrays::patched::PatchAccessor;
1618
use crate::arrays::patched::TransposedPatches;
1719
use crate::arrays::patched::patch_lanes;
1820
use crate::buffer::BufferHandle;
1921
use crate::dtype::IntegerPType;
2022
use crate::dtype::NativePType;
21-
use crate::dtype::PType;
2223
use crate::match_each_native_ptype;
2324
use crate::match_each_unsigned_integer_ptype;
2425
use crate::patches::Patches;
2526
use crate::stats::ArrayStats;
27+
use crate::validity::Validity;
2628

2729
/// An array that partially "patches" another array with new values.
2830
///
@@ -50,14 +52,17 @@ pub struct PatchedArray {
5052
/// indices within a 1024-element chunk. The PType of these MUST be u16
5153
pub(super) indices: BufferHandle,
5254
/// patch values corresponding to the indices. The ptype is specified by `values_ptype`.
53-
pub(super) values: BufferHandle,
54-
/// PType of the scalars in `values`. Can be any native type.
55-
pub(super) values_ptype: PType,
55+
pub(super) values: ArrayRef,
5656

5757
pub(super) stats_set: ArrayStats,
5858
}
5959

6060
impl PatchedArray {
61+
/// Create a new `PatchedArray` from a child array and a set of [`Patches`].
62+
///
63+
/// # Errors
64+
///
65+
/// The `inner` array must be primitive type, and it must have the same DType as the patches.
6166
pub fn from_array_and_patches(
6267
inner: ArrayRef,
6368
patches: &Patches,
@@ -68,6 +73,11 @@ impl PatchedArray {
6873
"array DType must match patches DType"
6974
);
7075

76+
vortex_ensure!(
77+
inner.dtype().is_primitive(),
78+
"Creating PatchedArray from Patches only supported for primitive arrays"
79+
);
80+
7181
let values_ptype = patches.dtype().as_ptype();
7282

7383
let TransposedPatches {
@@ -80,27 +90,32 @@ impl PatchedArray {
8090

8191
let len = inner.len();
8292

93+
let values = PrimitiveArray::from_buffer_handle(
94+
BufferHandle::new_host(values),
95+
values_ptype,
96+
Validity::NonNullable,
97+
)
98+
.into_array();
99+
83100
Ok(Self {
84101
inner,
85102
n_chunks,
86103
n_lanes,
87-
values_ptype,
88104
offset: 0,
89105
len,
90106
lane_offsets: BufferHandle::new_host(lane_offsets),
91107
indices: BufferHandle::new_host(indices),
92-
values: BufferHandle::new_host(values),
108+
values,
93109
stats_set: ArrayStats::default(),
94110
})
95111
}
96112

97113
/// Get an accessor, which allows ranged access to patches by chunk/lane.
98-
pub fn accessor<V: NativePType>(&self) -> PatchAccessor<'_, V> {
114+
pub fn accessor(&self) -> PatchAccessor<'_> {
99115
PatchAccessor {
100116
n_lanes: self.n_lanes,
101117
lane_offsets: self.lane_offsets.as_host().reinterpret::<u32>(),
102118
indices: self.indices.as_host().reinterpret::<u16>(),
103-
values: self.values.as_host().reinterpret::<V>(),
104119
}
105120
}
106121

@@ -133,7 +148,6 @@ impl PatchedArray {
133148
len,
134149
indices,
135150
values,
136-
values_ptype: self.values_ptype,
137151
lane_offsets: sliced_lane_offsets,
138152
stats_set: ArrayStats::default(),
139153
})

vortex-array/src/arrays/patched/compute/compare.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use crate::IntoArray;
1212
use crate::arrays::BoolArray;
1313
use crate::arrays::ConstantArray;
1414
use crate::arrays::Patched;
15+
use crate::arrays::PrimitiveArray;
1516
use crate::arrays::bool::BoolArrayParts;
1617
use crate::arrays::patched::patch_lanes;
1718
use crate::arrays::primitive::NativeValue;
@@ -28,6 +29,12 @@ impl CompareKernel for Patched {
2829
operator: CompareOperator,
2930
ctx: &mut ExecutionCtx,
3031
) -> VortexResult<Option<ArrayRef>> {
32+
// We only accelerate comparisons for primitives
33+
if !lhs.dtype().is_primitive() {
34+
return Ok(None);
35+
}
36+
37+
// We only accelerate comparisons against constants
3138
let Some(constant) = rhs.as_constant() else {
3239
return Ok(None);
3340
};
@@ -87,9 +94,10 @@ impl CompareKernel for Patched {
8794

8895
let lane_offsets = lhs.lane_offsets.as_host().reinterpret::<u32>();
8996
let indices = lhs.indices.as_host().reinterpret::<u16>();
97+
let values = lhs.values.clone().execute::<PrimitiveArray>(ctx)?;
9098

91-
match_each_native_ptype!(lhs.values_ptype, |V| {
92-
let values = lhs.values.as_host().reinterpret::<V>();
99+
match_each_native_ptype!(values.ptype(), |V| {
100+
let values = values.as_slice::<V>();
93101
let constant = constant
94102
.as_primitive()
95103
.as_::<V>()

vortex-array/src/arrays/patched/compute/take.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,11 @@ impl TakeExecute for Patched {
2424
indices: &ArrayRef,
2525
ctx: &mut ExecutionCtx,
2626
) -> VortexResult<Option<ArrayRef>> {
27+
// Only pushdown take when we have primitive types.
28+
if !array.dtype().is_primitive() {
29+
return Ok(None);
30+
}
31+
2732
// Perform take on the inner array, including the placeholders.
2833
let inner = array
2934
.inner
@@ -41,6 +46,7 @@ impl TakeExecute for Patched {
4146
match_each_unsigned_integer_ptype!(indices_ptype, |I| {
4247
match_each_native_ptype!(ptype, |V| {
4348
let indices = indices.clone().execute::<PrimitiveArray>(ctx)?;
49+
let values = array.values.clone().execute::<PrimitiveArray>(ctx)?;
4450
let mut output = Buffer::<V>::from_byte_buffer(buffer.unwrap_host()).into_mut();
4551
take_map(
4652
output.as_mut(),
@@ -51,7 +57,7 @@ impl TakeExecute for Patched {
5157
array.n_lanes,
5258
array.lane_offsets.as_host().reinterpret::<u32>(),
5359
array.indices.as_host().reinterpret::<u16>(),
54-
array.values.as_host().reinterpret::<V>(),
60+
values.as_slice::<V>(),
5561
);
5662

5763
// SAFETY: output and validity still have same length after take_map returns.

vortex-array/src/arrays/patched/mod.rs

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,23 +32,34 @@ const fn patch_lanes<V: Sized>() -> usize {
3232
if size_of::<V>() < 8 { 32 } else { 16 }
3333
}
3434

35-
pub struct PatchAccessor<'a, V> {
35+
pub struct PatchAccessor<'a> {
3636
n_lanes: usize,
3737
lane_offsets: &'a [u32],
3838
indices: &'a [u16],
39-
values: &'a [V],
4039
}
4140

42-
impl<'a, V: Sized> PatchAccessor<'a, V> {
43-
/// Access the patches for a particular lane
44-
pub fn access(&'a self, chunk: usize, lane: usize) -> LanePatches<'a, V> {
41+
pub struct PatchOffset {
42+
/// Global offset into the list of patches. These are some of the
43+
pub index: usize,
44+
/// This is the value stored in the `indices` buffer, which encodes the offset of the `index`-th
45+
/// patch
46+
pub chunk_offset: u16,
47+
}
48+
49+
impl<'a> PatchAccessor<'a> {
50+
/// Get an iterator over indices and values offsets.
51+
///
52+
/// The first component is the index into the `indices` and `values`, and the second component
53+
/// is the set of values instead here...I think?
54+
pub fn offsets_iter(
55+
&self,
56+
chunk: usize,
57+
lane: usize,
58+
) -> impl Iterator<Item = (usize, u16)> + '_ {
4559
let start = self.lane_offsets[chunk * self.n_lanes + lane] as usize;
4660
let stop = self.lane_offsets[chunk * self.n_lanes + lane + 1] as usize;
4761

48-
LanePatches {
49-
indices: &self.indices[start..stop],
50-
values: &self.values[start..stop],
51-
}
62+
std::iter::zip(start..stop, self.indices[start..stop].iter().copied())
5263
}
5364
}
5465

vortex-array/src/arrays/patched/vtable/mod.rs

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@ impl VTable for Patched {
8686

8787
fn array_hash<H: Hasher>(array: &Self::Array, state: &mut H, precision: Precision) {
8888
array.inner.array_hash(state, precision);
89-
array.values_ptype.hash(state);
9089
array.n_chunks.hash(state);
9190
array.n_lanes.hash(state);
9291
array.lane_offsets.array_hash(state, precision);
@@ -97,7 +96,6 @@ impl VTable for Patched {
9796
fn array_eq(array: &Self::Array, other: &Self::Array, precision: Precision) -> bool {
9897
array.n_chunks == other.n_chunks
9998
&& array.n_lanes == other.n_lanes
100-
&& array.values_ptype == other.values_ptype
10199
&& array.inner.array_eq(&other.inner, precision)
102100
&& array.lane_offsets.array_eq(&other.lane_offsets, precision)
103101
&& array.indices.array_eq(&other.indices, precision)
@@ -112,7 +110,6 @@ impl VTable for Patched {
112110
match idx {
113111
0 => array.lane_offsets.clone(),
114112
1 => array.indices.clone(),
115-
2 => array.values.clone(),
116113
_ => vortex_panic!("invalid buffer index for PatchedArray: {idx}"),
117114
}
118115
}
@@ -121,28 +118,27 @@ impl VTable for Patched {
121118
match idx {
122119
0 => Some("lane_offsets".to_string()),
123120
1 => Some("patch_indices".to_string()),
124-
2 => Some("patch_values".to_string()),
125121
_ => vortex_panic!("invalid buffer index for PatchedArray: {idx}"),
126122
}
127123
}
128124

129125
fn nchildren(_array: &Self::Array) -> usize {
130-
1
126+
2
131127
}
132128

133129
fn child(array: &Self::Array, idx: usize) -> ArrayRef {
134-
if idx == 0 {
135-
array.inner.clone()
136-
} else {
137-
vortex_panic!("invalid child index for PatchedArray: {idx}");
130+
match idx {
131+
0 => array.inner.clone(),
132+
1 => array.values.clone(),
133+
_ => vortex_panic!("invalid buffer index for PatchedArray: {idx}"),
138134
}
139135
}
140136

141137
fn child_name(_array: &Self::Array, idx: usize) -> String {
142-
if idx == 0 {
143-
"inner".to_string()
144-
} else {
145-
vortex_panic!("invalid child index for PatchedArray: {idx}");
138+
match idx {
139+
0 => "inner".to_string(),
140+
1 => "patch_values".to_string(),
141+
_ => vortex_panic!("invalid buffer index for PatchedArray: {idx}"),
146142
}
147143
}
148144

@@ -181,10 +177,14 @@ impl VTable for Patched {
181177

182178
let n_lanes = match_each_native_ptype!(dtype.as_ptype(), |P| { patch_lanes::<P>() });
183179

184-
let &[lane_offsets, indices, values] = &buffers else {
180+
let &[lane_offsets, indices] = &buffers else {
185181
vortex_bail!("invalid buffer count for PatchedArray");
186182
};
187183

184+
// values and indices should have same len.
185+
let expected_len = indices.as_host().reinterpret::<u16>().len();
186+
let values = children.get(1, dtype, expected_len)?;
187+
188188
Ok(PatchedArray {
189189
inner,
190190
n_chunks,
@@ -193,19 +193,19 @@ impl VTable for Patched {
193193
len,
194194
lane_offsets: lane_offsets.clone(),
195195
indices: indices.clone(),
196-
values: values.clone(),
197-
values_ptype: dtype.as_ptype(),
196+
values,
198197
stats_set: ArrayStats::default(),
199198
})
200199
}
201200

202201
fn with_children(array: &mut Self::Array, mut children: Vec<ArrayRef>) -> VortexResult<()> {
203202
vortex_ensure!(
204-
children.len() == 1,
205-
"PatchedArray must have exactly 1 child"
203+
children.len() == 2,
204+
"PatchedArray must have exactly 2 children"
206205
);
207206

208207
array.inner = children.remove(0);
208+
array.values = children.remove(0);
209209

210210
Ok(())
211211
}
@@ -226,23 +226,25 @@ impl VTable for Patched {
226226
let lane_offsets: Buffer<u32> =
227227
Buffer::from_byte_buffer(array.lane_offsets.clone().unwrap_host());
228228
let indices: Buffer<u16> = Buffer::from_byte_buffer(array.indices.clone().unwrap_host());
229+
let values = array.values.clone().execute::<PrimitiveArray>(ctx)?;
230+
231+
// TODO(aduffy): add support for non-primitive PatchedArray patches application.
229232

230-
let patched_values = match_each_native_ptype!(array.values_ptype, |V| {
233+
let patched_values = match_each_native_ptype!(values.ptype(), |V| {
231234
let mut output = Buffer::<V>::from_byte_buffer(buffer.unwrap_host()).into_mut();
232-
let values: Buffer<V> = Buffer::from_byte_buffer(array.values.clone().unwrap_host());
233235

234236
let offset = array.offset;
235237
let len = array.len;
236238

237-
apply::<V>(
239+
apply_patches_primitive::<V>(
238240
&mut output,
239241
offset,
240242
len,
241243
array.n_chunks,
242244
array.n_lanes,
243245
&lane_offsets,
244246
&indices,
245-
&values,
247+
values.as_slice::<V>(),
246248
);
247249

248250
// The output will always be aligned to a chunk boundary, we apply the offset/len
@@ -276,7 +278,7 @@ impl VTable for Patched {
276278

277279
/// Apply patches on top of the existing value types.
278280
#[allow(clippy::too_many_arguments)]
279-
fn apply<V: NativePType>(
281+
fn apply_patches_primitive<V: NativePType>(
280282
output: &mut [V],
281283
offset: usize,
282284
len: usize,

vortex-array/src/arrays/patched/vtable/operations.rs

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use crate::DynArray;
77
use crate::arrays::patched::Patched;
88
use crate::arrays::patched::PatchedArray;
99
use crate::arrays::patched::patch_lanes;
10+
use crate::dtype::PType;
1011
use crate::match_each_native_ptype;
1112
use crate::scalar::Scalar;
1213
use crate::vtable::OperationsVTable;
@@ -17,21 +18,19 @@ impl OperationsVTable<Patched> for Patched {
1718
let chunk = index / 1024;
1819
#[allow(clippy::cast_possible_truncation)]
1920
let chunk_index = (index % 1024) as u16;
20-
match_each_native_ptype!(array.values_ptype, |V| {
21-
let lane = index % patch_lanes::<V>();
22-
let accessor = array.accessor::<V>();
23-
let patches = accessor.access(chunk, lane);
24-
// NOTE: we do linear scan as lane has <= 32 patches, binary search would likely
25-
// be slower.
26-
for (patch_index, patch_value) in patches.iter() {
27-
if patch_index == chunk_index {
28-
return Ok(Scalar::primitive(
29-
patch_value,
30-
array.inner.dtype().nullability(),
31-
));
32-
}
21+
22+
let values_ptype = PType::try_from(array.dtype())?;
23+
24+
let lane = match_each_native_ptype!(values_ptype, |V| { index % patch_lanes::<V>() });
25+
let accessor = array.accessor();
26+
27+
// NOTE: we do linear scan as lane has <= 32 patches, binary search would likely
28+
// be slower.
29+
for (index, patch_index) in accessor.offsets_iter(chunk, lane) {
30+
if patch_index == chunk_index {
31+
return array.values.scalar_at(index);
3332
}
34-
});
33+
}
3534

3635
// Otherwise, access the underlying value.
3736
array.inner.scalar_at(index)

0 commit comments

Comments
 (0)