Skip to content

Commit e1f2e3c

Browse files
robert3005joseph-isaacs
authored andcommitted
Decimal sum uses smallest physical type that supports output dtype (#7008)
Instead of always casting values to i256 we take the primitive value of the return dtype and perform operations in that space --------- Signed-off-by: Robert Kruszewski <github@robertk.io> Co-authored-by: Joe Isaacs <joe.isaacs@live.co.uk>
1 parent 8d230d2 commit e1f2e3c

1 file changed

Lines changed: 78 additions & 33 deletions

File tree

vortex-array/src/aggregate_fn/fns/sum/decimal.rs

Lines changed: 78 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2,57 +2,102 @@
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

44
use itertools::Itertools;
5+
use num_traits::AsPrimitive;
6+
use num_traits::CheckedAdd;
7+
use num_traits::NumOps;
8+
use vortex_buffer::BitBuffer;
9+
use vortex_buffer::Buffer;
10+
use vortex_error::VortexExpect;
511
use vortex_error::VortexResult;
612
use vortex_error::vortex_panic;
7-
use vortex_mask::AllOr;
13+
use vortex_mask::Mask;
814

915
use super::SumState;
1016
use crate::arrays::DecimalArray;
17+
use crate::dtype::DecimalDType;
18+
use crate::dtype::DecimalType;
19+
use crate::dtype::NativeDecimalType;
1120
use crate::match_each_decimal_value_type;
1221
use crate::scalar::DecimalValue;
1322

1423
/// Accumulate a decimal array into the sum state.
1524
/// Returns Ok(true) if saturated (overflow), Ok(false) if not.
1625
pub(super) fn accumulate_decimal(inner: &mut SumState, d: &DecimalArray) -> VortexResult<bool> {
26+
let mask = d.validity_mask()?;
27+
let validity = match &mask {
28+
Mask::AllTrue(_) => None,
29+
Mask::Values(mask_values) => Some(mask_values.bit_buffer()),
30+
Mask::AllFalse(_) => {
31+
return Ok(false);
32+
}
33+
};
34+
1735
let SumState::Decimal { value, dtype } = inner else {
1836
vortex_panic!("expected decimal sum state for decimal input");
1937
};
2038

21-
let mask = d.validity_mask()?;
22-
match mask.bit_buffer() {
23-
AllOr::None => Ok(false),
24-
AllOr::All => match_each_decimal_value_type!(d.values_type(), |T| {
25-
for &v in d.buffer::<T>().iter() {
26-
match value.checked_add(&DecimalValue::from(v)) {
27-
Some(r) => {
28-
*value = r;
29-
// Check for overflow
30-
if !value.fits_in_precision(*dtype) {
31-
return Ok(true);
32-
}
33-
}
34-
None => return Ok(true),
35-
}
39+
let values_type = DecimalType::smallest_decimal_value_type(dtype);
40+
match_each_decimal_value_type!(d.values_type(), |T| {
41+
match_each_decimal_value_type!(values_type, |I| {
42+
let initial: I = value
43+
.cast()
44+
.vortex_expect("cannot fail to cast initial value");
45+
match sum_decimal_value(initial, d.buffer::<T>(), validity, *dtype) {
46+
Some(v) => *value = v,
47+
None => return Ok(true),
3648
}
3749
Ok(false)
38-
}),
39-
AllOr::Some(validity) => match_each_decimal_value_type!(d.values_type(), |T| {
40-
for (&v, valid) in d.buffer::<T>().iter().zip_eq(validity.iter()) {
41-
if valid {
42-
match value.checked_add(&DecimalValue::from(v)) {
43-
Some(r) => {
44-
*value = r;
45-
if !value.fits_in_precision(*dtype) {
46-
return Ok(true);
47-
}
48-
}
49-
None => return Ok(true),
50-
}
51-
}
52-
}
53-
Ok(false)
54-
}),
50+
})
51+
})
52+
}
53+
54+
fn sum_decimal_value<T, I>(
55+
initial: I,
56+
values: Buffer<T>,
57+
validity: Option<&BitBuffer>,
58+
output_dtype: DecimalDType,
59+
) -> Option<DecimalValue>
60+
where
61+
T: AsPrimitive<I>,
62+
I: NumOps + CheckedAdd + Copy + NativeDecimalType + 'static,
63+
bool: AsPrimitive<I>,
64+
DecimalValue: From<I>,
65+
{
66+
let sum = match validity {
67+
Some(v) => sum_decimal_with_validity(values, v, initial),
68+
None => sum_decimal(values, initial),
69+
};
70+
71+
sum.map(DecimalValue::from)
72+
// We have to make sure that the decimal value fits the precision of the decimal dtype.
73+
.filter(|v| v.fits_in_precision(output_dtype))
74+
}
75+
76+
fn sum_decimal<T: AsPrimitive<I>, I: Copy + CheckedAdd + 'static>(
77+
values: Buffer<T>,
78+
initial: I,
79+
) -> Option<I> {
80+
let mut sum = initial;
81+
for v in values.iter() {
82+
let v: I = v.as_();
83+
sum = CheckedAdd::checked_add(&sum, &v)?;
84+
}
85+
Some(sum)
86+
}
87+
88+
fn sum_decimal_with_validity<T, I>(values: Buffer<T>, validity: &BitBuffer, initial: I) -> Option<I>
89+
where
90+
T: AsPrimitive<I>,
91+
I: NumOps + CheckedAdd + Copy + 'static,
92+
bool: AsPrimitive<I>,
93+
{
94+
let mut sum = initial;
95+
for (v, valid) in values.iter().zip_eq(validity) {
96+
let v: I = v.as_() * valid.as_();
97+
98+
sum = CheckedAdd::checked_add(&sum, &v)?;
5599
}
100+
Some(sum)
56101
}
57102

58103
#[cfg(test)]

0 commit comments

Comments
 (0)