@@ -73,7 +73,7 @@ impl AggregateUDFImpl for Avg {
7373 }
7474
7575 fn accumulator ( & self , _acc_args : AccumulatorArgs ) -> Result < Box < dyn Accumulator > > {
76- // instantiate specialized accumulator based for the type
76+ // All numeric types use Float64 accumulation after casting
7777 match ( & self . input_data_type , & self . result_data_type ) {
7878 ( Float64 , Float64 ) => Ok ( Box :: < AvgAccumulator > :: default ( ) ) ,
7979 _ => not_impl_err ! (
@@ -115,7 +115,6 @@ impl AggregateUDFImpl for Avg {
115115 & self ,
116116 _args : AccumulatorArgs ,
117117 ) -> Result < Box < dyn GroupsAccumulator > > {
118- // instantiate specialized accumulator based for the type
119118 match ( & self . input_data_type , & self . result_data_type ) {
120119 ( Float64 , Float64 ) => Ok ( Box :: new ( AvgGroupsAccumulator :: < Float64Type , _ > :: new (
121120 & self . input_data_type ,
@@ -172,7 +171,7 @@ impl Accumulator for AvgAccumulator {
172171 // counts are summed
173172 self . count += sum ( states[ 1 ] . as_primitive :: < Int64Type > ( ) ) . unwrap_or_default ( ) ;
174173
175- // sums are summed
174+ // sums are summed - no overflow checking in all Eval Modes
176175 if let Some ( x) = sum ( states[ 0 ] . as_primitive :: < Float64Type > ( ) ) {
177176 let v = self . sum . get_or_insert ( 0. ) ;
178177 * v += x;
@@ -182,7 +181,7 @@ impl Accumulator for AvgAccumulator {
182181
183182 fn evaluate ( & mut self ) -> Result < ScalarValue > {
184183 if self . count == 0 {
185- // If all input are nulls, count will be 0 and we will get null after the division.
184+ // If all input are nulls, count will be 0, and we will get null after the division.
186185 // This is consistent with Spark Average implementation.
187186 Ok ( ScalarValue :: Float64 ( None ) )
188187 } else {
@@ -198,7 +197,8 @@ impl Accumulator for AvgAccumulator {
198197}
199198
200199/// An accumulator to compute the average of `[PrimitiveArray<T>]`.
201- /// Stores values as native types, and does overflow checking
200+ /// Stores values as native types (
201+ /// no overflow check all eval modes since inf is a perfectly valid value per spark impl)
202202///
203203/// F: Function that calculates the average value from a sum of
204204/// T::Native and a total count
@@ -260,6 +260,7 @@ where
260260 if values. null_count ( ) == 0 {
261261 for ( & group_index, & value) in iter {
262262 let sum = & mut self . sums [ group_index] ;
263+ // No overflow checking - Infinity is a valid result
263264 * sum = ( * sum) . add_wrapping ( value) ;
264265 self . counts [ group_index] += 1 ;
265266 }
@@ -296,7 +297,7 @@ where
296297 self . counts [ group_index] += partial_count;
297298 }
298299
299- // update sums
300+ // update sums - no overflow checking (in all eval modes)
300301 self . sums . resize ( total_num_groups, T :: default_value ( ) ) ;
301302 let iter2 = group_indices. iter ( ) . zip ( partial_sums. values ( ) . iter ( ) ) ;
302303 for ( & group_index, & new_value) in iter2 {
@@ -325,7 +326,6 @@ where
325326 Ok ( Arc :: new ( array) )
326327 }
327328
328- // return arrays for sums and counts
329329 fn state ( & mut self , emit_to : EmitTo ) -> Result < Vec < ArrayRef > > {
330330 let counts = emit_to. take_needed ( & mut self . counts ) ;
331331 let counts = Int64Array :: new ( counts. into ( ) , None ) ;
0 commit comments