Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion vortex-array/public-api.lock
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,11 @@ pub mod vortex_array::aggregate_fn::fns::sum

pub enum vortex_array::aggregate_fn::fns::sum::SumState

pub vortex_array::aggregate_fn::fns::sum::SumState::Decimal(vortex_array::scalar::DecimalValue)
pub vortex_array::aggregate_fn::fns::sum::SumState::Decimal

pub vortex_array::aggregate_fn::fns::sum::SumState::Decimal::dtype: vortex_array::dtype::DecimalDType

pub vortex_array::aggregate_fn::fns::sum::SumState::Decimal::value: vortex_array::scalar::DecimalValue

pub vortex_array::aggregate_fn::fns::sum::SumState::Float(f64)

Expand Down
141 changes: 136 additions & 5 deletions vortex-array/src/aggregate_fn/fns/sum/decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use crate::scalar::DecimalValue;
/// Accumulate a decimal array into the sum state.
/// Returns Ok(true) if saturated (overflow), Ok(false) if not.
pub(super) fn accumulate_decimal(inner: &mut SumState, d: &DecimalArray) -> VortexResult<bool> {
let SumState::Decimal(acc) = inner else {
let SumState::Decimal { value, dtype } = inner else {
Comment thread
robert3005 marked this conversation as resolved.
vortex_panic!("expected decimal sum state for decimal input");
};

Expand All @@ -23,8 +23,14 @@ pub(super) fn accumulate_decimal(inner: &mut SumState, d: &DecimalArray) -> Vort
AllOr::None => Ok(false),
AllOr::All => match_each_decimal_value_type!(d.values_type(), |T| {
for &v in d.buffer::<T>().iter() {
match acc.checked_add(&DecimalValue::from(v)) {
Some(r) => *acc = r,
match value.checked_add(&DecimalValue::from(v)) {
Some(r) => {
*value = r;
// Check for overflow
if !value.fits_in_precision(*dtype) {
return Ok(true);
}
}
None => return Ok(true),
}
}
Expand All @@ -33,8 +39,13 @@ pub(super) fn accumulate_decimal(inner: &mut SumState, d: &DecimalArray) -> Vort
AllOr::Some(validity) => match_each_decimal_value_type!(d.values_type(), |T| {
for (&v, valid) in d.buffer::<T>().iter().zip_eq(validity.iter()) {
if valid {
match acc.checked_add(&DecimalValue::from(v)) {
Some(r) => *acc = r,
match value.checked_add(&DecimalValue::from(v)) {
Some(r) => {
*value = r;
if !value.fits_in_precision(*dtype) {
return Ok(true);
}
}
None => return Ok(true),
}
}
Expand All @@ -53,6 +64,9 @@ mod tests {
use crate::IntoArray;
use crate::LEGACY_SESSION;
use crate::VortexSessionExecute;
use crate::aggregate_fn::AggregateFnVTable;
use crate::aggregate_fn::EmptyOptions;
use crate::aggregate_fn::fns::sum::Sum;
use crate::aggregate_fn::fns::sum::sum;
use crate::arrays::DecimalArray;
use crate::dtype::DType;
Expand Down Expand Up @@ -286,4 +300,121 @@ mod tests {
);
Ok(())
}

#[test]
fn sum_decimal_near_precision_boundary() -> VortexResult<()> {
// Input precision 4 → return precision min(76, 4+10) = 14.
// Native type for precision 14 is I64 (max precision 18), so 14 < 18.
// Use combine_partials to push state near (but under) 10^14.
let input_dtype = DType::Decimal(DecimalDType::new(4, 0), Nullability::NonNullable);
let mut state = Sum.empty_partial(&EmptyOptions, &input_dtype)?;

let near_limit = Scalar::decimal(
DecimalValue::from(99_999_999_999_990i64),
DecimalDType::new(14, 0),
Nullable,
);
Sum.combine_partials(&mut state, near_limit)?;

// Add a small value that keeps us just under 10^14.
let small = Scalar::decimal(DecimalValue::from(9i64), DecimalDType::new(14, 0), Nullable);
Sum.combine_partials(&mut state, small)?;

let result = Sum.flush(&mut state)?;
assert!(!result.is_null());
assert_eq!(
result.as_decimal().decimal_value(),
Some(DecimalValue::I256(i256::from_i128(99_999_999_999_999)))
);
Ok(())
}

#[test]
fn sum_decimal_precision_overflow_within_i256() -> VortexResult<()> {
// Input precision 4 → return precision 14. Native I64 (max 18).
// The max representable value for precision 14 is 10^14 - 1.
// When the sum reaches exactly 10^14, fits_in_precision fails even though
// i256 arithmetic does not overflow. This tests the precision-based
// saturation path in combine_partials.
let input_dtype = DType::Decimal(DecimalDType::new(4, 0), Nullability::NonNullable);
let mut state = Sum.empty_partial(&EmptyOptions, &input_dtype)?;

let near_limit = Scalar::decimal(
DecimalValue::from(99_999_999_999_999i64),
DecimalDType::new(14, 0),
Nullable,
);
Sum.combine_partials(&mut state, near_limit)?;

// Push the sum to exactly 10^14, exceeding precision 14.
let one_more =
Scalar::decimal(DecimalValue::from(1i64), DecimalDType::new(14, 0), Nullable);
Sum.combine_partials(&mut state, one_more)?;

let result = Sum.flush(&mut state)?;
assert!(result.is_null());
assert_eq!(
result.dtype(),
&DType::Decimal(DecimalDType::new(14, 0), Nullable)
);
Ok(())
}

#[test]
fn sum_decimal_precision_overflow_negative() -> VortexResult<()> {
// Same setup but with negative values: sum reaches -10^14.
let input_dtype = DType::Decimal(DecimalDType::new(4, 0), Nullability::NonNullable);
let mut state = Sum.empty_partial(&EmptyOptions, &input_dtype)?;

let near_limit = Scalar::decimal(
DecimalValue::from(-99_999_999_999_999i64),
DecimalDType::new(14, 0),
Nullable,
);
Sum.combine_partials(&mut state, near_limit)?;

let one_more = Scalar::decimal(
DecimalValue::from(-1i64),
DecimalDType::new(14, 0),
Nullable,
);
Sum.combine_partials(&mut state, one_more)?;

let result = Sum.flush(&mut state)?;
assert!(result.is_null());
Ok(())
}

#[test]
fn sum_decimal_accumulate_precision_overflow() -> VortexResult<()> {
// Test precision overflow via the accumulate_decimal path (not combine_partials).
// Input precision 28 (I128 storage) → return precision min(76, 38) = 38.
// Native for precision 38 is I128 (max 38), so 38 = 38.
// Use precision 27 → return 37. Native for 37 is I128 (max 38), so 37 < 38.
//
// We use combine_partials to get the state close to 10^37, then accumulate
// a real array that pushes it over.
let input_dtype = DType::Decimal(DecimalDType::new(27, 0), Nullability::NonNullable);
let return_dtype = DecimalDType::new(37, 0);
let mut state = Sum.empty_partial(&EmptyOptions, &input_dtype)?;

// Set state to 10^37 - 1 via combine_partials.
let near_limit_val: i128 = 10i128.pow(37) - 1;
let near_limit =
Scalar::decimal(DecimalValue::from(near_limit_val), return_dtype, Nullable);
Sum.combine_partials(&mut state, near_limit)?;

// Now accumulate a real i128 array with a single element = 1 to overflow precision.
let decimal =
DecimalArray::new(buffer![1i128], DecimalDType::new(27, 0), Validity::AllValid);

// Drive accumulate through the vtable directly.
let columnar = crate::Columnar::Canonical(crate::Canonical::Decimal(decimal));
let mut ctx = LEGACY_SESSION.create_execution_ctx();
Sum.accumulate(&mut state, &columnar, &mut ctx)?;

let result = Sum.flush(&mut state)?;
assert!(result.is_null());
Ok(())
}
}
22 changes: 14 additions & 8 deletions vortex-array/src/aggregate_fn/fns/sum/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,15 +160,15 @@ impl AggregateFnVTable for Sum {
*acc += val;
false
}
SumState::Decimal(acc) => {
SumState::Decimal { value, dtype } => {
let val = other
.as_decimal()
.decimal_value()
.vortex_expect("checked non-null");
match acc.checked_add(&val) {
match value.checked_add(&val) {
Some(r) => {
*acc = r;
false
*value = r;
!value.fits_in_precision(*dtype)
}
None => true,
}
Expand All @@ -186,12 +186,12 @@ impl AggregateFnVTable for Sum {
Some(SumState::Unsigned(v)) => Scalar::primitive(*v, Nullability::Nullable),
Some(SumState::Signed(v)) => Scalar::primitive(*v, Nullability::Nullable),
Some(SumState::Float(v)) => Scalar::primitive(*v, Nullability::Nullable),
Some(SumState::Decimal(v)) => {
Some(SumState::Decimal { value, .. }) => {
let decimal_dtype = *partial
.return_dtype
.as_decimal_opt()
.vortex_expect("return dtype must be decimal");
Scalar::decimal(*v, decimal_dtype, Nullability::Nullable)
Scalar::decimal(*value, decimal_dtype, Nullability::Nullable)
}
};

Expand Down Expand Up @@ -271,7 +271,10 @@ pub enum SumState {
Unsigned(u64),
Signed(i64),
Float(f64),
Decimal(DecimalValue),
Decimal {
value: DecimalValue,
dtype: DecimalDType,
},
}

fn make_zero_state(return_dtype: &DType) -> SumState {
Expand All @@ -281,7 +284,10 @@ fn make_zero_state(return_dtype: &DType) -> SumState {
PType::I8 | PType::I16 | PType::I32 | PType::I64 => SumState::Signed(0),
PType::F16 | PType::F32 | PType::F64 => SumState::Float(0.0),
},
DType::Decimal(decimal, _) => SumState::Decimal(DecimalValue::zero(decimal)),
DType::Decimal(decimal, _) => SumState::Decimal {
value: DecimalValue::zero(decimal),
dtype: *decimal,
},
_ => vortex_panic!("Unsupported sum type"),
}
}
Expand Down
4 changes: 2 additions & 2 deletions vortex-array/src/aggregate_fn/fns/sum/primitive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ fn accumulate_primitive_all(inner: &mut SumState, p: &PrimitiveArray) -> VortexR
Ok(false)
}
),
SumState::Decimal(_) => vortex_panic!("decimal sum state with primitive input"),
SumState::Decimal { .. } => vortex_panic!("decimal sum state with primitive input"),
}
}

Expand Down Expand Up @@ -105,7 +105,7 @@ fn accumulate_primitive_valid(
Ok(false)
}
),
SumState::Decimal(_) => vortex_panic!("decimal sum state with primitive input"),
SumState::Decimal { .. } => vortex_panic!("decimal sum state with primitive input"),
}
}

Expand Down
Loading