Skip to content

Commit 8557f46

Browse files
joseph-isaacsclaude
andcommitted
feat: iterative execution for SparseArray
- Add #[array_slots] macro for SparseSlots - Add SparseParts struct + OwnedExt trait for into_parts(self) - require_child! indices => Primitive, values => AnyCanonical - execute_sparse takes SparseParts with resolved patches - Inner functions (varbin, lists, fixed_size_list) take &Patches + &SparseArray for validity Signed-off-by: Joe Isaacs <joe.isaacs@live.co.uk> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent b13ba9c commit 8557f46

2 files changed

Lines changed: 167 additions & 67 deletions

File tree

encodings/sparse/src/canonical.rs

Lines changed: 85 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -57,66 +57,90 @@ use vortex_error::vortex_panic;
5757
use crate::ConstantArray;
5858
use crate::Sparse;
5959
use crate::SparseArray;
60+
use crate::SparseParts;
61+
62+
/// Build a temporary [`SparseArray`] from resolved patches (for computing validity).
63+
fn sparse_array_for_validity(patches: &Patches, fill_value: &Scalar, len: usize) -> SparseArray {
64+
// Re-wrap resolved patches (offset=0) into a SparseArray so we can call .validity().
65+
Sparse::try_new(
66+
patches.indices().clone(),
67+
patches.values().clone(),
68+
len,
69+
fill_value.clone(),
70+
)
71+
.vortex_expect("rebuilding SparseArray for validity")
72+
}
73+
6074
pub(super) fn execute_sparse(
61-
array: &SparseArray,
75+
parts: SparseParts,
6276
ctx: &mut ExecutionCtx,
6377
) -> VortexResult<ArrayRef> {
64-
if array.patches().num_patches() == 0 {
65-
return Ok(ConstantArray::new(array.fill_scalar().clone(), array.len()).into_array());
78+
let SparseParts { patches, fill_value, dtype, len } = parts;
79+
80+
if patches.num_patches() == 0 {
81+
return Ok(ConstantArray::new(fill_value, len).into_array());
6682
}
6783

68-
Ok(match array.dtype() {
84+
// Patches are already resolved (offset subtracted) by SparseParts::resolve_patches().
85+
Ok(match &dtype {
6986
DType::Null => {
70-
assert!(array.fill_scalar().is_null());
71-
NullArray::new(array.len()).into_array()
87+
assert!(fill_value.is_null());
88+
NullArray::new(len).into_array()
7289
}
7390
DType::Bool(..) => {
74-
let resolved_patches = array.resolved_patches()?;
75-
execute_sparse_bools(&resolved_patches, array.fill_scalar(), ctx)?
91+
execute_sparse_bools(&patches, &fill_value, ctx)?
7692
}
7793
DType::Primitive(ptype, ..) => {
78-
let resolved_patches = array.resolved_patches()?;
7994
match_each_native_ptype!(ptype, |P| {
80-
execute_sparse_primitives::<P>(&resolved_patches, array.fill_scalar(), ctx)?
95+
execute_sparse_primitives::<P>(&patches, &fill_value, ctx)?
8196
})
8297
}
8398
DType::Struct(struct_fields, ..) => execute_sparse_struct(
8499
struct_fields,
85-
array.fill_scalar().as_struct(),
86-
array.dtype(),
87-
array.patches(),
88-
array.len(),
100+
fill_value.as_struct(),
101+
&dtype,
102+
&patches,
103+
len,
89104
ctx,
90105
)?,
91106
DType::Decimal(decimal_dtype, nullability) => {
92107
let canonical_decimal_value_type =
93108
DecimalType::smallest_decimal_value_type(decimal_dtype);
94-
let fill_value = array.fill_scalar().as_decimal();
109+
let fill_decimal = fill_value.as_decimal();
95110
match_each_decimal_value_type!(canonical_decimal_value_type, |D| {
96111
execute_sparse_decimal::<D>(
97112
*decimal_dtype,
98113
*nullability,
99-
fill_value,
100-
array.patches(),
101-
array.len(),
114+
fill_decimal,
115+
&patches,
116+
len,
102117
ctx,
103118
)?
104119
})
105120
}
106121
dtype @ DType::Utf8(..) => {
107-
let fill_value = array.fill_scalar().as_utf8().value().cloned();
108-
let fill_value = fill_value.map(BufferString::into_inner);
109-
execute_varbin(array, dtype.clone(), fill_value, ctx)?
122+
let fill = fill_value.as_utf8().value().cloned();
123+
let fill = fill.map(BufferString::into_inner);
124+
let validity_arr = sparse_array_for_validity(&patches, &fill_value, len);
125+
execute_varbin(&patches, &validity_arr, dtype.clone(), fill, ctx)?
110126
}
111127
dtype @ DType::Binary(..) => {
112-
let fill_value = array.fill_scalar().as_binary().value().cloned();
113-
execute_varbin(array, dtype.clone(), fill_value, ctx)?
128+
let fill = fill_value.as_binary().value().cloned();
129+
let validity_arr = sparse_array_for_validity(&patches, &fill_value, len);
130+
execute_varbin(&patches, &validity_arr, dtype.clone(), fill, ctx)?
114131
}
115132
DType::List(values_dtype, nullability) => {
116-
execute_sparse_lists(array, Arc::clone(values_dtype), *nullability, ctx)?
133+
let validity_arr = sparse_array_for_validity(&patches, &fill_value, len);
134+
execute_sparse_lists(
135+
&patches, &validity_arr, &fill_value,
136+
Arc::clone(values_dtype), *nullability, ctx,
137+
)?
117138
}
118139
DType::FixedSizeList(.., nullability) => {
119-
execute_sparse_fixed_size_list(array, *nullability, ctx)?
140+
let validity_arr = sparse_array_for_validity(&patches, &fill_value, len);
141+
execute_sparse_fixed_size_list(
142+
&patches, &validity_arr, &fill_value, *nullability, ctx,
143+
)?
120144
}
121145
DType::Extension(_ext_dtype) => todo!(),
122146
DType::Variant(_) => vortex_bail!("Sparse canonicalization does not support Variant"),
@@ -128,39 +152,39 @@ pub(super) fn execute_sparse(
128152
reason = "complexity is from nested match_smallest_offset_type macro"
129153
)]
130154
fn execute_sparse_lists(
131-
array: &SparseArray,
155+
resolved: &Patches,
156+
validity_array: &SparseArray,
157+
fill_value: &Scalar,
132158
values_dtype: Arc<DType>,
133159
nullability: Nullability,
134160
ctx: &mut ExecutionCtx,
135161
) -> VortexResult<ArrayRef> {
136-
let resolved_patches = array.resolved_patches()?;
137-
138-
let indices = resolved_patches
162+
let indices = resolved
139163
.indices()
140164
.clone()
141165
.execute::<PrimitiveArray>(ctx)?;
142-
let values = resolved_patches
166+
let values = resolved
143167
.values()
144168
.clone()
145169
.execute::<ListViewArray>(ctx)?;
146-
let fill_value = array.fill_scalar().as_list();
170+
let fill_list = fill_value.as_list();
147171

148-
let n_filled = array.len() - resolved_patches.num_patches();
149-
let total_canonical_values = values.elements().len() + fill_value.len() * n_filled;
172+
let len = validity_array.len();
173+
let n_filled = len - resolved.num_patches();
174+
let total_canonical_values = values.elements().len() + fill_list.len() * n_filled;
150175

151-
let validity = {
152-
let arr = array.as_array();
153-
Validity::from_mask(arr.validity()?.execute_mask(arr.len(), ctx)?, nullability)
154-
};
176+
let arr = validity_array.as_ref();
177+
let validity =
178+
Validity::from_mask(arr.validity()?.execute_mask(arr.len(), ctx)?, nullability);
155179

156180
Ok(match_each_integer_ptype!(indices.ptype(), |I| {
157181
match_smallest_offset_type!(total_canonical_values, |O| {
158182
execute_sparse_lists_inner::<I, O>(
159183
indices.as_slice(),
160184
values,
161-
fill_value,
185+
fill_list,
162186
values_dtype,
163-
array.len(),
187+
len,
164188
total_canonical_values,
165189
validity,
166190
ctx,
@@ -224,32 +248,33 @@ fn execute_sparse_lists_inner<I: IntegerPType, O: IntegerPType>(
224248

225249
/// Canonicalize a sparse [`FixedSizeListArray`] by expanding it into a dense representation.
226250
fn execute_sparse_fixed_size_list(
227-
array: &SparseArray,
251+
resolved: &Patches,
252+
validity_array: &SparseArray,
253+
fill_value: &Scalar,
228254
nullability: Nullability,
229255
ctx: &mut ExecutionCtx,
230256
) -> VortexResult<ArrayRef> {
231-
let resolved_patches = array.resolved_patches()?;
232-
let indices = resolved_patches
257+
let indices = resolved
233258
.indices()
234259
.clone()
235260
.execute::<PrimitiveArray>(ctx)?;
236-
let values = resolved_patches
261+
let values = resolved
237262
.values()
238263
.clone()
239264
.execute::<FixedSizeListArray>(ctx)?;
240-
let fill_value = array.fill_scalar().as_list();
265+
let fill_list = fill_value.as_list();
266+
let len = validity_array.len();
241267

242-
let validity = {
243-
let arr = array.as_array();
244-
Validity::from_mask(arr.validity()?.execute_mask(arr.len(), ctx)?, nullability)
245-
};
268+
let arr = validity_array.as_ref();
269+
let validity =
270+
Validity::from_mask(arr.validity()?.execute_mask(arr.len(), ctx)?, nullability);
246271

247272
Ok(match_each_integer_ptype!(indices.ptype(), |I| {
248273
execute_sparse_fixed_size_list_inner::<I>(
249274
indices.as_slice(),
250275
values,
251-
fill_value,
252-
array.len(),
276+
fill_list,
277+
len,
253278
validity,
254279
ctx,
255280
)
@@ -496,22 +521,20 @@ fn execute_sparse_decimal<D: NativeDecimalType>(
496521
}
497522

498523
fn execute_varbin(
499-
array: &SparseArray,
524+
resolved: &Patches,
525+
validity_array: &SparseArray,
500526
dtype: DType,
501527
fill_value: Option<ByteBuffer>,
502528
ctx: &mut ExecutionCtx,
503529
) -> VortexResult<ArrayRef> {
504-
let patches = array.resolved_patches()?;
505-
let indices = patches.indices().clone().execute::<PrimitiveArray>(ctx)?;
506-
let values = patches.values().clone().execute::<VarBinViewArray>(ctx)?;
507-
let validity = {
508-
let arr = array.as_array();
509-
Validity::from_mask(
510-
arr.validity()?.execute_mask(arr.len(), ctx)?,
511-
dtype.nullability(),
512-
)
513-
};
514-
let len = array.len();
530+
let indices = resolved.indices().clone().execute::<PrimitiveArray>(ctx)?;
531+
let values = resolved.values().clone().execute::<VarBinViewArray>(ctx)?;
532+
let arr = validity_array.as_ref();
533+
let validity = Validity::from_mask(
534+
arr.validity()?.execute_mask(arr.len(), ctx)?,
535+
dtype.nullability(),
536+
);
537+
let len = arr.len();
515538

516539
Ok(match_each_integer_ptype!(indices.ptype(), |I| {
517540
let indices = indices.to_buffer::<I>();

encodings/sparse/src/lib.rs

Lines changed: 82 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@ use vortex_array::ArrayEq;
1414
use vortex_array::ArrayHash;
1515
use vortex_array::ArrayId;
1616
use vortex_array::ArrayParts;
17+
use vortex_array::AnyCanonical;
1718
use vortex_array::ArrayRef;
1819
use vortex_array::ArrayView;
1920
use vortex_array::Canonical;
2021
use vortex_array::ExecutionCtx;
2122
use vortex_array::ExecutionResult;
2223
use vortex_array::IntoArray;
2324
use vortex_array::Precision;
25+
use vortex_array::arrays::Primitive;
2426
use vortex_array::arrays::BoolArray;
2527
use vortex_array::arrays::ConstantArray;
2628
use vortex_array::arrays::PrimitiveArray;
@@ -64,6 +66,69 @@ mod slice;
6466
/// A [`Sparse`]-encoded Vortex array.
6567
pub type SparseArray = Array<Sparse>;
6668

69+
#[vortex_array::array_slots(Sparse)]
70+
pub struct SparseSlots {
71+
pub patch_indices: ArrayRef,
72+
pub patch_values: ArrayRef,
73+
pub patch_chunk_offsets: Option<ArrayRef>,
74+
}
75+
76+
/// Concrete parts of a [`SparseArray`] after iterative execution.
77+
pub(crate) struct SparseParts {
78+
pub patches: Patches,
79+
pub fill_value: Scalar,
80+
pub dtype: DType,
81+
pub len: usize,
82+
}
83+
84+
impl SparseParts {
85+
/// Resolve patches by subtracting the offset from indices.
86+
pub fn resolve_patches(mut self) -> VortexResult<Self> {
87+
if self.patches.offset() != 0 {
88+
let offset_scalar =
89+
Scalar::from(self.patches.offset()).cast(self.patches.indices().dtype())?;
90+
let indices = self.patches.indices().binary(
91+
ConstantArray::new(offset_scalar, self.patches.indices().len()).into_array(),
92+
Operator::Sub,
93+
)?;
94+
self.patches = Patches::new(
95+
self.patches.array_len(),
96+
0,
97+
indices,
98+
self.patches.values().clone(),
99+
None,
100+
)?;
101+
}
102+
Ok(self)
103+
}
104+
}
105+
106+
pub(crate) trait SparseOwnedExt {
107+
fn into_parts(self) -> VortexResult<SparseParts>;
108+
}
109+
110+
impl SparseOwnedExt for Array<Sparse> {
111+
fn into_parts(self) -> VortexResult<SparseParts> {
112+
let patches = Patches::new(
113+
self.len(),
114+
self.patches().offset(),
115+
self.as_ref().slots()[SparseSlots::PATCH_INDICES]
116+
.clone()
117+
.vortex_expect("indices"),
118+
self.as_ref().slots()[SparseSlots::PATCH_VALUES]
119+
.clone()
120+
.vortex_expect("values"),
121+
self.as_ref().slots()[SparseSlots::PATCH_CHUNK_OFFSETS].clone(),
122+
)?;
123+
Ok(SparseParts {
124+
patches,
125+
fill_value: self.fill_scalar().clone(),
126+
dtype: self.dtype().clone(),
127+
len: self.len(),
128+
})
129+
}
130+
}
131+
67132
#[derive(Clone, prost::Message)]
68133
#[repr(C)]
69134
pub struct SparseMetadata {
@@ -186,7 +251,7 @@ impl VTable for Sparse {
186251
}
187252

188253
fn slot_name(_array: ArrayView<'_, Self>, idx: usize) -> String {
189-
SLOT_NAMES[idx].to_string()
254+
SparseSlots::NAMES[idx].to_string()
190255
}
191256

192257
fn reduce_parent(
@@ -207,13 +272,25 @@ impl VTable for Sparse {
207272
}
208273

209274
fn execute(array: Array<Self>, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
210-
execute_sparse(&array, ctx).map(ExecutionResult::done)
275+
// Require children to be executed through the scheduler,
276+
// enabling cross-step optimization via reduce_parent rules.
277+
let array = vortex_array::require_child!(
278+
array, array.patch_indices(), SparseSlots::PATCH_INDICES => Primitive
279+
);
280+
let array = vortex_array::require_child!(
281+
array, array.patch_values(), SparseSlots::PATCH_VALUES => AnyCanonical
282+
);
283+
vortex_array::require_opt_child!(
284+
array,
285+
array.patch_chunk_offsets(),
286+
SparseSlots::PATCH_CHUNK_OFFSETS => Primitive
287+
);
288+
289+
let parts = array.into_parts()?.resolve_patches()?;
290+
execute_sparse(parts, ctx).map(ExecutionResult::done)
211291
}
212292
}
213293

214-
pub(crate) const NUM_SLOTS: usize = 3;
215-
pub(crate) const SLOT_NAMES: [&str; NUM_SLOTS] =
216-
["patch_indices", "patch_values", "patch_chunk_offsets"];
217294

218295
#[derive(Clone, Debug)]
219296
pub struct SparseData {

0 commit comments

Comments
 (0)