Skip to content

Commit b8431cb

Browse files
committed
use child for values instead of buffer
Signed-off-by: Andrew Duffy <andrew@a10y.dev>
1 parent 5dcae7d commit b8431cb

File tree

7 files changed

+99
-59
lines changed

7 files changed

+99
-59
lines changed

vortex-array/src/arrays/patched/array.rs

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,19 @@ use crate::ArrayRef;
1212
use crate::Canonical;
1313
use crate::DynArray;
1414
use crate::ExecutionCtx;
15+
use crate::IntoArray;
16+
use crate::arrays::PrimitiveArray;
1517
use crate::arrays::patched::PatchAccessor;
1618
use crate::arrays::patched::TransposedPatches;
1719
use crate::arrays::patched::patch_lanes;
1820
use crate::buffer::BufferHandle;
1921
use crate::dtype::IntegerPType;
2022
use crate::dtype::NativePType;
21-
use crate::dtype::PType;
2223
use crate::match_each_native_ptype;
2324
use crate::match_each_unsigned_integer_ptype;
2425
use crate::patches::Patches;
2526
use crate::stats::ArrayStats;
27+
use crate::validity::Validity;
2628

2729
/// An array that partially "patches" another array with new values.
2830
///
@@ -50,14 +52,17 @@ pub struct PatchedArray {
5052
/// indices within a 1024-element chunk. The PType of these MUST be u16
5153
pub(super) indices: BufferHandle,
5254
/// patch values corresponding to the indices. The ptype is specified by `values_ptype`.
53-
pub(super) values: BufferHandle,
54-
/// PType of the scalars in `values`. Can be any native type.
55-
pub(super) values_ptype: PType,
55+
pub(super) values: ArrayRef,
5656

5757
pub(super) stats_set: ArrayStats,
5858
}
5959

6060
impl PatchedArray {
61+
/// Create a new `PatchedArray` from a child array and a set of [`Patches`].
62+
///
63+
/// # Errors
64+
///
65+
/// The `inner` array must be primitive type, and it must have the same DType as the patches.
6166
pub fn from_array_and_patches(
6267
inner: ArrayRef,
6368
patches: &Patches,
@@ -68,6 +73,11 @@ impl PatchedArray {
6873
"array DType must match patches DType"
6974
);
7075

76+
vortex_ensure!(
77+
inner.dtype().is_primitive(),
78+
"Creating PatchedArray from Patches only supported for primitive arrays"
79+
);
80+
7181
let values_ptype = patches.dtype().as_ptype();
7282

7383
let TransposedPatches {
@@ -80,27 +90,32 @@ impl PatchedArray {
8090

8191
let len = inner.len();
8292

93+
let values = PrimitiveArray::from_buffer_handle(
94+
BufferHandle::new_host(values),
95+
values_ptype,
96+
Validity::NonNullable,
97+
)
98+
.into_array();
99+
83100
Ok(Self {
84101
inner,
85102
n_chunks,
86103
n_lanes,
87-
values_ptype,
88104
offset: 0,
89105
len,
90106
lane_offsets: BufferHandle::new_host(lane_offsets),
91107
indices: BufferHandle::new_host(indices),
92-
values: BufferHandle::new_host(values),
108+
values,
93109
stats_set: ArrayStats::default(),
94110
})
95111
}
96112

97113
/// Get an accessor, which allows ranged access to patches by chunk/lane.
98-
pub fn accessor<V: NativePType>(&self) -> PatchAccessor<'_, V> {
114+
pub fn accessor(&self) -> PatchAccessor<'_> {
99115
PatchAccessor {
100116
n_lanes: self.n_lanes,
101117
lane_offsets: self.lane_offsets.as_host().reinterpret::<u32>(),
102118
indices: self.indices.as_host().reinterpret::<u16>(),
103-
values: self.values.as_host().reinterpret::<V>(),
104119
}
105120
}
106121

@@ -133,7 +148,6 @@ impl PatchedArray {
133148
len,
134149
indices,
135150
values,
136-
values_ptype: self.values_ptype,
137151
lane_offsets: sliced_lane_offsets,
138152
stats_set: ArrayStats::default(),
139153
})

vortex-array/src/arrays/patched/compute/compare.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use crate::IntoArray;
1212
use crate::arrays::BoolArray;
1313
use crate::arrays::ConstantArray;
1414
use crate::arrays::Patched;
15+
use crate::arrays::PrimitiveArray;
1516
use crate::arrays::bool::BoolArrayParts;
1617
use crate::arrays::patched::patch_lanes;
1718
use crate::arrays::primitive::NativeValue;
@@ -28,6 +29,12 @@ impl CompareKernel for Patched {
2829
operator: CompareOperator,
2930
ctx: &mut ExecutionCtx,
3031
) -> VortexResult<Option<ArrayRef>> {
32+
// We only accelerate comparisons for primitives
33+
if !lhs.dtype().is_primitive() {
34+
return Ok(None);
35+
}
36+
37+
// We only accelerate comparisons against constants
3138
let Some(constant) = rhs.as_constant() else {
3239
return Ok(None);
3340
};
@@ -87,9 +94,10 @@ impl CompareKernel for Patched {
8794

8895
let lane_offsets = lhs.lane_offsets.as_host().reinterpret::<u32>();
8996
let indices = lhs.indices.as_host().reinterpret::<u16>();
97+
let values = lhs.values.clone().execute::<PrimitiveArray>(ctx)?;
9098

91-
match_each_native_ptype!(lhs.values_ptype, |V| {
92-
let values = lhs.values.as_host().reinterpret::<V>();
99+
match_each_native_ptype!(values.ptype(), |V| {
100+
let values = values.as_slice::<V>();
93101
let constant = constant
94102
.as_primitive()
95103
.as_::<V>()

vortex-array/src/arrays/patched/compute/take.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,11 @@ impl TakeExecute for Patched {
2424
indices: &ArrayRef,
2525
ctx: &mut ExecutionCtx,
2626
) -> VortexResult<Option<ArrayRef>> {
27+
// Only pushdown take when we have primitive types.
28+
if !array.dtype().is_primitive() {
29+
return Ok(None);
30+
}
31+
2732
// Perform take on the inner array, including the placeholders.
2833
let inner = array
2934
.inner
@@ -41,6 +46,7 @@ impl TakeExecute for Patched {
4146
match_each_unsigned_integer_ptype!(indices_ptype, |I| {
4247
match_each_native_ptype!(ptype, |V| {
4348
let indices = indices.clone().execute::<PrimitiveArray>(ctx)?;
49+
let values = array.values.clone().execute::<PrimitiveArray>(ctx)?;
4450
let mut output = Buffer::<V>::from_byte_buffer(buffer.unwrap_host()).into_mut();
4551
take_map(
4652
output.as_mut(),
@@ -51,7 +57,7 @@ impl TakeExecute for Patched {
5157
array.n_lanes,
5258
array.lane_offsets.as_host().reinterpret::<u32>(),
5359
array.indices.as_host().reinterpret::<u16>(),
54-
array.values.as_host().reinterpret::<V>(),
60+
values.as_slice::<V>(),
5561
);
5662

5763
// SAFETY: output and validity still have same length after take_map returns.

vortex-array/src/arrays/patched/mod.rs

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,23 +32,34 @@ const fn patch_lanes<V: Sized>() -> usize {
3232
if size_of::<V>() < 8 { 32 } else { 16 }
3333
}
3434

35-
pub struct PatchAccessor<'a, V> {
35+
pub struct PatchAccessor<'a> {
3636
n_lanes: usize,
3737
lane_offsets: &'a [u32],
3838
indices: &'a [u16],
39-
values: &'a [V],
4039
}
4140

42-
impl<'a, V: Sized> PatchAccessor<'a, V> {
43-
/// Access the patches for a particular lane
44-
pub fn access(&'a self, chunk: usize, lane: usize) -> LanePatches<'a, V> {
41+
pub struct PatchOffset {
42+
/// Global offset into the list of patches. These are some of the
43+
pub index: usize,
44+
/// This is the value stored in the `indices` buffer, which encodes the offset of the `index`-th
45+
/// patch
46+
pub chunk_offset: u16,
47+
}
48+
49+
impl<'a> PatchAccessor<'a> {
50+
/// Get an iterator over indices and values offsets.
51+
///
52+
/// The first component is the index into the `indices` and `values`, and the second component
53+
/// is the set of values instead here...I think?
54+
pub fn offsets_iter(
55+
&self,
56+
chunk: usize,
57+
lane: usize,
58+
) -> impl Iterator<Item = (usize, u16)> + '_ {
4559
let start = self.lane_offsets[chunk * self.n_lanes + lane] as usize;
4660
let stop = self.lane_offsets[chunk * self.n_lanes + lane + 1] as usize;
4761

48-
LanePatches {
49-
indices: &self.indices[start..stop],
50-
values: &self.values[start..stop],
51-
}
62+
std::iter::zip(start..stop, self.indices[start..stop].iter().copied())
5263
}
5364
}
5465

vortex-array/src/arrays/patched/vtable/mod.rs

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,6 @@ impl VTable for Patched {
9191

9292
fn array_hash<H: Hasher>(array: &Self::Array, state: &mut H, precision: Precision) {
9393
array.inner.array_hash(state, precision);
94-
array.values_ptype.hash(state);
9594
array.n_chunks.hash(state);
9695
array.n_lanes.hash(state);
9796
array.lane_offsets.array_hash(state, precision);
@@ -102,7 +101,6 @@ impl VTable for Patched {
102101
fn array_eq(array: &Self::Array, other: &Self::Array, precision: Precision) -> bool {
103102
array.n_chunks == other.n_chunks
104103
&& array.n_lanes == other.n_lanes
105-
&& array.values_ptype == other.values_ptype
106104
&& array.inner.array_eq(&other.inner, precision)
107105
&& array.lane_offsets.array_eq(&other.lane_offsets, precision)
108106
&& array.indices.array_eq(&other.indices, precision)
@@ -117,7 +115,6 @@ impl VTable for Patched {
117115
match idx {
118116
0 => array.lane_offsets.clone(),
119117
1 => array.indices.clone(),
120-
2 => array.values.clone(),
121118
_ => vortex_panic!("invalid buffer index for PatchedArray: {idx}"),
122119
}
123120
}
@@ -126,28 +123,27 @@ impl VTable for Patched {
126123
match idx {
127124
0 => Some("lane_offsets".to_string()),
128125
1 => Some("patch_indices".to_string()),
129-
2 => Some("patch_values".to_string()),
130126
_ => vortex_panic!("invalid buffer index for PatchedArray: {idx}"),
131127
}
132128
}
133129

134130
fn nchildren(_array: &Self::Array) -> usize {
135-
1
131+
2
136132
}
137133

138134
fn child(array: &Self::Array, idx: usize) -> ArrayRef {
139-
if idx == 0 {
140-
array.inner.clone()
141-
} else {
142-
vortex_panic!("invalid child index for PatchedArray: {idx}");
135+
match idx {
136+
0 => array.inner.clone(),
137+
1 => array.values.clone(),
138+
_ => vortex_panic!("invalid buffer index for PatchedArray: {idx}"),
143139
}
144140
}
145141

146142
fn child_name(_array: &Self::Array, idx: usize) -> String {
147-
if idx == 0 {
148-
"inner".to_string()
149-
} else {
150-
vortex_panic!("invalid child index for PatchedArray: {idx}");
143+
match idx {
144+
0 => "inner".to_string(),
145+
1 => "patch_values".to_string(),
146+
_ => vortex_panic!("invalid buffer index for PatchedArray: {idx}"),
151147
}
152148
}
153149

@@ -186,10 +182,14 @@ impl VTable for Patched {
186182

187183
let n_lanes = match_each_native_ptype!(dtype.as_ptype(), |P| { patch_lanes::<P>() });
188184

189-
let &[lane_offsets, indices, values] = &buffers else {
185+
let &[lane_offsets, indices] = &buffers else {
190186
vortex_bail!("invalid buffer count for PatchedArray");
191187
};
192188

189+
// values and indices should have same len.
190+
let expected_len = indices.as_host().reinterpret::<u16>().len();
191+
let values = children.get(1, dtype, expected_len)?;
192+
193193
Ok(PatchedArray {
194194
inner,
195195
n_chunks,
@@ -198,19 +198,19 @@ impl VTable for Patched {
198198
len,
199199
lane_offsets: lane_offsets.clone(),
200200
indices: indices.clone(),
201-
values: values.clone(),
202-
values_ptype: dtype.as_ptype(),
201+
values,
203202
stats_set: ArrayStats::default(),
204203
})
205204
}
206205

207206
fn with_children(array: &mut Self::Array, mut children: Vec<ArrayRef>) -> VortexResult<()> {
208207
vortex_ensure!(
209-
children.len() == 1,
210-
"PatchedArray must have exactly 1 child"
208+
children.len() == 2,
209+
"PatchedArray must have exactly 2 children"
211210
);
212211

213212
array.inner = children.remove(0);
213+
array.values = children.remove(0);
214214

215215
Ok(())
216216
}
@@ -231,23 +231,25 @@ impl VTable for Patched {
231231
let lane_offsets: Buffer<u32> =
232232
Buffer::from_byte_buffer(array.lane_offsets.clone().unwrap_host());
233233
let indices: Buffer<u16> = Buffer::from_byte_buffer(array.indices.clone().unwrap_host());
234+
let values = array.values.clone().execute::<PrimitiveArray>(ctx)?;
235+
236+
// TODO(aduffy): add support for non-primitive PatchedArray patches application.
234237

235-
let patched_values = match_each_native_ptype!(array.values_ptype, |V| {
238+
let patched_values = match_each_native_ptype!(values.ptype(), |V| {
236239
let mut output = Buffer::<V>::from_byte_buffer(buffer.unwrap_host()).into_mut();
237-
let values: Buffer<V> = Buffer::from_byte_buffer(array.values.clone().unwrap_host());
238240

239241
let offset = array.offset;
240242
let len = array.len;
241243

242-
apply::<V>(
244+
apply_patches_primitive::<V>(
243245
&mut output,
244246
offset,
245247
len,
246248
array.n_chunks,
247249
array.n_lanes,
248250
&lane_offsets,
249251
&indices,
250-
&values,
252+
values.as_slice::<V>(),
251253
);
252254

253255
// The output will always be aligned to a chunk boundary, we apply the offset/len
@@ -281,7 +283,7 @@ impl VTable for Patched {
281283

282284
/// Apply patches on top of the existing value types.
283285
#[allow(clippy::too_many_arguments)]
284-
fn apply<V: NativePType>(
286+
fn apply_patches_primitive<V: NativePType>(
285287
output: &mut [V],
286288
offset: usize,
287289
len: usize,

vortex-array/src/arrays/patched/vtable/operations.rs

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use crate::DynArray;
77
use crate::arrays::patched::Patched;
88
use crate::arrays::patched::PatchedArray;
99
use crate::arrays::patched::patch_lanes;
10+
use crate::dtype::PType;
1011
use crate::match_each_native_ptype;
1112
use crate::scalar::Scalar;
1213
use crate::vtable::OperationsVTable;
@@ -17,21 +18,19 @@ impl OperationsVTable<Patched> for Patched {
1718
let chunk = index / 1024;
1819
#[allow(clippy::cast_possible_truncation)]
1920
let chunk_index = (index % 1024) as u16;
20-
match_each_native_ptype!(array.values_ptype, |V| {
21-
let lane = index % patch_lanes::<V>();
22-
let accessor = array.accessor::<V>();
23-
let patches = accessor.access(chunk, lane);
24-
// NOTE: we do linear scan as lane has <= 32 patches, binary search would likely
25-
// be slower.
26-
for (patch_index, patch_value) in patches.iter() {
27-
if patch_index == chunk_index {
28-
return Ok(Scalar::primitive(
29-
patch_value,
30-
array.inner.dtype().nullability(),
31-
));
32-
}
21+
22+
let values_ptype = PType::try_from(array.dtype())?;
23+
24+
let lane = match_each_native_ptype!(values_ptype, |V| { index % patch_lanes::<V>() });
25+
let accessor = array.accessor();
26+
27+
// NOTE: we do linear scan as lane has <= 32 patches, binary search would likely
28+
// be slower.
29+
for (index, patch_index) in accessor.offsets_iter(chunk, lane) {
30+
if patch_index == chunk_index {
31+
return array.values.scalar_at(index);
3332
}
34-
});
33+
}
3534

3635
// Otherwise, access the underlying value.
3736
array.inner.scalar_at(index)

0 commit comments

Comments
 (0)