-
Notifications
You must be signed in to change notification settings - Fork 161
feat: Mean aggregate
#7201
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Mean aggregate
#7201
Changes from 10 commits
6bafabe
ee67d80
2d38b62
4e16945
575672a
3c8b12b
46db9fa
53806a4
0515378
ce2e52c
0e56b5c
1e0d743
f847221
be4fa68
e5a99ea
7f179cc
738e302
3c426f0
a9be15a
44f4b65
348f88c
75e8c5c
55c0d04
db815bf
4401d73
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,254 @@ | ||
| // SPDX-License-Identifier: Apache-2.0 | ||
| // SPDX-FileCopyrightText: Copyright the Vortex contributors | ||
|
|
||
| //! Generic adapter for aggregates whose result is computed from two child | ||
| //! aggregate functions, e.g. `Mean = Sum / Count`. | ||
|
|
||
| use std::fmt::{self, Debug, Display, Formatter}; | ||
| use std::hash::Hash; | ||
|
|
||
| use vortex_error::{VortexResult, vortex_bail, vortex_err}; | ||
| use vortex_session::VortexSession; | ||
|
|
||
| use crate::aggregate_fn::{AggregateFnId, AggregateFnVTable}; | ||
| use crate::builtins::ArrayBuiltins; | ||
| use crate::dtype::{DType, FieldName, FieldNames, Nullability, StructFields}; | ||
| use crate::scalar::Scalar; | ||
| use crate::{ArrayRef, Columnar, ExecutionCtx}; | ||
|
|
||
| /// Pair of options for the two children of a [`BinaryCombined`] aggregate. | ||
| /// | ||
| /// Wrapper around `(L, R)` because the [`AggregateFnVTable::Options`] bound | ||
| /// requires `Display`, which tuples don't implement. | ||
| #[derive(Clone, Debug, PartialEq, Eq, Hash)] | ||
| pub struct PairOptions<L, R>(pub L, pub R); | ||
|
|
||
| impl<L: Display, R: Display> Display for PairOptions<L, R> { | ||
| fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { | ||
| write!(f, "({}, {})", self.0, self.1) | ||
| } | ||
| } | ||
|
|
||
| // Convenience aliases so signatures stay readable. | ||
| type LeftOptions<T> = <<T as BinaryCombined>::Left as AggregateFnVTable>::Options; | ||
| type RightOptions<T> = <<T as BinaryCombined>::Right as AggregateFnVTable>::Options; | ||
| type LeftPartial<T> = <<T as BinaryCombined>::Left as AggregateFnVTable>::Partial; | ||
| type RightPartial<T> = <<T as BinaryCombined>::Right as AggregateFnVTable>::Partial; | ||
| /// Combined options for a [`BinaryCombined`] aggregate. | ||
| pub type CombinedOptions<T> = PairOptions<LeftOptions<T>, RightOptions<T>>; | ||
|
|
||
| /// Declare an aggregate function in terms of two child aggregates. | ||
| pub trait BinaryCombined: 'static + Send + Sync + Clone { | ||
| /// The left child aggregate vtable. | ||
| type Left: AggregateFnVTable; | ||
| /// The right child aggregate vtable. | ||
| type Right: AggregateFnVTable; | ||
|
|
||
| /// Stable identifier for the combined aggregate. | ||
| fn id(&self) -> AggregateFnId; | ||
|
|
||
| /// Construct the left child vtable. | ||
| fn left(&self) -> Self::Left; | ||
|
|
||
| /// Construct the right child vtable. | ||
| fn right(&self) -> Self::Right; | ||
|
|
||
| /// Field name for the left child in the partial struct dtype. | ||
| fn left_name(&self) -> &'static str { | ||
| "left" | ||
| } | ||
|
|
||
| /// Field name for the right child in the partial struct dtype. | ||
| fn right_name(&self) -> &'static str { | ||
| "right" | ||
| } | ||
|
Comment on lines
+68
to
+76
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why default these?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. those don't matter, so may as well use left/right |
||
|
|
||
| /// Return type of the combined aggregate. | ||
| fn return_dtype(&self, input_dtype: &DType) -> Option<DType>; | ||
|
|
||
| /// Combine the finalized left and right results into the final aggregate. | ||
| fn finalize(&self, left: ArrayRef, right: ArrayRef) -> VortexResult<ArrayRef>; | ||
|
|
||
| /// Serialize the options for this combined aggregate. Default: not serializable. | ||
| fn serialize(&self, options: &CombinedOptions<Self>) -> VortexResult<Option<Vec<u8>>> { | ||
| let _ = options; | ||
| Ok(None) | ||
| } | ||
|
|
||
| /// Deserialize the options for this combined aggregate. Default: bails. | ||
| fn deserialize( | ||
| &self, | ||
| metadata: &[u8], | ||
| session: &VortexSession, | ||
| ) -> VortexResult<CombinedOptions<Self>> { | ||
| let _ = (metadata, session); | ||
| vortex_bail!( | ||
| "Combined aggregate function {} is not deserializable", | ||
| BinaryCombined::id(self) | ||
| ); | ||
| } | ||
|
|
||
| /// Coerce the input type. Default: chains `right.coerce_args(left.coerce_args(input))`. | ||
| fn coerce_args( | ||
| &self, | ||
| options: &CombinedOptions<Self>, | ||
| input_dtype: &DType, | ||
| ) -> VortexResult<DType> { | ||
| let left_coerced = self.left().coerce_args(&options.0, input_dtype)?; | ||
| self.right().coerce_args(&options.1, &left_coerced) | ||
| } | ||
| } | ||
|
|
||
| /// Adapter that exposes any [`BinaryCombined`] as an [`AggregateFnVTable`]. | ||
| #[derive(Clone, Debug)] | ||
| pub struct Combined<T: BinaryCombined>(pub T); | ||
|
|
||
| impl<T: BinaryCombined> Combined<T> { | ||
| /// Construct a new combined aggregate vtable. | ||
| pub fn new(inner: T) -> Self { | ||
| Self(inner) | ||
| } | ||
| } | ||
|
|
||
| impl<T: BinaryCombined> AggregateFnVTable for Combined<T> { | ||
| type Options = CombinedOptions<T>; | ||
| type Partial = (LeftPartial<T>, RightPartial<T>); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How will this partial be seralized?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. as a struct scalar with left/right fields |
||
|
|
||
| fn id(&self) -> AggregateFnId { | ||
| self.0.id() | ||
| } | ||
|
|
||
| fn serialize(&self, options: &Self::Options) -> VortexResult<Option<Vec<u8>>> { | ||
| BinaryCombined::serialize(&self.0, options) | ||
| } | ||
|
|
||
| fn deserialize( | ||
| &self, | ||
| metadata: &[u8], | ||
| session: &VortexSession, | ||
| ) -> VortexResult<Self::Options> { | ||
| BinaryCombined::deserialize(&self.0, metadata, session) | ||
| } | ||
|
|
||
| fn coerce_args(&self, options: &Self::Options, input_dtype: &DType) -> VortexResult<DType> { | ||
| BinaryCombined::coerce_args(&self.0, options, input_dtype) | ||
| } | ||
|
|
||
| fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option<DType> { | ||
| BinaryCombined::return_dtype(&self.0, input_dtype) | ||
| } | ||
|
|
||
| fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option<DType> { | ||
| let l = self.0.left().partial_dtype(&options.0, input_dtype)?; | ||
| let r = self.0.right().partial_dtype(&options.1, input_dtype)?; | ||
| Some(struct_dtype(self.0.left_name(), self.0.right_name(), l, r)) | ||
| } | ||
|
|
||
| fn empty_partial( | ||
| &self, | ||
| options: &Self::Options, | ||
| input_dtype: &DType, | ||
| ) -> VortexResult<Self::Partial> { | ||
| Ok(( | ||
| self.0.left().empty_partial(&options.0, input_dtype)?, | ||
| self.0.right().empty_partial(&options.1, input_dtype)?, | ||
| )) | ||
| } | ||
|
|
||
| fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()> { | ||
| if other.is_null() { | ||
| return Ok(()); | ||
| } | ||
| let s = other.as_struct(); | ||
| let lname = self.0.left_name(); | ||
| let rname = self.0.right_name(); | ||
| let l_field = s | ||
| .field(lname) | ||
| .ok_or_else(|| vortex_err!("BinaryCombined partial missing `{}` field", lname))?; | ||
| let r_field = s | ||
| .field(rname) | ||
| .ok_or_else(|| vortex_err!("BinaryCombined partial missing `{}` field", rname))?; | ||
| self.0.left().combine_partials(&mut partial.0, l_field)?; | ||
| self.0.right().combine_partials(&mut partial.1, r_field)?; | ||
| Ok(()) | ||
| } | ||
|
|
||
| fn to_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> { | ||
| let l_scalar = self.0.left().to_scalar(&partial.0)?; | ||
| let r_scalar = self.0.right().to_scalar(&partial.1)?; | ||
| let dtype = struct_dtype( | ||
| self.0.left_name(), | ||
| self.0.right_name(), | ||
| l_scalar.dtype().clone(), | ||
| r_scalar.dtype().clone(), | ||
| ); | ||
| Ok(Scalar::struct_(dtype, vec![l_scalar, r_scalar])) | ||
| } | ||
|
|
||
| fn reset(&self, partial: &mut Self::Partial) { | ||
| self.0.left().reset(&mut partial.0); | ||
| self.0.right().reset(&mut partial.1); | ||
| } | ||
|
|
||
| fn is_saturated(&self, partial: &Self::Partial) -> bool { | ||
| self.0.left().is_saturated(&partial.0) && self.0.right().is_saturated(&partial.1) | ||
| } | ||
|
|
||
| /// Fans out to each child's `try_accumulate`, falling back to `accumulate` | ||
| /// against a lazily-canonicalized batch. We always claim to handle the | ||
| /// batch ourselves so [`Self::accumulate`] is unreachable — this is the | ||
| /// same trick `Count` uses to opt out of the canonicalization path. | ||
| fn try_accumulate( | ||
| &self, | ||
| state: &mut Self::Partial, | ||
| batch: &ArrayRef, | ||
| ctx: &mut ExecutionCtx, | ||
| ) -> VortexResult<bool> { | ||
| let mut canonical: Option<Columnar> = None; | ||
| if !self.0.left().try_accumulate(&mut state.0, batch, ctx)? { | ||
| canonical = Some(batch.clone().execute::<Columnar>(ctx)?); | ||
| self.0 | ||
| .left() | ||
| .accumulate(&mut state.0, canonical.as_ref().expect("just set"), ctx)?; | ||
| } | ||
| if !self.0.right().try_accumulate(&mut state.1, batch, ctx)? { | ||
| if canonical.is_none() { | ||
| canonical = Some(batch.clone().execute::<Columnar>(ctx)?); | ||
| } | ||
| self.0 | ||
| .right() | ||
| .accumulate(&mut state.1, canonical.as_ref().expect("just set"), ctx)?; | ||
| } | ||
| Ok(true) | ||
| } | ||
|
|
||
| fn accumulate( | ||
| &self, | ||
| _state: &mut Self::Partial, | ||
| _batch: &Columnar, | ||
| _ctx: &mut ExecutionCtx, | ||
| ) -> VortexResult<()> { | ||
| unreachable!("Combined::try_accumulate handles all batches") | ||
| } | ||
|
|
||
| fn finalize(&self, states: ArrayRef) -> VortexResult<ArrayRef> { | ||
| let l_field = states.get_item(FieldName::from(self.0.left_name()))?; | ||
| let r_field = states.get_item(FieldName::from(self.0.right_name()))?; | ||
| let l_finalized = self.0.left().finalize(l_field)?; | ||
| let r_finalized = self.0.right().finalize(r_field)?; | ||
| BinaryCombined::finalize(&self.0, l_finalized, r_finalized) | ||
| } | ||
| } | ||
|
|
||
| fn struct_dtype(left_name: &str, right_name: &str, left: DType, right: DType) -> DType { | ||
| DType::Struct( | ||
| StructFields::new( | ||
| FieldNames::from_iter([ | ||
| FieldName::from(left_name), | ||
| FieldName::from(right_name), | ||
| ]), | ||
| vec![left, right], | ||
| ), | ||
| Nullability::NonNullable, | ||
| ) | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can factor this out of this PR, we could at a later date add the back to on disk version will be the same?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i blocked serialization for now?
why?