Skip to content

Commit 761c404

Browse files
authored
Aggregate Fns (#6721)
First PR implementing the Aggregate Functions proposal in vortex-data/rfcs#21 --------- Signed-off-by: Nicholas Gates <nick@nickgates.com>
1 parent 60476a9 commit 761c404

File tree

18 files changed

+2417
-15
lines changed

18 files changed

+2417
-15
lines changed

vortex-array/public-api.lock

Lines changed: 378 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use vortex_error::VortexResult;
5+
use vortex_error::vortex_ensure;
6+
use vortex_session::VortexSession;
7+
8+
use crate::AnyCanonical;
9+
use crate::ArrayRef;
10+
use crate::Canonical;
11+
use crate::DynArray;
12+
use crate::VortexSessionExecute;
13+
use crate::aggregate_fn::AggregateFn;
14+
use crate::aggregate_fn::AggregateFnRef;
15+
use crate::aggregate_fn::AggregateFnVTable;
16+
use crate::aggregate_fn::session::AggregateFnSessionExt;
17+
use crate::dtype::DType;
18+
use crate::executor::MAX_ITERATIONS;
19+
use crate::scalar::Scalar;
20+
21+
/// Reference-counted type-erased accumulator.
22+
pub type AccumulatorRef = Box<dyn DynAccumulator>;
23+
24+
/// An accumulator used for computing aggregates over an entire stream of arrays.
25+
pub struct Accumulator<V: AggregateFnVTable> {
26+
/// The vtable of the aggregate function.
27+
vtable: V,
28+
/// Type-erased aggregate function used for kernel dispatch.
29+
aggregate_fn: AggregateFnRef,
30+
/// The DType of the input.
31+
dtype: DType,
32+
/// The DType of the aggregate.
33+
return_dtype: DType,
34+
/// The DType of the accumulator state.
35+
state_dtype: DType,
36+
/// The current state of the accumulator, updated after each accumulate/merge call.
37+
current_state: V::GroupState,
38+
/// A session used to lookup custom aggregate kernels.
39+
session: VortexSession,
40+
}
41+
42+
impl<V: AggregateFnVTable> Accumulator<V> {
43+
pub fn try_new(
44+
vtable: V,
45+
options: V::Options,
46+
dtype: DType,
47+
session: VortexSession,
48+
) -> VortexResult<Self> {
49+
let return_dtype = vtable.return_dtype(&options, &dtype)?;
50+
let state_dtype = vtable.state_dtype(&options, &dtype)?;
51+
let current_state = vtable.state_new(&options, &dtype)?;
52+
let aggregate_fn = AggregateFn::new(vtable.clone(), options).erased();
53+
54+
Ok(Self {
55+
vtable,
56+
aggregate_fn,
57+
dtype,
58+
return_dtype,
59+
state_dtype,
60+
current_state,
61+
session,
62+
})
63+
}
64+
}
65+
66+
/// A trait object for type-erased accumulators, used for dynamic dispatch when the aggregate
67+
/// function is not known at compile time.
68+
pub trait DynAccumulator: 'static + Send {
69+
/// Accumulate a new array into the accumulator's state.
70+
fn accumulate(&mut self, batch: &ArrayRef) -> VortexResult<()>;
71+
72+
/// Whether the accumulator's result is fully determined.
73+
fn is_saturated(&self) -> bool;
74+
75+
/// Flush the accumulation state and return the partial aggregate result as a scalar.
76+
///
77+
/// Resets the accumulator state back to the initial state.
78+
fn flush(&mut self) -> VortexResult<Scalar>;
79+
80+
/// Finish the accumulation and return the final aggregate result as a scalar.
81+
///
82+
/// Resets the accumulator state back to the initial state.
83+
fn finish(&mut self) -> VortexResult<Scalar>;
84+
}
85+
86+
impl<V: AggregateFnVTable> DynAccumulator for Accumulator<V> {
87+
fn accumulate(&mut self, batch: &ArrayRef) -> VortexResult<()> {
88+
if self.is_saturated() {
89+
return Ok(());
90+
}
91+
92+
vortex_ensure!(
93+
batch.dtype() == &self.dtype,
94+
"Input DType mismatch: expected {}, got {}",
95+
self.dtype,
96+
batch.dtype()
97+
);
98+
99+
let kernels = &self.session.aggregate_fns().kernels;
100+
101+
let mut ctx = self.session.create_execution_ctx();
102+
let mut batch = batch.clone();
103+
for _ in 0..*MAX_ITERATIONS {
104+
if batch.is::<AnyCanonical>() {
105+
break;
106+
}
107+
108+
let kernel_key = (self.vtable.id(), batch.encoding_id());
109+
if let Some(kernel) = kernels.read().get(&kernel_key)
110+
&& let Some(result) = kernel.aggregate(&self.aggregate_fn, &batch)?
111+
{
112+
vortex_ensure!(
113+
result.dtype() == &self.state_dtype,
114+
"Aggregate kernel returned {}, expected {}",
115+
result.dtype(),
116+
self.state_dtype,
117+
);
118+
self.vtable.state_merge(&mut self.current_state, result)?;
119+
return Ok(());
120+
}
121+
122+
// Execute one step and try again
123+
batch = batch.execute(&mut ctx)?;
124+
}
125+
126+
// Otherwise, execute the batch until it is canonical and accumulate it into the state.
127+
let canonical = batch.execute::<Canonical>(&mut ctx)?;
128+
129+
self.vtable
130+
.state_accumulate(&mut self.current_state, &canonical, &mut ctx)
131+
}
132+
133+
fn is_saturated(&self) -> bool {
134+
self.vtable.state_is_saturated(&self.current_state)
135+
}
136+
137+
fn flush(&mut self) -> VortexResult<Scalar> {
138+
let partial = self.vtable.state_flush(&mut self.current_state)?;
139+
140+
#[cfg(debug_assertions)]
141+
{
142+
vortex_ensure!(
143+
partial.dtype() == &self.state_dtype,
144+
"Aggregate kernel returned incorrect DType on flush: expected {}, got {}",
145+
self.state_dtype,
146+
partial.dtype(),
147+
);
148+
}
149+
150+
Ok(partial)
151+
}
152+
153+
fn finish(&mut self) -> VortexResult<Scalar> {
154+
let partial = self.flush()?;
155+
let result = self.vtable.finalize_scalar(partial)?;
156+
157+
vortex_ensure!(
158+
result.dtype() == &self.return_dtype,
159+
"Aggregate kernel returned incorrect DType on finalize: expected {}, got {}",
160+
self.return_dtype,
161+
result.dtype(),
162+
);
163+
164+
Ok(result)
165+
}
166+
}

0 commit comments

Comments
 (0)