Skip to content

Commit 0ac3646

Browse files
committed
AggregateFn MinMax
Signed-off-by: Nicholas Gates <nick@nickgates.com>
1 parent 925f49d commit 0ac3646

6 files changed

Lines changed: 771 additions & 0 deletions

File tree

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use std::marker::PhantomData;
5+
use std::ops::BitAnd;
6+
7+
use vortex_error::VortexResult;
8+
use vortex_mask::Mask;
9+
10+
use super::Direction;
11+
use crate::ArrayRef;
12+
use crate::IntoArray;
13+
use crate::aggregate_fn::accumulator::Accumulator;
14+
use crate::arrays::BoolArray;
15+
use crate::canonical::ToCanonical;
16+
use crate::scalar::Scalar;
17+
18+
/// Accumulator for boolean min/max.
19+
///
20+
/// - Min is saturated as soon as `false` is seen (since false < true).
21+
/// - Max is saturated as soon as `true` is seen.
22+
pub(super) struct BoolExtremumAccumulator<D> {
23+
current: Option<bool>,
24+
results: Vec<Option<bool>>,
25+
_direction: PhantomData<D>,
26+
}
27+
28+
impl<D: Direction> BoolExtremumAccumulator<D> {
29+
pub(super) fn new() -> Self {
30+
Self {
31+
current: None,
32+
results: Vec::new(),
33+
_direction: PhantomData,
34+
}
35+
}
36+
37+
#[inline]
38+
fn consider(&mut self, candidate: bool) {
39+
match self.current {
40+
None => self.current = Some(candidate),
41+
Some(cur) => {
42+
if D::should_replace_bool(cur, candidate) {
43+
self.current = Some(candidate);
44+
}
45+
}
46+
}
47+
}
48+
}
49+
50+
/// Count of true and false values in a boolean array, considering validity.
51+
struct BoolCounts {
52+
true_count: u64,
53+
false_count: u64,
54+
}
55+
56+
fn bool_counts(bool_array: &BoolArray) -> VortexResult<BoolCounts> {
57+
let validity = bool_array.validity_mask()?;
58+
let bits = bool_array.to_bit_buffer();
59+
60+
match &validity {
61+
Mask::AllTrue(_) => {
62+
let true_count = bits.true_count() as u64;
63+
let false_count = bool_array.len() as u64 - true_count;
64+
Ok(BoolCounts {
65+
true_count,
66+
false_count,
67+
})
68+
}
69+
Mask::AllFalse(_) => Ok(BoolCounts {
70+
true_count: 0,
71+
false_count: 0,
72+
}),
73+
Mask::Values(v) => {
74+
let valid_bits = bits.bitand(v.bit_buffer());
75+
let true_count = valid_bits.true_count() as u64;
76+
let valid_count = v.bit_buffer().true_count() as u64;
77+
let false_count = valid_count - true_count;
78+
Ok(BoolCounts {
79+
true_count,
80+
false_count,
81+
})
82+
}
83+
}
84+
}
85+
86+
impl<D: Direction> Accumulator for BoolExtremumAccumulator<D> {
87+
fn accumulate(&mut self, batch: &ArrayRef) -> VortexResult<()> {
88+
let bool_array = batch.to_bool();
89+
let counts = bool_counts(&bool_array)?;
90+
91+
if counts.true_count > 0 {
92+
self.consider(true);
93+
}
94+
if counts.false_count > 0 {
95+
self.consider(false);
96+
}
97+
98+
Ok(())
99+
}
100+
101+
fn merge(&mut self, state: &Scalar) -> VortexResult<()> {
102+
if state.is_null() {
103+
return Ok(());
104+
}
105+
if let Some(v) = state.as_bool().value() {
106+
self.consider(v);
107+
}
108+
Ok(())
109+
}
110+
111+
fn is_saturated(&self) -> bool {
112+
self.current.is_some_and(D::is_saturated_bool)
113+
}
114+
115+
fn flush(&mut self) -> VortexResult<()> {
116+
self.results.push(self.current);
117+
self.current = None;
118+
Ok(())
119+
}
120+
121+
fn finish(self: Box<Self>) -> VortexResult<ArrayRef> {
122+
Ok(BoolArray::from_iter(self.results).into_array())
123+
}
124+
}
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
mod bool_accumulator;
5+
mod primitive_accumulator;
6+
7+
use num_traits::Bounded;
8+
use vortex_error::VortexResult;
9+
use vortex_error::vortex_bail;
10+
11+
use self::bool_accumulator::BoolExtremumAccumulator;
12+
use self::primitive_accumulator::PrimitiveExtremumAccumulator;
13+
use crate::aggregate_fn::AggregateFnId;
14+
use crate::aggregate_fn::AggregateFnVTable;
15+
use crate::aggregate_fn::accumulator::Accumulator;
16+
use crate::dtype::DType;
17+
use crate::dtype::NativePType;
18+
use crate::dtype::Nullability;
19+
use crate::match_each_native_ptype;
20+
use crate::scalar_fn::EmptyOptions;
21+
22+
/// Compile-time direction for extremum accumulators.
23+
pub(crate) trait Direction: Send + Sync + 'static {
24+
/// Returns `true` if `candidate` should replace `current`.
25+
fn should_replace<T: NativePType>(current: T, candidate: T) -> bool;
26+
27+
/// Returns `true` if `value` is at the type's extreme and no further improvement is possible.
28+
fn is_saturated<T: NativePType + Bounded>(value: T) -> bool;
29+
30+
/// Returns `true` if `candidate` should replace `current` for booleans.
31+
fn should_replace_bool(current: bool, candidate: bool) -> bool;
32+
33+
/// Returns `true` if the boolean value is saturated.
34+
fn is_saturated_bool(value: bool) -> bool;
35+
}
36+
37+
/// Seek the minimum value.
38+
pub(crate) struct FindMin;
39+
40+
impl Direction for FindMin {
41+
#[inline]
42+
fn should_replace<T: NativePType>(current: T, candidate: T) -> bool {
43+
candidate.is_lt(current)
44+
}
45+
46+
#[inline]
47+
fn is_saturated<T: NativePType + Bounded>(value: T) -> bool {
48+
value.total_compare(T::min_value()).is_eq()
49+
}
50+
51+
#[inline]
52+
fn should_replace_bool(_current: bool, candidate: bool) -> bool {
53+
!candidate
54+
}
55+
56+
#[inline]
57+
fn is_saturated_bool(value: bool) -> bool {
58+
!value
59+
}
60+
}
61+
62+
/// Seek the maximum value.
63+
pub(crate) struct FindMax;
64+
65+
impl Direction for FindMax {
66+
#[inline]
67+
fn should_replace<T: NativePType>(current: T, candidate: T) -> bool {
68+
candidate.is_gt(current)
69+
}
70+
71+
#[inline]
72+
fn is_saturated<T: NativePType + Bounded>(value: T) -> bool {
73+
value.total_compare(T::max_value()).is_eq()
74+
}
75+
76+
#[inline]
77+
fn should_replace_bool(_current: bool, candidate: bool) -> bool {
78+
candidate
79+
}
80+
81+
#[inline]
82+
fn is_saturated_bool(value: bool) -> bool {
83+
value
84+
}
85+
}
86+
87+
/// Computes the minimum of numeric or boolean values.
88+
///
89+
/// Nulls and NaN values are skipped. The output dtype matches the input dtype but is always
90+
/// nullable.
91+
///
92+
/// # Flush semantics
93+
///
94+
/// - **Empty group** (no accumulate/merge calls): produces **null**.
95+
/// - **All-null group**: produces **null**.
96+
/// - `is_saturated()` returns true once the type's minimum value is seen.
97+
#[derive(Clone)]
98+
pub struct Min;
99+
100+
/// Computes the maximum of numeric or boolean values.
101+
///
102+
/// Nulls and NaN values are skipped. The output dtype matches the input dtype but is always
103+
/// nullable.
104+
///
105+
/// # Flush semantics
106+
///
107+
/// - **Empty group** (no accumulate/merge calls): produces **null**.
108+
/// - **All-null group**: produces **null**.
109+
/// - `is_saturated()` returns true once the type's maximum value is seen.
110+
#[derive(Clone)]
111+
pub struct Max;
112+
113+
fn return_dtype(input_dtype: &DType) -> VortexResult<DType> {
114+
match input_dtype {
115+
DType::Bool(_) => Ok(DType::Bool(Nullability::Nullable)),
116+
DType::Primitive(p, _) => Ok(DType::Primitive(*p, Nullability::Nullable)),
117+
_ => vortex_bail!(
118+
"Min/Max requires numeric or boolean input, got {}",
119+
input_dtype
120+
),
121+
}
122+
}
123+
124+
fn make_accumulator<D: Direction>(input_dtype: &DType) -> VortexResult<Box<dyn Accumulator>> {
125+
match input_dtype {
126+
DType::Bool(_) => Ok(Box::new(BoolExtremumAccumulator::<D>::new())),
127+
DType::Primitive(p, _) => Ok(match_each_native_ptype!(*p, |T| {
128+
Box::new(PrimitiveExtremumAccumulator::<T, D>::new()) as Box<dyn Accumulator>
129+
})),
130+
_ => vortex_bail!(
131+
"Min/Max requires numeric or boolean input, got {}",
132+
input_dtype
133+
),
134+
}
135+
}
136+
137+
impl AggregateFnVTable for Min {
138+
type Options = EmptyOptions;
139+
140+
fn id(&self) -> AggregateFnId {
141+
AggregateFnId::new_ref("vortex.min")
142+
}
143+
144+
fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> VortexResult<DType> {
145+
return_dtype(input_dtype)
146+
}
147+
148+
fn state_dtype(&self, options: &Self::Options, input_dtype: &DType) -> VortexResult<DType> {
149+
self.return_dtype(options, input_dtype)
150+
}
151+
152+
fn accumulator(
153+
&self,
154+
_options: &Self::Options,
155+
input_dtype: &DType,
156+
) -> VortexResult<Box<dyn Accumulator>> {
157+
make_accumulator::<FindMin>(input_dtype)
158+
}
159+
}
160+
161+
impl AggregateFnVTable for Max {
162+
type Options = EmptyOptions;
163+
164+
fn id(&self) -> AggregateFnId {
165+
AggregateFnId::new_ref("vortex.max")
166+
}
167+
168+
fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> VortexResult<DType> {
169+
return_dtype(input_dtype)
170+
}
171+
172+
fn state_dtype(&self, options: &Self::Options, input_dtype: &DType) -> VortexResult<DType> {
173+
self.return_dtype(options, input_dtype)
174+
}
175+
176+
fn accumulator(
177+
&self,
178+
_options: &Self::Options,
179+
input_dtype: &DType,
180+
) -> VortexResult<Box<dyn Accumulator>> {
181+
make_accumulator::<FindMax>(input_dtype)
182+
}
183+
}
184+
185+
#[cfg(test)]
186+
mod tests;

0 commit comments

Comments
 (0)