Skip to content

Commit 6f54d3d

Browse files
authored
Reorder agg kernel dispatch, and have Combined use inner accumulators (#7889)
Aggregate kernel dispatch previously tried the vtable before the plugin registry, meaning plugins could not intercept the aggregate function as they can with array execution. This change also fixed Combined to use inner accumulators, instead of just inner partials. This fixes the kernel dispatch for the delegated aggregate functions. e.g. a custom Sum<->Dict kernel previously would not have been used for Combiner<Avg><->Dict, even though it delegates to Sum + Count. --------- Signed-off-by: Nicholas Gates <nick@nickgates.com>
1 parent 3e93048 commit 6f54d3d

7 files changed

Lines changed: 405 additions & 106 deletions

File tree

vortex-array/public-api.lock

Lines changed: 72 additions & 2 deletions
Large diffs are not rendered by default.

vortex-array/src/aggregate_fn/accumulator.rs

Lines changed: 285 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@ use vortex_error::VortexResult;
55
use vortex_error::vortex_ensure;
66
use vortex_error::vortex_err;
77

8-
use crate::AnyCanonical;
98
use crate::ArrayRef;
109
use crate::Columnar;
1110
use crate::ExecutionCtx;
1211
use crate::aggregate_fn::AggregateFn;
1312
use crate::aggregate_fn::AggregateFnRef;
1413
use crate::aggregate_fn::AggregateFnVTable;
1514
use crate::aggregate_fn::session::AggregateFnSessionExt;
15+
use crate::columnar::AnyColumnar;
1616
use crate::dtype::DType;
1717
use crate::executor::max_iterations;
1818
use crate::scalar::Scalar;
@@ -72,9 +72,26 @@ pub trait DynAccumulator: 'static + Send {
7272
/// Accumulate a new array into the accumulator's state.
7373
fn accumulate(&mut self, batch: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()>;
7474

75+
/// Fold an external partial-state scalar into this accumulator's state.
76+
///
77+
/// The scalar must have the dtype reported by the vtable's `partial_dtype` for the
78+
/// options and input dtype used to construct this accumulator.
79+
fn combine_partials(&mut self, other: Scalar) -> VortexResult<()>;
80+
7581
/// Whether the accumulator's result is fully determined.
7682
fn is_saturated(&self) -> bool;
7783

84+
/// Reset the accumulator's state to the empty group.
85+
fn reset(&mut self);
86+
87+
/// Read the current partial state as a scalar without resetting it.
88+
///
89+
/// The returned scalar has the dtype reported by the vtable's `partial_dtype`.
90+
fn partial_scalar(&self) -> VortexResult<Scalar>;
91+
92+
/// Compute the final aggregate result as a scalar without resetting state.
93+
fn final_scalar(&self) -> VortexResult<Scalar>;
94+
7895
/// Flush the accumulation state and return the partial aggregate result as a scalar.
7996
///
8097
/// Resets the accumulator state back to the initial state.
@@ -99,31 +116,75 @@ impl<V: AggregateFnVTable> DynAccumulator for Accumulator<V> {
99116
batch.dtype()
100117
);
101118

102-
// Allow the vtable to short-circuit on the raw array before decompression.
103-
if self.vtable.try_accumulate(&mut self.partial, batch, ctx)? {
119+
// 0. Stats-driven shortcut: if the aggregate can be derived directly from the batch's
120+
// cached statistics, use that and skip both kernel dispatch and decode. This is the
121+
// only layer that consults `batch.statistics()`; encoding kernels must not.
122+
if let Some(result) = self.vtable.try_partial_from_stats(batch)? {
123+
vortex_ensure!(
124+
result.dtype() == &self.partial_dtype,
125+
"Aggregate try_partial_from_stats returned {}, expected {}",
126+
result.dtype(),
127+
self.partial_dtype,
128+
);
129+
self.vtable.combine_partials(&mut self.partial, result)?;
104130
return Ok(());
105131
}
106132

107133
let session = ctx.session().clone();
108134
let kernels = &session.aggregate_fns().kernels;
109135

136+
// 1. Kernel registry first: a registered `(encoding, aggregate_fn)` kernel is strictly
137+
// more specific than the vtable's `try_accumulate` short-circuit. Checking the
138+
// registry first gives kernels for `Combined<V>` aggregates a chance to fire —
139+
// `Combined::try_accumulate` always returns true, so a later kernel check would be
140+
// unreachable.
141+
{
142+
let kernels_r = kernels.read();
143+
let batch_id = batch.encoding_id();
144+
let kernel = kernels_r
145+
.get(&(batch_id, Some(self.aggregate_fn.id())))
146+
.or_else(|| kernels_r.get(&(batch_id, None)))
147+
.copied();
148+
drop(kernels_r);
149+
if let Some(kernel) = kernel
150+
&& let Some(result) = kernel.aggregate(&self.aggregate_fn, batch, ctx)?
151+
{
152+
vortex_ensure!(
153+
result.dtype() == &self.partial_dtype,
154+
"Aggregate kernel returned {}, expected {}",
155+
result.dtype(),
156+
self.partial_dtype,
157+
);
158+
self.vtable.combine_partials(&mut self.partial, result)?;
159+
return Ok(());
160+
}
161+
}
162+
163+
// 2. Allow the vtable to short-circuit on the raw array before decompression.
164+
if self.vtable.try_accumulate(&mut self.partial, batch, ctx)? {
165+
return Ok(());
166+
}
167+
168+
// 3. Iteratively check the registry against each intermediate encoding, executing one
169+
// step between checks. Mirrors the loop in `GroupedAccumulator::accumulate_list_view`.
170+
// Iteration 0 re-checks the initial encoding — a redundant HashMap miss, the price of
171+
// keeping the loop body uniform. Terminates on `AnyColumnar` (Canonical or Constant)
172+
// since the vtable's `accumulate(&Columnar)` handles both cases directly.
110173
let mut batch = batch.clone();
111174
for _ in 0..max_iterations() {
112-
if batch.is::<AnyCanonical>() {
175+
if batch.is::<AnyColumnar>() {
113176
break;
114177
}
115178

116179
let kernels_r = kernels.read();
117180
let batch_id = batch.encoding_id();
118-
if let Some(result) = kernels_r
181+
let kernel = kernels_r
119182
.get(&(batch_id, Some(self.aggregate_fn.id())))
120183
.or_else(|| kernels_r.get(&(batch_id, None)))
121-
.and_then(|kernel| {
122-
kernel
123-
.aggregate(&self.aggregate_fn, &batch, ctx)
124-
.transpose()
125-
})
126-
.transpose()?
184+
.copied();
185+
drop(kernels_r);
186+
if let Some(kernel) = kernel
187+
&& let Some(result) = kernel.aggregate(&self.aggregate_fn, &batch, ctx)?
127188
{
128189
vortex_ensure!(
129190
result.dtype() == &self.partial_dtype,
@@ -135,29 +196,35 @@ impl<V: AggregateFnVTable> DynAccumulator for Accumulator<V> {
135196
return Ok(());
136197
}
137198

138-
// Execute one step and try again
139199
batch = batch.execute(ctx)?;
140200
}
141201

142-
// Otherwise, execute the batch until it is columnar and accumulate it into the state.
202+
// 4. Otherwise, execute the batch until it is columnar and accumulate it into the state.
143203
let columnar = batch.execute::<Columnar>(ctx)?;
144204

145205
self.vtable.accumulate(&mut self.partial, &columnar, ctx)
146206
}
147207

208+
fn combine_partials(&mut self, other: Scalar) -> VortexResult<()> {
209+
self.vtable.combine_partials(&mut self.partial, other)
210+
}
211+
148212
fn is_saturated(&self) -> bool {
149213
self.vtable.is_saturated(&self.partial)
150214
}
151215

152-
fn flush(&mut self) -> VortexResult<Scalar> {
153-
let partial = self.vtable.to_scalar(&self.partial)?;
216+
fn reset(&mut self) {
154217
self.vtable.reset(&mut self.partial);
218+
}
219+
220+
fn partial_scalar(&self) -> VortexResult<Scalar> {
221+
let partial = self.vtable.to_scalar(&self.partial)?;
155222

156223
#[cfg(debug_assertions)]
157224
{
158225
vortex_ensure!(
159226
partial.dtype() == &self.partial_dtype,
160-
"Aggregate kernel returned incorrect DType on flush: expected {}, got {}",
227+
"Aggregate returned incorrect DType on partial_scalar: expected {}, got {}",
161228
self.partial_dtype,
162229
partial.dtype(),
163230
);
@@ -166,17 +233,216 @@ impl<V: AggregateFnVTable> DynAccumulator for Accumulator<V> {
166233
Ok(partial)
167234
}
168235

169-
fn finish(&mut self) -> VortexResult<Scalar> {
236+
fn final_scalar(&self) -> VortexResult<Scalar> {
170237
let result = self.vtable.finalize_scalar(&self.partial)?;
171-
self.vtable.reset(&mut self.partial);
172238

173239
vortex_ensure!(
174240
result.dtype() == &self.return_dtype,
175-
"Aggregate kernel returned incorrect DType on finalize: expected {}, got {}",
241+
"Aggregate returned incorrect DType on final_scalar: expected {}, got {}",
176242
self.return_dtype,
177243
result.dtype(),
178244
);
179245

180246
Ok(result)
181247
}
248+
249+
fn flush(&mut self) -> VortexResult<Scalar> {
250+
let partial = self.partial_scalar()?;
251+
self.reset();
252+
Ok(partial)
253+
}
254+
255+
fn finish(&mut self) -> VortexResult<Scalar> {
256+
let result = self.final_scalar()?;
257+
self.reset();
258+
Ok(result)
259+
}
260+
}
261+
262+
#[cfg(test)]
263+
mod tests {
264+
use vortex_buffer::buffer;
265+
use vortex_error::VortexResult;
266+
use vortex_session::SessionExt;
267+
use vortex_session::VortexSession;
268+
269+
use crate::ArrayRef;
270+
use crate::ExecutionCtx;
271+
use crate::IntoArray;
272+
use crate::VortexSessionExecute;
273+
use crate::aggregate_fn::Accumulator;
274+
use crate::aggregate_fn::AggregateFnRef;
275+
use crate::aggregate_fn::AggregateFnVTable;
276+
use crate::aggregate_fn::DynAccumulator;
277+
use crate::aggregate_fn::EmptyOptions;
278+
use crate::aggregate_fn::combined::Combined;
279+
use crate::aggregate_fn::combined::PairOptions;
280+
use crate::aggregate_fn::fns::mean::Mean;
281+
use crate::aggregate_fn::fns::sum::Sum;
282+
use crate::aggregate_fn::kernels::DynAggregateKernel;
283+
use crate::aggregate_fn::session::AggregateFnSession;
284+
use crate::array::VTable;
285+
use crate::arrays::Dict;
286+
use crate::arrays::DictArray;
287+
use crate::dtype::DType;
288+
use crate::dtype::Nullability;
289+
use crate::dtype::PType;
290+
use crate::scalar::Scalar;
291+
use crate::session::ArraySession;
292+
293+
/// Mean partial sentinel `{sum: 42.0, count: 1}` — distinguishable from the
294+
/// natural fan-out result `{sum: 7.0, count: 1}` that `Combined::try_accumulate`
295+
/// would produce for `dict_of_seven()`.
296+
#[derive(Debug)]
297+
struct SentinelMeanPartialKernel;
298+
impl DynAggregateKernel for SentinelMeanPartialKernel {
299+
fn aggregate(
300+
&self,
301+
_aggregate_fn: &AggregateFnRef,
302+
_batch: &ArrayRef,
303+
_ctx: &mut ExecutionCtx,
304+
) -> VortexResult<Option<Scalar>> {
305+
Ok(Some(sentinel_partial()))
306+
}
307+
}
308+
309+
/// Returns `Ok(None)` => kernel declined, dispatch falls through.
310+
#[derive(Debug)]
311+
struct DeclineKernel;
312+
impl DynAggregateKernel for DeclineKernel {
313+
fn aggregate(
314+
&self,
315+
_aggregate_fn: &AggregateFnRef,
316+
_batch: &ArrayRef,
317+
_ctx: &mut ExecutionCtx,
318+
) -> VortexResult<Option<Scalar>> {
319+
Ok(None)
320+
}
321+
}
322+
323+
/// Sum partial sentinel `42.0` — distinguishable from the natural Sum of
324+
/// `dict_of_seven()` which is `7.0`.
325+
#[derive(Debug)]
326+
struct SentinelSumPartialKernel;
327+
impl DynAggregateKernel for SentinelSumPartialKernel {
328+
fn aggregate(
329+
&self,
330+
_aggregate_fn: &AggregateFnRef,
331+
_batch: &ArrayRef,
332+
_ctx: &mut ExecutionCtx,
333+
) -> VortexResult<Option<Scalar>> {
334+
Ok(Some(Scalar::primitive(42.0f64, Nullability::Nullable)))
335+
}
336+
}
337+
338+
fn fresh_session() -> VortexSession {
339+
VortexSession::empty().with::<ArraySession>()
340+
}
341+
342+
fn dict_of_seven() -> ArrayRef {
343+
DictArray::try_new(buffer![0u32].into_array(), buffer![7.0f64].into_array())
344+
.expect("valid dictionary")
345+
.into_array()
346+
}
347+
348+
fn mean_f64_accumulator() -> VortexResult<Accumulator<Combined<Mean>>> {
349+
let dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
350+
Accumulator::try_new(
351+
Mean::combined(),
352+
PairOptions(EmptyOptions, EmptyOptions),
353+
dtype,
354+
)
355+
}
356+
357+
fn sentinel_partial() -> Scalar {
358+
let acc = mean_f64_accumulator().expect("build accumulator");
359+
let sum = Scalar::primitive(42.0f64, Nullability::Nullable);
360+
let count = Scalar::primitive(1u64, Nullability::NonNullable);
361+
Scalar::struct_(acc.partial_dtype, vec![sum, count])
362+
}
363+
364+
/// Kernel registered for `(Dict, Combined<Mean>)` fires in preference to
365+
/// `Combined::try_accumulate`'s fan-out path — proves the dispatch reorder.
366+
#[test]
367+
fn combined_kernel_fires() -> VortexResult<()> {
368+
static KERNEL: SentinelMeanPartialKernel = SentinelMeanPartialKernel;
369+
let session = fresh_session();
370+
session
371+
.get::<AggregateFnSession>()
372+
.register_aggregate_kernel(Dict.id(), Some(Mean::combined().id()), &KERNEL);
373+
let mut ctx = session.create_execution_ctx();
374+
375+
let mut acc = mean_f64_accumulator()?;
376+
acc.accumulate(&dict_of_seven(), &mut ctx)?;
377+
let partial = acc.flush()?;
378+
379+
let s = partial.as_struct();
380+
assert_eq!(
381+
s.field("sum").unwrap().as_primitive().as_::<f64>(),
382+
Some(42.0)
383+
);
384+
assert_eq!(
385+
s.field("count").unwrap().as_primitive().as_::<u64>(),
386+
Some(1)
387+
);
388+
Ok(())
389+
}
390+
391+
/// Kernel returns `Ok(None)` => dispatch falls through to `Combined::try_accumulate`'s
392+
/// natural fan-out. The natural partial is `{sum: 7.0, count: 1}`.
393+
#[test]
394+
fn fallback_when_kernel_declines() -> VortexResult<()> {
395+
static KERNEL: DeclineKernel = DeclineKernel;
396+
let session = fresh_session();
397+
session
398+
.get::<AggregateFnSession>()
399+
.register_aggregate_kernel(Dict.id(), Some(Mean::combined().id()), &KERNEL);
400+
let mut ctx = session.create_execution_ctx();
401+
402+
let mut acc = mean_f64_accumulator()?;
403+
acc.accumulate(&dict_of_seven(), &mut ctx)?;
404+
let partial = acc.flush()?;
405+
406+
let s = partial.as_struct();
407+
assert_eq!(
408+
s.field("sum").unwrap().as_primitive().as_::<f64>(),
409+
Some(7.0)
410+
);
411+
assert_eq!(
412+
s.field("count").unwrap().as_primitive().as_::<u64>(),
413+
Some(1)
414+
);
415+
Ok(())
416+
}
417+
418+
/// A kernel registered for the inner `(Dict, Sum)` child fires when accumulating a
419+
/// Dict batch through `Combined<Mean>`. This is the reusable-primitive case the
420+
/// refactor enables: no `(Dict, Combined<Mean>)` kernel is needed.
421+
#[test]
422+
fn child_kernel_fires_through_combined() -> VortexResult<()> {
423+
static KERNEL: SentinelSumPartialKernel = SentinelSumPartialKernel;
424+
let session = fresh_session();
425+
session
426+
.get::<AggregateFnSession>()
427+
.register_aggregate_kernel(Dict.id(), Some(Sum.id()), &KERNEL);
428+
let mut ctx = session.create_execution_ctx();
429+
430+
let mut acc = mean_f64_accumulator()?;
431+
acc.accumulate(&dict_of_seven(), &mut ctx)?;
432+
let partial = acc.flush()?;
433+
434+
let s = partial.as_struct();
435+
// `Sum` child returned the sentinel 42.0 — proves the (Dict, Sum) kernel fired
436+
// via `Combined<Mean>`'s fan-out. `Count`'s native `try_accumulate` reads the
437+
// batch's valid_count, so count is the real 1.
438+
assert_eq!(
439+
s.field("sum").unwrap().as_primitive().as_::<f64>(),
440+
Some(42.0)
441+
);
442+
assert_eq!(
443+
s.field("count").unwrap().as_primitive().as_::<u64>(),
444+
Some(1)
445+
);
446+
Ok(())
447+
}
182448
}

0 commit comments

Comments
 (0)