Skip to content

Commit 2fd4df3

Browse files
authored
AggregateFn naming (#6828)
Some naming fixes for the aggregate functions API --------- Signed-off-by: Nicholas Gates <nick@nickgates.com>
1 parent a8b340a commit 2fd4df3

6 files changed

Lines changed: 118 additions & 117 deletions

File tree

vortex-array/public-api.lock

Lines changed: 42 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -54,35 +54,33 @@ pub fn vortex_array::aggregate_fn::fns::sum::Sum::fmt(&self, f: &mut core::fmt::
5454

5555
impl vortex_array::aggregate_fn::AggregateFnVTable for vortex_array::aggregate_fn::fns::sum::Sum
5656

57-
pub type vortex_array::aggregate_fn::fns::sum::Sum::GroupState = vortex_array::aggregate_fn::fns::sum::SumGroupState
58-
5957
pub type vortex_array::aggregate_fn::fns::sum::Sum::Options = vortex_array::aggregate_fn::fns::sum::SumOptions
6058

61-
pub fn vortex_array::aggregate_fn::fns::sum::Sum::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult<Self::Options>
59+
pub type vortex_array::aggregate_fn::fns::sum::Sum::Partial = vortex_array::aggregate_fn::fns::sum::SumPartial
6260

63-
pub fn vortex_array::aggregate_fn::fns::sum::Sum::finalize(&self, states: vortex_array::ArrayRef) -> vortex_error::VortexResult<vortex_array::ArrayRef>
61+
pub fn vortex_array::aggregate_fn::fns::sum::Sum::accumulate(&self, partial: &mut Self::Partial, batch: &vortex_array::Canonical, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<()>
6462

65-
pub fn vortex_array::aggregate_fn::fns::sum::Sum::finalize_scalar(&self, state: vortex_array::scalar::Scalar) -> vortex_error::VortexResult<vortex_array::scalar::Scalar>
63+
pub fn vortex_array::aggregate_fn::fns::sum::Sum::combine_partials(&self, partial: &mut Self::Partial, other: vortex_array::scalar::Scalar) -> vortex_error::VortexResult<()>
6664

67-
pub fn vortex_array::aggregate_fn::fns::sum::Sum::id(&self) -> vortex_array::aggregate_fn::AggregateFnId
65+
pub fn vortex_array::aggregate_fn::fns::sum::Sum::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult<Self::Options>
6866

69-
pub fn vortex_array::aggregate_fn::fns::sum::Sum::return_dtype(&self, _options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<vortex_array::dtype::DType>
67+
pub fn vortex_array::aggregate_fn::fns::sum::Sum::empty_partial(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<Self::Partial>
7068

71-
pub fn vortex_array::aggregate_fn::fns::sum::Sum::serialize(&self, options: &Self::Options) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>
69+
pub fn vortex_array::aggregate_fn::fns::sum::Sum::finalize(&self, partials: vortex_array::ArrayRef) -> vortex_error::VortexResult<vortex_array::ArrayRef>
7270

73-
pub fn vortex_array::aggregate_fn::fns::sum::Sum::state_accumulate(&self, state: &mut Self::GroupState, batch: &vortex_array::Canonical, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<()>
71+
pub fn vortex_array::aggregate_fn::fns::sum::Sum::finalize_scalar(&self, partial: vortex_array::scalar::Scalar) -> vortex_error::VortexResult<vortex_array::scalar::Scalar>
7472

75-
pub fn vortex_array::aggregate_fn::fns::sum::Sum::state_dtype(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<vortex_array::dtype::DType>
73+
pub fn vortex_array::aggregate_fn::fns::sum::Sum::flush(&self, partial: &mut Self::Partial) -> vortex_error::VortexResult<vortex_array::scalar::Scalar>
7674

77-
pub fn vortex_array::aggregate_fn::fns::sum::Sum::state_flush(&self, state: &mut Self::GroupState) -> vortex_error::VortexResult<vortex_array::scalar::Scalar>
75+
pub fn vortex_array::aggregate_fn::fns::sum::Sum::id(&self) -> vortex_array::aggregate_fn::AggregateFnId
7876

79-
pub fn vortex_array::aggregate_fn::fns::sum::Sum::state_is_saturated(&self, state: &Self::GroupState) -> bool
77+
pub fn vortex_array::aggregate_fn::fns::sum::Sum::is_saturated(&self, partial: &Self::Partial) -> bool
8078

81-
pub fn vortex_array::aggregate_fn::fns::sum::Sum::state_merge(&self, state: &mut Self::GroupState, other: vortex_array::scalar::Scalar) -> vortex_error::VortexResult<()>
79+
pub fn vortex_array::aggregate_fn::fns::sum::Sum::partial_dtype(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<vortex_array::dtype::DType>
8280

83-
pub fn vortex_array::aggregate_fn::fns::sum::Sum::state_new(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<Self::GroupState>
81+
pub fn vortex_array::aggregate_fn::fns::sum::Sum::return_dtype(&self, _options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<vortex_array::dtype::DType>
8482

85-
pub struct vortex_array::aggregate_fn::fns::sum::SumGroupState
83+
pub fn vortex_array::aggregate_fn::fns::sum::Sum::serialize(&self, options: &Self::Options) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>
8684

8785
pub struct vortex_array::aggregate_fn::fns::sum::SumOptions
8886

@@ -110,6 +108,8 @@ pub fn vortex_array::aggregate_fn::fns::sum::SumOptions::hash<__H: core::hash::H
110108

111109
impl core::marker::StructuralPartialEq for vortex_array::aggregate_fn::fns::sum::SumOptions
112110

111+
pub struct vortex_array::aggregate_fn::fns::sum::SumPartial
112+
113113
pub mod vortex_array::aggregate_fn::session
114114

115115
pub struct vortex_array::aggregate_fn::session::AggregateFnSession
@@ -294,64 +294,64 @@ pub fn V::id(&self) -> arcref::ArcRef<str>
294294

295295
pub trait vortex_array::aggregate_fn::AggregateFnVTable: 'static + core::marker::Sized + core::clone::Clone + core::marker::Send + core::marker::Sync
296296

297-
pub type vortex_array::aggregate_fn::AggregateFnVTable::GroupState: 'static + core::marker::Send
298-
299297
pub type vortex_array::aggregate_fn::AggregateFnVTable::Options: 'static + core::marker::Send + core::marker::Sync + core::clone::Clone + core::fmt::Debug + core::fmt::Display + core::cmp::PartialEq + core::cmp::Eq + core::hash::Hash
300298

299+
pub type vortex_array::aggregate_fn::AggregateFnVTable::Partial: 'static + core::marker::Send
300+
301+
pub fn vortex_array::aggregate_fn::AggregateFnVTable::accumulate(&self, state: &mut Self::Partial, batch: &vortex_array::Canonical, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<()>
302+
303+
pub fn vortex_array::aggregate_fn::AggregateFnVTable::combine_partials(&self, partial: &mut Self::Partial, other: vortex_array::scalar::Scalar) -> vortex_error::VortexResult<()>
304+
301305
pub fn vortex_array::aggregate_fn::AggregateFnVTable::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult<Self::Options>
302306

307+
pub fn vortex_array::aggregate_fn::AggregateFnVTable::empty_partial(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<Self::Partial>
308+
303309
pub fn vortex_array::aggregate_fn::AggregateFnVTable::finalize(&self, states: vortex_array::ArrayRef) -> vortex_error::VortexResult<vortex_array::ArrayRef>
304310

305311
pub fn vortex_array::aggregate_fn::AggregateFnVTable::finalize_scalar(&self, state: vortex_array::scalar::Scalar) -> vortex_error::VortexResult<vortex_array::scalar::Scalar>
306312

313+
pub fn vortex_array::aggregate_fn::AggregateFnVTable::flush(&self, partial: &mut Self::Partial) -> vortex_error::VortexResult<vortex_array::scalar::Scalar>
314+
307315
pub fn vortex_array::aggregate_fn::AggregateFnVTable::id(&self) -> vortex_array::aggregate_fn::AggregateFnId
308316

317+
pub fn vortex_array::aggregate_fn::AggregateFnVTable::is_saturated(&self, state: &Self::Partial) -> bool
318+
319+
pub fn vortex_array::aggregate_fn::AggregateFnVTable::partial_dtype(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<vortex_array::dtype::DType>
320+
309321
pub fn vortex_array::aggregate_fn::AggregateFnVTable::return_dtype(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<vortex_array::dtype::DType>
310322

311323
pub fn vortex_array::aggregate_fn::AggregateFnVTable::serialize(&self, options: &Self::Options) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>
312324

313-
pub fn vortex_array::aggregate_fn::AggregateFnVTable::state_accumulate(&self, state: &mut Self::GroupState, batch: &vortex_array::Canonical, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<()>
325+
impl vortex_array::aggregate_fn::AggregateFnVTable for vortex_array::aggregate_fn::fns::sum::Sum
314326

315-
pub fn vortex_array::aggregate_fn::AggregateFnVTable::state_dtype(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<vortex_array::dtype::DType>
327+
pub type vortex_array::aggregate_fn::fns::sum::Sum::Options = vortex_array::aggregate_fn::fns::sum::SumOptions
316328

317-
pub fn vortex_array::aggregate_fn::AggregateFnVTable::state_flush(&self, state: &mut Self::GroupState) -> vortex_error::VortexResult<vortex_array::scalar::Scalar>
329+
pub type vortex_array::aggregate_fn::fns::sum::Sum::Partial = vortex_array::aggregate_fn::fns::sum::SumPartial
318330

319-
pub fn vortex_array::aggregate_fn::AggregateFnVTable::state_is_saturated(&self, state: &Self::GroupState) -> bool
331+
pub fn vortex_array::aggregate_fn::fns::sum::Sum::accumulate(&self, partial: &mut Self::Partial, batch: &vortex_array::Canonical, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<()>
320332

321-
pub fn vortex_array::aggregate_fn::AggregateFnVTable::state_merge(&self, state: &mut Self::GroupState, partial: vortex_array::scalar::Scalar) -> vortex_error::VortexResult<()>
333+
pub fn vortex_array::aggregate_fn::fns::sum::Sum::combine_partials(&self, partial: &mut Self::Partial, other: vortex_array::scalar::Scalar) -> vortex_error::VortexResult<()>
322334

323-
pub fn vortex_array::aggregate_fn::AggregateFnVTable::state_new(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<Self::GroupState>
335+
pub fn vortex_array::aggregate_fn::fns::sum::Sum::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult<Self::Options>
324336

325-
impl vortex_array::aggregate_fn::AggregateFnVTable for vortex_array::aggregate_fn::fns::sum::Sum
337+
pub fn vortex_array::aggregate_fn::fns::sum::Sum::empty_partial(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<Self::Partial>
326338

327-
pub type vortex_array::aggregate_fn::fns::sum::Sum::GroupState = vortex_array::aggregate_fn::fns::sum::SumGroupState
339+
pub fn vortex_array::aggregate_fn::fns::sum::Sum::finalize(&self, partials: vortex_array::ArrayRef) -> vortex_error::VortexResult<vortex_array::ArrayRef>
328340

329-
pub type vortex_array::aggregate_fn::fns::sum::Sum::Options = vortex_array::aggregate_fn::fns::sum::SumOptions
341+
pub fn vortex_array::aggregate_fn::fns::sum::Sum::finalize_scalar(&self, partial: vortex_array::scalar::Scalar) -> vortex_error::VortexResult<vortex_array::scalar::Scalar>
330342

331-
pub fn vortex_array::aggregate_fn::fns::sum::Sum::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult<Self::Options>
343+
pub fn vortex_array::aggregate_fn::fns::sum::Sum::flush(&self, partial: &mut Self::Partial) -> vortex_error::VortexResult<vortex_array::scalar::Scalar>
332344

333-
pub fn vortex_array::aggregate_fn::fns::sum::Sum::finalize(&self, states: vortex_array::ArrayRef) -> vortex_error::VortexResult<vortex_array::ArrayRef>
345+
pub fn vortex_array::aggregate_fn::fns::sum::Sum::id(&self) -> vortex_array::aggregate_fn::AggregateFnId
334346

335-
pub fn vortex_array::aggregate_fn::fns::sum::Sum::finalize_scalar(&self, state: vortex_array::scalar::Scalar) -> vortex_error::VortexResult<vortex_array::scalar::Scalar>
347+
pub fn vortex_array::aggregate_fn::fns::sum::Sum::is_saturated(&self, partial: &Self::Partial) -> bool
336348

337-
pub fn vortex_array::aggregate_fn::fns::sum::Sum::id(&self) -> vortex_array::aggregate_fn::AggregateFnId
349+
pub fn vortex_array::aggregate_fn::fns::sum::Sum::partial_dtype(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<vortex_array::dtype::DType>
338350

339351
pub fn vortex_array::aggregate_fn::fns::sum::Sum::return_dtype(&self, _options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<vortex_array::dtype::DType>
340352

341353
pub fn vortex_array::aggregate_fn::fns::sum::Sum::serialize(&self, options: &Self::Options) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>
342354

343-
pub fn vortex_array::aggregate_fn::fns::sum::Sum::state_accumulate(&self, state: &mut Self::GroupState, batch: &vortex_array::Canonical, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<()>
344-
345-
pub fn vortex_array::aggregate_fn::fns::sum::Sum::state_dtype(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<vortex_array::dtype::DType>
346-
347-
pub fn vortex_array::aggregate_fn::fns::sum::Sum::state_flush(&self, state: &mut Self::GroupState) -> vortex_error::VortexResult<vortex_array::scalar::Scalar>
348-
349-
pub fn vortex_array::aggregate_fn::fns::sum::Sum::state_is_saturated(&self, state: &Self::GroupState) -> bool
350-
351-
pub fn vortex_array::aggregate_fn::fns::sum::Sum::state_merge(&self, state: &mut Self::GroupState, other: vortex_array::scalar::Scalar) -> vortex_error::VortexResult<()>
352-
353-
pub fn vortex_array::aggregate_fn::fns::sum::Sum::state_new(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<Self::GroupState>
354-
355355
pub trait vortex_array::aggregate_fn::AggregateFnVTableExt: vortex_array::aggregate_fn::AggregateFnVTable
356356

357357
pub fn vortex_array::aggregate_fn::AggregateFnVTableExt::bind(&self, options: Self::Options) -> vortex_array::aggregate_fn::AggregateFnRef

vortex-array/src/aggregate_fn/accumulator.rs

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@ pub struct Accumulator<V: AggregateFnVTable> {
3232
/// The DType of the aggregate.
3333
return_dtype: DType,
3434
/// The DType of the accumulator state.
35-
state_dtype: DType,
36-
/// The current state of the accumulator, updated after each accumulate/merge call.
37-
current_state: V::GroupState,
35+
partial_dtype: DType,
36+
/// The partial state of the accumulator, updated after each accumulate/merge call.
37+
partial: V::Partial,
3838
/// A session used to lookup custom aggregate kernels.
3939
session: VortexSession,
4040
}
@@ -47,17 +47,17 @@ impl<V: AggregateFnVTable> Accumulator<V> {
4747
session: VortexSession,
4848
) -> VortexResult<Self> {
4949
let return_dtype = vtable.return_dtype(&options, &dtype)?;
50-
let state_dtype = vtable.state_dtype(&options, &dtype)?;
51-
let current_state = vtable.state_new(&options, &dtype)?;
50+
let partial_dtype = vtable.partial_dtype(&options, &dtype)?;
51+
let partial = vtable.empty_partial(&options, &dtype)?;
5252
let aggregate_fn = AggregateFn::new(vtable.clone(), options).erased();
5353

5454
Ok(Self {
5555
vtable,
5656
aggregate_fn,
5757
dtype,
5858
return_dtype,
59-
state_dtype,
60-
current_state,
59+
partial_dtype,
60+
partial,
6161
session,
6262
})
6363
}
@@ -110,12 +110,12 @@ impl<V: AggregateFnVTable> DynAccumulator for Accumulator<V> {
110110
&& let Some(result) = kernel.aggregate(&self.aggregate_fn, &batch)?
111111
{
112112
vortex_ensure!(
113-
result.dtype() == &self.state_dtype,
113+
result.dtype() == &self.partial_dtype,
114114
"Aggregate kernel returned {}, expected {}",
115115
result.dtype(),
116-
self.state_dtype,
116+
self.partial_dtype,
117117
);
118-
self.vtable.state_merge(&mut self.current_state, result)?;
118+
self.vtable.combine_partials(&mut self.partial, result)?;
119119
return Ok(());
120120
}
121121

@@ -127,22 +127,22 @@ impl<V: AggregateFnVTable> DynAccumulator for Accumulator<V> {
127127
let canonical = batch.execute::<Canonical>(&mut ctx)?;
128128

129129
self.vtable
130-
.state_accumulate(&mut self.current_state, &canonical, &mut ctx)
130+
.accumulate(&mut self.partial, &canonical, &mut ctx)
131131
}
132132

133133
fn is_saturated(&self) -> bool {
134-
self.vtable.state_is_saturated(&self.current_state)
134+
self.vtable.is_saturated(&self.partial)
135135
}
136136

137137
fn flush(&mut self) -> VortexResult<Scalar> {
138-
let partial = self.vtable.state_flush(&mut self.current_state)?;
138+
let partial = self.vtable.flush(&mut self.partial)?;
139139

140140
#[cfg(debug_assertions)]
141141
{
142142
vortex_ensure!(
143-
partial.dtype() == &self.state_dtype,
143+
partial.dtype() == &self.partial_dtype,
144144
"Aggregate kernel returned incorrect DType on flush: expected {}, got {}",
145-
self.state_dtype,
145+
self.partial_dtype,
146146
partial.dtype(),
147147
);
148148
}

vortex-array/src/aggregate_fn/accumulator_grouped.rs

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,10 @@ pub struct GroupedAccumulator<V: AggregateFnVTable> {
5353
dtype: DType,
5454
/// The DType of the aggregate.
5555
return_dtype: DType,
56-
/// The DType of the accumulator state.
57-
state_dtype: DType,
56+
/// The DType of the partial accumulator state.
57+
partial_dtype: DType,
5858
/// The accumulated state for prior batches of groups.
59-
states: Vec<ArrayRef>,
59+
partials: Vec<ArrayRef>,
6060
/// A session used to lookup custom aggregate kernels.
6161
session: VortexSession,
6262
}
@@ -70,16 +70,16 @@ impl<V: AggregateFnVTable> GroupedAccumulator<V> {
7070
) -> VortexResult<Self> {
7171
let aggregate_fn = AggregateFn::new(vtable.clone(), options.clone()).erased();
7272
let return_dtype = vtable.return_dtype(&options, &dtype)?;
73-
let state_dtype = vtable.state_dtype(&options, &dtype)?;
73+
let partial_dtype = vtable.partial_dtype(&options, &dtype)?;
7474

7575
Ok(Self {
7676
vtable,
7777
options,
7878
aggregate_fn,
7979
dtype,
8080
return_dtype,
81-
state_dtype,
82-
states: vec![],
81+
partial_dtype,
82+
partials: vec![],
8383
session,
8484
})
8585
}
@@ -129,8 +129,8 @@ impl<V: AggregateFnVTable> DynGroupedAccumulator for GroupedAccumulator<V> {
129129
}
130130

131131
fn flush(&mut self) -> VortexResult<ArrayRef> {
132-
let states = std::mem::take(&mut self.states);
133-
Ok(ChunkedArray::try_new(states, self.state_dtype.clone())?.into_array())
132+
let states = std::mem::take(&mut self.partials);
133+
Ok(ChunkedArray::try_new(states, self.partial_dtype.clone())?.into_array())
134134
}
135135

136136
fn finish(&mut self) -> VortexResult<ArrayRef> {
@@ -211,7 +211,7 @@ impl<V: AggregateFnVTable> GroupedAccumulator<V> {
211211
self.dtype.clone(),
212212
self.session.clone(),
213213
)?;
214-
let mut states = builder_with_capacity(&self.state_dtype, offsets.len());
214+
let mut states = builder_with_capacity(&self.partial_dtype, offsets.len());
215215

216216
for (offset, size) in offsets.iter().zip(sizes.iter()) {
217217
let offset = offset.to_usize().vortex_expect("Offset value is not usize");
@@ -277,7 +277,7 @@ impl<V: AggregateFnVTable> GroupedAccumulator<V> {
277277
self.dtype.clone(),
278278
self.session.clone(),
279279
)?;
280-
let mut states = builder_with_capacity(&self.state_dtype, groups.len());
280+
let mut states = builder_with_capacity(&self.partial_dtype, groups.len());
281281

282282
let mut offset = 0;
283283
let size = groups
@@ -301,12 +301,12 @@ impl<V: AggregateFnVTable> GroupedAccumulator<V> {
301301

302302
fn push_result(&mut self, state: ArrayRef) -> VortexResult<()> {
303303
vortex_ensure!(
304-
state.dtype() == &self.state_dtype,
304+
state.dtype() == &self.partial_dtype,
305305
"State DType mismatch: expected {}, got {}",
306-
self.state_dtype,
306+
self.partial_dtype,
307307
state.dtype()
308308
);
309-
self.states.push(state);
309+
self.partials.push(state);
310310
Ok(())
311311
}
312312
}

0 commit comments

Comments
 (0)