Skip to content

Commit 0515378

Browse files
committed
Combined
Signed-off-by: blaginin <github@blaginin.me>
1 parent 53806a4 commit 0515378

3 files changed

Lines changed: 257 additions & 0 deletions

File tree

Lines changed: 254 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,254 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
//! Generic adapter for aggregates whose result is computed from two child
5+
//! aggregate functions, e.g. `Mean = Sum / Count`.
6+
7+
use std::fmt::{self, Debug, Display, Formatter};
8+
use std::hash::Hash;
9+
10+
use vortex_error::{VortexResult, vortex_bail, vortex_err};
11+
use vortex_session::VortexSession;
12+
13+
use crate::aggregate_fn::{AggregateFnId, AggregateFnVTable};
14+
use crate::builtins::ArrayBuiltins;
15+
use crate::dtype::{DType, FieldName, FieldNames, Nullability, StructFields};
16+
use crate::scalar::Scalar;
17+
use crate::{ArrayRef, Columnar, ExecutionCtx};
18+
19+
/// Pair of options for the two children of a [`BinaryCombined`] aggregate.
20+
///
21+
/// Wrapper around `(L, R)` because the [`AggregateFnVTable::Options`] bound
22+
/// requires `Display`, which tuples don't implement.
23+
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
24+
pub struct PairOptions<L, R>(pub L, pub R);
25+
26+
impl<L: Display, R: Display> Display for PairOptions<L, R> {
27+
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
28+
write!(f, "({}, {})", self.0, self.1)
29+
}
30+
}
31+
32+
// Convenience aliases so signatures stay readable.
33+
type LeftOptions<T> = <<T as BinaryCombined>::Left as AggregateFnVTable>::Options;
34+
type RightOptions<T> = <<T as BinaryCombined>::Right as AggregateFnVTable>::Options;
35+
type LeftPartial<T> = <<T as BinaryCombined>::Left as AggregateFnVTable>::Partial;
36+
type RightPartial<T> = <<T as BinaryCombined>::Right as AggregateFnVTable>::Partial;
37+
/// Combined options for a [`BinaryCombined`] aggregate.
38+
pub type CombinedOptions<T> = PairOptions<LeftOptions<T>, RightOptions<T>>;
39+
40+
/// Declare an aggregate function in terms of two child aggregates.
41+
pub trait BinaryCombined: 'static + Send + Sync + Clone {
42+
/// The left child aggregate vtable.
43+
type Left: AggregateFnVTable;
44+
/// The right child aggregate vtable.
45+
type Right: AggregateFnVTable;
46+
47+
/// Stable identifier for the combined aggregate.
48+
fn id(&self) -> AggregateFnId;
49+
50+
/// Construct the left child vtable.
51+
fn left(&self) -> Self::Left;
52+
53+
/// Construct the right child vtable.
54+
fn right(&self) -> Self::Right;
55+
56+
/// Field name for the left child in the partial struct dtype.
57+
fn left_name(&self) -> &'static str {
58+
"left"
59+
}
60+
61+
/// Field name for the right child in the partial struct dtype.
62+
fn right_name(&self) -> &'static str {
63+
"right"
64+
}
65+
66+
/// Return type of the combined aggregate.
67+
fn return_dtype(&self, input_dtype: &DType) -> Option<DType>;
68+
69+
/// Combine the finalized left and right results into the final aggregate.
70+
fn finalize(&self, left: ArrayRef, right: ArrayRef) -> VortexResult<ArrayRef>;
71+
72+
/// Serialize the options for this combined aggregate. Default: not serializable.
73+
fn serialize(&self, options: &CombinedOptions<Self>) -> VortexResult<Option<Vec<u8>>> {
74+
let _ = options;
75+
Ok(None)
76+
}
77+
78+
/// Deserialize the options for this combined aggregate. Default: bails.
79+
fn deserialize(
80+
&self,
81+
metadata: &[u8],
82+
session: &VortexSession,
83+
) -> VortexResult<CombinedOptions<Self>> {
84+
let _ = (metadata, session);
85+
vortex_bail!(
86+
"Combined aggregate function {} is not deserializable",
87+
BinaryCombined::id(self)
88+
);
89+
}
90+
91+
/// Coerce the input type. Default: chains `right.coerce_args(left.coerce_args(input))`.
92+
fn coerce_args(
93+
&self,
94+
options: &CombinedOptions<Self>,
95+
input_dtype: &DType,
96+
) -> VortexResult<DType> {
97+
let left_coerced = self.left().coerce_args(&options.0, input_dtype)?;
98+
self.right().coerce_args(&options.1, &left_coerced)
99+
}
100+
}
101+
102+
/// Adapter that exposes any [`BinaryCombined`] as an [`AggregateFnVTable`].
103+
#[derive(Clone, Debug)]
104+
pub struct Combined<T: BinaryCombined>(pub T);
105+
106+
impl<T: BinaryCombined> Combined<T> {
107+
/// Construct a new combined aggregate vtable.
108+
pub fn new(inner: T) -> Self {
109+
Self(inner)
110+
}
111+
}
112+
113+
impl<T: BinaryCombined> AggregateFnVTable for Combined<T> {
114+
type Options = CombinedOptions<T>;
115+
type Partial = (LeftPartial<T>, RightPartial<T>);
116+
117+
fn id(&self) -> AggregateFnId {
118+
self.0.id()
119+
}
120+
121+
fn serialize(&self, options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
122+
BinaryCombined::serialize(&self.0, options)
123+
}
124+
125+
fn deserialize(
126+
&self,
127+
metadata: &[u8],
128+
session: &VortexSession,
129+
) -> VortexResult<Self::Options> {
130+
BinaryCombined::deserialize(&self.0, metadata, session)
131+
}
132+
133+
fn coerce_args(&self, options: &Self::Options, input_dtype: &DType) -> VortexResult<DType> {
134+
BinaryCombined::coerce_args(&self.0, options, input_dtype)
135+
}
136+
137+
fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option<DType> {
138+
BinaryCombined::return_dtype(&self.0, input_dtype)
139+
}
140+
141+
fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option<DType> {
142+
let l = self.0.left().partial_dtype(&options.0, input_dtype)?;
143+
let r = self.0.right().partial_dtype(&options.1, input_dtype)?;
144+
Some(struct_dtype(self.0.left_name(), self.0.right_name(), l, r))
145+
}
146+
147+
fn empty_partial(
148+
&self,
149+
options: &Self::Options,
150+
input_dtype: &DType,
151+
) -> VortexResult<Self::Partial> {
152+
Ok((
153+
self.0.left().empty_partial(&options.0, input_dtype)?,
154+
self.0.right().empty_partial(&options.1, input_dtype)?,
155+
))
156+
}
157+
158+
fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()> {
159+
if other.is_null() {
160+
return Ok(());
161+
}
162+
let s = other.as_struct();
163+
let lname = self.0.left_name();
164+
let rname = self.0.right_name();
165+
let l_field = s
166+
.field(lname)
167+
.ok_or_else(|| vortex_err!("BinaryCombined partial missing `{}` field", lname))?;
168+
let r_field = s
169+
.field(rname)
170+
.ok_or_else(|| vortex_err!("BinaryCombined partial missing `{}` field", rname))?;
171+
self.0.left().combine_partials(&mut partial.0, l_field)?;
172+
self.0.right().combine_partials(&mut partial.1, r_field)?;
173+
Ok(())
174+
}
175+
176+
fn to_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
177+
let l_scalar = self.0.left().to_scalar(&partial.0)?;
178+
let r_scalar = self.0.right().to_scalar(&partial.1)?;
179+
let dtype = struct_dtype(
180+
self.0.left_name(),
181+
self.0.right_name(),
182+
l_scalar.dtype().clone(),
183+
r_scalar.dtype().clone(),
184+
);
185+
Ok(Scalar::struct_(dtype, vec![l_scalar, r_scalar]))
186+
}
187+
188+
fn reset(&self, partial: &mut Self::Partial) {
189+
self.0.left().reset(&mut partial.0);
190+
self.0.right().reset(&mut partial.1);
191+
}
192+
193+
fn is_saturated(&self, partial: &Self::Partial) -> bool {
194+
self.0.left().is_saturated(&partial.0) && self.0.right().is_saturated(&partial.1)
195+
}
196+
197+
/// Fans out to each child's `try_accumulate`, falling back to `accumulate`
198+
/// against a lazily-canonicalized batch. We always claim to handle the
199+
/// batch ourselves so [`Self::accumulate`] is unreachable — this is the
200+
/// same trick `Count` uses to opt out of the canonicalization path.
201+
fn try_accumulate(
202+
&self,
203+
state: &mut Self::Partial,
204+
batch: &ArrayRef,
205+
ctx: &mut ExecutionCtx,
206+
) -> VortexResult<bool> {
207+
let mut canonical: Option<Columnar> = None;
208+
if !self.0.left().try_accumulate(&mut state.0, batch, ctx)? {
209+
canonical = Some(batch.clone().execute::<Columnar>(ctx)?);
210+
self.0
211+
.left()
212+
.accumulate(&mut state.0, canonical.as_ref().expect("just set"), ctx)?;
213+
}
214+
if !self.0.right().try_accumulate(&mut state.1, batch, ctx)? {
215+
if canonical.is_none() {
216+
canonical = Some(batch.clone().execute::<Columnar>(ctx)?);
217+
}
218+
self.0
219+
.right()
220+
.accumulate(&mut state.1, canonical.as_ref().expect("just set"), ctx)?;
221+
}
222+
Ok(true)
223+
}
224+
225+
fn accumulate(
226+
&self,
227+
_state: &mut Self::Partial,
228+
_batch: &Columnar,
229+
_ctx: &mut ExecutionCtx,
230+
) -> VortexResult<()> {
231+
unreachable!("Combined::try_accumulate handles all batches")
232+
}
233+
234+
fn finalize(&self, states: ArrayRef) -> VortexResult<ArrayRef> {
235+
let l_field = states.get_item(FieldName::from(self.0.left_name()))?;
236+
let r_field = states.get_item(FieldName::from(self.0.right_name()))?;
237+
let l_finalized = self.0.left().finalize(l_field)?;
238+
let r_finalized = self.0.right().finalize(r_field)?;
239+
BinaryCombined::finalize(&self.0, l_finalized, r_finalized)
240+
}
241+
}
242+
243+
fn struct_dtype(left_name: &str, right_name: &str, left: DType, right: DType) -> DType {
244+
DType::Struct(
245+
StructFields::new(
246+
FieldNames::from_iter([
247+
FieldName::from(left_name),
248+
FieldName::from(right_name),
249+
]),
250+
vec![left, right],
251+
),
252+
Nullability::NonNullable,
253+
)
254+
}

vortex-array/src/aggregate_fn/fns/mean/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@ use crate::aggregate_fn::AggregateFnId;
1818
use crate::aggregate_fn::AggregateFnVTable;
1919
use crate::aggregate_fn::DynAccumulator;
2020
use crate::aggregate_fn::EmptyOptions;
21+
use crate::arrays::bool::BoolArrayExt;
2122
use crate::arrays::PrimitiveArray;
23+
use crate::arrays::struct_::StructArrayExt;
2224
use crate::canonical::ToCanonical;
2325
use crate::dtype::DType;
2426
use crate::dtype::FieldName;

vortex-array/src/aggregate_fn/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ pub mod fns;
3333
pub mod kernels;
3434
pub mod proto;
3535
pub mod session;
36+
pub mod combined;
3637

3738
/// A unique identifier for an aggregate function.
3839
pub type AggregateFnId = ArcRef<str>;

0 commit comments

Comments
 (0)