Skip to content

Commit 842b03a

Browse files
committed
Reorder kernel dispatch, and have Combined use inner accumulators
Signed-off-by: Nicholas Gates <nick@nickgates.com>
1 parent 7a65ca2 commit 842b03a

1 file changed

Lines changed: 55 additions & 28 deletions

File tree

vortex-array/src/aggregate_fn/accumulator.rs

Lines changed: 55 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -151,12 +151,12 @@ impl<V: AggregateFnVTable> DynAccumulator for Accumulator<V> {
151151
return Ok(());
152152
}
153153

154-
// 3. Iteratively execute one step at a time, re-checking the kernel registry against
155-
// each intermediate encoding. (The initial batch's encoding was already checked in
156-
// step 1, so execute first.)
154+
// 3. Iteratively check the registry against each intermediate encoding, executing one
155+
// step between checks. Mirrors the loop in `GroupedAccumulator::accumulate_list_view`.
156+
// Iteration 0 re-checks the initial encoding — a redundant HashMap miss, the price of
157+
// keeping the loop body uniform.
157158
let mut batch = batch.clone();
158159
for _ in 0..max_iterations() {
159-
batch = batch.execute(ctx)?;
160160
if batch.is::<AnyCanonical>() {
161161
break;
162162
}
@@ -180,6 +180,8 @@ impl<V: AggregateFnVTable> DynAccumulator for Accumulator<V> {
180180
self.vtable.combine_partials(&mut self.partial, result)?;
181181
return Ok(());
182182
}
183+
184+
batch = batch.execute(ctx)?;
183185
}
184186

185187
// 4. Otherwise, execute the batch until it is columnar and accumulate it into the state.
@@ -273,32 +275,53 @@ mod tests {
273275
use crate::scalar::Scalar;
274276
use crate::session::ArraySession;
275277

276-
/// Stub kernel that always returns the configured `Option<Scalar>`.
277-
/// `Some(_)` means "I handled this batch with this partial"; `None` means
278-
/// "kernel does not apply, fall through".
278+
/// Mean partial sentinel `{sum: 42.0, count: 1}` — distinguishable from the
279+
/// natural fan-out result `{sum: 7.0, count: 1}` that `Combined::try_accumulate`
280+
/// would produce for `dict_of_seven()`.
279281
#[derive(Debug)]
280-
struct StubKernel(Option<Scalar>);
282+
struct SentinelMeanPartialKernel;
283+
impl DynAggregateKernel for SentinelMeanPartialKernel {
284+
fn aggregate(
285+
&self,
286+
_aggregate_fn: &AggregateFnRef,
287+
_batch: &ArrayRef,
288+
_ctx: &mut ExecutionCtx,
289+
) -> VortexResult<Option<Scalar>> {
290+
Ok(Some(sentinel_partial()))
291+
}
292+
}
281293

282-
impl DynAggregateKernel for StubKernel {
294+
/// Returns `Ok(None)` => kernel declined, dispatch falls through.
295+
#[derive(Debug)]
296+
struct DeclineKernel;
297+
impl DynAggregateKernel for DeclineKernel {
283298
fn aggregate(
284299
&self,
285300
_aggregate_fn: &AggregateFnRef,
286301
_batch: &ArrayRef,
287302
_ctx: &mut ExecutionCtx,
288303
) -> VortexResult<Option<Scalar>> {
289-
Ok(self.0.clone())
304+
Ok(None)
290305
}
291306
}
292307

293-
fn session_with_stub_kernel(kernel_result: Option<Scalar>) -> VortexSession {
294-
let session = VortexSession::empty().with::<ArraySession>();
295-
// Leak the kernel so it has the `'static` lifetime the registry requires.
296-
// The session is short-lived so a couple of bytes of test-only leakage is fine.
297-
let kernel: &'static StubKernel = Box::leak(Box::new(StubKernel(kernel_result)));
298-
session
299-
.get::<AggregateFnSession>()
300-
.register_aggregate_kernel(Dict.id(), Some(Mean::combined().id()), kernel);
301-
session
308+
/// Sum partial sentinel `42.0` — distinguishable from the natural Sum of
309+
/// `dict_of_seven()` which is `7.0`.
310+
#[derive(Debug)]
311+
struct SentinelSumPartialKernel;
312+
impl DynAggregateKernel for SentinelSumPartialKernel {
313+
fn aggregate(
314+
&self,
315+
_aggregate_fn: &AggregateFnRef,
316+
_batch: &ArrayRef,
317+
_ctx: &mut ExecutionCtx,
318+
) -> VortexResult<Option<Scalar>> {
319+
Ok(Some(Scalar::primitive(42.0f64, Nullability::Nullable)))
320+
}
321+
}
322+
323+
fn fresh_session() -> VortexSession {
324+
VortexSession::empty().with::<ArraySession>()
302325
}
303326

304327
fn dict_of_seven() -> ArrayRef {
@@ -316,8 +339,6 @@ mod tests {
316339
)
317340
}
318341

319-
/// Sentinel partial: `{sum: 42.0, count: 1}`. Distinguishable from the natural
320-
/// fallback `{sum: 7.0, count: 1}` that `Combined::try_accumulate` would produce.
321342
fn sentinel_partial() -> Scalar {
322343
let acc = mean_f64_accumulator().expect("build accumulator");
323344
let sum = Scalar::primitive(42.0f64, Nullability::Nullable);
@@ -329,7 +350,11 @@ mod tests {
329350
/// `Combined::try_accumulate`'s fan-out path — proves the dispatch reorder.
330351
#[test]
331352
fn combined_kernel_fires() -> VortexResult<()> {
332-
let session = session_with_stub_kernel(Some(sentinel_partial()));
353+
static KERNEL: SentinelMeanPartialKernel = SentinelMeanPartialKernel;
354+
let session = fresh_session();
355+
session
356+
.get::<AggregateFnSession>()
357+
.register_aggregate_kernel(Dict.id(), Some(Mean::combined().id()), &KERNEL);
333358
let mut ctx = session.create_execution_ctx();
334359

335360
let mut acc = mean_f64_accumulator()?;
@@ -352,7 +377,11 @@ mod tests {
352377
/// natural fan-out. The natural partial is `{sum: 7.0, count: 1}`.
353378
#[test]
354379
fn fallback_when_kernel_declines() -> VortexResult<()> {
355-
let session = session_with_stub_kernel(None);
380+
static KERNEL: DeclineKernel = DeclineKernel;
381+
let session = fresh_session();
382+
session
383+
.get::<AggregateFnSession>()
384+
.register_aggregate_kernel(Dict.id(), Some(Mean::combined().id()), &KERNEL);
356385
let mut ctx = session.create_execution_ctx();
357386

358387
let mut acc = mean_f64_accumulator()?;
@@ -376,13 +405,11 @@ mod tests {
376405
/// refactor enables: no `(Dict, Combined<Mean>)` kernel is needed.
377406
#[test]
378407
fn child_kernel_fires_through_combined() -> VortexResult<()> {
379-
let session = VortexSession::empty().with::<ArraySession>();
380-
let sum_sentinel: &'static StubKernel = Box::leak(Box::new(StubKernel(Some(
381-
Scalar::primitive(42.0f64, Nullability::Nullable),
382-
))));
408+
static KERNEL: SentinelSumPartialKernel = SentinelSumPartialKernel;
409+
let session = fresh_session();
383410
session
384411
.get::<AggregateFnSession>()
385-
.register_aggregate_kernel(Dict.id(), Some(Sum.id()), sum_sentinel);
412+
.register_aggregate_kernel(Dict.id(), Some(Sum.id()), &KERNEL);
386413
let mut ctx = session.create_execution_ctx();
387414

388415
let mut acc = mean_f64_accumulator()?;

0 commit comments

Comments
 (0)