diff --git a/vortex-array/public-api.lock b/vortex-array/public-api.lock index a79061a8313..20d6ad7dde3 100644 --- a/vortex-array/public-api.lock +++ b/vortex-array/public-api.lock @@ -80,7 +80,11 @@ pub mod vortex_array::aggregate_fn::fns::sum pub enum vortex_array::aggregate_fn::fns::sum::SumState -pub vortex_array::aggregate_fn::fns::sum::SumState::Decimal(vortex_array::scalar::DecimalValue) +pub vortex_array::aggregate_fn::fns::sum::SumState::Decimal + +pub vortex_array::aggregate_fn::fns::sum::SumState::Decimal::dtype: vortex_array::dtype::DecimalDType + +pub vortex_array::aggregate_fn::fns::sum::SumState::Decimal::value: vortex_array::scalar::DecimalValue pub vortex_array::aggregate_fn::fns::sum::SumState::Float(f64) diff --git a/vortex-array/src/aggregate_fn/fns/sum/decimal.rs b/vortex-array/src/aggregate_fn/fns/sum/decimal.rs index fc388c57b49..b89b12ce74e 100644 --- a/vortex-array/src/aggregate_fn/fns/sum/decimal.rs +++ b/vortex-array/src/aggregate_fn/fns/sum/decimal.rs @@ -14,7 +14,7 @@ use crate::scalar::DecimalValue; /// Accumulate a decimal array into the sum state. /// Returns Ok(true) if saturated (overflow), Ok(false) if not. pub(super) fn accumulate_decimal(inner: &mut SumState, d: &DecimalArray) -> VortexResult { - let SumState::Decimal(acc) = inner else { + let SumState::Decimal { value, dtype } = inner else { vortex_panic!("expected decimal sum state for decimal input"); }; @@ -23,8 +23,14 @@ pub(super) fn accumulate_decimal(inner: &mut SumState, d: &DecimalArray) -> Vort AllOr::None => Ok(false), AllOr::All => match_each_decimal_value_type!(d.values_type(), |T| { for &v in d.buffer::().iter() { - match acc.checked_add(&DecimalValue::from(v)) { - Some(r) => *acc = r, + match value.checked_add(&DecimalValue::from(v)) { + Some(r) => { + *value = r; + // Check for overflow + if !value.fits_in_precision(*dtype) { + return Ok(true); + } + } None => return Ok(true), } } @@ -33,8 +39,13 @@ pub(super) fn accumulate_decimal(inner: &mut SumState, d: &DecimalArray) -> Vort AllOr::Some(validity) => match_each_decimal_value_type!(d.values_type(), |T| { for (&v, valid) in d.buffer::().iter().zip_eq(validity.iter()) { if valid { - match acc.checked_add(&DecimalValue::from(v)) { - Some(r) => *acc = r, + match value.checked_add(&DecimalValue::from(v)) { + Some(r) => { + *value = r; + if !value.fits_in_precision(*dtype) { + return Ok(true); + } + } None => return Ok(true), } } @@ -53,6 +64,9 @@ mod tests { use crate::IntoArray; use crate::LEGACY_SESSION; use crate::VortexSessionExecute; + use crate::aggregate_fn::AggregateFnVTable; + use crate::aggregate_fn::EmptyOptions; + use crate::aggregate_fn::fns::sum::Sum; use crate::aggregate_fn::fns::sum::sum; use crate::arrays::DecimalArray; use crate::dtype::DType; @@ -286,4 +300,121 @@ mod tests { ); Ok(()) } + + #[test] + fn sum_decimal_near_precision_boundary() -> VortexResult<()> { + // Input precision 4 → return precision min(76, 4+10) = 14. + // Native type for precision 14 is I64 (max precision 18), so 14 < 18. + // Use combine_partials to push state near (but under) 10^14. + let input_dtype = DType::Decimal(DecimalDType::new(4, 0), Nullability::NonNullable); + let mut state = Sum.empty_partial(&EmptyOptions, &input_dtype)?; + + let near_limit = Scalar::decimal( + DecimalValue::from(99_999_999_999_990i64), + DecimalDType::new(14, 0), + Nullable, + ); + Sum.combine_partials(&mut state, near_limit)?; + + // Add a small value that keeps us just under 10^14. + let small = Scalar::decimal(DecimalValue::from(9i64), DecimalDType::new(14, 0), Nullable); + Sum.combine_partials(&mut state, small)?; + + let result = Sum.flush(&mut state)?; + assert!(!result.is_null()); + assert_eq!( + result.as_decimal().decimal_value(), + Some(DecimalValue::I256(i256::from_i128(99_999_999_999_999))) + ); + Ok(()) + } + + #[test] + fn sum_decimal_precision_overflow_within_i256() -> VortexResult<()> { + // Input precision 4 → return precision 14. Native I64 (max 18). + // The max representable value for precision 14 is 10^14 - 1. + // When the sum reaches exactly 10^14, fits_in_precision fails even though + // i256 arithmetic does not overflow. This tests the precision-based + // saturation path in combine_partials. + let input_dtype = DType::Decimal(DecimalDType::new(4, 0), Nullability::NonNullable); + let mut state = Sum.empty_partial(&EmptyOptions, &input_dtype)?; + + let near_limit = Scalar::decimal( + DecimalValue::from(99_999_999_999_999i64), + DecimalDType::new(14, 0), + Nullable, + ); + Sum.combine_partials(&mut state, near_limit)?; + + // Push the sum to exactly 10^14, exceeding precision 14. + let one_more = + Scalar::decimal(DecimalValue::from(1i64), DecimalDType::new(14, 0), Nullable); + Sum.combine_partials(&mut state, one_more)?; + + let result = Sum.flush(&mut state)?; + assert!(result.is_null()); + assert_eq!( + result.dtype(), + &DType::Decimal(DecimalDType::new(14, 0), Nullable) + ); + Ok(()) + } + + #[test] + fn sum_decimal_precision_overflow_negative() -> VortexResult<()> { + // Same setup but with negative values: sum reaches -10^14. + let input_dtype = DType::Decimal(DecimalDType::new(4, 0), Nullability::NonNullable); + let mut state = Sum.empty_partial(&EmptyOptions, &input_dtype)?; + + let near_limit = Scalar::decimal( + DecimalValue::from(-99_999_999_999_999i64), + DecimalDType::new(14, 0), + Nullable, + ); + Sum.combine_partials(&mut state, near_limit)?; + + let one_more = Scalar::decimal( + DecimalValue::from(-1i64), + DecimalDType::new(14, 0), + Nullable, + ); + Sum.combine_partials(&mut state, one_more)?; + + let result = Sum.flush(&mut state)?; + assert!(result.is_null()); + Ok(()) + } + + #[test] + fn sum_decimal_accumulate_precision_overflow() -> VortexResult<()> { + // Test precision overflow via the accumulate_decimal path (not combine_partials). + // Input precision 28 (I128 storage) → return precision min(76, 38) = 38. + // Native for precision 38 is I128 (max 38), so 38 = 38. + // Use precision 27 → return 37. Native for 37 is I128 (max 38), so 37 < 38. + // + // We use combine_partials to get the state close to 10^37, then accumulate + // a real array that pushes it over. + let input_dtype = DType::Decimal(DecimalDType::new(27, 0), Nullability::NonNullable); + let return_dtype = DecimalDType::new(37, 0); + let mut state = Sum.empty_partial(&EmptyOptions, &input_dtype)?; + + // Set state to 10^37 - 1 via combine_partials. + let near_limit_val: i128 = 10i128.pow(37) - 1; + let near_limit = + Scalar::decimal(DecimalValue::from(near_limit_val), return_dtype, Nullable); + Sum.combine_partials(&mut state, near_limit)?; + + // Now accumulate a real i128 array with a single element = 1 to overflow precision. + let decimal = + DecimalArray::new(buffer![1i128], DecimalDType::new(27, 0), Validity::AllValid); + + // Drive accumulate through the vtable directly. + let columnar = crate::Columnar::Canonical(crate::Canonical::Decimal(decimal)); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + Sum.accumulate(&mut state, &columnar, &mut ctx)?; + + let result = Sum.flush(&mut state)?; + assert!(result.is_null()); + Ok(()) + } } diff --git a/vortex-array/src/aggregate_fn/fns/sum/mod.rs b/vortex-array/src/aggregate_fn/fns/sum/mod.rs index c5b5c94630d..6ef11472099 100644 --- a/vortex-array/src/aggregate_fn/fns/sum/mod.rs +++ b/vortex-array/src/aggregate_fn/fns/sum/mod.rs @@ -160,15 +160,15 @@ impl AggregateFnVTable for Sum { *acc += val; false } - SumState::Decimal(acc) => { + SumState::Decimal { value, dtype } => { let val = other .as_decimal() .decimal_value() .vortex_expect("checked non-null"); - match acc.checked_add(&val) { + match value.checked_add(&val) { Some(r) => { - *acc = r; - false + *value = r; + !value.fits_in_precision(*dtype) } None => true, } @@ -186,12 +186,12 @@ impl AggregateFnVTable for Sum { Some(SumState::Unsigned(v)) => Scalar::primitive(*v, Nullability::Nullable), Some(SumState::Signed(v)) => Scalar::primitive(*v, Nullability::Nullable), Some(SumState::Float(v)) => Scalar::primitive(*v, Nullability::Nullable), - Some(SumState::Decimal(v)) => { + Some(SumState::Decimal { value, .. }) => { let decimal_dtype = *partial .return_dtype .as_decimal_opt() .vortex_expect("return dtype must be decimal"); - Scalar::decimal(*v, decimal_dtype, Nullability::Nullable) + Scalar::decimal(*value, decimal_dtype, Nullability::Nullable) } }; @@ -271,7 +271,10 @@ pub enum SumState { Unsigned(u64), Signed(i64), Float(f64), - Decimal(DecimalValue), + Decimal { + value: DecimalValue, + dtype: DecimalDType, + }, } fn make_zero_state(return_dtype: &DType) -> SumState { @@ -281,7 +284,10 @@ fn make_zero_state(return_dtype: &DType) -> SumState { PType::I8 | PType::I16 | PType::I32 | PType::I64 => SumState::Signed(0), PType::F16 | PType::F32 | PType::F64 => SumState::Float(0.0), }, - DType::Decimal(decimal, _) => SumState::Decimal(DecimalValue::zero(decimal)), + DType::Decimal(decimal, _) => SumState::Decimal { + value: DecimalValue::zero(decimal), + dtype: *decimal, + }, _ => vortex_panic!("Unsupported sum type"), } } diff --git a/vortex-array/src/aggregate_fn/fns/sum/primitive.rs b/vortex-array/src/aggregate_fn/fns/sum/primitive.rs index 292711f95bf..7cb29e252bf 100644 --- a/vortex-array/src/aggregate_fn/fns/sum/primitive.rs +++ b/vortex-array/src/aggregate_fn/fns/sum/primitive.rs @@ -59,7 +59,7 @@ fn accumulate_primitive_all(inner: &mut SumState, p: &PrimitiveArray) -> VortexR Ok(false) } ), - SumState::Decimal(_) => vortex_panic!("decimal sum state with primitive input"), + SumState::Decimal { .. } => vortex_panic!("decimal sum state with primitive input"), } } @@ -105,7 +105,7 @@ fn accumulate_primitive_valid( Ok(false) } ), - SumState::Decimal(_) => vortex_panic!("decimal sum state with primitive input"), + SumState::Decimal { .. } => vortex_panic!("decimal sum state with primitive input"), } }