Skip to content

Commit afd676d

Browse files
committed
Aggregate Fns: Sum
Signed-off-by: Nicholas Gates <nick@nickgates.com>
1 parent 1f43f0b commit afd676d

7 files changed

Lines changed: 797 additions & 0 deletions

File tree

vortex-array/src/aggregate_fn/fns/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

44
pub mod mean;
5+
pub mod sum;
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use std::ops::BitAnd;
5+
6+
use vortex_error::VortexResult;
7+
use vortex_mask::Mask;
8+
9+
use crate::ArrayRef;
10+
use crate::IntoArray;
11+
use crate::aggregate_fn::accumulator::Accumulator;
12+
use crate::arrays::PrimitiveArray;
13+
use crate::canonical::ToCanonical;
14+
use crate::scalar::Scalar;
15+
16+
/// Accumulator that sums boolean values by counting `true` as 1 and `false` as 0.
17+
///
18+
/// Output type is `u64` (nullable). Overflow is theoretically possible but extremely
19+
/// unlikely since it would require more than `u64::MAX` true values.
20+
pub(super) struct BoolSumAccumulator {
21+
count: u64,
22+
/// Whether at least one non-null value has been accumulated.
23+
has_values: bool,
24+
/// Whether accumulate() or merge() has been called at all (even with all-null data).
25+
has_input: bool,
26+
checked: bool,
27+
overflowed: bool,
28+
results: Vec<Option<u64>>,
29+
}
30+
31+
impl BoolSumAccumulator {
32+
pub(super) fn new(checked: bool) -> Self {
33+
Self {
34+
count: 0,
35+
has_values: false,
36+
has_input: false,
37+
checked,
38+
overflowed: false,
39+
results: Vec::new(),
40+
}
41+
}
42+
}
43+
44+
impl Accumulator for BoolSumAccumulator {
45+
fn accumulate(&mut self, batch: &ArrayRef) -> VortexResult<()> {
46+
self.has_input = true;
47+
if self.overflowed {
48+
return Ok(());
49+
}
50+
51+
let bool_array = batch.to_bool();
52+
let validity = bool_array.validity_mask()?;
53+
54+
let true_count = match &validity {
55+
Mask::AllTrue(_) => bool_array.to_bit_buffer().true_count() as u64,
56+
Mask::AllFalse(_) => return Ok(()),
57+
Mask::Values(v) => bool_array
58+
.to_bit_buffer()
59+
.bitand(v.bit_buffer())
60+
.true_count() as u64,
61+
};
62+
63+
self.has_values = true;
64+
if self.checked {
65+
if let Some(new_count) = self.count.checked_add(true_count) {
66+
self.count = new_count;
67+
} else {
68+
self.overflowed = true;
69+
}
70+
} else {
71+
self.count = self.count.wrapping_add(true_count);
72+
}
73+
74+
Ok(())
75+
}
76+
77+
fn merge(&mut self, state: &Scalar) -> VortexResult<()> {
78+
if state.is_null() {
79+
return Ok(());
80+
}
81+
self.has_input = true;
82+
if let Some(v) = state.as_primitive().typed_value::<u64>() {
83+
self.has_values = true;
84+
if self.checked {
85+
if let Some(new_count) = self.count.checked_add(v) {
86+
self.count = new_count;
87+
} else {
88+
self.overflowed = true;
89+
}
90+
} else {
91+
self.count = self.count.wrapping_add(v);
92+
}
93+
}
94+
Ok(())
95+
}
96+
97+
fn is_saturated(&self) -> bool {
98+
self.checked && self.overflowed
99+
}
100+
101+
fn flush(&mut self) -> VortexResult<()> {
102+
let result = if self.overflowed {
103+
None
104+
} else if self.has_values {
105+
Some(self.count)
106+
} else if self.has_input {
107+
// All-null group.
108+
None
109+
} else {
110+
// Empty group: identity is zero.
111+
Some(0)
112+
};
113+
self.results.push(result);
114+
self.count = 0;
115+
self.has_values = false;
116+
self.has_input = false;
117+
self.overflowed = false;
118+
Ok(())
119+
}
120+
121+
fn finish(self: Box<Self>) -> VortexResult<ArrayRef> {
122+
Ok(PrimitiveArray::from_option_iter(self.results).into_array())
123+
}
124+
}
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use vortex_error::VortexResult;
5+
use vortex_mask::Mask;
6+
7+
use crate::ArrayRef;
8+
use crate::IntoArray;
9+
use crate::aggregate_fn::accumulator::Accumulator;
10+
use crate::arrays::PrimitiveArray;
11+
use crate::canonical::ToCanonical;
12+
use crate::dtype::NativePType;
13+
use crate::match_each_native_ptype;
14+
use crate::scalar::Scalar;
15+
16+
pub(super) struct FloatSumAccumulator {
17+
sum: f64,
18+
/// Whether at least one non-null value has been accumulated.
19+
has_values: bool,
20+
/// Whether accumulate() or merge() has been called at all (even with all-null data).
21+
has_input: bool,
22+
results: Vec<Option<f64>>,
23+
}
24+
25+
impl FloatSumAccumulator {
26+
pub(super) fn new() -> Self {
27+
Self {
28+
sum: 0.0,
29+
has_values: false,
30+
has_input: false,
31+
results: Vec::new(),
32+
}
33+
}
34+
}
35+
36+
fn accumulate_all_valid<T: NativePType>(values: &[T], sum: &mut f64, has_values: &mut bool) {
37+
for v in values {
38+
*has_values = true;
39+
*sum += v.to_f64().unwrap_or(0.0);
40+
}
41+
}
42+
43+
fn accumulate_with_mask<T: NativePType>(
44+
values: &[T],
45+
mask: &vortex_mask::MaskValues,
46+
sum: &mut f64,
47+
has_values: &mut bool,
48+
) {
49+
for (v, valid) in values.iter().zip(mask.bit_buffer().iter()) {
50+
if valid {
51+
*has_values = true;
52+
*sum += v.to_f64().unwrap_or(0.0);
53+
}
54+
}
55+
}
56+
57+
impl Accumulator for FloatSumAccumulator {
58+
fn accumulate(&mut self, batch: &ArrayRef) -> VortexResult<()> {
59+
self.has_input = true;
60+
let primitive = batch.to_primitive();
61+
let validity = primitive.validity_mask()?;
62+
63+
match_each_native_ptype!(primitive.ptype(), integral: |_T| {
64+
unreachable!("FloatSumAccumulator should not be used with integer types");
65+
}, floating: |T| {
66+
let values = primitive.as_slice::<T>();
67+
match &validity {
68+
Mask::AllTrue(_) => accumulate_all_valid(
69+
values,
70+
&mut self.sum,
71+
&mut self.has_values,
72+
),
73+
Mask::AllFalse(_) => {}
74+
Mask::Values(v) => accumulate_with_mask(
75+
values,
76+
v,
77+
&mut self.sum,
78+
&mut self.has_values,
79+
),
80+
}
81+
});
82+
83+
Ok(())
84+
}
85+
86+
fn merge(&mut self, state: &Scalar) -> VortexResult<()> {
87+
if state.is_null() {
88+
return Ok(());
89+
}
90+
self.has_input = true;
91+
if let Some(v) = state.as_primitive().typed_value::<f64>() {
92+
self.has_values = true;
93+
self.sum += v;
94+
}
95+
Ok(())
96+
}
97+
98+
fn flush(&mut self) -> VortexResult<()> {
99+
let result = if self.has_values {
100+
Some(self.sum)
101+
} else if self.has_input {
102+
// All-null group.
103+
None
104+
} else {
105+
// Empty group: identity is zero.
106+
Some(0.0)
107+
};
108+
self.results.push(result);
109+
self.sum = 0.0;
110+
self.has_values = false;
111+
self.has_input = false;
112+
Ok(())
113+
}
114+
115+
fn finish(self: Box<Self>) -> VortexResult<ArrayRef> {
116+
Ok(PrimitiveArray::from_option_iter(self.results).into_array())
117+
}
118+
}

0 commit comments

Comments
 (0)