@@ -5,14 +5,14 @@ use vortex_error::VortexResult;
55use vortex_error:: vortex_ensure;
66use vortex_error:: vortex_err;
77
8- use crate :: AnyCanonical ;
98use crate :: ArrayRef ;
109use crate :: Columnar ;
1110use crate :: ExecutionCtx ;
1211use crate :: aggregate_fn:: AggregateFn ;
1312use crate :: aggregate_fn:: AggregateFnRef ;
1413use crate :: aggregate_fn:: AggregateFnVTable ;
1514use crate :: aggregate_fn:: session:: AggregateFnSessionExt ;
15+ use crate :: columnar:: AnyColumnar ;
1616use crate :: dtype:: DType ;
1717use crate :: executor:: max_iterations;
1818use crate :: scalar:: Scalar ;
@@ -72,9 +72,26 @@ pub trait DynAccumulator: 'static + Send {
7272 /// Accumulate a new array into the accumulator's state.
7373 fn accumulate ( & mut self , batch : & ArrayRef , ctx : & mut ExecutionCtx ) -> VortexResult < ( ) > ;
7474
75+ /// Fold an external partial-state scalar into this accumulator's state.
76+ ///
77+ /// The scalar must have the dtype reported by the vtable's `partial_dtype` for the
78+ /// options and input dtype used to construct this accumulator.
79+ fn combine_partials ( & mut self , other : Scalar ) -> VortexResult < ( ) > ;
80+
7581 /// Whether the accumulator's result is fully determined.
7682 fn is_saturated ( & self ) -> bool ;
7783
84+ /// Reset the accumulator's state to the empty group.
85+ fn reset ( & mut self ) ;
86+
87+ /// Read the current partial state as a scalar without resetting it.
88+ ///
89+ /// The returned scalar has the dtype reported by the vtable's `partial_dtype`.
90+ fn partial_scalar ( & self ) -> VortexResult < Scalar > ;
91+
92+ /// Compute the final aggregate result as a scalar without resetting state.
93+ fn final_scalar ( & self ) -> VortexResult < Scalar > ;
94+
7895 /// Flush the accumulation state and return the partial aggregate result as a scalar.
7996 ///
8097 /// Resets the accumulator state back to the initial state.
@@ -99,31 +116,75 @@ impl<V: AggregateFnVTable> DynAccumulator for Accumulator<V> {
99116 batch. dtype( )
100117 ) ;
101118
102- // Allow the vtable to short-circuit on the raw array before decompression.
103- if self . vtable . try_accumulate ( & mut self . partial , batch, ctx) ? {
119+ // 0. Stats-driven shortcut: if the aggregate can be derived directly from the batch's
120+ // cached statistics, use that and skip both kernel dispatch and decode. This is the
121+ // only layer that consults `batch.statistics()`; encoding kernels must not.
122+ if let Some ( result) = self . vtable . try_partial_from_stats ( batch) ? {
123+ vortex_ensure ! (
124+ result. dtype( ) == & self . partial_dtype,
125+ "Aggregate try_partial_from_stats returned {}, expected {}" ,
126+ result. dtype( ) ,
127+ self . partial_dtype,
128+ ) ;
129+ self . vtable . combine_partials ( & mut self . partial , result) ?;
104130 return Ok ( ( ) ) ;
105131 }
106132
107133 let session = ctx. session ( ) . clone ( ) ;
108134 let kernels = & session. aggregate_fns ( ) . kernels ;
109135
136+ // 1. Kernel registry first: a registered `(encoding, aggregate_fn)` kernel is strictly
137+ // more specific than the vtable's `try_accumulate` short-circuit. Checking the
138+ // registry first gives kernels for `Combined<V>` aggregates a chance to fire —
139+ // `Combined::try_accumulate` always returns true, so a later kernel check would be
140+ // unreachable.
141+ {
142+ let kernels_r = kernels. read ( ) ;
143+ let batch_id = batch. encoding_id ( ) ;
144+ let kernel = kernels_r
145+ . get ( & ( batch_id, Some ( self . aggregate_fn . id ( ) ) ) )
146+ . or_else ( || kernels_r. get ( & ( batch_id, None ) ) )
147+ . copied ( ) ;
148+ drop ( kernels_r) ;
149+ if let Some ( kernel) = kernel
150+ && let Some ( result) = kernel. aggregate ( & self . aggregate_fn , batch, ctx) ?
151+ {
152+ vortex_ensure ! (
153+ result. dtype( ) == & self . partial_dtype,
154+ "Aggregate kernel returned {}, expected {}" ,
155+ result. dtype( ) ,
156+ self . partial_dtype,
157+ ) ;
158+ self . vtable . combine_partials ( & mut self . partial , result) ?;
159+ return Ok ( ( ) ) ;
160+ }
161+ }
162+
163+ // 2. Allow the vtable to short-circuit on the raw array before decompression.
164+ if self . vtable . try_accumulate ( & mut self . partial , batch, ctx) ? {
165+ return Ok ( ( ) ) ;
166+ }
167+
168+ // 3. Iteratively check the registry against each intermediate encoding, executing one
169+ // step between checks. Mirrors the loop in `GroupedAccumulator::accumulate_list_view`.
170+ // Iteration 0 re-checks the initial encoding — a redundant HashMap miss, the price of
171+ // keeping the loop body uniform. Terminates on `AnyColumnar` (Canonical or Constant)
172+ // since the vtable's `accumulate(&Columnar)` handles both cases directly.
110173 let mut batch = batch. clone ( ) ;
111174 for _ in 0 ..max_iterations ( ) {
112- if batch. is :: < AnyCanonical > ( ) {
175+ if batch. is :: < AnyColumnar > ( ) {
113176 break ;
114177 }
115178
116179 let kernels_r = kernels. read ( ) ;
117180 let batch_id = batch. encoding_id ( ) ;
118- if let Some ( result ) = kernels_r
181+ let kernel = kernels_r
119182 . get ( & ( batch_id, Some ( self . aggregate_fn . id ( ) ) ) )
120183 . or_else ( || kernels_r. get ( & ( batch_id, None ) ) )
121- . and_then ( |kernel| {
122- kernel
123- . aggregate ( & self . aggregate_fn , & batch, ctx)
124- . transpose ( )
125- } )
126- . transpose ( ) ?
184+ . copied ( ) ;
185+ drop ( kernels_r) ;
186+ if let Some ( kernel) = kernel
187+ && let Some ( result) = kernel. aggregate ( & self . aggregate_fn , & batch, ctx) ?
127188 {
128189 vortex_ensure ! (
129190 result. dtype( ) == & self . partial_dtype,
@@ -135,29 +196,35 @@ impl<V: AggregateFnVTable> DynAccumulator for Accumulator<V> {
135196 return Ok ( ( ) ) ;
136197 }
137198
138- // Execute one step and try again
139199 batch = batch. execute ( ctx) ?;
140200 }
141201
142- // Otherwise, execute the batch until it is columnar and accumulate it into the state.
202+ // 4. Otherwise, execute the batch until it is columnar and accumulate it into the state.
143203 let columnar = batch. execute :: < Columnar > ( ctx) ?;
144204
145205 self . vtable . accumulate ( & mut self . partial , & columnar, ctx)
146206 }
147207
208+ fn combine_partials ( & mut self , other : Scalar ) -> VortexResult < ( ) > {
209+ self . vtable . combine_partials ( & mut self . partial , other)
210+ }
211+
148212 fn is_saturated ( & self ) -> bool {
149213 self . vtable . is_saturated ( & self . partial )
150214 }
151215
152- fn flush ( & mut self ) -> VortexResult < Scalar > {
153- let partial = self . vtable . to_scalar ( & self . partial ) ?;
216+ fn reset ( & mut self ) {
154217 self . vtable . reset ( & mut self . partial ) ;
218+ }
219+
220+ fn partial_scalar ( & self ) -> VortexResult < Scalar > {
221+ let partial = self . vtable . to_scalar ( & self . partial ) ?;
155222
156223 #[ cfg( debug_assertions) ]
157224 {
158225 vortex_ensure ! (
159226 partial. dtype( ) == & self . partial_dtype,
160- "Aggregate kernel returned incorrect DType on flush : expected {}, got {}" ,
227+ "Aggregate returned incorrect DType on partial_scalar : expected {}, got {}" ,
161228 self . partial_dtype,
162229 partial. dtype( ) ,
163230 ) ;
@@ -166,17 +233,216 @@ impl<V: AggregateFnVTable> DynAccumulator for Accumulator<V> {
166233 Ok ( partial)
167234 }
168235
169- fn finish ( & mut self ) -> VortexResult < Scalar > {
236+ fn final_scalar ( & self ) -> VortexResult < Scalar > {
170237 let result = self . vtable . finalize_scalar ( & self . partial ) ?;
171- self . vtable . reset ( & mut self . partial ) ;
172238
173239 vortex_ensure ! (
174240 result. dtype( ) == & self . return_dtype,
175- "Aggregate kernel returned incorrect DType on finalize : expected {}, got {}" ,
241+ "Aggregate returned incorrect DType on final_scalar : expected {}, got {}" ,
176242 self . return_dtype,
177243 result. dtype( ) ,
178244 ) ;
179245
180246 Ok ( result)
181247 }
248+
249+ fn flush ( & mut self ) -> VortexResult < Scalar > {
250+ let partial = self . partial_scalar ( ) ?;
251+ self . reset ( ) ;
252+ Ok ( partial)
253+ }
254+
255+ fn finish ( & mut self ) -> VortexResult < Scalar > {
256+ let result = self . final_scalar ( ) ?;
257+ self . reset ( ) ;
258+ Ok ( result)
259+ }
260+ }
261+
262+ #[ cfg( test) ]
263+ mod tests {
264+ use vortex_buffer:: buffer;
265+ use vortex_error:: VortexResult ;
266+ use vortex_session:: SessionExt ;
267+ use vortex_session:: VortexSession ;
268+
269+ use crate :: ArrayRef ;
270+ use crate :: ExecutionCtx ;
271+ use crate :: IntoArray ;
272+ use crate :: VortexSessionExecute ;
273+ use crate :: aggregate_fn:: Accumulator ;
274+ use crate :: aggregate_fn:: AggregateFnRef ;
275+ use crate :: aggregate_fn:: AggregateFnVTable ;
276+ use crate :: aggregate_fn:: DynAccumulator ;
277+ use crate :: aggregate_fn:: EmptyOptions ;
278+ use crate :: aggregate_fn:: combined:: Combined ;
279+ use crate :: aggregate_fn:: combined:: PairOptions ;
280+ use crate :: aggregate_fn:: fns:: mean:: Mean ;
281+ use crate :: aggregate_fn:: fns:: sum:: Sum ;
282+ use crate :: aggregate_fn:: kernels:: DynAggregateKernel ;
283+ use crate :: aggregate_fn:: session:: AggregateFnSession ;
284+ use crate :: array:: VTable ;
285+ use crate :: arrays:: Dict ;
286+ use crate :: arrays:: DictArray ;
287+ use crate :: dtype:: DType ;
288+ use crate :: dtype:: Nullability ;
289+ use crate :: dtype:: PType ;
290+ use crate :: scalar:: Scalar ;
291+ use crate :: session:: ArraySession ;
292+
293+ /// Mean partial sentinel `{sum: 42.0, count: 1}` — distinguishable from the
294+ /// natural fan-out result `{sum: 7.0, count: 1}` that `Combined::try_accumulate`
295+ /// would produce for `dict_of_seven()`.
296+ #[ derive( Debug ) ]
297+ struct SentinelMeanPartialKernel ;
298+ impl DynAggregateKernel for SentinelMeanPartialKernel {
299+ fn aggregate (
300+ & self ,
301+ _aggregate_fn : & AggregateFnRef ,
302+ _batch : & ArrayRef ,
303+ _ctx : & mut ExecutionCtx ,
304+ ) -> VortexResult < Option < Scalar > > {
305+ Ok ( Some ( sentinel_partial ( ) ) )
306+ }
307+ }
308+
309+ /// Returns `Ok(None)` => kernel declined, dispatch falls through.
310+ #[ derive( Debug ) ]
311+ struct DeclineKernel ;
312+ impl DynAggregateKernel for DeclineKernel {
313+ fn aggregate (
314+ & self ,
315+ _aggregate_fn : & AggregateFnRef ,
316+ _batch : & ArrayRef ,
317+ _ctx : & mut ExecutionCtx ,
318+ ) -> VortexResult < Option < Scalar > > {
319+ Ok ( None )
320+ }
321+ }
322+
323+ /// Sum partial sentinel `42.0` — distinguishable from the natural Sum of
324+ /// `dict_of_seven()` which is `7.0`.
325+ #[ derive( Debug ) ]
326+ struct SentinelSumPartialKernel ;
327+ impl DynAggregateKernel for SentinelSumPartialKernel {
328+ fn aggregate (
329+ & self ,
330+ _aggregate_fn : & AggregateFnRef ,
331+ _batch : & ArrayRef ,
332+ _ctx : & mut ExecutionCtx ,
333+ ) -> VortexResult < Option < Scalar > > {
334+ Ok ( Some ( Scalar :: primitive ( 42.0f64 , Nullability :: Nullable ) ) )
335+ }
336+ }
337+
338+ fn fresh_session ( ) -> VortexSession {
339+ VortexSession :: empty ( ) . with :: < ArraySession > ( )
340+ }
341+
342+ fn dict_of_seven ( ) -> ArrayRef {
343+ DictArray :: try_new ( buffer ! [ 0u32 ] . into_array ( ) , buffer ! [ 7.0f64 ] . into_array ( ) )
344+ . expect ( "valid dictionary" )
345+ . into_array ( )
346+ }
347+
348+ fn mean_f64_accumulator ( ) -> VortexResult < Accumulator < Combined < Mean > > > {
349+ let dtype = DType :: Primitive ( PType :: F64 , Nullability :: NonNullable ) ;
350+ Accumulator :: try_new (
351+ Mean :: combined ( ) ,
352+ PairOptions ( EmptyOptions , EmptyOptions ) ,
353+ dtype,
354+ )
355+ }
356+
357+ fn sentinel_partial ( ) -> Scalar {
358+ let acc = mean_f64_accumulator ( ) . expect ( "build accumulator" ) ;
359+ let sum = Scalar :: primitive ( 42.0f64 , Nullability :: Nullable ) ;
360+ let count = Scalar :: primitive ( 1u64 , Nullability :: NonNullable ) ;
361+ Scalar :: struct_ ( acc. partial_dtype , vec ! [ sum, count] )
362+ }
363+
364+ /// Kernel registered for `(Dict, Combined<Mean>)` fires in preference to
365+ /// `Combined::try_accumulate`'s fan-out path — proves the dispatch reorder.
366+ #[ test]
367+ fn combined_kernel_fires ( ) -> VortexResult < ( ) > {
368+ static KERNEL : SentinelMeanPartialKernel = SentinelMeanPartialKernel ;
369+ let session = fresh_session ( ) ;
370+ session
371+ . get :: < AggregateFnSession > ( )
372+ . register_aggregate_kernel ( Dict . id ( ) , Some ( Mean :: combined ( ) . id ( ) ) , & KERNEL ) ;
373+ let mut ctx = session. create_execution_ctx ( ) ;
374+
375+ let mut acc = mean_f64_accumulator ( ) ?;
376+ acc. accumulate ( & dict_of_seven ( ) , & mut ctx) ?;
377+ let partial = acc. flush ( ) ?;
378+
379+ let s = partial. as_struct ( ) ;
380+ assert_eq ! (
381+ s. field( "sum" ) . unwrap( ) . as_primitive( ) . as_:: <f64 >( ) ,
382+ Some ( 42.0 )
383+ ) ;
384+ assert_eq ! (
385+ s. field( "count" ) . unwrap( ) . as_primitive( ) . as_:: <u64 >( ) ,
386+ Some ( 1 )
387+ ) ;
388+ Ok ( ( ) )
389+ }
390+
391+ /// Kernel returns `Ok(None)` => dispatch falls through to `Combined::try_accumulate`'s
392+ /// natural fan-out. The natural partial is `{sum: 7.0, count: 1}`.
393+ #[ test]
394+ fn fallback_when_kernel_declines ( ) -> VortexResult < ( ) > {
395+ static KERNEL : DeclineKernel = DeclineKernel ;
396+ let session = fresh_session ( ) ;
397+ session
398+ . get :: < AggregateFnSession > ( )
399+ . register_aggregate_kernel ( Dict . id ( ) , Some ( Mean :: combined ( ) . id ( ) ) , & KERNEL ) ;
400+ let mut ctx = session. create_execution_ctx ( ) ;
401+
402+ let mut acc = mean_f64_accumulator ( ) ?;
403+ acc. accumulate ( & dict_of_seven ( ) , & mut ctx) ?;
404+ let partial = acc. flush ( ) ?;
405+
406+ let s = partial. as_struct ( ) ;
407+ assert_eq ! (
408+ s. field( "sum" ) . unwrap( ) . as_primitive( ) . as_:: <f64 >( ) ,
409+ Some ( 7.0 )
410+ ) ;
411+ assert_eq ! (
412+ s. field( "count" ) . unwrap( ) . as_primitive( ) . as_:: <u64 >( ) ,
413+ Some ( 1 )
414+ ) ;
415+ Ok ( ( ) )
416+ }
417+
418+ /// A kernel registered for the inner `(Dict, Sum)` child fires when accumulating a
419+ /// Dict batch through `Combined<Mean>`. This is the reusable-primitive case the
420+ /// refactor enables: no `(Dict, Combined<Mean>)` kernel is needed.
421+ #[ test]
422+ fn child_kernel_fires_through_combined ( ) -> VortexResult < ( ) > {
423+ static KERNEL : SentinelSumPartialKernel = SentinelSumPartialKernel ;
424+ let session = fresh_session ( ) ;
425+ session
426+ . get :: < AggregateFnSession > ( )
427+ . register_aggregate_kernel ( Dict . id ( ) , Some ( Sum . id ( ) ) , & KERNEL ) ;
428+ let mut ctx = session. create_execution_ctx ( ) ;
429+
430+ let mut acc = mean_f64_accumulator ( ) ?;
431+ acc. accumulate ( & dict_of_seven ( ) , & mut ctx) ?;
432+ let partial = acc. flush ( ) ?;
433+
434+ let s = partial. as_struct ( ) ;
435+ // `Sum` child returned the sentinel 42.0 — proves the (Dict, Sum) kernel fired
436+ // via `Combined<Mean>`'s fan-out. `Count`'s native `try_accumulate` reads the
437+ // batch's valid_count, so count is the real 1.
438+ assert_eq ! (
439+ s. field( "sum" ) . unwrap( ) . as_primitive( ) . as_:: <f64 >( ) ,
440+ Some ( 42.0 )
441+ ) ;
442+ assert_eq ! (
443+ s. field( "count" ) . unwrap( ) . as_primitive( ) . as_:: <u64 >( ) ,
444+ Some ( 1 )
445+ ) ;
446+ Ok ( ( ) )
447+ }
182448}
0 commit comments