Skip to content

Commit cb6c2ed

Browse files
committed
Check for overflow in decimal sum
Signed-off-by: Nicholas Gates <nick@nickgates.com>
1 parent fb87092 commit cb6c2ed

3 files changed

Lines changed: 152 additions & 15 deletions

File tree

vortex-array/src/aggregate_fn/fns/sum/decimal.rs

Lines changed: 136 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ use crate::scalar::DecimalValue;
1414
/// Accumulate a decimal array into the sum state.
1515
/// Returns Ok(true) if saturated (overflow), Ok(false) if not.
1616
pub(super) fn accumulate_decimal(inner: &mut SumState, d: &DecimalArray) -> VortexResult<bool> {
17-
let SumState::Decimal(acc) = inner else {
17+
let SumState::Decimal { value, dtype } = inner else {
1818
vortex_panic!("expected decimal sum state for decimal input");
1919
};
2020

@@ -23,8 +23,14 @@ pub(super) fn accumulate_decimal(inner: &mut SumState, d: &DecimalArray) -> Vort
2323
AllOr::None => Ok(false),
2424
AllOr::All => match_each_decimal_value_type!(d.values_type(), |T| {
2525
for &v in d.buffer::<T>().iter() {
26-
match acc.checked_add(&DecimalValue::from(v)) {
27-
Some(r) => *acc = r,
26+
match value.checked_add(&DecimalValue::from(v)) {
27+
Some(r) => {
28+
*value = r;
29+
// Check for overflow
30+
if !value.fits_in_precision(*dtype) {
31+
return Ok(true);
32+
}
33+
}
2834
None => return Ok(true),
2935
}
3036
}
@@ -33,8 +39,13 @@ pub(super) fn accumulate_decimal(inner: &mut SumState, d: &DecimalArray) -> Vort
3339
AllOr::Some(validity) => match_each_decimal_value_type!(d.values_type(), |T| {
3440
for (&v, valid) in d.buffer::<T>().iter().zip_eq(validity.iter()) {
3541
if valid {
36-
match acc.checked_add(&DecimalValue::from(v)) {
37-
Some(r) => *acc = r,
42+
match value.checked_add(&DecimalValue::from(v)) {
43+
Some(r) => {
44+
*value = r;
45+
if !value.fits_in_precision(*dtype) {
46+
return Ok(true);
47+
}
48+
}
3849
None => return Ok(true),
3950
}
4051
}
@@ -53,6 +64,9 @@ mod tests {
5364
use crate::IntoArray;
5465
use crate::LEGACY_SESSION;
5566
use crate::VortexSessionExecute;
67+
use crate::aggregate_fn::AggregateFnVTable;
68+
use crate::aggregate_fn::EmptyOptions;
69+
use crate::aggregate_fn::fns::sum::Sum;
5670
use crate::aggregate_fn::fns::sum::sum;
5771
use crate::arrays::DecimalArray;
5872
use crate::dtype::DType;
@@ -286,4 +300,121 @@ mod tests {
286300
);
287301
Ok(())
288302
}
303+
304+
#[test]
305+
fn sum_decimal_near_precision_boundary() -> VortexResult<()> {
306+
// Input precision 4 → return precision min(76, 4+10) = 14.
307+
// Native type for precision 14 is I64 (max precision 18), so 14 < 18.
308+
// Use combine_partials to push state near (but under) 10^14.
309+
let input_dtype = DType::Decimal(DecimalDType::new(4, 0), Nullability::NonNullable);
310+
let mut state = Sum.empty_partial(&EmptyOptions, &input_dtype)?;
311+
312+
let near_limit = Scalar::decimal(
313+
DecimalValue::from(99_999_999_999_990i64),
314+
DecimalDType::new(14, 0),
315+
Nullable,
316+
);
317+
Sum.combine_partials(&mut state, near_limit)?;
318+
319+
// Add a small value that keeps us just under 10^14.
320+
let small = Scalar::decimal(DecimalValue::from(9i64), DecimalDType::new(14, 0), Nullable);
321+
Sum.combine_partials(&mut state, small)?;
322+
323+
let result = Sum.flush(&mut state)?;
324+
assert!(!result.is_null());
325+
assert_eq!(
326+
result.as_decimal().decimal_value(),
327+
Some(DecimalValue::I256(i256::from_i128(99_999_999_999_999)))
328+
);
329+
Ok(())
330+
}
331+
332+
#[test]
333+
fn sum_decimal_precision_overflow_within_i256() -> VortexResult<()> {
334+
// Input precision 4 → return precision 14. Native I64 (max 18).
335+
// The max representable value for precision 14 is 10^14 - 1.
336+
// When the sum reaches exactly 10^14, fits_in_precision fails even though
337+
// i256 arithmetic does not overflow. This tests the precision-based
338+
// saturation path in combine_partials.
339+
let input_dtype = DType::Decimal(DecimalDType::new(4, 0), Nullability::NonNullable);
340+
let mut state = Sum.empty_partial(&EmptyOptions, &input_dtype)?;
341+
342+
let near_limit = Scalar::decimal(
343+
DecimalValue::from(99_999_999_999_999i64),
344+
DecimalDType::new(14, 0),
345+
Nullable,
346+
);
347+
Sum.combine_partials(&mut state, near_limit)?;
348+
349+
// Push the sum to exactly 10^14, exceeding precision 14.
350+
let one_more =
351+
Scalar::decimal(DecimalValue::from(1i64), DecimalDType::new(14, 0), Nullable);
352+
Sum.combine_partials(&mut state, one_more)?;
353+
354+
let result = Sum.flush(&mut state)?;
355+
assert!(result.is_null());
356+
assert_eq!(
357+
result.dtype(),
358+
&DType::Decimal(DecimalDType::new(14, 0), Nullable)
359+
);
360+
Ok(())
361+
}
362+
363+
#[test]
364+
fn sum_decimal_precision_overflow_negative() -> VortexResult<()> {
365+
// Same setup but with negative values: sum reaches -10^14.
366+
let input_dtype = DType::Decimal(DecimalDType::new(4, 0), Nullability::NonNullable);
367+
let mut state = Sum.empty_partial(&EmptyOptions, &input_dtype)?;
368+
369+
let near_limit = Scalar::decimal(
370+
DecimalValue::from(-99_999_999_999_999i64),
371+
DecimalDType::new(14, 0),
372+
Nullable,
373+
);
374+
Sum.combine_partials(&mut state, near_limit)?;
375+
376+
let one_more = Scalar::decimal(
377+
DecimalValue::from(-1i64),
378+
DecimalDType::new(14, 0),
379+
Nullable,
380+
);
381+
Sum.combine_partials(&mut state, one_more)?;
382+
383+
let result = Sum.flush(&mut state)?;
384+
assert!(result.is_null());
385+
Ok(())
386+
}
387+
388+
#[test]
389+
fn sum_decimal_accumulate_precision_overflow() -> VortexResult<()> {
390+
// Test precision overflow via the accumulate_decimal path (not combine_partials).
391+
// Input precision 28 (I128 storage) → return precision min(76, 38) = 38.
392+
// Native for precision 38 is I128 (max 38), so 38 = 38.
393+
// Use precision 27 → return 37. Native for 37 is I128 (max 38), so 37 < 38.
394+
//
395+
// We use combine_partials to get the state close to 10^37, then accumulate
396+
// a real array that pushes it over.
397+
let input_dtype = DType::Decimal(DecimalDType::new(27, 0), Nullability::NonNullable);
398+
let return_dtype = DecimalDType::new(37, 0);
399+
let mut state = Sum.empty_partial(&EmptyOptions, &input_dtype)?;
400+
401+
// Set state to 10^37 - 1 via combine_partials.
402+
let near_limit_val: i128 = 10i128.pow(37) - 1;
403+
let near_limit =
404+
Scalar::decimal(DecimalValue::from(near_limit_val), return_dtype, Nullable);
405+
Sum.combine_partials(&mut state, near_limit)?;
406+
407+
// Now accumulate a real i128 array with a single element = 1 to overflow precision.
408+
let decimal =
409+
DecimalArray::new(buffer![1i128], DecimalDType::new(27, 0), Validity::AllValid);
410+
411+
// Drive accumulate through the vtable directly.
412+
let columnar = crate::Columnar::Canonical(crate::Canonical::Decimal(decimal));
413+
let mut ctx = LEGACY_SESSION.create_execution_ctx();
414+
Sum.accumulate(&mut state, &columnar, &mut ctx)?;
415+
416+
let result = Sum.flush(&mut state)?;
417+
assert!(result.is_null());
418+
Ok(())
419+
}
289420
}

vortex-array/src/aggregate_fn/fns/sum/mod.rs

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -160,15 +160,15 @@ impl AggregateFnVTable for Sum {
160160
*acc += val;
161161
false
162162
}
163-
SumState::Decimal(acc) => {
163+
SumState::Decimal { value, dtype } => {
164164
let val = other
165165
.as_decimal()
166166
.decimal_value()
167167
.vortex_expect("checked non-null");
168-
match acc.checked_add(&val) {
168+
match value.checked_add(&val) {
169169
Some(r) => {
170-
*acc = r;
171-
false
170+
*value = r;
171+
!value.fits_in_precision(*dtype)
172172
}
173173
None => true,
174174
}
@@ -186,12 +186,12 @@ impl AggregateFnVTable for Sum {
186186
Some(SumState::Unsigned(v)) => Scalar::primitive(*v, Nullability::Nullable),
187187
Some(SumState::Signed(v)) => Scalar::primitive(*v, Nullability::Nullable),
188188
Some(SumState::Float(v)) => Scalar::primitive(*v, Nullability::Nullable),
189-
Some(SumState::Decimal(v)) => {
189+
Some(SumState::Decimal { value, .. }) => {
190190
let decimal_dtype = *partial
191191
.return_dtype
192192
.as_decimal_opt()
193193
.vortex_expect("return dtype must be decimal");
194-
Scalar::decimal(*v, decimal_dtype, Nullability::Nullable)
194+
Scalar::decimal(*value, decimal_dtype, Nullability::Nullable)
195195
}
196196
};
197197

@@ -271,7 +271,10 @@ pub enum SumState {
271271
Unsigned(u64),
272272
Signed(i64),
273273
Float(f64),
274-
Decimal(DecimalValue),
274+
Decimal {
275+
value: DecimalValue,
276+
dtype: DecimalDType,
277+
},
275278
}
276279

277280
fn make_zero_state(return_dtype: &DType) -> SumState {
@@ -281,7 +284,10 @@ fn make_zero_state(return_dtype: &DType) -> SumState {
281284
PType::I8 | PType::I16 | PType::I32 | PType::I64 => SumState::Signed(0),
282285
PType::F16 | PType::F32 | PType::F64 => SumState::Float(0.0),
283286
},
284-
DType::Decimal(decimal, _) => SumState::Decimal(DecimalValue::zero(decimal)),
287+
DType::Decimal(decimal, _) => SumState::Decimal {
288+
value: DecimalValue::zero(decimal),
289+
dtype: *decimal,
290+
},
285291
_ => vortex_panic!("Unsupported sum type"),
286292
}
287293
}

vortex-array/src/aggregate_fn/fns/sum/primitive.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ fn accumulate_primitive_all(inner: &mut SumState, p: &PrimitiveArray) -> VortexR
5959
Ok(false)
6060
}
6161
),
62-
SumState::Decimal(_) => vortex_panic!("decimal sum state with primitive input"),
62+
SumState::Decimal { .. } => vortex_panic!("decimal sum state with primitive input"),
6363
}
6464
}
6565

@@ -105,7 +105,7 @@ fn accumulate_primitive_valid(
105105
Ok(false)
106106
}
107107
),
108-
SumState::Decimal(_) => vortex_panic!("decimal sum state with primitive input"),
108+
SumState::Decimal { .. } => vortex_panic!("decimal sum state with primitive input"),
109109
}
110110
}
111111

0 commit comments

Comments
 (0)