@@ -7,9 +7,9 @@ use vortex_error::VortexExpect;
77use vortex_error:: VortexResult ;
88use vortex_error:: vortex_bail;
99use vortex_error:: vortex_ensure;
10+ use vortex_error:: vortex_err;
1011use vortex_error:: vortex_panic;
1112use vortex_mask:: Mask ;
12- use vortex_session:: VortexSession ;
1313
1414use crate :: AnyCanonical ;
1515use crate :: ArrayRef ;
@@ -18,7 +18,6 @@ use crate::Columnar;
1818use crate :: DynArray ;
1919use crate :: ExecutionCtx ;
2020use crate :: IntoArray ;
21- use crate :: VortexSessionExecute ;
2221use crate :: aggregate_fn:: Accumulator ;
2322use crate :: aggregate_fn:: AggregateFn ;
2423use crate :: aggregate_fn:: AggregateFnRef ;
@@ -58,20 +57,25 @@ pub struct GroupedAccumulator<V: AggregateFnVTable> {
5857 partial_dtype : DType ,
5958 /// The accumulated state for prior batches of groups.
6059 partials : Vec < ArrayRef > ,
61- /// A session used to lookup custom aggregate kernels.
62- session : VortexSession ,
6360}
6461
6562impl < V : AggregateFnVTable > GroupedAccumulator < V > {
66- pub fn try_new (
67- vtable : V ,
68- options : V :: Options ,
69- dtype : DType ,
70- session : VortexSession ,
71- ) -> VortexResult < Self > {
63+ pub fn try_new ( vtable : V , options : V :: Options , dtype : DType ) -> VortexResult < Self > {
7264 let aggregate_fn = AggregateFn :: new ( vtable. clone ( ) , options. clone ( ) ) . erased ( ) ;
73- let return_dtype = vtable. return_dtype ( & options, & dtype) ?;
74- let partial_dtype = vtable. partial_dtype ( & options, & dtype) ?;
65+ let return_dtype = vtable. return_dtype ( & options, & dtype) . ok_or_else ( || {
66+ vortex_err ! (
67+ "Aggregate function {} cannot be applied to dtype {}" ,
68+ vtable. id( ) ,
69+ dtype
70+ )
71+ } ) ?;
72+ let partial_dtype = vtable. partial_dtype ( & options, & dtype) . ok_or_else ( || {
73+ vortex_err ! (
74+ "Aggregate function {} cannot be applied to dtype {}" ,
75+ vtable. id( ) ,
76+ dtype
77+ )
78+ } ) ?;
7579
7680 Ok ( Self {
7781 vtable,
@@ -81,7 +85,6 @@ impl<V: AggregateFnVTable> GroupedAccumulator<V> {
8185 return_dtype,
8286 partial_dtype,
8387 partials : vec ! [ ] ,
84- session,
8588 } )
8689 }
8790}
@@ -90,7 +93,7 @@ impl<V: AggregateFnVTable> GroupedAccumulator<V> {
9093/// function is not known at compile time.
9194pub trait DynGroupedAccumulator : ' static + Send {
9295 /// Accumulate a list of groups into the accumulator.
93- fn accumulate_list ( & mut self , groups : & ArrayRef ) -> VortexResult < ( ) > ;
96+ fn accumulate_list ( & mut self , groups : & ArrayRef , ctx : & mut ExecutionCtx ) -> VortexResult < ( ) > ;
9497
9598 /// Finish the accumulation and return the partial aggregate results for all groups.
9699 /// Resets the accumulator state for the next round of accumulation.
@@ -102,7 +105,7 @@ pub trait DynGroupedAccumulator: 'static + Send {
102105}
103106
104107impl < V : AggregateFnVTable > DynGroupedAccumulator for GroupedAccumulator < V > {
105- fn accumulate_list ( & mut self , groups : & ArrayRef ) -> VortexResult < ( ) > {
108+ fn accumulate_list ( & mut self , groups : & ArrayRef , ctx : & mut ExecutionCtx ) -> VortexResult < ( ) > {
106109 let elements_dtype = match groups. dtype ( ) {
107110 DType :: List ( elem, _) => elem,
108111 DType :: FixedSizeList ( elem, ..) => elem,
@@ -118,17 +121,15 @@ impl<V: AggregateFnVTable> DynGroupedAccumulator for GroupedAccumulator<V> {
118121 elements_dtype
119122 ) ;
120123
121- let mut ctx = self . session . create_execution_ctx ( ) ;
122-
123124 // We first execute the groups until it is a ListView or FixedSizeList, since we only
124125 // dispatch the aggregate kernel over the elements of these arrays.
125- let canonical = match groups. clone ( ) . execute :: < Columnar > ( & mut ctx) ? {
126+ let canonical = match groups. clone ( ) . execute :: < Columnar > ( ctx) ? {
126127 Columnar :: Canonical ( c) => c,
127- Columnar :: Constant ( c) => c. into_array ( ) . execute :: < Canonical > ( & mut ctx) ?,
128+ Columnar :: Constant ( c) => c. into_array ( ) . execute :: < Canonical > ( ctx) ?,
128129 } ;
129130 match canonical {
130- Canonical :: List ( groups) => self . accumulate_list_view ( & groups, & mut ctx) ,
131- Canonical :: FixedSizeList ( groups) => self . accumulate_fixed_size_list ( & groups, & mut ctx) ,
131+ Canonical :: List ( groups) => self . accumulate_list_view ( & groups, ctx) ,
132+ Canonical :: FixedSizeList ( groups) => self . accumulate_fixed_size_list ( & groups, ctx) ,
132133 _ => vortex_panic ! ( "We checked the DType above, so this should never happen" ) ,
133134 }
134135 }
@@ -160,8 +161,7 @@ impl<V: AggregateFnVTable> GroupedAccumulator<V> {
160161 ctx : & mut ExecutionCtx ,
161162 ) -> VortexResult < ( ) > {
162163 let mut elements = groups. elements ( ) . clone ( ) ;
163- let session = self . session . clone ( ) ;
164-
164+ let session = ctx. session ( ) . clone ( ) ;
165165 let kernels = & session. aggregate_fns ( ) . grouped_kernels ;
166166
167167 for _ in 0 ..* MAX_ITERATIONS {
@@ -205,7 +205,13 @@ impl<V: AggregateFnVTable> GroupedAccumulator<V> {
205205 match_each_integer_ptype ! ( offsets. dtype( ) . as_ptype( ) , |O | {
206206 let offsets = offsets. clone( ) . execute:: <Buffer <O >>( ctx) ?;
207207 let sizes = sizes. execute:: <Buffer <O >>( ctx) ?;
208- self . accumulate_list_view_typed( & elements, offsets. as_ref( ) , sizes. as_ref( ) , & validity)
208+ self . accumulate_list_view_typed(
209+ & elements,
210+ offsets. as_ref( ) ,
211+ sizes. as_ref( ) ,
212+ & validity,
213+ ctx,
214+ )
209215 } )
210216 }
211217
@@ -215,12 +221,12 @@ impl<V: AggregateFnVTable> GroupedAccumulator<V> {
215221 offsets : & [ O ] ,
216222 sizes : & [ O ] ,
217223 validity : & Mask ,
224+ ctx : & mut ExecutionCtx ,
218225 ) -> VortexResult < ( ) > {
219226 let mut accumulator = Accumulator :: try_new (
220227 self . vtable . clone ( ) ,
221228 self . options . clone ( ) ,
222229 self . dtype . clone ( ) ,
223- self . session . clone ( ) ,
224230 ) ?;
225231 let mut states = builder_with_capacity ( & self . partial_dtype , offsets. len ( ) ) ;
226232
@@ -230,7 +236,7 @@ impl<V: AggregateFnVTable> GroupedAccumulator<V> {
230236
231237 if validity. value ( offset) {
232238 let group = elements. slice ( offset..offset + size) ?;
233- accumulator. accumulate ( & group) ?;
239+ accumulator. accumulate ( & group, ctx ) ?;
234240 states. append_scalar ( & accumulator. finish ( ) ?) ?;
235241 } else {
236242 states. append_null ( )
@@ -246,8 +252,7 @@ impl<V: AggregateFnVTable> GroupedAccumulator<V> {
246252 ctx : & mut ExecutionCtx ,
247253 ) -> VortexResult < ( ) > {
248254 let mut elements = groups. elements ( ) . clone ( ) ;
249-
250- let session = self . session . clone ( ) ;
255+ let session = ctx. session ( ) . clone ( ) ;
251256 let kernels = & session. aggregate_fns ( ) . grouped_kernels ;
252257
253258 for _ in 0 ..64 {
@@ -291,7 +296,6 @@ impl<V: AggregateFnVTable> GroupedAccumulator<V> {
291296 self . vtable . clone ( ) ,
292297 self . options . clone ( ) ,
293298 self . dtype . clone ( ) ,
294- self . session . clone ( ) ,
295299 ) ?;
296300 let mut states = builder_with_capacity ( & self . partial_dtype , groups. len ( ) ) ;
297301
@@ -304,7 +308,7 @@ impl<V: AggregateFnVTable> GroupedAccumulator<V> {
304308 for i in 0 ..groups. len ( ) {
305309 if validity. value ( i) {
306310 let group = elements. slice ( offset..offset + size) ?;
307- accumulator. accumulate ( & group) ?;
311+ accumulator. accumulate ( & group, ctx ) ?;
308312 states. append_scalar ( & accumulator. finish ( ) ?) ?;
309313 } else {
310314 states. append_null ( )
0 commit comments