|
2 | 2 | // SPDX-FileCopyrightText: Copyright the Vortex contributors |
3 | 3 |
|
4 | 4 | use itertools::Itertools; |
| 5 | +use num_traits::AsPrimitive; |
| 6 | +use num_traits::CheckedAdd; |
| 7 | +use num_traits::NumOps; |
| 8 | +use vortex_buffer::BitBuffer; |
| 9 | +use vortex_buffer::Buffer; |
| 10 | +use vortex_error::VortexExpect; |
5 | 11 | use vortex_error::VortexResult; |
6 | 12 | use vortex_error::vortex_panic; |
7 | | -use vortex_mask::AllOr; |
| 13 | +use vortex_mask::Mask; |
8 | 14 |
|
9 | 15 | use super::SumState; |
10 | 16 | use crate::arrays::DecimalArray; |
| 17 | +use crate::dtype::DecimalDType; |
| 18 | +use crate::dtype::DecimalType; |
| 19 | +use crate::dtype::NativeDecimalType; |
11 | 20 | use crate::match_each_decimal_value_type; |
12 | 21 | use crate::scalar::DecimalValue; |
13 | 22 |
|
14 | 23 | /// Accumulate a decimal array into the sum state. |
15 | 24 | /// Returns Ok(true) if saturated (overflow), Ok(false) if not. |
16 | 25 | pub(super) fn accumulate_decimal(inner: &mut SumState, d: &DecimalArray) -> VortexResult<bool> { |
| 26 | + let mask = d.validity_mask()?; |
| 27 | + let validity = match &mask { |
| 28 | + Mask::AllTrue(_) => None, |
| 29 | + Mask::Values(mask_values) => Some(mask_values.bit_buffer()), |
| 30 | + Mask::AllFalse(_) => { |
| 31 | + return Ok(false); |
| 32 | + } |
| 33 | + }; |
| 34 | + |
17 | 35 | let SumState::Decimal { value, dtype } = inner else { |
18 | 36 | vortex_panic!("expected decimal sum state for decimal input"); |
19 | 37 | }; |
20 | 38 |
|
21 | | - let mask = d.validity_mask()?; |
22 | | - match mask.bit_buffer() { |
23 | | - AllOr::None => Ok(false), |
24 | | - AllOr::All => match_each_decimal_value_type!(d.values_type(), |T| { |
25 | | - for &v in d.buffer::<T>().iter() { |
26 | | - match value.checked_add(&DecimalValue::from(v)) { |
27 | | - Some(r) => { |
28 | | - *value = r; |
29 | | - // Check for overflow |
30 | | - if !value.fits_in_precision(*dtype) { |
31 | | - return Ok(true); |
32 | | - } |
33 | | - } |
34 | | - None => return Ok(true), |
35 | | - } |
| 39 | + let values_type = DecimalType::smallest_decimal_value_type(dtype); |
| 40 | + match_each_decimal_value_type!(d.values_type(), |T| { |
| 41 | + match_each_decimal_value_type!(values_type, |I| { |
| 42 | + let initial: I = value |
| 43 | + .cast() |
| 44 | + .vortex_expect("cannot fail to cast initial value"); |
| 45 | + match sum_decimal_value(initial, d.buffer::<T>(), validity, *dtype) { |
| 46 | + Some(v) => *value = v, |
| 47 | + None => return Ok(true), |
36 | 48 | } |
37 | 49 | Ok(false) |
38 | | - }), |
39 | | - AllOr::Some(validity) => match_each_decimal_value_type!(d.values_type(), |T| { |
40 | | - for (&v, valid) in d.buffer::<T>().iter().zip_eq(validity.iter()) { |
41 | | - if valid { |
42 | | - match value.checked_add(&DecimalValue::from(v)) { |
43 | | - Some(r) => { |
44 | | - *value = r; |
45 | | - if !value.fits_in_precision(*dtype) { |
46 | | - return Ok(true); |
47 | | - } |
48 | | - } |
49 | | - None => return Ok(true), |
50 | | - } |
51 | | - } |
52 | | - } |
53 | | - Ok(false) |
54 | | - }), |
| 50 | + }) |
| 51 | + }) |
| 52 | +} |
| 53 | + |
| 54 | +fn sum_decimal_value<T, I>( |
| 55 | + initial: I, |
| 56 | + values: Buffer<T>, |
| 57 | + validity: Option<&BitBuffer>, |
| 58 | + output_dtype: DecimalDType, |
| 59 | +) -> Option<DecimalValue> |
| 60 | +where |
| 61 | + T: AsPrimitive<I>, |
| 62 | + I: NumOps + CheckedAdd + Copy + NativeDecimalType + 'static, |
| 63 | + bool: AsPrimitive<I>, |
| 64 | + DecimalValue: From<I>, |
| 65 | +{ |
| 66 | + let sum = match validity { |
| 67 | + Some(v) => sum_decimal_with_validity(values, v, initial), |
| 68 | + None => sum_decimal(values, initial), |
| 69 | + }; |
| 70 | + |
| 71 | + sum.map(DecimalValue::from) |
| 72 | + // We have to make sure that the decimal value fits the precision of the decimal dtype. |
| 73 | + .filter(|v| v.fits_in_precision(output_dtype)) |
| 74 | +} |
| 75 | + |
| 76 | +fn sum_decimal<T: AsPrimitive<I>, I: Copy + CheckedAdd + 'static>( |
| 77 | + values: Buffer<T>, |
| 78 | + initial: I, |
| 79 | +) -> Option<I> { |
| 80 | + let mut sum = initial; |
| 81 | + for v in values.iter() { |
| 82 | + let v: I = v.as_(); |
| 83 | + sum = CheckedAdd::checked_add(&sum, &v)?; |
| 84 | + } |
| 85 | + Some(sum) |
| 86 | +} |
| 87 | + |
| 88 | +fn sum_decimal_with_validity<T, I>(values: Buffer<T>, validity: &BitBuffer, initial: I) -> Option<I> |
| 89 | +where |
| 90 | + T: AsPrimitive<I>, |
| 91 | + I: NumOps + CheckedAdd + Copy + 'static, |
| 92 | + bool: AsPrimitive<I>, |
| 93 | +{ |
| 94 | + let mut sum = initial; |
| 95 | + for (v, valid) in values.iter().zip_eq(validity) { |
| 96 | + let v: I = v.as_() * valid.as_(); |
| 97 | + |
| 98 | + sum = CheckedAdd::checked_add(&sum, &v)?; |
55 | 99 | } |
| 100 | + Some(sum) |
56 | 101 | } |
57 | 102 |
|
58 | 103 | #[cfg(test)] |
|
0 commit comments