Skip to content

Commit 0e50f08

Browse files
authored
Aggregates: ChunkedArray kernel (#6837)
First implementation of an aggregate push-down kernel. Requires making a few changes to the API. It also removes checked flag from the sum aggregate function since this is not supported in the old world so no point trying to re-create it yet.
1 parent 5ec8d2f commit 0e50f08

10 files changed

Lines changed: 411 additions & 197 deletions

File tree

vortex-array/public-api.lock

Lines changed: 14 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ pub fn vortex_array::aggregate_fn::fns::sum::Sum::fmt(&self, f: &mut core::fmt::
5454

5555
impl vortex_array::aggregate_fn::AggregateFnVTable for vortex_array::aggregate_fn::fns::sum::Sum
5656

57-
pub type vortex_array::aggregate_fn::fns::sum::Sum::Options = vortex_array::aggregate_fn::fns::sum::SumOptions
57+
pub type vortex_array::aggregate_fn::fns::sum::Sum::Options = vortex_array::aggregate_fn::EmptyOptions
5858

5959
pub type vortex_array::aggregate_fn::fns::sum::Sum::Partial = vortex_array::aggregate_fn::fns::sum::SumPartial
6060

@@ -64,7 +64,7 @@ pub fn vortex_array::aggregate_fn::fns::sum::Sum::combine_partials(&self, partia
6464

6565
pub fn vortex_array::aggregate_fn::fns::sum::Sum::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult<Self::Options>
6666

67-
pub fn vortex_array::aggregate_fn::fns::sum::Sum::empty_partial(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<Self::Partial>
67+
pub fn vortex_array::aggregate_fn::fns::sum::Sum::empty_partial(&self, _options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<Self::Partial>
6868

6969
pub fn vortex_array::aggregate_fn::fns::sum::Sum::finalize(&self, partials: vortex_array::ArrayRef) -> vortex_error::VortexResult<vortex_array::ArrayRef>
7070

@@ -82,33 +82,19 @@ pub fn vortex_array::aggregate_fn::fns::sum::Sum::return_dtype(&self, _options:
8282

8383
pub fn vortex_array::aggregate_fn::fns::sum::Sum::serialize(&self, options: &Self::Options) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>
8484

85-
pub struct vortex_array::aggregate_fn::fns::sum::SumOptions
86-
87-
impl core::clone::Clone for vortex_array::aggregate_fn::fns::sum::SumOptions
88-
89-
pub fn vortex_array::aggregate_fn::fns::sum::SumOptions::clone(&self) -> vortex_array::aggregate_fn::fns::sum::SumOptions
90-
91-
impl core::cmp::Eq for vortex_array::aggregate_fn::fns::sum::SumOptions
92-
93-
impl core::cmp::PartialEq for vortex_array::aggregate_fn::fns::sum::SumOptions
94-
95-
pub fn vortex_array::aggregate_fn::fns::sum::SumOptions::eq(&self, other: &vortex_array::aggregate_fn::fns::sum::SumOptions) -> bool
96-
97-
impl core::fmt::Debug for vortex_array::aggregate_fn::fns::sum::SumOptions
85+
pub struct vortex_array::aggregate_fn::fns::sum::SumPartial
9886

99-
pub fn vortex_array::aggregate_fn::fns::sum::SumOptions::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result
87+
pub mod vortex_array::aggregate_fn::kernels
10088

101-
impl core::fmt::Display for vortex_array::aggregate_fn::fns::sum::SumOptions
89+
pub trait vortex_array::aggregate_fn::kernels::DynAggregateKernel: 'static + core::marker::Send + core::marker::Sync + core::fmt::Debug
10290

103-
pub fn vortex_array::aggregate_fn::fns::sum::SumOptions::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result
91+
pub fn vortex_array::aggregate_fn::kernels::DynAggregateKernel::aggregate(&self, aggregate_fn: &vortex_array::aggregate_fn::AggregateFnRef, batch: &vortex_array::ArrayRef, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<core::option::Option<vortex_array::scalar::Scalar>>
10492

105-
impl core::hash::Hash for vortex_array::aggregate_fn::fns::sum::SumOptions
93+
pub trait vortex_array::aggregate_fn::kernels::DynGroupedAggregateKernel: 'static + core::marker::Send + core::marker::Sync + core::fmt::Debug
10694

107-
pub fn vortex_array::aggregate_fn::fns::sum::SumOptions::hash<__H: core::hash::Hasher>(&self, state: &mut __H)
95+
pub fn vortex_array::aggregate_fn::kernels::DynGroupedAggregateKernel::grouped_aggregate(&self, aggregate_fn: &vortex_array::aggregate_fn::AggregateFnRef, groups: &vortex_array::arrays::ListViewArray) -> vortex_error::VortexResult<core::option::Option<vortex_array::ArrayRef>>
10896

109-
impl core::marker::StructuralPartialEq for vortex_array::aggregate_fn::fns::sum::SumOptions
110-
111-
pub struct vortex_array::aggregate_fn::fns::sum::SumPartial
97+
pub fn vortex_array::aggregate_fn::kernels::DynGroupedAggregateKernel::grouped_aggregate_fixed_size(&self, aggregate_fn: &vortex_array::aggregate_fn::AggregateFnRef, groups: &vortex_array::arrays::FixedSizeListArray) -> vortex_error::VortexResult<core::option::Option<vortex_array::ArrayRef>>
11298

11399
pub mod vortex_array::aggregate_fn::session
114100

@@ -118,11 +104,13 @@ impl vortex_array::aggregate_fn::session::AggregateFnSession
118104

119105
pub fn vortex_array::aggregate_fn::session::AggregateFnSession::register<V: vortex_array::aggregate_fn::AggregateFnVTable>(&self, vtable: V)
120106

107+
pub fn vortex_array::aggregate_fn::session::AggregateFnSession::register_aggregate_kernel(&self, array_id: vortex_array::vtable::ArrayId, agg_fn_id: core::option::Option<vortex_array::aggregate_fn::AggregateFnId>, kernel: &'static dyn vortex_array::aggregate_fn::kernels::DynAggregateKernel)
108+
121109
pub fn vortex_array::aggregate_fn::session::AggregateFnSession::registry(&self) -> &vortex_array::aggregate_fn::session::AggregateFnRegistry
122110

123111
impl core::default::Default for vortex_array::aggregate_fn::session::AggregateFnSession
124112

125-
pub fn vortex_array::aggregate_fn::session::AggregateFnSession::default() -> vortex_array::aggregate_fn::session::AggregateFnSession
113+
pub fn vortex_array::aggregate_fn::session::AggregateFnSession::default() -> Self
126114

127115
impl core::fmt::Debug for vortex_array::aggregate_fn::session::AggregateFnSession
128116

@@ -324,7 +312,7 @@ pub fn vortex_array::aggregate_fn::AggregateFnVTable::serialize(&self, options:
324312

325313
impl vortex_array::aggregate_fn::AggregateFnVTable for vortex_array::aggregate_fn::fns::sum::Sum
326314

327-
pub type vortex_array::aggregate_fn::fns::sum::Sum::Options = vortex_array::aggregate_fn::fns::sum::SumOptions
315+
pub type vortex_array::aggregate_fn::fns::sum::Sum::Options = vortex_array::aggregate_fn::EmptyOptions
328316

329317
pub type vortex_array::aggregate_fn::fns::sum::Sum::Partial = vortex_array::aggregate_fn::fns::sum::SumPartial
330318

@@ -334,7 +322,7 @@ pub fn vortex_array::aggregate_fn::fns::sum::Sum::combine_partials(&self, partia
334322

335323
pub fn vortex_array::aggregate_fn::fns::sum::Sum::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult<Self::Options>
336324

337-
pub fn vortex_array::aggregate_fn::fns::sum::Sum::empty_partial(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<Self::Partial>
325+
pub fn vortex_array::aggregate_fn::fns::sum::Sum::empty_partial(&self, _options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<Self::Partial>
338326

339327
pub fn vortex_array::aggregate_fn::fns::sum::Sum::finalize(&self, partials: vortex_array::ArrayRef) -> vortex_error::VortexResult<vortex_array::ArrayRef>
340328

vortex-array/src/aggregate_fn/accumulator.rs

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,17 @@ impl<V: AggregateFnVTable> DynAccumulator for Accumulator<V> {
105105
break;
106106
}
107107

108-
let kernel_key = (self.vtable.id(), batch.encoding_id());
109-
if let Some(kernel) = kernels.read().get(&kernel_key)
110-
&& let Some(result) = kernel.aggregate(&self.aggregate_fn, &batch)?
108+
let kernels_r = kernels.read();
109+
let batch_id = batch.encoding_id();
110+
if let Some(result) = kernels_r
111+
.get(&(batch_id.clone(), Some(self.aggregate_fn.id())))
112+
.or_else(|| kernels_r.get(&(batch_id, None)))
113+
.and_then(|kernel| {
114+
kernel
115+
.aggregate(&self.aggregate_fn, &batch, &mut ctx)
116+
.transpose()
117+
})
118+
.transpose()?
111119
{
112120
vortex_ensure!(
113121
result.dtype() == &self.partial_dtype,

vortex-array/src/aggregate_fn/accumulator_grouped.rs

Lines changed: 43 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -164,21 +164,27 @@ impl<V: AggregateFnVTable> GroupedAccumulator<V> {
164164
break;
165165
}
166166

167-
let kernel_key = (self.vtable.id(), elements.encoding_id());
168-
if let Some(kernel) = kernels.read().get(&kernel_key) {
169-
// SAFETY: we assume that elements execution is safe
170-
let groups = unsafe {
171-
ListViewArray::new_unchecked(
172-
elements.clone(),
173-
groups.offsets().clone(),
174-
groups.sizes().clone(),
175-
groups.validity().clone(),
176-
)
177-
};
178-
179-
if let Some(result) = kernel.grouped_aggregate(&self.aggregate_fn, &groups)? {
180-
return self.push_result(result);
181-
}
167+
let kernels_r = kernels.read();
168+
if let Some(result) = kernels_r
169+
.get(&(elements.encoding_id(), Some(self.aggregate_fn.id())))
170+
.or_else(|| kernels_r.get(&(elements.encoding_id(), None)))
171+
.and_then(|kernel| {
172+
// SAFETY: we assume that elements execution is safe
173+
let groups = unsafe {
174+
ListViewArray::new_unchecked(
175+
elements.clone(),
176+
groups.offsets().clone(),
177+
groups.sizes().clone(),
178+
groups.validity().clone(),
179+
)
180+
};
181+
kernel
182+
.grouped_aggregate(&self.aggregate_fn, &groups)
183+
.transpose()
184+
})
185+
.transpose()?
186+
{
187+
return self.push_result(result);
182188
}
183189

184190
// Execute one step and try again
@@ -244,23 +250,28 @@ impl<V: AggregateFnVTable> GroupedAccumulator<V> {
244250
break;
245251
}
246252

247-
let kernel_key = (self.vtable.id(), elements.encoding_id());
248-
if let Some(kernel) = kernels.read().get(&kernel_key) {
249-
// SAFETY: we assume that elements execution is safe
250-
let groups = unsafe {
251-
FixedSizeListArray::new_unchecked(
252-
elements.clone(),
253-
groups.list_size(),
254-
groups.validity().clone(),
255-
groups.len(),
256-
)
257-
};
258-
259-
if let Some(result) =
260-
kernel.grouped_aggregate_fixed_size(&self.aggregate_fn, &groups)?
261-
{
262-
return self.push_result(result);
263-
}
253+
let kernels_r = kernels.read();
254+
if let Some(result) = kernels_r
255+
.get(&(elements.encoding_id(), Some(self.aggregate_fn.id())))
256+
.or_else(|| kernels_r.get(&(elements.encoding_id(), None)))
257+
.and_then(|kernel| {
258+
// SAFETY: we assume that elements execution is safe
259+
let groups = unsafe {
260+
FixedSizeListArray::new_unchecked(
261+
elements.clone(),
262+
groups.list_size(),
263+
groups.validity().clone(),
264+
groups.len(),
265+
)
266+
};
267+
268+
kernel
269+
.grouped_aggregate_fixed_size(&self.aggregate_fn, &groups)
270+
.transpose()
271+
})
272+
.transpose()?
273+
{
274+
return self.push_result(result);
264275
}
265276

266277
// Execute one step and try again

0 commit comments

Comments
 (0)