Skip to content

Commit 2ba5636

Browse files
authored
Cut over sum aggregate function (#6910)
Move sum compute function over to the Sum impl of AggregatFnVTable Note that we remove the ability to sum extension types since we were incorrectly just summing the storage DType. This doesn't make sense, e.g. summing timestamps is invalid, but summing time deltas is allowed. --------- Signed-off-by: Nicholas Gates <nick@nickgates.com>
1 parent 50868db commit 2ba5636

37 files changed

Lines changed: 1856 additions & 2630 deletions

File tree

fuzz/src/array/mod.rs

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,14 @@ use tracing::debug;
4242
use vortex_array::ArrayRef;
4343
use vortex_array::DynArray;
4444
use vortex_array::IntoArray;
45+
use vortex_array::VortexSessionExecute;
46+
use vortex_array::aggregate_fn::fns::sum::sum;
4547
use vortex_array::arrays::ConstantArray;
4648
use vortex_array::arrays::PrimitiveArray;
4749
use vortex_array::arrays::arbitrary::ArbitraryArray;
4850
use vortex_array::builtins::ArrayBuiltins;
4951
use vortex_array::compute::MinMaxResult;
5052
use vortex_array::compute::min_max;
51-
use vortex_array::compute::sum;
5253
use vortex_array::dtype::DType;
5354
use vortex_array::dtype::Nullability;
5455
use vortex_array::scalar::Scalar;
@@ -68,6 +69,7 @@ use vortex_error::vortex_panic;
6869
use vortex_mask::Mask;
6970
use vortex_utils::aliases::hash_set::HashSet;
7071

72+
use crate::SESSION;
7173
use crate::error::Backtrace;
7274
use crate::error::VortexFuzzError;
7375
use crate::error::VortexFuzzResult;
@@ -173,6 +175,8 @@ impl<'a> Arbitrary<'a> for FuzzArrayAction {
173175
let array = ArbitraryArray::arbitrary(u)?.0;
174176
let mut current_array = array.to_array();
175177

178+
let mut ctx = SESSION.create_execution_ctx();
179+
176180
let mut valid_actions = actions_for_dtype(current_array.dtype())
177181
.into_iter()
178182
.collect::<Vec<_>>();
@@ -330,6 +334,7 @@ impl<'a> Arbitrary<'a> for FuzzArrayAction {
330334
current_array
331335
.to_canonical()
332336
.vortex_expect("to_canonical should succeed in fuzz test"),
337+
&mut ctx,
333338
)
334339
.vortex_expect("sum_canonical_array should succeed in fuzz test");
335340
(Action::Sum, ExpectedValue::Scalar(sum_result))
@@ -566,6 +571,8 @@ pub fn run_fuzz_action(fuzz_action: FuzzArrayAction) -> VortexFuzzResult<bool> {
566571
let FuzzArrayAction { array, actions } = fuzz_action;
567572
let mut current_array = array.to_array();
568573

574+
let mut ctx = SESSION.create_execution_ctx();
575+
569576
debug!(
570577
"Initial array:\nTree:\n{}Values:\n{:#}",
571578
current_array.display_tree(),
@@ -640,8 +647,8 @@ pub fn run_fuzz_action(fuzz_action: FuzzArrayAction) -> VortexFuzzResult<bool> {
640647
current_array = cast_result;
641648
}
642649
Action::Sum => {
643-
let sum_result =
644-
sum(&current_array).vortex_expect("sum operation should succeed in fuzz test");
650+
let sum_result = sum(&current_array, &mut ctx)
651+
.vortex_expect("sum operation should succeed in fuzz test");
645652
assert_scalar_eq(&expected.scalar(), &sum_result, i)?;
646653
}
647654
Action::MinMax => {

fuzz/src/array/sum.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

44
use vortex_array::Canonical;
5+
use vortex_array::ExecutionCtx;
56
use vortex_array::IntoArray as _;
6-
use vortex_array::compute::sum;
7+
use vortex_array::aggregate_fn::fns::sum::sum;
78
use vortex_array::scalar::Scalar;
89
use vortex_error::VortexResult;
910

1011
/// Compute sum on the canonical form of the array to get a consistent baseline.
11-
pub fn sum_canonical_array(canonical: Canonical) -> VortexResult<Scalar> {
12+
pub fn sum_canonical_array(canonical: Canonical, ctx: &mut ExecutionCtx) -> VortexResult<Scalar> {
1213
// TODO(joe): replace with baseline not using canonical
13-
sum(&canonical.into_array())
14+
sum(&canonical.into_array(), ctx)
1415
}

vortex-array/public-api.lock

Lines changed: 22 additions & 138 deletions
Large diffs are not rendered by default.

vortex-array/src/aggregate_fn/accumulator.rs

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33

44
use vortex_error::VortexResult;
55
use vortex_error::vortex_ensure;
6-
use vortex_session::VortexSession;
6+
use vortex_error::vortex_err;
77

88
use crate::AnyCanonical;
99
use crate::ArrayRef;
1010
use crate::Columnar;
1111
use crate::DynArray;
12-
use crate::VortexSessionExecute;
12+
use crate::ExecutionCtx;
1313
use crate::aggregate_fn::AggregateFn;
1414
use crate::aggregate_fn::AggregateFnRef;
1515
use crate::aggregate_fn::AggregateFnVTable;
@@ -35,19 +35,24 @@ pub struct Accumulator<V: AggregateFnVTable> {
3535
partial_dtype: DType,
3636
/// The partial state of the accumulator, updated after each accumulate/merge call.
3737
partial: V::Partial,
38-
/// A session used to lookup custom aggregate kernels.
39-
session: VortexSession,
4038
}
4139

4240
impl<V: AggregateFnVTable> Accumulator<V> {
43-
pub fn try_new(
44-
vtable: V,
45-
options: V::Options,
46-
dtype: DType,
47-
session: VortexSession,
48-
) -> VortexResult<Self> {
49-
let return_dtype = vtable.return_dtype(&options, &dtype)?;
50-
let partial_dtype = vtable.partial_dtype(&options, &dtype)?;
41+
pub fn try_new(vtable: V, options: V::Options, dtype: DType) -> VortexResult<Self> {
42+
let return_dtype = vtable.return_dtype(&options, &dtype).ok_or_else(|| {
43+
vortex_err!(
44+
"Aggregate function {} cannot be applied to dtype {}",
45+
vtable.id(),
46+
dtype
47+
)
48+
})?;
49+
let partial_dtype = vtable.partial_dtype(&options, &dtype).ok_or_else(|| {
50+
vortex_err!(
51+
"Aggregate function {} cannot be applied to dtype {}",
52+
vtable.id(),
53+
dtype
54+
)
55+
})?;
5156
let partial = vtable.empty_partial(&options, &dtype)?;
5257
let aggregate_fn = AggregateFn::new(vtable.clone(), options).erased();
5358

@@ -58,7 +63,6 @@ impl<V: AggregateFnVTable> Accumulator<V> {
5863
return_dtype,
5964
partial_dtype,
6065
partial,
61-
session,
6266
})
6367
}
6468
}
@@ -67,7 +71,7 @@ impl<V: AggregateFnVTable> Accumulator<V> {
6771
/// function is not known at compile time.
6872
pub trait DynAccumulator: 'static + Send {
6973
/// Accumulate a new array into the accumulator's state.
70-
fn accumulate(&mut self, batch: &ArrayRef) -> VortexResult<()>;
74+
fn accumulate(&mut self, batch: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()>;
7175

7276
/// Whether the accumulator's result is fully determined.
7377
fn is_saturated(&self) -> bool;
@@ -84,7 +88,7 @@ pub trait DynAccumulator: 'static + Send {
8488
}
8589

8690
impl<V: AggregateFnVTable> DynAccumulator for Accumulator<V> {
87-
fn accumulate(&mut self, batch: &ArrayRef) -> VortexResult<()> {
91+
fn accumulate(&mut self, batch: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()> {
8892
if self.is_saturated() {
8993
return Ok(());
9094
}
@@ -96,9 +100,9 @@ impl<V: AggregateFnVTable> DynAccumulator for Accumulator<V> {
96100
batch.dtype()
97101
);
98102

99-
let kernels = &self.session.aggregate_fns().kernels;
103+
let session = ctx.session().clone();
104+
let kernels = &session.aggregate_fns().kernels;
100105

101-
let mut ctx = self.session.create_execution_ctx();
102106
let mut batch = batch.clone();
103107
for _ in 0..*MAX_ITERATIONS {
104108
if batch.is::<AnyCanonical>() {
@@ -112,7 +116,7 @@ impl<V: AggregateFnVTable> DynAccumulator for Accumulator<V> {
112116
.or_else(|| kernels_r.get(&(batch_id, None)))
113117
.and_then(|kernel| {
114118
kernel
115-
.aggregate(&self.aggregate_fn, &batch, &mut ctx)
119+
.aggregate(&self.aggregate_fn, &batch, ctx)
116120
.transpose()
117121
})
118122
.transpose()?
@@ -128,14 +132,13 @@ impl<V: AggregateFnVTable> DynAccumulator for Accumulator<V> {
128132
}
129133

130134
// Execute one step and try again
131-
batch = batch.execute(&mut ctx)?;
135+
batch = batch.execute(ctx)?;
132136
}
133137

134138
// Otherwise, execute the batch until it is columnar and accumulate it into the state.
135-
let columnar = batch.execute::<Columnar>(&mut ctx)?;
139+
let columnar = batch.execute::<Columnar>(ctx)?;
136140

137-
self.vtable
138-
.accumulate(&mut self.partial, &columnar, &mut ctx)
141+
self.vtable.accumulate(&mut self.partial, &columnar, ctx)
139142
}
140143

141144
fn is_saturated(&self) -> bool {

vortex-array/src/aggregate_fn/accumulator_grouped.rs

Lines changed: 34 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@ use vortex_error::VortexExpect;
77
use vortex_error::VortexResult;
88
use vortex_error::vortex_bail;
99
use vortex_error::vortex_ensure;
10+
use vortex_error::vortex_err;
1011
use vortex_error::vortex_panic;
1112
use vortex_mask::Mask;
12-
use vortex_session::VortexSession;
1313

1414
use crate::AnyCanonical;
1515
use crate::ArrayRef;
@@ -18,7 +18,6 @@ use crate::Columnar;
1818
use crate::DynArray;
1919
use crate::ExecutionCtx;
2020
use crate::IntoArray;
21-
use crate::VortexSessionExecute;
2221
use crate::aggregate_fn::Accumulator;
2322
use crate::aggregate_fn::AggregateFn;
2423
use crate::aggregate_fn::AggregateFnRef;
@@ -58,20 +57,25 @@ pub struct GroupedAccumulator<V: AggregateFnVTable> {
5857
partial_dtype: DType,
5958
/// The accumulated state for prior batches of groups.
6059
partials: Vec<ArrayRef>,
61-
/// A session used to lookup custom aggregate kernels.
62-
session: VortexSession,
6360
}
6461

6562
impl<V: AggregateFnVTable> GroupedAccumulator<V> {
66-
pub fn try_new(
67-
vtable: V,
68-
options: V::Options,
69-
dtype: DType,
70-
session: VortexSession,
71-
) -> VortexResult<Self> {
63+
pub fn try_new(vtable: V, options: V::Options, dtype: DType) -> VortexResult<Self> {
7264
let aggregate_fn = AggregateFn::new(vtable.clone(), options.clone()).erased();
73-
let return_dtype = vtable.return_dtype(&options, &dtype)?;
74-
let partial_dtype = vtable.partial_dtype(&options, &dtype)?;
65+
let return_dtype = vtable.return_dtype(&options, &dtype).ok_or_else(|| {
66+
vortex_err!(
67+
"Aggregate function {} cannot be applied to dtype {}",
68+
vtable.id(),
69+
dtype
70+
)
71+
})?;
72+
let partial_dtype = vtable.partial_dtype(&options, &dtype).ok_or_else(|| {
73+
vortex_err!(
74+
"Aggregate function {} cannot be applied to dtype {}",
75+
vtable.id(),
76+
dtype
77+
)
78+
})?;
7579

7680
Ok(Self {
7781
vtable,
@@ -81,7 +85,6 @@ impl<V: AggregateFnVTable> GroupedAccumulator<V> {
8185
return_dtype,
8286
partial_dtype,
8387
partials: vec![],
84-
session,
8588
})
8689
}
8790
}
@@ -90,7 +93,7 @@ impl<V: AggregateFnVTable> GroupedAccumulator<V> {
9093
/// function is not known at compile time.
9194
pub trait DynGroupedAccumulator: 'static + Send {
9295
/// Accumulate a list of groups into the accumulator.
93-
fn accumulate_list(&mut self, groups: &ArrayRef) -> VortexResult<()>;
96+
fn accumulate_list(&mut self, groups: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()>;
9497

9598
/// Finish the accumulation and return the partial aggregate results for all groups.
9699
/// Resets the accumulator state for the next round of accumulation.
@@ -102,7 +105,7 @@ pub trait DynGroupedAccumulator: 'static + Send {
102105
}
103106

104107
impl<V: AggregateFnVTable> DynGroupedAccumulator for GroupedAccumulator<V> {
105-
fn accumulate_list(&mut self, groups: &ArrayRef) -> VortexResult<()> {
108+
fn accumulate_list(&mut self, groups: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()> {
106109
let elements_dtype = match groups.dtype() {
107110
DType::List(elem, _) => elem,
108111
DType::FixedSizeList(elem, ..) => elem,
@@ -118,17 +121,15 @@ impl<V: AggregateFnVTable> DynGroupedAccumulator for GroupedAccumulator<V> {
118121
elements_dtype
119122
);
120123

121-
let mut ctx = self.session.create_execution_ctx();
122-
123124
// We first execute the groups until it is a ListView or FixedSizeList, since we only
124125
// dispatch the aggregate kernel over the elements of these arrays.
125-
let canonical = match groups.clone().execute::<Columnar>(&mut ctx)? {
126+
let canonical = match groups.clone().execute::<Columnar>(ctx)? {
126127
Columnar::Canonical(c) => c,
127-
Columnar::Constant(c) => c.into_array().execute::<Canonical>(&mut ctx)?,
128+
Columnar::Constant(c) => c.into_array().execute::<Canonical>(ctx)?,
128129
};
129130
match canonical {
130-
Canonical::List(groups) => self.accumulate_list_view(&groups, &mut ctx),
131-
Canonical::FixedSizeList(groups) => self.accumulate_fixed_size_list(&groups, &mut ctx),
131+
Canonical::List(groups) => self.accumulate_list_view(&groups, ctx),
132+
Canonical::FixedSizeList(groups) => self.accumulate_fixed_size_list(&groups, ctx),
132133
_ => vortex_panic!("We checked the DType above, so this should never happen"),
133134
}
134135
}
@@ -160,8 +161,7 @@ impl<V: AggregateFnVTable> GroupedAccumulator<V> {
160161
ctx: &mut ExecutionCtx,
161162
) -> VortexResult<()> {
162163
let mut elements = groups.elements().clone();
163-
let session = self.session.clone();
164-
164+
let session = ctx.session().clone();
165165
let kernels = &session.aggregate_fns().grouped_kernels;
166166

167167
for _ in 0..*MAX_ITERATIONS {
@@ -205,7 +205,13 @@ impl<V: AggregateFnVTable> GroupedAccumulator<V> {
205205
match_each_integer_ptype!(offsets.dtype().as_ptype(), |O| {
206206
let offsets = offsets.clone().execute::<Buffer<O>>(ctx)?;
207207
let sizes = sizes.execute::<Buffer<O>>(ctx)?;
208-
self.accumulate_list_view_typed(&elements, offsets.as_ref(), sizes.as_ref(), &validity)
208+
self.accumulate_list_view_typed(
209+
&elements,
210+
offsets.as_ref(),
211+
sizes.as_ref(),
212+
&validity,
213+
ctx,
214+
)
209215
})
210216
}
211217

@@ -215,12 +221,12 @@ impl<V: AggregateFnVTable> GroupedAccumulator<V> {
215221
offsets: &[O],
216222
sizes: &[O],
217223
validity: &Mask,
224+
ctx: &mut ExecutionCtx,
218225
) -> VortexResult<()> {
219226
let mut accumulator = Accumulator::try_new(
220227
self.vtable.clone(),
221228
self.options.clone(),
222229
self.dtype.clone(),
223-
self.session.clone(),
224230
)?;
225231
let mut states = builder_with_capacity(&self.partial_dtype, offsets.len());
226232

@@ -230,7 +236,7 @@ impl<V: AggregateFnVTable> GroupedAccumulator<V> {
230236

231237
if validity.value(offset) {
232238
let group = elements.slice(offset..offset + size)?;
233-
accumulator.accumulate(&group)?;
239+
accumulator.accumulate(&group, ctx)?;
234240
states.append_scalar(&accumulator.finish()?)?;
235241
} else {
236242
states.append_null()
@@ -246,8 +252,7 @@ impl<V: AggregateFnVTable> GroupedAccumulator<V> {
246252
ctx: &mut ExecutionCtx,
247253
) -> VortexResult<()> {
248254
let mut elements = groups.elements().clone();
249-
250-
let session = self.session.clone();
255+
let session = ctx.session().clone();
251256
let kernels = &session.aggregate_fns().grouped_kernels;
252257

253258
for _ in 0..64 {
@@ -291,7 +296,6 @@ impl<V: AggregateFnVTable> GroupedAccumulator<V> {
291296
self.vtable.clone(),
292297
self.options.clone(),
293298
self.dtype.clone(),
294-
self.session.clone(),
295299
)?;
296300
let mut states = builder_with_capacity(&self.partial_dtype, groups.len());
297301

@@ -304,7 +308,7 @@ impl<V: AggregateFnVTable> GroupedAccumulator<V> {
304308
for i in 0..groups.len() {
305309
if validity.value(i) {
306310
let group = elements.slice(offset..offset + size)?;
307-
accumulator.accumulate(&group)?;
311+
accumulator.accumulate(&group, ctx)?;
308312
states.append_scalar(&accumulator.finish()?)?;
309313
} else {
310314
states.append_null()

0 commit comments

Comments
 (0)