diff --git a/native/spark-expr/benches/aggregate.rs b/native/spark-expr/benches/aggregate.rs index 72628975b3..47e2cf61c3 100644 --- a/native/spark-expr/benches/aggregate.rs +++ b/native/spark-expr/benches/aggregate.rs @@ -15,8 +15,8 @@ // specific language governing permissions and limitations // under the License.use arrow::array::{ArrayRef, BooleanBuilder, Int32Builder, RecordBatch, StringBuilder}; -use arrow::array::builder::{Decimal128Builder, StringBuilder}; -use arrow::array::{ArrayRef, RecordBatch}; +use arrow::array::builder::{Decimal128Builder, Int64Builder, StringBuilder}; +use arrow::array::{ArrayRef, Int64Array, RecordBatch}; use arrow::datatypes::SchemaRef; use arrow::datatypes::{DataType, Field, Schema}; use criterion::{criterion_group, criterion_main, Criterion}; @@ -25,14 +25,14 @@ use datafusion::datasource::source::DataSourceExec; use datafusion::execution::TaskContext; use datafusion::functions_aggregate::average::avg_udaf; use datafusion::functions_aggregate::sum::sum_udaf; -use datafusion::logical_expr::AggregateUDF; +use datafusion::logical_expr::function::AccumulatorArgs; +use datafusion::logical_expr::{AggregateUDF, AggregateUDFImpl, EmitTo}; use datafusion::physical_expr::aggregate::AggregateExprBuilder; use datafusion::physical_expr::expressions::Column; use datafusion::physical_expr::PhysicalExpr; use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; use datafusion::physical_plan::ExecutionPlan; -use datafusion_comet_spark_expr::SumDecimal; -use datafusion_comet_spark_expr::{AvgDecimal, EvalMode}; +use datafusion_comet_spark_expr::{AvgDecimal, EvalMode, SumDecimal, SumInteger}; use futures::StreamExt; use std::hint::black_box; use std::sync::Arc; @@ -111,6 +111,153 @@ fn criterion_benchmark(c: &mut Criterion) { }); group.finish(); + + // SumInteger benchmarks + let mut group = c.benchmark_group("sum_integer"); + let int_batch = create_int64_record_batch(num_rows); + let mut int_batches = Vec::new(); + for _ in 0..10 { + int_batches.push(int_batch.clone()); + } + let int_partitions = &[int_batches]; + let int_c0: Arc = Arc::new(Column::new("c0", 0)); + let int_c1: Arc = Arc::new(Column::new("c1", 1)); + + group.bench_function("sum_int64_datafusion", |b| { + let datafusion_sum = sum_udaf(); + b.to_async(&rt).iter(|| { + black_box(agg_test( + int_partitions, + int_c0.clone(), + int_c1.clone(), + datafusion_sum.clone(), + "sum", + )) + }) + }); + + group.bench_function("sum_int64_comet_legacy", |b| { + let comet_sum = Arc::new(AggregateUDF::new_from_impl( + SumInteger::try_new(DataType::Int64, EvalMode::Legacy).unwrap(), + )); + b.to_async(&rt).iter(|| { + black_box(agg_test( + int_partitions, + int_c0.clone(), + int_c1.clone(), + comet_sum.clone(), + "sum", + )) + }) + }); + + group.bench_function("sum_int64_comet_ansi", |b| { + let comet_sum = Arc::new(AggregateUDF::new_from_impl( + SumInteger::try_new(DataType::Int64, EvalMode::Ansi).unwrap(), + )); + b.to_async(&rt).iter(|| { + black_box(agg_test( + int_partitions, + int_c0.clone(), + int_c1.clone(), + comet_sum.clone(), + "sum", + )) + }) + }); + + group.bench_function("sum_int64_comet_try", |b| { + let comet_sum = Arc::new(AggregateUDF::new_from_impl( + SumInteger::try_new(DataType::Int64, EvalMode::Try).unwrap(), + )); + b.to_async(&rt).iter(|| { + black_box(agg_test( + int_partitions, + int_c0.clone(), + int_c1.clone(), + comet_sum.clone(), + "sum", + )) + }) + }); + + group.finish(); + + // Direct accumulator benchmarks (bypassing execution framework) + let mut group = c.benchmark_group("sum_integer_accumulator"); + let int64_array: ArrayRef = Arc::new(Int64Array::from_iter_values(0..8192i64)); + let arrays: Vec = vec![int64_array]; + + let return_field = Arc::new(Field::new("sum", DataType::Int64, true)); + let schema = Schema::new(vec![Field::new("c0", DataType::Int64, true)]); + let expr_field = Arc::new(Field::new("c0", DataType::Int64, true)); + let expr_fields: Vec> = vec![expr_field]; + + // Single-row Accumulator benchmarks + for (name, eval_mode) in [ + ("row_legacy", EvalMode::Legacy), + ("row_ansi", EvalMode::Ansi), + ("row_try", EvalMode::Try), + ] { + let return_field = return_field.clone(); + let expr_fields = expr_fields.clone(); + group.bench_function(name, |b| { + let udf = SumInteger::try_new(DataType::Int64, eval_mode).unwrap(); + b.iter(|| { + let acc_args = AccumulatorArgs { + return_field: return_field.clone(), + schema: &schema, + ignore_nulls: false, + order_bys: &[], + name: "sum", + is_distinct: false, + is_reversed: false, + exprs: &[], + expr_fields: &expr_fields, + }; + let mut acc = udf.accumulator(acc_args).unwrap(); + for _ in 0..10 { + acc.update_batch(&arrays).unwrap(); + } + black_box(acc.evaluate().unwrap()) + }) + }); + } + + // GroupsAccumulator benchmarks + let group_indices: Vec = (0..8192).map(|i| i % 1024).collect(); + for (name, eval_mode) in [ + ("groups_legacy", EvalMode::Legacy), + ("groups_ansi", EvalMode::Ansi), + ("groups_try", EvalMode::Try), + ] { + let return_field = return_field.clone(); + let expr_fields = expr_fields.clone(); + group.bench_function(name, |b| { + let udf = SumInteger::try_new(DataType::Int64, eval_mode).unwrap(); + b.iter(|| { + let acc_args = AccumulatorArgs { + return_field: return_field.clone(), + schema: &schema, + ignore_nulls: false, + order_bys: &[], + name: "sum", + is_distinct: false, + is_reversed: false, + exprs: &[], + expr_fields: &expr_fields, + }; + let mut acc = udf.create_groups_accumulator(acc_args).unwrap(); + for _ in 0..10 { + acc.update_batch(&arrays, &group_indices, None, 1024) + .unwrap(); + } + black_box(acc.evaluate(EmitTo::All).unwrap()) + }) + }); + } + + group.finish(); } async fn agg_test( @@ -187,6 +334,31 @@ fn create_record_batch(num_rows: usize) -> RecordBatch { RecordBatch::try_new(Arc::new(schema), columns).unwrap() } +fn create_int64_record_batch(num_rows: usize) -> RecordBatch { + let mut int64_builder = Int64Builder::with_capacity(num_rows); + let mut string_builder = StringBuilder::with_capacity(num_rows, num_rows * 32); + for i in 0..num_rows { + int64_builder.append_value(i as i64); + string_builder.append_value(format!("group_{}", i % 1024)); + } + let int64_array = Arc::new(int64_builder.finish()); + let string_array = Arc::new(string_builder.finish()); + + let mut fields = vec![]; + let mut columns: Vec = vec![]; + + // string column for grouping + fields.push(Field::new("c0", DataType::Utf8, false)); + columns.push(string_array); + + // int64 column for summing + fields.push(Field::new("c1", DataType::Int64, false)); + columns.push(int64_array); + + let schema = Schema::new(fields); + RecordBatch::try_new(Arc::new(schema), columns).unwrap() +} + fn config() -> Criterion { Criterion::default() .measurement_time(Duration::from_millis(500)) diff --git a/native/spark-expr/src/agg_funcs/sum_int.rs b/native/spark-expr/src/agg_funcs/sum_int.rs index d226c5eded..2ea07c743e 100644 --- a/native/spark-expr/src/agg_funcs/sum_int.rs +++ b/native/spark-expr/src/agg_funcs/sum_int.rs @@ -69,7 +69,11 @@ impl AggregateUDFImpl for SumInteger { } fn accumulator(&self, _acc_args: AccumulatorArgs) -> DFResult> { - Ok(Box::new(SumIntegerAccumulator::new(self.eval_mode))) + match self.eval_mode { + EvalMode::Legacy => Ok(Box::new(SumIntegerAccumulatorLegacy::new())), + EvalMode::Ansi => Ok(Box::new(SumIntegerAccumulatorAnsi::new())), + EvalMode::Try => Ok(Box::new(SumIntegerAccumulatorTry::new())), + } } fn state_fields(&self, _args: StateFieldsArgs) -> DFResult> { @@ -91,7 +95,11 @@ impl AggregateUDFImpl for SumInteger { &self, _args: AccumulatorArgs, ) -> DFResult> { - Ok(Box::new(SumIntGroupsAccumulator::new(self.eval_mode))) + match self.eval_mode { + EvalMode::Legacy => Ok(Box::new(SumIntGroupsAccumulatorLegacy::new())), + EvalMode::Ansi => Ok(Box::new(SumIntGroupsAccumulatorAnsi::new())), + EvalMode::Try => Ok(Box::new(SumIntGroupsAccumulatorTry::new())), + } } fn reverse_expr(&self) -> ReversedUDAF { @@ -100,39 +108,222 @@ impl AggregateUDFImpl for SumInteger { } #[derive(Debug)] -struct SumIntegerAccumulator { +struct SumIntegerAccumulatorLegacy { sum: Option, - eval_mode: EvalMode, - has_all_nulls: bool, } -impl SumIntegerAccumulator { - fn new(eval_mode: EvalMode) -> Self { - if eval_mode == EvalMode::Try { - Self { - // Try mode starts with 0 (because if this is init to None we cant say if it is none due to all nulls or due to an overflow) - sum: Some(0), - has_all_nulls: true, - eval_mode, +impl SumIntegerAccumulatorLegacy { + fn new() -> Self { + Self { sum: None } + } +} + +impl Accumulator for SumIntegerAccumulatorLegacy { + fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> { + fn update_sum(int_array: &PrimitiveArray, mut sum: i64) -> DFResult + where + T: ArrowPrimitiveType, + { + for i in 0..int_array.len() { + if !int_array.is_null(i) { + let v = int_array.value(i).to_i64().ok_or_else(|| { + DataFusionError::Internal(format!( + "Failed to convert value {:?} to i64", + int_array.value(i) + )) + })?; + sum = v.add_wrapping(sum); + } + } + Ok(sum) + } + + let values = &values[0]; + if values.len() == values.null_count() { + return Ok(()); + } + + let running_sum = self.sum.unwrap_or(0); + let sum = match values.data_type() { + DataType::Int64 => update_sum(as_primitive_array::(values), running_sum)?, + DataType::Int32 => update_sum(as_primitive_array::(values), running_sum)?, + DataType::Int16 => update_sum(as_primitive_array::(values), running_sum)?, + DataType::Int8 => update_sum(as_primitive_array::(values), running_sum)?, + _ => { + return Err(DataFusionError::Internal(format!( + "unsupported data type: {:?}", + values.data_type() + ))); } + }; + self.sum = Some(sum); + Ok(()) + } + + fn evaluate(&mut self) -> DFResult { + Ok(ScalarValue::Int64(self.sum)) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } + + fn state(&mut self) -> DFResult> { + Ok(vec![ScalarValue::Int64(self.sum)]) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> { + if states.len() != 1 { + return Err(DataFusionError::Internal(format!( + "Invalid state while merging batch. Expected 1 element but found {}", + states.len() + ))); + } + + let that_sum_array = states[0].as_primitive::(); + let that_sum = if that_sum_array.is_null(0) { + None } else { - Self { - sum: None, - has_all_nulls: false, - eval_mode, + Some(that_sum_array.value(0)) + }; + + if that_sum.is_none() { + return Ok(()); + } + if self.sum.is_none() { + self.sum = that_sum; + return Ok(()); + } + + self.sum = Some(self.sum.unwrap().add_wrapping(that_sum.unwrap())); + Ok(()) + } +} + +#[derive(Debug)] +struct SumIntegerAccumulatorAnsi { + sum: Option, +} + +impl SumIntegerAccumulatorAnsi { + fn new() -> Self { + Self { sum: None } + } +} + +impl Accumulator for SumIntegerAccumulatorAnsi { + fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> { + fn update_sum(int_array: &PrimitiveArray, mut sum: i64) -> DFResult + where + T: ArrowPrimitiveType, + { + for i in 0..int_array.len() { + if !int_array.is_null(i) { + let v = int_array.value(i).to_i64().ok_or_else(|| { + DataFusionError::Internal(format!( + "Failed to convert value {:?} to i64", + int_array.value(i) + )) + })?; + sum = v + .add_checked(sum) + .map_err(|_| DataFusionError::from(arithmetic_overflow_error("integer")))?; + } + } + Ok(sum) + } + + let values = &values[0]; + if values.len() == values.null_count() { + return Ok(()); + } + + let running_sum = self.sum.unwrap_or(0); + let sum = match values.data_type() { + DataType::Int64 => update_sum(as_primitive_array::(values), running_sum)?, + DataType::Int32 => update_sum(as_primitive_array::(values), running_sum)?, + DataType::Int16 => update_sum(as_primitive_array::(values), running_sum)?, + DataType::Int8 => update_sum(as_primitive_array::(values), running_sum)?, + _ => { + return Err(DataFusionError::Internal(format!( + "unsupported data type: {:?}", + values.data_type() + ))); } + }; + self.sum = Some(sum); + Ok(()) + } + + fn evaluate(&mut self) -> DFResult { + Ok(ScalarValue::Int64(self.sum)) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } + + fn state(&mut self) -> DFResult> { + Ok(vec![ScalarValue::Int64(self.sum)]) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> { + if states.len() != 1 { + return Err(DataFusionError::Internal(format!( + "Invalid state while merging batch. Expected 1 element but found {}", + states.len() + ))); + } + + let that_sum_array = states[0].as_primitive::(); + let that_sum = if that_sum_array.is_null(0) { + None + } else { + Some(that_sum_array.value(0)) + }; + + if that_sum.is_none() { + return Ok(()); + } + if self.sum.is_none() { + self.sum = that_sum; + return Ok(()); + } + + self.sum = Some( + self.sum + .unwrap() + .add_checked(that_sum.unwrap()) + .map_err(|_| DataFusionError::from(arithmetic_overflow_error("integer")))?, + ); + Ok(()) + } +} + +#[derive(Debug)] +struct SumIntegerAccumulatorTry { + sum: Option, + has_all_nulls: bool, +} + +impl SumIntegerAccumulatorTry { + fn new() -> Self { + Self { + // Try mode starts with 0 (because if this is init to None we cant say if it is none due to all nulls or due to an overflow) + sum: Some(0), + has_all_nulls: true, } } + + fn overflowed(&self) -> bool { + !self.has_all_nulls && self.sum.is_none() + } } -impl Accumulator for SumIntegerAccumulator { +impl Accumulator for SumIntegerAccumulatorTry { fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> { - // accumulator internal to add sum and return null sum (and has_nulls false) if there is an overflow in Try Eval mode - fn update_sum_internal( - int_array: &PrimitiveArray, - eval_mode: EvalMode, - mut sum: i64, - ) -> Result, DataFusionError> + /// Returns Ok(Some(sum)) on success, Ok(None) on overflow + fn update_sum(int_array: &PrimitiveArray, mut sum: i64) -> DFResult> where T: ArrowPrimitiveType, { @@ -144,72 +335,41 @@ impl Accumulator for SumIntegerAccumulator { int_array.value(i) )) })?; - match eval_mode { - EvalMode::Legacy => { - sum = v.add_wrapping(sum); - } - EvalMode::Ansi | EvalMode::Try => { - match v.add_checked(sum) { - Ok(v) => sum = v, - Err(_e) => { - return if eval_mode == EvalMode::Ansi { - Err(DataFusionError::from(arithmetic_overflow_error( - "integer", - ))) - } else { - Ok(None) - }; - } - }; - } + match v.add_checked(sum) { + Ok(new_sum) => sum = new_sum, + Err(_) => return Ok(None), } } } Ok(Some(sum)) } - if self.eval_mode == EvalMode::Try && !self.has_all_nulls && self.sum.is_none() { - // we saw an overflow earlier (Try eval mode). Skip processing + // Skip if we already saw an overflow + if self.overflowed() { return Ok(()); } + let values = &values[0]; if values.len() == values.null_count() { - Ok(()) - } else { - // No nulls so there should be a non-null sum / null incase overflow in Try eval - let running_sum = self.sum.unwrap_or(0); - let sum = match values.data_type() { - DataType::Int64 => update_sum_internal( - as_primitive_array::(values), - self.eval_mode, - running_sum, - )?, - DataType::Int32 => update_sum_internal( - as_primitive_array::(values), - self.eval_mode, - running_sum, - )?, - DataType::Int16 => update_sum_internal( - as_primitive_array::(values), - self.eval_mode, - running_sum, - )?, - DataType::Int8 => update_sum_internal( - as_primitive_array::(values), - self.eval_mode, - running_sum, - )?, - _ => { - return Err(DataFusionError::Internal(format!( - "unsupported data type: {:?}", - values.data_type() - ))); - } - }; - self.sum = sum; - self.has_all_nulls = false; - Ok(()) + return Ok(()); } + + let running_sum = self.sum.unwrap_or(0); + let sum = match values.data_type() { + DataType::Int64 => update_sum(as_primitive_array::(values), running_sum)?, + DataType::Int32 => update_sum(as_primitive_array::(values), running_sum)?, + DataType::Int16 => update_sum(as_primitive_array::(values), running_sum)?, + DataType::Int8 => update_sum(as_primitive_array::(values), running_sum)?, + _ => { + return Err(DataFusionError::Internal(format!( + "unsupported data type: {:?}", + values.data_type() + ))); + } + }; + self.sum = sum; + self.has_all_nulls = false; + Ok(()) } fn evaluate(&mut self) -> DFResult { @@ -225,26 +385,16 @@ impl Accumulator for SumIntegerAccumulator { } fn state(&mut self) -> DFResult> { - if self.eval_mode == EvalMode::Try { - Ok(vec![ - ScalarValue::Int64(self.sum), - ScalarValue::Boolean(Some(self.has_all_nulls)), - ]) - } else { - Ok(vec![ScalarValue::Int64(self.sum)]) - } + Ok(vec![ + ScalarValue::Int64(self.sum), + ScalarValue::Boolean(Some(self.has_all_nulls)), + ]) } fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> { - let expected_state_len = if self.eval_mode == EvalMode::Try { - 2 - } else { - 1 - }; - if expected_state_len != states.len() { + if states.len() != 2 { return Err(DataFusionError::Internal(format!( - "Invalid state while merging batch. Expected {} elements but found {}", - expected_state_len, + "Invalid state while merging batch. Expected 2 elements but found {}", states.len() ))); } @@ -255,94 +405,326 @@ impl Accumulator for SumIntegerAccumulator { } else { Some(that_sum_array.value(0)) }; + let that_has_all_nulls = states[1].as_boolean().value(0); - // Check for overflow for early termination - if self.eval_mode == EvalMode::Try { - let that_has_all_nulls = states[1].as_boolean().value(0); - let that_overflowed = !that_has_all_nulls && that_sum.is_none(); - let this_overflowed = !self.has_all_nulls && self.sum.is_none(); - if that_overflowed || this_overflowed { + let that_overflowed = !that_has_all_nulls && that_sum.is_none(); + if that_overflowed || self.overflowed() { + self.sum = None; + self.has_all_nulls = false; + return Ok(()); + } + + if that_has_all_nulls { + return Ok(()); + } + if self.has_all_nulls { + self.sum = that_sum; + self.has_all_nulls = false; + return Ok(()); + } + + // Both sides have non-null values + match self.sum.unwrap().add_checked(that_sum.unwrap()) { + Ok(v) => self.sum = Some(v), + Err(_) => { self.sum = None; self.has_all_nulls = false; - return Ok(()); } - if that_has_all_nulls { - return Ok(()); + } + Ok(()) + } +} + +struct SumIntGroupsAccumulatorLegacy { + sums: Vec>, +} + +impl SumIntGroupsAccumulatorLegacy { + fn new() -> Self { + Self { sums: Vec::new() } + } +} + +impl GroupsAccumulator for SumIntGroupsAccumulatorLegacy { + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> DFResult<()> { + fn update_groups_sum( + int_array: &PrimitiveArray, + group_indices: &[usize], + sums: &mut [Option], + ) -> DFResult<()> + where + T: ArrowPrimitiveType, + T::Native: ArrowNativeType, + { + for (i, &group_index) in group_indices.iter().enumerate() { + if !int_array.is_null(i) { + let v = int_array.value(i).to_i64().ok_or_else(|| { + DataFusionError::Internal("Failed to convert value to i64".to_string()) + })?; + sums[group_index] = Some(sums[group_index].unwrap_or(0).add_wrapping(v)); + } } - if self.has_all_nulls { - self.sum = that_sum; - self.has_all_nulls = false; - return Ok(()); + Ok(()) + } + + debug_assert!(opt_filter.is_none(), "opt_filter is not supported yet"); + let values = &values[0]; + self.sums.resize(total_num_groups, None); + + match values.data_type() { + DataType::Int64 => update_groups_sum( + as_primitive_array::(values), + group_indices, + &mut self.sums, + )?, + DataType::Int32 => update_groups_sum( + as_primitive_array::(values), + group_indices, + &mut self.sums, + )?, + DataType::Int16 => update_groups_sum( + as_primitive_array::(values), + group_indices, + &mut self.sums, + )?, + DataType::Int8 => update_groups_sum( + as_primitive_array::(values), + group_indices, + &mut self.sums, + )?, + _ => { + return Err(DataFusionError::Internal(format!( + "Unsupported data type for SumIntGroupsAccumulatorLegacy: {:?}", + values.data_type() + ))) } - } else { - if that_sum.is_none() { - return Ok(()); + }; + Ok(()) + } + + fn evaluate(&mut self, emit_to: EmitTo) -> DFResult { + match emit_to { + EmitTo::All => { + let result = Arc::new(Int64Array::from(std::mem::take(&mut self.sums))) as ArrayRef; + Ok(result) } - if self.sum.is_none() { - self.sum = that_sum; - return Ok(()); + EmitTo::First(n) => { + let result = Arc::new(Int64Array::from(self.sums.drain(..n).collect::>())) + as ArrayRef; + Ok(result) } } + } - // safe to unwrap (since we checked nulls above) but handling error just in case state is corrupt - let left = self.sum.ok_or_else(|| { - DataFusionError::Internal( - "Invalid state in merging batch. Current batch's sum is None".to_string(), - ) - })?; - let right = that_sum.ok_or_else(|| { - DataFusionError::Internal( - "Invalid state in merging batch. Incoming sum is None".to_string(), - ) - })?; + fn state(&mut self, emit_to: EmitTo) -> DFResult> { + let sums = emit_to.take_needed(&mut self.sums); + Ok(vec![Arc::new(Int64Array::from(sums))]) + } - match self.eval_mode { - EvalMode::Legacy => { - self.sum = Some(left.add_wrapping(right)); + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> DFResult<()> { + debug_assert!(opt_filter.is_none(), "opt_filter is not supported yet"); + + if values.len() != 1 { + return Err(DataFusionError::Internal(format!( + "Invalid state while merging batch. Expected 1 element but found {}", + values.len() + ))); + } + let that_sums = values[0].as_primitive::(); + + self.sums.resize(total_num_groups, None); + + for (idx, &group_index) in group_indices.iter().enumerate() { + if that_sums.is_null(idx) { + continue; } - EvalMode::Ansi | EvalMode::Try => match left.add_checked(right) { - Ok(v) => self.sum = Some(v), - Err(_) => { - if self.eval_mode == EvalMode::Ansi { - return Err(DataFusionError::from(arithmetic_overflow_error("integer"))); - } else { - self.sum = None; - self.has_all_nulls = false; - } + let that_sum = that_sums.value(idx); + + if self.sums[group_index].is_none() { + self.sums[group_index] = Some(that_sum); + } else { + self.sums[group_index] = + Some(self.sums[group_index].unwrap().add_wrapping(that_sum)); + } + } + Ok(()) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } +} + +struct SumIntGroupsAccumulatorAnsi { + sums: Vec>, +} + +impl SumIntGroupsAccumulatorAnsi { + fn new() -> Self { + Self { sums: Vec::new() } + } +} + +impl GroupsAccumulator for SumIntGroupsAccumulatorAnsi { + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> DFResult<()> { + fn update_groups_sum( + int_array: &PrimitiveArray, + group_indices: &[usize], + sums: &mut [Option], + ) -> DFResult<()> + where + T: ArrowPrimitiveType, + T::Native: ArrowNativeType, + { + for (i, &group_index) in group_indices.iter().enumerate() { + if !int_array.is_null(i) { + let v = int_array.value(i).to_i64().ok_or_else(|| { + DataFusionError::Internal("Failed to convert value to i64".to_string()) + })?; + sums[group_index] = + Some(sums[group_index].unwrap_or(0).add_checked(v).map_err(|_| { + DataFusionError::from(arithmetic_overflow_error("integer")) + })?); } - }, + } + Ok(()) + } + + debug_assert!(opt_filter.is_none(), "opt_filter is not supported yet"); + let values = &values[0]; + self.sums.resize(total_num_groups, None); + + match values.data_type() { + DataType::Int64 => update_groups_sum( + as_primitive_array::(values), + group_indices, + &mut self.sums, + )?, + DataType::Int32 => update_groups_sum( + as_primitive_array::(values), + group_indices, + &mut self.sums, + )?, + DataType::Int16 => update_groups_sum( + as_primitive_array::(values), + group_indices, + &mut self.sums, + )?, + DataType::Int8 => update_groups_sum( + as_primitive_array::(values), + group_indices, + &mut self.sums, + )?, + _ => { + return Err(DataFusionError::Internal(format!( + "Unsupported data type for SumIntGroupsAccumulatorAnsi: {:?}", + values.data_type() + ))) + } + }; + Ok(()) + } + + fn evaluate(&mut self, emit_to: EmitTo) -> DFResult { + match emit_to { + EmitTo::All => { + let result = Arc::new(Int64Array::from(std::mem::take(&mut self.sums))) as ArrayRef; + Ok(result) + } + EmitTo::First(n) => { + let result = Arc::new(Int64Array::from(self.sums.drain(..n).collect::>())) + as ArrayRef; + Ok(result) + } + } + } + + fn state(&mut self, emit_to: EmitTo) -> DFResult> { + let sums = emit_to.take_needed(&mut self.sums); + Ok(vec![Arc::new(Int64Array::from(sums))]) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> DFResult<()> { + debug_assert!(opt_filter.is_none(), "opt_filter is not supported yet"); + + if values.len() != 1 { + return Err(DataFusionError::Internal(format!( + "Invalid state while merging batch. Expected 1 element but found {}", + values.len() + ))); + } + let that_sums = values[0].as_primitive::(); + + self.sums.resize(total_num_groups, None); + + for (idx, &group_index) in group_indices.iter().enumerate() { + if that_sums.is_null(idx) { + continue; + } + let that_sum = that_sums.value(idx); + + if self.sums[group_index].is_none() { + self.sums[group_index] = Some(that_sum); + } else { + self.sums[group_index] = Some( + self.sums[group_index] + .unwrap() + .add_checked(that_sum) + .map_err(|_| DataFusionError::from(arithmetic_overflow_error("integer")))?, + ); + } } Ok(()) } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } } -struct SumIntGroupsAccumulator { +struct SumIntGroupsAccumulatorTry { sums: Vec>, has_all_nulls: Vec, - eval_mode: EvalMode, } -impl SumIntGroupsAccumulator { - fn new(eval_mode: EvalMode) -> Self { +impl SumIntGroupsAccumulatorTry { + fn new() -> Self { Self { sums: Vec::new(), - eval_mode, has_all_nulls: Vec::new(), } } - fn resize_helper(&mut self, total_num_groups: usize) { - if self.eval_mode == EvalMode::Try { - self.sums.resize(total_num_groups, Some(0)); - self.has_all_nulls.resize(total_num_groups, true); - } else { - self.sums.resize(total_num_groups, None); - self.has_all_nulls.resize(total_num_groups, false); - } + fn group_overflowed(&self, group_index: usize) -> bool { + !self.has_all_nulls[group_index] && self.sums[group_index].is_none() } } -impl GroupsAccumulator for SumIntGroupsAccumulator { +impl GroupsAccumulator for SumIntGroupsAccumulatorTry { fn update_batch( &mut self, values: &[ArrayRef], @@ -350,12 +732,11 @@ impl GroupsAccumulator for SumIntGroupsAccumulator { opt_filter: Option<&BooleanArray>, total_num_groups: usize, ) -> DFResult<()> { - fn update_groups_sum_internal( + fn update_groups_sum( int_array: &PrimitiveArray, group_indices: &[usize], sums: &mut [Option], has_all_nulls: &mut [bool], - eval_mode: EvalMode, ) -> DFResult<()> where T: ArrowPrimitiveType, @@ -363,39 +744,18 @@ impl GroupsAccumulator for SumIntGroupsAccumulator { { for (i, &group_index) in group_indices.iter().enumerate() { if !int_array.is_null(i) { - // there is an overflow in prev group in try eval. Skip processing - if eval_mode == EvalMode::Try - && !has_all_nulls[group_index] - && sums[group_index].is_none() - { + // Skip if this group already overflowed + if !has_all_nulls[group_index] && sums[group_index].is_none() { continue; } let v = int_array.value(i).to_i64().ok_or_else(|| { DataFusionError::Internal("Failed to convert value to i64".to_string()) })?; - match eval_mode { - EvalMode::Legacy => { - sums[group_index] = - Some(sums[group_index].unwrap_or(0).add_wrapping(v)); - } - EvalMode::Ansi | EvalMode::Try => { - match sums[group_index].unwrap_or(0).add_checked(v) { - Ok(new_sum) => { - sums[group_index] = Some(new_sum); - } - Err(_) => { - if eval_mode == EvalMode::Ansi { - return Err(DataFusionError::from( - arithmetic_overflow_error("integer"), - )); - } else { - sums[group_index] = None; - } - } - }; - } - } - has_all_nulls[group_index] = false + match sums[group_index].unwrap_or(0).add_checked(v) { + Ok(new_sum) => sums[group_index] = Some(new_sum), + Err(_) => sums[group_index] = None, + }; + has_all_nulls[group_index] = false; } } Ok(()) @@ -403,40 +763,37 @@ impl GroupsAccumulator for SumIntGroupsAccumulator { debug_assert!(opt_filter.is_none(), "opt_filter is not supported yet"); let values = &values[0]; - self.resize_helper(total_num_groups); + self.sums.resize(total_num_groups, Some(0)); + self.has_all_nulls.resize(total_num_groups, true); match values.data_type() { - DataType::Int64 => update_groups_sum_internal( + DataType::Int64 => update_groups_sum( as_primitive_array::(values), group_indices, &mut self.sums, &mut self.has_all_nulls, - self.eval_mode, )?, - DataType::Int32 => update_groups_sum_internal( + DataType::Int32 => update_groups_sum( as_primitive_array::(values), group_indices, &mut self.sums, &mut self.has_all_nulls, - self.eval_mode, )?, - DataType::Int16 => update_groups_sum_internal( + DataType::Int16 => update_groups_sum( as_primitive_array::(values), group_indices, &mut self.sums, &mut self.has_all_nulls, - self.eval_mode, )?, - DataType::Int8 => update_groups_sum_internal( + DataType::Int8 => update_groups_sum( as_primitive_array::(values), group_indices, &mut self.sums, &mut self.has_all_nulls, - self.eval_mode, )?, _ => { return Err(DataFusionError::Internal(format!( - "Unsupported data type for SumIntGroupsAccumulator: {:?}", + "Unsupported data type for SumIntGroupsAccumulatorTry: {:?}", values.data_type() ))) } @@ -453,7 +810,6 @@ impl GroupsAccumulator for SumIntGroupsAccumulator { .zip(self.has_all_nulls.iter()) .map(|(&sum, &is_null)| if is_null { None } else { sum }), )) as ArrayRef; - self.sums.clear(); self.has_all_nulls.clear(); Ok(result) @@ -472,16 +828,11 @@ impl GroupsAccumulator for SumIntGroupsAccumulator { fn state(&mut self, emit_to: EmitTo) -> DFResult> { let sums = emit_to.take_needed(&mut self.sums); - - if self.eval_mode == EvalMode::Try { - let has_all_nulls = emit_to.take_needed(&mut self.has_all_nulls); - Ok(vec![ - Arc::new(Int64Array::from(sums)), - Arc::new(BooleanArray::from(has_all_nulls)), - ]) - } else { - Ok(vec![Arc::new(Int64Array::from(sums))]) - } + let has_all_nulls = emit_to.take_needed(&mut self.has_all_nulls); + Ok(vec![ + Arc::new(Int64Array::from(sums)), + Arc::new(BooleanArray::from(has_all_nulls)), + ]) } fn merge_batch( @@ -493,27 +844,17 @@ impl GroupsAccumulator for SumIntGroupsAccumulator { ) -> DFResult<()> { debug_assert!(opt_filter.is_none(), "opt_filter is not supported yet"); - let expected_state_len = if self.eval_mode == EvalMode::Try { - 2 - } else { - 1 - }; - if expected_state_len != values.len() { + if values.len() != 2 { return Err(DataFusionError::Internal(format!( - "Invalid state while merging batch. Expected {} elements but found {}", - expected_state_len, + "Invalid state while merging batch. Expected 2 elements but found {}", values.len() ))); } let that_sums = values[0].as_primitive::(); + let that_has_all_nulls_array = values[1].as_boolean(); - self.resize_helper(total_num_groups); - - let that_sums_is_all_nulls = if self.eval_mode == EvalMode::Try { - Some(values[1].as_boolean()) - } else { - None - }; + self.sums.resize(total_num_groups, Some(0)); + self.has_all_nulls.resize(total_num_groups, true); for (idx, &group_index) in group_indices.iter().enumerate() { let that_sum = if that_sums.is_null(idx) { @@ -521,62 +862,34 @@ impl GroupsAccumulator for SumIntGroupsAccumulator { } else { Some(that_sums.value(idx)) }; + let that_has_all_nulls = that_has_all_nulls_array.value(idx); - if self.eval_mode == EvalMode::Try { - let that_has_all_nulls = that_sums_is_all_nulls.unwrap().value(idx); - - let that_overflowed = !that_has_all_nulls && that_sum.is_none(); - let this_overflowed = - !self.has_all_nulls[group_index] && self.sums[group_index].is_none(); - - if that_overflowed || this_overflowed { - self.sums[group_index] = None; - self.has_all_nulls[group_index] = false; - continue; - } - - if that_has_all_nulls { - continue; - } + let that_overflowed = !that_has_all_nulls && that_sum.is_none(); + if that_overflowed || self.group_overflowed(group_index) { + self.sums[group_index] = None; + self.has_all_nulls[group_index] = false; + continue; + } - if self.has_all_nulls[group_index] { - self.sums[group_index] = that_sum; - self.has_all_nulls[group_index] = false; - continue; - } - } else { - if that_sum.is_none() { - continue; - } - if self.sums[group_index].is_none() { - self.sums[group_index] = that_sum; - continue; - } + if that_has_all_nulls { + continue; } - // Both sides have non-null. Update sums now - let left = self.sums[group_index].unwrap(); - let right = that_sum.unwrap(); + if self.has_all_nulls[group_index] { + self.sums[group_index] = that_sum; + self.has_all_nulls[group_index] = false; + continue; + } - match self.eval_mode { - EvalMode::Legacy => { - self.sums[group_index] = Some(left.add_wrapping(right)); - } - EvalMode::Ansi | EvalMode::Try => { - match left.add_checked(right) { - Ok(v) => self.sums[group_index] = Some(v), - Err(_) => { - if self.eval_mode == EvalMode::Ansi { - return Err(DataFusionError::from(arithmetic_overflow_error( - "integer", - ))); - } else { - // overflow. update flag accordingly - self.sums[group_index] = None; - self.has_all_nulls[group_index] = false; - } - } - } + // Both sides have non-null values + match self.sums[group_index] + .unwrap() + .add_checked(that_sum.unwrap()) + { + Ok(v) => self.sums[group_index] = Some(v), + Err(_) => { + self.sums[group_index] = None; + self.has_all_nulls[group_index] = false; } } }