Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vortex-array/src/aggregate_fn/accumulator_grouped.rs
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ impl<V: AggregateFnVTable> GroupedAccumulator<V> {
if validity.value(i) {
let group = elements.slice(offset..offset + size)?;
accumulator.accumulate(&group, ctx)?;
states.append_scalar(&accumulator.finish()?)?;
states.append_scalar(&accumulator.flush()?)?;
} else {
states.append_null()
}
Expand Down
254 changes: 254 additions & 0 deletions vortex-array/src/aggregate_fn/combined.rs
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can factor this out of this PR, we could at a later date add the back to on disk version will be the same?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

disk version will be the same

i blocked serialization for now?

we can factor this out of this PR

why?

Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright the Vortex contributors

//! Generic adapter for aggregates whose result is computed from two child
//! aggregate functions, e.g. `Mean = Sum / Count`.

use std::fmt::{self, Debug, Display, Formatter};
use std::hash::Hash;

use vortex_error::{VortexResult, vortex_bail, vortex_err};
use vortex_session::VortexSession;

use crate::aggregate_fn::{AggregateFnId, AggregateFnVTable};
use crate::builtins::ArrayBuiltins;
use crate::dtype::{DType, FieldName, FieldNames, Nullability, StructFields};
use crate::scalar::Scalar;
use crate::{ArrayRef, Columnar, ExecutionCtx};

/// Pair of options for the two children of a [`BinaryCombined`] aggregate.
///
/// Wrapper around `(L, R)` because the [`AggregateFnVTable::Options`] bound
/// requires `Display`, which tuples don't implement.
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct PairOptions<L, R>(pub L, pub R);

impl<L: Display, R: Display> Display for PairOptions<L, R> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "({}, {})", self.0, self.1)
}
}

// Convenience aliases so signatures stay readable.
type LeftOptions<T> = <<T as BinaryCombined>::Left as AggregateFnVTable>::Options;
type RightOptions<T> = <<T as BinaryCombined>::Right as AggregateFnVTable>::Options;
type LeftPartial<T> = <<T as BinaryCombined>::Left as AggregateFnVTable>::Partial;
type RightPartial<T> = <<T as BinaryCombined>::Right as AggregateFnVTable>::Partial;
/// Combined options for a [`BinaryCombined`] aggregate.
pub type CombinedOptions<T> = PairOptions<LeftOptions<T>, RightOptions<T>>;

/// Declare an aggregate function in terms of two child aggregates.
pub trait BinaryCombined: 'static + Send + Sync + Clone {
/// The left child aggregate vtable.
type Left: AggregateFnVTable;
/// The right child aggregate vtable.
type Right: AggregateFnVTable;

/// Stable identifier for the combined aggregate.
fn id(&self) -> AggregateFnId;

/// Construct the left child vtable.
fn left(&self) -> Self::Left;

/// Construct the right child vtable.
fn right(&self) -> Self::Right;

/// Field name for the left child in the partial struct dtype.
fn left_name(&self) -> &'static str {
"left"
}

/// Field name for the right child in the partial struct dtype.
fn right_name(&self) -> &'static str {
"right"
}
Comment on lines +68 to +76
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why default these?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

those don't matter, so may as well use left/right


/// Return type of the combined aggregate.
fn return_dtype(&self, input_dtype: &DType) -> Option<DType>;

/// Combine the finalized left and right results into the final aggregate.
fn finalize(&self, left: ArrayRef, right: ArrayRef) -> VortexResult<ArrayRef>;

/// Serialize the options for this combined aggregate. Default: not serializable.
fn serialize(&self, options: &CombinedOptions<Self>) -> VortexResult<Option<Vec<u8>>> {
let _ = options;
Ok(None)
}

/// Deserialize the options for this combined aggregate. Default: bails.
fn deserialize(
&self,
metadata: &[u8],
session: &VortexSession,
) -> VortexResult<CombinedOptions<Self>> {
let _ = (metadata, session);
vortex_bail!(
"Combined aggregate function {} is not deserializable",
BinaryCombined::id(self)
);
}

/// Coerce the input type. Default: chains `right.coerce_args(left.coerce_args(input))`.
fn coerce_args(
&self,
options: &CombinedOptions<Self>,
input_dtype: &DType,
) -> VortexResult<DType> {
let left_coerced = self.left().coerce_args(&options.0, input_dtype)?;
self.right().coerce_args(&options.1, &left_coerced)
}
}

/// Adapter that exposes any [`BinaryCombined`] as an [`AggregateFnVTable`].
#[derive(Clone, Debug)]
pub struct Combined<T: BinaryCombined>(pub T);

impl<T: BinaryCombined> Combined<T> {
/// Construct a new combined aggregate vtable.
pub fn new(inner: T) -> Self {
Self(inner)
}
}

impl<T: BinaryCombined> AggregateFnVTable for Combined<T> {
type Options = CombinedOptions<T>;
type Partial = (LeftPartial<T>, RightPartial<T>);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How will this partial be seralized?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as a struct scalar with left/right fields


fn id(&self) -> AggregateFnId {
self.0.id()
}

fn serialize(&self, options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
BinaryCombined::serialize(&self.0, options)
}

fn deserialize(
&self,
metadata: &[u8],
session: &VortexSession,
) -> VortexResult<Self::Options> {
BinaryCombined::deserialize(&self.0, metadata, session)
}

fn coerce_args(&self, options: &Self::Options, input_dtype: &DType) -> VortexResult<DType> {
BinaryCombined::coerce_args(&self.0, options, input_dtype)
}

fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option<DType> {
BinaryCombined::return_dtype(&self.0, input_dtype)
}

fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option<DType> {
let l = self.0.left().partial_dtype(&options.0, input_dtype)?;
let r = self.0.right().partial_dtype(&options.1, input_dtype)?;
Some(struct_dtype(self.0.left_name(), self.0.right_name(), l, r))
}

fn empty_partial(
&self,
options: &Self::Options,
input_dtype: &DType,
) -> VortexResult<Self::Partial> {
Ok((
self.0.left().empty_partial(&options.0, input_dtype)?,
self.0.right().empty_partial(&options.1, input_dtype)?,
))
}

fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()> {
if other.is_null() {
return Ok(());
}
let s = other.as_struct();
let lname = self.0.left_name();
let rname = self.0.right_name();
let l_field = s
.field(lname)
.ok_or_else(|| vortex_err!("BinaryCombined partial missing `{}` field", lname))?;
let r_field = s
.field(rname)
.ok_or_else(|| vortex_err!("BinaryCombined partial missing `{}` field", rname))?;
self.0.left().combine_partials(&mut partial.0, l_field)?;
self.0.right().combine_partials(&mut partial.1, r_field)?;
Ok(())
}

fn to_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
let l_scalar = self.0.left().to_scalar(&partial.0)?;
let r_scalar = self.0.right().to_scalar(&partial.1)?;
let dtype = struct_dtype(
self.0.left_name(),
self.0.right_name(),
l_scalar.dtype().clone(),
r_scalar.dtype().clone(),
);
Ok(Scalar::struct_(dtype, vec![l_scalar, r_scalar]))
}

fn reset(&self, partial: &mut Self::Partial) {
self.0.left().reset(&mut partial.0);
self.0.right().reset(&mut partial.1);
}

fn is_saturated(&self, partial: &Self::Partial) -> bool {
self.0.left().is_saturated(&partial.0) && self.0.right().is_saturated(&partial.1)
}

/// Fans out to each child's `try_accumulate`, falling back to `accumulate`
/// against a lazily-canonicalized batch. We always claim to handle the
/// batch ourselves so [`Self::accumulate`] is unreachable — this is the
/// same trick `Count` uses to opt out of the canonicalization path.
fn try_accumulate(
&self,
state: &mut Self::Partial,
batch: &ArrayRef,
ctx: &mut ExecutionCtx,
) -> VortexResult<bool> {
let mut canonical: Option<Columnar> = None;
if !self.0.left().try_accumulate(&mut state.0, batch, ctx)? {
canonical = Some(batch.clone().execute::<Columnar>(ctx)?);
self.0
.left()
.accumulate(&mut state.0, canonical.as_ref().expect("just set"), ctx)?;
}
if !self.0.right().try_accumulate(&mut state.1, batch, ctx)? {
if canonical.is_none() {
canonical = Some(batch.clone().execute::<Columnar>(ctx)?);
}
self.0
.right()
.accumulate(&mut state.1, canonical.as_ref().expect("just set"), ctx)?;
}
Ok(true)
}

fn accumulate(
&self,
_state: &mut Self::Partial,
_batch: &Columnar,
_ctx: &mut ExecutionCtx,
) -> VortexResult<()> {
unreachable!("Combined::try_accumulate handles all batches")
}

fn finalize(&self, states: ArrayRef) -> VortexResult<ArrayRef> {
let l_field = states.get_item(FieldName::from(self.0.left_name()))?;
let r_field = states.get_item(FieldName::from(self.0.right_name()))?;
let l_finalized = self.0.left().finalize(l_field)?;
let r_finalized = self.0.right().finalize(r_field)?;
BinaryCombined::finalize(&self.0, l_finalized, r_finalized)
}
}

fn struct_dtype(left_name: &str, right_name: &str, left: DType, right: DType) -> DType {
DType::Struct(
StructFields::new(
FieldNames::from_iter([
FieldName::from(left_name),
FieldName::from(right_name),
]),
vec![left, right],
),
Nullability::NonNullable,
)
}
Loading
Loading